From 40996d11e0b38841aefbd70a17bf44df041343aa Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 10 Mar 2022 12:41:06 +0100 Subject: [PATCH] PyTorch 1.11.0 compatibility updates (#6932) Resolves `AttributeError: 'Upsample' object has no attribute 'recompute_scale_factor'` first raised in https://github.com/ultralytics/yolov5/issues/5499 --- models/experimental.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/models/experimental.py b/models/experimental.py index 463e5514a06e..01bdfe72db4f 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -94,21 +94,22 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True): model = Ensemble() for w in weights if isinstance(weights, list) else [weights]: ckpt = torch.load(attempt_download(w), map_location=map_location) # load - if fuse: - model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model - else: - model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse + ckpt = (ckpt['ema'] or ckpt['model']).float() # FP32 model + model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode # Compatibility updates for m in model.modules(): - if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]: - m.inplace = inplace # pytorch 1.7.0 compatibility - if type(m) is Detect: + t = type(m) + if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model): + m.inplace = inplace # torch 1.7.0 compatibility + if t is Detect: if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility delattr(m, 'anchor_grid') setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl) - elif type(m) is Conv: - m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + elif t is nn.Upsample: + m.recompute_scale_factor = None # torch 1.11.0 compatibility + elif t is Conv: + m._non_persistent_buffers_set = set() # torch 1.6.0 compatibility if len(model) == 1: return model[-1] # return model