blob: bdcbd1124cea28e2eca8362c3371b5b53832b722 [file] [log] [blame]
package exporter
import (
"fmt"
"reflect"
"slices"
"strings"
)
// Options to control the generation done by GenerateSQL.
type Options int
const (
SchemaOnly Options = iota
SchemaAndColumnNames
)
// SchemaTarget defines the target db to generate the schema for.
type SchemaTarget int
const (
CockroachDB SchemaTarget = iota
Spanner
)
// spannerConverter provides a struct to help replace necessary values
// in the schema to be compatible with Spanner postgres.
// TODO(ashwinpv): This will replace the other schema gen once spanner is fully rolled out.
type spannerConverter struct {
sequences []string
indices []string
indexNames []string
ttlExcludeTables []string
primaryKeys map[string]string
}
// getSequenceDeclarations returns the sequence creation statements for all
// the sequences that were encountered while replacing unique_rowid() during schema generation.
func (sc *spannerConverter) getSequenceDeclarations() string {
if len(sc.sequences) > 0 {
sequenceBuilder := strings.Builder{}
for _, seq := range sc.sequences {
sequenceBuilder.WriteString("CREATE SEQUENCE IF NOT EXISTS " + seq + " bit_reversed_positive;\n")
}
return sequenceBuilder.String()
}
return ""
}
// getIndexDeclarations returns the index creation statements for all
// the indices that were encountered during schema generation.
func (sc *spannerConverter) getIndexDeclarations() string {
if len(sc.indices) > 0 {
indexBuilder := strings.Builder{}
for _, idx := range sc.indices {
indexBuilder.WriteString("CREATE INDEX IF NOT EXISTS " + idx + ";\n")
}
return indexBuilder.String()
}
return ""
}
// updateColumnTypesForSpanner updates the column types defined in the column string with types
// compatible with Spanner postgres.
func (sc *spannerConverter) updateColumnTypesForSpanner(sqlColumnText string, tableName string) string {
typeReplacements := map[string]string{
"INT2": "INT8",
"INT4": "INT8",
"CHAR": "VARCHAR(1)",
"STRING": "TEXT",
"UUID": "TEXT",
"BYTES": "BYTEA",
"gen_random_uuid()": "spanner.generate_uuid()",
"UNIQUE": "",
"SERIAL": "INT8",
}
// unique_rowid() generates a unique integer identifier for int columns. This does not work in spanner.
// The replacement is basically to define a SEQUENCE and use nextval('<sequence_name>') to get the
// unique value. We keep a track of all the sequences we need to create and then create the generation
// statements later.
uniqueRowIdentifier := "unique_rowid()"
if strings.Contains(sqlColumnText, uniqueRowIdentifier) {
sequenceName := fmt.Sprintf("%s_seq", tableName)
sc.sequences = append(sc.sequences, sequenceName)
typeReplacements[uniqueRowIdentifier] = fmt.Sprintf("nextval('%s')", sequenceName)
if strings.Contains(sqlColumnText, "PRIMARY KEY") {
// The primary key statement should come after the columns when we
// are using nextval for generating unique row ids.
columnName := strings.Split(sqlColumnText, " ")[0]
sc.primaryKeys[tableName] = columnName
sqlColumnText = strings.Replace(sqlColumnText, " PRIMARY KEY", "", -1)
}
}
if strings.Contains(sqlColumnText, "INDEX") {
// This is a list of indices to ignore. When we switch to spanner,
// these will either be removed or updated with a compatible replacement.
ignoreIndices := map[string][]string{
// These are not supported since spanner does not support indexing on JSONB objects.
"Traces": {"keys_idx", "keys_idx_1"},
"ValuesAtHead": {"keys_idx"},
}
// Index is specified as "INDEX <index_name> (index columns)"
indexSpecStartIdx := strings.Index(sqlColumnText, "INDEX") + 6
indexSpec := sqlColumnText[indexSpecStartIdx:]
splits := strings.SplitAfterN(indexSpec, " ", 2)
indexName := strings.TrimSpace(splits[0])
indexDetails := strings.TrimSpace(splits[1])
if strings.Contains(indexDetails, "STORING") {
indexDetails = strings.ReplaceAll(indexDetails, "STORING", "INCLUDE")
}
if slices.Contains(sc.indexNames, indexName) {
indexName = indexName + "_1"
}
if excludeIndices, ok := ignoreIndices[tableName]; ok {
if slices.Contains(excludeIndices, indexName) {
return ""
}
}
sc.indexNames = append(sc.indexNames, indexName)
// Spanner expects the index definition to be "CREATE INDEX <index_name> on <table_name> (<columns>)"
sc.indices = append(sc.indices, fmt.Sprintf("%s on %s %s", indexName, tableName, indexDetails))
// The index is not specified in the schema as a column, so return empty string.
return ""
}
// Check if this is a computed column in CDB schema.
// Eg: corpus TEXT AS (keys->>'source_type') STORED NOT NULL
// This should be written as "corpus TEXT GENERATED ALWAYS AS (keys->>'source_type') STORED NOT NULL" for Spanner.
if strings.Contains(sqlColumnText, "AS (") {
insertIndex := strings.Index(sqlColumnText, "AS (")
sqlColumnText = sqlColumnText[:insertIndex-1] + " GENERATED ALWAYS " + sqlColumnText[insertIndex:]
}
updatedString := sqlColumnText
for baseType, spannerType := range typeReplacements {
if strings.Contains(sqlColumnText, baseType) {
updatedString = strings.Replace(updatedString, baseType, spannerType, 1)
}
}
return updatedString
}
// GenerateSQL takes in a "table type", that is a table whose fields are slices.
// Each field will be interpreted as a table. The sql struct tags will be used
// to generate the SQL schema. A package name is taken in to be included in the
// returned string. If a malformed type is passed in, this function will panic.
func GenerateSQL(inputType interface{}, pkg string, opt Options, schemaTarget SchemaTarget, ttlExcludeTables []string) string {
header := fmt.Sprintf("package %s\n\n// Generated by //go/sql/exporter/\n// DO NOT EDIT\n\nconst Schema = `", pkg)
var sc *spannerConverter
if schemaTarget == Spanner {
sc = &spannerConverter{
sequences: []string{},
indices: []string{},
indexNames: []string{},
ttlExcludeTables: ttlExcludeTables,
primaryKeys: map[string]string{},
}
}
body := strings.Builder{}
t := reflect.TypeOf(inputType)
for i := 0; i < t.NumField(); i++ {
table := t.Field(i) // Fields of the outer type are expected to be tables.
if table.Type.Kind() != reflect.Slice {
panic(`Expected table should be a slice: ` + table.Name)
}
body.WriteString("CREATE TABLE IF NOT EXISTS ")
body.WriteString(table.Name)
body.WriteString(" (")
row := table.Type.Elem()
wasFirst := true
for j := 0; j < row.NumField(); j++ {
col := row.Field(j)
sqlText, ok := col.Tag.Lookup("sql")
if !ok {
panic(`Field missing "sql" tag:` + table.Name + "." + row.Name())
}
// If generating for spanner, update the column types to be compatible.
if sc != nil {
sqlText = sc.updateColumnTypesForSpanner(sqlText, table.Name)
}
// If the column was index specification and we are generating for spanner,
// sqlText can be empty.
if sqlText != "" {
if !wasFirst {
body.WriteString(",")
}
wasFirst = false
body.WriteString("\n ")
body.WriteString(strings.TrimSpace(sqlText))
}
}
if sc != nil {
// Automatically create a TTL column for tables in Spanner.
body.WriteString(",\n createdat TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP")
if pk, ok := sc.primaryKeys[table.Name]; ok {
body.WriteString(",\n PRIMARY KEY (" + pk + ")")
}
// Do not add a TTL spec if the table is excluded.
if slices.Contains(sc.ttlExcludeTables, table.Name) {
body.WriteString("\n);\n")
} else {
// Add TTL spec of 3 years by default.
body.WriteString("\n) TTL INTERVAL '1095 days' ON createdat;\n")
}
} else {
body.WriteString("\n);\n")
}
}
sequences := ""
indices := ""
// If generating for spanner, we would need to add sequence creation statements before
// the tables for all unique_rowid() replacements.
if schemaTarget == Spanner {
sequences = sc.getSequenceDeclarations()
indices = sc.getIndexDeclarations()
body.WriteString(indices)
}
body.WriteString("`\n")
cols := ""
if opt == SchemaAndColumnNames {
cols += columnNames(inputType)
}
return header + sequences + body.String() + cols
}
// columnNames takes in a "table type", that is a table whose fields are slices.
// Each field will be interpreted as a table. The sql struct tags will be used
// to generate a variable for each table that contains the column names in the
// order they appear in the struct. If a malformed type is passed in, this
// function will panic.
//
// Indexes and computed columns are ignored.
func columnNames(inputType interface{}) string {
body := strings.Builder{}
t := reflect.TypeOf(inputType)
for i := 0; i < t.NumField(); i++ {
body.WriteString("\n")
table := t.Field(i) // Fields of the outer type are expected to be tables.
if table.Type.Kind() != reflect.Slice {
panic(`Expected table should be a slice: ` + table.Name)
}
body.WriteString(`var `)
body.WriteString(table.Name)
body.WriteString(" = []string{")
row := table.Type.Elem()
for j := 0; j < row.NumField(); j++ {
col := row.Field(j)
sqlText, ok := col.Tag.Lookup("sql")
if !ok {
panic(`Field missing "sql" tag:` + table.Name + "." + row.Name())
}
sqlText = strings.TrimSpace(sqlText)
if strings.Contains(sqlText, "STORED") || strings.HasPrefix(sqlText, "INDEX") || strings.HasPrefix(sqlText, "PRIMARY") || strings.HasPrefix(sqlText, "INVERTED") {
continue
}
body.WriteString("\n")
colName := strings.SplitN(sqlText, " ", 2)[0]
body.WriteString("\t\"")
body.WriteString(colName)
body.WriteString(`",`)
}
body.WriteString("\n}\n")
}
return body.String()
}