Skip to content

Commit

Permalink
make sure to have parity for model parameters passed to optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jul 9, 2024
1 parent 33a9fd6 commit dfb58f9
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,32 @@ def __init__(
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer is None
and self.args.alternate_optimizer != "optimi_adamw"
):
return super().create_optimizer()

opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]

optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
Expand All @@ -316,11 +336,11 @@ def create_optimizer(self):
loraplus_lr_ratio,
loraplus_lr_embedding,
)
else:
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW

self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(opt_model.parameters(), **optimizer_kwargs)
AdamW(optimizer_grouped_parameters, **optimizer_kwargs)
)

if is_sagemaker_mp_enabled():
Expand Down

0 comments on commit dfb58f9

Please sign in to comment.