Skip to content

Commit

Permalink
chore: Bunify and add test coverage for ExperimentTotalStepTime and…
Browse files Browse the repository at this point in the history
… `ExperimentNumSteps` (#9333)
  • Loading branch information
ShreyaLnuHpe committed May 9, 2024
1 parent 3c0eac6 commit 86aa319
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 46 deletions.
2 changes: 0 additions & 2 deletions master/internal/db/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ type DB interface {
ExperimentHasCheckpointsInRegistry(id int) (bool, error)
SaveExperimentProgress(id int, progress *float64) error
ActiveExperimentConfig(id int) (expconf.ExperimentConfig, error)
ExperimentTotalStepTime(id int) (float64, error)
ExperimentNumTrials(id int) (int64, error)
ExperimentTrialIDs(expID int) ([]int, error)
ExperimentNumSteps(id int) (int64, error)
ExperimentModelDefinitionRaw(id int) ([]byte, error)
UpdateTrialFields(id int, newRunnerMetadata *trialv1.TrialRunnerMetadata, newRunID, newRestarts int) error
TrialRunIDAndRestarts(trialID int) (int, int, error)
Expand Down
37 changes: 19 additions & 18 deletions master/internal/db/postgres_experiments.go
Original file line number Diff line number Diff line change
Expand Up @@ -937,16 +937,16 @@ func ActiveLogPolicies(
// ExperimentTotalStepTime returns the total elapsed time for all allocations of the experiment
// with the given ID. Any step with a NULL end_time does not contribute. Elapsed time is
// expressed as a floating point number of seconds.
func (db *PgDB) ExperimentTotalStepTime(id int) (float64, error) {
func ExperimentTotalStepTime(ctx context.Context, id int) (float64, error) {
var seconds float64
if err := db.sql.Get(&seconds, `
SELECT COALESCE(extract(epoch from sum(a.end_time - a.start_time)), 0)
FROM allocations a
JOIN run_id_task_id tasks ON a.task_id = tasks.task_id
JOIN trials t ON tasks.run_id = t.id
WHERE t.experiment_id = $1
`, id); err != nil {
return 0, errors.Wrapf(err, "querying for total step time of experiment %v", id)
if err := Bun().NewSelect().
ColumnExpr("COALESCE(extract(epoch from sum(a.end_time - a.start_time)), 0)").
TableExpr("allocations AS a").
Join("JOIN run_id_task_id AS tasks ON a.task_id = tasks.task_id").
Join("JOIN trials AS t ON tasks.run_id = t.id").
Where("t.experiment_id = ?", id).
Scan(ctx, &seconds); err != nil {
return 0.0, fmt.Errorf("querying for total step time of experiment %v: %w", id, err)
}
return seconds, nil
}
Expand Down Expand Up @@ -1010,16 +1010,17 @@ func ExperimentsTrialAndTaskIDs(ctx context.Context, idb bun.IDB, expIDs []int)
}

// ExperimentNumSteps returns the total number of steps for all trials of the experiment.
func (db *PgDB) ExperimentNumSteps(id int) (int64, error) {
var numSteps int64
if err := db.sql.Get(&numSteps, `
SELECT count(*)
FROM raw_steps s, trials t
WHERE t.experiment_id = $1 AND s.trial_id = t.id
`, id); err != nil {
return 0, errors.Wrapf(err, "querying for number of steps of experiment %v", id)
func ExperimentNumSteps(ctx context.Context, id int) (int64, error) {
numSteps, err := Bun().NewSelect().
TableExpr("raw_steps AS s").
Join("JOIN trials AS t ON t.id = s.trial_id").
Where("t.experiment_id = ?", id).
Count(ctx)
if err != nil {
return int64(0), fmt.Errorf("querying for number of steps of experiment %v: %w", id, err)
}
return numSteps, nil

return int64(numSteps), nil
}

// ExperimentModelDefinitionRaw returns the zipped model definition for an experiment as a byte
Expand Down
166 changes: 146 additions & 20 deletions master/internal/db/postgres_experiments_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,26 @@ func TestTopTrialsByMetric(t *testing.T) {
}
}

func createMetric(sc int32, mv float64, trID int) *trialv1.TrialMetrics {
m := &trialv1.TrialMetrics{
TrialId: int32(trID),
StepsCompleted: &sc,
Metrics: &commonv1.Metrics{
AvgMetrics: &structpb.Struct{
Fields: map[string]*structpb.Value{
defaultSearcherMetric: {
Kind: &structpb.Value_NumberValue{
NumberValue: mv,
},
},
},
},
BatchMetrics: []*structpb.Struct{},
},
}
return m
}

func TestDeleteExperiments(t *testing.T) {
ctx := context.Background()

Expand Down Expand Up @@ -698,26 +718,6 @@ func TestDeleteExperiments(t *testing.T) {
numExptSns = 1 // Experiment snapshots per experiment
)

createMetric := func(sc int32, mv float64, trID int) *trialv1.TrialMetrics {
m := &trialv1.TrialMetrics{
TrialId: int32(trID),
StepsCompleted: &sc,
Metrics: &commonv1.Metrics{
AvgMetrics: &structpb.Struct{
Fields: map[string]*structpb.Value{
defaultSearcherMetric: {
Kind: &structpb.Value_NumberValue{
NumberValue: mv,
},
},
},
},
BatchMetrics: []*structpb.Struct{},
},
}
return m
}

checkPointIndex := 1
for i := 0; i < numExpts; i++ { // Create experiments
exp := RequireMockExperiment(t, db, user)
Expand Down Expand Up @@ -976,3 +976,129 @@ func validateExperimentMatch(t *testing.T, expected []*model.Experiment, actual
// map should be empty
require.Equal(t, len(m), 0)
}

func TestExperimentTotalStepTime(t *testing.T) {
ctx := context.Background()

require.NoError(t, etc.SetRootPath(RootFromDB))
db := MustResolveTestPostgres(t)
MustMigrateTestPostgres(t, db, MigrationsFromDB)

t.Run("invalid experiment, return 0.0, no error", func(t *testing.T) {
sec, err := ExperimentTotalStepTime(ctx, -1)
require.Equal(t, 0.0, sec)
require.NoError(t, err)
})

t.Run("experiment with single trial/task with null endtime", func(t *testing.T) {
user := RequireMockUser(t, db)
exp := RequireMockExperiment(t, db, user)
timeInSeconds, err := ExperimentTotalStepTime(ctx, exp.ID)
require.NoError(t, err)
require.Equal(t, 0.0, timeInSeconds)
})

t.Run("experiment with single trial/task with set endtime", func(t *testing.T) {
user := RequireMockUser(t, db)
exp := RequireMockExperiment(t, db, user)
_, task := RequireMockTrial(t, db, exp)
alloc := RequireMockAllocation(t, db, task.TaskID)
endTime := alloc.StartTime.Add(time.Hour)
alloc.EndTime = &endTime // It only changes the memory and not DB.
require.NoError(t, CompleteAllocation(ctx, alloc))
timeInSeconds, err := ExperimentTotalStepTime(ctx, exp.ID)
require.NoError(t, err)
require.Equal(t, 3600.0, timeInSeconds)
})

t.Run("experiment with multiple trials/tasks", func(t *testing.T) {
user := RequireMockUser(t, db)
exp := RequireMockExperiment(t, db, user)

// Add 3 tasks to an experiment.
_, task := RequireMockTrial(t, db, exp)
alloc := RequireMockAllocation(t, db, task.TaskID)
endTime := alloc.StartTime.Add(time.Hour) // Adding an hour.
alloc.EndTime = &endTime
require.NoError(t, CompleteAllocation(ctx, alloc))

_, task = RequireMockTrial(t, db, exp)
alloc = RequireMockAllocation(t, db, task.TaskID)
endTime = alloc.StartTime.Add(time.Minute) // Adding a minute.
alloc.EndTime = &endTime
require.NoError(t, CompleteAllocation(ctx, alloc))

_, task = RequireMockTrial(t, db, exp)
alloc = RequireMockAllocation(t, db, task.TaskID)
endTime = alloc.StartTime.Add(time.Second) // Adding a second.
alloc.EndTime = &endTime
require.NoError(t, CompleteAllocation(ctx, alloc))

timeInSeconds, err := ExperimentTotalStepTime(ctx, exp.ID)
require.NoError(t, err)
require.Equal(t, 3661.0, timeInSeconds)
})
}

func TestExperimentNumSteps(t *testing.T) {
ctx := context.Background()

require.NoError(t, etc.SetRootPath(RootFromDB))
db := MustResolveTestPostgres(t)
MustMigrateTestPostgres(t, db, MigrationsFromDB)

t.Run("invalid experiment, return 0, no error", func(t *testing.T) {
sec, err := ExperimentNumSteps(ctx, -1)
require.Equal(t, int64(0), sec)
require.NoError(t, err)
})

t.Run("experiment with single trial metrics", func(t *testing.T) {
user := RequireMockUser(t, db)
exp := RequireMockExperiment(t, db, user)
trialID := RequireMockTrialID(t, db, exp)

// Create training metrics (raw_steps).
mRaw1 := createMetric(10, 0.5, trialID)
err := db.AddTrainingMetrics(ctx, mRaw1)
require.NoError(t, err)

count, err := ExperimentNumSteps(ctx, exp.ID)
require.NoError(t, err)
require.Equal(t, int64(1), count)
})

t.Run("experiment with single trial multiple raw metrics", func(t *testing.T) {
user := RequireMockUser(t, db)
exp := RequireMockExperiment(t, db, user)
trialID := RequireMockTrialID(t, db, exp)

// Create training metrics (raw_steps).
for i := 1; i < 5; i++ {
mRaw := createMetric(int32(i), 0.5, trialID)
err := db.AddTrainingMetrics(ctx, mRaw)
require.NoError(t, err)
}

count, err := ExperimentNumSteps(ctx, exp.ID)
require.NoError(t, err)
require.Equal(t, int64(4), count)
})

t.Run("experiment with multiple trial raw metrics", func(t *testing.T) {
user := RequireMockUser(t, db)
exp := RequireMockExperiment(t, db, user)

// Create training metrics (raw_steps).
for i := 0; i < 3; i++ {
trialID := RequireMockTrialID(t, db, exp)
mRaw := createMetric(10, 0.5, trialID)
err := db.AddTrainingMetrics(ctx, mRaw)
require.NoError(t, err)
}

count, err := ExperimentNumSteps(ctx, exp.ID)
require.NoError(t, err)
require.Equal(t, int64(3), count)
})
}
12 changes: 6 additions & 6 deletions master/internal/telemetry/reports.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,17 @@ func fetchNumTrials(db db.DB, experimentID int) *int64 {
return &result
}

func fetchNumSteps(db db.DB, experimentID int) *int64 {
result, err := db.ExperimentNumSteps(experimentID)
func fetchNumSteps(experimentID int) *int64 {
result, err := db.ExperimentNumSteps(context.TODO(), experimentID)
if err != nil {
syslog.WithError(err).Warn("failed to fetch telemetry metrics")
return nil
}
return &result
}

func fetchTotalStepTime(db db.DB, experimentID int) *float64 {
result, err := db.ExperimentTotalStepTime(experimentID)
func fetchTotalStepTime(experimentID int) *float64 {
result, err := db.ExperimentTotalStepTime(context.TODO(), experimentID)
if err != nil {
syslog.WithError(err).Warn("failed to fetch telemetry metrics")
return nil
Expand All @@ -200,8 +200,8 @@ func ReportExperimentStateChanged(db db.DB, e *model.Experiment) {
// Report additional metrics when an experiment reaches a terminal state.
// These metrics are null for non-terminal state transitions.
numTrials = fetchNumTrials(db, e.ID)
numSteps = fetchNumSteps(db, e.ID)
totalStepTime = fetchTotalStepTime(db, e.ID)
numSteps = fetchNumSteps(e.ID)
totalStepTime = fetchTotalStepTime(e.ID)
}

defaultTelemeter.track(
Expand Down

0 comments on commit 86aa319

Please sign in to comment.