diff --git a/composer/loggers/mlflow_logger.py b/composer/loggers/mlflow_logger.py index 9ed6415dce..dcbed33d96 100644 --- a/composer/loggers/mlflow_logger.py +++ b/composer/loggers/mlflow_logger.py @@ -148,7 +148,7 @@ def init(self, state: State, logger: Logger) -> None: # Store the Composer run name in the MLFlow run tags so it can be retrieved for autoresume. self.tags = self.tags or {} - self.tags['run_name'] = state.run_name + self.tags['run_name'] = os.environ.get('RUN_NAME', state.run_name) # Adjust name and group based on `rank_zero_only`. if not self._rank_zero_only: @@ -171,16 +171,6 @@ def init(self, state: State, logger: Logger) -> None: output_format='list', ) - # Check for the old tag (`composer_run_name`) For backwards compatibility in case a run using the old - # tag fails and the run is resumed with a newer version of Composer that uses `run_name` instead of - # `composer_run_name`. - if len(existing_runs) == 0: - existing_runs = mlflow.search_runs( - experiment_ids=[self._experiment_id], - filter_string=f'tags.composer_run_name = "{state.run_name}"', - output_format='list', - ) - if len(existing_runs) > 0: self._run_id = existing_runs[0].info.run_id else: diff --git a/tests/loggers/test_mlflow_logger.py b/tests/loggers/test_mlflow_logger.py index 22b0445e2d..b97c108633 100644 --- a/tests/loggers/test_mlflow_logger.py +++ b/tests/loggers/test_mlflow_logger.py @@ -6,7 +6,7 @@ import os import time from pathlib import Path -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import numpy as np import pytest @@ -190,24 +190,52 @@ def test_mlflow_experiment_init_existing_composer_run(monkeypatch): assert test_logger._run_id == existing_id -def test_mlflow_experiment_init_existing_composer_run_with_old_tag(monkeypatch): - """ Test that an existing MLFlow run is used if one exists with the old `composer_run_name` tag. - """ +@pytest.fixture +def mock_mlflow_client(): + with patch('mlflow.tracking.MlflowClient') as MockClient: + mock_create_run = MagicMock(return_value=MagicMock(info=MagicMock(run_id='mock-run-id'))) + MockClient.return_value.create_run = mock_create_run + yield MockClient + + +def test_mlflow_logger_uses_env_var_run_name(monkeypatch, mock_mlflow_client): + """Test that MLFlowLogger uses the 'RUN_NAME' environment variable if set.""" mlflow = pytest.importorskip('mlflow') monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock()) monkeypatch.setattr(mlflow, 'start_run', MagicMock()) + from composer.loggers.mlflow_logger import MLFlowLogger + mock_state = MagicMock() + mock_state.run_name = 'dummy-run-name' + monkeypatch.setenv('RUN_NAME', 'env-run-name') + + test_logger = MLFlowLogger() + test_logger.init(state=mock_state, logger=MagicMock()) + + assert test_logger.tags is not None + assert test_logger.tags['run_name'] == 'env-run-name' + monkeypatch.delenv('RUN_NAME') + + +def test_mlflow_logger_uses_state_run_name_if_no_env_var_set(monkeypatch, mock_mlflow_client): + """Test that MLFlowLogger uses the state's run name if no 'RUN_NAME' environment variable is set.""" + mlflow = pytest.importorskip('mlflow') + + monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock()) + monkeypatch.setattr(mlflow, 'start_run', MagicMock()) mock_state = MagicMock() - mock_state.composer_run_name = 'dummy-run-name' + mock_state.run_name = 'state-run-name' existing_id = 'dummy-id' mock_search_runs = MagicMock(return_value=[MagicMock(info=MagicMock(run_id=existing_id))]) monkeypatch.setattr(mlflow, 'search_runs', mock_search_runs) + from composer.loggers.mlflow_logger import MLFlowLogger test_logger = MLFlowLogger() test_logger.init(state=mock_state, logger=MagicMock()) - assert test_logger._run_id == existing_id + assert test_logger.tags is not None + assert test_logger.tags['run_name'] == 'state-run-name' def test_mlflow_experiment_set_up(tmp_path):