From 7af22b80808fa2c387762cbfa5df6c3f325b52d8 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Fri, 29 Dec 2023 13:56:35 -0800 Subject: [PATCH 01/18] Implementation of MLFlowObjectStore --- composer/utils/__init__.py | 6 +- composer/utils/object_store/__init__.py | 5 +- .../utils/object_store/mlflow_object_store.py | 371 ++++++++++++++++++ composer/utils/object_store/object_store.py | 3 +- setup.py | 5 +- .../object_store/test_mlflow_object_store.py | 280 +++++++++++++ 6 files changed, 663 insertions(+), 7 deletions(-) create mode 100644 composer/utils/object_store/mlflow_object_store.py create mode 100644 tests/utils/object_store/test_mlflow_object_store.py diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index dc3d029e94..30930250d9 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -21,8 +21,9 @@ from composer.utils.iter_helpers import IteratorFileStream, ensure_tuple, map_collection from composer.utils.misc import (create_interval_scheduler, get_free_tcp_port, is_model_deepspeed, is_model_fsdp, is_notebook, model_eval_mode, using_torch_2) -from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, ObjectStore, ObjectStoreTransientError, - OCIObjectStore, S3ObjectStore, SFTPObjectStore, UCObjectStore) +from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore, + ObjectStoreTransientError, OCIObjectStore, S3ObjectStore, SFTPObjectStore, + UCObjectStore) from composer.utils.retrying import retry from composer.utils.string_enum import StringEnum @@ -44,6 +45,7 @@ 'OCIObjectStore', 'GCSObjectStore', 'UCObjectStore', + 'MLFlowObjectStore', 'MissingConditionalImportError', 'import_object', 'is_model_deepspeed', diff --git a/composer/utils/object_store/__init__.py b/composer/utils/object_store/__init__.py index de28ec9674..d7757c2284 100644 --- a/composer/utils/object_store/__init__.py +++ b/composer/utils/object_store/__init__.py @@ -5,6 +5,7 @@ from composer.utils.object_store.gcs_object_store import GCSObjectStore from composer.utils.object_store.libcloud_object_store import LibcloudObjectStore +from composer.utils.object_store.mlflow_object_store import MLFlowObjectStore from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError from composer.utils.object_store.oci_object_store import OCIObjectStore from composer.utils.object_store.s3_object_store import S3ObjectStore @@ -12,6 +13,6 @@ from composer.utils.object_store.uc_object_store import UCObjectStore __all__ = [ - 'ObjectStore', 'ObjectStoreTransientError', 'LibcloudObjectStore', 'S3ObjectStore', 'SFTPObjectStore', - 'OCIObjectStore', 'GCSObjectStore', 'UCObjectStore' + 'ObjectStore', 'ObjectStoreTransientError', 'LibcloudObjectStore', 'MLFlowObjectStore', 'S3ObjectStore', + 'SFTPObjectStore', 'OCIObjectStore', 'GCSObjectStore', 'UCObjectStore' ] diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py new file mode 100644 index 0000000000..fe4f04291b --- /dev/null +++ b/composer/utils/object_store/mlflow_object_store.py @@ -0,0 +1,371 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""MLFlow Artifacts object store.""" + +from __future__ import annotations + +import logging +import os +import pathlib +import uuid +from typing import Callable, List, Optional, Tuple, Union + +import mlflow +from databricks.sdk import WorkspaceClient +from mlflow import MlflowClient +from mlflow.entities import FileInfo +from mlflow.exceptions import (ABORTED, DATA_LOSS, DEADLINE_EXCEEDED, ENDPOINT_NOT_FOUND, INTERNAL_ERROR, INVALID_STATE, + NOT_FOUND, REQUEST_LIMIT_EXCEEDED, RESOURCE_DOES_NOT_EXIST, RESOURCE_EXHAUSTED, + TEMPORARILY_UNAVAILABLE, ErrorCode, MlflowException) + +from composer.utils import dist +from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError + +__all__ = ['MLFlowObjectStore'] + +MLFLOW_DATABRICKS_TRACKING_URI = 'databricks' +MLFLOW_DBFS_PATH_PREFIX = 'databricks/mlflow-tracking/' + +DEFAULT_MLFLOW_EXPERIMENT_NAME = 'mlflow-object-store' + +PLACEHOLDER_EXPERIMENT_ID = 'EXPERIMENT_ID' +PLACEHOLDER_RUN_ID = 'RUN_ID' + +# https://github.com/mlflow/mlflow/blob/master/mlflow/exceptions.py for used error codes. +# https://github.com/mlflow/mlflow/blob/master/mlflow/protos/databricks.proto for code descriptions. +_RETRYABLE_SERVER_CODES = [ + ErrorCode.Name(code) for code in [ + DATA_LOSS, + INTERNAL_ERROR, + INVALID_STATE, + TEMPORARILY_UNAVAILABLE, + DEADLINE_EXCEEDED, + ] +] + +_RETRYABLE_CLIENT_CODES = [ErrorCode.Name(code) for code in [ABORTED, REQUEST_LIMIT_EXCEEDED, RESOURCE_EXHAUSTED]] + +_NOT_FOUND_CODES = [ErrorCode.Name(code) for code in [RESOURCE_DOES_NOT_EXIST, NOT_FOUND, ENDPOINT_NOT_FOUND]] + +log = logging.getLogger(__name__) + + +def _wrap_mlflow_exceptions(uri: str, e: MlflowException): + """Wraps retryable MLFlow errors in ObjectStoreTransientError for automatic retry handling.""" + if e.error_code in _RETRYABLE_SERVER_CODES or e.error_code in _RETRYABLE_CLIENT_CODES: + raise ObjectStoreTransientError(e.error_code) from e + elif e.error_code in _NOT_FOUND_CODES: + raise FileNotFoundError(f'Object {uri} not found') from e + + raise e + + +class MLFlowObjectStore(ObjectStore): + """Utility class for uploading and downloading artifacts from MLFlow. + + It can be initializd for an existing run, a new run in an existing experiment, the active run used by the `mlflow` + module, or a new run in a new experiment. See the documentation for ``path`` for more details. + + .. note:: + At this time, only Databricks-managed MLFlow with a 'databricks' tracking URI is supported. + Using this object store requires setting `DATABRICKS_HOST` and `DATABRICKS_TOKEN` + environment variables with the right credentials. + + Because Databricks-managed MLFlow artifact URIs are stored in DBFS, artifact/object names are relative to an + artifacts root directory for the experiment/run in DBFS, `databricks/mlflow-tracking//`. + The URI format is inconsistent with that of other object stores, so the full DBFS path is sometimes used as the + object name in upstream code that uses this class. + + For simplicity in such upstream code, :class:`MLFlowObjectStore` accepts both relative MLFlow artifact paths and + absolute DBFS paths as object names. If an object name takes the form of + `databricks/mlflow-tracking///artifacts/`, + it is assumed to be an absolute DBFS path, and the `` is used when uploading objects to MLFlow. + Otherwise, the object name is assumed to be a relative MLFlow artifact path, and the full provided name will be + used as the artifact path when uploading to MLFlow. + + Args: + path (str): A DBFS path of the form + `databricks/mlflow-tracking///artifacts/`. + `experiment_id` and `run_id` can be set as the format string placeholders `'EXPERIMENT_ID'` and `'RUN_ID'`. + + If both `experiment_id` and `run_id` are set as placeholders, the MLFlowObjectStore will be associated with + the currently active MLFlow run if one exists. If no active run exists, a new run will be created under a + default experiment name, or the experiment name specified by the `MLFLOW_EXPERIMENT_NAME` environment + variable if one is set. + + If `experiment_id` is provided and `run_id` is not, the MLFlowObjectStore will create a new run in the + provided experiment. + + Providing a `run_id` without an `experiment_id` will raise an error. + multipart_upload_chunk_size(int, optional): The maximum size of single chunk in an MLFlow multipart upload. + The maximum number of chunks supported by MLFlow is 10,000, so the max file size that can + be uploaded is `10 000 * multipart_upload_chunk_size`. Defaults to 100MB for a max upload size of 1TB. + """ + + def __init__(self, path: str, multipart_upload_chunk_size: int = 100 * 1024 * 1024) -> None: + tracking_uri = os.getenv('MLFLOW_TRACKING_URI', MLFLOW_DATABRICKS_TRACKING_URI) + if tracking_uri != MLFLOW_DATABRICKS_TRACKING_URI: + raise ValueError( + 'MLFlowObjectStore currently only supports Databricks-hosted MLFlow tracking. ' + f'Environment variable `MLFLOW_TRACKING_URI` is set to a non-Databricks URI {tracking_uri}. ' + f'Please unset it or set the value to `{MLFLOW_DATABRICKS_TRACKING_URI}`.') + + # Use the Databricks WorkspaceClient to check that credentials are set up correctly. + try: + WorkspaceClient() + except Exception as e: + raise ValueError( + f'Databricks SDK credentials not correctly setup. ' + 'Visit https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html#databricks-native-authentication ' + 'to identify different ways to setup credentials.') from e + + self._mlflow_client = MlflowClient(MLFLOW_DATABRICKS_TRACKING_URI) + mlflow.environment_variables.MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.set(multipart_upload_chunk_size) + + experiment_id, run_id, _ = MLFlowObjectStore.parse_dbfs_path(path) + if experiment_id == PLACEHOLDER_EXPERIMENT_ID: + experiment_id = None + if run_id == PLACEHOLDER_RUN_ID: + run_id = None + + # Construct the `experiment_id` and `run_id` depending on whether format placeholders were provided. + if not dist.is_initialized() or dist.get_global_rank() == 0: + experiment_id, run_id = self._init_run_info(experiment_id, run_id) + + if dist.is_initialized(): + mlflow_info = [experiment_id, run_id] + dist.broadcast_object_list(mlflow_info, src=0) + experiment_id, run_id = mlflow_info + + self.experiment_id = experiment_id + self.run_id = run_id + + def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> Tuple[str, str]: + """Returns the experiment ID and run ID for the MLFlow run backing this object store. + + In a distributed setting, this should only be called on the rank 0 process. + """ + if experiment_id is None: + if run_id is not None: + raise ValueError('A `run_id` cannot be provided without a valid `experiment_id`.') + + active_run = mlflow.active_run() + if active_run is not None: + experiment_id = active_run.info.experiment_id + run_id = active_run.info.run_id + log.info(f'MLFlowObjectStore using active MLFlow run (run_id={run_id})') + else: + # If no active run exists, create a new run for the default experiment. + experiment_name = os.getenv(mlflow.environment_variables.MLFLOW_EXPERIMENT_NAME.name, + DEFAULT_MLFLOW_EXPERIMENT_NAME) + + experiment = self._mlflow_client.get_experiment_by_name(experiment_name) + if experiment is not None: + experiment_id = experiment.experiment_id + else: + experiment_id = self._mlflow_client.create_experiment(experiment_name) + + run_id = self._mlflow_client.create_run(experiment_id).info.run_id + + log.info(f'MLFlowObjectStore using a new MLFlow run (run_id={run_id}) ' + f'for new experiment "{experiment_name}" (experiment_id={experiment_id})') + else: + if run_id is not None: + # If a `run_id` is provided, check that it belongs to the provided experiment. + run = self._mlflow_client.get_run(run_id) + if run.info.experiment_id != experiment_id: + raise ValueError( + f'Provided `run_id` {run_id} does not belong to provided experiment {experiment_id}. ' + f'Found experiment {run.info.experiment_id}.') + + log.info(f'MLFlowObjectStore using provided MLFlow run (run_id={run_id}) ' + f'for provided experiment (experiment_id={experiment_id})') + else: + # If no `run_id` is provided, create a new run in the provided experiment. + run = self._mlflow_client.create_run(experiment_id) + run_id = run.info.run_id + + log.info(f'MLFlowObjectStore using new MLFlow run (run_id={run_id}) ' + f'for provided experiment (experiment_id={experiment_id})') + + if experiment_id is None or run_id is None: + raise ValueError('MLFlowObjectStore failed to initialize experiment and run ID.') + + return experiment_id, run_id + + @staticmethod + def parse_dbfs_path(path: str) -> Tuple[str, str, str]: + """Parses a DBFS path to extract the MLFlow experiment ID, run ID, and relative artifact path. + + The path is expected to be of the format + `databricks/mlflow-tracking///artifacts/`. + + Args: + path (str): The DBFS path to parse. + + Returns: + (str, str, str): (experiment_id, run_id, artifact_path) + + Raises: + ValueError: If the path is not of the expected format. + """ + if not path.startswith(MLFLOW_DBFS_PATH_PREFIX): + raise ValueError(f'DBFS MLFlow path should start with {MLFLOW_DBFS_PATH_PREFIX}. Got: {path}') + + # Strip `databricks/mlflow-tracking/` and split into + # ``, ``, `'artifacts'`, ``` + subpath = path[len(MLFLOW_DBFS_PATH_PREFIX):] + mlflow_parts = subpath.split('/', maxsplit=3) + + if len(mlflow_parts) != 4 or mlflow_parts[2] != 'artifacts': + raise ValueError(f'Databricks MLFlow artifact path expected to be of the format ' + f'{MLFLOW_DBFS_PATH_PREFIX}///artifacts/. ' + f'Found path={path}') + + return mlflow_parts[0], mlflow_parts[1], mlflow_parts[3] + + def get_artifact_path(self, object_name: str) -> str: + """Converts an object name into an MLFlow relative artifact path. + + Args: + object_name (str): The object name to convert. If the object name is a DBFS path beginning with + ``MLFLOW_DBFS_PATH_PREFIX``, the path will be parsed to extract the MLFlow relative artifact path. + Otherwise, the object name is assumed to be a relative artifact path and will be returned as-is. + """ + if object_name.startswith(MLFLOW_DBFS_PATH_PREFIX): + _, _, object_name = self.parse_dbfs_path(object_name) + return object_name + + def get_dbfs_path(self, object_name: str) -> str: + """Converts an object name to a full DBFS path.""" + artifact_path = self.get_artifact_path(object_name) + return f'{MLFLOW_DBFS_PATH_PREFIX}{self.experiment_id}/{self.run_id}/artifacts/{artifact_path}' + + def get_uri(self, object_name: str) -> str: + return 'dbfs:/' + self.get_dbfs_path(object_name) + + def upload_object(self, + object_name: str, + filename: Union[str, pathlib.Path], + callback: Optional[Callable[[int, int], None]] = None): + del callback # unused + + # Extract relative path from DBFS path. + artifact_path = self.get_artifact_path(object_name) + artifact_base_name = os.path.basename(artifact_path) + artifact_dir = os.path.dirname(artifact_path) + + # Since MLFlow doesn't support uploading artifacts with a different base name than the local file, + # create a temporary symlink to the local file with the same base name as the desired artifact name. + filename = os.path.abspath(filename) + tmp_dir = f'/tmp/{uuid.uuid4()}' + os.makedirs(tmp_dir, exist_ok=True) + tmp_symlink_path = os.path.join(tmp_dir, artifact_base_name) + os.symlink(filename, tmp_symlink_path) + + try: + self._mlflow_client.log_artifact(self.run_id, tmp_symlink_path, artifact_dir) + except MlflowException as e: + _wrap_mlflow_exceptions(self.get_uri(object_name), e) + + def get_object_size(self, object_name: str) -> int: + artifact = None + try: + artifact = self._get_artifact_info(object_name) + except MlflowException as e: + _wrap_mlflow_exceptions(self.get_uri(object_name), e) + + if artifact is not None: + return artifact.file_size + else: + raise FileNotFoundError(f'Object {object_name} not found') + + def download_object( + self, + object_name: str, + filename: Union[str, pathlib.Path], + overwrite: bool = False, + callback: Optional[Callable[[int, int], None]] = None, + ) -> None: + del callback # unused + + # Since MlflowClient.download_artifacts only raises MlflowException with 500 Internal Error, + # check for existence to surface a FileNotFoundError if necessary. + artifact_path = self.get_artifact_path(object_name) + artifact_info = self._get_artifact_info(object_name) + if artifact_info is None: + raise FileNotFoundError(f'Object {self.get_dbfs_path(artifact_path)} not found') + + filename = os.path.abspath(filename) + if os.path.exists(filename) and not overwrite: + raise FileExistsError(f'The file at {filename} already exists and overwrite is set to False.') + + # MLFLow doesn't support downloading artifacts directly to a specified filename, so instead + # download to a temporary directory and then move the file to the desired location. + tmp_dir = f'/tmp/{uuid.uuid4()}' + os.makedirs(tmp_dir, exist_ok=True) + try: + self._mlflow_client.download_artifacts( + run_id=self.run_id, + artifact_path=artifact_path, + dst_path=tmp_dir, + ) + tmp_path = os.path.join(tmp_dir, artifact_path) + + os.makedirs(os.path.dirname(filename), exist_ok=True) + if overwrite: + os.replace(tmp_path, filename) + else: + os.rename(tmp_path, filename) + except MlflowException as e: + _wrap_mlflow_exceptions(self.get_uri(artifact_path), e) + + def list_objects(self, prefix: Optional[str] = None) -> List[str]: + """See :meth:`~composer.utils.ObjectStore.list_objects`. + + MLFlowObjectStore does not support listing objects with a prefix, so the ``prefix`` argument is ignored. + """ + del prefix # not supported for MLFlowObjectStore + + objects = [] + self._list_objects_helper(None, objects) + return objects + + def _list_objects_helper(self, prefix: Optional[str], objects: List[str]) -> None: + """Helper to recursively populate the full list of objects for ``list_objects``. + + Args: + prefix (str | None): An artifact path prefix for artifacts to find. + objects (list[str]): The list of DBFS object paths to populate. + """ + artifact = None + try: + for artifact in self._mlflow_client.list_artifacts(self.run_id, prefix): + if artifact.is_dir: + self._list_objects_helper(artifact.path, objects) + else: + objects.append(artifact.path) + except MlflowException as e: + uri = '' if artifact is None else self.get_uri(artifact.path) + _wrap_mlflow_exceptions(uri, e) + + def _get_artifact_info(self, object_name) -> Optional[FileInfo]: + """Get the :class:`~mlflow.entities.FileInfo` for the given object name. + + Args: + object_name (str): The name of the object, either as an absolute DBFS path or a relative MLFlow artifact path. + + Returns: + Optional[FileInfo]: The :class:`~mlflow.entities.FileInfo` for the object, or None if it does not exist. + """ + # MLFlow doesn't support info for a singleton artifact, so we need to list all artifacts in the + # parent path and find the one with the matching name. + artifact_path = self.get_artifact_path(object_name) + artifact_dir = os.path.dirname(artifact_path) + artifacts = self._mlflow_client.list_artifacts(self.run_id, artifact_dir) + for artifact in artifacts: + if not artifact.is_dir and artifact.path == artifact_path: + return artifact + + return None diff --git a/composer/utils/object_store/object_store.py b/composer/utils/object_store/object_store.py index ef39b847a3..11f763ef68 100644 --- a/composer/utils/object_store/object_store.py +++ b/composer/utils/object_store/object_store.py @@ -87,7 +87,7 @@ def upload_object(self, Args: object_name (str): Object name (where object will be stored in the container) - filename (str | pathlib.Path): Path the the object on disk + filename (str | pathlib.Path): Path to the object on disk callback ((int, int) -> None, optional): If specified, the callback is periodically called with the number of bytes uploaded and the total size of the object being uploaded. **kwargs: other arguments to the upload object function are supported @@ -133,6 +133,7 @@ def download_object( downloaded and the total size of the object. Raises: + FileExistsError: If ``filename`` already exists and ``overwrite`` is ``False``. FileNotFoundError: If the file was not found in the object store. ObjectStoreTransientError: If there was a transient connection issue with downloading the object. """ diff --git a/setup.py b/setup.py index dd2b7b00ae..29d0c67302 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,8 @@ def package_files(prefix: str, directory: str, extension: str): break else: assert end != -1, 'there should be a balanced number of start and ends' - long_description = long_description[:start] + long_description[end + len(end_tag):] + long_description = long_description[:start] + \ + long_description[end + len(end_tag):] install_requires = [ 'pyyaml>=6.0,<7', @@ -223,7 +224,7 @@ def package_files(prefix: str, directory: str, extension: str): ] extra_deps['mlflow'] = [ - 'mlflow>=2.9.0,<3.0', + 'mlflow>=2.9.2,<3.0', ] extra_deps['pandas'] = ['pandas>=2.0.0,<3.0'] diff --git a/tests/utils/object_store/test_mlflow_object_store.py b/tests/utils/object_store/test_mlflow_object_store.py new file mode 100644 index 0000000000..33cb7819e0 --- /dev/null +++ b/tests/utils/object_store/test_mlflow_object_store.py @@ -0,0 +1,280 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import os +from pathlib import Path +from unittest import mock +from unittest.mock import MagicMock + +import pytest + +from composer.utils import MLFlowObjectStore +from composer.utils.object_store.mlflow_object_store import PLACEHOLDER_EXPERIMENT_ID, PLACEHOLDER_RUN_ID + +TEST_PATH_FORMAT = 'databricks/mlflow-tracking/{experiment_id}/{run_id}/artifacts/' +EXPERIMENT_ID = '123' +EXPERIMENT_NAME = 'test-experiment' +RUN_ID = '456' +RUN_NAME = 'test-run' +ARTIFACT_PATH = 'path/to/artifact' +DEFAULT_PATH = TEST_PATH_FORMAT.format(experiment_id=EXPERIMENT_ID, run_id=RUN_ID) + + +def test_parse_dbfs_path(): + full_artifact_path = DEFAULT_PATH + ARTIFACT_PATH + assert MLFlowObjectStore.parse_dbfs_path(full_artifact_path) == (EXPERIMENT_ID, RUN_ID, ARTIFACT_PATH) + + # Test with bad prefix + with pytest.raises(ValueError): + MLFlowObjectStore.parse_dbfs_path(f'bad-prefix/{EXPERIMENT_ID}/{RUN_ID}/artifacts/{ARTIFACT_PATH}') + + # Test without artifacts + with pytest.raises(ValueError): + MLFlowObjectStore.parse_dbfs_path(f'databricks/mlflow-tracking/{EXPERIMENT_ID}/{RUN_ID}/') + with pytest.raises(ValueError): + MLFlowObjectStore.parse_dbfs_path(f'databricks/mlflow-tracking/{EXPERIMENT_ID}/{RUN_ID}/not-artifacts/') + + +def test_init_fail_without_databricks_tracking_uri(monkeypatch): + monkeypatch.setenv('MLFLOW_TRACKING_URI', 'not-databricks') + with pytest.raises(ValueError): + MLFlowObjectStore(DEFAULT_PATH) + + +@mock.patch('composer.utils.object_store.mlflow_object_store.WorkspaceClient') +@mock.patch('composer.utils.object_store.mlflow_object_store.MlflowClient') +def test_init_with_experiment_and_run(mock_mlflow_client, mock_ws_client): + mock_mlflow_client.return_value.get_run.return_value = MagicMock(info=MagicMock(experiment_id=EXPERIMENT_ID)) + + store = MLFlowObjectStore(DEFAULT_PATH) + assert store.experiment_id == EXPERIMENT_ID + assert store.run_id == RUN_ID + + +@mock.patch('composer.utils.object_store.mlflow_object_store.WorkspaceClient') +@mock.patch('composer.utils.object_store.mlflow_object_store.MlflowClient') +def test_init_with_experiment_and_no_run(mock_mlflow_client, mock_ws_client): + mock_mlflow_client.return_value.create_run.return_value = MagicMock( + info=MagicMock(run_id=RUN_ID, run_name='test-run')) + + store = MLFlowObjectStore(TEST_PATH_FORMAT.format(experiment_id=EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID)) + assert store.experiment_id == EXPERIMENT_ID + assert store.run_id == RUN_ID + + +@mock.patch('composer.utils.object_store.mlflow_object_store.WorkspaceClient') +def test_init_with_run_and_no_experiment(mock_ws_client): + with pytest.raises(ValueError): + MLFlowObjectStore(TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=RUN_ID)) + + +@mock.patch('composer.utils.object_store.mlflow_object_store.WorkspaceClient') +@mock.patch('composer.utils.object_store.mlflow_object_store.MlflowClient') +@mock.patch('composer.utils.object_store.mlflow_object_store.mlflow') +def test_init_with_active_run(mock_mlflow, mock_mlflow_client, mock_ws_client): + mock_mlflow.active_run.return_value = MagicMock(info=MagicMock(experiment_id=EXPERIMENT_ID, run_id=RUN_ID)) + + store = MLFlowObjectStore( + TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID)) + assert store.experiment_id == EXPERIMENT_ID + assert store.run_id == RUN_ID + + +@mock.patch('composer.utils.object_store.mlflow_object_store.WorkspaceClient') +@mock.patch('composer.utils.object_store.mlflow_object_store.MlflowClient') +def test_init_with_existing_experiment_and_no_run(mock_mlflow_client, mock_ws_client): + mock_mlflow_client.return_value.get_experiment_by_name.return_value = MagicMock(experiment_id=EXPERIMENT_ID) + mock_mlflow_client.return_value.create_run.return_value = MagicMock( + info=MagicMock(run_id=RUN_ID, run_name='test-run')) + + store = MLFlowObjectStore( + TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID)) + assert store.experiment_id == EXPERIMENT_ID + assert store.run_id == RUN_ID + + +@mock.patch('composer.utils.object_store.mlflow_object_store.WorkspaceClient') +@mock.patch('composer.utils.object_store.mlflow_object_store.MlflowClient') +def test_init_with_no_experiment_and_no_run(mock_mlflow_client, mock_ws_client): + mock_mlflow_client.return_value.get_experiment_by_name.return_value = None + mock_mlflow_client.return_value.create_experiment.return_value = EXPERIMENT_ID + mock_mlflow_client.return_value.create_run.return_value = MagicMock( + info=MagicMock(run_id=RUN_ID, run_name='test-run')) + + store = MLFlowObjectStore( + TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID)) + assert store.experiment_id == EXPERIMENT_ID + assert store.run_id == RUN_ID + + +@pytest.fixture() +def mlflow_object_store(): + + def mock_mlflow_client_list_artifacts(*args, **kwargs): + """Mock behavior for MlflowClient.list_artifacts(). + + Behaves as if artifacts are stored under the following structure: + - dir1/ + - a.txt + - b.txt + - dir2/ + - c.txt + - dir3/ + - d.txt + """ + path = args[1] + if not path: + return [ + MagicMock(path='dir1', is_dir=True, file_size=None), + MagicMock(path='dir2', is_dir=True, file_size=None) + ] + elif path == 'dir1': + return [ + MagicMock(path='dir1/a.txt', is_dir=False, file_size=100), + MagicMock(path='dir1/b.txt', is_dir=False, file_size=200) + ] + elif path == 'dir2': + return [ + MagicMock(path='dir2/c.txt', is_dir=False, file_size=300), + MagicMock(path='dir2/dir3', is_dir=True, file_size=None) + ] + elif path == 'dir2/dir3': + return [MagicMock(path='dir2/dir3/d.txt', is_dir=False, file_size=400)] + else: + return [] + + with mock.patch('composer.utils.object_store.mlflow_object_store.WorkspaceClient'): + with mock.patch('composer.utils.object_store.mlflow_object_store.MlflowClient') as mock_mlflow_client: + mock_mlflow_client.return_value.get_run.return_value = MagicMock(info=MagicMock( + experiment_id=EXPERIMENT_ID)) + mock_mlflow_client.return_value.list_artifacts.side_effect = mock_mlflow_client_list_artifacts + yield MLFlowObjectStore(DEFAULT_PATH) + + +def test_get_artifact_path(mlflow_object_store): + # Relative MLFlow artifact path + assert mlflow_object_store.get_artifact_path(ARTIFACT_PATH) == ARTIFACT_PATH + + # Absolute DBFS path + assert mlflow_object_store.get_artifact_path(DEFAULT_PATH + ARTIFACT_PATH) == ARTIFACT_PATH + + +def test_get_dbfs_path(mlflow_object_store): + experiment_id = mlflow_object_store.experiment_id + run_id = mlflow_object_store.run_id + + expected_dbfs_path = f'databricks/mlflow-tracking/{experiment_id}/{run_id}/artifacts/{ARTIFACT_PATH}' + assert mlflow_object_store.get_dbfs_path(ARTIFACT_PATH) == expected_dbfs_path + + +def test_get_uri(mlflow_object_store): + experiment_id = mlflow_object_store.experiment_id + run_id = mlflow_object_store.run_id + expected_uri = f'dbfs:/databricks/mlflow-tracking/{experiment_id}/{run_id}/artifacts/{ARTIFACT_PATH}' + + # Relative MLFlow artifact path + assert mlflow_object_store.get_uri(ARTIFACT_PATH) == expected_uri + + # Absolute DBFS path + assert mlflow_object_store.get_uri(DEFAULT_PATH + ARTIFACT_PATH) == expected_uri + + +def test_get_artifact_info(mlflow_object_store): + assert mlflow_object_store._get_artifact_info('dir1/a.txt').path == 'dir1/a.txt' + assert mlflow_object_store._get_artifact_info('dir1/b.txt').path == 'dir1/b.txt' + assert mlflow_object_store._get_artifact_info('dir2/c.txt').path == 'dir2/c.txt' + assert mlflow_object_store._get_artifact_info('dir2/dir3/d.txt').path == 'dir2/dir3/d.txt' + + # Test with absolute DBFS path + assert mlflow_object_store._get_artifact_info(DEFAULT_PATH + 'dir1/a.txt').path == 'dir1/a.txt' + + # Verify directories are not returned + assert mlflow_object_store._get_artifact_info('dir1') is None + + # Test non-existent artifact + assert mlflow_object_store._get_artifact_info('nonexistent.txt') is None + + +def test_get_object_size(mlflow_object_store): + assert mlflow_object_store.get_object_size('dir1/a.txt') == 100 + assert mlflow_object_store.get_object_size('dir1/b.txt') == 200 + assert mlflow_object_store.get_object_size('dir2/c.txt') == 300 + assert mlflow_object_store.get_object_size('dir2/dir3/d.txt') == 400 + + # Test with absolute DBFS path + assert mlflow_object_store.get_object_size(DEFAULT_PATH + 'dir1/a.txt') == 100 + + # Verify FileNotFoundError is raised for non-existent artifact + with pytest.raises(FileNotFoundError): + mlflow_object_store.get_object_size('dir1') + with pytest.raises(FileNotFoundError): + mlflow_object_store.get_object_size('nonexistent.txt') + + +def test_download_object(mlflow_object_store, tmp_path): + + def mock_mlflow_client_download_artifacts(*args, **kwargs): + artifact_path = kwargs['artifact_path'] + dst_path = kwargs['dst_path'] + local_path = os.path.join(dst_path, artifact_path) + os.makedirs(os.path.dirname(local_path), exist_ok=True) + + size = mlflow_object_store.get_object_size(artifact_path) + file_content = bytes('0' * (size), 'utf-8') + + print(local_path) + + with open(local_path, 'wb') as fp: + fp.write(file_content) + return local_path + + mlflow_object_store._mlflow_client.download_artifacts.side_effect = mock_mlflow_client_download_artifacts + + # Test downloading file + object_name = 'dir1/a.txt' + file_to_download = str(tmp_path / Path(object_name)) + mlflow_object_store.download_object(object_name, file_to_download) + assert os.path.exists(file_to_download) + assert os.path.getsize(file_to_download) == mlflow_object_store.get_object_size(object_name) + + # Test cannot overwrite existing file when `overwrite` is False + with pytest.raises(FileExistsError): + mlflow_object_store.download_object(object_name, file_to_download, overwrite=False) + + # Test can overwrite existing file when `overwrite` is True + mlflow_object_store.download_object(object_name, file_to_download, overwrite=True) + + # Test downloading file under different name + object_name = 'dir1/a.txt' + file_to_download = str(tmp_path / Path('renamed.txt')) + mlflow_object_store.download_object(object_name, file_to_download) + assert os.path.exists(file_to_download) + assert os.path.getsize(file_to_download) == mlflow_object_store.get_object_size(object_name) + + # Raises FileNotFound when artifact does not exist + with pytest.raises(FileNotFoundError): + mlflow_object_store.download_object('nonexistent.txt', file_to_download) + + +def test_upload_object(mlflow_object_store, tmp_path): + file_to_upload = str(tmp_path / Path('file.txt')) + with open(file_to_upload, 'wb') as f: + f.write(bytes(range(20))) + + object_name = 'dir1/file.txt' + mlflow_object_store.upload_object(object_name=object_name, filename=file_to_upload) + run_id, local_path, artifact_dir = mlflow_object_store._mlflow_client.log_artifact.call_args.args + assert run_id == mlflow_object_store.run_id + assert os.path.basename(local_path) == os.path.basename(object_name) + assert artifact_dir == os.path.dirname(object_name) + + # Test basename symlink is created with correct name when object base name is different + object_name = 'dir1/renamed.txt' + mlflow_object_store.upload_object(object_name=object_name, filename=file_to_upload) + _, local_path, _ = mlflow_object_store._mlflow_client.log_artifact.call_args.args + assert os.path.basename(local_path) == os.path.basename(object_name) + + +def test_list_objects(mlflow_object_store): + expected = {'dir1/a.txt', 'dir1/b.txt', 'dir2/c.txt', 'dir2/dir3/d.txt'} + assert set(mlflow_object_store.list_objects()) == expected From 2468120fd20fb6c62427b47e96c6c0dce4478889 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Fri, 29 Dec 2023 14:37:35 -0800 Subject: [PATCH 02/18] Update object store test settings --- composer/utils/object_store/mlflow_object_store.py | 13 ++++++------- tests/utils/object_store/object_store_settings.py | 9 +++++---- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index fe4f04291b..172ec42e5b 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -72,13 +72,12 @@ class MLFlowObjectStore(ObjectStore): Using this object store requires setting `DATABRICKS_HOST` and `DATABRICKS_TOKEN` environment variables with the right credentials. - Because Databricks-managed MLFlow artifact URIs are stored in DBFS, artifact/object names are relative to an - artifacts root directory for the experiment/run in DBFS, `databricks/mlflow-tracking//`. - The URI format is inconsistent with that of other object stores, so the full DBFS path is sometimes used as the - object name in upstream code that uses this class. - - For simplicity in such upstream code, :class:`MLFlowObjectStore` accepts both relative MLFlow artifact paths and - absolute DBFS paths as object names. If an object name takes the form of + Unlike other object stores, the DBFS URI scheme for MLFlow artifacts has no bucket, and the path is prefixed + with the artifacts root directory for a given experiment/run, + `databricks/mlflow-tracking///`. However, object names are also sometimes passed by + upstream code as artifact paths relative to this root, rather than the full path. To keep upstream code simple, + :class:`MLFlowObjectStore` accepts both relative MLFlow artifact paths and absolute DBFS paths as object names. + If an object name takes the form of `databricks/mlflow-tracking///artifacts/`, it is assumed to be an absolute DBFS path, and the `` is used when uploading objects to MLFlow. Otherwise, the object name is assumed to be a relative MLFlow artifact path, and the full provided name will be diff --git a/tests/utils/object_store/object_store_settings.py b/tests/utils/object_store/object_store_settings.py index d94cd70fd6..05f459db9c 100644 --- a/tests/utils/object_store/object_store_settings.py +++ b/tests/utils/object_store/object_store_settings.py @@ -14,8 +14,8 @@ import composer.utils.object_store import composer.utils.object_store.sftp_object_store -from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, ObjectStore, OCIObjectStore, - S3ObjectStore, SFTPObjectStore, UCObjectStore) +from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore, + OCIObjectStore, S3ObjectStore, SFTPObjectStore, UCObjectStore) from composer.utils.object_store.sftp_object_store import SFTPObjectStore from tests.common import get_module_subclasses @@ -56,8 +56,9 @@ object_stores = [ pytest.param(x, marks=_object_store_marks[x], id=x.__name__) for x in get_module_subclasses(composer.utils.object_store, ObjectStore) - # Note: OCI, GCS and UC have their own test suite, so they are exempt from being included in this one.`` - if not issubclass(x, OCIObjectStore) and not issubclass(x, GCSObjectStore) and not issubclass(x, UCObjectStore) + # Note: OCI, GCS, UC, and MLFlow have their own test suite, so they are exempt from being included in this one.`` + if not issubclass(x, OCIObjectStore) and not issubclass(x, GCSObjectStore) and not issubclass(x, UCObjectStore) and + not issubclass(x, MLFlowObjectStore) ] From e750618c5294816ecf59398be57470f82cd30e47 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Sat, 30 Dec 2023 12:42:45 -0800 Subject: [PATCH 03/18] Import mlflow dependencies inline --- .../utils/object_store/mlflow_object_store.py | 75 +++++++++++-------- .../utils/object_store/uc_object_store.py | 2 +- 2 files changed, 46 insertions(+), 31 deletions(-) diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index 172ec42e5b..ec112b1059 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -11,15 +11,8 @@ import uuid from typing import Callable, List, Optional, Tuple, Union -import mlflow -from databricks.sdk import WorkspaceClient -from mlflow import MlflowClient -from mlflow.entities import FileInfo -from mlflow.exceptions import (ABORTED, DATA_LOSS, DEADLINE_EXCEEDED, ENDPOINT_NOT_FOUND, INTERNAL_ERROR, INVALID_STATE, - NOT_FOUND, REQUEST_LIMIT_EXCEEDED, RESOURCE_DOES_NOT_EXIST, RESOURCE_EXHAUSTED, - TEMPORARILY_UNAVAILABLE, ErrorCode, MlflowException) - from composer.utils import dist +from composer.utils.import_helpers import MissingConditionalImportError from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError __all__ = ['MLFlowObjectStore'] @@ -32,31 +25,34 @@ PLACEHOLDER_EXPERIMENT_ID = 'EXPERIMENT_ID' PLACEHOLDER_RUN_ID = 'RUN_ID' -# https://github.com/mlflow/mlflow/blob/master/mlflow/exceptions.py for used error codes. -# https://github.com/mlflow/mlflow/blob/master/mlflow/protos/databricks.proto for code descriptions. -_RETRYABLE_SERVER_CODES = [ - ErrorCode.Name(code) for code in [ - DATA_LOSS, - INTERNAL_ERROR, - INVALID_STATE, - TEMPORARILY_UNAVAILABLE, - DEADLINE_EXCEEDED, - ] -] - -_RETRYABLE_CLIENT_CODES = [ErrorCode.Name(code) for code in [ABORTED, REQUEST_LIMIT_EXCEEDED, RESOURCE_EXHAUSTED]] - -_NOT_FOUND_CODES = [ErrorCode.Name(code) for code in [RESOURCE_DOES_NOT_EXIST, NOT_FOUND, ENDPOINT_NOT_FOUND]] - log = logging.getLogger(__name__) -def _wrap_mlflow_exceptions(uri: str, e: MlflowException): +def _wrap_mlflow_exceptions(uri: str, e: Exception): """Wraps retryable MLFlow errors in ObjectStoreTransientError for automatic retry handling.""" - if e.error_code in _RETRYABLE_SERVER_CODES or e.error_code in _RETRYABLE_CLIENT_CODES: - raise ObjectStoreTransientError(e.error_code) from e - elif e.error_code in _NOT_FOUND_CODES: - raise FileNotFoundError(f'Object {uri} not found') from e + from mlflow.exceptions import (ABORTED, DATA_LOSS, DEADLINE_EXCEEDED, ENDPOINT_NOT_FOUND, INTERNAL_ERROR, + INVALID_STATE, NOT_FOUND, REQUEST_LIMIT_EXCEEDED, RESOURCE_DOES_NOT_EXIST, + RESOURCE_EXHAUSTED, TEMPORARILY_UNAVAILABLE, ErrorCode, MlflowException) + + # https://github.com/mlflow/mlflow/blob/master/mlflow/exceptions.py for used error codes. + # https://github.com/mlflow/mlflow/blob/master/mlflow/protos/databricks.proto for code descriptions. + retryable_server_codes = [ + ErrorCode.Name(code) for code in [ + DATA_LOSS, + INTERNAL_ERROR, + INVALID_STATE, + TEMPORARILY_UNAVAILABLE, + DEADLINE_EXCEEDED, + ] + ] + retryable_client_codes = [ErrorCode.Name(code) for code in [ABORTED, REQUEST_LIMIT_EXCEEDED, RESOURCE_EXHAUSTED]] + not_found_codes = [ErrorCode.Name(code) for code in [RESOURCE_DOES_NOT_EXIST, NOT_FOUND, ENDPOINT_NOT_FOUND]] + + if isinstance(e, MlflowException): + if e.error_code in retryable_server_codes or e.error_code in retryable_client_codes: + raise ObjectStoreTransientError(e.error_code) from e + elif e.error_code in not_found_codes: + raise FileNotFoundError(f'Object {uri} not found') from e raise e @@ -103,6 +99,17 @@ class MLFlowObjectStore(ObjectStore): """ def __init__(self, path: str, multipart_upload_chunk_size: int = 100 * 1024 * 1024) -> None: + try: + import mlflow + from mlflow import MlflowClient + except ImportError as e: + raise MissingConditionalImportError('mlflow', conda_package='mlflow>=2.9.2,<3.0') from e + + try: + from databricks.sdk import WorkspaceClient + except ImportError as e: + raise MissingConditionalImportError('databricks', conda_package='databricks-sdk>=0.15.0,<1.0') from e + tracking_uri = os.getenv('MLFLOW_TRACKING_URI', MLFLOW_DATABRICKS_TRACKING_URI) if tracking_uri != MLFLOW_DATABRICKS_TRACKING_URI: raise ValueError( @@ -145,6 +152,8 @@ def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> In a distributed setting, this should only be called on the rank 0 process. """ + import mlflow + if experiment_id is None: if run_id is not None: raise ValueError('A `run_id` cannot be provided without a valid `experiment_id`.') @@ -249,6 +258,7 @@ def upload_object(self, filename: Union[str, pathlib.Path], callback: Optional[Callable[[int, int], None]] = None): del callback # unused + from mlflow.exceptions import MlflowException # Extract relative path from DBFS path. artifact_path = self.get_artifact_path(object_name) @@ -269,6 +279,8 @@ def upload_object(self, _wrap_mlflow_exceptions(self.get_uri(object_name), e) def get_object_size(self, object_name: str) -> int: + from mlflow.exceptions import MlflowException + artifact = None try: artifact = self._get_artifact_info(object_name) @@ -288,6 +300,7 @@ def download_object( callback: Optional[Callable[[int, int], None]] = None, ) -> None: del callback # unused + from mlflow.exceptions import MlflowException # Since MlflowClient.download_artifacts only raises MlflowException with 500 Internal Error, # check for existence to surface a FileNotFoundError if necessary. @@ -338,6 +351,8 @@ def _list_objects_helper(self, prefix: Optional[str], objects: List[str]) -> Non prefix (str | None): An artifact path prefix for artifacts to find. objects (list[str]): The list of DBFS object paths to populate. """ + from mlflow.exceptions import MlflowException + artifact = None try: for artifact in self._mlflow_client.list_artifacts(self.run_id, prefix): @@ -349,7 +364,7 @@ def _list_objects_helper(self, prefix: Optional[str], objects: List[str]) -> Non uri = '' if artifact is None else self.get_uri(artifact.path) _wrap_mlflow_exceptions(uri, e) - def _get_artifact_info(self, object_name) -> Optional[FileInfo]: + def _get_artifact_info(self, object_name): """Get the :class:`~mlflow.entities.FileInfo` for the given object name. Args: diff --git a/composer/utils/object_store/uc_object_store.py b/composer/utils/object_store/uc_object_store.py index 8317675134..08e77cb910 100644 --- a/composer/utils/object_store/uc_object_store.py +++ b/composer/utils/object_store/uc_object_store.py @@ -53,7 +53,7 @@ def __init__(self, path: str) -> None: try: from databricks.sdk import WorkspaceClient except ImportError as e: - raise MissingConditionalImportError('databricks', conda_package='databricks-sdk>=0.8.0,<1.0') from e + raise MissingConditionalImportError('databricks', conda_package='databricks-sdk>=0.15.0,<1.0') from e try: self.client = WorkspaceClient() From eaf998c078a8c5d75ede229c27d82e2cb0c9eb1e Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Sat, 30 Dec 2023 17:04:44 -0800 Subject: [PATCH 04/18] Fix tests and ignore some pyright --- .../utils/object_store/mlflow_object_store.py | 7 +- .../utils/object_store/uc_object_store.py | 10 ++- .../object_store/test_mlflow_object_store.py | 85 +++++++++++++------ 3 files changed, 69 insertions(+), 33 deletions(-) diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index ec112b1059..4c31d42227 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -49,9 +49,10 @@ def _wrap_mlflow_exceptions(uri: str, e: Exception): not_found_codes = [ErrorCode.Name(code) for code in [RESOURCE_DOES_NOT_EXIST, NOT_FOUND, ENDPOINT_NOT_FOUND]] if isinstance(e, MlflowException): - if e.error_code in retryable_server_codes or e.error_code in retryable_client_codes: - raise ObjectStoreTransientError(e.error_code) from e - elif e.error_code in not_found_codes: + error_code = e.error_code # pyright: ignore + if error_code in retryable_server_codes or error_code in retryable_client_codes: + raise ObjectStoreTransientError(error_code) from e + elif error_code in not_found_codes: raise FileNotFoundError(f'Object {uri} not found') from e raise e diff --git a/composer/utils/object_store/uc_object_store.py b/composer/utils/object_store/uc_object_store.py index 08e77cb910..8759ce7b30 100644 --- a/composer/utils/object_store/uc_object_store.py +++ b/composer/utils/object_store/uc_object_store.py @@ -167,7 +167,7 @@ def download_object(self, try: from databricks.sdk.core import DatabricksError try: - with self.client.files.download(self._get_object_path(object_name)).contents as resp: + with self.client.files.download(self._get_object_path(object_name)).contents as resp: # pyright: ignore with open(tmp_path, 'wb') as f: # Chunk the data into multiple blocks of 64MB to avoid # OOMs when downloading really large files @@ -199,11 +199,15 @@ def get_object_size(self, object_name: str) -> int: Raises: FileNotFoundError: If the file was not found in the object store. + IsADirectoryError: If the object is a directory, not a file. """ from databricks.sdk.core import DatabricksError try: file_info = self.client.files.get_status(self._get_object_path(object_name)) - return file_info.file_size + if file_info.is_dir: + raise IsADirectoryError(f'{object_name} is a UC directory, not a file.') + + return file_info.file_size # pyright: ignore except DatabricksError as e: _wrap_errors(self.get_uri(object_name), e) @@ -231,6 +235,6 @@ def list_objects(self, prefix: Optional[str]) -> List[str]: path=self._UC_VOLUME_LIST_API_ENDPOINT, data=data, headers={'Source': 'mosaicml/composer'}) - return [f['path'] for f in resp.get('files', []) if not f['is_dir']] + return [f['path'] for f in resp.get('files', []) if not f['is_dir']] # pyright: ignore except DatabricksError as e: _wrap_errors(self.get_uri(prefix), e) diff --git a/tests/utils/object_store/test_mlflow_object_store.py b/tests/utils/object_store/test_mlflow_object_store.py index 33cb7819e0..300dd7c3df 100644 --- a/tests/utils/object_store/test_mlflow_object_store.py +++ b/tests/utils/object_store/test_mlflow_object_store.py @@ -3,7 +3,6 @@ import os from pathlib import Path -from unittest import mock from unittest.mock import MagicMock import pytest @@ -41,9 +40,14 @@ def test_init_fail_without_databricks_tracking_uri(monkeypatch): MLFlowObjectStore(DEFAULT_PATH) -@mock.patch('composer.utils.object_store.mlflow_object_store.WorkspaceClient') -@mock.patch('composer.utils.object_store.mlflow_object_store.MlflowClient') -def test_init_with_experiment_and_run(mock_mlflow_client, mock_ws_client): +def test_init_with_experiment_and_run(monkeypatch): + dbx_sdk = pytest.importorskip('databricks.sdk') + monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) + + mlflow = pytest.importorskip('mlflow') + mock_mlflow_client = MagicMock() + monkeypatch.setattr(mlflow, 'MlflowClient', mock_mlflow_client) + mock_mlflow_client.return_value.get_run.return_value = MagicMock(info=MagicMock(experiment_id=EXPERIMENT_ID)) store = MLFlowObjectStore(DEFAULT_PATH) @@ -51,9 +55,14 @@ def test_init_with_experiment_and_run(mock_mlflow_client, mock_ws_client): assert store.run_id == RUN_ID -@mock.patch('composer.utils.object_store.mlflow_object_store.WorkspaceClient') -@mock.patch('composer.utils.object_store.mlflow_object_store.MlflowClient') -def test_init_with_experiment_and_no_run(mock_mlflow_client, mock_ws_client): +def test_init_with_experiment_and_no_run(monkeypatch): + dbx_sdk = pytest.importorskip('databricks.sdk') + monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) + + mlflow = pytest.importorskip('mlflow') + mock_mlflow_client = MagicMock() + monkeypatch.setattr(mlflow, 'MlflowClient', mock_mlflow_client) + mock_mlflow_client.return_value.create_run.return_value = MagicMock( info=MagicMock(run_id=RUN_ID, run_name='test-run')) @@ -62,17 +71,24 @@ def test_init_with_experiment_and_no_run(mock_mlflow_client, mock_ws_client): assert store.run_id == RUN_ID -@mock.patch('composer.utils.object_store.mlflow_object_store.WorkspaceClient') -def test_init_with_run_and_no_experiment(mock_ws_client): +def test_init_with_run_and_no_experiment(monkeypatch): + dbx_sdk = pytest.importorskip('databricks.sdk') + monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) + with pytest.raises(ValueError): MLFlowObjectStore(TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=RUN_ID)) -@mock.patch('composer.utils.object_store.mlflow_object_store.WorkspaceClient') -@mock.patch('composer.utils.object_store.mlflow_object_store.MlflowClient') -@mock.patch('composer.utils.object_store.mlflow_object_store.mlflow') -def test_init_with_active_run(mock_mlflow, mock_mlflow_client, mock_ws_client): - mock_mlflow.active_run.return_value = MagicMock(info=MagicMock(experiment_id=EXPERIMENT_ID, run_id=RUN_ID)) +def test_init_with_active_run(monkeypatch): + dbx_sdk = pytest.importorskip('databricks.sdk') + monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) + + mlflow = pytest.importorskip('mlflow') + mock_active_run = MagicMock() + monkeypatch.setattr(mlflow, 'active_run', mock_active_run) + monkeypatch.setattr(mlflow, 'MlflowClient', MagicMock()) + + mock_active_run.return_value = MagicMock(info=MagicMock(experiment_id=EXPERIMENT_ID, run_id=RUN_ID)) store = MLFlowObjectStore( TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID)) @@ -80,9 +96,14 @@ def test_init_with_active_run(mock_mlflow, mock_mlflow_client, mock_ws_client): assert store.run_id == RUN_ID -@mock.patch('composer.utils.object_store.mlflow_object_store.WorkspaceClient') -@mock.patch('composer.utils.object_store.mlflow_object_store.MlflowClient') -def test_init_with_existing_experiment_and_no_run(mock_mlflow_client, mock_ws_client): +def test_init_with_existing_experiment_and_no_run(monkeypatch): + dbx_sdk = pytest.importorskip('databricks.sdk') + monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) + + mlflow = pytest.importorskip('mlflow') + mock_mlflow_client = MagicMock() + monkeypatch.setattr(mlflow, 'MlflowClient', mock_mlflow_client) + mock_mlflow_client.return_value.get_experiment_by_name.return_value = MagicMock(experiment_id=EXPERIMENT_ID) mock_mlflow_client.return_value.create_run.return_value = MagicMock( info=MagicMock(run_id=RUN_ID, run_name='test-run')) @@ -93,9 +114,14 @@ def test_init_with_existing_experiment_and_no_run(mock_mlflow_client, mock_ws_cl assert store.run_id == RUN_ID -@mock.patch('composer.utils.object_store.mlflow_object_store.WorkspaceClient') -@mock.patch('composer.utils.object_store.mlflow_object_store.MlflowClient') -def test_init_with_no_experiment_and_no_run(mock_mlflow_client, mock_ws_client): +def test_init_with_no_experiment_and_no_run(monkeypatch): + dbx_sdk = pytest.importorskip('databricks.sdk') + monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) + + mlflow = pytest.importorskip('mlflow') + mock_mlflow_client = MagicMock() + monkeypatch.setattr(mlflow, 'MlflowClient', mock_mlflow_client) + mock_mlflow_client.return_value.get_experiment_by_name.return_value = None mock_mlflow_client.return_value.create_experiment.return_value = EXPERIMENT_ID mock_mlflow_client.return_value.create_run.return_value = MagicMock( @@ -108,7 +134,7 @@ def test_init_with_no_experiment_and_no_run(mock_mlflow_client, mock_ws_client): @pytest.fixture() -def mlflow_object_store(): +def mlflow_object_store(monkeypatch): def mock_mlflow_client_list_artifacts(*args, **kwargs): """Mock behavior for MlflowClient.list_artifacts(). @@ -143,12 +169,17 @@ def mock_mlflow_client_list_artifacts(*args, **kwargs): else: return [] - with mock.patch('composer.utils.object_store.mlflow_object_store.WorkspaceClient'): - with mock.patch('composer.utils.object_store.mlflow_object_store.MlflowClient') as mock_mlflow_client: - mock_mlflow_client.return_value.get_run.return_value = MagicMock(info=MagicMock( - experiment_id=EXPERIMENT_ID)) - mock_mlflow_client.return_value.list_artifacts.side_effect = mock_mlflow_client_list_artifacts - yield MLFlowObjectStore(DEFAULT_PATH) + dbx_sdk = pytest.importorskip('databricks.sdk') + monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) + + mlflow = pytest.importorskip('mlflow') + mock_mlflow_client = MagicMock() + monkeypatch.setattr(mlflow, 'MlflowClient', mock_mlflow_client) + + mock_mlflow_client.return_value.get_run.return_value = MagicMock(info=MagicMock(experiment_id=EXPERIMENT_ID)) + mock_mlflow_client.return_value.list_artifacts.side_effect = mock_mlflow_client_list_artifacts + + yield MLFlowObjectStore(DEFAULT_PATH) def test_get_artifact_path(mlflow_object_store): From 18b51b65e4c3a227b7e2f3e37af83b42b848ec8c Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Mon, 1 Jan 2024 14:18:07 -0800 Subject: [PATCH 05/18] Bugfix --- composer/utils/object_store/mlflow_object_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index 4c31d42227..bf48840ddf 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -321,7 +321,7 @@ def download_object( try: self._mlflow_client.download_artifacts( run_id=self.run_id, - artifact_path=artifact_path, + path=artifact_path, dst_path=tmp_dir, ) tmp_path = os.path.join(tmp_dir, artifact_path) From 8fa66dd12fba8c6ed3b7565e1dbb862ae245a207 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Mon, 1 Jan 2024 15:45:41 -0800 Subject: [PATCH 06/18] Enforce experiment and run in get_artifact_path --- .../utils/object_store/mlflow_object_store.py | 8 +++++++- .../object_store/test_mlflow_object_store.py | 20 ++++++++++++++++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index bf48840ddf..72a175eb56 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -243,7 +243,13 @@ def get_artifact_path(self, object_name: str) -> str: Otherwise, the object name is assumed to be a relative artifact path and will be returned as-is. """ if object_name.startswith(MLFLOW_DBFS_PATH_PREFIX): - _, _, object_name = self.parse_dbfs_path(object_name) + experiment_id, run_id, object_name = self.parse_dbfs_path(object_name) + if (experiment_id != self.experiment_id and experiment_id != PLACEHOLDER_EXPERIMENT_ID): + raise ValueError(f'Object {object_name} belongs to experiment with id={experiment_id}, ' + f'but MLFlowObjectStore is associated with experiment {self.experiment_id}.') + if (run_id != self.run_id and run_id != PLACEHOLDER_RUN_ID): + raise ValueError(f'Object {object_name} belongs to run with id={run_id}, ' + f'but MLFlowObjectStore is associated with run {self.run_id}.') return object_name def get_dbfs_path(self, object_name: str) -> str: diff --git a/tests/utils/object_store/test_mlflow_object_store.py b/tests/utils/object_store/test_mlflow_object_store.py index 300dd7c3df..d46fc493a4 100644 --- a/tests/utils/object_store/test_mlflow_object_store.py +++ b/tests/utils/object_store/test_mlflow_object_store.py @@ -189,6 +189,20 @@ def test_get_artifact_path(mlflow_object_store): # Absolute DBFS path assert mlflow_object_store.get_artifact_path(DEFAULT_PATH + ARTIFACT_PATH) == ARTIFACT_PATH + # Absolute DBFS path with placeholders + path = TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID) + ARTIFACT_PATH + assert mlflow_object_store.get_artifact_path(path) == ARTIFACT_PATH + + # Raises ValueError for different experiment ID + path = TEST_PATH_FORMAT.format(experiment_id='different-experiment', run_id=PLACEHOLDER_RUN_ID) + ARTIFACT_PATH + with pytest.raises(ValueError): + mlflow_object_store.get_artifact_path(path) + + # Raises ValueError for different run ID + path = TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id='different-run') + ARTIFACT_PATH + with pytest.raises(ValueError): + mlflow_object_store.get_artifact_path(path) + def test_get_dbfs_path(mlflow_object_store): experiment_id = mlflow_object_store.experiment_id @@ -245,12 +259,12 @@ def test_get_object_size(mlflow_object_store): def test_download_object(mlflow_object_store, tmp_path): def mock_mlflow_client_download_artifacts(*args, **kwargs): - artifact_path = kwargs['artifact_path'] + path = kwargs['path'] dst_path = kwargs['dst_path'] - local_path = os.path.join(dst_path, artifact_path) + local_path = os.path.join(dst_path, path) os.makedirs(os.path.dirname(local_path), exist_ok=True) - size = mlflow_object_store.get_object_size(artifact_path) + size = mlflow_object_store.get_object_size(path) file_content = bytes('0' * (size), 'utf-8') print(local_path) From 2980cbc6a5333ae1b83e62e7e214fc1339292206 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Tue, 2 Jan 2024 09:49:32 -0800 Subject: [PATCH 07/18] Update placeholders --- composer/utils/object_store/mlflow_object_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index 72a175eb56..00884f214f 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -22,8 +22,8 @@ DEFAULT_MLFLOW_EXPERIMENT_NAME = 'mlflow-object-store' -PLACEHOLDER_EXPERIMENT_ID = 'EXPERIMENT_ID' -PLACEHOLDER_RUN_ID = 'RUN_ID' +PLACEHOLDER_EXPERIMENT_ID = 'MLFLOW_EXPERIMENT_ID' +PLACEHOLDER_RUN_ID = 'MLFLOW_RUN_ID' log = logging.getLogger(__name__) From 881d18e2519f4a373af3a2a4fac18e5fee5c84d2 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Tue, 2 Jan 2024 13:40:46 -0800 Subject: [PATCH 08/18] Make logs debug instead of info --- composer/utils/object_store/mlflow_object_store.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index 00884f214f..52cfc9646e 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -163,7 +163,7 @@ def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> if active_run is not None: experiment_id = active_run.info.experiment_id run_id = active_run.info.run_id - log.info(f'MLFlowObjectStore using active MLFlow run (run_id={run_id})') + log.debug(f'MLFlowObjectStore using active MLFlow run (run_id={run_id})') else: # If no active run exists, create a new run for the default experiment. experiment_name = os.getenv(mlflow.environment_variables.MLFLOW_EXPERIMENT_NAME.name, @@ -177,8 +177,8 @@ def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> run_id = self._mlflow_client.create_run(experiment_id).info.run_id - log.info(f'MLFlowObjectStore using a new MLFlow run (run_id={run_id}) ' - f'for new experiment "{experiment_name}" (experiment_id={experiment_id})') + log.debug(f'MLFlowObjectStore using a new MLFlow run (run_id={run_id}) ' + f'for new experiment "{experiment_name}" (experiment_id={experiment_id})') else: if run_id is not None: # If a `run_id` is provided, check that it belongs to the provided experiment. @@ -188,15 +188,15 @@ def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> f'Provided `run_id` {run_id} does not belong to provided experiment {experiment_id}. ' f'Found experiment {run.info.experiment_id}.') - log.info(f'MLFlowObjectStore using provided MLFlow run (run_id={run_id}) ' - f'for provided experiment (experiment_id={experiment_id})') + log.debug(f'MLFlowObjectStore using provided MLFlow run (run_id={run_id}) ' + f'for provided experiment (experiment_id={experiment_id})') else: # If no `run_id` is provided, create a new run in the provided experiment. run = self._mlflow_client.create_run(experiment_id) run_id = run.info.run_id - log.info(f'MLFlowObjectStore using new MLFlow run (run_id={run_id}) ' - f'for provided experiment (experiment_id={experiment_id})') + log.debug(f'MLFlowObjectStore using new MLFlow run (run_id={run_id}) ' + f'for provided experiment (experiment_id={experiment_id})') if experiment_id is None or run_id is None: raise ValueError('MLFlowObjectStore failed to initialize experiment and run ID.') From c1d6d721fd32f88f5dc17714bd40ec8271275036 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Thu, 4 Jan 2024 16:38:02 -0800 Subject: [PATCH 09/18] Minor PR comments --- composer/utils/object_store/__init__.py | 11 +++++-- .../utils/object_store/mlflow_object_store.py | 30 +++++++++---------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/composer/utils/object_store/__init__.py b/composer/utils/object_store/__init__.py index d7757c2284..e623c385f0 100644 --- a/composer/utils/object_store/__init__.py +++ b/composer/utils/object_store/__init__.py @@ -13,6 +13,13 @@ from composer.utils.object_store.uc_object_store import UCObjectStore __all__ = [ - 'ObjectStore', 'ObjectStoreTransientError', 'LibcloudObjectStore', 'MLFlowObjectStore', 'S3ObjectStore', - 'SFTPObjectStore', 'OCIObjectStore', 'GCSObjectStore', 'UCObjectStore' + 'ObjectStore', + 'ObjectStoreTransientError', + 'LibcloudObjectStore', + 'MLFlowObjectStore', + 'S3ObjectStore', + 'SFTPObjectStore', + 'OCIObjectStore', + 'GCSObjectStore', + 'UCObjectStore', ] diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index 52cfc9646e..dc9e2a0371 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -34,8 +34,8 @@ def _wrap_mlflow_exceptions(uri: str, e: Exception): INVALID_STATE, NOT_FOUND, REQUEST_LIMIT_EXCEEDED, RESOURCE_DOES_NOT_EXIST, RESOURCE_EXHAUSTED, TEMPORARILY_UNAVAILABLE, ErrorCode, MlflowException) - # https://github.com/mlflow/mlflow/blob/master/mlflow/exceptions.py for used error codes. - # https://github.com/mlflow/mlflow/blob/master/mlflow/protos/databricks.proto for code descriptions. + # https://github.com/mlflow/mlflow/blob/39b76b5b05407af5d223e892b03e450b7264576a/mlflow/exceptions.py for used error codes. + # https://github.com/mlflow/mlflow/blob/39b76b5b05407af5d223e892b03e450b7264576a/mlflow/protos/databricks.proto for code descriptions. retryable_server_codes = [ ErrorCode.Name(code) for code in [ DATA_LOSS, @@ -94,7 +94,7 @@ class MLFlowObjectStore(ObjectStore): provided experiment. Providing a `run_id` without an `experiment_id` will raise an error. - multipart_upload_chunk_size(int, optional): The maximum size of single chunk in an MLFlow multipart upload. + multipart_upload_chunk_size(int, optional): The maximum size of a single chunk in an MLFlow multipart upload. The maximum number of chunks supported by MLFlow is 10,000, so the max file size that can be uploaded is `10 000 * multipart_upload_chunk_size`. Defaults to 100MB for a max upload size of 1TB. """ @@ -163,7 +163,7 @@ def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> if active_run is not None: experiment_id = active_run.info.experiment_id run_id = active_run.info.run_id - log.debug(f'MLFlowObjectStore using active MLFlow run (run_id={run_id})') + log.debug(f'MLFlowObjectStore using active MLFlow run {run_id=}') else: # If no active run exists, create a new run for the default experiment. experiment_name = os.getenv(mlflow.environment_variables.MLFLOW_EXPERIMENT_NAME.name, @@ -177,8 +177,8 @@ def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> run_id = self._mlflow_client.create_run(experiment_id).info.run_id - log.debug(f'MLFlowObjectStore using a new MLFlow run (run_id={run_id}) ' - f'for new experiment "{experiment_name}" (experiment_id={experiment_id})') + log.debug(f'MLFlowObjectStore using a new MLFlow run {run_id=}' + f'for new experiment "{experiment_name}" {experiment_id=}') else: if run_id is not None: # If a `run_id` is provided, check that it belongs to the provided experiment. @@ -188,15 +188,15 @@ def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> f'Provided `run_id` {run_id} does not belong to provided experiment {experiment_id}. ' f'Found experiment {run.info.experiment_id}.') - log.debug(f'MLFlowObjectStore using provided MLFlow run (run_id={run_id}) ' - f'for provided experiment (experiment_id={experiment_id})') + log.debug(f'MLFlowObjectStore using provided MLFlow run {run_id=} ' + f'for provided experiment {experiment_id=}') else: # If no `run_id` is provided, create a new run in the provided experiment. run = self._mlflow_client.create_run(experiment_id) run_id = run.info.run_id - log.debug(f'MLFlowObjectStore using new MLFlow run (run_id={run_id}) ' - f'for provided experiment (experiment_id={experiment_id})') + log.debug(f'MLFlowObjectStore using new MLFlow run {run_id=} ' + f'for provided experiment {experiment_id=}') if experiment_id is None or run_id is None: raise ValueError('MLFlowObjectStore failed to initialize experiment and run ID.') @@ -230,7 +230,7 @@ def parse_dbfs_path(path: str) -> Tuple[str, str, str]: if len(mlflow_parts) != 4 or mlflow_parts[2] != 'artifacts': raise ValueError(f'Databricks MLFlow artifact path expected to be of the format ' f'{MLFLOW_DBFS_PATH_PREFIX}///artifacts/. ' - f'Found path={path}') + f'Found {path=}') return mlflow_parts[0], mlflow_parts[1], mlflow_parts[3] @@ -245,11 +245,11 @@ def get_artifact_path(self, object_name: str) -> str: if object_name.startswith(MLFLOW_DBFS_PATH_PREFIX): experiment_id, run_id, object_name = self.parse_dbfs_path(object_name) if (experiment_id != self.experiment_id and experiment_id != PLACEHOLDER_EXPERIMENT_ID): - raise ValueError(f'Object {object_name} belongs to experiment with id={experiment_id}, ' - f'but MLFlowObjectStore is associated with experiment {self.experiment_id}.') + raise ValueError(f'Object {object_name} belongs to experiment ID {experiment_id}, ' + f'but MLFlowObjectStore is associated with experiment ID {self.experiment_id}.') if (run_id != self.run_id and run_id != PLACEHOLDER_RUN_ID): - raise ValueError(f'Object {object_name} belongs to run with id={run_id}, ' - f'but MLFlowObjectStore is associated with run {self.run_id}.') + raise ValueError(f'Object {object_name} belongs to run ID {run_id}, ' + f'but MLFlowObjectStore is associated with run ID {self.run_id}.') return object_name def get_dbfs_path(self, object_name: str) -> str: From 7c6823af1510bec3323a862281490380e53a4715 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Thu, 4 Jan 2024 16:48:30 -0800 Subject: [PATCH 10/18] MLflow casing --- .../utils/object_store/mlflow_object_store.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index dc9e2a0371..c9145fcbb1 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -1,7 +1,7 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 -"""MLFlow Artifacts object store.""" +"""MLflow Artifacts object store.""" from __future__ import annotations @@ -29,7 +29,7 @@ def _wrap_mlflow_exceptions(uri: str, e: Exception): - """Wraps retryable MLFlow errors in ObjectStoreTransientError for automatic retry handling.""" + """Wraps retryable MLflow errors in ObjectStoreTransientError for automatic retry handling.""" from mlflow.exceptions import (ABORTED, DATA_LOSS, DEADLINE_EXCEEDED, ENDPOINT_NOT_FOUND, INTERNAL_ERROR, INVALID_STATE, NOT_FOUND, REQUEST_LIMIT_EXCEEDED, RESOURCE_DOES_NOT_EXIST, RESOURCE_EXHAUSTED, TEMPORARILY_UNAVAILABLE, ErrorCode, MlflowException) @@ -59,26 +59,26 @@ def _wrap_mlflow_exceptions(uri: str, e: Exception): class MLFlowObjectStore(ObjectStore): - """Utility class for uploading and downloading artifacts from MLFlow. + """Utility class for uploading and downloading artifacts from MLflow. It can be initializd for an existing run, a new run in an existing experiment, the active run used by the `mlflow` module, or a new run in a new experiment. See the documentation for ``path`` for more details. .. note:: - At this time, only Databricks-managed MLFlow with a 'databricks' tracking URI is supported. + At this time, only Databricks-managed MLflow with a 'databricks' tracking URI is supported. Using this object store requires setting `DATABRICKS_HOST` and `DATABRICKS_TOKEN` environment variables with the right credentials. - Unlike other object stores, the DBFS URI scheme for MLFlow artifacts has no bucket, and the path is prefixed + Unlike other object stores, the DBFS URI scheme for MLflow artifacts has no bucket, and the path is prefixed with the artifacts root directory for a given experiment/run, `databricks/mlflow-tracking///`. However, object names are also sometimes passed by upstream code as artifact paths relative to this root, rather than the full path. To keep upstream code simple, - :class:`MLFlowObjectStore` accepts both relative MLFlow artifact paths and absolute DBFS paths as object names. + :class:`MLFlowObjectStore` accepts both relative MLflow artifact paths and absolute DBFS paths as object names. If an object name takes the form of `databricks/mlflow-tracking///artifacts/`, - it is assumed to be an absolute DBFS path, and the `` is used when uploading objects to MLFlow. - Otherwise, the object name is assumed to be a relative MLFlow artifact path, and the full provided name will be - used as the artifact path when uploading to MLFlow. + it is assumed to be an absolute DBFS path, and the `` is used when uploading objects to MLflow. + Otherwise, the object name is assumed to be a relative MLflow artifact path, and the full provided name will be + used as the artifact path when uploading to MLflow. Args: path (str): A DBFS path of the form @@ -86,7 +86,7 @@ class MLFlowObjectStore(ObjectStore): `experiment_id` and `run_id` can be set as the format string placeholders `'EXPERIMENT_ID'` and `'RUN_ID'`. If both `experiment_id` and `run_id` are set as placeholders, the MLFlowObjectStore will be associated with - the currently active MLFlow run if one exists. If no active run exists, a new run will be created under a + the currently active MLflow run if one exists. If no active run exists, a new run will be created under a default experiment name, or the experiment name specified by the `MLFLOW_EXPERIMENT_NAME` environment variable if one is set. @@ -94,8 +94,8 @@ class MLFlowObjectStore(ObjectStore): provided experiment. Providing a `run_id` without an `experiment_id` will raise an error. - multipart_upload_chunk_size(int, optional): The maximum size of a single chunk in an MLFlow multipart upload. - The maximum number of chunks supported by MLFlow is 10,000, so the max file size that can + multipart_upload_chunk_size(int, optional): The maximum size of a single chunk in an MLflow multipart upload. + The maximum number of chunks supported by MLflow is 10,000, so the max file size that can be uploaded is `10 000 * multipart_upload_chunk_size`. Defaults to 100MB for a max upload size of 1TB. """ @@ -114,7 +114,7 @@ def __init__(self, path: str, multipart_upload_chunk_size: int = 100 * 1024 * 10 tracking_uri = os.getenv('MLFLOW_TRACKING_URI', MLFLOW_DATABRICKS_TRACKING_URI) if tracking_uri != MLFLOW_DATABRICKS_TRACKING_URI: raise ValueError( - 'MLFlowObjectStore currently only supports Databricks-hosted MLFlow tracking. ' + 'MLFlowObjectStore currently only supports Databricks-hosted MLflow tracking. ' f'Environment variable `MLFLOW_TRACKING_URI` is set to a non-Databricks URI {tracking_uri}. ' f'Please unset it or set the value to `{MLFLOW_DATABRICKS_TRACKING_URI}`.') @@ -149,7 +149,7 @@ def __init__(self, path: str, multipart_upload_chunk_size: int = 100 * 1024 * 10 self.run_id = run_id def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> Tuple[str, str]: - """Returns the experiment ID and run ID for the MLFlow run backing this object store. + """Returns the experiment ID and run ID for the MLflow run backing this object store. In a distributed setting, this should only be called on the rank 0 process. """ @@ -163,7 +163,7 @@ def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> if active_run is not None: experiment_id = active_run.info.experiment_id run_id = active_run.info.run_id - log.debug(f'MLFlowObjectStore using active MLFlow run {run_id=}') + log.debug(f'MLFlowObjectStore using active MLflow run {run_id=}') else: # If no active run exists, create a new run for the default experiment. experiment_name = os.getenv(mlflow.environment_variables.MLFLOW_EXPERIMENT_NAME.name, @@ -177,7 +177,7 @@ def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> run_id = self._mlflow_client.create_run(experiment_id).info.run_id - log.debug(f'MLFlowObjectStore using a new MLFlow run {run_id=}' + log.debug(f'MLFlowObjectStore using a new MLflow run {run_id=}' f'for new experiment "{experiment_name}" {experiment_id=}') else: if run_id is not None: @@ -188,14 +188,14 @@ def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> f'Provided `run_id` {run_id} does not belong to provided experiment {experiment_id}. ' f'Found experiment {run.info.experiment_id}.') - log.debug(f'MLFlowObjectStore using provided MLFlow run {run_id=} ' + log.debug(f'MLFlowObjectStore using provided MLflow run {run_id=} ' f'for provided experiment {experiment_id=}') else: # If no `run_id` is provided, create a new run in the provided experiment. run = self._mlflow_client.create_run(experiment_id) run_id = run.info.run_id - log.debug(f'MLFlowObjectStore using new MLFlow run {run_id=} ' + log.debug(f'MLFlowObjectStore using new MLflow run {run_id=} ' f'for provided experiment {experiment_id=}') if experiment_id is None or run_id is None: @@ -205,7 +205,7 @@ def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> @staticmethod def parse_dbfs_path(path: str) -> Tuple[str, str, str]: - """Parses a DBFS path to extract the MLFlow experiment ID, run ID, and relative artifact path. + """Parses a DBFS path to extract the MLflow experiment ID, run ID, and relative artifact path. The path is expected to be of the format `databricks/mlflow-tracking///artifacts/`. @@ -220,7 +220,7 @@ def parse_dbfs_path(path: str) -> Tuple[str, str, str]: ValueError: If the path is not of the expected format. """ if not path.startswith(MLFLOW_DBFS_PATH_PREFIX): - raise ValueError(f'DBFS MLFlow path should start with {MLFLOW_DBFS_PATH_PREFIX}. Got: {path}') + raise ValueError(f'DBFS MLflow path should start with {MLFLOW_DBFS_PATH_PREFIX}. Got: {path}') # Strip `databricks/mlflow-tracking/` and split into # ``, ``, `'artifacts'`, ``` @@ -228,18 +228,18 @@ def parse_dbfs_path(path: str) -> Tuple[str, str, str]: mlflow_parts = subpath.split('/', maxsplit=3) if len(mlflow_parts) != 4 or mlflow_parts[2] != 'artifacts': - raise ValueError(f'Databricks MLFlow artifact path expected to be of the format ' + raise ValueError(f'Databricks MLflow artifact path expected to be of the format ' f'{MLFLOW_DBFS_PATH_PREFIX}///artifacts/. ' f'Found {path=}') return mlflow_parts[0], mlflow_parts[1], mlflow_parts[3] def get_artifact_path(self, object_name: str) -> str: - """Converts an object name into an MLFlow relative artifact path. + """Converts an object name into an MLflow relative artifact path. Args: object_name (str): The object name to convert. If the object name is a DBFS path beginning with - ``MLFLOW_DBFS_PATH_PREFIX``, the path will be parsed to extract the MLFlow relative artifact path. + ``MLFLOW_DBFS_PATH_PREFIX``, the path will be parsed to extract the MLflow relative artifact path. Otherwise, the object name is assumed to be a relative artifact path and will be returned as-is. """ if object_name.startswith(MLFLOW_DBFS_PATH_PREFIX): @@ -272,7 +272,7 @@ def upload_object(self, artifact_base_name = os.path.basename(artifact_path) artifact_dir = os.path.dirname(artifact_path) - # Since MLFlow doesn't support uploading artifacts with a different base name than the local file, + # Since MLflow doesn't support uploading artifacts with a different base name than the local file, # create a temporary symlink to the local file with the same base name as the desired artifact name. filename = os.path.abspath(filename) tmp_dir = f'/tmp/{uuid.uuid4()}' @@ -375,12 +375,12 @@ def _get_artifact_info(self, object_name): """Get the :class:`~mlflow.entities.FileInfo` for the given object name. Args: - object_name (str): The name of the object, either as an absolute DBFS path or a relative MLFlow artifact path. + object_name (str): The name of the object, either as an absolute DBFS path or a relative MLflow artifact path. Returns: Optional[FileInfo]: The :class:`~mlflow.entities.FileInfo` for the object, or None if it does not exist. """ - # MLFlow doesn't support info for a singleton artifact, so we need to list all artifacts in the + # MLflow doesn't support info for a singleton artifact, so we need to list all artifacts in the # parent path and find the one with the matching name. artifact_path = self.get_artifact_path(object_name) artifact_dir = os.path.dirname(artifact_path) From 503294746e3cbc8e7863bdac7e8219fd5d549d4d Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Thu, 4 Jan 2024 16:59:22 -0800 Subject: [PATCH 11/18] tracking_uri fixes --- composer/utils/object_store/mlflow_object_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index c9145fcbb1..d004258286 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -111,7 +111,7 @@ def __init__(self, path: str, multipart_upload_chunk_size: int = 100 * 1024 * 10 except ImportError as e: raise MissingConditionalImportError('databricks', conda_package='databricks-sdk>=0.15.0,<1.0') from e - tracking_uri = os.getenv('MLFLOW_TRACKING_URI', MLFLOW_DATABRICKS_TRACKING_URI) + tracking_uri = mlflow.get_tracking_uri() if tracking_uri != MLFLOW_DATABRICKS_TRACKING_URI: raise ValueError( 'MLFlowObjectStore currently only supports Databricks-hosted MLflow tracking. ' @@ -127,7 +127,7 @@ def __init__(self, path: str, multipart_upload_chunk_size: int = 100 * 1024 * 10 'Visit https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html#databricks-native-authentication ' 'to identify different ways to setup credentials.') from e - self._mlflow_client = MlflowClient(MLFLOW_DATABRICKS_TRACKING_URI) + self._mlflow_client = MlflowClient(tracking_uri) mlflow.environment_variables.MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.set(multipart_upload_chunk_size) experiment_id, run_id, _ = MLFlowObjectStore.parse_dbfs_path(path) From d6e84a6325f971b37768949118b071fc680d2ab6 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Thu, 4 Jan 2024 17:11:00 -0800 Subject: [PATCH 12/18] Update comments --- composer/utils/object_store/mlflow_object_store.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index d004258286..7afb660b81 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -66,8 +66,9 @@ class MLFlowObjectStore(ObjectStore): .. note:: At this time, only Databricks-managed MLflow with a 'databricks' tracking URI is supported. - Using this object store requires setting `DATABRICKS_HOST` and `DATABRICKS_TOKEN` - environment variables with the right credentials. + Using this object store requires configuring Databricks authentication through a configuration file or + environment variables. See + https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html#databricks-native-authentication Unlike other object stores, the DBFS URI scheme for MLflow artifacts has no bucket, and the path is prefixed with the artifacts root directory for a given experiment/run, @@ -83,7 +84,8 @@ class MLFlowObjectStore(ObjectStore): Args: path (str): A DBFS path of the form `databricks/mlflow-tracking///artifacts/`. - `experiment_id` and `run_id` can be set as the format string placeholders `'EXPERIMENT_ID'` and `'RUN_ID'`. + `experiment_id` and `run_id` can be set as the format string placeholders + `{mlflow_experiment_id}` and `{mlflow_run_id}'`. If both `experiment_id` and `run_id` are set as placeholders, the MLFlowObjectStore will be associated with the currently active MLflow run if one exists. If no active run exists, a new run will be created under a From f91f317e17c1f3069d8ec016fa526afef16411c5 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Thu, 4 Jan 2024 18:02:15 -0800 Subject: [PATCH 13/18] Update placeholders --- composer/utils/object_store/mlflow_object_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index 7afb660b81..14f469180f 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -22,8 +22,8 @@ DEFAULT_MLFLOW_EXPERIMENT_NAME = 'mlflow-object-store' -PLACEHOLDER_EXPERIMENT_ID = 'MLFLOW_EXPERIMENT_ID' -PLACEHOLDER_RUN_ID = 'MLFLOW_RUN_ID' +PLACEHOLDER_EXPERIMENT_ID = '{mlflow_experiment_id}' +PLACEHOLDER_RUN_ID = '{mlflow_run_id}' log = logging.getLogger(__name__) From 87db74ace59c333a4210ba8b8c1ed6c17f811613 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Thu, 4 Jan 2024 18:02:20 -0800 Subject: [PATCH 14/18] Fix tests --- .../utils/object_store/test_mlflow_object_store.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/utils/object_store/test_mlflow_object_store.py b/tests/utils/object_store/test_mlflow_object_store.py index d46fc493a4..ddbd82034f 100644 --- a/tests/utils/object_store/test_mlflow_object_store.py +++ b/tests/utils/object_store/test_mlflow_object_store.py @@ -41,6 +41,8 @@ def test_init_fail_without_databricks_tracking_uri(monkeypatch): def test_init_with_experiment_and_run(monkeypatch): + monkeypatch.setenv('MLFLOW_TRACKING_URI', 'databricks') + dbx_sdk = pytest.importorskip('databricks.sdk') monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) @@ -56,6 +58,8 @@ def test_init_with_experiment_and_run(monkeypatch): def test_init_with_experiment_and_no_run(monkeypatch): + monkeypatch.setenv('MLFLOW_TRACKING_URI', 'databricks') + dbx_sdk = pytest.importorskip('databricks.sdk') monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) @@ -72,6 +76,8 @@ def test_init_with_experiment_and_no_run(monkeypatch): def test_init_with_run_and_no_experiment(monkeypatch): + monkeypatch.setenv('MLFLOW_TRACKING_URI', 'databricks') + dbx_sdk = pytest.importorskip('databricks.sdk') monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) @@ -80,6 +86,8 @@ def test_init_with_run_and_no_experiment(monkeypatch): def test_init_with_active_run(monkeypatch): + monkeypatch.setenv('MLFLOW_TRACKING_URI', 'databricks') + dbx_sdk = pytest.importorskip('databricks.sdk') monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) @@ -97,6 +105,8 @@ def test_init_with_active_run(monkeypatch): def test_init_with_existing_experiment_and_no_run(monkeypatch): + monkeypatch.setenv('MLFLOW_TRACKING_URI', 'databricks') + dbx_sdk = pytest.importorskip('databricks.sdk') monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) @@ -115,6 +125,8 @@ def test_init_with_existing_experiment_and_no_run(monkeypatch): def test_init_with_no_experiment_and_no_run(monkeypatch): + monkeypatch.setenv('MLFLOW_TRACKING_URI', 'databricks') + dbx_sdk = pytest.importorskip('databricks.sdk') monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) @@ -135,6 +147,7 @@ def test_init_with_no_experiment_and_no_run(monkeypatch): @pytest.fixture() def mlflow_object_store(monkeypatch): + monkeypatch.setenv('MLFLOW_TRACKING_URI', 'databricks') def mock_mlflow_client_list_artifacts(*args, **kwargs): """Mock behavior for MlflowClient.list_artifacts(). From a29285112d0bceef0b964b8f66c35fe2e7b94e51 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Thu, 4 Jan 2024 18:37:44 -0800 Subject: [PATCH 15/18] Fix pyright --- composer/utils/object_store/uc_object_store.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/composer/utils/object_store/uc_object_store.py b/composer/utils/object_store/uc_object_store.py index 8759ce7b30..23e8440354 100644 --- a/composer/utils/object_store/uc_object_store.py +++ b/composer/utils/object_store/uc_object_store.py @@ -167,7 +167,10 @@ def download_object(self, try: from databricks.sdk.core import DatabricksError try: - with self.client.files.download(self._get_object_path(object_name)).contents as resp: # pyright: ignore + contents = self.client.files.download(self._get_object_path(object_name)).contents + assert contents is not None + + with contents as resp: # pyright: ignore with open(tmp_path, 'wb') as f: # Chunk the data into multiple blocks of 64MB to avoid # OOMs when downloading really large files @@ -235,6 +238,7 @@ def list_objects(self, prefix: Optional[str]) -> List[str]: path=self._UC_VOLUME_LIST_API_ENDPOINT, data=data, headers={'Source': 'mosaicml/composer'}) - return [f['path'] for f in resp.get('files', []) if not f['is_dir']] # pyright: ignore + assert isinstance(resp, dict) + return [f['path'] for f in resp.get('files', []) if not f['is_dir']] except DatabricksError as e: _wrap_errors(self.get_uri(prefix), e) From 14fa39cf34ce34c76d98d333cd3ff6b46426f2f2 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Fri, 5 Jan 2024 11:04:38 -0800 Subject: [PATCH 16/18] Use tempfile for temp dirs --- .../utils/object_store/mlflow_object_store.py | 51 +++++++++---------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index 14f469180f..1d45a4724f 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -8,7 +8,7 @@ import logging import os import pathlib -import uuid +import tempfile from typing import Callable, List, Optional, Tuple, Union from composer.utils import dist @@ -277,15 +277,15 @@ def upload_object(self, # Since MLflow doesn't support uploading artifacts with a different base name than the local file, # create a temporary symlink to the local file with the same base name as the desired artifact name. filename = os.path.abspath(filename) - tmp_dir = f'/tmp/{uuid.uuid4()}' - os.makedirs(tmp_dir, exist_ok=True) - tmp_symlink_path = os.path.join(tmp_dir, artifact_base_name) - os.symlink(filename, tmp_symlink_path) - try: - self._mlflow_client.log_artifact(self.run_id, tmp_symlink_path, artifact_dir) - except MlflowException as e: - _wrap_mlflow_exceptions(self.get_uri(object_name), e) + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_symlink_path = os.path.join(tmp_dir, artifact_base_name) + os.symlink(filename, tmp_symlink_path) + + try: + self._mlflow_client.log_artifact(self.run_id, tmp_symlink_path, artifact_dir) + except MlflowException as e: + _wrap_mlflow_exceptions(self.get_uri(object_name), e) def get_object_size(self, object_name: str) -> int: from mlflow.exceptions import MlflowException @@ -324,23 +324,22 @@ def download_object( # MLFLow doesn't support downloading artifacts directly to a specified filename, so instead # download to a temporary directory and then move the file to the desired location. - tmp_dir = f'/tmp/{uuid.uuid4()}' - os.makedirs(tmp_dir, exist_ok=True) - try: - self._mlflow_client.download_artifacts( - run_id=self.run_id, - path=artifact_path, - dst_path=tmp_dir, - ) - tmp_path = os.path.join(tmp_dir, artifact_path) - - os.makedirs(os.path.dirname(filename), exist_ok=True) - if overwrite: - os.replace(tmp_path, filename) - else: - os.rename(tmp_path, filename) - except MlflowException as e: - _wrap_mlflow_exceptions(self.get_uri(artifact_path), e) + with tempfile.TemporaryDirectory() as tmp_dir: + try: + self._mlflow_client.download_artifacts( + run_id=self.run_id, + path=artifact_path, + dst_path=tmp_dir, + ) + tmp_path = os.path.join(tmp_dir, artifact_path) + + os.makedirs(os.path.dirname(filename), exist_ok=True) + if overwrite: + os.replace(tmp_path, filename) + else: + os.rename(tmp_path, filename) + except MlflowException as e: + _wrap_mlflow_exceptions(self.get_uri(artifact_path), e) def list_objects(self, prefix: Optional[str] = None) -> List[str]: """See :meth:`~composer.utils.ObjectStore.list_objects`. From b3b24e785cfbb81942661728cc09d673ad7d83c5 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Fri, 5 Jan 2024 11:29:58 -0800 Subject: [PATCH 17/18] Read tracking uri env var directly --- composer/utils/object_store/mlflow_object_store.py | 2 +- .../utils/object_store/test_mlflow_object_store.py | 13 ------------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index 1d45a4724f..4ee89ab525 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -113,7 +113,7 @@ def __init__(self, path: str, multipart_upload_chunk_size: int = 100 * 1024 * 10 except ImportError as e: raise MissingConditionalImportError('databricks', conda_package='databricks-sdk>=0.15.0,<1.0') from e - tracking_uri = mlflow.get_tracking_uri() + tracking_uri = os.getenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, MLFLOW_DATABRICKS_TRACKING_URI) if tracking_uri != MLFLOW_DATABRICKS_TRACKING_URI: raise ValueError( 'MLFlowObjectStore currently only supports Databricks-hosted MLflow tracking. ' diff --git a/tests/utils/object_store/test_mlflow_object_store.py b/tests/utils/object_store/test_mlflow_object_store.py index ddbd82034f..d46fc493a4 100644 --- a/tests/utils/object_store/test_mlflow_object_store.py +++ b/tests/utils/object_store/test_mlflow_object_store.py @@ -41,8 +41,6 @@ def test_init_fail_without_databricks_tracking_uri(monkeypatch): def test_init_with_experiment_and_run(monkeypatch): - monkeypatch.setenv('MLFLOW_TRACKING_URI', 'databricks') - dbx_sdk = pytest.importorskip('databricks.sdk') monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) @@ -58,8 +56,6 @@ def test_init_with_experiment_and_run(monkeypatch): def test_init_with_experiment_and_no_run(monkeypatch): - monkeypatch.setenv('MLFLOW_TRACKING_URI', 'databricks') - dbx_sdk = pytest.importorskip('databricks.sdk') monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) @@ -76,8 +72,6 @@ def test_init_with_experiment_and_no_run(monkeypatch): def test_init_with_run_and_no_experiment(monkeypatch): - monkeypatch.setenv('MLFLOW_TRACKING_URI', 'databricks') - dbx_sdk = pytest.importorskip('databricks.sdk') monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) @@ -86,8 +80,6 @@ def test_init_with_run_and_no_experiment(monkeypatch): def test_init_with_active_run(monkeypatch): - monkeypatch.setenv('MLFLOW_TRACKING_URI', 'databricks') - dbx_sdk = pytest.importorskip('databricks.sdk') monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) @@ -105,8 +97,6 @@ def test_init_with_active_run(monkeypatch): def test_init_with_existing_experiment_and_no_run(monkeypatch): - monkeypatch.setenv('MLFLOW_TRACKING_URI', 'databricks') - dbx_sdk = pytest.importorskip('databricks.sdk') monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) @@ -125,8 +115,6 @@ def test_init_with_existing_experiment_and_no_run(monkeypatch): def test_init_with_no_experiment_and_no_run(monkeypatch): - monkeypatch.setenv('MLFLOW_TRACKING_URI', 'databricks') - dbx_sdk = pytest.importorskip('databricks.sdk') monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) @@ -147,7 +135,6 @@ def test_init_with_no_experiment_and_no_run(monkeypatch): @pytest.fixture() def mlflow_object_store(monkeypatch): - monkeypatch.setenv('MLFLOW_TRACKING_URI', 'databricks') def mock_mlflow_client_list_artifacts(*args, **kwargs): """Mock behavior for MlflowClient.list_artifacts(). From cd83f64e4617c3c772b21ccabd69b2b58b268df0 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Fri, 5 Jan 2024 16:58:48 -0800 Subject: [PATCH 18/18] Remove dist from MLFlowObjectStore --- composer/utils/object_store/mlflow_object_store.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index 4ee89ab525..15f50bcdb0 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -11,7 +11,6 @@ import tempfile from typing import Callable, List, Optional, Tuple, Union -from composer.utils import dist from composer.utils.import_helpers import MissingConditionalImportError from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError @@ -139,16 +138,7 @@ def __init__(self, path: str, multipart_upload_chunk_size: int = 100 * 1024 * 10 run_id = None # Construct the `experiment_id` and `run_id` depending on whether format placeholders were provided. - if not dist.is_initialized() or dist.get_global_rank() == 0: - experiment_id, run_id = self._init_run_info(experiment_id, run_id) - - if dist.is_initialized(): - mlflow_info = [experiment_id, run_id] - dist.broadcast_object_list(mlflow_info, src=0) - experiment_id, run_id = mlflow_info - - self.experiment_id = experiment_id - self.run_id = run_id + self.experiment_id, self.run_id = self._init_run_info(experiment_id, run_id) def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> Tuple[str, str]: """Returns the experiment ID and run ID for the MLflow run backing this object store.