Skip to content

Commit

Permalink
perf: avoid loading model def in experiment model (#8742)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicholasBlaskey authored Feb 2, 2024
1 parent 7698452 commit 422f5aa
Show file tree
Hide file tree
Showing 15 changed files with 128 additions and 128 deletions.
28 changes: 14 additions & 14 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1392,10 +1392,10 @@ func (a *apiServer) GetExperimentCheckpoints(
}

func (a *apiServer) createUnmanagedExperimentTx(
ctx context.Context, idb bun.IDB, dbExp *model.Experiment, activeConfig expconf.ExperimentConfigV0,
taskSpec *tasks.TaskSpec, user *model.User,
ctx context.Context, idb bun.IDB, dbExp *model.Experiment, modelDef []byte,
activeConfig expconf.ExperimentConfigV0, taskSpec *tasks.TaskSpec, user *model.User,
) (*apiv1.CreateExperimentResponse, error) {
e, _, err := newUnmanagedExperiment(ctx, idb, a.m, dbExp, activeConfig, taskSpec)
e, _, err := newUnmanagedExperiment(ctx, idb, a.m, dbExp, modelDef, activeConfig, taskSpec)
if err != nil {
return nil, fmt.Errorf("failed to make new unmanaged experiment: %w", err)
}
Expand Down Expand Up @@ -1494,20 +1494,18 @@ func (a *apiServer) ContinueExperiment(
return nil, err
}

dbExp, activeConfig, _, taskSpec, err := a.m.parseCreateExperiment(
dbExp, modelDef, activeConfig, _, taskSpec, err := a.m.parseCreateExperiment(
&apiv1.CreateExperimentRequest{
Config: string(configBytes),
ParentId: req.Id, // Use parent logic.
Config: string(configBytes),
}, user,
)
if err != nil {
return nil, fmt.Errorf("parsing continue experiment request: %w", err)
}
dbExp.ParentID = nil // Not actually a parent though.
dbExp.ID = int(req.Id)
dbExp.JobID = origExperiment.JobID // Revive job.

e, launchWarnings, err := newExperiment(a.m, dbExp, activeConfig, taskSpec)
e, launchWarnings, err := newExperiment(a.m, dbExp, modelDef, activeConfig, taskSpec)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create experiment: %s", err)
}
Expand Down Expand Up @@ -1653,7 +1651,7 @@ func (a *apiServer) CreateExperiment(
}
}

dbExp, activeConfig, p, taskSpec, err := a.m.parseCreateExperiment(
dbExp, modelDef, activeConfig, p, taskSpec, err := a.m.parseCreateExperiment(
req, user,
)
if err != nil {
Expand Down Expand Up @@ -1681,7 +1679,7 @@ func (a *apiServer) CreateExperiment(
}

if req.Unmanaged != nil && *req.Unmanaged {
return a.createUnmanagedExperimentTx(ctx, db.Bun(), dbExp, activeConfig, taskSpec, user)
return a.createUnmanagedExperimentTx(ctx, db.Bun(), dbExp, modelDef, activeConfig, taskSpec, user)
}
// Check user has permission for what they are trying to do
// before actually saving the experiment.
Expand All @@ -1691,10 +1689,11 @@ func (a *apiServer) CreateExperiment(
}
}

e, launchWarnings, err := newExperiment(a.m, dbExp, activeConfig, taskSpec)
e, launchWarnings, err := newExperiment(a.m, dbExp, modelDef, activeConfig, taskSpec)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create experiment: %s", err)
}
modelDef = nil //nolint:ineffassign

if err = e.Start(); err != nil {
return nil, errors.Wrapf(err, "failed to start experiment %d", e.ID)
Expand Down Expand Up @@ -1734,7 +1733,7 @@ func (a *apiServer) PutExperiment(
return nil, status.Errorf(codes.Internal, "failed to get the user: %s", err)
}

dbExp, activeConfig, p, taskSpec, err := a.m.parseCreateExperiment(
dbExp, modelDef, activeConfig, p, taskSpec, err := a.m.parseCreateExperiment(
req.CreateExperimentRequest, user,
)
if err != nil {
Expand All @@ -1748,8 +1747,9 @@ func (a *apiServer) PutExperiment(

dbExp.ExternalExperimentID = &req.ExternalExperimentId

innerResp, err = a.createUnmanagedExperimentTx(ctx, db.Bun(), dbExp, activeConfig, taskSpec, user)

innerResp, err = a.createUnmanagedExperimentTx(
ctx, db.Bun(), dbExp, modelDef, activeConfig, taskSpec, user,
)
if err != nil {
return nil, fmt.Errorf("failed to create unmanaged experiment: %w", err)
}
Expand Down
66 changes: 31 additions & 35 deletions master/internal/api_experiment_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,18 +574,17 @@ func TestGetExperiments(t *testing.T) {
})
activeConfig0 = schemas.WithDefaults(activeConfig0)
exp0 := &model.Experiment{
StartTime: startTime,
EndTime: &endTime,
ModelDefinitionBytes: []byte{1, 2, 3},
JobID: model.JobID(job0ID),
Archived: false,
State: model.PausedState,
Notes: "notes",
Config: activeConfig0.AsLegacy(),
OwnerID: ptrs.Ptr(model.UserID(1)),
ProjectID: int(pid),
StartTime: startTime,
EndTime: &endTime,
JobID: model.JobID(job0ID),
Archived: false,
State: model.PausedState,
Notes: "notes",
Config: activeConfig0.AsLegacy(),
OwnerID: ptrs.Ptr(model.UserID(1)),
ProjectID: int(pid),
}
require.NoError(t, api.m.db.AddExperiment(exp0, activeConfig0))
require.NoError(t, api.m.db.AddExperiment(exp0, []byte{1, 2, 3}, activeConfig0))
for i := 0; i < 3; i++ {
task := &model.Task{TaskType: model.TaskTypeTrial, TaskID: model.NewTaskID()}
require.NoError(t, db.AddTask(ctx, task))
Expand Down Expand Up @@ -631,17 +630,16 @@ func TestGetExperiments(t *testing.T) {
})
activeConfig1 = schemas.WithDefaults(activeConfig1)
exp1 := &model.Experiment{
StartTime: secondStartTime,
ModelDefinitionBytes: []byte{1, 2, 3},
JobID: model.JobID(job1ID),
Archived: true,
State: model.ErrorState,
ParentID: ptrs.Ptr(exp0.ID),
Config: activeConfig1.AsLegacy(),
OwnerID: ptrs.Ptr(model.UserID(userResp.User.Id)),
ProjectID: int(pid),
StartTime: secondStartTime,
JobID: model.JobID(job1ID),
Archived: true,
State: model.ErrorState,
ParentID: ptrs.Ptr(exp0.ID),
Config: activeConfig1.AsLegacy(),
OwnerID: ptrs.Ptr(model.UserID(userResp.User.Id)),
ProjectID: int(pid),
}
require.NoError(t, api.m.db.AddExperiment(exp1, activeConfig1))
require.NoError(t, api.m.db.AddExperiment(exp1, []byte{1, 2, 3}, activeConfig1))
exp1Expected := &experimentv1.Experiment{
StartTime: timestamppb.New(secondStartTime),
Duration: ptrs.Ptr(int32(0)),
Expand Down Expand Up @@ -1028,18 +1026,17 @@ func benchmarkGetExperiments(b *testing.B, n int) {
})
activeConfig = schemas.WithDefaults(activeConfig)
exp := &model.Experiment{
ModelDefinitionBytes: []byte{1, 2, 3},
State: model.PausedState,
Config: activeConfig.AsLegacy(),
OwnerID: ptrs.Ptr(model.UserID(userResp.User.Id)),
ProjectID: 1,
State: model.PausedState,
Config: activeConfig.AsLegacy(),
OwnerID: ptrs.Ptr(model.UserID(userResp.User.Id)),
ProjectID: 1,
}
for i := 0; i < n; i++ {
jobID := uuid.New().String()
exp.ID = 0
exp.JobID = model.JobID(jobID)

if err := api.m.db.AddExperiment(exp, activeConfig); err != nil {
if err := api.m.db.AddExperiment(exp, []byte{1, 2, 3}, activeConfig); err != nil {
b.Fatal(err)
}
}
Expand Down Expand Up @@ -1081,15 +1078,14 @@ func createTestExpWithProjectID(
})
activeConfig = schemas.WithDefaults(activeConfig)
exp := &model.Experiment{
JobID: model.JobID(uuid.New().String()),
State: model.PausedState,
OwnerID: &curUser.ID,
ProjectID: projectID,
StartTime: time.Now(),
ModelDefinitionBytes: []byte{10, 11, 12},
Config: activeConfig.AsLegacy(),
JobID: model.JobID(uuid.New().String()),
State: model.PausedState,
OwnerID: &curUser.ID,
ProjectID: projectID,
StartTime: time.Now(),
Config: activeConfig.AsLegacy(),
}
require.NoError(t, api.m.db.AddExperiment(exp, activeConfig))
require.NoError(t, api.m.db.AddExperiment(exp, []byte{10, 11, 12}, activeConfig))

// Get experiment as our API mostly will to make it easier to mock.
exp, err := db.ExperimentByID(context.TODO(), exp.ID)
Expand Down
17 changes: 8 additions & 9 deletions master/internal/core_checkpoint_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,16 +463,15 @@ func mockExperimentS3(
})

exp := model.Experiment{
JobID: model.NewJobID(),
State: model.ActiveState,
Config: cfg.AsLegacy(),
ModelDefinitionBytes: db.ReadTestModelDefiniton(t, folderPath),
StartTime: time.Now().Add(-time.Hour),
OwnerID: &user.ID,
Username: user.Username,
ProjectID: 1,
JobID: model.NewJobID(),
State: model.ActiveState,
Config: cfg.AsLegacy(),
StartTime: time.Now().Add(-time.Hour),
OwnerID: &user.ID,
Username: user.Username,
ProjectID: 1,
}
err := pgDB.AddExperiment(&exp, cfg)
err := pgDB.AddExperiment(&exp, db.ReadTestModelDefiniton(t, folderPath), cfg)
require.NoError(t, err, "failed to add experiment")
return &exp
}
36 changes: 18 additions & 18 deletions master/internal/core_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,21 +262,21 @@ func getCreateExperimentsProject(
}

func (m *Master) parseCreateExperiment(req *apiv1.CreateExperimentRequest, owner *model.User) (
*model.Experiment, expconf.ExperimentConfig, *projectv1.Project, *tasks.TaskSpec, error,
*model.Experiment, []byte, expconf.ExperimentConfig, *projectv1.Project, *tasks.TaskSpec, error,
) {
ctx := context.TODO()
// Read the config as the user provided it.
config, err := expconf.ParseAnyExperimentConfigYAML([]byte(req.Config))
if err != nil {
return nil, config, nil, nil, errors.Wrap(err, "invalid experiment configuration")
return nil, nil, config, nil, nil, errors.Wrap(err, "invalid experiment configuration")
}

// Apply the template that the user specified.
if req.Template != nil {
var tc expconf.ExperimentConfig
err := templates.UnmarshalTemplateConfig(ctx, *req.Template, owner, &tc, true)
if err != nil {
return nil, config, nil, nil, err
return nil, nil, config, nil, nil, err
}
config = schemas.Merge(config, tc)
}
Expand All @@ -286,34 +286,34 @@ func (m *Master) parseCreateExperiment(req *apiv1.CreateExperimentRequest, owner

p, err := getCreateExperimentsProject(m, req, owner, defaulted)
if err != nil {
return nil, config, nil, nil, err
return nil, nil, config, nil, nil, err
}
workspaceModel, err := workspace.WorkspaceByProjectID(ctx, int(p.Id))
if err != nil && errors.Cause(err) != sql.ErrNoRows {
return nil, config, nil, nil, err
return nil, nil, config, nil, nil, err
}
workspaceID := resolveWorkspaceID(workspaceModel)
poolName, err := m.rm.ResolveResourcePool(
resources.ResourcePool(), workspaceID, resources.SlotsPerTrial())
if err != nil {
return nil, config, nil, nil, errors.Wrapf(err, "invalid resource configuration")
return nil, nil, config, nil, nil, errors.Wrapf(err, "invalid resource configuration")
}
isSingleNode := resources.IsSingleNode() != nil && *resources.IsSingleNode()
if err = m.rm.ValidateResources(poolName, resources.SlotsPerTrial(), isSingleNode); err != nil {
return nil, config, nil, nil, errors.Wrapf(err, "error validating resources")
return nil, nil, config, nil, nil, errors.Wrapf(err, "error validating resources")
}
taskContainerDefaults, err := m.rm.TaskContainerDefaults(
poolName,
m.config.TaskContainerDefaults,
)
if err != nil {
return nil, config, nil, nil, errors.Wrapf(err, "error getting TaskContainerDefaults")
return nil, nil, config, nil, nil, errors.Wrapf(err, "error getting TaskContainerDefaults")
}
taskSpec := *m.taskSpec
taskSpec.TaskContainerDefaults = taskContainerDefaults
taskSpec.TaskContainerDefaults.MergeIntoExpConfig(&config)
if defaulted.RawEntrypoint == nil && (req.Unmanaged == nil || !*req.Unmanaged) {
return nil, config, nil, nil, errors.New("managed experiments require entrypoint")
return nil, nil, config, nil, nil, errors.New("managed experiments require entrypoint")
}

// Merge in workspace's checkpoint storage into the conifg.
Expand All @@ -322,7 +322,7 @@ func (m *Master) parseCreateExperiment(req *apiv1.CreateExperimentRequest, owner
Where("id = ?", p.WorkspaceId).
Column("checkpoint_storage_config").
Scan(ctx); err != nil {
return nil, config, nil, nil, err
return nil, nil, config, nil, nil, err
}
config.RawCheckpointStorage = schemas.Merge(
config.RawCheckpointStorage, w.CheckpointStorageConfig)
Expand All @@ -337,12 +337,12 @@ func (m *Master) parseCreateExperiment(req *apiv1.CreateExperimentRequest, owner

// Make sure the experiment config has all eventuallyRequired fields.
if err = schemas.IsComplete(config); err != nil {
return nil, config, nil, nil, errors.Wrap(err, "invalid experiment configuration")
return nil, nil, config, nil, nil, errors.Wrap(err, "invalid experiment configuration")
}

// Disallow EOL searchers.
if err = config.Searcher().AssertCurrent(); err != nil {
return nil, config, nil, nil, errors.Wrap(err, "invalid experiment configuration")
return nil, nil, config, nil, nil, errors.Wrap(err, "invalid experiment configuration")
}

modelBytes := []byte{}
Expand All @@ -352,34 +352,34 @@ func (m *Master) parseCreateExperiment(req *apiv1.CreateExperimentRequest, owner
var dbErr error
modelBytes, dbErr = m.db.ExperimentModelDefinitionRaw(int(req.ParentId))
if dbErr != nil {
return nil, config, nil, nil, errors.Wrapf(
return nil, nil, config, nil, nil, errors.Wrapf(
dbErr, "unable to find parent experiment %v", req.ParentId)
}
} else {
var compressErr error
if req.ModelDefinition != nil {
modelBytes, compressErr = archive.ToTarGz(filesToArchive(req.ModelDefinition))
if compressErr != nil {
return nil, config, nil, nil, errors.Wrapf(
return nil, nil, config, nil, nil, errors.Wrapf(
compressErr, "unable to find compress model definition")
}
}
}

token, createSessionErr := user.StartSession(ctx, owner)
if createSessionErr != nil {
return nil, config, nil, nil, errors.Wrapf(
return nil, nil, config, nil, nil, errors.Wrapf(
createSessionErr, "unable to create user session inside task")
}
taskSpec.UserSessionToken = token
taskSpec.Owner = owner

dbExp, err := model.NewExperiment(
config, req.Config, modelBytes, parentID, false,
config, req.Config, parentID, false,
int(p.Id), req.Unmanaged != nil && *req.Unmanaged,
)
if err != nil {
return nil, config, nil, nil, err
return nil, nil, config, nil, nil, err
}

if owner != nil {
Expand All @@ -393,5 +393,5 @@ func (m *Master) parseCreateExperiment(req *apiv1.CreateExperimentRequest, owner
taskSpec.Labels = append(taskSpec.Labels, label)
}

return dbExp, config, p, &taskSpec, err
return dbExp, modelBytes, config, p, &taskSpec, err
}
2 changes: 1 addition & 1 deletion master/internal/db/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type DB interface {
CheckExperimentExists(id int) (bool, error)
CheckTrialExists(id int) (bool, error)
TrialExperimentAndRequestID(id int) (int, model.RequestID, error)
AddExperiment(experiment *model.Experiment, activeConfig expconf.ExperimentConfig) error
AddExperiment(experiment *model.Experiment, modelDef []byte, activeConfig expconf.ExperimentConfig) error
ExperimentIDByTrialID(trialID int) (int, error)
NonTerminalExperiments() ([]*model.Experiment, error)
TerminateExperimentInRestart(id int, state model.State) error
Expand Down
Loading

0 comments on commit 422f5aa

Please sign in to comment.