Skip to content

Commit

Permalink
Add: mlflow for experiment tracking (#1059) [skip ci]
Browse files Browse the repository at this point in the history
* Update requirements.txt

adding mlflow

* Update __init__.py

Imports for mlflow

* Update README.md

* Create mlflow_.py (#1)

* Update README.md

* fix precommits

* Update README.md

Update mlflow_tracking_uri

* Update trainer_builder.py

update trainer building

* chore: lint

* make ternary a bit more readable

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
  • Loading branch information
JohanWork and winglian committed Jan 9, 2024
1 parent 651b7a3 commit 090c24d
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 2 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Features:
- Integrated with xformer, flash attention, rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb
- Log results and optionally checkpoints to wandb or mlflow
- And more!


Expand Down Expand Up @@ -695,6 +695,10 @@ wandb_name: # Set the name of your wandb run
wandb_run_id: # Set the ID of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training

# mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name

# Where to save the full-finetuned model to
output_dir: ./completed-model

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ hf_transfer
colorama
numba
numpy>=1.24.4
mlflow
# qlora things
bert-score==0.3.13
evaluate==0.4.0
Expand Down
4 changes: 4 additions & 0 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_optim_env
Expand Down Expand Up @@ -289,6 +290,9 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
normalize_config(cfg)

setup_wandb_env_vars(cfg)

setup_mlflow_env_vars(cfg)

return cfg


Expand Down
7 changes: 6 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,12 @@ def build(self, total_num_steps):
False if self.cfg.ddp else None
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
report_to = None
if self.cfg.use_wandb:
report_to = "wandb"
if self.cfg.use_mlflow:
report_to = "mlflow"
training_arguments_kwargs["report_to"] = report_to
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_name if self.cfg.use_wandb else None
)
Expand Down
18 changes: 18 additions & 0 deletions src/axolotl/utils/mlflow_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Module for mlflow utilities"""

import os

from axolotl.utils.dict import DictDefault


def setup_mlflow_env_vars(cfg: DictDefault):
for key in cfg.keys():
if key.startswith("mlflow_"):
value = cfg.get(key, "")

if value and isinstance(value, str) and len(value) > 0:
os.environ[key.upper()] = value

# Enable mlflow if experiment name is present
if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
cfg.use_mlflow = True

0 comments on commit 090c24d

Please sign in to comment.