Skip to content

Commit

Permalink
fix: master checks db newness before migrating [DET-10312] (#9414)
Browse files Browse the repository at this point in the history
  • Loading branch information
jesse-amano-hpe authored May 23, 2024
1 parent da46208 commit 3cbe805
Show file tree
Hide file tree
Showing 9 changed files with 221 additions and 36 deletions.
2 changes: 1 addition & 1 deletion master/cmd/determined-master/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func runMigrate(cmd *cobra.Command, args []string) error {
}
}()

if _, err = database.Migrate(config.DB.Migrations, config.DB.ViewsAndTriggers, args); err != nil {
if err = database.Migrate(config.DB.Migrations, config.DB.ViewsAndTriggers, args); err != nil {
return errors.Wrap(err, "running migrations")
}

Expand Down
2 changes: 1 addition & 1 deletion master/cmd/determined-master/populate_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func runPopulate(cmd *cobra.Command, args []string) error {
}

masterConfig := config.GetMasterConfig()
database, _, err := db.Setup(&masterConfig.DB)
database, err := db.Setup(&masterConfig.DB)
if err != nil {
return err
}
Expand Down
33 changes: 23 additions & 10 deletions master/internal/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -1129,34 +1129,47 @@ func (m *Master) Run(ctx context.Context, gRPCLogInitDone chan struct{}) error {
return errors.Wrap(err, "could not set static root")
}

var isBrandNewCluster bool
m.db, isBrandNewCluster, err = db.Setup(&m.config.DB)
isBrandNewCluster, err := db.IsNew(&m.config.DB)
if err != nil {
return err
}
defer closeWithErrCheck("db", m.db)

m.ClusterID, err = m.db.GetOrCreateClusterID(m.config.Telemetry.ClusterID)
if err != nil {
return errors.Wrap(err, "could not fetch cluster id from database")
return errors.Wrap(err, "could not verify database version")
}

if isBrandNewCluster {
// This has to happen before setup, to minimize risk of creating a database in a state that looks like
// there are already users, then aborting, which would allow a subsequent cluster to come up ignoring
// this check.
password := m.config.Security.InitialUserPassword
if password == "" {
log.Error("This cluster was deployed without an initial password for the built-in `determined` " +
"and `admin` users. New clusters can be deployed with initial passwords set using the " +
"`security.initial_user_password` setting.")
return errors.New("could not deploy without initial password")
}
}

m.db, err = db.Setup(&m.config.DB)
if err != nil {
return err
}
defer closeWithErrCheck("db", m.db)

if isBrandNewCluster {
// This has to happen after setup, since creating the built-in users without a
// password is part of the first migration.
password := m.config.Security.InitialUserPassword
for _, username := range user.BuiltInUsers {
err := user.SetUserPassword(ctx, username, password)
if err != nil {
return fmt.Errorf("could not update default user password: %w", err)
return fmt.Errorf("could not set password for %s: %w", username, err)
}
}
}

m.ClusterID, err = m.db.GetOrCreateClusterID(m.config.Telemetry.ClusterID)
if err != nil {
return errors.Wrap(err, "could not fetch cluster id from database")
}

webhookManager, err := webhooks.New(ctx)
if err != nil {
return fmt.Errorf("initializing webhooks: %w", err)
Expand Down
102 changes: 102 additions & 0 deletions master/internal/core_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@
package internal

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/go-pg/pg/v10"

"github.com/determined-ai/determined/master/internal/config"
"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/mocks"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/tasks"

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -80,3 +86,99 @@ func TestHealthCheck(t *testing.T) {
})
})
}

func TestRun(t *testing.T) {
type testScenario struct {
name string
initialPassword string
repeats int
checkRunErr func(require.TestingT, error, ...interface{})
}

test := func(t *testing.T, scenario testScenario) {
pgdb, teardown := db.MustResolveNewPostgresDatabase(t)
t.Cleanup(teardown)
mockRM := MockRM()

pgOpts, err := pg.ParseURL(pgdb.URL)
require.NoError(t, err)

addr := strings.SplitN(pgOpts.Addr, ":", 2)

for i := 0; i < scenario.repeats; i++ {
m := &Master{
rm: mockRM,
config: &config.Config{
Security: config.SecurityConfig{
InitialUserPassword: scenario.initialPassword,
},
InternalConfig: config.InternalConfig{
ExternalSessions: model.ExternalSessions{},
},
TaskContainerDefaults: model.TaskContainerDefaultsConfig{},
ResourceConfig: *config.DefaultResourceConfig(),
Logging: model.LoggingConfig{
DefaultLoggingConfig: &model.DefaultLoggingConfig{},
},
},
taskSpec: &tasks.TaskSpec{SSHRsaSize: 1024},
}
require.NoError(t, m.config.Resolve())
m.config.DB = config.DBConfig{
User: pgOpts.User,
Password: pgOpts.Password,
Migrations: "file://../static/migrations",
ViewsAndTriggers: "../static/views_and_triggers",
Host: addr[0],
Port: addr[1],
Name: pgOpts.Database,
SSLMode: "disable",
}
// listen on any available port, we don't care
m.config.Port = 0

ctx, cancel := context.WithCancel(context.Background())
gRPCLogInitDone := make(chan struct{})
var runErr error
go func() {
defer cancel()
runErr = m.Run(ctx, gRPCLogInitDone)
}()

select {
case <-gRPCLogInitDone:
cancel()
case <-ctx.Done():
require.ErrorIs(t, ctx.Err(), context.Canceled)
}
scenario.checkRunErr(t, runErr)
}
}

scenarios := []testScenario{
{
name: "blank password",
initialPassword: "",
repeats: 5,
checkRunErr: require.Error,
},
// TODO: DET-10314 - the "happy path" is much harder to test than errors,
// because once Run() gets all the way to actually serving endpoints etc.
// there's a delicate shutdown ordering needed to avoid nil derefs on
// logging, db connections, etc.
// Running bcrypt is also ~20 seconds per password, so initialization is
// inherently slow.
// {
// name: "strong enough password",
// initialPassword: "testPassword1",
// repeats: 1,
// checkRunErr: require.NoError,
// },
}

for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
test(t, scenario)
})
}
}
2 changes: 1 addition & 1 deletion master/internal/db/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

// DB is an interface for _all_ the functionality packed into the DB.
type DB interface {
Migrate(migrationURL, codeURL string, actions []string) (isNew bool, err error)
Migrate(migrationURL, codeURL string, actions []string) error
Close() error
GetOrCreateClusterID(telemetryID string) (string, error)
TrialExperimentAndRequestID(id int) (int, model.RequestID, error)
Expand Down
26 changes: 13 additions & 13 deletions master/internal/db/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func makeGoPgOpts(dbURL string) (*pg.Options, error) {
return opts, nil
}

func tablesExist(tx *pg.Tx, tableNames []string) (map[string]bool, error) {
func tablesExist(tx pg.DBI, tableNames []string) (map[string]bool, error) {
existingTables := []string{}
result := map[string]bool{}
for _, tn := range tableNames {
Expand Down Expand Up @@ -216,7 +216,7 @@ var testOnlyDBLock func(sql *sqlx.DB) (unlock func())
// Migrate runs the migrations from the specified directory URL.
func (db *PgDB) Migrate(
migrationURL string, dbCodeDir string, actions []string,
) (isNew bool, err error) {
) error {
if testOnlyDBLock != nil {
// In integration tests, multiple processes can be running this code at once, which can lead to
// errors because PostgreSQL's CREATE TABLE IF NOT EXISTS is not great with concurrency.
Expand All @@ -226,14 +226,14 @@ func (db *PgDB) Migrate(

dbCodeFiles, hash, needToUpdateDBCode, err := db.readDBCodeAndCheckIfDifferent(dbCodeDir)
if err != nil {
return false, err
return err
}

// go-pg/migrations uses go-pg/pg connection API, which is not compatible
// with pgx, so we use a one-off go-pg/pg connection.
pgOpts, err := makeGoPgOpts(db.URL)
if err != nil {
return false, err
return err
}

pgConn := pg.Connect(pgOpts)
Expand All @@ -245,7 +245,7 @@ func (db *PgDB) Migrate(

tx, err := pgConn.Begin()
if err != nil {
return false, err
return err
}

defer func() {
Expand All @@ -256,33 +256,33 @@ func (db *PgDB) Migrate(
}()

if err = ensureMigrationUpgrade(tx); err != nil {
return false, errors.Wrap(err, "error upgrading migration metadata")
return errors.Wrap(err, "error upgrading migration metadata")
}

if err = tx.Commit(); err != nil {
return false, err
return err
}

log.Infof("running DB migrations from %s; this might take a while...", migrationURL)

re := regexp.MustCompile(`file://(.+)`)
match := re.FindStringSubmatch(migrationURL)
if len(match) != 2 {
return false, fmt.Errorf("failed to parse migrationsURL: %s", migrationURL)
return fmt.Errorf("failed to parse migrationsURL: %s", migrationURL)
}

collection := migrations.NewCollection()
collection.DisableSQLAutodiscover(true)
if err = collection.DiscoverSQLMigrations(match[1]); err != nil {
return false, err
return err
}
if len(collection.Migrations()) == 0 {
return false, errors.New("failed to discover any migrations")
return errors.New("failed to discover any migrations")
}

oldVersion, newVersion, err := collection.Run(pgConn, actions...)
if err != nil {
return false, errors.Wrap(err, "error applying migrations")
return errors.Wrap(err, "error applying migrations")
}

if oldVersion == newVersion {
Expand All @@ -295,13 +295,13 @@ func (db *PgDB) Migrate(
if needToUpdateDBCode {
log.Info("database views changed")
if err := db.addDBCode(dbCodeFiles, hash); err != nil {
return false, err
return err
}
} else {
log.Info("database views unchanged, will not updated")
}
}

log.Info("DB migrations completed")
return oldVersion == 0, nil
return nil
}
2 changes: 1 addition & 1 deletion master/internal/db/postgres_test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func MigrateTestPostgres(db *PgDB, migrationsPath string, actions ...string) err
if len(actions) == 0 {
actions = []string{"up"}
}
_, err := db.Migrate(
err := db.Migrate(
migrationsPath, strings.ReplaceAll(migrationsPath+"/../views_and_triggers", "file://", ""), actions)
if err != nil {
return fmt.Errorf("failed to migrate postgres: %w", err)
Expand Down
53 changes: 46 additions & 7 deletions master/internal/db/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"database/sql"
"fmt"

"github.com/go-pg/migrations/v8"
"github.com/go-pg/pg/v10"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"

Expand Down Expand Up @@ -83,24 +85,61 @@ func Connect(opts *config.DBConfig) (*PgDB, error) {
return db, nil
}

// IsNew checks to see if the database's migration tracking tables have been
// created, and if so, if it's above version 0. It returns `false` if the current
// version exists and is higher than zero, `true` otherwise.
// This is not guaranteed to be accurate if the database is otherwise in a bad or
// incomplete state. If an error is returned, the bool should be ignored.
func IsNew(opts *config.DBConfig) (bool, error) {
dbURL := fmt.Sprintf(cnxTpl, opts.User, opts.Password, opts.Host, opts.Port, opts.Name)
dbURL += fmt.Sprintf(sslTpl, opts.SSLMode, opts.SSLRootCert)
pgOpts, err := makeGoPgOpts(dbURL)
if err != nil {
return false, err
}

pgConn := pg.Connect(pgOpts)
defer func() {
if errd := pgConn.Close(); errd != nil {
log.Errorf("error closing pg connection: %s", errd)
}
}()

exist, err := tablesExist(pgConn, []string{"gopg_migrations", "schema_migrations"})
if err != nil {
return false, err
}
if !exist["gopg_migrations"] {
return true, nil
}

collection := migrations.NewCollection()
collection.DisableSQLAutodiscover(true)
version, err := collection.Version(pgConn)
if err != nil {
return false, err
}
return version == 0, nil
}

// Setup connects to the database and run any necessary migrations.
func Setup(opts *config.DBConfig) (db *PgDB, isNew bool, err error) {
func Setup(opts *config.DBConfig) (db *PgDB, err error) {
db, err = Connect(opts)
if err != nil {
return db, false, err
return db, err
}

isNew, err = db.Migrate(opts.Migrations, opts.ViewsAndTriggers, []string{"up"})
err = db.Migrate(opts.Migrations, opts.ViewsAndTriggers, []string{"up"})
if err != nil {
return nil, false, fmt.Errorf("error running migrations: %s", err)
return nil, fmt.Errorf("error running migrations: %s", err)
}

if err = InitAuthKeys(); err != nil {
return nil, false, err
return nil, err
}

if err = initAllocationSessions(context.TODO()); err != nil {
return nil, false, err
return nil, err
}
return db, isNew, nil
return db, nil
}
Loading

0 comments on commit 3cbe805

Please sign in to comment.