-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support broadcast_buffers in OssDdp #68
Comments
@myleott just to make sure I got the incentive right, the batchnorm problem is when using OSS in conjunction with a model parallel technique, right ? |
Nope, this affects any model which has batch norm that uses data parallel. In particular, batch norm keeps running stats which should be synchronized across data parallel workers. Here's an interesting discussion about a more flexible version of this (not saying we need this, but we should at least have the on/off version): pytorch/pytorch#30718 |
Finally got the time to get back to this @myleott, sorry for being slow. There's still something I don't get, we forcefully sync the model in between the ranks after each step already, by virtue of each rank being responsible for a shard's worth of update (so it needs to be sync'ed to the other ones, for now OSS just shards the optimizer parameter state and the full model state is on each rank). It feels like it already covers this broadcast_buffers need, but I must be missing something |
I think the module buffers are separate list of tensors that are different from the module params. They are not updated by the optimizer. Check the source of the original ddp code linked by Myle. The buffers are also part of the module's parameters but NOT in the params list and not updated by the optimizer. They are part of the checkpoint and update by the layers (like BN, but without backprop). |
ah thanks, yes makes a lot of sense, I read too fast and was thinking about the model params. Ok, this is not sync'ed indeed |
🚀 Feature
We should add support for the
broadcast_buffers
flag to OssDdp.Motivation
Distributed training with BatchNorm requires it. We removed it from the fairseq implementation because it slows things down a bit, but for the generalized implementation here we should add it back (as a configurable option).
Additional context
See documentation for
broadcast_buffers
in the main DDP module: https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.htmlThe text was updated successfully, but these errors were encountered: