Skip to content

Commit

Permalink
AutoShape explicit arguments fix (#9443)
Browse files Browse the repository at this point in the history
* AutoShape explicit arguments fix

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* Update common.py

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
glenn-jocher committed Sep 16, 2022
1 parent 03f2ca8 commit 2ac4b63
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def forward(self, ims, size=640, augment=False, profile=False):
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
if isinstance(ims, torch.Tensor): # torch
with amp.autocast(autocast):
return self.model(ims.to(p.device).type_as(p), augment, profile) # inference
return self.model(ims.to(p.device).type_as(p), augment=augment) # inference

# Pre-process
n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
Expand Down Expand Up @@ -662,7 +662,7 @@ def forward(self, ims, size=640, augment=False, profile=False):
with amp.autocast(autocast):
# Inference
with dt[1]:
y = self.model(x, augment, profile) # forward
y = self.model(x, augment=augment) # forward

# Post-process
with dt[2]:
Expand Down Expand Up @@ -696,7 +696,7 @@ def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
self.n = len(self.pred) # number of images (batch size)
self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
self.s = shape # inference BCHW shape
self.s = tuple(shape) # inference BCHW shape

def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
crops = []
Expand Down Expand Up @@ -726,7 +726,7 @@ def display(self, pprint=False, show=False, save=False, crop=False, render=False

im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
if pprint:
print(s.rstrip(', '))
LOGGER.info(s.rstrip(', '))
if show:
im.show(self.files[i]) # show
if save:
Expand All @@ -743,7 +743,7 @@ def display(self, pprint=False, show=False, save=False, crop=False, render=False

def print(self):
self.display(pprint=True) # print results
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t)
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t)

def show(self, labels=True):
self.display(show=True, labels=labels) # show results
Expand Down

0 comments on commit 2ac4b63

Please sign in to comment.