diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index dc3d029e94..30930250d9 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -21,8 +21,9 @@ from composer.utils.iter_helpers import IteratorFileStream, ensure_tuple, map_collection from composer.utils.misc import (create_interval_scheduler, get_free_tcp_port, is_model_deepspeed, is_model_fsdp, is_notebook, model_eval_mode, using_torch_2) -from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, ObjectStore, ObjectStoreTransientError, - OCIObjectStore, S3ObjectStore, SFTPObjectStore, UCObjectStore) +from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore, + ObjectStoreTransientError, OCIObjectStore, S3ObjectStore, SFTPObjectStore, + UCObjectStore) from composer.utils.retrying import retry from composer.utils.string_enum import StringEnum @@ -44,6 +45,7 @@ 'OCIObjectStore', 'GCSObjectStore', 'UCObjectStore', + 'MLFlowObjectStore', 'MissingConditionalImportError', 'import_object', 'is_model_deepspeed', diff --git a/composer/utils/object_store/__init__.py b/composer/utils/object_store/__init__.py index de28ec9674..e623c385f0 100644 --- a/composer/utils/object_store/__init__.py +++ b/composer/utils/object_store/__init__.py @@ -5,6 +5,7 @@ from composer.utils.object_store.gcs_object_store import GCSObjectStore from composer.utils.object_store.libcloud_object_store import LibcloudObjectStore +from composer.utils.object_store.mlflow_object_store import MLFlowObjectStore from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError from composer.utils.object_store.oci_object_store import OCIObjectStore from composer.utils.object_store.s3_object_store import S3ObjectStore @@ -12,6 +13,13 @@ from composer.utils.object_store.uc_object_store import UCObjectStore __all__ = [ - 'ObjectStore', 'ObjectStoreTransientError', 'LibcloudObjectStore', 'S3ObjectStore', 'SFTPObjectStore', - 'OCIObjectStore', 'GCSObjectStore', 'UCObjectStore' + 'ObjectStore', + 'ObjectStoreTransientError', + 'LibcloudObjectStore', + 'MLFlowObjectStore', + 'S3ObjectStore', + 'SFTPObjectStore', + 'OCIObjectStore', + 'GCSObjectStore', + 'UCObjectStore', ] diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py new file mode 100644 index 0000000000..15f50bcdb0 --- /dev/null +++ b/composer/utils/object_store/mlflow_object_store.py @@ -0,0 +1,383 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""MLflow Artifacts object store.""" + +from __future__ import annotations + +import logging +import os +import pathlib +import tempfile +from typing import Callable, List, Optional, Tuple, Union + +from composer.utils.import_helpers import MissingConditionalImportError +from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError + +__all__ = ['MLFlowObjectStore'] + +MLFLOW_DATABRICKS_TRACKING_URI = 'databricks' +MLFLOW_DBFS_PATH_PREFIX = 'databricks/mlflow-tracking/' + +DEFAULT_MLFLOW_EXPERIMENT_NAME = 'mlflow-object-store' + +PLACEHOLDER_EXPERIMENT_ID = '{mlflow_experiment_id}' +PLACEHOLDER_RUN_ID = '{mlflow_run_id}' + +log = logging.getLogger(__name__) + + +def _wrap_mlflow_exceptions(uri: str, e: Exception): + """Wraps retryable MLflow errors in ObjectStoreTransientError for automatic retry handling.""" + from mlflow.exceptions import (ABORTED, DATA_LOSS, DEADLINE_EXCEEDED, ENDPOINT_NOT_FOUND, INTERNAL_ERROR, + INVALID_STATE, NOT_FOUND, REQUEST_LIMIT_EXCEEDED, RESOURCE_DOES_NOT_EXIST, + RESOURCE_EXHAUSTED, TEMPORARILY_UNAVAILABLE, ErrorCode, MlflowException) + + # https://github.com/mlflow/mlflow/blob/39b76b5b05407af5d223e892b03e450b7264576a/mlflow/exceptions.py for used error codes. + # https://github.com/mlflow/mlflow/blob/39b76b5b05407af5d223e892b03e450b7264576a/mlflow/protos/databricks.proto for code descriptions. + retryable_server_codes = [ + ErrorCode.Name(code) for code in [ + DATA_LOSS, + INTERNAL_ERROR, + INVALID_STATE, + TEMPORARILY_UNAVAILABLE, + DEADLINE_EXCEEDED, + ] + ] + retryable_client_codes = [ErrorCode.Name(code) for code in [ABORTED, REQUEST_LIMIT_EXCEEDED, RESOURCE_EXHAUSTED]] + not_found_codes = [ErrorCode.Name(code) for code in [RESOURCE_DOES_NOT_EXIST, NOT_FOUND, ENDPOINT_NOT_FOUND]] + + if isinstance(e, MlflowException): + error_code = e.error_code # pyright: ignore + if error_code in retryable_server_codes or error_code in retryable_client_codes: + raise ObjectStoreTransientError(error_code) from e + elif error_code in not_found_codes: + raise FileNotFoundError(f'Object {uri} not found') from e + + raise e + + +class MLFlowObjectStore(ObjectStore): + """Utility class for uploading and downloading artifacts from MLflow. + + It can be initializd for an existing run, a new run in an existing experiment, the active run used by the `mlflow` + module, or a new run in a new experiment. See the documentation for ``path`` for more details. + + .. note:: + At this time, only Databricks-managed MLflow with a 'databricks' tracking URI is supported. + Using this object store requires configuring Databricks authentication through a configuration file or + environment variables. See + https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html#databricks-native-authentication + + Unlike other object stores, the DBFS URI scheme for MLflow artifacts has no bucket, and the path is prefixed + with the artifacts root directory for a given experiment/run, + `databricks/mlflow-tracking///`. However, object names are also sometimes passed by + upstream code as artifact paths relative to this root, rather than the full path. To keep upstream code simple, + :class:`MLFlowObjectStore` accepts both relative MLflow artifact paths and absolute DBFS paths as object names. + If an object name takes the form of + `databricks/mlflow-tracking///artifacts/`, + it is assumed to be an absolute DBFS path, and the `` is used when uploading objects to MLflow. + Otherwise, the object name is assumed to be a relative MLflow artifact path, and the full provided name will be + used as the artifact path when uploading to MLflow. + + Args: + path (str): A DBFS path of the form + `databricks/mlflow-tracking///artifacts/`. + `experiment_id` and `run_id` can be set as the format string placeholders + `{mlflow_experiment_id}` and `{mlflow_run_id}'`. + + If both `experiment_id` and `run_id` are set as placeholders, the MLFlowObjectStore will be associated with + the currently active MLflow run if one exists. If no active run exists, a new run will be created under a + default experiment name, or the experiment name specified by the `MLFLOW_EXPERIMENT_NAME` environment + variable if one is set. + + If `experiment_id` is provided and `run_id` is not, the MLFlowObjectStore will create a new run in the + provided experiment. + + Providing a `run_id` without an `experiment_id` will raise an error. + multipart_upload_chunk_size(int, optional): The maximum size of a single chunk in an MLflow multipart upload. + The maximum number of chunks supported by MLflow is 10,000, so the max file size that can + be uploaded is `10 000 * multipart_upload_chunk_size`. Defaults to 100MB for a max upload size of 1TB. + """ + + def __init__(self, path: str, multipart_upload_chunk_size: int = 100 * 1024 * 1024) -> None: + try: + import mlflow + from mlflow import MlflowClient + except ImportError as e: + raise MissingConditionalImportError('mlflow', conda_package='mlflow>=2.9.2,<3.0') from e + + try: + from databricks.sdk import WorkspaceClient + except ImportError as e: + raise MissingConditionalImportError('databricks', conda_package='databricks-sdk>=0.15.0,<1.0') from e + + tracking_uri = os.getenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, MLFLOW_DATABRICKS_TRACKING_URI) + if tracking_uri != MLFLOW_DATABRICKS_TRACKING_URI: + raise ValueError( + 'MLFlowObjectStore currently only supports Databricks-hosted MLflow tracking. ' + f'Environment variable `MLFLOW_TRACKING_URI` is set to a non-Databricks URI {tracking_uri}. ' + f'Please unset it or set the value to `{MLFLOW_DATABRICKS_TRACKING_URI}`.') + + # Use the Databricks WorkspaceClient to check that credentials are set up correctly. + try: + WorkspaceClient() + except Exception as e: + raise ValueError( + f'Databricks SDK credentials not correctly setup. ' + 'Visit https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html#databricks-native-authentication ' + 'to identify different ways to setup credentials.') from e + + self._mlflow_client = MlflowClient(tracking_uri) + mlflow.environment_variables.MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.set(multipart_upload_chunk_size) + + experiment_id, run_id, _ = MLFlowObjectStore.parse_dbfs_path(path) + if experiment_id == PLACEHOLDER_EXPERIMENT_ID: + experiment_id = None + if run_id == PLACEHOLDER_RUN_ID: + run_id = None + + # Construct the `experiment_id` and `run_id` depending on whether format placeholders were provided. + self.experiment_id, self.run_id = self._init_run_info(experiment_id, run_id) + + def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) -> Tuple[str, str]: + """Returns the experiment ID and run ID for the MLflow run backing this object store. + + In a distributed setting, this should only be called on the rank 0 process. + """ + import mlflow + + if experiment_id is None: + if run_id is not None: + raise ValueError('A `run_id` cannot be provided without a valid `experiment_id`.') + + active_run = mlflow.active_run() + if active_run is not None: + experiment_id = active_run.info.experiment_id + run_id = active_run.info.run_id + log.debug(f'MLFlowObjectStore using active MLflow run {run_id=}') + else: + # If no active run exists, create a new run for the default experiment. + experiment_name = os.getenv(mlflow.environment_variables.MLFLOW_EXPERIMENT_NAME.name, + DEFAULT_MLFLOW_EXPERIMENT_NAME) + + experiment = self._mlflow_client.get_experiment_by_name(experiment_name) + if experiment is not None: + experiment_id = experiment.experiment_id + else: + experiment_id = self._mlflow_client.create_experiment(experiment_name) + + run_id = self._mlflow_client.create_run(experiment_id).info.run_id + + log.debug(f'MLFlowObjectStore using a new MLflow run {run_id=}' + f'for new experiment "{experiment_name}" {experiment_id=}') + else: + if run_id is not None: + # If a `run_id` is provided, check that it belongs to the provided experiment. + run = self._mlflow_client.get_run(run_id) + if run.info.experiment_id != experiment_id: + raise ValueError( + f'Provided `run_id` {run_id} does not belong to provided experiment {experiment_id}. ' + f'Found experiment {run.info.experiment_id}.') + + log.debug(f'MLFlowObjectStore using provided MLflow run {run_id=} ' + f'for provided experiment {experiment_id=}') + else: + # If no `run_id` is provided, create a new run in the provided experiment. + run = self._mlflow_client.create_run(experiment_id) + run_id = run.info.run_id + + log.debug(f'MLFlowObjectStore using new MLflow run {run_id=} ' + f'for provided experiment {experiment_id=}') + + if experiment_id is None or run_id is None: + raise ValueError('MLFlowObjectStore failed to initialize experiment and run ID.') + + return experiment_id, run_id + + @staticmethod + def parse_dbfs_path(path: str) -> Tuple[str, str, str]: + """Parses a DBFS path to extract the MLflow experiment ID, run ID, and relative artifact path. + + The path is expected to be of the format + `databricks/mlflow-tracking///artifacts/`. + + Args: + path (str): The DBFS path to parse. + + Returns: + (str, str, str): (experiment_id, run_id, artifact_path) + + Raises: + ValueError: If the path is not of the expected format. + """ + if not path.startswith(MLFLOW_DBFS_PATH_PREFIX): + raise ValueError(f'DBFS MLflow path should start with {MLFLOW_DBFS_PATH_PREFIX}. Got: {path}') + + # Strip `databricks/mlflow-tracking/` and split into + # ``, ``, `'artifacts'`, ``` + subpath = path[len(MLFLOW_DBFS_PATH_PREFIX):] + mlflow_parts = subpath.split('/', maxsplit=3) + + if len(mlflow_parts) != 4 or mlflow_parts[2] != 'artifacts': + raise ValueError(f'Databricks MLflow artifact path expected to be of the format ' + f'{MLFLOW_DBFS_PATH_PREFIX}///artifacts/. ' + f'Found {path=}') + + return mlflow_parts[0], mlflow_parts[1], mlflow_parts[3] + + def get_artifact_path(self, object_name: str) -> str: + """Converts an object name into an MLflow relative artifact path. + + Args: + object_name (str): The object name to convert. If the object name is a DBFS path beginning with + ``MLFLOW_DBFS_PATH_PREFIX``, the path will be parsed to extract the MLflow relative artifact path. + Otherwise, the object name is assumed to be a relative artifact path and will be returned as-is. + """ + if object_name.startswith(MLFLOW_DBFS_PATH_PREFIX): + experiment_id, run_id, object_name = self.parse_dbfs_path(object_name) + if (experiment_id != self.experiment_id and experiment_id != PLACEHOLDER_EXPERIMENT_ID): + raise ValueError(f'Object {object_name} belongs to experiment ID {experiment_id}, ' + f'but MLFlowObjectStore is associated with experiment ID {self.experiment_id}.') + if (run_id != self.run_id and run_id != PLACEHOLDER_RUN_ID): + raise ValueError(f'Object {object_name} belongs to run ID {run_id}, ' + f'but MLFlowObjectStore is associated with run ID {self.run_id}.') + return object_name + + def get_dbfs_path(self, object_name: str) -> str: + """Converts an object name to a full DBFS path.""" + artifact_path = self.get_artifact_path(object_name) + return f'{MLFLOW_DBFS_PATH_PREFIX}{self.experiment_id}/{self.run_id}/artifacts/{artifact_path}' + + def get_uri(self, object_name: str) -> str: + return 'dbfs:/' + self.get_dbfs_path(object_name) + + def upload_object(self, + object_name: str, + filename: Union[str, pathlib.Path], + callback: Optional[Callable[[int, int], None]] = None): + del callback # unused + from mlflow.exceptions import MlflowException + + # Extract relative path from DBFS path. + artifact_path = self.get_artifact_path(object_name) + artifact_base_name = os.path.basename(artifact_path) + artifact_dir = os.path.dirname(artifact_path) + + # Since MLflow doesn't support uploading artifacts with a different base name than the local file, + # create a temporary symlink to the local file with the same base name as the desired artifact name. + filename = os.path.abspath(filename) + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_symlink_path = os.path.join(tmp_dir, artifact_base_name) + os.symlink(filename, tmp_symlink_path) + + try: + self._mlflow_client.log_artifact(self.run_id, tmp_symlink_path, artifact_dir) + except MlflowException as e: + _wrap_mlflow_exceptions(self.get_uri(object_name), e) + + def get_object_size(self, object_name: str) -> int: + from mlflow.exceptions import MlflowException + + artifact = None + try: + artifact = self._get_artifact_info(object_name) + except MlflowException as e: + _wrap_mlflow_exceptions(self.get_uri(object_name), e) + + if artifact is not None: + return artifact.file_size + else: + raise FileNotFoundError(f'Object {object_name} not found') + + def download_object( + self, + object_name: str, + filename: Union[str, pathlib.Path], + overwrite: bool = False, + callback: Optional[Callable[[int, int], None]] = None, + ) -> None: + del callback # unused + from mlflow.exceptions import MlflowException + + # Since MlflowClient.download_artifacts only raises MlflowException with 500 Internal Error, + # check for existence to surface a FileNotFoundError if necessary. + artifact_path = self.get_artifact_path(object_name) + artifact_info = self._get_artifact_info(object_name) + if artifact_info is None: + raise FileNotFoundError(f'Object {self.get_dbfs_path(artifact_path)} not found') + + filename = os.path.abspath(filename) + if os.path.exists(filename) and not overwrite: + raise FileExistsError(f'The file at {filename} already exists and overwrite is set to False.') + + # MLFLow doesn't support downloading artifacts directly to a specified filename, so instead + # download to a temporary directory and then move the file to the desired location. + with tempfile.TemporaryDirectory() as tmp_dir: + try: + self._mlflow_client.download_artifacts( + run_id=self.run_id, + path=artifact_path, + dst_path=tmp_dir, + ) + tmp_path = os.path.join(tmp_dir, artifact_path) + + os.makedirs(os.path.dirname(filename), exist_ok=True) + if overwrite: + os.replace(tmp_path, filename) + else: + os.rename(tmp_path, filename) + except MlflowException as e: + _wrap_mlflow_exceptions(self.get_uri(artifact_path), e) + + def list_objects(self, prefix: Optional[str] = None) -> List[str]: + """See :meth:`~composer.utils.ObjectStore.list_objects`. + + MLFlowObjectStore does not support listing objects with a prefix, so the ``prefix`` argument is ignored. + """ + del prefix # not supported for MLFlowObjectStore + + objects = [] + self._list_objects_helper(None, objects) + return objects + + def _list_objects_helper(self, prefix: Optional[str], objects: List[str]) -> None: + """Helper to recursively populate the full list of objects for ``list_objects``. + + Args: + prefix (str | None): An artifact path prefix for artifacts to find. + objects (list[str]): The list of DBFS object paths to populate. + """ + from mlflow.exceptions import MlflowException + + artifact = None + try: + for artifact in self._mlflow_client.list_artifacts(self.run_id, prefix): + if artifact.is_dir: + self._list_objects_helper(artifact.path, objects) + else: + objects.append(artifact.path) + except MlflowException as e: + uri = '' if artifact is None else self.get_uri(artifact.path) + _wrap_mlflow_exceptions(uri, e) + + def _get_artifact_info(self, object_name): + """Get the :class:`~mlflow.entities.FileInfo` for the given object name. + + Args: + object_name (str): The name of the object, either as an absolute DBFS path or a relative MLflow artifact path. + + Returns: + Optional[FileInfo]: The :class:`~mlflow.entities.FileInfo` for the object, or None if it does not exist. + """ + # MLflow doesn't support info for a singleton artifact, so we need to list all artifacts in the + # parent path and find the one with the matching name. + artifact_path = self.get_artifact_path(object_name) + artifact_dir = os.path.dirname(artifact_path) + artifacts = self._mlflow_client.list_artifacts(self.run_id, artifact_dir) + for artifact in artifacts: + if not artifact.is_dir and artifact.path == artifact_path: + return artifact + + return None diff --git a/composer/utils/object_store/object_store.py b/composer/utils/object_store/object_store.py index ef39b847a3..11f763ef68 100644 --- a/composer/utils/object_store/object_store.py +++ b/composer/utils/object_store/object_store.py @@ -87,7 +87,7 @@ def upload_object(self, Args: object_name (str): Object name (where object will be stored in the container) - filename (str | pathlib.Path): Path the the object on disk + filename (str | pathlib.Path): Path to the object on disk callback ((int, int) -> None, optional): If specified, the callback is periodically called with the number of bytes uploaded and the total size of the object being uploaded. **kwargs: other arguments to the upload object function are supported @@ -133,6 +133,7 @@ def download_object( downloaded and the total size of the object. Raises: + FileExistsError: If ``filename`` already exists and ``overwrite`` is ``False``. FileNotFoundError: If the file was not found in the object store. ObjectStoreTransientError: If there was a transient connection issue with downloading the object. """ diff --git a/composer/utils/object_store/uc_object_store.py b/composer/utils/object_store/uc_object_store.py index 8317675134..23e8440354 100644 --- a/composer/utils/object_store/uc_object_store.py +++ b/composer/utils/object_store/uc_object_store.py @@ -53,7 +53,7 @@ def __init__(self, path: str) -> None: try: from databricks.sdk import WorkspaceClient except ImportError as e: - raise MissingConditionalImportError('databricks', conda_package='databricks-sdk>=0.8.0,<1.0') from e + raise MissingConditionalImportError('databricks', conda_package='databricks-sdk>=0.15.0,<1.0') from e try: self.client = WorkspaceClient() @@ -167,7 +167,10 @@ def download_object(self, try: from databricks.sdk.core import DatabricksError try: - with self.client.files.download(self._get_object_path(object_name)).contents as resp: + contents = self.client.files.download(self._get_object_path(object_name)).contents + assert contents is not None + + with contents as resp: # pyright: ignore with open(tmp_path, 'wb') as f: # Chunk the data into multiple blocks of 64MB to avoid # OOMs when downloading really large files @@ -199,11 +202,15 @@ def get_object_size(self, object_name: str) -> int: Raises: FileNotFoundError: If the file was not found in the object store. + IsADirectoryError: If the object is a directory, not a file. """ from databricks.sdk.core import DatabricksError try: file_info = self.client.files.get_status(self._get_object_path(object_name)) - return file_info.file_size + if file_info.is_dir: + raise IsADirectoryError(f'{object_name} is a UC directory, not a file.') + + return file_info.file_size # pyright: ignore except DatabricksError as e: _wrap_errors(self.get_uri(object_name), e) @@ -231,6 +238,7 @@ def list_objects(self, prefix: Optional[str]) -> List[str]: path=self._UC_VOLUME_LIST_API_ENDPOINT, data=data, headers={'Source': 'mosaicml/composer'}) + assert isinstance(resp, dict) return [f['path'] for f in resp.get('files', []) if not f['is_dir']] except DatabricksError as e: _wrap_errors(self.get_uri(prefix), e) diff --git a/setup.py b/setup.py index 40155d9108..7322bdc49e 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,8 @@ def package_files(prefix: str, directory: str, extension: str): break else: assert end != -1, 'there should be a balanced number of start and ends' - long_description = long_description[:start] + long_description[end + len(end_tag):] + long_description = long_description[:start] + \ + long_description[end + len(end_tag):] install_requires = [ 'pyyaml>=6.0,<7', @@ -223,7 +224,7 @@ def package_files(prefix: str, directory: str, extension: str): ] extra_deps['mlflow'] = [ - 'mlflow>=2.9.0,<3.0', + 'mlflow>=2.9.2,<3.0', ] extra_deps['pandas'] = ['pandas>=2.0.0,<3.0'] diff --git a/tests/utils/object_store/object_store_settings.py b/tests/utils/object_store/object_store_settings.py index d94cd70fd6..05f459db9c 100644 --- a/tests/utils/object_store/object_store_settings.py +++ b/tests/utils/object_store/object_store_settings.py @@ -14,8 +14,8 @@ import composer.utils.object_store import composer.utils.object_store.sftp_object_store -from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, ObjectStore, OCIObjectStore, - S3ObjectStore, SFTPObjectStore, UCObjectStore) +from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore, + OCIObjectStore, S3ObjectStore, SFTPObjectStore, UCObjectStore) from composer.utils.object_store.sftp_object_store import SFTPObjectStore from tests.common import get_module_subclasses @@ -56,8 +56,9 @@ object_stores = [ pytest.param(x, marks=_object_store_marks[x], id=x.__name__) for x in get_module_subclasses(composer.utils.object_store, ObjectStore) - # Note: OCI, GCS and UC have their own test suite, so they are exempt from being included in this one.`` - if not issubclass(x, OCIObjectStore) and not issubclass(x, GCSObjectStore) and not issubclass(x, UCObjectStore) + # Note: OCI, GCS, UC, and MLFlow have their own test suite, so they are exempt from being included in this one.`` + if not issubclass(x, OCIObjectStore) and not issubclass(x, GCSObjectStore) and not issubclass(x, UCObjectStore) and + not issubclass(x, MLFlowObjectStore) ] diff --git a/tests/utils/object_store/test_mlflow_object_store.py b/tests/utils/object_store/test_mlflow_object_store.py new file mode 100644 index 0000000000..d46fc493a4 --- /dev/null +++ b/tests/utils/object_store/test_mlflow_object_store.py @@ -0,0 +1,325 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import os +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from composer.utils import MLFlowObjectStore +from composer.utils.object_store.mlflow_object_store import PLACEHOLDER_EXPERIMENT_ID, PLACEHOLDER_RUN_ID + +TEST_PATH_FORMAT = 'databricks/mlflow-tracking/{experiment_id}/{run_id}/artifacts/' +EXPERIMENT_ID = '123' +EXPERIMENT_NAME = 'test-experiment' +RUN_ID = '456' +RUN_NAME = 'test-run' +ARTIFACT_PATH = 'path/to/artifact' +DEFAULT_PATH = TEST_PATH_FORMAT.format(experiment_id=EXPERIMENT_ID, run_id=RUN_ID) + + +def test_parse_dbfs_path(): + full_artifact_path = DEFAULT_PATH + ARTIFACT_PATH + assert MLFlowObjectStore.parse_dbfs_path(full_artifact_path) == (EXPERIMENT_ID, RUN_ID, ARTIFACT_PATH) + + # Test with bad prefix + with pytest.raises(ValueError): + MLFlowObjectStore.parse_dbfs_path(f'bad-prefix/{EXPERIMENT_ID}/{RUN_ID}/artifacts/{ARTIFACT_PATH}') + + # Test without artifacts + with pytest.raises(ValueError): + MLFlowObjectStore.parse_dbfs_path(f'databricks/mlflow-tracking/{EXPERIMENT_ID}/{RUN_ID}/') + with pytest.raises(ValueError): + MLFlowObjectStore.parse_dbfs_path(f'databricks/mlflow-tracking/{EXPERIMENT_ID}/{RUN_ID}/not-artifacts/') + + +def test_init_fail_without_databricks_tracking_uri(monkeypatch): + monkeypatch.setenv('MLFLOW_TRACKING_URI', 'not-databricks') + with pytest.raises(ValueError): + MLFlowObjectStore(DEFAULT_PATH) + + +def test_init_with_experiment_and_run(monkeypatch): + dbx_sdk = pytest.importorskip('databricks.sdk') + monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) + + mlflow = pytest.importorskip('mlflow') + mock_mlflow_client = MagicMock() + monkeypatch.setattr(mlflow, 'MlflowClient', mock_mlflow_client) + + mock_mlflow_client.return_value.get_run.return_value = MagicMock(info=MagicMock(experiment_id=EXPERIMENT_ID)) + + store = MLFlowObjectStore(DEFAULT_PATH) + assert store.experiment_id == EXPERIMENT_ID + assert store.run_id == RUN_ID + + +def test_init_with_experiment_and_no_run(monkeypatch): + dbx_sdk = pytest.importorskip('databricks.sdk') + monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) + + mlflow = pytest.importorskip('mlflow') + mock_mlflow_client = MagicMock() + monkeypatch.setattr(mlflow, 'MlflowClient', mock_mlflow_client) + + mock_mlflow_client.return_value.create_run.return_value = MagicMock( + info=MagicMock(run_id=RUN_ID, run_name='test-run')) + + store = MLFlowObjectStore(TEST_PATH_FORMAT.format(experiment_id=EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID)) + assert store.experiment_id == EXPERIMENT_ID + assert store.run_id == RUN_ID + + +def test_init_with_run_and_no_experiment(monkeypatch): + dbx_sdk = pytest.importorskip('databricks.sdk') + monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) + + with pytest.raises(ValueError): + MLFlowObjectStore(TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=RUN_ID)) + + +def test_init_with_active_run(monkeypatch): + dbx_sdk = pytest.importorskip('databricks.sdk') + monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) + + mlflow = pytest.importorskip('mlflow') + mock_active_run = MagicMock() + monkeypatch.setattr(mlflow, 'active_run', mock_active_run) + monkeypatch.setattr(mlflow, 'MlflowClient', MagicMock()) + + mock_active_run.return_value = MagicMock(info=MagicMock(experiment_id=EXPERIMENT_ID, run_id=RUN_ID)) + + store = MLFlowObjectStore( + TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID)) + assert store.experiment_id == EXPERIMENT_ID + assert store.run_id == RUN_ID + + +def test_init_with_existing_experiment_and_no_run(monkeypatch): + dbx_sdk = pytest.importorskip('databricks.sdk') + monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) + + mlflow = pytest.importorskip('mlflow') + mock_mlflow_client = MagicMock() + monkeypatch.setattr(mlflow, 'MlflowClient', mock_mlflow_client) + + mock_mlflow_client.return_value.get_experiment_by_name.return_value = MagicMock(experiment_id=EXPERIMENT_ID) + mock_mlflow_client.return_value.create_run.return_value = MagicMock( + info=MagicMock(run_id=RUN_ID, run_name='test-run')) + + store = MLFlowObjectStore( + TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID)) + assert store.experiment_id == EXPERIMENT_ID + assert store.run_id == RUN_ID + + +def test_init_with_no_experiment_and_no_run(monkeypatch): + dbx_sdk = pytest.importorskip('databricks.sdk') + monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) + + mlflow = pytest.importorskip('mlflow') + mock_mlflow_client = MagicMock() + monkeypatch.setattr(mlflow, 'MlflowClient', mock_mlflow_client) + + mock_mlflow_client.return_value.get_experiment_by_name.return_value = None + mock_mlflow_client.return_value.create_experiment.return_value = EXPERIMENT_ID + mock_mlflow_client.return_value.create_run.return_value = MagicMock( + info=MagicMock(run_id=RUN_ID, run_name='test-run')) + + store = MLFlowObjectStore( + TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID)) + assert store.experiment_id == EXPERIMENT_ID + assert store.run_id == RUN_ID + + +@pytest.fixture() +def mlflow_object_store(monkeypatch): + + def mock_mlflow_client_list_artifacts(*args, **kwargs): + """Mock behavior for MlflowClient.list_artifacts(). + + Behaves as if artifacts are stored under the following structure: + - dir1/ + - a.txt + - b.txt + - dir2/ + - c.txt + - dir3/ + - d.txt + """ + path = args[1] + if not path: + return [ + MagicMock(path='dir1', is_dir=True, file_size=None), + MagicMock(path='dir2', is_dir=True, file_size=None) + ] + elif path == 'dir1': + return [ + MagicMock(path='dir1/a.txt', is_dir=False, file_size=100), + MagicMock(path='dir1/b.txt', is_dir=False, file_size=200) + ] + elif path == 'dir2': + return [ + MagicMock(path='dir2/c.txt', is_dir=False, file_size=300), + MagicMock(path='dir2/dir3', is_dir=True, file_size=None) + ] + elif path == 'dir2/dir3': + return [MagicMock(path='dir2/dir3/d.txt', is_dir=False, file_size=400)] + else: + return [] + + dbx_sdk = pytest.importorskip('databricks.sdk') + monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) + + mlflow = pytest.importorskip('mlflow') + mock_mlflow_client = MagicMock() + monkeypatch.setattr(mlflow, 'MlflowClient', mock_mlflow_client) + + mock_mlflow_client.return_value.get_run.return_value = MagicMock(info=MagicMock(experiment_id=EXPERIMENT_ID)) + mock_mlflow_client.return_value.list_artifacts.side_effect = mock_mlflow_client_list_artifacts + + yield MLFlowObjectStore(DEFAULT_PATH) + + +def test_get_artifact_path(mlflow_object_store): + # Relative MLFlow artifact path + assert mlflow_object_store.get_artifact_path(ARTIFACT_PATH) == ARTIFACT_PATH + + # Absolute DBFS path + assert mlflow_object_store.get_artifact_path(DEFAULT_PATH + ARTIFACT_PATH) == ARTIFACT_PATH + + # Absolute DBFS path with placeholders + path = TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID) + ARTIFACT_PATH + assert mlflow_object_store.get_artifact_path(path) == ARTIFACT_PATH + + # Raises ValueError for different experiment ID + path = TEST_PATH_FORMAT.format(experiment_id='different-experiment', run_id=PLACEHOLDER_RUN_ID) + ARTIFACT_PATH + with pytest.raises(ValueError): + mlflow_object_store.get_artifact_path(path) + + # Raises ValueError for different run ID + path = TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id='different-run') + ARTIFACT_PATH + with pytest.raises(ValueError): + mlflow_object_store.get_artifact_path(path) + + +def test_get_dbfs_path(mlflow_object_store): + experiment_id = mlflow_object_store.experiment_id + run_id = mlflow_object_store.run_id + + expected_dbfs_path = f'databricks/mlflow-tracking/{experiment_id}/{run_id}/artifacts/{ARTIFACT_PATH}' + assert mlflow_object_store.get_dbfs_path(ARTIFACT_PATH) == expected_dbfs_path + + +def test_get_uri(mlflow_object_store): + experiment_id = mlflow_object_store.experiment_id + run_id = mlflow_object_store.run_id + expected_uri = f'dbfs:/databricks/mlflow-tracking/{experiment_id}/{run_id}/artifacts/{ARTIFACT_PATH}' + + # Relative MLFlow artifact path + assert mlflow_object_store.get_uri(ARTIFACT_PATH) == expected_uri + + # Absolute DBFS path + assert mlflow_object_store.get_uri(DEFAULT_PATH + ARTIFACT_PATH) == expected_uri + + +def test_get_artifact_info(mlflow_object_store): + assert mlflow_object_store._get_artifact_info('dir1/a.txt').path == 'dir1/a.txt' + assert mlflow_object_store._get_artifact_info('dir1/b.txt').path == 'dir1/b.txt' + assert mlflow_object_store._get_artifact_info('dir2/c.txt').path == 'dir2/c.txt' + assert mlflow_object_store._get_artifact_info('dir2/dir3/d.txt').path == 'dir2/dir3/d.txt' + + # Test with absolute DBFS path + assert mlflow_object_store._get_artifact_info(DEFAULT_PATH + 'dir1/a.txt').path == 'dir1/a.txt' + + # Verify directories are not returned + assert mlflow_object_store._get_artifact_info('dir1') is None + + # Test non-existent artifact + assert mlflow_object_store._get_artifact_info('nonexistent.txt') is None + + +def test_get_object_size(mlflow_object_store): + assert mlflow_object_store.get_object_size('dir1/a.txt') == 100 + assert mlflow_object_store.get_object_size('dir1/b.txt') == 200 + assert mlflow_object_store.get_object_size('dir2/c.txt') == 300 + assert mlflow_object_store.get_object_size('dir2/dir3/d.txt') == 400 + + # Test with absolute DBFS path + assert mlflow_object_store.get_object_size(DEFAULT_PATH + 'dir1/a.txt') == 100 + + # Verify FileNotFoundError is raised for non-existent artifact + with pytest.raises(FileNotFoundError): + mlflow_object_store.get_object_size('dir1') + with pytest.raises(FileNotFoundError): + mlflow_object_store.get_object_size('nonexistent.txt') + + +def test_download_object(mlflow_object_store, tmp_path): + + def mock_mlflow_client_download_artifacts(*args, **kwargs): + path = kwargs['path'] + dst_path = kwargs['dst_path'] + local_path = os.path.join(dst_path, path) + os.makedirs(os.path.dirname(local_path), exist_ok=True) + + size = mlflow_object_store.get_object_size(path) + file_content = bytes('0' * (size), 'utf-8') + + print(local_path) + + with open(local_path, 'wb') as fp: + fp.write(file_content) + return local_path + + mlflow_object_store._mlflow_client.download_artifacts.side_effect = mock_mlflow_client_download_artifacts + + # Test downloading file + object_name = 'dir1/a.txt' + file_to_download = str(tmp_path / Path(object_name)) + mlflow_object_store.download_object(object_name, file_to_download) + assert os.path.exists(file_to_download) + assert os.path.getsize(file_to_download) == mlflow_object_store.get_object_size(object_name) + + # Test cannot overwrite existing file when `overwrite` is False + with pytest.raises(FileExistsError): + mlflow_object_store.download_object(object_name, file_to_download, overwrite=False) + + # Test can overwrite existing file when `overwrite` is True + mlflow_object_store.download_object(object_name, file_to_download, overwrite=True) + + # Test downloading file under different name + object_name = 'dir1/a.txt' + file_to_download = str(tmp_path / Path('renamed.txt')) + mlflow_object_store.download_object(object_name, file_to_download) + assert os.path.exists(file_to_download) + assert os.path.getsize(file_to_download) == mlflow_object_store.get_object_size(object_name) + + # Raises FileNotFound when artifact does not exist + with pytest.raises(FileNotFoundError): + mlflow_object_store.download_object('nonexistent.txt', file_to_download) + + +def test_upload_object(mlflow_object_store, tmp_path): + file_to_upload = str(tmp_path / Path('file.txt')) + with open(file_to_upload, 'wb') as f: + f.write(bytes(range(20))) + + object_name = 'dir1/file.txt' + mlflow_object_store.upload_object(object_name=object_name, filename=file_to_upload) + run_id, local_path, artifact_dir = mlflow_object_store._mlflow_client.log_artifact.call_args.args + assert run_id == mlflow_object_store.run_id + assert os.path.basename(local_path) == os.path.basename(object_name) + assert artifact_dir == os.path.dirname(object_name) + + # Test basename symlink is created with correct name when object base name is different + object_name = 'dir1/renamed.txt' + mlflow_object_store.upload_object(object_name=object_name, filename=file_to_upload) + _, local_path, _ = mlflow_object_store._mlflow_client.log_artifact.call_args.args + assert os.path.basename(local_path) == os.path.basename(object_name) + + +def test_list_objects(mlflow_object_store): + expected = {'dir1/a.txt', 'dir1/b.txt', 'dir2/c.txt', 'dir2/dir3/d.txt'} + assert set(mlflow_object_store.list_objects()) == expected