Skip to content

Commit

Permalink
feature: loss watchdog for terminating training runs that are failing
Browse files Browse the repository at this point in the history
  • Loading branch information
Karl-Johan Alm committed Nov 28, 2023
1 parent a48dbf6 commit 49c9cd7
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,9 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128

loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)

# Save model as safetensors (require safetensors package)
save_safetensors:

Expand Down
3 changes: 3 additions & 0 deletions examples/mistral/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
eval_steps: 0.05
eval_table_size:
Expand Down
4 changes: 4 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
Expand Down Expand Up @@ -430,6 +431,9 @@ def get_callbacks(self):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)

if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))

return callbacks

def get_post_trainer_create_callbacks(self, trainer):
Expand Down
30 changes: 30 additions & 0 deletions src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,36 @@ def on_step_end(
return control


class LossWatchDogCallback(TrainerCallback):
"""Callback to track loss and stop training if loss is too high"""

def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.violations = 0
self.threshold = cfg.loss_watchdog_threshold
self.patience = cfg.loss_watchdog_patience or 3

def on_step_end(
self,
_args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
):
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
if state.log_history[-1]["loss"] > self.threshold:
self.violations += 1
if self.violations >= self.patience:
LOG.warning(
"Loss is too high, stopping training (loss_watchdog_threshold)"
)
control.should_training_stop = True
else:
self.violations = 0
return control


def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [
Expand Down

0 comments on commit 49c9cd7

Please sign in to comment.