Skip to content

Commit

Permalink
update fuse_conv_and_bn()
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jul 6, 2020
1 parent 6b95d6d commit 121d90b
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def prune(model, amount=0.3):
import torch.nn.utils.prune as prune
print('Pruning model... ', end='')
for name, m in model.named_modules():
if isinstance(m, torch.nn.Conv2d):
if isinstance(m, nn.Conv2d):
prune.l1_unstructured(m, name='weight', amount=amount) # prune
prune.remove(m, 'weight') # make permanent
print(' %.3g global sparsity' % sparsity(model))
Expand All @@ -100,23 +100,20 @@ def fuse_conv_and_bn(conv, bn):
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
with torch.no_grad():
# init
fusedconv = torch.nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=True)
fusedconv = nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=True).to(conv.weight.device)

# prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))

# prepare spatial bias
if conv.bias is not None:
b_conv = conv.bias
else:
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device)
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)

Expand Down Expand Up @@ -159,8 +156,8 @@ def load_classifier(name='resnet101', n=2):

# Reshape output to n classes
filters = model.fc.weight.shape[1]
model.fc.bias = torch.nn.Parameter(torch.zeros(n), requires_grad=True)
model.fc.weight = torch.nn.Parameter(torch.zeros(n, filters), requires_grad=True)
model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
model.fc.out_features = n
return model

Expand Down

0 comments on commit 121d90b

Please sign in to comment.