diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 35ef0116d263..6baa9d5061e5 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -90,7 +90,7 @@ def prune(model, amount=0.3): import torch.nn.utils.prune as prune print('Pruning model... ', end='') for name, m in model.named_modules(): - if isinstance(m, torch.nn.Conv2d): + if isinstance(m, nn.Conv2d): prune.l1_unstructured(m, name='weight', amount=amount) # prune prune.remove(m, 'weight') # make permanent print(' %.3g global sparsity' % sparsity(model)) @@ -100,12 +100,12 @@ def fuse_conv_and_bn(conv, bn): # https://tehnokv.com/posts/fusing-batchnorm-and-conv/ with torch.no_grad(): # init - fusedconv = torch.nn.Conv2d(conv.in_channels, - conv.out_channels, - kernel_size=conv.kernel_size, - stride=conv.stride, - padding=conv.padding, - bias=True) + fusedconv = nn.Conv2d(conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + bias=True).to(conv.weight.device) # prepare filters w_conv = conv.weight.clone().view(conv.out_channels, -1) @@ -113,10 +113,7 @@ def fuse_conv_and_bn(conv, bn): fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) # prepare spatial bias - if conv.bias is not None: - b_conv = conv.bias - else: - b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) + b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) @@ -159,8 +156,8 @@ def load_classifier(name='resnet101', n=2): # Reshape output to n classes filters = model.fc.weight.shape[1] - model.fc.bias = torch.nn.Parameter(torch.zeros(n), requires_grad=True) - model.fc.weight = torch.nn.Parameter(torch.zeros(n, filters), requires_grad=True) + model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True) + model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True) model.fc.out_features = n return model