Skip to content

Commit

Permalink
fix: Use experiment config to determine is_multi_trial in api_runs qu…
Browse files Browse the repository at this point in the history
…eries (#9475)
  • Loading branch information
AmanuelAaron committed Jun 10, 2024
1 parent f87214b commit dde6362
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
8 changes: 4 additions & 4 deletions master/internal/api_runs.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func getRunsColumns(q *bun.SelectQuery) *bun.SelectQuery {
'progress', e.progress,
'forked_from', e.parent_id,
'external_experiment_id', e.external_experiment_id,
'is_multitrial', ((SELECT COUNT(*) FROM runs r WHERE e.id = r.experiment_id) > 1),
'is_multitrial', (e.config->'searcher'->>'name' != 'single'),
'pachyderm_integration', NULLIF(e.config#>'{integrations,pachyderm}', 'null'),
'id', e.id) AS experiment`).
Join("LEFT JOIN experiments AS e ON r.experiment_id=e.id").
Expand Down Expand Up @@ -196,7 +196,7 @@ func sortRuns(sortString *string, runQuery *bun.SelectQuery) error {
"externalExperimentId": "e.external_experiment_id",
"externalRunId": "r.external_run_id",
"experimentId": "e.id",
"isExpMultitrial": "((SELECT COUNT(*) FROM runs r WHERE e.id = r.experiment_id) > 1)",
"isExpMultitrial": "(e.config->'searcher'->>'name' != 'single')",
"parentArchived": "(w.archived OR p.archived)",
}
sortParams := strings.Split(*sortString, ",")
Expand Down Expand Up @@ -310,7 +310,7 @@ func (a *apiServer) MoveRuns(
Column("r.id").
ColumnExpr("COALESCE((r.archived OR e.archived OR p.archived OR w.archived), FALSE) AS archived").
ColumnExpr("r.experiment_id as exp_id").
ColumnExpr("((SELECT COUNT(*) FROM runs r WHERE e.id = r.experiment_id) > 1) as is_multitrial").
ColumnExpr("(e.config->'searcher'->>'name' != 'single') as is_multitrial").
Join("LEFT JOIN experiments e ON r.experiment_id=e.id").
Join("JOIN projects p ON r.project_id = p.id").
Join("JOIN workspaces w ON p.workspace_id = w.id").
Expand Down Expand Up @@ -539,7 +539,7 @@ func (a *apiServer) DeleteRuns(ctx context.Context, req *apiv1.DeleteRunsRequest
Column("r.id").
ColumnExpr("COALESCE((r.archived OR e.archived OR p.archived OR w.archived), FALSE) AS archived").
ColumnExpr("r.experiment_id as exp_id").
ColumnExpr("((SELECT COUNT(*) FROM runs r WHERE e.id = r.experiment_id) > 1) as is_multitrial").
ColumnExpr("(e.config->'searcher'->>'name' != 'single') as is_multitrial").
ColumnExpr("r.state IN (?) AS is_terminal", bun.In(model.StatesToStrings(model.TerminalStates))).
Where("r.project_id = ?", req.ProjectId)

Expand Down
52 changes: 49 additions & 3 deletions master/internal/api_runs_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/master/pkg/schemas"
"github.com/determined-ai/determined/master/pkg/schemas/expconf"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/determined-ai/determined/proto/pkg/taskv1"
)
Expand Down Expand Up @@ -402,7 +403,52 @@ func setUpMultiTrialExperiments(ctx context.Context, t *testing.T, api *apiServe

func TestMoveRunsMultiTrialSkip(t *testing.T) {
api, curUser, ctx := setupAPITest(t, nil)
sourceprojectID, destprojectID, runID1, runID2, _ := setUpMultiTrialExperiments(ctx, t, api, curUser)
_, projectIDInt := createProjectAndWorkspace(ctx, t, api)
_, projectID2Int := createProjectAndWorkspace(ctx, t, api)
sourceprojectID := int32(projectIDInt)
destprojectID := int32(projectID2Int)

// nolint: exhaustruct
experimentConfig := expconf.ExperimentConfig{
RawDescription: ptrs.Ptr("descnew"),
RawName: expconf.Name{RawString: ptrs.Ptr("name")},
RawSearcher: &expconf.SearcherConfigV0{
RawRandomConfig: &expconf.RandomConfigV0{
RawMaxLength: &expconf.LengthV0{
Unit: expconf.Batches,
Units: 1,
},
},
},
}

activeConfig := schemas.WithDefaults(schemas.Merge(experimentConfig, minExpConfig))

exp := createTestExpWithActiveConfig(t, api, curUser, projectIDInt, activeConfig)

task1 := &model.Task{TaskType: model.TaskTypeTrial, TaskID: model.NewTaskID()}
require.NoError(t, db.AddTask(ctx, task1))
require.NoError(t, db.AddTrial(ctx, &model.Trial{
State: model.CompletedState,
ExperimentID: exp.ID,
StartTime: time.Now(),
}, task1.TaskID))

task2 := &model.Task{TaskType: model.TaskTypeTrial, TaskID: model.NewTaskID()}
require.NoError(t, db.AddTask(ctx, task2))
require.NoError(t, db.AddTrial(ctx, &model.Trial{
State: model.CompletedState,
ExperimentID: exp.ID,
StartTime: time.Now(),
}, task2.TaskID))

req := &apiv1.SearchRunsRequest{
ProjectId: &sourceprojectID,
Sort: ptrs.Ptr("id=asc"),
}
resp, err := api.SearchRuns(ctx, req)
require.NoError(t, err)
runID1, runID2 := resp.Runs[0].Id, resp.Runs[1].Id

moveIds := []int32{runID1}

Expand All @@ -420,11 +466,11 @@ func TestMoveRunsMultiTrialSkip(t *testing.T) {
moveResp.Results[0].Error)

// run still in old project
req := &apiv1.SearchRunsRequest{
req = &apiv1.SearchRunsRequest{
ProjectId: &sourceprojectID,
Sort: ptrs.Ptr("id=asc"),
}
resp, err := api.SearchRuns(ctx, req)
resp, err = api.SearchRuns(ctx, req)
require.NoError(t, err)
require.Len(t, resp.Runs, 2)
require.Equal(t, runID1, resp.Runs[0].Id)
Expand Down

0 comments on commit dde6362

Please sign in to comment.