Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate if model_without_ddp is needed #4385

Closed
prabhat00155 opened this issue Sep 9, 2021 · 9 comments
Closed

Investigate if model_without_ddp is needed #4385

prabhat00155 opened this issue Sep 9, 2021 · 9 comments

Comments

@prabhat00155
Copy link
Contributor

prabhat00155 commented Sep 9, 2021

🐛 Describe the bug

Investigate if we need model_without_ddp in the training script.

model_without_ddp = model

Versions

N/A

cc @datumbox

@prabhat00155
Copy link
Contributor Author

While running a fairly limited training experiment for #4381, it was observed that model worked without any issue instead of model_without_ddp. Hence, it may be worth investigating if we need model_without_ddp at all.

@datumbox
Copy link
Contributor

datumbox commented Sep 9, 2021

@fmassa Prabhat did investigations where he used the model instead of model_without_ddp on the checkpoint mechanism and it seems that the loading/saving/resuming worked fine. Could you please provide some additional information on why using the non-parallelized version was needed in the first place? What is it expected to fail if we use the model instead of model_without_ddp? It might be the case that whatever limitation used to exist, is no longer an issue.

@fmassa
Copy link
Member

fmassa commented Sep 11, 2021

Hi,

The reason why we have model_without_ddp was to be able to serialize the model weights without the module. prefix that gets attached to DDP models. Otherwise, you wouldn't be able to load the checkpoints from the trained model on the CPU or on a model without DDP, or would have to perform checkpoint renaming to manually remove the prefix. The easiest thing to do was to keep around a version of the model that didn't have DDP, so that we don't have to deal with this.

@datumbox
Copy link
Contributor

@fmassa Thanks for the info.

Prabhat found that using model instead of model_without_ddp on the following snippet works fine for EMA. He was able to save/load checkpoints without a problem:

model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay)

We did end up changing it to model_without_ddp just to be safe but I wonder if the limitation that existed int the past is no longer there and we can replace model_without_ddp with model everywhere. Thoughts?

@prabhat00155
Copy link
Contributor Author

I ran training on resnet18 for 100 epochs, where I cancelled the training around epoch 34 and resumed it again, with and without model_without_ddp(I just do model_without_ddp = model after the if block).
Here is the diff:

diff --git a/references/classification/train.py b/references/classification/train.py
index a3e4c9ad..d49dba56 100644
--- a/references/classification/train.py
+++ b/references/classification/train.py
@@ -201,6 +201,7 @@ def main(args):
     if args.distributed:
         model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
         model_without_ddp = model.module
+    model_without_ddp = model
 
     model_ema = None
     if args.model_ema:

I didn't see any significant difference in the result.
training_with_changes1.txt
training_with_changes2.txt
training_with_no_changes1.txt
training_with_no_changes2.txt

@fmassa
Copy link
Member

fmassa commented Sep 14, 2021

@prabhat00155 what you would have needed would be to try to load the serialized checkpoint from a model that hasn't been wrapped up in DDP yet.

Something as simple as

model = torchvision.models.resnet50()
model.load_state_dict(path_to_checkpoint)

would fail due to the added module. that gets appended due to DDP, so you would need to use tools like torch.nn.modules.utils.consume_prefix_in_state_dict_if_present, which are pretty new and was added to PyTorch less than 6 months ago pytorch/pytorch#53224

I would be ok removing the current model_without_ddp in torchvision if we use a newer and better way that is provided by PyTorch, but I'm not sure that the current torch.nn.modules.utils.consume_prefix_in_state_dict_if_present is enough for that (at least it would need some thinking to be able to make sure all cases are handled properly)

@datumbox
Copy link
Contributor

@fmassa Thanks for providing background on why this was added.

So basically, this workaround increases user-friendliness on how the weights are handled after the training is completed (hence outside of the train.py script).

Two thoughts on eliminating the non-parallelized version:

  1. For the users of our library that just take the pre-trained weights this has no effect. It's us the contributors who train the models, prepare them (verify, produce hashes, load on S3 etc) and make them available. So we could easily adjust our process to do the extra step of removing .module without real issues.
  2. For the users of the references, this could potentially become a source of frustration as users would have to take the checkpoints, remove .module with the aforementioned method and then use the weights.

I don't have a very strong opinion over this, but I'm leaning towards keeping it for the time being. Yes it's a bit annoying to keep the non-parallelized version around but it does eliminate potential frustration for new users of the library. Thoughts?

@prabhat00155
Copy link
Contributor Author

Thanks @fmassa and @datumbox! I think it makes sense to keep it then. I'll close this issue.

@fmassa
Copy link
Member

fmassa commented Sep 16, 2021

I don't have a very strong opinion over this, but I'm leaning towards keeping it for the time being. Yes it's a bit annoying to keep the non-parallelized version around but it does eliminate potential frustration for new users of the library. Thoughts?

Yes, that's the main reason why I'd lean on keeping it. The use-case I had was indeed for a user to run evaluation on a checkpoint in a different machine with a single GPU (or without GPUs), and the current setup allows them to get this seamlessly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants