Skip to content

Commit

Permalink
perf: improve GetExperiments + SearchExperiments counting (#8801)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicholasBlaskey authored Feb 13, 2024
1 parent d8d9965 commit 7a13863
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 26 deletions.
23 changes: 11 additions & 12 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -577,11 +577,11 @@ func getExperimentColumns(q *bun.SelectQuery) *bun.SelectQuery {
Column("e.checkpoint_count").
Column("e.unmanaged").
Column("e.external_experiment_id").
Column(`t.external_trial_id`).
Join("JOIN users u ON e.owner_id = u.id").
Join("JOIN projects p ON e.project_id = p.id").
Join("JOIN workspaces w ON p.workspace_id = w.id").
Join("LEFT JOIN trials AS t ON t.id = e.best_trial_id")
ColumnExpr(`r.external_run_id AS external_trial_id`).
Join("LEFT JOIN users u ON e.owner_id = u.id").
Join("LEFT JOIN projects p ON e.project_id = p.id").
Join("LEFT JOIN workspaces w ON p.workspace_id = w.id").
Join("LEFT JOIN runs AS r ON r.id = e.best_trial_id")
}

func (a *apiServer) GetExperiments(
Expand Down Expand Up @@ -616,7 +616,7 @@ func (a *apiServer) GetExperiments(
apiv1.GetExperimentsRequest_SORT_BY_USER: "display_name",
apiv1.GetExperimentsRequest_SORT_BY_FORKED_FROM: "e.parent_id",
apiv1.GetExperimentsRequest_SORT_BY_RESOURCE_POOL: "resource_pool",
apiv1.GetExperimentsRequest_SORT_BY_PROJECT_ID: "project_id",
apiv1.GetExperimentsRequest_SORT_BY_PROJECT_ID: "e.project_id",
apiv1.GetExperimentsRequest_SORT_BY_CHECKPOINT_SIZE: "checkpoint_size",
apiv1.GetExperimentsRequest_SORT_BY_CHECKPOINT_COUNT: "checkpoint_count",
apiv1.GetExperimentsRequest_SORT_BY_SEARCHER_METRIC_VAL: `(
Expand Down Expand Up @@ -702,7 +702,7 @@ func (a *apiServer) GetExperiments(
return nil, err
}

query = query.Where("project_id = ?", req.ProjectId)
query = query.Where("e.project_id = ?", req.ProjectId)
}
if query, err = experiment.AuthZProvider.Get().
FilterExperimentsQuery(ctx, *curUser, proj, query,
Expand Down Expand Up @@ -2453,7 +2453,7 @@ func sortExperiments(sortString *string, experimentQuery *bun.SelectQuery) error
"user": "display_name",
"forkedFrom": "e.parent_id",
"resourcePool": "resource_pool",
"projectId": "project_id",
"projectId": "e.project_id",
"checkpointSize": "checkpoint_size",
"checkpointCount": "checkpoint_count",
"duration": "duration",
Expand All @@ -2466,7 +2466,7 @@ func sortExperiments(sortString *string, experimentQuery *bun.SelectQuery) error
LIMIT 1
) `,
"externalExperimentId": "e.external_experiment_id",
"externalTrialId": "trials.external_trial_id",
"externalTrialId": "r.external_run_id",
}
sortByMap := map[string]string{
"asc": "ASC",
Expand Down Expand Up @@ -2494,7 +2494,7 @@ func sortExperiments(sortString *string, experimentQuery *bun.SelectQuery) error
if err != nil {
return err
}
experimentQuery.OrderExpr("trials.summary_metrics->?->?->>? ?",
experimentQuery.OrderExpr("r.summary_metrics->?->?->>? ?",
metricGroup, metricName, metricQualifier, bun.Safe(sortDirection))
default:
if _, ok := orderColMap[paramDetail[0]]; !ok {
Expand Down Expand Up @@ -2522,7 +2522,6 @@ func (a *apiServer) SearchExperiments(
Model(&experiments).
ModelTableExpr("experiments as e").
Column("e.best_trial_id").
Join("LEFT JOIN trials ON trials.id = e.best_trial_id").
Apply(getExperimentColumns)

curUser, _, err := grpcutil.GetUser(ctx)
Expand All @@ -2536,7 +2535,7 @@ func (a *apiServer) SearchExperiments(
return nil, err
}

experimentQuery = experimentQuery.Where("project_id = ?", req.ProjectId)
experimentQuery = experimentQuery.Where("e.project_id = ?", req.ProjectId)
}
if experimentQuery, err = experiment.AuthZProvider.Get().
FilterExperimentsQuery(ctx, *curUser, proj, experimentQuery,
Expand Down
32 changes: 20 additions & 12 deletions master/internal/api_experiment_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,14 @@ func TestSearchExperiments(t *testing.T) {
Exec(ctx)
require.NoError(t, err)

// Sort by external trial ID.
resp, err = api.SearchExperiments(ctx, &apiv1.SearchExperimentsRequest{
ProjectId: req.ProjectId,
Sort: ptrs.Ptr("externalTrialId=asc"),
})
require.NoError(t, err)
require.Len(t, resp.Experiments, 3)

resp, err = api.SearchExperiments(ctx, req)
require.NoError(t, err)
require.Len(t, resp.Experiments, 3)
Expand Down Expand Up @@ -1656,18 +1664,18 @@ func TestExperimentSearchApiFilterParsing(t *testing.T) {
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_TEXT","location":"LOCATION_TYPE_EXPERIMENT", "columnName":"tags","kind":"field","operator":"notContains", "value":"val"}],"conjunction":"and","kind":"group"},"showArchived":true}`, `(((e.config->>'labels' NOT ILIKE '%val%')))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_EXPERIMENT", "columnName":"duration","kind":"field","operator":">", "value":0}],"conjunction":"and","kind":"group"},"showArchived":true}`, `(((extract(epoch FROM coalesce(e.end_time, now()) - e.start_time) > 0)))`},
{`{"filterGroup":{"children":[{"columnName":"projectId","location":"LOCATION_TYPE_EXPERIMENT", "kind":"field","operator":">=","value":-1}],"conjunction":"and","kind":"group"},"showArchived":true}`, `(((project_id >= -1)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_accuracy.mean","kind":"field","operator":">=","value":0}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((trials.summary_metrics->'validation_metrics'->'validation_accuracy'->>'mean')::float8 >= 0)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_TEXT","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_string.min","kind":"field","operator":"=","value":"string"}],"conjunction":"and","kind":"group"},"showArchived":true}`, `(((trials.summary_metrics->'validation_metrics'->'validation_string'->>'min' = 'string')))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_TEXT","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_string.max","kind":"field","operator":"!=","value":"string"}],"conjunction":"and","kind":"group"},"showArchived":true}`, `(((trials.summary_metrics->'validation_metrics'->'validation_string'->>'max' != 'string')))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_TEXT","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_string.mean","kind":"field","operator":"contains","value":"string"}],"conjunction":"and","kind":"group"},"showArchived":true}`, `(((trials.summary_metrics->'validation_metrics'->'validation_string'->>'mean' LIKE '%string%')))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_TEXT","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_string.min","kind":"field","operator":"notContains","value":"string"}],"conjunction":"and","kind":"group"},"showArchived":true}`, `(((trials.summary_metrics->'validation_metrics'->'validation_string'->>'min' NOT LIKE '%string%')))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_error.min","kind":"field","operator":">=","value":0}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((trials.summary_metrics->'validation_metrics'->'validation_error'->>'min')::float8 >= 0)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_error.max","kind":"field","operator":"notEmpty","value":0}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((trials.summary_metrics->'validation_metrics'->'validation_error'->>'max')::float8 IS NOT NULL)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_error.max","kind":"field","operator":"isEmpty","value":0}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((trials.summary_metrics->'validation_metrics'->'validation_error'->>'max')::float8 IS NULL)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.x.max","kind":"field","operator":"=","value": 0}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((trials.summary_metrics->'validation_metrics'->'x'->>'max')::float8 = 0)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.loss.last","kind":"field","operator":"!=","value":0.004}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((trials.summary_metrics->'validation_metrics'->'loss'->>'last')::float8 != 0.004)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_accuracy.max","kind":"field","operator":"<","value":-3}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((trials.summary_metrics->'validation_metrics'->'validation_accuracy'->>'max')::float8 < -3)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_accuracy.min","kind":"field","operator":"<=","value":10}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((trials.summary_metrics->'validation_metrics'->'validation_accuracy'->>'min')::float8 <= 10)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_accuracy.mean","kind":"field","operator":">=","value":0}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((r.summary_metrics->'validation_metrics'->'validation_accuracy'->>'mean')::float8 >= 0)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_TEXT","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_string.min","kind":"field","operator":"=","value":"string"}],"conjunction":"and","kind":"group"},"showArchived":true}`, `(((r.summary_metrics->'validation_metrics'->'validation_string'->>'min' = 'string')))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_TEXT","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_string.max","kind":"field","operator":"!=","value":"string"}],"conjunction":"and","kind":"group"},"showArchived":true}`, `(((r.summary_metrics->'validation_metrics'->'validation_string'->>'max' != 'string')))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_TEXT","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_string.mean","kind":"field","operator":"contains","value":"string"}],"conjunction":"and","kind":"group"},"showArchived":true}`, `(((r.summary_metrics->'validation_metrics'->'validation_string'->>'mean' LIKE '%string%')))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_TEXT","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_string.min","kind":"field","operator":"notContains","value":"string"}],"conjunction":"and","kind":"group"},"showArchived":true}`, `(((r.summary_metrics->'validation_metrics'->'validation_string'->>'min' NOT LIKE '%string%')))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_error.min","kind":"field","operator":">=","value":0}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((r.summary_metrics->'validation_metrics'->'validation_error'->>'min')::float8 >= 0)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_error.max","kind":"field","operator":"notEmpty","value":0}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((r.summary_metrics->'validation_metrics'->'validation_error'->>'max')::float8 IS NOT NULL)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_error.max","kind":"field","operator":"isEmpty","value":0}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((r.summary_metrics->'validation_metrics'->'validation_error'->>'max')::float8 IS NULL)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.x.max","kind":"field","operator":"=","value": 0}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((r.summary_metrics->'validation_metrics'->'x'->>'max')::float8 = 0)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.loss.last","kind":"field","operator":"!=","value":0.004}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((r.summary_metrics->'validation_metrics'->'loss'->>'last')::float8 != 0.004)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_accuracy.max","kind":"field","operator":"<","value":-3}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((r.summary_metrics->'validation_metrics'->'validation_accuracy'->>'max')::float8 < -3)))`},
{`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_NUMBER","location":"LOCATION_TYPE_VALIDATIONS", "columnName":"validation.validation_accuracy.min","kind":"field","operator":"<=","value":10}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((r.summary_metrics->'validation_metrics'->'validation_accuracy'->>'min')::float8 <= 10)))`},
{`{"filterGroup":{"children":[{"columnName":"projectId","kind":"field","operator":">=","value":null}],"conjunction":"and","kind":"group"},"showArchived":true}`, `(((true)))`},
{`{"filterGroup":{"children":[{"columnName":"id","kind":"field","operator":"=","value":1},{"children":[{"columnName":"id","kind":"field","operator":"=","value":2},{"columnName":"id","kind":"field","operator":"=","value":3}],"conjunction":"and","kind":"group"},{"columnName":"id","kind":"field","operator":"=","value":4},{"children":[{"columnName":"id","kind":"field","operator":"=","value":5}],"conjunction":"and","kind":"group"}],"conjunction":"and","kind":"group"},"showArchived":true}`, `(((e.id = 1)) AND (((e.id = 2)) AND ((e.id = 3))) AND ((e.id = 4)) AND (((e.id = 5))))`},
{`{"filterGroup":{"children":[{"children":[{"columnName":"checkpointCount","kind":"field","operator":"=","value":4},{"columnName":"numTrials","kind":"field","operator":"=","value":1},{"columnName":"progress","kind":"field","operator":"=","value":100}],"conjunction":"and","kind":"group"}],"conjunction":"and","kind":"group"},"showArchived":true}`, `((((e.checkpoint_count = 4)) AND (((SELECT COUNT(*) FROM trials t WHERE e.id = t.experiment_id) = 1)) AND ((ROUND(COALESCE(progress, 0) * 100)::INTEGER = 100))))`},
Expand Down
4 changes: 2 additions & 2 deletions master/internal/experiment_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func expColumnNameToSQL(columnName string) (string, error) {
LIMIT 1
) `,
"externalExperimentId": "e.external_experiment_id",
"externalTrialId": "trials.external_trial_id",
"externalTrialId": "r.external_run_id",
}
var exists bool
col, exists := filterExperimentColMap[columnName]
Expand Down Expand Up @@ -394,7 +394,7 @@ func (e experimentFilter) toSQL(q *bun.SelectQuery,
var col string
var queryArgs []interface{}
var queryString string
col = `trials.summary_metrics->?->?->>?`
col = `r.summary_metrics->?->?->>?`
queryArgs = append(queryArgs, metricGroup, metricName, metricQualifier)
if queryColumnType == projectv1.ColumnType_COLUMN_TYPE_NUMBER.String() {
col = fmt.Sprintf(`(%v)::float8`, col)
Expand Down

0 comments on commit 7a13863

Please sign in to comment.