Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

schedulefree optimizers #30079

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

winglian
Copy link
Contributor

@winglian winglian commented Apr 6, 2024

What does this PR do?

integrates meta's https://github.com/facebookresearch/schedule_free for adamw & sgd

https://twitter.com/aaron_defazio/status/1776320004465582331

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr @younesbelkada @pacman100

@muellerzr
Copy link
Contributor

muellerzr commented Apr 6, 2024

FYI this will need huggingface/accelerate#2631 as we need to upstream accelerate's ability to call train/eval on a wrapped optimizer

@danielhanchen
Copy link
Contributor

Some thoughts:

  • I was trying to ask Aaron et al on Twitter if they did any transformer experiments, but to no avail. They said a paper will come in 1 or 2 months.
  • Aaron et al's past work on D-Adaptation won a best ICML paper, with their follow up work being Prodigy - but both on transformers did similar or worse than AdamW. https://twitter.com/danielhanchen/status/1775547139248341125
  • Superconvergence + LR range finder + Fast AI's Ranger21 optimizer was the goto optimizer for CNNs, and worked fabulously well, but on transformers, the learning rate range finder sadi 1e-3 was the best, whilst 1e-5 was better. However, the 1 cycle learning rate stuck. Learning rate finder for the trainer  #16013
  • A huge issue is this needs tuning??! But how about a well tuned AdamW? Eg see https://twitter.com/kellerjordan0/status/1776716388037529843 which outperformed it using a tuned SGD.

I'm just a little bit reserved for now since the author themselves aren't providing any transformer benchmarks, nor have they compared their CNN baselines to superconvergence, which is the goto standard for fast training for CNNs. Likewise https://parameterfree.com/2023/08/30/yet-another-icml-award-fiasco/ wasn't pleasant.

@PhilipMay
Copy link
Contributor

Should be very easy to test this on Phi-2 or TinyLlama when the implementation works?

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @winglian ! 🤩 I left one minor comment, wdyt?

@@ -3117,6 +3145,9 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
if "ScheduleFree" in self.optimizer.__class__.__name__:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe instead of checking the class name here we could inject an attribute _hf_schedule_free_optim to make sure we can support that in the future for other shcedule free optimizers, what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would be on the Trainer class, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the place that makes the most sense to set that would be in get_optimizer_cls_and_kwargs but that is a @staticmethod so has no access to the trainer object. We could do something along the lines of

setattr(self.optimizer, "_hf_schedule_free_optim", True)

after we instantiate the optimizer_cls but we would still have to do some sort of class name detection.

Alternatively we could pass another value in the return tuple specific to schedule_free optimizers (but that feels worse)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh good point yeah, in that case this is probably already fine I would say, thanks for investigating @winglian !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than have it as a stateful attribute, could we instead move this logic out to a module-level function e.g.:

def _is_schedule_free_optimizer(optimizer):
    return "ScheduleFree" in optimizer__class__.__name__

?

This way:

  • The check is a bit more explicit within the code logic
  • we can easily adapt the checking in one place, rather than throughout the code, if we end up introducing e.g. a _is_schedule_free attribute or there's schedule free optimizers with slightly different names

@PhilipMay
Copy link
Contributor

This PR should maybe also add a few lines to the README about "how to use this".

@muellerzr
Copy link
Contributor

We've merged the accelerate portion in, so if anyone is trying this out in distributed fashions, you can do pip install git+https://github.com/huggingface/accelerate :)

src/transformers/trainer.py Outdated Show resolved Hide resolved
@bratao
Copy link

bratao commented Apr 14, 2024

There is any chance of this making into the main branch? I and other confirmed that the results are real. Thank you @winglian

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super useful addition of scheduler free optimizers @winglian! It would be great to document the usage along with a minimal example.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@CoffeeVampir3
Copy link

Is their any remaining work I could contribute towards getting this PR merged?

Cheers

@winglian
Copy link
Contributor Author

@pacman100 @muellerzr @younesbelkada Can we get a new review to get this merged? Since the last check, I rebased, added some fixes and docs.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Overall LG2M, let's pin schedulefree as a >= however.

Can you also run the quality checks? Afterwords at least from my end looks good to merge.

setup.py Outdated Show resolved Hide resolved
@winglian
Copy link
Contributor Author

winglian commented Jun 1, 2024

@muellerzr ran the make quality/lint and also added a smoke test to the test suite for schedule free adam

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a bunch! cc @LysandreJik for final review

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot !

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

Main comment is about the getattr logic in get_optimizer_cls_and_kwargs

@@ -3117,6 +3145,9 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
if "ScheduleFree" in self.optimizer.__class__.__name__:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than have it as a stateful attribute, could we instead move this logic out to a module-level function e.g.:

def _is_schedule_free_optimizer(optimizer):
    return "ScheduleFree" in optimizer__class__.__name__

?

This way:

  • The check is a bit more explicit within the code logic
  • we can easily adapt the checking in one place, rather than throughout the code, if we end up introducing e.g. a _is_schedule_free attribute or there's schedule free optimizers with slightly different names

additional_optim_kwargs["warmup_steps"] = args.warmup_steps
additional_optim_kwargs.update(
{
"weight_lr_power": float(getattr(torch, optim_args.get("weight_lr_power", 2.0))),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem right:

  • If we get "weight_lr_power" from optim_args I'm presuming it's a float as string e.g. "2.0"? I don't think torch.2.0 exists?
  • If optim_args doesn't have "weight_lr_power", then the second argument to getattr is a float, which isn't compatible

additional_optim_kwargs.update(
{
"weight_lr_power": float(getattr(torch, optim_args.get("weight_lr_power", 2.0))),
"r": float(getattr(torch, optim_args.get("r", 0.0))),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@winglian
Copy link
Contributor Author

Will get back to this soon. Not stale 😅

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Jul 31, 2024
@bratao
Copy link

bratao commented Jul 31, 2024

@winglian please don´t let it die

@amyeroberts amyeroberts reopened this Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet