diff --git a/e2e_tests/tests/cluster/test_checkpoints.py b/e2e_tests/tests/cluster/test_checkpoints.py index 13d6070262c..9e9ca09e0e7 100644 --- a/e2e_tests/tests/cluster/test_checkpoints.py +++ b/e2e_tests/tests/cluster/test_checkpoints.py @@ -21,19 +21,26 @@ EXPECT_TIMEOUT = 5 -def wait_for_gc_to_finish(experiment_id: int) -> None: +def wait_for_gc_to_finish(experiment_ids: List[int]) -> None: certs.cli_cert = certs.default_load(conf.make_master_url()) authentication.cli_auth = authentication.Authentication(conf.make_master_url()) - saw_gc = False + + seen_gc_experiment_ids = set() + done_gc_experiment_ids = set() # Don't wait longer than 5 minutes (as 600 half-seconds to improve our sampling resolution). for _ in range(600): r = api.get(conf.make_master_url(), "tasks").json() names = [task["name"] for task in r.values()] - gc_name = f"Checkpoint GC (Experiment {experiment_id})" - if gc_name in names: - saw_gc = True - elif saw_gc: - # We previously saw checkpoint gc but now we don't, so it must have finished. + + for experiment_id in experiment_ids: + gc_name = f"Checkpoint GC (Experiment {experiment_id})" + if gc_name in names: + seen_gc_experiment_ids.add(experiment_id) + elif experiment_id in seen_gc_experiment_ids: + # We saw the gc before but now don't so we assume it is done. + done_gc_experiment_ids.add(experiment_id) + + if len(done_gc_experiment_ids) == len(experiment_ids): return time.sleep(0.5) @@ -145,15 +152,12 @@ def test_delete_checkpoints() -> None: config, model_def_path=conf.fixtures_path("no_op"), expected_trials=1 ) - wait_for_gc_to_finish(exp_id_1) - wait_for_gc_to_finish(exp_id_2) - test_session = api_utils.determined_test_session() exp_1_checkpoints = bindings.get_GetExperimentCheckpoints( session=test_session, id=exp_id_1 ).checkpoints exp_2_checkpoints = bindings.get_GetExperimentCheckpoints( - session=test_session, id=exp_id_1 + session=test_session, id=exp_id_2 ).checkpoints assert len(exp_1_checkpoints) > 0, f"no checkpoints found in experiment with ID:{exp_id_1}" assert len(exp_2_checkpoints) > 0, f"no checkpoints found in experiment with ID:{exp_id_2}" @@ -182,8 +186,7 @@ def test_delete_checkpoints() -> None: delete_body = bindings.v1DeleteCheckpointsRequest(checkpointUuids=d_checkpoint_uuids) bindings.delete_DeleteCheckpoints(session=test_session, body=delete_body) - wait_for_gc_to_finish(exp_id_1) - wait_for_gc_to_finish(exp_id_2) + wait_for_gc_to_finish([exp_id_1, exp_id_2]) for d_c in d_checkpoint_uuids: ensure_checkpoint_deleted(test_session, d_c, storage_manager) @@ -265,7 +268,7 @@ def run_gc_checkpoints_test(checkpoint_storage: Dict[str, str]) -> None: # In some configurations, checkpoint GC will run on an auxillary machine, which may have to # be spun up still. So we'll wait for it to run. - wait_for_gc_to_finish(experiment_id) + wait_for_gc_to_finish([experiment_id]) # Checkpoints are not marked as deleted until gc_checkpoint task starts. retries = 5 @@ -465,7 +468,7 @@ def assert_checkpoint_state( checkpointUuids=[completed_checkpoints[0].uuid], ) bindings.post_CheckpointsRemoveFiles(test_session, body=remove_body) - wait_for_gc_to_finish(exp_id) + wait_for_gc_to_finish([exp_id]) assert_checkpoint_state( completed_checkpoints[0].uuid, @@ -491,7 +494,7 @@ def assert_checkpoint_state( checkpointUuids=[completed_checkpoints[0].uuid], ) bindings.post_CheckpointsRemoveFiles(test_session, body=remove_body) - wait_for_gc_to_finish(exp_id) + wait_for_gc_to_finish([exp_id]) assert_checkpoint_state( completed_checkpoints[0].uuid, @@ -509,7 +512,7 @@ def assert_checkpoint_state( checkpointUuids=[completed_checkpoints[0].uuid], ) bindings.post_CheckpointsRemoveFiles(test_session, body=remove_body) - wait_for_gc_to_finish(exp_id) + wait_for_gc_to_finish([exp_id]) assert_checkpoint_state( completed_checkpoints[0].uuid, @@ -532,7 +535,7 @@ def assert_checkpoint_state( checkpointUuids=[completed_checkpoints[1].uuid], ) bindings.post_CheckpointsRemoveFiles(test_session, body=remove_body) - wait_for_gc_to_finish(exp_id) + wait_for_gc_to_finish([exp_id]) assert_checkpoint_state( completed_checkpoints[1].uuid, diff --git a/e2e_tests/tests/experiment/test_core.py b/e2e_tests/tests/experiment/test_core.py index a1653fbe356..7880ec06fff 100644 --- a/e2e_tests/tests/experiment/test_core.py +++ b/e2e_tests/tests/experiment/test_core.py @@ -160,7 +160,7 @@ def test_end_to_end_adaptive() -> None: None, ) - wait_for_gc_to_finish(experiment_id=exp_id) + wait_for_gc_to_finish(experiment_ids=[exp_id]) # Check that validation accuracy look sane (more than 93% on MNIST). trials = exp.experiment_trials(exp_id)