From 8b9f511522bba6aae56a9798f27069980815d422 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 10 Feb 2021 16:10:43 -0800 Subject: [PATCH] PyTorch Hub results.save('path/to/dir') (#2179) --- models/common.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/models/common.py b/models/common.py index e8adb66293d5..7cfea01f223e 100644 --- a/models/common.py +++ b/models/common.py @@ -1,6 +1,7 @@ # This file contains modules common to various models import math +from pathlib import Path import numpy as np import requests @@ -241,7 +242,7 @@ def __init__(self, imgs, pred, names=None): self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized self.n = len(self.pred) - def display(self, pprint=False, show=False, save=False, render=False): + def display(self, pprint=False, show=False, save=False, render=False, save_dir=''): colors = color_list() for i, (img, pred) in enumerate(zip(self.imgs, self.pred)): str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} ' @@ -259,7 +260,7 @@ def display(self, pprint=False, show=False, save=False, render=False): if show: img.show(f'image {i}') # show if save: - f = f'results{i}.jpg' + f = Path(save_dir) / f'results{i}.jpg' img.save(f) # save print(f"{'Saving' * (i == 0)} {f},", end='' if i < self.n - 1 else ' done.\n') if render: @@ -271,8 +272,8 @@ def print(self): def show(self): self.display(show=True) # show results - def save(self): - self.display(save=True) # save results + def save(self, save_dir=''): + self.display(save=True, save_dir=save_dir) # save results def render(self): self.display(render=True) # render results