-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incompatibility with torch.fx #188
Comments
Hmmm, Einops requires some information about shape, but torch.fx does not provide it during tracing (... breaking even more rules). Even number of dimensions is not available to perform validation, so that's just unlikely to work. I think ok walkaround would be pre-scripting these layers first, somewhat like
... but it does not work with torch.fx because torch.ScriptedModule has no implementation for serialization / deserialization. Another less appealing way is to trace modules:
This works... kinda. After tracing modules are really just a bunch of simple operations like torch.repeat, torch.transpose, etc., but torch.fx does not understand it and stores modules as pickles. Don't see a reasonable way forward here |
Hi, I stumbled across this bug too. Is the current fix to rewrite the model without einops? |
yes updated much later: maybe not, see proposals below |
Okay. Thank you for the quick response! |
Hello, all. I am using You may pass Please refer to the following links for example usage:
Hope this can help 😄 |
Thank you @Jongchan import torch
from einops import rearrange
torch.fx.wrap('rearrange') For some reason it fixed error during |
Describe the bug
When trying to call torchvision.models.feature_extraction.get_graph_node_names on a model that has an einops operation, the operation fail and gets the following error
Traceback (most recent call last):
File "/home/Documents/project/run.py", line 96, in run
print(get_graph_node_names(model.encoder))
File "/home/anaconda3/envs/base/lib/python3.9/site-packages/torchvision/models/feature_extraction.py", line 239, in get_graph_node_names
train_tracer.trace(model.train())
File "/home/anaconda3/envs/base/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py", line 566, in trace
self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
File "/home/Documents/project/encoder.py", line 192, in forward
latents = repeat(self.latents, " l d -> b l d", b=batch_size) * self.lr_mul
File "/home/anaconda3/envs/base/lib/python3.9/site-packages/einops/einops.py", line 537, in repeat
return reduce(tensor, pattern, reduction='repeat', **axes_lengths)
File "/home/anaconda3/envs/base/lib/python3.9/site-packages/einops/einops.py", line 410, in reduce
return _apply_recipe(recipe, tensor, reduction_type=reduction)
File "/home/anaconda3/base/scam/lib/python3.9/site-packages/einops/einops.py", line 231, in _apply_recipe
backend = get_backend(tensor)
File "/home/anaconda3/base/scam/lib/python3.9/site-packages/einops/_backends.py", line 52, in get_backend
raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor)))
RuntimeError: Tensor type unknown to einops <class 'torch.fx.proxy.Proxy'>
Reproduction steps
Steps to reproduce the behavior:
Create a model with repeat operation, then call get_graph_node_names on it
Expected behavior
rearrange should work with torch.fx
Your platform
einops 0.4.1, torch 1.11, torchvision 0.12, python 3.9
The text was updated successfully, but these errors were encountered: