Skip to content

Commit

Permalink
PyTorch Hub results.render() (ultralytics#1897)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jan 11, 2021
1 parent 0d36eec commit 813f8b6
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 14 deletions.
19 changes: 13 additions & 6 deletions models/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file contains modules common to various models

import math

import numpy as np
import requests
import torch
Expand Down Expand Up @@ -240,27 +241,29 @@ def __init__(self, imgs, pred, names=None):
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
self.n = len(self.pred)

def display(self, pprint=False, show=False, save=False):
def display(self, pprint=False, show=False, save=False, render=False):
colors = color_list()
for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
str = f'Image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
if pred is not None:
for c in pred[:, -1].unique():
n = (pred[:, -1] == c).sum() # detections per class
str += f'{n} {self.names[int(c)]}s, ' # add to string
if show or save:
if show or save or render:
img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
for *box, conf, cls in pred: # xyxy, confidence, class
# str += '%s %.2f, ' % (names[int(cls)], conf) # label
ImageDraw.Draw(img).rectangle(box, width=4, outline=colors[int(cls) % 10]) # plot
if pprint:
print(str)
if show:
img.show(f'Image {i}') # show
if save:
f = f'results{i}.jpg'
str += f"saved to '{f}'"
img.save(f) # save
if show:
img.show(f'Image {i}') # show
if pprint:
print(str)
if render:
self.imgs[i] = np.asarray(img)

def print(self):
self.display(pprint=True) # print results
Expand All @@ -271,6 +274,10 @@ def show(self):
def save(self):
self.display(save=True) # save results

def render(self):
self.display(render=True) # render results
return self.imgs

def __len__(self):
return self.n

Expand Down
11 changes: 6 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
check_requirements, print_mutation, set_logging, one_cycle
check_requirements, print_mutation, set_logging, one_cycle, colorstr
from utils.google_utils import attempt_download
from utils.loss import compute_loss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
Expand All @@ -44,7 +44,7 @@


def train(hyp, opt, device, tb_writer=None, wandb=None):
logger.info(f'Hyperparameters {hyp}')
logger.info(colorstr('blue', 'bold', 'Hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank

Expand Down Expand Up @@ -233,9 +233,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda)
logger.info('Image sizes %g train, %g test\n'
'Using %g dataloader workers\nLogging results to %s\n'
'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, save_dir, epochs))
logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
f'Using {dataloader.num_workers} dataloader workers\n'
f'Logging results to {save_dir}\n'
f'Starting training for {epochs} epochs...')
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
model.train()

Expand Down
7 changes: 4 additions & 3 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
torch.set_printoptions(linewidth=320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads


def set_logging(rank=-1):
Expand Down Expand Up @@ -117,7 +118,7 @@ def one_cycle(y1=0.0, y2=1.0, steps=100):

def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
*prefix, str = input # color arguments, string
*prefix, string = input # color arguments, string
colors = {'black': '\033[30m', # basic colors
'red': '\033[31m',
'green': '\033[32m',
Expand All @@ -136,9 +137,9 @@ def colorstr(*input):
'bright_white': '\033[97m',
'end': '\033[0m', # misc
'bold': '\033[1m',
'undelrine': '\033[4m'}
'underline': '\033[4m'}

return ''.join(colors[x] for x in prefix) + str + colors['end']
return ''.join(colors[x] for x in prefix) + f'{string}' + colors['end']


def labels_to_class_weights(labels, nc=80):
Expand Down

0 comments on commit 813f8b6

Please sign in to comment.