Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move MLFlow dataset outside of log_config #1234

Merged
merged 17 commits into from
May 24, 2024
7 changes: 5 additions & 2 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
'update_batch_size_info',
'process_init_device',
'log_config',
'log_dataset_uri',
]


Expand Down Expand Up @@ -508,7 +509,6 @@ def log_config(cfg: Dict[str, Any]) -> None:

if 'mlflow' in loggers and mlflow.active_run():
mlflow.log_params(params=cfg)
_log_dataset_uri(cfg)


def _parse_source_dataset(cfg: Dict[str, Any]) -> List[Tuple[str, str, str]]:
Expand Down Expand Up @@ -619,12 +619,15 @@ def _process_data_source(
log.warning('DataSource Not Found.')


def _log_dataset_uri(cfg: Dict[str, Any]) -> None:
def log_dataset_uri(cfg: Dict[str, Any]) -> None:
KuuCi marked this conversation as resolved.
Show resolved Hide resolved
"""Logs dataset tracking information to MLflow.

Args:
cfg (DictConfig): A config dictionary of a run
"""
loggers = cfg.get('loggers', None) or {}
if 'mlflow' not in loggers or not mlflow.active_run():
return
# Figure out which data source to use
data_paths = _parse_source_dataset(cfg)

Expand Down
2 changes: 2 additions & 0 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
TRAIN_CONFIG_KEYS,
TrainConfig,
log_config,
log_dataset_uri,
make_dataclass_and_log_config,
pop_config,
process_init_device,
Expand Down Expand Up @@ -530,6 +531,7 @@ def main(cfg: DictConfig) -> Trainer:
if train_cfg.log_config:
log.info('Logging config')
log_config(logged_cfg)
log_dataset_uri(logged_cfg)
torch.cuda.empty_cache()
gc.collect()

Expand Down
8 changes: 5 additions & 3 deletions tests/utils/test_mlflow_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import pytest

from llmfoundry.utils.config_utils import (
_log_dataset_uri,
_parse_source_dataset,
log_dataset_uri,
)

mlflow = pytest.importorskip('mlflow')
Expand Down Expand Up @@ -84,10 +84,12 @@ def test_log_dataset_uri():
}},
source_dataset_train='huggingface/train_dataset',
source_dataset_eval='huggingface/eval_dataset',
loggers={'mlflow': {}},
)

with patch('mlflow.log_input') as mock_log_input:
_log_dataset_uri(cfg)
with patch('mlflow.log_input') as mock_log_input, \
patch('mlflow.active_run', return_value=True):
log_dataset_uri(cfg)
assert mock_log_input.call_count == 2
meta_dataset_calls = [
args[0] for args, _ in mock_log_input.call_args_list
Expand Down
Loading