From 2469c74123d6bf0429ac4dfd765d2480a0a241fe Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 29 Mar 2021 14:24:55 +0200 Subject: [PATCH 1/3] Speed profiling improvements --- hubconf.py | 7 ++++--- utils/torch_utils.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/hubconf.py b/hubconf.py index 0eaf70787e64..1e6b9c78ac6a 100644 --- a/hubconf.py +++ b/hubconf.py @@ -38,9 +38,10 @@ def create(name, pretrained, channels, classes, autoshape): fname = f'{name}.pt' # checkpoint filename attempt_download(fname) # download if not found locally ckpt = torch.load(fname, map_location=torch.device('cpu')) # load - state_dict = ckpt['model'].float().state_dict() # to FP32 - state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter - model.load_state_dict(state_dict, strict=False) # load + msd = model.state_dict() # model state_dict + csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 + csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter + model.load_state_dict(csd, strict=False) # load if len(ckpt['model'].names) == classes: model.names = ckpt['model'].names # set class names attribute if autoshape: diff --git a/utils/torch_utils.py b/utils/torch_utils.py index d6da0cae8945..d41b934d54ea 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -216,7 +216,7 @@ def model_info(model, verbose=False, img_size=640): from thop import profile stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input - flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS + flops = profile(model, inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS except (ImportError, Exception): From b7f5c0a13773179ee51cfd30db601ea76fb0d1cd Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 29 Mar 2021 14:36:01 +0200 Subject: [PATCH 2/3] Update torch_utils.py deepcopy() required to avoid adding elements to model. --- utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index d41b934d54ea..d6da0cae8945 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -216,7 +216,7 @@ def model_info(model, verbose=False, img_size=640): from thop import profile stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input - flops = profile(model, inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS + flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS except (ImportError, Exception): From 32261195750ae464f0bf12207c157fd9fcda1b8a Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 29 Mar 2021 15:07:28 +0200 Subject: [PATCH 3/3] Update torch_utils.py --- utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index d6da0cae8945..9991e5ec87d8 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -191,7 +191,7 @@ def fuse_conv_and_bn(conv, bn): # prepare filters w_conv = conv.weight.clone().view(conv.out_channels, -1) w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) - fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) + fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) # prepare spatial bias b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias