Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add warmup_ratio #893

Merged
merged 2 commits into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ micro_batch_size: 2
eval_batch_size:
num_epochs: 4
warmup_steps: 100
warmup_ratio: 0.05
NanoCode012 marked this conversation as resolved.
Show resolved Hide resolved
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
Expand Down
13 changes: 8 additions & 5 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,14 @@ def _get_trainer_cls(self):
return AxolotlTrainer

def build(self, total_num_steps):
warmup_steps = (
self.cfg.warmup_steps
if self.cfg.warmup_steps is not None
else min(int(0.03 * total_num_steps), 100)
)
warmup_steps = None
if self.cfg.warmup_steps is not None:
warmup_steps = self.cfg.warmup_steps
elif self.cfg.warmup_ratio is not None:
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:
warmup_steps = min(int(0.03 * total_num_steps), 100)

logging_steps = (
self.cfg.logging_steps
if self.cfg.logging_steps is not None
Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ def validate_config(cfg):
if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`")

if cfg.warmup_steps and cfg.warmup_ratio:
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")

# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
Expand Down
30 changes: 30 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,3 +649,33 @@ def test_load_in_x_bit_without_adapter(self):
)

validate_config(cfg)

def test_warmup_step_no_conflict(self):
cfg = DictDefault(
{
"warmup_steps": 10,
"warmup_ratio": 0.1,
}
)

with pytest.raises(
ValueError,
match=r".*warmup_steps and warmup_ratio are mutually exclusive*",
):
validate_config(cfg)

cfg = DictDefault(
{
"warmup_steps": 10,
}
)

validate_config(cfg)

cfg = DictDefault(
{
"warmup_ratio": 0.1,
}
)

validate_config(cfg)