blob: ee430db8191cfc35521240eac9c761d8c404aa41 [file] [log] [blame]
package sqltest
import (
"context"
"crypto/rand"
"fmt"
"math"
"math/big"
"os/exec"
"reflect"
"strings"
"testing"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/stretchr/testify/require"
"go.skia.org/infra/go/emulators"
"go.skia.org/infra/go/skerr"
"go.skia.org/infra/go/testutils/unittest"
"go.skia.org/infra/golden/go/sql"
"go.skia.org/infra/golden/go/sql/schema"
)
// NewCockroachDBForTests creates a randomly named database on a test CockroachDB instance (aka the
// CockroachDB emulator). The returned pool will automatically be closed after the test finishes.
func NewCockroachDBForTests(ctx context.Context, t *testing.T) *pgxpool.Pool {
unittest.RequiresCockroachDB(t)
out, err := exec.Command("cockroach", "version").CombinedOutput()
require.NoError(t, err, "Do you have 'cockroach' on your path? %s", out)
n, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64))
require.NoError(t, err)
dbName := "for_tests" + n.String()
host := emulators.GetEmulatorHostEnvVar(emulators.CockroachDB)
out, err = exec.Command("cockroach", "sql", "--insecure", "--host="+host,
"--execute=CREATE DATABASE IF NOT EXISTS "+dbName).CombinedOutput()
require.NoError(t, err, `creating test database: %s
If running locally, make sure you set the env var COCKROACHDB_EMULATOR_STORE_DIR and ran:
./scripts/run_emulators/run_emulators start
and now currently have %s set. Even though we call it an "emulator",
this sets up a real version of cockroachdb.
`, out, emulators.GetEmulatorHostEnvVarName(emulators.CockroachDB))
connectionString := fmt.Sprintf("postgresql://root@%s/%s?sslmode=disable", host, dbName)
conn, err := pgxpool.Connect(ctx, connectionString)
require.NoError(t, err)
t.Cleanup(func() {
conn.Close()
})
return conn
}
// NewCockroachDBForTestsWithProductionSchema returns a SQL database with the production
// schema. It will be aimed at a randomly named database.
func NewCockroachDBForTestsWithProductionSchema(ctx context.Context, t *testing.T) *pgxpool.Pool {
db := NewCockroachDBForTests(ctx, t)
_, err := db.Exec(ctx, schema.Schema)
require.NoError(t, err)
return db
}
// SQLExporter is an abstraction around a type that can be written as a single row in a SQL table.
type SQLExporter interface {
// ToSQLRow returns the column names and the column data that should be written for this row.
ToSQLRow() (colNames []string, colData []interface{})
}
// BulkInsertDataTables adds all the data from tables to the provided database. tables is expected
// to be a struct that contains fields which are slices of SQLExporter. The tables will be inserted
// in the same order that the fields are in the struct - if there are foreign key relationships,
// be sure to order them correctly. This method panics if the passed in tables parameter is of
// the wrong type.
func BulkInsertDataTables(ctx context.Context, db *pgxpool.Pool, tables interface{}) error {
// It's tempting to insert these in parallel, but that could make foreign keys flaky.
v := reflect.ValueOf(tables)
for i := 0; i < v.NumField(); i++ {
tableName := v.Type().Field(i).Name
table := v.Field(i) // Fields of the outer type are expected to be tables.
if table.Kind() != reflect.Slice {
panic(`Expected table should be a slice: ` + tableName)
}
if err := writeToTable(ctx, db, tableName, table); err != nil {
return skerr.Wrap(err)
}
}
return nil
}
func writeToTable(ctx context.Context, db *pgxpool.Pool, name string, table reflect.Value) error {
var arguments []interface{}
var colNames []string
// Go through each element of the table slice, cast it to ToSQLRow and then call that
// function on it to get the arguments needed for that row.
for j := 0; j < table.Len(); j++ {
r := table.Index(j)
row, ok := r.Interface().(SQLExporter)
if !ok {
panic(`Expected table should be a slice of types that implement ToSQLRow: ` + name)
}
cn, args := row.ToSQLRow()
if len(colNames) == 0 {
colNames = cn
}
if len(colNames) != len(args) {
panic(`Expected length of colNames and arguments to match for ` + name)
}
arguments = append(arguments, args...)
}
if len(arguments) == 0 {
return nil
}
vp := sql.ValuesPlaceholders(len(colNames), table.Len())
insert := fmt.Sprintf(`INSERT INTO %s (%s) VALUES %s`, name, strings.Join(colNames, ","), vp)
_, err := db.Exec(ctx, insert, arguments...)
return skerr.Wrapf(err, "Inserting %d rows into table %s", table.Len(), name)
}