Skip to content

Commit

Permalink
feat: Pause & Resume run (#9129)
Browse files Browse the repository at this point in the history
  • Loading branch information
AmanuelAaron authored Jun 6, 2024
1 parent df3919c commit 2588eea
Show file tree
Hide file tree
Showing 9 changed files with 1,838 additions and 557 deletions.
110 changes: 110 additions & 0 deletions e2e_tests/tests/run/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,113 @@ def test_run_kill_filter() -> None:
for res in killResp.results:
assert res.error == ""
wait_for_run_state(sess, res.id, bindings.trialv1State.CANCELED)


@pytest.mark.e2e_cpu
def test_run_pause_and_resume() -> None:
sess = api_utils.user_session()
exp_id = exp.create_experiment(
sess, conf.fixtures_path("no_op/single.yaml"), conf.fixtures_path("no_op")
)

searchResp = bindings.get_SearchRuns(
sess,
limit=1,
filter="""{"filterGroup":{"children":[{"columnName":"experimentId","kind":"field",
"location":"LOCATION_TYPE_RUN","operator":"=","type":"COLUMN_TYPE_NUMBER","value":"""
+ str(exp_id)
+ """}],"conjunction":"and","kind":"group"},"showArchived":false}""",
)

assert searchResp.runs[0].state == bindings.trialv1State.ACTIVE
run_id = searchResp.runs[0].id
pauseResp = bindings.post_PauseRuns(
sess, body=bindings.v1PauseRunsRequest(runIds=[run_id], projectId=1)
)

# validate response
assert len(pauseResp.results) == 1
assert pauseResp.results[0].id == run_id
assert pauseResp.results[0].error == ""

# ensure that run is paused
wait_for_run_state(sess, run_id, bindings.trialv1State.PAUSED)

resumeResp = bindings.post_ResumeRuns(
sess, body=bindings.v1ResumeRunsRequest(runIds=[run_id], projectId=1)
)

assert len(resumeResp.results) == 1
assert resumeResp.results[0].id == run_id
assert resumeResp.results[0].error == ""

# ensure that run is unpaused
wait_for_run_state(sess, run_id, bindings.trialv1State.ACTIVE)

# kill run for cleanup
_ = bindings.post_KillRuns(sess, body=bindings.v1KillRunsRequest(runIds=[run_id], projectId=1))
wait_for_run_state(sess, run_id, bindings.trialv1State.CANCELED)


@pytest.mark.e2e_cpu
def test_run_pause_and_resume_filter_skip_empty() -> None:
sess = api_utils.user_session()
exp_id = exp.create_experiment(
sess,
conf.fixtures_path("mnist_pytorch/adaptive_short.yaml"),
conf.fixtures_path("mnist_pytorch"),
)

runFilter = (
"""{
"filterGroup": {
"children": [
{
"columnName": "experimentId",
"kind": "field",
"location": "LOCATION_TYPE_RUN",
"operator": "=",
"type": "COLUMN_TYPE_NUMBER",
"value": %s
},
{
"columnName": "hp.n_filters2",
"kind": "field",
"location": "LOCATION_TYPE_RUN_HYPERPARAMETERS",
"operator": ">=",
"type": "COLUMN_TYPE_NUMBER",
"value": 40
}
],
"conjunction": "and",
"kind": "group"
},
"showArchived": false
}"""
% exp_id
)
pauseResp = bindings.post_PauseRuns(
sess,
body=bindings.v1PauseRunsRequest(
runIds=[],
filter=runFilter,
projectId=1,
),
)

# validate response
for r in pauseResp.results:
assert r.error == "Cannot pause/unpause run '" + str(r.id) + "' (part of multi-trial)."

resumeResp = bindings.post_ResumeRuns(
sess,
body=bindings.v1ResumeRunsRequest(runIds=[], projectId=1, filter=runFilter),
)

for res in resumeResp.results:
assert res.error == "Cannot pause/unpause run '" + str(res.id) + "' (part of multi-trial)."
wait_for_run_state(sess, res.id, bindings.trialv1State.ACTIVE)

# kill run for cleanup
_ = bindings.post_KillRuns(sess, body=bindings.v1KillRunsRequest(runIds=[res.id], projectId=1))
wait_for_run_state(sess, res.id, bindings.trialv1State.CANCELED)
158 changes: 158 additions & 0 deletions harness/determined/common/api/bindings.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 2588eea

Please sign in to comment.