Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mlflow datasets #1119

Merged
merged 47 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
a99525b
add logger
KuuCi Apr 18, 2024
a9e7b0d
reqs
KuuCi Apr 18, 2024
ecdfaca
small fix
KuuCi Apr 18, 2024
c29cbc2
import mlflow
KuuCi Apr 18, 2024
b787cca
parse_uri
KuuCi Apr 18, 2024
ab7268f
parse_uri
KuuCi Apr 18, 2024
a06fcb2
finished debug
KuuCi Apr 18, 2024
aae821d
precommit
KuuCi Apr 18, 2024
996fb01
more code fix
KuuCi Apr 18, 2024
e507eac
revert setup
KuuCi Apr 18, 2024
f570842
better dovs
KuuCi Apr 18, 2024
1245c64
rm docstr
KuuCi Apr 18, 2024
c9006c5
precommit
KuuCi Apr 18, 2024
d79410c
Update tests to not rely on mistral (#1117)
dakinggg Apr 18, 2024
d47caea
Bump transformers to 4.40 (#1118)
dakinggg Apr 18, 2024
af68170
merge
KuuCi Apr 18, 2024
ca698b8
revert setup
KuuCi Apr 18, 2024
802dd8c
Merge branch 'main' into mlflow-datasets
KuuCi Apr 18, 2024
47bf6cb
precommit
KuuCi Apr 18, 2024
bbcabcc
precommit
KuuCi Apr 19, 2024
cf7c9df
tweaks to resolve comments
KuuCi Apr 19, 2024
eb2afbb
unit test
KuuCi Apr 19, 2024
05c2461
code quality
KuuCi Apr 19, 2024
bb86a78
quotation
KuuCi Apr 19, 2024
f3d8348
quote
KuuCi Apr 19, 2024
c44fafa
more quality
KuuCi Apr 19, 2024
6250a44
optional
KuuCi Apr 19, 2024
e38964b
pyright
KuuCi Apr 19, 2024
0199788
type check
KuuCi Apr 19, 2024
90bcad0
rm typechecking
KuuCi Apr 19, 2024
0d66d1d
yapf
KuuCi Apr 19, 2024
d286e15
first pass
KuuCi Apr 19, 2024
6a6632c
fix
KuuCi Apr 19, 2024
969c1c0
get refactor
KuuCi Apr 19, 2024
d472282
refactor
KuuCi Apr 19, 2024
cbf0c30
local hf path
KuuCi Apr 19, 2024
8dd8cec
dbfs
KuuCi Apr 19, 2024
3457eb5
rm local
KuuCi Apr 19, 2024
48843f9
typo
KuuCi Apr 19, 2024
39ba332
second pass
KuuCi Apr 22, 2024
5e0f853
update
KuuCi Apr 22, 2024
84a7930
Merge branch 'main' into mlflow-datasets
KuuCi Apr 22, 2024
6aae695
Merge branch 'main' into mlflow-datasets
dakinggg Apr 23, 2024
f63c325
Merge branch 'main' into mlflow-datasets
dakinggg Apr 23, 2024
a9fda9c
third pass
KuuCi Apr 23, 2024
85fa2df
os.path.join
KuuCi Apr 24, 2024
6cecaec
precommit
KuuCi Apr 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 96 additions & 6 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,22 @@
import logging
import math
import warnings
from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Union
from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Union, TYPE_CHECKING

from composer.utils import dist
from composer.utils import dist, parse_uri
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om

from llmfoundry.layers_registry import ffns_with_megablocks
from llmfoundry.models.utils import init_empty_weights

try:
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
import mlflow
except ImportError:
mlflow = None
if TYPE_CHECKING: # for pyright
import mlflow

log = logging.getLogger(__name__)

__all__ = [
Expand Down Expand Up @@ -178,9 +185,92 @@ def log_config(cfg: DictConfig) -> None:
wandb.config.update(om.to_container(cfg, resolve=True))

if 'mlflow' in cfg.get('loggers', {}):
try:
import mlflow
except ImportError as e:
raise e
if not mlflow:
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
raise ImportError('MLflow is required but not installed.')
if mlflow.active_run():
mlflow.log_params(params=om.to_container(cfg, resolve=True))
log_dataset_uri(cfg)


def parse_source_dataset(cfg: DictConfig):
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
"""Parse a run config for dataset information."""
data_paths = set()

for data_split in ['train', 'eval']:
split = cfg.get(f'{data_split}_loader', {}).get('dataset',
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
{}).get('split', None)
source_dataset_path = cfg.get(f'source_dataset_{data_split}', {})
KuuCi marked this conversation as resolved.
Show resolved Hide resolved

# Check for Delta table
if source_dataset_path and len(source_dataset_path.split('.')) >= 3:
data_paths.add(('delta_table', source_dataset_path, data_split))
# Check for UC volume
elif source_dataset_path and source_dataset_path.startswith('/Volumes'):
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
data_paths.add(('uc_volume', source_dataset_path, data_split))
# Check for HF path
elif cfg.get(f'{data_split}_loader', {}).get('dataset',
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
{}).get('hf_name'):
hf_path = cfg.get(f'{data_split}_loader', {}).get('dataset',
{}).get('hf_name')
backend, _, _ = parse_uri(hf_path)
if backend:
hf_path = f'{hf_path.rstrip("/")}/{split}' if split else hf_path
data_paths.add((backend, hf_path, data_split))
else:
data_paths.add(('hf', hf_path, data_split))
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
# check for remote path
elif cfg.get(f'{data_split}_loader', {}).get('dataset',
{}).get('remote', None):
remote_path = cfg.get(f'{data_split}_loader',
{}).get('dataset', {}).get('remote', None)
backend, _, _ = parse_uri(remote_path)
remote_path = f'{remote_path.rstrip("/")}/{split}/' if split else remote_path
data_paths.add((backend, remote_path, data_split))
# check for local path
elif cfg.get(f'{data_split}_loader', {}).get('dataset',
{}).get('local', None):
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
local_path = cfg.get(f'{data_split}_loader',
{}).get('dataset', {}).get('local', None)
split = cfg.get(f'{data_split}_loader',
{}).get('dataset', {}).get('split', None)
data_paths.add(('local', local_path, data_split))

return data_paths


def log_dataset_uri(cfg: DictConfig) -> Optional[mlflow.data.meta_dataset.MetaDataset]:
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
"""Logs dataset tracking information to MLflow."""
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
if mlflow is None:
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
log.warning('MLflow is not installed. Skipping dataset logging.')
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
return None
# Figure out which data source to use
data_paths = parse_source_dataset(cfg)

dataset_source_mapping = {
's3': mlflow.data.http_dataset_source.HTTPDatasetSource,
'oci': mlflow.data.http_dataset_source.HTTPDatasetSource,
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
'https': mlflow.data.http_dataset_source.HTTPDatasetSource,
'hf': mlflow.data.huggingface_dataset_source.HuggingFaceDatasetSource,
'delta_table': mlflow.data.delta_dataset_source.DeltaDatasetSource,
'uc_volume': mlflow.data.uc_volume_dataset_source.UCVolumeDatasetSource,
'local': mlflow.data.http_dataset_source.HTTPDatasetSource,
}

KuuCi marked this conversation as resolved.
Show resolved Hide resolved
for dataset_type, path, split in data_paths:
source_class = dataset_source_mapping.get(dataset_type)
KuuCi marked this conversation as resolved.
Show resolved Hide resolved

if source_class:
if dataset_type == 'delta_table':
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
source = source_class(delta_table_name=path)
elif dataset_type == 'hf' or dataset_type == 'uc_volume':
source = source_class(path=path)
else:
source = source_class(url=path)
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
else:
log.info(
f'{dataset_type} unknown, defaulting to http dataset source')
source = mlflow.data.http_dataset_source.HTTPDatasetSource(url=path)

mlflow.log_input(
mlflow.data.meta_dataset.MetaDataset(
source, name=split))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

install_requires = [
'mosaicml[libcloud,wandb,oci,gcs]>=0.21.1,<0.22',
'mlflow>=2.10,<3',
'mlflow>=2.12.1,<3',
'accelerate>=0.25,<0.26', # for HF inference `device_map`
'transformers>=4.40,<4.41',
'mosaicml-streaming>=0.7.5,<0.8',
Expand Down
99 changes: 99 additions & 0 deletions tests/utils/test_mlflow_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any
from unittest.mock import patch

import pytest
from omegaconf import OmegaConf

from llmfoundry.utils.config_utils import log_dataset_uri, parse_source_dataset

mlflow = pytest.importorskip('mlflow')


def create_config(**kwargs: Any):
"""Helper function to create OmegaConf configurations."""
return OmegaConf.create(kwargs)


def test_parse_source_dataset_delta_table():
cfg = create_config(source_dataset_train='db.schema.train_table',
source_dataset_eval='db.schema.eval_table')
expected = {('delta_table', 'db.schema.train_table', 'train'),
('delta_table', 'db.schema.eval_table', 'eval')}
assert parse_source_dataset(cfg) == expected


def test_parse_source_dataset_uc_volume():
cfg = create_config(source_dataset_train='/Volumes/train_data',
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
source_dataset_eval='/Volumes/eval_data')
expected = {('uc_volume', '/Volumes/train_data', 'train'),
('uc_volume', '/Volumes/eval_data', 'eval')}
assert parse_source_dataset(cfg) == expected


def test_parse_source_dataset_hf():
cfg = create_config(
train_loader={'dataset': {
'hf_name': 'huggingface/train_dataset'
}},
eval_loader={'dataset': {
'hf_name': 'huggingface/eval_dataset'
}})
expected = {('hf', 'huggingface/train_dataset', 'train'),
('hf', 'huggingface/eval_dataset', 'eval')}
assert parse_source_dataset(cfg) == expected


def test_parse_source_dataset_remote():
cfg = create_config(
train_loader={'dataset': {
'remote': 'https://remote/train_dataset'
}},
eval_loader={'dataset': {
'remote': 'https://remote/eval_dataset'
}})
expected = {('https', 'https://remote/train_dataset', 'train'),
('https', 'https://remote/eval_dataset', 'eval')}
assert parse_source_dataset(cfg) == expected


def test_parse_source_dataset_local():
cfg = create_config(
train_loader={'dataset': {
'local': '/local/train_dataset'
}},
eval_loader={'dataset': {
'local': '/local/eval_dataset'
}})
expected = {('local', '/local/train_dataset', 'train'),
('local', '/local/eval_dataset', 'eval')}
assert parse_source_dataset(cfg) == expected


@pytest.mark.usefixtures('mock_mlflow_classes')
def test_log_dataset_uri_all_sources():
cfg = create_config(
train_loader={'dataset': {
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
'hf_name': 'huggingface/train_dataset'
}},
eval_loader={'dataset': {
'hf_name': 'huggingface/eval_dataset'
}},
source_dataset_train='db.schema.train_table',
source_dataset_eval='/Volumes/eval_data')

with patch('mlflow.data.meta_dataset.MetaDataset'):
with patch('mlflow.log_input') as mock_log_input:
log_dataset_uri(cfg)
assert mock_log_input.call_count == 2
KuuCi marked this conversation as resolved.
Show resolved Hide resolved


@pytest.fixture
def mock_mlflow_classes():
with patch('mlflow.data.http_dataset_source.HTTPDatasetSource') as http_source, \
patch('mlflow.data.huggingface_dataset_source.HuggingFaceDatasetSource') as hf_source, \
patch('mlflow.data.delta_dataset_source.DeltaDatasetSource') as delta_source, \
patch('mlflow.data.uc_volume_dataset_source.UCVolumeDatasetSource') as uc_source:
yield http_source, hf_source, delta_source, uc_source
Loading