| package database |
| |
| import ( |
| "bufio" |
| "database/sql" |
| "flag" |
| "fmt" |
| "os" |
| "strings" |
| "time" |
| |
| _ "github.com/go-sql-driver/mysql" |
| "go.skia.org/infra/go/metadata" |
| "go.skia.org/infra/go/sklog" |
| "go.skia.org/infra/go/util" |
| ) |
| |
| const ( |
| // Template for DB connection strings. |
| DB_CONN_TMPL = "%s:%s@tcp(%s:%d)/%s?parseTime=true" |
| |
| // Default database driver. |
| DEFAULT_DRIVER = "mysql" |
| |
| // Name of the root user. |
| USER_ROOT = "root" |
| |
| // Name of the readwrite user. |
| USER_RW = "readwrite" |
| ) |
| |
| // DatabaseConfig contains information required to create a database connection. |
| type DatabaseConfig struct { |
| Host string |
| Port int |
| User string |
| Name string |
| Password string |
| MigrationSteps []MigrationStep |
| } |
| |
| // ConfigFromPrefixedFlags adds command-line flags for the database with the given prefix. |
| func ConfigFromPrefixedFlags(defaultHost string, defaultPort int, defaultUser, defaultDatabase string, migrationSteps []MigrationStep, prefix string) *DatabaseConfig { |
| c := DatabaseConfig{ |
| MigrationSteps: migrationSteps, |
| } |
| flag.StringVar(&c.Host, prefix+"db_host", defaultHost, "Hostname of the MySQL database server.") |
| flag.IntVar(&c.Port, prefix+"db_port", defaultPort, "Port number of the MySQL database.") |
| flag.StringVar(&c.User, prefix+"db_user", defaultUser, "MySQL user name.") |
| flag.StringVar(&c.Name, prefix+"db_name", defaultDatabase, "Name of the MySQL database.") |
| return &c |
| } |
| |
| // ConfigFromPrefixedFlags adds command-line flags for the database. |
| func ConfigFromFlags(defaultHost string, defaultPort int, defaultUser, defaultDatabase string, migrationSteps []MigrationStep) *DatabaseConfig { |
| return ConfigFromPrefixedFlags(defaultHost, defaultPort, defaultUser, defaultDatabase, migrationSteps, "") |
| } |
| |
| // validate returns an error if the command-line flags have not been set. |
| func (c *DatabaseConfig) validate() error { |
| if c.Host == "" || c.Port == 0 || c.User == "" || c.Name == "" { |
| return fmt.Errorf( |
| "One or more of the required command-line flags was not set. " + |
| "Did you call forget to call flag.Parse?") |
| } |
| return nil |
| } |
| |
| // PromptForPassword prompts for a password and sets the Password field. |
| func (c *DatabaseConfig) PromptForPassword() error { |
| if err := c.validate(); err != nil { |
| return err |
| } |
| |
| reader := bufio.NewReader(os.Stdin) |
| fmt.Printf("Enter password for MySQL user %s: ", c.User) |
| pw, err := reader.ReadString('\n') |
| if err != nil { |
| return fmt.Errorf("Failed to get password: %v", err) |
| } |
| c.Password = strings.Trim(pw, "\n") |
| return nil |
| } |
| |
| // GetPasswordFromMetadata retrieve the password from metadata and sets the Password field. |
| func (c *DatabaseConfig) GetPasswordFromMetadata() error { |
| if err := c.validate(); err != nil { |
| return err |
| } |
| key := "" |
| if c.User == USER_RW { |
| key = metadata.DATABASE_RW_PASSWORD |
| } else if c.User == USER_ROOT { |
| key = metadata.DATABASE_ROOT_PASSWORD |
| } |
| if key == "" { |
| return fmt.Errorf("Unknown user %s; could not obtain password from metadata.", c.User) |
| } |
| password, err := metadata.ProjectGet(key) |
| if err != nil { |
| return fmt.Errorf("Failed to find metadata.") |
| } |
| c.Password = password |
| return nil |
| } |
| |
| // MySQLString returns a MySQL connection string derived from the DatabaseConfig. |
| func (c *DatabaseConfig) MySQLString() string { |
| return fmt.Sprintf(DB_CONN_TMPL, c.User, c.Password, c.Host, c.Port, c.Name) |
| } |
| |
| // Single step to migrated from one database version to the next and back. |
| type MigrationStep struct { |
| MySQLUp []string |
| MySQLDown []string |
| } |
| |
| // Database handle to send queries to the underlying database. |
| type VersionedDB struct { |
| // Database intance that is backed by MySQL. |
| DB *sql.DB |
| |
| // List of migration steps for this database. |
| migrationSteps []MigrationStep |
| } |
| |
| // Init must be called once before DB is used. |
| // |
| // Since it used glog, make sure it is also called after flag.Parse is called. |
| func (c *DatabaseConfig) NewVersionedDB() (*VersionedDB, error) { |
| if err := c.validate(); err != nil { |
| return nil, err |
| } |
| |
| // If there is a connection string then connect to the MySQL server. |
| // This is for testing only. In production we get the relevant information |
| // from the metadata server. |
| var err error |
| var DB *sql.DB = nil |
| |
| sklog.Infoln("Opening SQL database.") |
| DB, err = sql.Open(DEFAULT_DRIVER, c.MySQLString()) |
| if err != nil { |
| return nil, fmt.Errorf("Failed to open connection to SQL server: %v", err) |
| } |
| |
| sklog.Infoln("Sending Ping.") |
| if err := DB.Ping(); err != nil { |
| return nil, fmt.Errorf("Failed to ping SQL server: %v", err) |
| } |
| |
| // As outlined in this comment: |
| // https://github.com/go-sql-driver/mysql/issues/257#issuecomment-48985975 |
| // We can remove this once we have determined it's not necessary. |
| DB.SetMaxIdleConns(0) |
| DB.SetMaxOpenConns(200) |
| |
| result := &VersionedDB{ |
| DB: DB, |
| migrationSteps: c.MigrationSteps, |
| } |
| |
| // Make sure the migration table exists. |
| if err := result.checkVersionTable(); err != nil { |
| return nil, fmt.Errorf("Attempt to create version table returned: %v", err) |
| } |
| sklog.Infoln("Version table OK.") |
| |
| // Ping the database occasionally to keep the connection fresh. |
| go func() { |
| c := time.Tick(1 * time.Minute) |
| for range c { |
| if err := result.DB.Ping(); err != nil { |
| sklog.Warningln("Database failed to respond:", err) |
| } |
| sklog.Infof("db: Successful ping") |
| } |
| }() |
| |
| return result, nil |
| } |
| |
| // Close the underlying database. |
| func (vdb *VersionedDB) Close() error { |
| return vdb.DB.Close() |
| } |
| |
| // Migrates the database to the specified target version. Use DBVersion() to |
| // retrieve the current version of the database. |
| func (vdb *VersionedDB) Migrate(targetVersion int) (rv error) { |
| if (targetVersion < 0) || (targetVersion > vdb.MaxDBVersion()) { |
| sklog.Fatalf("Target db version must be in range: [0 .. %d]", vdb.MaxDBVersion()) |
| } |
| |
| currentVersion, err := vdb.DBVersion() |
| if err != nil { |
| return err |
| } |
| |
| if currentVersion > vdb.MaxDBVersion() { |
| sklog.Fatalf("Version table is out of date with current DB version.") |
| } |
| |
| if targetVersion == currentVersion { |
| return nil |
| } |
| |
| // start a transaction |
| txn, err := vdb.DB.Begin() |
| if err != nil { |
| return err |
| } |
| defer func() { rv = CommitOrRollback(txn, rv) }() |
| |
| // run through the transactions |
| runSteps := vdb.getMigrations(currentVersion, targetVersion) |
| if len(runSteps) == 0 { |
| sklog.Fatalln("Unable to find migration steps.") |
| } |
| |
| for _, step := range runSteps { |
| for _, stmt := range step { |
| sklog.Infoln("EXECUTING: \n", stmt) |
| if _, err = txn.Exec(stmt); err != nil { |
| return err |
| } |
| } |
| } |
| |
| // update the dbversion table |
| if err = vdb.setDBVersion(txn, targetVersion); err != nil { |
| return err |
| } |
| |
| return nil |
| } |
| |
| // Returns the current version of the database. It assumes that the |
| // Migrate function has already been called and the version table has been |
| // created in the database. |
| func (vdb *VersionedDB) DBVersion() (int, error) { |
| stmt := ` |
| SELECT version |
| FROM sk_db_version |
| WHERE id=1` |
| |
| var version int |
| err := vdb.DB.QueryRow(stmt).Scan(&version) |
| return version, err |
| } |
| |
| // Returns the highest version currently available. |
| func (vdb *VersionedDB) MaxDBVersion() int { |
| return len(vdb.migrationSteps) |
| } |
| |
| // IsLatestVersion returns true or false depending on whether the DB is up to |
| // date. If an error occured it will log it and return false. If the |
| // current version is not the latest possible version it will log to info. |
| func (vdb *VersionedDB) IsLatestVersion() bool { |
| dbVersion, err := vdb.DBVersion() |
| if err != nil { |
| sklog.Errorf("Error retrieving db version: %s", err) |
| return false |
| } |
| |
| if dbVersion != vdb.MaxDBVersion() { |
| sklog.Infof("The current DB version is %d. The latest available version is %d", dbVersion, vdb.MaxDBVersion()) |
| return false |
| } |
| return true |
| } |
| |
| // Returns an error if the version table does not exist. |
| func (vdb *VersionedDB) checkVersionTable() error { |
| // Check if the table exists in MySQL. |
| stmt := "SHOW TABLES LIKE 'sk_db_version'" |
| |
| var temp string |
| err := vdb.DB.QueryRow(stmt).Scan(&temp) |
| if err != nil { |
| // See if we can create the version table. |
| return vdb.ensureVersionTable() |
| } |
| |
| return nil |
| } |
| |
| func (vdb *VersionedDB) setDBVersion(txn *sql.Tx, newDBVersion int) error { |
| stmt := `REPLACE INTO sk_db_version (id, version, updated) VALUES(1, ?, ?)` |
| _, err := txn.Exec(stmt, newDBVersion, time.Now().Unix()) |
| return err |
| } |
| |
| func (vdb *VersionedDB) ensureVersionTable() (rv error) { |
| txn, err := vdb.DB.Begin() |
| defer func() { rv = CommitOrRollback(txn, rv) }() |
| |
| if err != nil { |
| return fmt.Errorf("Unable to start database transaction. %s", err) |
| } |
| |
| stmt := `CREATE TABLE IF NOT EXISTS sk_db_version ( |
| id INTEGER NOT NULL PRIMARY KEY, |
| version INTEGER NOT NULL, |
| updated BIGINT NOT NULL |
| )` |
| if _, err = txn.Exec(stmt); err != nil { |
| return fmt.Errorf("Creating version table failed: %s", err) |
| } |
| |
| stmt = "SELECT COUNT(*) FROM sk_db_version" |
| var count int |
| if err = txn.QueryRow(stmt).Scan(&count); err != nil { |
| return fmt.Errorf("Unable to read version table: %s", err) |
| } |
| |
| // In both cases we want the transaction to roll back. |
| if count == 0 { |
| err = vdb.setDBVersion(txn, 0) |
| } else if count > 1 { |
| err = fmt.Errorf("Version table contains more than one row.") |
| } |
| |
| return err |
| } |
| |
| // Returns the SQL statements base on whether we are using MySQL and the |
| // current and target DB version. |
| // This function assumes that currentVersion != targetVersion. |
| func (vdb *VersionedDB) getMigrations(currentVersion int, targetVersion int) [][]string { |
| inc := util.SignInt(targetVersion - currentVersion) |
| idx := currentVersion |
| if inc < 0 { |
| idx = currentVersion - 1 |
| } |
| delta := util.AbsInt(targetVersion - currentVersion) |
| result := make([][]string, 0, delta) |
| |
| for i := 0; i < delta; i++ { |
| var temp []string |
| switch { |
| // using mysqlp |
| case (inc > 0): |
| temp = vdb.migrationSteps[idx].MySQLUp |
| case (inc < 0): |
| temp = vdb.migrationSteps[idx].MySQLDown |
| } |
| result = append(result, temp) |
| idx += inc |
| } |
| return result |
| } |
| |
| // Tx wraps the Commit and Rollback methods of a database transaction. |
| type Tx interface { |
| Commit() error |
| Rollback() error |
| } |
| |
| // CommitOrRollback is a function which commits or rolls back a database |
| // transaction, depending on whether or not the function returned an error, |
| // and logs any errors it encounters. Use it like this: |
| // |
| // defer func() { rv = CommitOrRollback(tx, rv) } |
| // |
| func CommitOrRollback(tx Tx, err error) error { |
| if err != nil { |
| if err2 := tx.Rollback(); err2 != nil { |
| return fmt.Errorf("%v; failed to rollback: %v", err, err2) |
| } else { |
| return fmt.Errorf("%v; transaction rolled back.", err) |
| } |
| } else { |
| return tx.Commit() |
| } |
| } |