Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save Axolotl config as WandB artifact #716

Merged
merged 5 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
Expand Down
24 changes: 24 additions & 0 deletions src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,27 @@ def log_table_from_dataloader(name: str, table_dataloader):
return control

return LogPredictionCallback


class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
"""Callback to save axolotl config to wandb"""

def __init__(self, cfg):
self.cfg = cfg

def on_train_begin(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
artifact = wandb.Artifact(name="axolotl-config", type="config")
artifact.add_file(local_path=self.cfg.axolotl_config_path)
wandb.run.log_artifact(artifact)
LOG.info("Axolotl config has been saved to WandB as an artifact.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe log.debug is enough

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im indifferent, we log/print 20+ lines of Wandb initialization (and hundreds of lines on training begin) anyway, wouldn't it make sense to leave it at info to make people aware that their config is archived - how should anyone (except us) ever find/know it otherwise?
With an explicit (short) statement after successful archivation in the log, people then know that they can e.g safely overwrite their config.yml for a new experiment.

But happy to change that as well, whats your opinion @winglian ?

except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control
4 changes: 4 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
Expand Down Expand Up @@ -775,6 +776,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
trainer.add_callback(LogPredictionCallback(cfg))

if cfg.use_wandb:
trainer.add_callback(SaveAxolotlConfigtoWandBCallback(cfg))
jphme marked this conversation as resolved.
Show resolved Hide resolved

if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))

Expand Down