Skip to content

Commit

Permalink
Cosine learning rate schedule - minimum learning rate (#1062)
Browse files Browse the repository at this point in the history
* Cosine min lr

* Cosine min lr - warn if using deepspeed

* cosine_min_lr_ratio readme

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
  • Loading branch information
RicardoDominguez and winglian committed Jan 9, 2024
1 parent c3e8165 commit 04b978b
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,7 @@ early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr

# For one_cycle optim
lr_div_factor: # Learning rate div factor
Expand Down
21 changes: 20 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
MambaDataCollator,
)
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
)

try:
import torch._dynamo # pylint: disable=ungrouped-imports
Expand Down Expand Up @@ -120,6 +123,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
cosine_min_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)


class AxolotlTrainer(Trainer):
Expand Down Expand Up @@ -159,6 +166,17 @@ def create_scheduler(
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.lr_scheduler_type == "cosine" and self.args.cosine_min_lr_ratio is not None:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
if self.args.deepspeed:
LOG.warning("Using cosine scheduler with deepspeed. This may be ignored if a scheduler is set \
in the deepspeed JSON")
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
Expand Down Expand Up @@ -745,6 +763,7 @@ def build(self, total_num_steps):
training_arguments_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
Expand Down
40 changes: 40 additions & 0 deletions src/axolotl/utils/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,43 @@ def get_cosine_schedule_with_quadratic_warmup(
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)


def _get_cosine_schedule_with_min_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float
):
# Warm up
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))

# Cosine learning rate decay
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
scaling = 0.5 * (1.0 + math.cos(math.pi * progress))
return (1 - min_lr_ratio) * scaling + min_lr_ratio


def get_cosine_schedule_with_min_lr(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float = 0.0,
):
"""
Create a learning rate schedule which has:
- linear warmup from 0 -> `max_lr` over `num_warmup_steps`
- cosine learning rate annealing from `max_lr` -> `min_lr` over `num_training_steps`
"""

lr_lambda = partial(
_get_cosine_schedule_with_min_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
min_lr_ratio=min_lr_ratio,
)
return LambdaLR(optimizer, lr_lambda)

0 comments on commit 04b978b

Please sign in to comment.