diff --git a/hubconf.py b/hubconf.py index 0eaf70787e64..1e6b9c78ac6a 100644 --- a/hubconf.py +++ b/hubconf.py @@ -38,9 +38,10 @@ def create(name, pretrained, channels, classes, autoshape): fname = f'{name}.pt' # checkpoint filename attempt_download(fname) # download if not found locally ckpt = torch.load(fname, map_location=torch.device('cpu')) # load - state_dict = ckpt['model'].float().state_dict() # to FP32 - state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter - model.load_state_dict(state_dict, strict=False) # load + msd = model.state_dict() # model state_dict + csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 + csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter + model.load_state_dict(csd, strict=False) # load if len(ckpt['model'].names) == classes: model.names = ckpt['model'].names # set class names attribute if autoshape: diff --git a/utils/torch_utils.py b/utils/torch_utils.py index d6da0cae8945..9991e5ec87d8 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -191,7 +191,7 @@ def fuse_conv_and_bn(conv, bn): # prepare filters w_conv = conv.weight.clone().view(conv.out_channels, -1) w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) - fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) + fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) # prepare spatial bias b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias