Skip to content

Commit

Permalink
ci: lower hf trainer accuracy target + improve failure messages (#9322)
Browse files Browse the repository at this point in the history
  • Loading branch information
eecsliu authored Jun 7, 2024
1 parent 84299a6 commit 96c061b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion e2e_tests/tests/nightly/test_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_hf_trainer_api_accuracy() -> None:
validations = _get_validation_metrics(detobj, trials[0].trial.id)
validation_accuracies = [v["eval_accuracy"] for v in validations]

target_accuracy = 0.82
target_accuracy = 0.75
assert max(validation_accuracies) > target_accuracy, (
f"hf_trainer_api did not reach minimum target accuracy {target_accuracy}."
f" full validation accuracy history: {validation_accuracies}"
Expand Down
10 changes: 6 additions & 4 deletions e2e_tests/tests/run/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_run_kill() -> None:
killResp = bindings.post_KillRuns(
sess, body=bindings.v1KillRunsRequest(runIds=[run_id], projectId=1)
)
assert len(killResp.results) == 1
assert len(killResp.results) == 1, f"failed to kill run {run_id} from exp {exp_id}"
assert killResp.results[0].id == run_id
assert killResp.results[0].error == ""

Expand All @@ -94,7 +94,9 @@ def test_run_kill() -> None:
)

# validate response
assert len(killResp.results) == 1
assert (
len(killResp.results) == 1
), f"error when trying to terminate run {run_id} from exp {exp_id} a second time"
assert killResp.results[0].id == run_id
assert killResp.results[0].error == ""

Expand Down Expand Up @@ -144,8 +146,8 @@ def test_run_kill_filter() -> None:
searchResp = bindings.get_SearchRuns(sess, filter=runFilter)

# validate response
assert len(searchResp.runs) > 0
assert len(killResp.results) > 0
assert len(killResp.results) > 0, f"failed to kill runs in exp {exp_id}"
assert len(searchResp.runs) > 0, f"failed to search runs in exp {exp_id}"
assert len(killResp.results) == len(searchResp.runs)
for res in killResp.results:
assert res.error == ""
Expand Down

0 comments on commit 96c061b

Please sign in to comment.