diff --git a/README.md b/README.md index 60013df93..c2d8e7d8c 100644 --- a/README.md +++ b/README.md @@ -571,7 +571,7 @@ torch_compile_backend: # Optional[str] # training hyperparameters gradient_accumulation_steps: 1 micro_batch_size: 2 -eval_batch_size: 2 +eval_batch_size: num_epochs: 3 warmup_steps: 100 learning_rate: 0.00003 diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index a8c41d95b..9503d838c 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -49,6 +49,8 @@ def normalize_config(cfg): cfg.batch_size = ( cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps ) + if cfg.eval_batch_size is None: + cfg.eval_batch_size = cfg.micro_batch_size cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) cfg.eval_table_size = cfg.eval_table_size or 0 @@ -157,6 +159,11 @@ def validate_config(cfg): "batch_size is not recommended. Please use gradient_accumulation_steps instead.", "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.", ) + if cfg.eval_batch_size != cfg.micro_batch_size: + LOG.warning( + "eval_batch_size != micro_batch_size. This can lead to VRAM instability." + ) + if cfg.load_4bit: raise ValueError("cfg.load_4bit parameter has been deprecated") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 80ee5c8c6..a10a2b0e7 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -668,9 +668,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ max_steps=total_num_steps if cfg.max_steps else -1, max_seq_length=cfg.sequence_len, per_device_train_batch_size=cfg.micro_batch_size, - per_device_eval_batch_size=cfg.eval_batch_size - if cfg.eval_batch_size is not None - else cfg.micro_batch_size, + per_device_eval_batch_size=cfg.eval_batch_size, gradient_accumulation_steps=cfg.gradient_accumulation_steps, eval_accumulation_steps=cfg.gradient_accumulation_steps, num_train_epochs=cfg.num_epochs,