Skip to content

Commit

Permalink
PyTorch Hub results.save('path/to/dir') (#2179)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Feb 11, 2021
1 parent a5d5f92 commit 404749a
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions models/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file contains modules common to various models

import math
from pathlib import Path

import numpy as np
import requests
Expand Down Expand Up @@ -241,7 +242,7 @@ def __init__(self, imgs, pred, names=None):
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
self.n = len(self.pred)

def display(self, pprint=False, show=False, save=False, render=False):
def display(self, pprint=False, show=False, save=False, render=False, save_dir=''):
colors = color_list()
for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
Expand All @@ -259,7 +260,7 @@ def display(self, pprint=False, show=False, save=False, render=False):
if show:
img.show(f'image {i}') # show
if save:
f = f'results{i}.jpg'
f = Path(save_dir) / f'results{i}.jpg'
img.save(f) # save
print(f"{'Saving' * (i == 0)} {f},", end='' if i < self.n - 1 else ' done.\n')
if render:
Expand All @@ -271,8 +272,8 @@ def print(self):
def show(self):
self.display(show=True) # show results

def save(self):
self.display(save=True) # save results
def save(self, save_dir=''):
self.display(save=True, save_dir=save_dir) # save results

def render(self):
self.display(render=True) # render results
Expand Down

0 comments on commit 404749a

Please sign in to comment.