Skip to content

Commit

Permalink
chore: migrate db schema trials to runs (#8723)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicholasBlaskey authored Jan 31, 2024
1 parent dfbb926 commit 905e449
Show file tree
Hide file tree
Showing 29 changed files with 701 additions and 142 deletions.
4 changes: 2 additions & 2 deletions master/internal/api_checkpoint_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
func createVersionTwoCheckpoint(
ctx context.Context, t *testing.T, api *apiServer, curUser model.User, resources map[string]int64,
) string {
_, task := createTestTrial(t, api, curUser)
tr, task := createTestTrial(t, api, curUser)

aID := model.AllocationID(string(task.TaskID) + "-1")
aIn := &model.Allocation{
Expand All @@ -58,7 +58,7 @@ func createVersionTwoCheckpoint(
"steps_completed": 5,
},
}
require.NoError(t, db.AddCheckpointMetadata(ctx, checkpoint))
require.NoError(t, db.AddCheckpointMetadata(ctx, checkpoint, tr.ID))

return checkpoint.UUID.String()
}
Expand Down
11 changes: 4 additions & 7 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1592,9 +1592,6 @@ func (a *apiServer) ContinueExperiment(

// Zero out trial restarts. We do somewhat lose information about how many times
// the previous failed but likely people care only about current run.
// TODO consider moving this to trial_id_task_id or some other level to preserve
// the history of what happened during the trial. We should also do this
// with submitted config yamls likely and display these in the webui.
var trialIDs []int32
for _, t := range trialsResp.Trials {
trialIDs = append(trialIDs, t.Id)
Expand Down Expand Up @@ -2627,17 +2624,17 @@ func (a *apiServer) SearchExperiments(
Column("trials.checkpoint_count").
Column("trials.summary_metrics").
ColumnExpr(`(
SELECT tt.task_id FROM trial_id_task_id tt
SELECT tt.task_id FROM run_id_task_id tt
JOIN tasks ta ON tt.task_id = ta.task_id
WHERE tt.trial_id = trials.id
WHERE tt.run_id = trials.id
ORDER BY ta.start_time
LIMIT 1
) AS task_id`).
ColumnExpr(`(
(SELECT json_agg(task_id) FROM (
SELECT tt.task_id FROM trial_id_task_id tt
SELECT tt.task_id FROM run_id_task_id tt
JOIN tasks ta ON tt.task_id = ta.task_id
WHERE tt.trial_id = trials.id
WHERE tt.run_id = trials.id
ORDER BY ta.start_time
) sub_tasks)) AS task_ids`).
ColumnExpr("proto_time(trials.start_time) AS start_time").
Expand Down
15 changes: 14 additions & 1 deletion master/internal/api_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -1385,9 +1385,22 @@ func (a *apiServer) ReportCheckpoint(
default:
}

if err := db.AddCheckpointMetadata(ctx, c); err != nil {
task, err := db.TaskByID(ctx, model.TaskID(req.Checkpoint.TaskId))
if err != nil {
return nil, fmt.Errorf("looking up task to decide if trial: %w", err)
}
if task.TaskType != model.TaskTypeTrial {
return nil, fmt.Errorf("can only report checkpoints on trial's tasks")
}
trial, err := db.TrialByTaskID(ctx, task.TaskID)
if err != nil {
return nil, fmt.Errorf("getting trial by task ID: %w", err)
}

if err := db.AddCheckpointMetadata(ctx, c, trial.ID); err != nil {
return nil, err
}

return &apiv1.ReportCheckpointResponse{}, nil
}

Expand Down
101 changes: 92 additions & 9 deletions master/internal/api_trials_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"

"github.com/determined-ai/determined/master/pkg/protoutils/protoconverter"

Expand All @@ -29,6 +31,7 @@ import (
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/determined-ai/determined/proto/pkg/checkpointv1"
"github.com/determined-ai/determined/proto/pkg/commonv1"
"github.com/determined-ai/determined/proto/pkg/trialv1"
)
Expand Down Expand Up @@ -530,6 +533,86 @@ func TestTrialsNonNumericMetrics(t *testing.T) {
})
}

func TestReportCheckpoint(t *testing.T) {
api, curUser, ctx := setupAPITest(t, nil)

tr, task := createTestTrial(t, api, curUser)

checkpointMeta, err := structpb.NewStruct(map[string]any{
"steps_completed": 1,
})
require.NoError(t, err)

checkpointID := uuid.New().String()
req := &apiv1.ReportCheckpointRequest{
Checkpoint: &checkpointv1.Checkpoint{
TaskId: string(task.TaskID),
AllocationId: nil,
Uuid: checkpointID,
ReportTime: timestamppb.New(time.Now().Truncate(time.Millisecond)),
Resources: nil,
Metadata: checkpointMeta,
State: checkpointv1.State_STATE_COMPLETED,
},
}
_, err = api.ReportCheckpoint(ctx, req)
require.NoError(t, err)

c, err := api.GetCheckpoint(ctx, &apiv1.GetCheckpointRequest{
CheckpointUuid: checkpointID,
})
require.NoError(t, err)

jsonActual, err := json.MarshalIndent(c.Checkpoint, "", "\t")
require.NoError(t, err)

getExperimentResp, err := api.GetExperiment(ctx, &apiv1.GetExperimentRequest{
ExperimentId: int32(tr.ExperimentID),
})
require.NoError(t, err)

req.Checkpoint.Training = &checkpointv1.CheckpointTrainingMetadata{
TrialId: wrapperspb.Int32(int32(tr.ID)),
ExperimentId: wrapperspb.Int32(int32(tr.ExperimentID)),
ExperimentConfig: getExperimentResp.Config,
Hparams: nil,
TrainingMetrics: &commonv1.Metrics{},
ValidationMetrics: &commonv1.Metrics{},
}
jsonExpected, err := json.MarshalIndent(req.Checkpoint, "", "\t")
require.NoError(t, err)

require.Equal(t, string(jsonExpected), string(jsonActual))
}

// This may have worked at some point but this definitely doesn't work after
// trial one to many tasks since we switched the fk reference for some reason.
func TestReportCheckpointNonTrialErrors(t *testing.T) {
api, _, ctx := setupAPITest(t, nil)

notebookTask := mockNotebookWithWorkspaceID(ctx, api, t, 1)

checkpointMeta, err := structpb.NewStruct(map[string]any{
"steps_completed": 1,
})
require.NoError(t, err)

checkpointID := uuid.New().String()
req := &apiv1.ReportCheckpointRequest{
Checkpoint: &checkpointv1.Checkpoint{
TaskId: string(notebookTask),
AllocationId: nil,
Uuid: checkpointID,
ReportTime: timestamppb.New(time.Now().Truncate(time.Millisecond)),
Resources: nil,
Metadata: checkpointMeta,
State: checkpointv1.State_STATE_COMPLETED,
},
}
_, err = api.ReportCheckpoint(ctx, req)
require.ErrorContains(t, err, "can only report checkpoints on trial's tasks")
}

func TestUnusualMetricNames(t *testing.T) {
api, curUser, ctx := setupAPITest(t, nil)
expectedMetricsMap := map[string]any{
Expand Down Expand Up @@ -761,9 +844,9 @@ func TestTrialProtoTaskIDs(t *testing.T) {
}
require.NoError(t, api.m.db.AddTask(task2))

_, err = db.Bun().NewInsert().Model(&[]model.TrialTaskID{
{TrialID: trial.ID, TaskID: task1.TaskID},
{TrialID: trial.ID, TaskID: task2.TaskID},
_, err = db.Bun().NewInsert().Model(&[]model.RunTaskID{
{RunID: trial.ID, TaskID: task1.TaskID},
{RunID: trial.ID, TaskID: task2.TaskID},
}).Exec(ctx)
require.NoError(t, err)

Expand Down Expand Up @@ -899,9 +982,9 @@ func TestTrialLogs(t *testing.T) {
}
require.NoError(t, api.m.db.AddTask(task2))

_, err := db.Bun().NewInsert().Model(&[]model.TrialTaskID{
{TrialID: trial.ID, TaskID: task1.TaskID},
{TrialID: trial.ID, TaskID: task2.TaskID},
_, err := db.Bun().NewInsert().Model(&[]model.RunTaskID{
{RunID: trial.ID, TaskID: task1.TaskID},
{RunID: trial.ID, TaskID: task2.TaskID},
}).Exec(ctx)
require.NoError(t, err)

Expand Down Expand Up @@ -995,9 +1078,9 @@ func TestTrialLogFields(t *testing.T) {
}
require.NoError(t, api.m.db.AddTask(task2))

_, err := db.Bun().NewInsert().Model(&[]model.TrialTaskID{
{TrialID: trial.ID, TaskID: task1.TaskID},
{TrialID: trial.ID, TaskID: task2.TaskID},
_, err := db.Bun().NewInsert().Model(&[]model.RunTaskID{
{RunID: trial.ID, TaskID: task1.TaskID},
{RunID: trial.ID, TaskID: task2.TaskID},
}).Exec(ctx)
require.NoError(t, err)

Expand Down
4 changes: 2 additions & 2 deletions master/internal/core_checkpoint_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ func addMockCheckpointDB(t *testing.T, pgDB *db.PgDB, id uuid.UUID, bucket strin
user := db.RequireMockUser(t, pgDB)
// Using a different path than DefaultTestSrcPath since we are one level up than most db tests
exp := mockExperimentS3(t, pgDB, user, "../../examples/tutorials/mnist_pytorch", bucket)
_, task := db.RequireMockTrial(t, pgDB, exp)
tr, task := db.RequireMockTrial(t, pgDB, exp)
allocation := db.RequireMockAllocation(t, pgDB, task.TaskID)
// Create checkpoints
checkpoint := db.MockModelCheckpoint(id, allocation)
err := db.AddCheckpointMetadata(context.TODO(), &checkpoint)
err := db.AddCheckpointMetadata(context.TODO(), &checkpoint, tr.ID)
require.NoError(t, err)
}

Expand Down
17 changes: 7 additions & 10 deletions master/internal/db/postgres_checkpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,17 @@ func UpdateCheckpointSizeTx(ctx context.Context, idb bun.IDB, checkpoints []uuid
err := idb.NewRaw(`
UPDATE runs SET checkpoint_size=sub.size, checkpoint_count=sub.count FROM (
SELECT
trial_id,
run_id,
COALESCE(SUM(size) FILTER (WHERE state != 'DELETED'), 0) AS size,
COUNT(*) FILTER (WHERE state != 'DELETED') AS count
FROM checkpoints_v2
JOIN trial_id_task_id tt ON tt.task_id = checkpoints_v2.task_id
WHERE
trial_id IN (
SELECT tt.trial_id FROM checkpoints_v2
LEFT JOIN trial_id_task_id tt ON tt.task_id = checkpoints_v2.task_id
WHERE uuid IN (?)
)
GROUP BY trial_id
JOIN run_checkpoints rc ON rc.checkpoint_id = checkpoints_v2.uuid
WHERE rc.run_id IN (
SELECT run_id FROM run_checkpoints WHERE checkpoint_id IN (?)
)
GROUP BY run_id
) sub
WHERE runs.id = sub.trial_id
WHERE runs.id = sub.run_id
RETURNING experiment_id`, bun.In(checkpoints)).Scan(ctx, &experimentIDs)
if err != nil {
return errors.Wrap(err, "errors updating trial checkpoint sizes and counts")
Expand Down
18 changes: 9 additions & 9 deletions master/internal/db/postgres_checkpoints_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestUpdateCheckpointSize(t *testing.T) {

checkpoint := MockModelCheckpoint(ckpt, allocation)
checkpoint.Resources = resources[resourcesIndex]
err := AddCheckpointMetadata(ctx, &checkpoint)
err := AddCheckpointMetadata(ctx, &checkpoint, tr.ID)
require.NoError(t, err)

resourcesIndex++
Expand Down Expand Up @@ -165,21 +165,21 @@ func TestDeleteCheckpoints(t *testing.T) {
MustMigrateTestPostgres(t, db, MigrationsFromDB)
user := RequireMockUser(t, db)
exp := RequireMockExperiment(t, db, user)
_, task := RequireMockTrial(t, db, exp)
tr, task := RequireMockTrial(t, db, exp)
allocation := RequireMockAllocation(t, db, task.TaskID)

// Create checkpoints
ckpt1 := uuid.New()
checkpoint1 := MockModelCheckpoint(ckpt1, allocation)
err := AddCheckpointMetadata(ctx, &checkpoint1)
err := AddCheckpointMetadata(ctx, &checkpoint1, tr.ID)
require.NoError(t, err)
ckpt2 := uuid.New()
checkpoint2 := MockModelCheckpoint(ckpt2, allocation)
err = AddCheckpointMetadata(ctx, &checkpoint2)
err = AddCheckpointMetadata(ctx, &checkpoint2, tr.ID)
require.NoError(t, err)
ckpt3 := uuid.New()
checkpoint3 := MockModelCheckpoint(ckpt3, allocation)
err = AddCheckpointMetadata(ctx, &checkpoint3)
err = AddCheckpointMetadata(ctx, &checkpoint3, tr.ID)
require.NoError(t, err)

// Insert a model.
Expand Down Expand Up @@ -286,7 +286,7 @@ func BenchmarkUpdateCheckpointSize(b *testing.B) {
exp := RequireMockExperiment(t, db, user)
for j := 0; j < 10; j++ {
t.Logf("Adding trial #%d", j)
_, task := RequireMockTrial(t, db, exp)
tr, task := RequireMockTrial(t, db, exp)
allocation := RequireMockAllocation(t, db, task.TaskID)
for k := 0; k < 10; k++ {
ckpt := uuid.New()
Expand All @@ -300,7 +300,7 @@ func BenchmarkUpdateCheckpointSize(b *testing.B) {
checkpoint := MockModelCheckpoint(ckpt, allocation)
checkpoint.Resources = resources

err := AddCheckpointMetadata(ctx, &checkpoint)
err := AddCheckpointMetadata(ctx, &checkpoint, tr.ID)
require.NoError(t, err)
}
}
Expand All @@ -318,15 +318,15 @@ func TestPgDB_GroupCheckpointUUIDsByExperimentID(t *testing.T) {
user := RequireMockUser(t, db)
for i := 0; i < 3; i++ {
exp := RequireMockExperiment(t, db, user)
_, tk := RequireMockTrial(t, db, exp)
tr, tk := RequireMockTrial(t, db, exp)

var ids []uuid.UUID
for j := 0; j < 3; j++ {
id := uuid.New()
err := AddCheckpointMetadata(context.TODO(), &model.CheckpointV2{
UUID: id,
TaskID: tk.TaskID,
})
}, tr.ID)
require.NoError(t, err)
ids = append(ids, id)
}
Expand Down
Loading

0 comments on commit 905e449

Please sign in to comment.