blob: a2335d2aef7785fe08829bbf078fad85f447352b [file] [log] [blame]
package sqltest
import (
"context"
"crypto/rand"
"fmt"
"math"
"math/big"
"os/exec"
"reflect"
"strings"
"testing"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.skia.org/infra/go/emulators"
"go.skia.org/infra/go/skerr"
"go.skia.org/infra/go/testutils/unittest"
"go.skia.org/infra/go/util"
"go.skia.org/infra/golden/go/sql"
"go.skia.org/infra/golden/go/sql/schema"
)
// NewCockroachDBForTests creates a randomly named database on a test CockroachDB instance (aka the
// CockroachDB emulator). The returned pool will automatically be closed after the test finishes.
func NewCockroachDBForTests(ctx context.Context, t testing.TB) *pgxpool.Pool {
unittest.RequiresCockroachDB(t)
out, err := exec.Command("cockroach", "version").CombinedOutput()
require.NoError(t, err, "Do you have 'cockroach' on your path? %s", out)
n, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64))
require.NoError(t, err)
dbName := "for_tests" + n.String()
host := emulators.GetEmulatorHostEnvVar(emulators.CockroachDB)
out, err = exec.Command("cockroach", "sql", "--insecure", "--host="+host,
"--execute=CREATE DATABASE IF NOT EXISTS "+dbName).CombinedOutput()
require.NoError(t, err, `creating test database: %s
If running locally, make sure you set the env var COCKROACHDB_EMULATOR_STORE_DIR and ran:
./scripts/run_emulators/run_emulators start
and now currently have %s set. Even though we call it an "emulator",
this sets up a real version of cockroachdb.
`, out, emulators.GetEmulatorHostEnvVarName(emulators.CockroachDB))
connectionString := fmt.Sprintf("postgresql://root@%s/%s?sslmode=disable", host, dbName)
conf, err := pgxpool.ParseConfig(connectionString)
require.NoError(t, err)
conf.MaxConns = 4
conn, err := pgxpool.ConnectConfig(ctx, conf)
require.NoError(t, err)
t.Cleanup(func() {
conn.Close()
})
return conn
}
// NewCockroachDBForTestsWithProductionSchema returns a SQL database with the production
// schema. It will be aimed at a randomly named database.
func NewCockroachDBForTestsWithProductionSchema(ctx context.Context, t testing.TB) *pgxpool.Pool {
db := NewCockroachDBForTests(ctx, t)
_, err := db.Exec(ctx, schema.Schema)
require.NoError(t, err)
return db
}
// SQLExporter is an abstraction around a type that can be written as a single row in a SQL table.
type SQLExporter interface {
// ToSQLRow returns the column names and the column data that should be written for this row.
ToSQLRow() (colNames []string, colData []interface{})
}
// SQLScanner is an abstraction around reading a single row from an SQL table.
type SQLScanner interface {
// ScanFrom takes in a function that takes in any number of pointers and will fill in the data.
// The arguments passed into scan should be the row fields in the order they appear in the
// table schema, as they will be filled in via a SELECT * FROM Table.
ScanFrom(scan func(...interface{}) error) error
}
// RowsOrder is an option that rows in a table can implement to specify the ordering of the returned
// data (to make for easier to debug tests).
type RowsOrder interface {
// RowsOrderBy returns a SQL fragment like "ORDER BY my_field DESC".
RowsOrderBy() string
}
// BulkInsertDataTables adds all the data from tables to the provided database. tables is expected
// to be a struct that contains fields which are slices of SQLExporter. The tables will be inserted
// in the same order that the fields are in the struct - if there are foreign key relationships,
// be sure to order them correctly. This method panics if the passed in tables parameter is of
// the wrong type.
func BulkInsertDataTables(ctx context.Context, db *pgxpool.Pool, tables interface{}) error {
// It's tempting to insert these in parallel, but that could make foreign keys flaky.
v := reflect.ValueOf(tables)
for i := 0; i < v.NumField(); i++ {
tableName := v.Type().Field(i).Name
table := v.Field(i) // Fields of the outer type are expected to be tables.
if table.Kind() != reflect.Slice {
panic(`Expected table should be a slice: ` + tableName)
}
if err := writeToTable(ctx, db, tableName, table); err != nil {
return skerr.Wrap(err)
}
}
return nil
}
func writeToTable(ctx context.Context, db *pgxpool.Pool, name string, table reflect.Value) error {
var arguments []interface{}
var colNames []string
numRows := table.Len()
// Go through each element of the table slice, cast it to ToSQLRow and then call that
// function on it to get the arguments needed for that row.
for j := 0; j < numRows; j++ {
r := table.Index(j)
row, ok := r.Interface().(SQLExporter)
if !ok {
panic(`Expected table should be a slice of types that implement ToSQLRow: ` + name)
}
cn, args := row.ToSQLRow()
if len(colNames) == 0 {
colNames = cn
}
if len(colNames) != len(args) {
panic(`Expected length of colNames and arguments to match for ` + name)
}
arguments = append(arguments, args...)
}
if len(arguments) == 0 {
return nil
}
numCols := len(colNames)
// a chunkSize of 3000 means we can stay under the 64k number limit as long as there are
// fewer than 22 columns, which should be realistic for all our tables.
err := util.ChunkIter(numRows, 3000, func(startIdx int, endIdx int) error {
argBatch := arguments[startIdx*numCols : endIdx*numCols]
vp := sql.ValuesPlaceholders(numCols, endIdx-startIdx)
insert := fmt.Sprintf(`INSERT INTO %s (%s) VALUES %s`, name, strings.Join(colNames, ","), vp)
_, err := db.Exec(ctx, insert, argBatch...)
return skerr.Wrapf(err, "batch: %d-%d (%d-%d)", startIdx, endIdx, startIdx*numCols, endIdx*numCols)
})
return skerr.Wrapf(err, "Inserting %d rows into table %s", numRows, name)
}
// GetAllRows returns all rows for a given table. The passed in row param must be a pointer type
// that implements the sqltest.Scanner interface. The passed in row may optionally implement the
// sqltest.RowsOrder interface to specify an ordering to return the rows in (this can make for
// easier to debug tests).
// The returned value is a slice of the provided row type (without a pointer) and can be converted
// to a normal slice via a type assertion. If anything goes wrong, the function will panic.
func GetAllRows(ctx context.Context, t *testing.T, db *pgxpool.Pool, table string, row interface{}, whereClauses ...string) interface{} {
if _, ok := row.(SQLScanner); !ok {
require.Fail(t, "Row does not implement SQLScanner. Need pointer type.", "%#v", row)
}
whereClause := ""
if len(whereClauses) > 0 {
whereClause = whereClauses[0]
}
statement := `SELECT * FROM ` + table + " " + whereClause
if ro, ok := row.(RowsOrder); ok {
statement += " " + ro.RowsOrderBy()
}
rows, err := db.Query(ctx, statement)
require.NoError(t, err)
defer rows.Close()
// typ now refers to the non-pointer type of row (aka RowType).
typ := reflect.TypeOf(row).Elem()
// Make an empty slice of RowType.
rv := reflect.MakeSlice(reflect.SliceOf(typ), 0, 5)
for rows.Next() {
// Create a new *RowType
newVal := reflect.New(typ)
// Convert it to SQLScanner. This should always succeed because of the type
// assertion earlier.
s := newVal.Interface().(SQLScanner)
// Have the type extract its values from the rows object.
require.NoError(t, s.ScanFrom(rows.Scan))
// Add the dereferenced (non-pointer) type to the slice
rv = reflect.Append(rv, newVal.Elem())
}
// Return the slice as type interface{}; It can be type asserted to []RowType by the caller.
return rv.Interface()
}
// AssertNoFullTableScans runs the given query with an explain and ensures the plan does not
// include a full table scan.
func AssertNoFullTableScans(t *testing.T, db *pgxpool.Pool, statement string, args ...interface{}) {
assert.NotContains(t, GetExplain(t, db, statement, args...), "FULL")
}
// GetExplain returns the query plan for a given statement and arguments.
func GetExplain(t *testing.T, db *pgxpool.Pool, statement string, args ...interface{}) string {
rows, err := db.Query(context.Background(), "EXPLAIN "+statement, args...)
require.NoError(t, err)
defer rows.Close()
var explainRows []string
for rows.Next() {
var tree string
var field string
var desc string
err := rows.Scan(&tree, &field, &desc)
require.NoError(t, err)
explainRows = append(explainRows, fmt.Sprintf("%s | %s | %s", tree, field, desc))
}
return strings.Join(explainRows, "\n")
}