Skip to content

Commit

Permalink
chore: only connect to the database once (#9456)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicholasBlaskey authored May 31, 2024
1 parent 0fdb822 commit d960f29
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 53 deletions.
29 changes: 16 additions & 13 deletions master/internal/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -1129,31 +1129,34 @@ func (m *Master) Run(ctx context.Context, gRPCLogInitDone chan struct{}) error {
return errors.Wrap(err, "could not set static root")
}

isBrandNewCluster, err := db.IsNew(&m.config.DB)
if err != nil {
return errors.Wrap(err, "could not verify database version")
}
var isOldCluster bool
newClustersRequirePasswords := func(*db.PgDB) error {
isOldCluster, err = db.Bun().NewSelect().Table("pg_tables").
Where("schemaname = 'public'").
Where("tablename = 'gopg_migrations'").
Exists(ctx)
if err != nil {
return fmt.Errorf("checking if database is fresh: %w", err)
}

if isBrandNewCluster && slices.Contains(m.config.FeatureSwitches, "prevent_blank_password") {
// This has to happen before setup, to minimize risk of creating a database in a state that looks like
// there are already users, then aborting, which would allow a subsequent cluster to come up ignoring
// this check.
password := m.config.Security.InitialUserPassword
if password == "" {
if !isOldCluster &&
slices.Contains(m.config.FeatureSwitches, "prevent_blank_password") &&
m.config.Security.InitialUserPassword == "" {
log.Error("This cluster was deployed without an initial password for the built-in `determined` " +
"and `admin` users. New clusters can be deployed with initial passwords set using the " +
"`security.initial_user_password` setting.")
return errors.New("could not deploy without initial password")
}
}

m.db, err = db.Setup(&m.config.DB)
return nil
}
m.db, err = db.Setup(&m.config.DB, newClustersRequirePasswords)
if err != nil {
return err
}
defer closeWithErrCheck("db", m.db)

if isBrandNewCluster {
if !isOldCluster {
// This has to happen after setup, since creating the built-in users without a
// password is part of the first migration.
password := m.config.Security.InitialUserPassword
Expand Down
49 changes: 9 additions & 40 deletions master/internal/db/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"database/sql"
"fmt"

"github.com/go-pg/migrations/v8"
"github.com/go-pg/pg/v10"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"

Expand Down Expand Up @@ -85,50 +83,21 @@ func Connect(opts *config.DBConfig) (*PgDB, error) {
return db, nil
}

// IsNew checks to see if the database's migration tracking tables have been
// created, and if so, if it's above version 0. It returns `false` if the current
// version exists and is higher than zero, `true` otherwise.
// This is not guaranteed to be accurate if the database is otherwise in a bad or
// incomplete state. If an error is returned, the bool should be ignored.
func IsNew(opts *config.DBConfig) (bool, error) {
dbURL := fmt.Sprintf(cnxTpl, opts.User, opts.Password, opts.Host, opts.Port, opts.Name)
dbURL += fmt.Sprintf(sslTpl, opts.SSLMode, opts.SSLRootCert)
pgOpts, err := makeGoPgOpts(dbURL)
if err != nil {
return false, err
}

pgConn := pg.Connect(pgOpts)
defer func() {
if errd := pgConn.Close(); errd != nil {
log.Errorf("error closing pg connection: %s", errd)
}
}()

exist, err := tablesExist(pgConn, []string{"gopg_migrations", "schema_migrations"})
if err != nil {
return false, err
}
if !exist["gopg_migrations"] {
return true, nil
}

collection := migrations.NewCollection()
collection.DisableSQLAutodiscover(true)
version, err := collection.Version(pgConn)
if err != nil {
return false, err
}
return version == 0, nil
}

// Setup connects to the database and run any necessary migrations.
func Setup(opts *config.DBConfig) (db *PgDB, err error) {
func Setup(
opts *config.DBConfig, postConnectHooks ...func(*PgDB) error,
) (db *PgDB, err error) {
db, err = Connect(opts)
if err != nil {
return db, err
}

for _, hook := range postConnectHooks {
if err := hook(db); err != nil {
return nil, err
}
}

err = db.Migrate(opts.Migrations, opts.ViewsAndTriggers, []string{"up"})
if err != nil {
return nil, fmt.Errorf("error running migrations: %s", err)
Expand Down

0 comments on commit d960f29

Please sign in to comment.