Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About torch.nn.utils clipping functions #2884

Open
carmocca opened this issue Apr 15, 2021 · 3 comments
Open

About torch.nn.utils clipping functions #2884

carmocca opened this issue Apr 15, 2021 · 3 comments
Labels
nostale Do not consider for staleness

Comments

@carmocca
Copy link

carmocca commented Apr 15, 2021

❓ Questions and Help

Hello,

I noticed xla patches the torch.nn.utils.clip_grad_norm_ function. Allegedly due to performance issues. From https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md:

Use torch.where to substitute control flow when applicable. E.g. The control flow with item() used in clip_grad_norm_ is problematic and impacts performance, so we have patched clip_grad_norm_ by calling torch.where instead, which gives us a dramatic performance improvement.

PyTorch's implementation of clip_grad_norm_ was updated in pytorch/pytorch#32020 (which made it into 1.5.0) so the computation no longer relies on .item()

Does that mean xla's patch is no longer necessary on torch >= 1.5.0?

Additionally, are there any known issues for torch.nn.utils.clip_grad_value_? I am assuming that's not the case since there are no .item() calls, but could not find any confirmation anywhere.

Thanks!

@JackCaoG
Copy link
Collaborator

Thanks! We meant to test this out. I agree that torch.nn.utils.clip_grad_norm_ seems fine now. We need to double check the speed on TPU before removing the patch. @ailzhang tried this once and run into a weird error last time.

@stale
Copy link

stale bot commented Jun 11, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale Has not had recent activity label Jun 11, 2021
@stale stale bot closed this as completed Jun 22, 2021
@JackCaoG JackCaoG added the nostale Do not consider for staleness label Aug 22, 2021
@JackCaoG JackCaoG reopened this Aug 22, 2021
@stale stale bot removed the stale Has not had recent activity label Aug 22, 2021
@carmocca
Copy link
Author

carmocca commented May 5, 2023

It would be nice to revisit this.

In Lightning, we are seeing this error when torch_xla is imported

E       RuntimeError: The norm of order 2.0 for a gradient from `parameters` is non-finite, so it cannot be clipped. This error can be disabled with `error_if_nonfinite=False`

For tests that run and pass normally if torch_xla had not been imported.

As a workaround, I'm doing:

if hasattr(torch.nn.utils.clip_grad_norm_, "_orig"):
    # hacky workaround to https://github.com/pytorch/xla/issues/2884: undo xla patching on import
    torch.nn.utils.clip_grad_norm_ = torch.nn.utils.clip_grad_norm_._orig

(in Lightning-AI/pytorch-lightning#17519)

This global patching is particularly problematic because it is done regardless of whether you actually end up using XLA at all.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
nostale Do not consider for staleness
Projects
None yet
Development

No branches or pull requests

2 participants