You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I combine device_train_microbatch_size="auto" with a model that is compiled with torch.compile(...), I get bad behavior in a couple predictable dimensions:
Because each separate batch size requires re-compiling the model, the launch time is very slow.
I get OOM errors after several compilations have failed. I'm guessing that torch is caching some values on the GPU or something from each compilation, but I'm not precisely sure what is going on.
My best guess at a clean solution is to determine the microbatch size with the uncompiled model so that only one compilation needs to be done. Another possible solution is to wait for torch to improve their implementation of dynamic shaping and treat the batch size as a dynamic shape dimension so that only one compilation needs to be done.
For now, I'm avoiding this issue by just manually specifying a microbatch size.
The text was updated successfully, but these errors were encountered:
My best guess at a clean solution is to determine the microbatch size with the uncompiled model so that only one compilation needs to be done. Another possible solution is to wait for torch to improve their implementation of dynamic shaping and treat the batch size as a dynamic shape dimension so that only one compilation needs to be done.
I think this seems like a reasonable path forward. We have put this on our roadmap but may not get to it for a few weeks.
We also welcome contributions, and you're more than welcome to open a PR implementing option 1!
When I combine
device_train_microbatch_size="auto"
with a model that is compiled withtorch.compile(...)
, I get bad behavior in a couple predictable dimensions:My best guess at a clean solution is to determine the microbatch size with the uncompiled model so that only one compilation needs to be done. Another possible solution is to wait for torch to improve their implementation of dynamic shaping and treat the batch size as a dynamic shape dimension so that only one compilation needs to be done.
For now, I'm avoiding this issue by just manually specifying a microbatch size.
The text was updated successfully, but these errors were encountered: