blob: 9d073e1c70da74ddf58efc58e4e6ab7e34a532f6 [file] [log] [blame] [edit]
package exporter
import (
"fmt"
"reflect"
"strings"
)
// Options to control the generation done by GenerateSQL.
type Options int
const (
SchemaOnly Options = iota
SchemaAndColumnNames
)
// GenerateSQL takes in a "table type", that is a table whose fields are slices.
// Each field will be interpreted as a table. The sql struct tags will be used
// to generate the SQL schema. A package name is taken in to be included in the
// returned string. If a malformed type is passed in, this function will panic.
func GenerateSQL(inputType interface{}, pkg string, opt Options) string {
header := fmt.Sprintf("package %s\n\n// Generated by //go/sql/exporter/\n// DO NOT EDIT\n\nconst Schema = `", pkg)
body := strings.Builder{}
t := reflect.TypeOf(inputType)
for i := 0; i < t.NumField(); i++ {
table := t.Field(i) // Fields of the outer type are expected to be tables.
if table.Type.Kind() != reflect.Slice {
panic(`Expected table should be a slice: ` + table.Name)
}
body.WriteString("CREATE TABLE IF NOT EXISTS ")
body.WriteString(table.Name)
body.WriteString(" (")
row := table.Type.Elem()
wasFirst := true
for j := 0; j < row.NumField(); j++ {
col := row.Field(j)
sqlText, ok := col.Tag.Lookup("sql")
if !ok {
panic(`Field missing "sql" tag:` + table.Name + "." + row.Name())
}
if !wasFirst {
body.WriteString(",")
}
wasFirst = false
body.WriteString("\n ")
body.WriteString(strings.TrimSpace(sqlText))
}
body.WriteString("\n);\n")
}
body.WriteString("`\n")
cols := ""
if opt == SchemaAndColumnNames {
cols += columnNames(inputType)
}
return header + body.String() + cols
}
// columnNames takes in a "table type", that is a table whose fields are slices.
// Each field will be interpreted as a table. The sql struct tags will be used
// to generate a variable for each table that contains the column names in the
// order they appear in the struct. If a malformed type is passed in, this
// function will panic.
//
// Indexes and computed columns are ignored.
func columnNames(inputType interface{}) string {
body := strings.Builder{}
t := reflect.TypeOf(inputType)
for i := 0; i < t.NumField(); i++ {
body.WriteString("\n")
table := t.Field(i) // Fields of the outer type are expected to be tables.
if table.Type.Kind() != reflect.Slice {
panic(`Expected table should be a slice: ` + table.Name)
}
body.WriteString(`var `)
body.WriteString(table.Name)
body.WriteString(" = []string{")
row := table.Type.Elem()
for j := 0; j < row.NumField(); j++ {
col := row.Field(j)
sqlText, ok := col.Tag.Lookup("sql")
if !ok {
panic(`Field missing "sql" tag:` + table.Name + "." + row.Name())
}
sqlText = strings.TrimSpace(sqlText)
if strings.Contains(sqlText, "STORED") || strings.HasPrefix(sqlText, "INDEX") || strings.HasPrefix(sqlText, "PRIMARY") || strings.HasPrefix(sqlText, "INVERTED") {
continue
}
body.WriteString("\n")
colName := strings.SplitN(sqlText, " ", 2)[0]
body.WriteString("\t\"")
body.WriteString(colName)
body.WriteString(`",`)
}
body.WriteString("\n}\n")
}
return body.String()
}