Skip to content

Commit

Permalink
feat: add a master API to fetch a trial by external id. (#8730)
Browse files Browse the repository at this point in the history
  • Loading branch information
ioga authored Feb 16, 2024
1 parent e78a4c0 commit 6b63750
Show file tree
Hide file tree
Showing 11 changed files with 3,831 additions and 3,274 deletions.
53 changes: 53 additions & 0 deletions harness/determined/common/api/bindings.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions master/internal/api_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,33 @@ func (a *apiServer) GetTrial(ctx context.Context, req *apiv1.GetTrialRequest) (
return resp, nil
}

func (a *apiServer) GetTrialByExternalID(ctx context.Context, req *apiv1.GetTrialByExternalIDRequest) (
*apiv1.GetTrialByExternalIDResponse, error,
) {
var trialID int
err := db.Bun().NewRaw(`
SELECT t.id
FROM trials t JOIN experiments e
ON t.experiment_id = e.id
WHERE t.external_trial_id = ? AND e.external_experiment_id = ?`,
req.ExternalTrialId, req.ExternalExperimentId).Scan(ctx, &trialID)
if err != nil {
return nil, db.MatchSentinelError(err)
}

proxyReq := apiv1.GetTrialRequest{TrialId: int32(trialID)}
proxyResp, err := a.GetTrial(ctx, &proxyReq)
if err != nil {
return nil, err
}

resp := apiv1.GetTrialByExternalIDResponse{
Trial: proxyResp.Trial,
}

return &resp, nil
}

func (a *apiServer) formatMetrics(
m *apiv1.DownsampledMetrics, metricMeasurements []db.MetricMeasurements,
) error {
Expand Down
27 changes: 27 additions & 0 deletions master/internal/api_trials_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1366,3 +1366,30 @@ func TestTrialSourceInfoModelVersion(t *testing.T) {
require.Equal(t, 1, len(getMVResp.Metrics))
require.Equal(t, int32(infTrial.ID), getMVResp.Metrics[0].TrialId)
}

func TestGetTrialByExternalID(t *testing.T) {
api, curUser, ctx := setupAPITest(t, nil)
trial, _ := createTestTrial(t, api, curUser)
externalExpID := uuid.New().String()
externalTrialID := "trial"

_, err := db.Bun().NewUpdate().Model(&model.Experiment{}).
Where("id = ?", trial.ExperimentID).
Set("external_experiment_id = ?", externalExpID).
Exec(ctx)
require.NoError(t, err)

_, err = db.Bun().NewUpdate().Model(&model.Run{}).
Where("id = ?", trial.ID).
Set("external_run_id = ?", externalTrialID).
Exec(ctx)
require.NoError(t, err)

resp, err := api.GetTrialByExternalID(ctx, &apiv1.GetTrialByExternalIDRequest{
ExternalExperimentId: externalExpID,
ExternalTrialId: externalTrialID,
})
require.NoError(t, err)

require.Equal(t, int(resp.Trial.Id), trial.ID)
}
2 changes: 1 addition & 1 deletion master/pkg/model/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ type Run struct {
HParams map[string]any `db:"hparams" bun:"hparams"`
WarmStartCheckpointID *int `db:"warm_start_checkpoint_id"`
TotalBatches int `db:"total_batches"`
ExternalRunID *string `db:"external_trial_id"`
ExternalRunID *string `db:"external_run_id"`
RestartID int `db:"restart_id"`
Restarts int `db:"restarts"`
RunnerState string `db:"runner_state"`
Expand Down
4 changes: 2 additions & 2 deletions master/static/srv/get_task.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ SELECT
t.start_time,
t.end_time,
CASE WHEN t.task_state is NULL THEN NULL
ELSE CONCAT('GENERIC_TASK_STATE_', t.task_state)
END as task_state,
ELSE CONCAT('GENERIC_TASK_STATE_', t.task_state)
END as task_state,
(
SELECT
COALESCE(
Expand Down
Loading

0 comments on commit 6b63750

Please sign in to comment.