diff --git a/models/common.py b/models/common.py index 90bfef5124b3..e79b8a9d2644 100644 --- a/models/common.py +++ b/models/common.py @@ -365,6 +365,7 @@ def __init__(self, imgs, pred, files, times=None, names=None, shape=None): self.s = shape # inference BCHW shape def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')): + crops = [] for i, (im, pred) in enumerate(zip(self.imgs, self.pred)): str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' if pred.shape[0]: @@ -376,7 +377,9 @@ def display(self, pprint=False, show=False, save=False, crop=False, render=False for *box, conf, cls in reversed(pred): # xyxy, confidence, class label = f'{self.names[int(cls)]} {conf:.2f}' if crop: - save_one_box(box, im, file=save_dir / 'crops' / self.names[int(cls)] / self.files[i]) + file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None + crops.append({'box': box, 'conf': conf, 'cls': cls, 'label': label, + 'im': save_one_box(box, im, file=file, save=save)}) else: # all others annotator.box_label(box, label, color=colors(cls)) im = annotator.im @@ -395,6 +398,10 @@ def display(self, pprint=False, show=False, save=False, crop=False, render=False LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}") if render: self.imgs[i] = np.asarray(im) + if crop: + if save: + LOGGER.info(f'Saved results to {save_dir}\n') + return crops def print(self): self.display(pprint=True) # print results @@ -408,10 +415,9 @@ def save(self, save_dir='runs/detect/exp'): save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir self.display(save=True, save_dir=save_dir) # save results - def crop(self, save_dir='runs/detect/exp'): - save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir - self.display(crop=True, save_dir=save_dir) # crop results - LOGGER.info(f'Saved results to {save_dir}\n') + def crop(self, save=True, save_dir='runs/detect/exp'): + save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) if save else None + return self.display(crop=True, save=save, save_dir=save_dir) # crop results def render(self): self.display(render=True) # render results