Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable HF SpeedMonitor #997

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Enable HF SpeedMonitor #997

wants to merge 8 commits into from

Conversation

rlrs
Copy link

@rlrs rlrs commented Feb 26, 2024

Enable SpeedMonitor on HF models by using PyTorch FlopCounterMode to calculate model FLOPs.

@rlrs rlrs requested a review from a team as a code owner February 26, 2024 18:27
@rlrs
Copy link
Author

rlrs commented Feb 26, 2024

Oops, some of these changes are for our internal use. Will remove them from here.

@dakinggg
Copy link
Collaborator

Hey @rlrs, thanks for the contribution! I didn't know about this PyTorch flop counter! We'll want to do a bit of testing to make sure that this reports the correct number and doesn't cause any issues with (1) speed (2) memory usage or (3) bad interactions with distributed training strategies like FSDP. What testing of this have you been able to do yourself?

@rlrs
Copy link
Author

rlrs commented Feb 27, 2024

Apologies for the lack of explanation or tests, I rushed this a bit.

So far I've used this with Mistral 7B, comparing against the standard Transformer Math 6PD calculation, and the results are quite close - well, I also rely on one of the same assumptions, namely that the backward pass is 2x the forward pass. It is possible to wrap fwd+bwd in FlopCounterMode instead of just fwd. To me, that seems more complicated since that code has to live outside the HF model wrapper, from where the model FLOPs have to be returned.

One uncertainty I have is how the FLOP counter interacts with non-PyTorch constructs like Flash Attention. I suspect that it might be necessary to register such code manually in order to get the correct result. If so, it might be silently underreporting FLOPs right now.

@dakinggg
Copy link
Collaborator

dakinggg commented Mar 1, 2024

No worries @rlrs! If you're able to do some testing (and add some unit tests) that would be great! Otherwise we'll look into it when we get a chance and appreciate the suggestion!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants