You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def__build_model(self):
self.c_d1=nn.Linear(in_features=self.hparams.in_features,
out_features=self.hparams.hidden_dim)
# move the layer definition up hereself.c_d2=nn.Linear(in_features=self.hparams.hidden_dim,
out_features=self.hparams.out_features)
self.c_d1_bn=nn.BatchNorm1d(self.hparams.hidden_dim)
self.c_d1_drop=nn.Dropout(self.hparams.drop_prob)
We get an error message because input size does not match (for this order).
Expected behavior
Input output sizes are computed in order of execution, not definition. This is important because PyTorch graphs are dynamically built on each forward, so order of execution of each layer is not known beforehand.
Proposed Fix
I propose to install a forward hook on each submodule and compute the sizes that way. I have started to validate the fix already and would like to submit a PR very soon if you agree.
Additional Context
It could be confusing to a user to see this error, they might think something is wrong with their code.
The text was updated successfully, but these errors were encountered:
🐛 Bug
To Reproduce
pl_examples/basic_examples/LightningTemplateModel.py
__build_model
method fromto:
We get an error message because input size does not match (for this order).
Expected behavior
Input output sizes are computed in order of execution, not definition. This is important because PyTorch graphs are dynamically built on each forward, so order of execution of each layer is not known beforehand.
Proposed Fix
I propose to install a forward hook on each submodule and compute the sizes that way.
I have started to validate the fix already and would like to submit a PR very soon if you agree.
Additional Context
It could be confusing to a user to see this error, they might think something is wrong with their code.
The text was updated successfully, but these errors were encountered: