Skip to content

Commit

Permalink
One cycle lr (#1803)
Browse files Browse the repository at this point in the history
* refactor one_cycle lr scheduler so it's reusable in more situations

* fix validation for lr_scheduler

* default to cosine anneal strategy

* one cycle lr exepects cos
  • Loading branch information
winglian authored Aug 5, 2024
1 parent b7665c2 commit ecdda00
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 43 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ repos:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
Expand Down
74 changes: 32 additions & 42 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ class AxolotlTrainingMixins:
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)


@dataclass
Expand Down Expand Up @@ -318,7 +324,23 @@ def create_scheduler(
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if use_cosine_quadratic:
if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
if "pct_start" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["pct_start"] = pct_start
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["anneal_strategy"] = "cos"

self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")

Expand Down Expand Up @@ -876,37 +898,6 @@ def compute_loss(
return lm_loss


class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""

tag_names = ["axolotl", "onecycle"]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None

def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps

self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
pct_start=pct_start,
div_factor=6,
)

return self.lr_scheduler


class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
Expand Down Expand Up @@ -1190,10 +1181,6 @@ def get_post_trainer_create_callbacks(self, trainer):
return callbacks

def _get_trainer_cls(self):
if self.cfg.lr_scheduler == "one_cycle" and (
self.cfg.fsdp or self.cfg.adapter == "qlora"
):
return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
Expand Down Expand Up @@ -1443,12 +1430,15 @@ def build(self, total_num_steps):
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler
if self.cfg.lr_scheduler
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine"
)
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[
"alternate_lr_scheduler_type"
] = self.cfg.lr_scheduler
else:
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_arguments_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ class HyperparametersConfig(BaseModel):
},
)
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[SchedulerType] = "cosine"
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None
cosine_min_lr_ratio: Optional[float] = None
Expand Down

0 comments on commit ecdda00

Please sign in to comment.