Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace usage of copy.deepcopy() - Convolution/Batch Norm Fuser in FX #2645

Closed
wants to merge 2 commits into from

Conversation

MirMustafaAli
Copy link

@MirMustafaAli MirMustafaAli commented Nov 4, 2023

Fixes #2331

Description

Replacing the use of copy.deepcopy() in Convolution/Batch Norm Fuser in FX tutorials with use of load_state_dict as mentioned in Deep copying PyTorch modules.

cc @eellison @suo @gmagogsfm @jamesr66a @msaroufim @SherlockNoMad @albanD @sekyondaMeta @svekars @carljparker @NicolasHug @kit1980 @subramen

Copy link

pytorch-bot bot commented Nov 4, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/2645

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c8436a4 with merge base f05f050 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@MirMustafaAli MirMustafaAli marked this pull request as draft November 4, 2023 07:49
@MirMustafaAli MirMustafaAli reopened this Nov 4, 2023
@MirMustafaAli MirMustafaAli marked this pull request as ready for review November 4, 2023 08:04
@svekars svekars added the fx issues related to fx label Nov 6, 2023
@github-actions github-actions bot removed the fx issues related to fx label Nov 6, 2023
@@ -104,7 +104,9 @@ def fuse_conv_bn_eval(conv, bn):
module `C` such that C(x) == B(A(x)) in inference mode.
"""
assert(not (conv.training or bn.training)), "Fusion only for eval!"
fused_conv = copy.deepcopy(conv)
fused_conv = type(conv)(conv.in_channels, conv.out_channels, conv.kernel_size)
Copy link
Member

@msaroufim msaroufim Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix seems weird? The right to do feels like its implementing a proper __deepcopy__() for nn modules? @albanD

This popular thread seems to validate this fix https://discuss.pytorch.org/t/deep-copying-pytorch-modules/13514 but idk if this is what we want people to actually do?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix seems weird? The right to do feels like its implementing a proper __deepcopy__() for nn modules? @albanD

This popular thread seems to validate this fix https://discuss.pytorch.org/t/deep-copying-pytorch-modules/13514 but idk if this is what we want people to actually do?

we can save and load the model, found from 2385 . Other than this, is there other way which i am missing that will help me make a plausible fix ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that Module is a complex enough class that deepcopying it is very challenging (the same way we don't recommend you serialize it as-is but only the state_dict).
deepcopy() work in most simple cases but it is expected to fail sometimes.
If you only have a regular Conv2d kernel, doing deepcopy or a new constructor is pretty much the same thing though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how would you want me to proceed with the PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what @albanD is saying is that in this specific case deepcopy-ing a conv layer is just fine, i.e. the original code probably doesn't need to be changed.

@github-actions github-actions bot added fx issues related to fx core Tutorials of any level of difficulty related to the core pytorch functionality and removed cla signed labels Nov 12, 2023
@@ -150,7 +152,9 @@ def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torc


def fuse(model: torch.nn.Module) -> torch.nn.Module:
model = copy.deepcopy(model)
model, state_dict = type(model)(), model.state_dict()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's only going to work for models that do not take any parameters to __init__().

@svekars with this + https://github.com/pytorch/tutorials/pull/2645/files#r1391396959, I'm tempted to think that the originally issue is probably irrelevant for this tutorial. Even if copy.deepcopy(model) may not be perfect, it's still better than any alternative that has been proposed so far. Perhaps we could close the original issue and still provide credits to the contributor for their efforts?

@svekars
Copy link
Contributor

svekars commented Nov 14, 2023

Closing this and the issue and will give half credit.

@svekars svekars closed this Nov 14, 2023
@MirMustafaAli MirMustafaAli deleted the fx_conv_deepcopy branch November 14, 2023 21:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed core Tutorials of any level of difficulty related to the core pytorch functionality docathon-h2-2023 easy fx issues related to fx
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Update tutorial to avoid use of copy.deepcopy()- Convolution/Batch Norm Fuser in FX
6 participants