Skip to content

Commit

Permalink
fix: fix CreateExperiment for Remote Users (#8700)
Browse files Browse the repository at this point in the history
  • Loading branch information
salonig23 authored Jan 17, 2024
1 parent f00768f commit f2899cc
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 8 deletions.
4 changes: 4 additions & 0 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1680,6 +1680,10 @@ func (a *apiServer) CreateExperiment(
if err != nil {
return nil, err
}

if taskSpec.ExtraEnvVars == nil {
taskSpec.ExtraEnvVars = map[string]string{}
}
maps.Copy(taskSpec.ExtraEnvVars, pachyEnvVars)

if err = experiment.AuthZProvider.Get().CanCreateExperiment(ctx, *user, p); err != nil {
Expand Down
27 changes: 19 additions & 8 deletions master/internal/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,11 @@ func (e *internalExperiment) stop() error {
}
e.syslog.Infof("PostStop state changed to %s", e.State)

taskSpec, err := e.taskSpec.Clone()
if err != nil {
return fmt.Errorf("cloning checkpoint gc task spec: %w", err)
}

checkpoints, err := e.db.ExperimentCheckpointsToGCRaw(
e.Experiment.ID,
e.activeConfig.CheckpointStorage().SaveExperimentBest(),
Expand All @@ -465,14 +470,17 @@ func (e *internalExperiment) stop() error {
e.syslog.WithError(err).Error("")
}

taskSpec := *e.taskSpec
if err := e.db.DeleteSnapshotsForExperiment(e.Experiment.ID); err != nil {
e.syslog.WithError(err).Errorf(
"failure to delete snapshots for experiment: %d", e.Experiment.ID)
}

// May be no checkpoints to gc, if so skip
if len(checkpoints) > 0 {
taskID := model.TaskID(fmt.Sprintf("%d.%s", e.ID, uuid.New()))
go func() {
err := runCheckpointGCTask(
e.rm, e.db, taskID, e.JobID, e.StartTime, taskSpec,
e.rm, e.db, taskID, e.JobID, e.StartTime, *taskSpec,
e.Experiment.ID, e.activeConfig.AsLegacy(), checkpoints, []string{fullDeleteGlob},
false, taskSpec.AgentUserGroup, taskSpec.Owner, e.logCtx,
)
Expand All @@ -482,11 +490,6 @@ func (e *internalExperiment) stop() error {
}()
}

if err := e.db.DeleteSnapshotsForExperiment(e.Experiment.ID); err != nil {
e.syslog.WithError(err).Errorf(
"failure to delete snapshots for experiment: %d", e.Experiment.ID)
}

if err := user.DeleteSessionByToken(
context.TODO(),
taskSpec.UserSessionToken,
Expand Down Expand Up @@ -794,9 +797,17 @@ func (e *internalExperiment) processOperations(
config := schemas.Copy(e.activeConfig)
state := experiment.TrialSearcherState{Create: op, Complete: true}
e.TrialSearcherState[op.RequestID] = state

clonedSpec, err := e.taskSpec.Clone()
if err != nil {
e.syslog.WithError(err).Error("failed to create trial")
e.trialClosed(op.RequestID, ptrs.Ptr(model.Errored))
continue
}

t, err := newTrial(
e.logCtx, trialTaskID(e.ID, op.RequestID), e.JobID, e.StartTime, e.ID, e.State,
state, e.rm, e.db, config, checkpoint, e.taskSpec, e.generatedKeys, false,
state, e.rm, e.db, config, checkpoint, clonedSpec, e.generatedKeys, false,
nil, continueFromTrialID, e.TrialClosed,
)
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions master/pkg/tasks/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

docker "github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/mount"
"github.com/jinzhu/copier"

"github.com/determined-ai/determined/master/pkg/archive"
"github.com/determined-ai/determined/master/pkg/cproto"
Expand Down Expand Up @@ -113,6 +114,17 @@ type TaskSpec struct {
UniqueExposedPortRequests map[string]int
}

// Clone deep copies a taskSpec.
func (t *TaskSpec) Clone() (*TaskSpec, error) {
var res TaskSpec
if err := copier.CopyWithOption(
&res, t, copier.Option{DeepCopy: true, IgnoreEmpty: true},
); err != nil {
return nil, fmt.Errorf("copying task spec %+v: %w", t, err)
}
return &res, nil
}

// ResolveWorkDir resolves the work dir.
func (t *TaskSpec) ResolveWorkDir() {
agentUser := ""
Expand Down
32 changes: 32 additions & 0 deletions master/pkg/tasks/task_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package tasks

import (
"testing"

"github.com/stretchr/testify/require"
k8sV1 "k8s.io/api/core/v1"

"github.com/determined-ai/determined/master/pkg/schemas/expconf"
)

func TestTaskSpecClone(t *testing.T) {
//nolint:exhaustruct
orig := &TaskSpec{
Environment: expconf.EnvironmentConfig{
RawPodSpec: &expconf.PodSpec{
Spec: k8sV1.PodSpec{
ServiceAccountName: "test",
},
},
},
ExtraEnvVars: map[string]string{"a": "true"},
}

cloned, err := orig.Clone()
require.NoError(t, err)
require.Equal(t, orig, cloned)

// Actually deep cloned.
orig.ExtraEnvVars["a"] = "diff"
require.Equal(t, map[string]string{"a": "true"}, cloned.ExtraEnvVars)
}

0 comments on commit f2899cc

Please sign in to comment.