Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: condense trial update functions #8808

Merged
merged 1 commit into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion master/internal/api_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,7 @@ func (a *apiServer) PostTrialRunnerMetadata(
return nil, err
}

if err := a.m.db.UpdateTrialRunnerMetadata(int(req.TrialId), req.Metadata); err != nil {
if err := a.m.db.UpdateTrialFields(int(req.TrialId), req.Metadata, 0, 0); err != nil {
return nil, err
}

Expand Down
5 changes: 1 addition & 4 deletions master/internal/db/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,8 @@ type DB interface {
id int,
experimentBest, trialBest, trialLatest int,
) ([]uuid.UUID, error)
UpdateTrialRunnerState(id int, state string) error
UpdateTrialRunnerMetadata(id int, md *trialv1.TrialRunnerMetadata) error
UpdateTrialFields(id int, newRunnerMetadata *trialv1.TrialRunnerMetadata, newRunID, newRestarts int) error
TrialRunIDAndRestarts(trialID int) (int, int, error)
UpdateTrialRunID(id, runID int) error
UpdateTrialRestarts(id, restarts int) error
AddTrainingMetrics(ctx context.Context, m *trialv1.TrialMetrics) error
AddValidationMetrics(
ctx context.Context, m *trialv1.TrialMetrics,
Expand Down
74 changes: 37 additions & 37 deletions master/internal/db/postgres_trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,15 @@ func TrialByTaskID(ctx context.Context, taskID model.TaskID) (*model.Trial, erro
return &t, nil
}

// UpdateTrial updates an existing trial. Fields that are nil or zero are not
// updated. end_time is set if the trial moves to a terminal state.
// UpdateTrial updates the state of an existing trial.
// end_time is set if the trial moves to a terminal state.
func UpdateTrial(ctx context.Context, id int, newState model.State) error {
trial, err := TrialByID(ctx, id)
if err != nil {
return fmt.Errorf("error finding trial %v to update: %w", id, err)
}

// Update trial state if necessary.
if trial.State == newState {
return nil
}
Expand All @@ -168,7 +169,7 @@ func UpdateTrial(ctx context.Context, id int, newState model.State) error {

if model.TerminalStates[newState] && trial.EndTime != nil {
if _, err := tx.NewRaw(`UPDATE tasks SET end_time = ? FROM run_id_task_id
WHERE run_id_task_id.task_id = tasks.task_id AND run_id_task_id.run_id = ? AND end_time IS NULL`,
WHERE run_id_task_id.task_id = tasks.task_id AND run_id_task_id.run_id = ? AND end_time IS NULL`,
*trial.EndTime, id).Exec(ctx); err != nil {
return fmt.Errorf("completing task: %w", err)
}
Expand All @@ -178,20 +179,41 @@ func UpdateTrial(ctx context.Context, id int, newState model.State) error {
})
}

// UpdateTrialRunnerState updates a trial runner's state.
func (db *PgDB) UpdateTrialRunnerState(id int, state string) error {
return db.UpdateTrialRunnerMetadata(id, &trialv1.TrialRunnerMetadata{State: state})
}
// UpdateTrialFields updates the specified fields of trial with ID id. Fields that are nil or zero
// are not updated.
func (db *PgDB) UpdateTrialFields(id int, newRunnerMetadata *trialv1.TrialRunnerMetadata, newRunID,
newRestarts int,
) error {
ctx := context.TODO()
trial, err := TrialByID(ctx, id)
if err != nil {
return fmt.Errorf("error finding trial %v to update: %w", id, err)
}

// UpdateTrialRunnerMetadata updates a trial's metadata about its runner.
func (db *PgDB) UpdateTrialRunnerMetadata(id int, md *trialv1.TrialRunnerMetadata) error {
if _, err := db.sql.Exec(`
UPDATE runs
SET runner_state = $2
WHERE id = $1`, id, md.State); err != nil {
return errors.Wrap(err, "saving trial runner state")
var toUpdate []string

// Update trial runner's state if necessary.
if newRunnerMetadata != nil {
trial.RunnerState = newRunnerMetadata.State
toUpdate = append(toUpdate, "runner_state")
}
return nil

// Update trial's run id if necessary.
if newRunID > 0 {
trial.RunID = newRunID
toUpdate = append(toUpdate, "restart_id")
}

// Update trial's restart count if necessary.
if newRestarts > 0 {
trial.Restarts = newRestarts
toUpdate = append(toUpdate, "restarts")
}

run, _ := trial.ToRunAndTrialV2()
_, err = Bun().NewUpdate().Model(run).Column(toUpdate...).Where("id = ?", id).Exec(ctx)

return err
}

// TrialRunIDAndRestarts returns the run id and restart count for a trial.
Expand All @@ -206,28 +228,6 @@ WHERE id = $1`, trialID).Scan(&runID, &restart); err != nil {
return runID, restart, nil
}

// UpdateTrialRunID sets the trial's run ID.
func (db *PgDB) UpdateTrialRunID(id, runID int) error {
if _, err := db.sql.Exec(`
UPDATE runs
SET restart_id = $2
WHERE id = $1`, id, runID); err != nil {
return errors.Wrap(err, "updating trial run id")
}
return nil
}

// UpdateTrialRestarts sets the trial's restart count.
func (db *PgDB) UpdateTrialRestarts(id, restartCount int) error {
if _, err := db.sql.Exec(`
UPDATE runs
SET restarts = $2
WHERE id = $1`, id, restartCount); err != nil {
return errors.Wrap(err, "updating trial restarts")
}
return nil
}

// fullTrialSummaryMetricsRecompute recomputes all summary metrics for a given trial.
func (db *PgDB) fullTrialSummaryMetricsRecompute(
ctx context.Context, tx *sqlx.Tx, trialID int,
Expand Down
14 changes: 7 additions & 7 deletions master/internal/db/postgres_trial_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func addMetrics(ctx context.Context,
},
}))
trialRunID++
require.NoError(t, db.UpdateTrialRunID(trialID, trialRunID))
require.NoError(t, db.UpdateTrialFields(trialID, nil, trialRunID, 0))
}

metrics, err := structpb.NewStruct(m)
Expand Down Expand Up @@ -86,7 +86,7 @@ func addMetrics(ctx context.Context,
},
}))
trialRunID++
require.NoError(t, db.UpdateTrialRunID(trialID, trialRunID))
require.NoError(t, db.UpdateTrialFields(trialID, nil, trialRunID, 0))
}

metrics, err := structpb.NewStruct(m)
Expand Down Expand Up @@ -935,7 +935,7 @@ func TestAddValidationMetricsDupeCheckpoints(t *testing.T) {
StartTime: ptrs.Ptr(time.Now()),
}
require.NoError(t, AddAllocation(ctx, a))
require.NoError(t, db.UpdateTrialRunID(tr.ID, 1))
require.NoError(t, db.UpdateTrialFields(tr.ID, nil, 1, 0))

// Now trial runs validation.
require.NoError(t, db.AddValidationMetrics(ctx, &trialv1.TrialMetrics{
Expand Down Expand Up @@ -1019,7 +1019,7 @@ func TestBatchesProcessedNRollbacks(t *testing.T) {
testMetricReporting := func(typ string, trialRunId, batches, expectedTotalBatches int,
expectedRollbacks Rollbacks,
) error {
require.NoError(t, db.UpdateTrialRunID(tr.ID, trialRunId))
require.NoError(t, db.UpdateTrialFields(tr.ID, nil, trialRunId, 0))
trialMetrics := &trialv1.TrialMetrics{
TrialId: int32(tr.ID),
TrialRunId: int32(trialRunId),
Expand Down Expand Up @@ -1117,9 +1117,9 @@ func TestUpdateTrialRunnerMetadata(t *testing.T) {
exp := RequireMockExperiment(t, db, user)
trialID := RequireMockTrialID(t, db, exp)

require.NoError(t, db.UpdateTrialRunnerMetadata(trialID, &trialv1.TrialRunnerMetadata{
require.NoError(t, db.UpdateTrialFields(trialID, &trialv1.TrialRunnerMetadata{
State: "expectedState",
}))
}, 0, 0))

actual := struct {
bun.BaseModel `bun:"table:runs"`
Expand Down Expand Up @@ -1168,7 +1168,7 @@ func TestGenericMetricsIO(t *testing.T) {

trialRunID := 1
batches := 10
require.NoError(t, db.UpdateTrialRunID(tr.ID, trialRunID))
require.NoError(t, db.UpdateTrialFields(tr.ID, nil, trialRunID, 0))
trialMetrics := &trialv1.TrialMetrics{
TrialId: int32(tr.ID),
TrialRunId: int32(trialRunID),
Expand Down
4 changes: 2 additions & 2 deletions master/internal/trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ func (t *trial) addTask(ctx context.Context) error {
}

func (t *trial) buildTaskSpecifier() (*tasks.TrialSpec, error) {
if err := t.db.UpdateTrialRunID(t.id, t.runID); err != nil {
if err := t.db.UpdateTrialFields(t.id, nil, t.runID, 0); err != nil {
return nil, errors.Wrap(err, "failed to save trial run ID")
}

Expand Down Expand Up @@ -616,7 +616,7 @@ func (t *trial) handleAllocationExit(exit *task.AllocationExited) error {
WithError(exit.Err).
Errorf("trial failed (restart %d/%d)", t.restarts, t.config.MaxRestarts())
t.restarts++
if err := t.db.UpdateTrialRestarts(t.id, t.restarts); err != nil {
if err := t.db.UpdateTrialFields(t.id, nil, 0, t.restarts); err != nil {
return t.transition(model.StateWithReason{
State: model.ErrorState,
InformationalReason: err.Error(),
Expand Down
6 changes: 6 additions & 0 deletions master/pkg/model/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,8 @@ type Trial struct {
TotalBatches int `db:"total_batches"`
ExternalTrialID *string `db:"external_trial_id"`
RunID int `db:"run_id"` // run_id as in restart_id not "runs" id.
Restarts int `db:"restarts"`
RunnerState string `db:"runner_state"`
LastActivity *time.Time `db:"last_activity"`
}

Expand All @@ -474,6 +476,8 @@ func (t *Trial) ToRunAndTrialV2() (*Run, *TrialV2) {
TotalBatches: t.TotalBatches,
ExternalRunID: t.ExternalTrialID,
RestartID: t.RunID,
Restarts: t.Restarts,
RunnerState: t.RunnerState,
LastActivity: t.LastActivity,
}
v2 := &TrialV2{
Expand Down Expand Up @@ -508,6 +512,8 @@ type Run struct {
TotalBatches int `db:"total_batches"`
ExternalRunID *string `db:"external_trial_id"`
RestartID int `db:"restart_id"`
Restarts int `db:"restarts"`
RunnerState string `db:"runner_state"`
LastActivity *time.Time `db:"last_activity"`
}

Expand Down
Loading