Skip to content

Commit

Permalink
ENH: Add support to show run outputs in SDK v2 (#947)
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Aug 23, 2024
1 parent c78a707 commit 598b78f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
6 changes: 5 additions & 1 deletion hi-ml-azure/src/health_azure/himl.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def submit_run_v2(
tags: Optional[Dict[str, str]] = None,
docker_shm_size: str = "",
wait_for_completion: bool = False,
wait_for_completion_show_output: bool = False,
identity_based_auth: bool = False,
hyperparam_args: Optional[Dict[str, Any]] = None,
num_nodes: int = 1,
Expand Down Expand Up @@ -444,6 +445,8 @@ def submit_run_v2(
:param docker_shm_size: The Docker shared memory size that should be used when creating a new Docker image.
:param wait_for_completion: If False (the default) return after the run is submitted to AzureML, otherwise wait for
the completion of this run (if True).
:param wait_for_completion_show_output: If wait_for_completion is True this parameter indicates whether to show the
run output on sys.stdout.
:param hyperparam_args: A dictionary of hyperparameter search args to pass into a sweep job.
:param num_nodes: The number of nodes to use for the job in AzureML. The value must be 1 or greater.
:param pytorch_processes_per_node: For plain PyTorch multi-GPU processing: The number of processes per node.
Expand Down Expand Up @@ -547,7 +550,7 @@ def create_command_job(cmd: str) -> Command:
print("==============================================================================\n")
if wait_for_completion:
print("Waiting for the completion of the AzureML job.")
wait_for_job_completion(ml_client, job_name=returned_job.name)
wait_for_job_completion(ml_client, job_name=returned_job.name, show_output=wait_for_completion_show_output)
print("AzureML job completed.")
# After waiting, ensure that the caller gets the latest version job object
returned_job = ml_client.jobs.get(returned_job.name)
Expand Down Expand Up @@ -1001,6 +1004,7 @@ def submit_to_azure_if_needed( # type: ignore
display_name=display_name,
docker_shm_size=docker_shm_size,
wait_for_completion=wait_for_completion,
wait_for_completion_show_output=wait_for_completion_show_output,
identity_based_auth=identity_based_auth,
hyperparam_args=hyperparam_args,
num_nodes=num_nodes,
Expand Down
26 changes: 15 additions & 11 deletions hi-ml-azure/src/health_azure/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,24 +903,28 @@ def is_job_completed(job: Job) -> bool:
return job.status == JobStatus.COMPLETED.value


def wait_for_job_completion(ml_client: MLClient, job_name: str) -> None:
def wait_for_job_completion(ml_client: MLClient, job_name: str, *, show_output: bool = False) -> None:
"""Wait until the job of the given ID is completed or failed with an error. If the job did not complete
successfully, a ValueError is raised.
:param ml_client: An MLClient object for the workspace where the job lives.
:param job_name: The name (id) of the job to wait for.
:param show_output: If True, log the run output on sys.stdout.
:raises ValueError: If the job did not complete successfully (any status other than Completed)
"""

while True:
# Get the latest job status by reading the whole job info again via the MLClient
updated_job = ml_client.jobs.get(name=job_name)
current_job_status = updated_job.status
if JobStatus.is_finished_state(current_job_status):
break
time.sleep(10)
if not is_job_completed(updated_job):
raise ValueError(f"Job {updated_job.name} jobs failed with status {current_job_status}.")
if show_output:
ml_client.jobs.stream(job_name)
job = ml_client.jobs.get(name=job_name)
else:
while True:
# Get the latest job status by reading the whole job info again via the MLClient
job = ml_client.jobs.get(name=job_name)
current_job_status = job.status
if JobStatus.is_finished_state(current_job_status):
break
time.sleep(10)
if not is_job_completed(job):
raise ValueError(f'Job "{job.name}" failed with status "{current_job_status}"')


def get_most_recent_run_id(run_recovery_file: Path) -> str:
Expand Down

0 comments on commit 598b78f

Please sign in to comment.