diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 4929d21cdf83..67437d208242 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -79,11 +79,11 @@ def reshape_classifier_output(model, n=1000): elif isinstance(m, nn.Sequential): types = [type(x) for x in m] if nn.Linear in types: - i = types.index(nn.Linear) # nn.Linear index + i = len(types) - 1 - types[::-1].index(nn.Linear) # Last nn.Linear index if m[i].out_features != n: m[i] = nn.Linear(m[i].in_features, n) elif nn.Conv2d in types: - i = types.index(nn.Conv2d) # nn.Conv2d index + i = len(types) - 1 - types[::-1].index(nn.Conv2d) # Last nn.Conv2d index if m[i].out_channels != n: m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)