Skip to content

Commit

Permalink
Fix reshape_classifier_output function to correctly reshape the final…
Browse files Browse the repository at this point in the history
… output layer
  • Loading branch information
gokamisama committed May 29, 2024
1 parent 892e8a8 commit 0695b27
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def reshape_classifier_output(model, n=1000):
elif isinstance(m, nn.Sequential):
types = [type(x) for x in m]
if nn.Linear in types:
i = types.index(nn.Linear) # nn.Linear index
i = len(types) - 1 - types[::-1].index(nn.Linear) # Last nn.Linear index
if m[i].out_features != n:
m[i] = nn.Linear(m[i].in_features, n)
elif nn.Conv2d in types:
i = types.index(nn.Conv2d) # nn.Conv2d index
i = len(types) - 1 - types[::-1].index(nn.Conv2d) # Last nn.Conv2d index
if m[i].out_channels != n:
m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)

Expand Down

0 comments on commit 0695b27

Please sign in to comment.