Skip to content

Commit

Permalink
Save Axolotl config as WandB artifact (#716)
Browse files Browse the repository at this point in the history
  • Loading branch information
jphme committed Oct 11, 2023
1 parent 5855dde commit 490923f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
Expand Down
24 changes: 24 additions & 0 deletions src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,27 @@ def log_table_from_dataloader(name: str, table_dataloader):
return control

return LogPredictionCallback


class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
"""Callback to save axolotl config to wandb"""

def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path

def on_train_begin(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
artifact = wandb.Artifact(name="axolotl-config", type="config")
artifact.add_file(local_path=self.axolotl_config_path)
wandb.run.log_artifact(artifact)
LOG.info("Axolotl config has been saved to WandB as an artifact.")
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control
4 changes: 4 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
Expand Down Expand Up @@ -775,6 +776,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
trainer.add_callback(LogPredictionCallback(cfg))

if cfg.use_wandb:
trainer.add_callback(SaveAxolotlConfigtoWandBCallback(cfg.axolotl_config_path))

if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))

Expand Down

0 comments on commit 490923f

Please sign in to comment.