From 0695b27faf111412592b39db2ff5a8e754760331 Mon Sep 17 00:00:00 2001 From: marui <984603294@qq.com> Date: Wed, 29 May 2024 22:13:23 +0800 Subject: [PATCH] Fix reshape_classifier_output function to correctly reshape the final output layer --- utils/torch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 4929d21cdf83..67437d208242 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -79,11 +79,11 @@ def reshape_classifier_output(model, n=1000): elif isinstance(m, nn.Sequential): types = [type(x) for x in m] if nn.Linear in types: - i = types.index(nn.Linear) # nn.Linear index + i = len(types) - 1 - types[::-1].index(nn.Linear) # Last nn.Linear index if m[i].out_features != n: m[i] = nn.Linear(m[i].in_features, n) elif nn.Conv2d in types: - i = types.index(nn.Conv2d) # nn.Conv2d index + i = len(types) - 1 - types[::-1].index(nn.Conv2d) # Last nn.Conv2d index if m[i].out_channels != n: m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)