Skip to content

Commit

Permalink
fix: stopping states are not handled in restore properly [RM-69] (#8958)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicholasBlaskey authored Mar 7, 2024
1 parent b06c923 commit 2395dcb
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 45 deletions.
2 changes: 2 additions & 0 deletions .circleci/devcluster/double-reattach.devcluster.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ stages:
password: postgres
user: postgres
name: determined
__internal:
preemption_timeout: 60s
checkpoint_storage:
type: shared_fs
host_path: /tmp
Expand Down
6 changes: 6 additions & 0 deletions docs/release-notes/reattach-stopping.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
:orphan:

**Bug Fixes**

- Experiments: Fix an issue where experiments in the ``STOPPING_CANCELED`` state on master restart,
would leave unkillable containers running on agents.
102 changes: 102 additions & 0 deletions e2e_tests/tests/cluster/test_master_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import docker
import pytest
import requests
import urllib3

from determined.common import api
from determined.common.api import bindings
Expand Down Expand Up @@ -213,6 +214,107 @@ def test_master_restart_continued_experiment(
assert "resources exited successfully with a zero exit code" in "".join(log.log for log in logs)


def experiment_who_cancels_itself_then_waits(sess: api.Session) -> int:
return exp.create_experiment(
sess,
conf.fixtures_path("core_api/sleep.yaml"),
conf.fixtures_path("core_api"),
["--config", "entrypoint='det e cancel $DET_EXPERIMENT_ID && sleep 500'"],
)


@pytest.mark.managed_devcluster
def test_master_restart_stopping(
restartable_managed_cluster: managed_cluster.ManagedCluster,
) -> None:
_test_master_restart_stopping(restartable_managed_cluster)


@pytest.mark.e2e_k8s
def test_master_restart_stopping_k8s(
k8s_managed_cluster: managed_cluster_k8s.ManagedK8sCluster,
) -> None:
_test_master_restart_stopping(k8s_managed_cluster)


def _test_master_restart_stopping(managed_cluster_restarts: abstract_cluster.Cluster) -> None:
sess = api_utils.user_session()
sess._max_retries = urllib3.util.retry.Retry(total=5, backoff_factor=0.5)

exp_id = experiment_who_cancels_itself_then_waits(sess)
try:
exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.STOPPING_CANCELED)
managed_cluster_restarts.kill_master()
managed_cluster_restarts.restart_master()

# Short wait so that we know it was killed by us and not preemption.
exp.wait_for_experiment_state(
sess, exp_id, bindings.experimentv1State.STOPPING_CANCELED, max_wait_secs=30
)
finally:
exp.kill_experiments(sess, [exp_id])
exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.CANCELED)

# All slots are empty, we don't leave a hanging container.
agentsResp = bindings.get_GetAgents(sess)
for a in agentsResp.agents:
if a.slots is not None:
for s in a.slots.values():
assert s.container is None, s.container.to_json()


@pytest.mark.managed_devcluster
def test_master_restart_stopping_ignore_preemption_still_gets_killed(
restartable_managed_cluster: managed_cluster.ManagedCluster,
) -> None:
sess = api_utils.user_session()
sess._max_retries = urllib3.util.retry.Retry(total=5, backoff_factor=0.5)

exp_id = experiment_who_cancels_itself_then_waits(sess)
try:
exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.STOPPING_CANCELED)
restartable_managed_cluster.kill_master()
restartable_managed_cluster.restart_master()
exp.wait_for_experiment_state(
sess, exp_id, bindings.experimentv1State.CANCELED, max_wait_secs=90
)

trial_id = exp.experiment_first_trial(sess, exp_id)
exp.assert_patterns_in_trial_logs(sess, trial_id, ["137"])
finally:
exp.kill_experiments(sess, [exp_id])
exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.CANCELED)


@pytest.mark.managed_devcluster
def test_master_restart_stopping_container_gone(
restartable_managed_cluster: managed_cluster.ManagedCluster,
) -> None:
sess = api_utils.user_session()
exp_id = experiment_who_cancels_itself_then_waits(sess)

exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.STOPPING_CANCELED)

client = docker.from_env()
containers = client.containers.list()

label = "ai.determined.container.description"
containers = [c for c in containers if f"exp-{exp_id}" in c.labels.get(label, "")]
assert len(containers) == 1

restartable_managed_cluster.kill_agent()
restartable_managed_cluster.kill_master()
containers[0].kill()
restartable_managed_cluster.restart_master()
restartable_managed_cluster.restart_agent(wait_for_amnesia=False)

# TODO(RM-70) make this state be an error.
exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.CANCELED)
trials = exp.experiment_trials(sess, exp_id)
assert len(trials) == 1
assert trials[0].trial.state == bindings.trialv1State.ERROR


@pytest.mark.managed_devcluster
@pytest.mark.parametrize("wait_for_amnesia", [True, False])
def test_master_restart_error_missing_docker_container(
Expand Down
1 change: 1 addition & 0 deletions master/internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ type InternalConfig struct {
AuditLoggingEnabled bool `json:"audit_logging_enabled"`
ExternalSessions model.ExternalSessions `json:"external_sessions"`
ProxiedServers []ProxiedServerConfig `json:"proxied_servers"`
PreemptionTimeout *model.Duration `json:"preemption_timeout"`
}

// Validate implements the check.Validatable interface.
Expand Down
8 changes: 1 addition & 7 deletions master/internal/db/postgres_experiments.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,13 +718,7 @@ WHERE unmanaged = false AND state IN (
}
continue
}
if model.StoppingStates[exp.State] {
finalState := model.StoppingToTerminalStates[exp.State]
if err := db.TerminateExperimentInRestart(exp.ID, finalState); err != nil {
log.WithError(err).Errorf("finalizing %v on restart", exp)
}
continue
}

exps = append(exps, &exp)
}
return exps, nil
Expand Down
34 changes: 23 additions & 11 deletions master/internal/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,14 @@ func (e *internalExperiment) start() error {
}

e.restoreTrials()

// Resend stopping state to trials again so we can reregister preemption timeout and stuff.
if model.StoppingStates[e.State] && e.State != model.StoppingCompletedState {
e.patchTrialsState(model.StateWithReason{
State: e.State,
InformationalReason: "resending stopping state signal on restore",
})
}
return nil
}

Expand Down Expand Up @@ -934,7 +942,22 @@ func (e *internalExperiment) updateState(state model.StateWithReason) bool {
}

e.syslog.Infof("updateState changed to %s", state.State)
e.patchTrialsState(state)

// The database error is explicitly ignored.
if err := e.db.SaveExperimentState(e.Experiment); err != nil {
e.syslog.Errorf("error saving experiment state: %s", err)
}
if e.canTerminate() {
if err := e.stop(); err != nil {
e.syslog.WithError(err).Error("failed to stop experiment on updateState")
}
}

return true
}

func (e *internalExperiment) patchTrialsState(state model.StateWithReason) {
var g errgroup.Group
g.SetLimit(maxConcurrentTrialOps)
for _, t := range e.trials {
Expand All @@ -948,17 +971,6 @@ func (e *internalExperiment) updateState(state model.StateWithReason) bool {
})
}
_ = g.Wait() // Errors are handled in g.Go.

if err := e.db.SaveExperimentState(e.Experiment); err != nil {
e.syslog.Errorf("error saving experiment state: %s", err)
}
if e.canTerminate() {
if err := e.stop(); err != nil {
e.syslog.WithError(err).Error("failed to stop experiment on updateState")
}
}
// The database error is explicitly ignored.
return true
}

func (e *internalExperiment) canTerminate() bool {
Expand Down
24 changes: 2 additions & 22 deletions master/internal/restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ import (
log "github.com/sirupsen/logrus"

"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/telemetry"
"github.com/determined-ai/determined/master/internal/user"
"github.com/determined-ai/determined/master/internal/webhooks"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/master/pkg/schemas"
Expand Down Expand Up @@ -64,23 +62,8 @@ func (m *Master) restoreExperiment(expModel *model.Experiment) error {
if err != nil {
return errors.Errorf("cannot restore experiment %d with unparsable config", expModel.ID)
}
if terminal, ok := model.StoppingToTerminalStates[expModel.State]; ok {
if err = m.db.TerminateExperimentInRestart(expModel.ID, terminal); err != nil {
return errors.Wrapf(err, "terminating experiment %d", expModel.ID)
}
expModel.State = terminal
telemetry.ReportExperimentStateChanged(m.db, expModel)
if err := webhooks.ReportExperimentStateChanged(
context.TODO(), *expModel, activeConfig,
); err != nil {
log.WithError(err).Error("failed to send experiment state change webhook in restore")
}
return nil
} else if _, ok := model.RunningStates[expModel.State]; !ok {
return errors.Errorf(
"cannot restore experiment %d from state %v", expModel.ID, expModel.State,
)
} else if err = activeConfig.Searcher().AssertCurrent(); err != nil {

if err := activeConfig.Searcher().AssertCurrent(); err != nil {
return errors.Errorf(
"cannot restore experiment %d with legacy searcher", expModel.ID,
)
Expand Down Expand Up @@ -176,9 +159,6 @@ func (e *internalExperiment) restoreTrial(
if model.TerminalStates[trial.State] {
l.Debugf("trial was in terminal state in restore: %s", trial.State)
terminal = true
} else if !model.RunningStates[trial.State] {
l.Debugf("cannot restore trial in state: %s", trial.State)
terminal = true
}
}

Expand Down
2 changes: 1 addition & 1 deletion master/internal/task/allocation_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func (as *allocationService) Signal(
sig AllocationSignal,
reason string,
) error {
ref, err := as.getAllocation(id)
ref, err := as.waitForRestore(context.TODO(), id)
if err != nil {
return err
}
Expand Down
8 changes: 8 additions & 0 deletions master/internal/task/allocation_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,14 @@ func TestRestore(t *testing.T) {
db, _, id, q, exitFuture := requireStarted(t, func(ar *sproto.AllocateRequest) {
*ar = restoredAr
})

rID, resources := requireAssigned(t, pgDB, restoredAr.AllocationID, q)
q.Put(&sproto.ResourcesAllocated{
ID: restoredAr.AllocationID,
ResourcePool: restoredAr.ResourcePool,
Resources: map[sproto.ResourcesID]sproto.Resources{rID: resources},
Recovered: true,
})
defer requireKilled(t, db, id, q, exitFuture)
}

Expand Down
8 changes: 7 additions & 1 deletion master/internal/task/preemptible/preemptible.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/google/uuid"

"github.com/determined-ai/determined/master/internal/config"
"github.com/determined-ai/determined/master/pkg/syncx/waitgroupx"
)

Expand Down Expand Up @@ -84,7 +85,12 @@ func (p *Preemptible) Preempt(timeoutCallback TimeoutFn) {
if !p.preempted {
p.wg.Go(func(ctx context.Context) {
// don't acquire a lock in here without changing close to not lock while it waits.
t := time.NewTimer(DefaultTimeout)
timeout := DefaultTimeout
if debugTimeout := config.GetMasterConfig().InternalConfig.PreemptionTimeout; debugTimeout != nil {
timeout = time.Duration(*debugTimeout)
}

t := time.NewTimer(timeout)
defer t.Stop()

select {
Expand Down
8 changes: 5 additions & 3 deletions master/internal/trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,15 @@ func (t *trial) continueSetup(continueFromTrialID *int) error {

// maybeAllocateTask checks if the trial should allocate state and allocates it if so.
func (t *trial) maybeAllocateTask() error {
if !(t.allocationID == nil &&
!t.searcher.Complete &&
t.state == model.ActiveState) {
// Only allocate for active trials, or trials that have been restored and are stopping.
// We need to allocate for stopping because we need to reattach the allocation.
shouldAllocateState := t.state == model.ActiveState || (t.restored && model.StoppingStates[t.state])
if t.allocationID != nil || t.searcher.Complete || !shouldAllocateState {
t.syslog.WithFields(logrus.Fields{
"allocation-id": t.allocationID,
"sercher-complete": t.searcher.Complete,
"trial-state": t.state,
"restored": t.restored,
}).Trace("decided not to allocate trial")
return nil
}
Expand Down

0 comments on commit 2395dcb

Please sign in to comment.