From 36e53c7442cc64aa8fc4202110b9cf6cd6213111 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 13 Sep 2023 11:37:23 -0400 Subject: [PATCH] improve how we setup eval/save strategies and steps (#547) * setup save end eval strategies to be consistent with trainer logic * add comments * better eval handling --- src/axolotl/utils/trainer.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 69a633f16..9685176b6 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -567,21 +567,33 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ "sample_packing_efficiency" ] = cfg.sample_packing_eff_est - if cfg.val_set_size == 0: + if cfg.eval_steps and cfg.evaluation_strategy: + # assume if the user set both, they know what they're doing + training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy + training_arguments_kwargs["eval_steps"] = cfg.eval_steps + elif cfg.val_set_size == 0: + # no eval set, so don't eval training_arguments_kwargs["evaluation_strategy"] = "no" + elif cfg.evaluation_strategy and cfg.evaluation_strategy in ["epoch", "no"]: + # if explicitly set for epoch, just set, and eval steps don't matter + training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy elif cfg.eval_steps: + # steps isn't used w/ epochs training_arguments_kwargs["evaluation_strategy"] = "steps" training_arguments_kwargs["eval_steps"] = cfg.eval_steps else: - # we have an eval set, but no steps defined, use epoch + # we have an eval set, but no steps defined, default to use epoch training_arguments_kwargs["evaluation_strategy"] = "epoch" - if cfg.save_strategy: + if cfg.save_steps: + # save_steps implies save_strategy of steps + training_arguments_kwargs["save_strategy"] = "steps" + training_arguments_kwargs["save_steps"] = cfg.save_steps + elif cfg.save_strategy: training_arguments_kwargs["save_strategy"] = cfg.save_strategy else: - training_arguments_kwargs["save_strategy"] = ( - "steps" if cfg.save_steps else "epoch" - ) + # default to saving each epoch if not defined + training_arguments_kwargs["save_strategy"] = "epoch" if cfg.do_bench_eval: training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval