| package exporter |
| |
| import ( |
| "fmt" |
| "reflect" |
| "strings" |
| ) |
| |
| // Options to control the generation done by GenerateSQL. |
| type Options int |
| |
| const ( |
| SchemaOnly Options = iota |
| SchemaAndColumnNames |
| ) |
| |
| // GenerateSQL takes in a "table type", that is a table whose fields are slices. |
| // Each field will be interpreted as a table. The sql struct tags will be used |
| // to generate the SQL schema. A package name is taken in to be included in the |
| // returned string. If a malformed type is passed in, this function will panic. |
| func GenerateSQL(inputType interface{}, pkg string, opt Options) string { |
| header := fmt.Sprintf("package %s\n\n// Generated by //go/sql/exporter/\n// DO NOT EDIT\n\nconst Schema = `", pkg) |
| |
| body := strings.Builder{} |
| t := reflect.TypeOf(inputType) |
| for i := 0; i < t.NumField(); i++ { |
| table := t.Field(i) // Fields of the outer type are expected to be tables. |
| if table.Type.Kind() != reflect.Slice { |
| panic(`Expected table should be a slice: ` + table.Name) |
| } |
| body.WriteString("CREATE TABLE IF NOT EXISTS ") |
| body.WriteString(table.Name) |
| body.WriteString(" (") |
| row := table.Type.Elem() |
| wasFirst := true |
| for j := 0; j < row.NumField(); j++ { |
| col := row.Field(j) |
| sqlText, ok := col.Tag.Lookup("sql") |
| if !ok { |
| panic(`Field missing "sql" tag:` + table.Name + "." + row.Name()) |
| } |
| if !wasFirst { |
| body.WriteString(",") |
| } |
| wasFirst = false |
| body.WriteString("\n ") |
| body.WriteString(strings.TrimSpace(sqlText)) |
| } |
| body.WriteString("\n);\n") |
| } |
| body.WriteString("`\n") |
| cols := "" |
| if opt == SchemaAndColumnNames { |
| cols += columnNames(inputType) |
| } |
| return header + body.String() + cols |
| } |
| |
| // columnNames takes in a "table type", that is a table whose fields are slices. |
| // Each field will be interpreted as a table. The sql struct tags will be used |
| // to generate a variable for each table that contains the column names in the |
| // order they appear in the struct. If a malformed type is passed in, this |
| // function will panic. |
| // |
| // Indexes and computed columns are ignored. |
| func columnNames(inputType interface{}) string { |
| body := strings.Builder{} |
| |
| t := reflect.TypeOf(inputType) |
| for i := 0; i < t.NumField(); i++ { |
| body.WriteString("\n") |
| table := t.Field(i) // Fields of the outer type are expected to be tables. |
| if table.Type.Kind() != reflect.Slice { |
| panic(`Expected table should be a slice: ` + table.Name) |
| } |
| body.WriteString(`var `) |
| body.WriteString(table.Name) |
| body.WriteString(" = []string{") |
| row := table.Type.Elem() |
| for j := 0; j < row.NumField(); j++ { |
| col := row.Field(j) |
| sqlText, ok := col.Tag.Lookup("sql") |
| if !ok { |
| panic(`Field missing "sql" tag:` + table.Name + "." + row.Name()) |
| } |
| sqlText = strings.TrimSpace(sqlText) |
| if strings.Contains(sqlText, "STORED") || strings.HasPrefix(sqlText, "INDEX") || strings.HasPrefix(sqlText, "PRIMARY") || strings.HasPrefix(sqlText, "INVERTED") { |
| continue |
| } |
| body.WriteString("\n") |
| colName := strings.SplitN(sqlText, " ", 2)[0] |
| body.WriteString("\t\"") |
| body.WriteString(colName) |
| body.WriteString(`",`) |
| } |
| body.WriteString("\n}\n") |
| } |
| return body.String() |
| } |