From 2ac4b634c745cc46c4728e682c6da66f79f6416a Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 16 Sep 2022 17:25:44 +0200 Subject: [PATCH] AutoShape explicit arguments fix (#9443) * AutoShape explicit arguments fix Signed-off-by: Glenn Jocher * Update common.py Signed-off-by: Glenn Jocher Signed-off-by: Glenn Jocher --- models/common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/models/common.py b/models/common.py index debbc2d03f60..85b82e10a4e1 100644 --- a/models/common.py +++ b/models/common.py @@ -633,7 +633,7 @@ def forward(self, ims, size=640, augment=False, profile=False): autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference if isinstance(ims, torch.Tensor): # torch with amp.autocast(autocast): - return self.model(ims.to(p.device).type_as(p), augment, profile) # inference + return self.model(ims.to(p.device).type_as(p), augment=augment) # inference # Pre-process n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images @@ -662,7 +662,7 @@ def forward(self, ims, size=640, augment=False, profile=False): with amp.autocast(autocast): # Inference with dt[1]: - y = self.model(x, augment, profile) # forward + y = self.model(x, augment=augment) # forward # Post-process with dt[2]: @@ -696,7 +696,7 @@ def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None): self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized self.n = len(self.pred) # number of images (batch size) self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms) - self.s = shape # inference BCHW shape + self.s = tuple(shape) # inference BCHW shape def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')): crops = [] @@ -726,7 +726,7 @@ def display(self, pprint=False, show=False, save=False, crop=False, render=False im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np if pprint: - print(s.rstrip(', ')) + LOGGER.info(s.rstrip(', ')) if show: im.show(self.files[i]) # show if save: @@ -743,7 +743,7 @@ def display(self, pprint=False, show=False, save=False, crop=False, render=False def print(self): self.display(pprint=True) # print results - print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t) + LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t) def show(self, labels=True): self.display(show=True, labels=labels) # show results