From 537b39a4a6546cb95c4320ad11b5caa0d914ad69 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 28 Oct 2020 15:03:50 +0100 Subject: [PATCH] PyTorch 1.7.0 Compatibility Updates (#1233) * torch 1.7.0 compatibility updates * add inference verification --- hubconf.py | 8 ++++++++ models/experimental.py | 7 +++++++ models/yolo.py | 1 - utils/torch_utils.py | 2 +- 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/hubconf.py b/hubconf.py index cd14863ca8b4..cc210528c087 100644 --- a/hubconf.py +++ b/hubconf.py @@ -108,3 +108,11 @@ def yolov5x(pretrained=False, channels=3, classes=80): if __name__ == '__main__': model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example + model = model.fuse().eval().autoshape() # for autoshaping of PIL/cv2/np inputs and NMS + + # Verify inference + from PIL import Image + + img = Image.open('inference/images/zidane.jpg') + y = model(img) + print(y[0].shape) diff --git a/models/experimental.py b/models/experimental.py index 0b61027b9d2f..a2908a15cf32 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -136,6 +136,13 @@ def attempt_load(weights, map_location=None): attempt_download(w) model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model + # Compatibility updates + for m in model.modules(): + if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]: + m.inplace = True # pytorch 1.7.0 compatibility + elif type(m) is Conv: + m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + if len(model) == 1: return model[-1] # return model else: diff --git a/models/yolo.py b/models/yolo.py index 0d46054ed21c..e1c30baa271d 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -165,7 +165,6 @@ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers print('Fusing layers... ') for m in self.model.modules(): if type(m) is Conv and hasattr(m, 'bn'): - m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, 'bn') # remove batchnorm m.forward = m.fuseforward # update forward diff --git a/utils/torch_utils.py b/utils/torch_utils.py index f6818238452f..25eff07f3f44 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -74,7 +74,7 @@ def initialize_weights(model): elif t is nn.BatchNorm2d: m.eps = 1e-3 m.momentum = 0.03 - elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]: + elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]: m.inplace = True