Skip to content

Commit

Permalink
refactor: condense trial update functions (#8808)
Browse files Browse the repository at this point in the history
  • Loading branch information
amandavialva01 committed Feb 7, 2024
1 parent 45c578b commit a35696d
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 51 deletions.
2 changes: 1 addition & 1 deletion master/internal/api_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,7 @@ func (a *apiServer) PostTrialRunnerMetadata(
return nil, err
}

if err := a.m.db.UpdateTrialRunnerMetadata(int(req.TrialId), req.Metadata); err != nil {
if err := a.m.db.UpdateTrialFields(int(req.TrialId), req.Metadata, 0, 0); err != nil {
return nil, err
}

Expand Down
5 changes: 1 addition & 4 deletions master/internal/db/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,8 @@ type DB interface {
id int,
experimentBest, trialBest, trialLatest int,
) ([]uuid.UUID, error)
UpdateTrialRunnerState(id int, state string) error
UpdateTrialRunnerMetadata(id int, md *trialv1.TrialRunnerMetadata) error
UpdateTrialFields(id int, newRunnerMetadata *trialv1.TrialRunnerMetadata, newRunID, newRestarts int) error
TrialRunIDAndRestarts(trialID int) (int, int, error)
UpdateTrialRunID(id, runID int) error
UpdateTrialRestarts(id, restarts int) error
AddTrainingMetrics(ctx context.Context, m *trialv1.TrialMetrics) error
AddValidationMetrics(
ctx context.Context, m *trialv1.TrialMetrics,
Expand Down
74 changes: 37 additions & 37 deletions master/internal/db/postgres_trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,15 @@ func TrialByTaskID(ctx context.Context, taskID model.TaskID) (*model.Trial, erro
return &t, nil
}

// UpdateTrial updates an existing trial. Fields that are nil or zero are not
// updated. end_time is set if the trial moves to a terminal state.
// UpdateTrial updates the state of an existing trial.
// end_time is set if the trial moves to a terminal state.
func UpdateTrial(ctx context.Context, id int, newState model.State) error {
trial, err := TrialByID(ctx, id)
if err != nil {
return fmt.Errorf("error finding trial %v to update: %w", id, err)
}

// Update trial state if necessary.
if trial.State == newState {
return nil
}
Expand All @@ -168,7 +169,7 @@ func UpdateTrial(ctx context.Context, id int, newState model.State) error {

if model.TerminalStates[newState] && trial.EndTime != nil {
if _, err := tx.NewRaw(`UPDATE tasks SET end_time = ? FROM run_id_task_id
WHERE run_id_task_id.task_id = tasks.task_id AND run_id_task_id.run_id = ? AND end_time IS NULL`,
WHERE run_id_task_id.task_id = tasks.task_id AND run_id_task_id.run_id = ? AND end_time IS NULL`,
*trial.EndTime, id).Exec(ctx); err != nil {
return fmt.Errorf("completing task: %w", err)
}
Expand All @@ -178,20 +179,41 @@ func UpdateTrial(ctx context.Context, id int, newState model.State) error {
})
}

// UpdateTrialRunnerState updates a trial runner's state.
func (db *PgDB) UpdateTrialRunnerState(id int, state string) error {
return db.UpdateTrialRunnerMetadata(id, &trialv1.TrialRunnerMetadata{State: state})
}
// UpdateTrialFields updates the specified fields of trial with ID id. Fields that are nil or zero
// are not updated.
func (db *PgDB) UpdateTrialFields(id int, newRunnerMetadata *trialv1.TrialRunnerMetadata, newRunID,
newRestarts int,
) error {
ctx := context.TODO()
trial, err := TrialByID(ctx, id)
if err != nil {
return fmt.Errorf("error finding trial %v to update: %w", id, err)
}

// UpdateTrialRunnerMetadata updates a trial's metadata about its runner.
func (db *PgDB) UpdateTrialRunnerMetadata(id int, md *trialv1.TrialRunnerMetadata) error {
if _, err := db.sql.Exec(`
UPDATE runs
SET runner_state = $2
WHERE id = $1`, id, md.State); err != nil {
return errors.Wrap(err, "saving trial runner state")
var toUpdate []string

// Update trial runner's state if necessary.
if newRunnerMetadata != nil {
trial.RunnerState = newRunnerMetadata.State
toUpdate = append(toUpdate, "runner_state")
}
return nil

// Update trial's run id if necessary.
if newRunID > 0 {
trial.RunID = newRunID
toUpdate = append(toUpdate, "restart_id")
}

// Update trial's restart count if necessary.
if newRestarts > 0 {
trial.Restarts = newRestarts
toUpdate = append(toUpdate, "restarts")
}

run, _ := trial.ToRunAndTrialV2()
_, err = Bun().NewUpdate().Model(run).Column(toUpdate...).Where("id = ?", id).Exec(ctx)

return err
}

// TrialRunIDAndRestarts returns the run id and restart count for a trial.
Expand All @@ -206,28 +228,6 @@ WHERE id = $1`, trialID).Scan(&runID, &restart); err != nil {
return runID, restart, nil
}

// UpdateTrialRunID sets the trial's run ID.
func (db *PgDB) UpdateTrialRunID(id, runID int) error {
if _, err := db.sql.Exec(`
UPDATE runs
SET restart_id = $2
WHERE id = $1`, id, runID); err != nil {
return errors.Wrap(err, "updating trial run id")
}
return nil
}

// UpdateTrialRestarts sets the trial's restart count.
func (db *PgDB) UpdateTrialRestarts(id, restartCount int) error {
if _, err := db.sql.Exec(`
UPDATE runs
SET restarts = $2
WHERE id = $1`, id, restartCount); err != nil {
return errors.Wrap(err, "updating trial restarts")
}
return nil
}

// fullTrialSummaryMetricsRecompute recomputes all summary metrics for a given trial.
func (db *PgDB) fullTrialSummaryMetricsRecompute(
ctx context.Context, tx *sqlx.Tx, trialID int,
Expand Down
14 changes: 7 additions & 7 deletions master/internal/db/postgres_trial_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func addMetrics(ctx context.Context,
},
}))
trialRunID++
require.NoError(t, db.UpdateTrialRunID(trialID, trialRunID))
require.NoError(t, db.UpdateTrialFields(trialID, nil, trialRunID, 0))
}

metrics, err := structpb.NewStruct(m)
Expand Down Expand Up @@ -86,7 +86,7 @@ func addMetrics(ctx context.Context,
},
}))
trialRunID++
require.NoError(t, db.UpdateTrialRunID(trialID, trialRunID))
require.NoError(t, db.UpdateTrialFields(trialID, nil, trialRunID, 0))
}

metrics, err := structpb.NewStruct(m)
Expand Down Expand Up @@ -935,7 +935,7 @@ func TestAddValidationMetricsDupeCheckpoints(t *testing.T) {
StartTime: ptrs.Ptr(time.Now()),
}
require.NoError(t, AddAllocation(ctx, a))
require.NoError(t, db.UpdateTrialRunID(tr.ID, 1))
require.NoError(t, db.UpdateTrialFields(tr.ID, nil, 1, 0))

// Now trial runs validation.
require.NoError(t, db.AddValidationMetrics(ctx, &trialv1.TrialMetrics{
Expand Down Expand Up @@ -1019,7 +1019,7 @@ func TestBatchesProcessedNRollbacks(t *testing.T) {
testMetricReporting := func(typ string, trialRunId, batches, expectedTotalBatches int,
expectedRollbacks Rollbacks,
) error {
require.NoError(t, db.UpdateTrialRunID(tr.ID, trialRunId))
require.NoError(t, db.UpdateTrialFields(tr.ID, nil, trialRunId, 0))
trialMetrics := &trialv1.TrialMetrics{
TrialId: int32(tr.ID),
TrialRunId: int32(trialRunId),
Expand Down Expand Up @@ -1117,9 +1117,9 @@ func TestUpdateTrialRunnerMetadata(t *testing.T) {
exp := RequireMockExperiment(t, db, user)
trialID := RequireMockTrialID(t, db, exp)

require.NoError(t, db.UpdateTrialRunnerMetadata(trialID, &trialv1.TrialRunnerMetadata{
require.NoError(t, db.UpdateTrialFields(trialID, &trialv1.TrialRunnerMetadata{
State: "expectedState",
}))
}, 0, 0))

actual := struct {
bun.BaseModel `bun:"table:runs"`
Expand Down Expand Up @@ -1168,7 +1168,7 @@ func TestGenericMetricsIO(t *testing.T) {

trialRunID := 1
batches := 10
require.NoError(t, db.UpdateTrialRunID(tr.ID, trialRunID))
require.NoError(t, db.UpdateTrialFields(tr.ID, nil, trialRunID, 0))
trialMetrics := &trialv1.TrialMetrics{
TrialId: int32(tr.ID),
TrialRunId: int32(trialRunID),
Expand Down
4 changes: 2 additions & 2 deletions master/internal/trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ func (t *trial) addTask(ctx context.Context) error {
}

func (t *trial) buildTaskSpecifier() (*tasks.TrialSpec, error) {
if err := t.db.UpdateTrialRunID(t.id, t.runID); err != nil {
if err := t.db.UpdateTrialFields(t.id, nil, t.runID, 0); err != nil {
return nil, errors.Wrap(err, "failed to save trial run ID")
}

Expand Down Expand Up @@ -616,7 +616,7 @@ func (t *trial) handleAllocationExit(exit *task.AllocationExited) error {
WithError(exit.Err).
Errorf("trial failed (restart %d/%d)", t.restarts, t.config.MaxRestarts())
t.restarts++
if err := t.db.UpdateTrialRestarts(t.id, t.restarts); err != nil {
if err := t.db.UpdateTrialFields(t.id, nil, 0, t.restarts); err != nil {
return t.transition(model.StateWithReason{
State: model.ErrorState,
InformationalReason: err.Error(),
Expand Down
6 changes: 6 additions & 0 deletions master/pkg/model/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,8 @@ type Trial struct {
TotalBatches int `db:"total_batches"`
ExternalTrialID *string `db:"external_trial_id"`
RunID int `db:"run_id"` // run_id as in restart_id not "runs" id.
Restarts int `db:"restarts"`
RunnerState string `db:"runner_state"`
LastActivity *time.Time `db:"last_activity"`
}

Expand All @@ -474,6 +476,8 @@ func (t *Trial) ToRunAndTrialV2() (*Run, *TrialV2) {
TotalBatches: t.TotalBatches,
ExternalRunID: t.ExternalTrialID,
RestartID: t.RunID,
Restarts: t.Restarts,
RunnerState: t.RunnerState,
LastActivity: t.LastActivity,
}
v2 := &TrialV2{
Expand Down Expand Up @@ -508,6 +512,8 @@ type Run struct {
TotalBatches int `db:"total_batches"`
ExternalRunID *string `db:"external_trial_id"`
RestartID int `db:"restart_id"`
Restarts int `db:"restarts"`
RunnerState string `db:"runner_state"`
LastActivity *time.Time `db:"last_activity"`
}

Expand Down

0 comments on commit a35696d

Please sign in to comment.