Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: loss watchdog for terminating training runs that are failing #899

Merged
merged 1 commit into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,9 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128

loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)

# Save model as safetensors (require safetensors package)
save_safetensors:

Expand Down
3 changes: 3 additions & 0 deletions examples/mistral/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
eval_steps: 0.05
eval_table_size:
Expand Down
4 changes: 4 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
Expand Down Expand Up @@ -430,6 +431,9 @@ def get_callbacks(self):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)

if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))

return callbacks

def get_post_trainer_create_callbacks(self, trainer):
Expand Down
30 changes: 30 additions & 0 deletions src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,36 @@ def on_step_end(
return control


class LossWatchDogCallback(TrainerCallback):
"""Callback to track loss and stop training if loss is too high"""

def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.violations = 0
self.threshold = cfg.loss_watchdog_threshold
self.patience = cfg.loss_watchdog_patience or 3

def on_step_end(
self,
_args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
):
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
if state.log_history[-1]["loss"] > self.threshold:
self.violations += 1
if self.violations >= self.patience:
LOG.warning(
"Loss is too high, stopping training (loss_watchdog_threshold)"
)
control.should_training_stop = True
else:
self.violations = 0
return control


def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [
Expand Down