Skip to content

Commit

Permalink
val.py refactor (#4053)
Browse files Browse the repository at this point in the history
* val.py refactor

* cleanup

* cleanup

* cleanup

* cleanup

* save after eval

* opt.imgsz bug fix

* wandb refactor

* dataloader to train_loader

* capitalize global variables

* runs/hub/exp to runs/detect/exp

* refactor wandb logging

* Refactor wandb operations (#4061)

Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
  • Loading branch information
glenn-jocher and AyushExel committed Jul 19, 2021
1 parent 9dd33fd commit f7d8562
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 203 deletions.
6 changes: 3 additions & 3 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from utils.general import check_img_size, check_requirements, check_imshow, colorstr, non_max_suppression, \
apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.plots import colors, plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized
from utils.torch_utils import select_device, load_classifier, time_sync


@torch.no_grad()
Expand Down Expand Up @@ -100,14 +100,14 @@ def run(weights='yolov5s.pt', # model.pt path(s)
img = img.unsqueeze(0)

# Inference
t1 = time_synchronized()
t1 = time_sync()
pred = model(img,
augment=augment,
visualize=increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False)[0]

# Apply NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
t2 = time_synchronized()
t2 = time_sync()

# Apply Classifier
if classify:
Expand Down
33 changes: 19 additions & 14 deletions models/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# YOLOv5 common modules

import logging
from copy import copy
from pathlib import Path, PosixPath

Expand All @@ -15,7 +16,9 @@
from utils.datasets import exif_transpose, letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box
from utils.plots import colors, plot_one_box
from utils.torch_utils import time_synchronized
from utils.torch_utils import time_sync

LOGGER = logging.getLogger(__name__)


def autopad(k, p=None): # kernel, padding
Expand Down Expand Up @@ -226,7 +229,7 @@ def __init__(self, model):
self.model = model.eval()

def autoshape(self):
print('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
return self

@torch.no_grad()
Expand All @@ -240,7 +243,7 @@ def forward(self, imgs, size=640, augment=False, profile=False):
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images

t = [time_synchronized()]
t = [time_sync()]
p = next(self.model.parameters()) # for device and type
if isinstance(imgs, torch.Tensor): # torch
with amp.autocast(enabled=p.device.type != 'cpu'):
Expand Down Expand Up @@ -270,19 +273,19 @@ def forward(self, imgs, size=640, augment=False, profile=False):
x = np.stack(x, 0) if n > 1 else x[0][None] # stack
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
t.append(time_synchronized())
t.append(time_sync())

with amp.autocast(enabled=p.device.type != 'cpu'):
# Inference
y = self.model(x, augment, profile)[0] # forward
t.append(time_synchronized())
t.append(time_sync())

# Post-process
y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])

t.append(time_synchronized())
t.append(time_sync())
return Detections(imgs, y, files, t, self.names, x.shape)


Expand Down Expand Up @@ -323,31 +326,33 @@ def display(self, pprint=False, show=False, save=False, crop=False, render=False

im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
if pprint:
print(str.rstrip(', '))
LOGGER.info(str.rstrip(', '))
if show:
im.show(self.files[i]) # show
if save:
f = self.files[i]
im.save(save_dir / f) # save
print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n')
if i == self.n - 1:
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to '{save_dir}'")
if render:
self.imgs[i] = np.asarray(im)

def print(self):
self.display(pprint=True) # print results
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t)
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
self.t)

def show(self):
self.display(show=True) # show results

def save(self, save_dir='runs/hub/exp'):
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir
def save(self, save_dir='runs/detect/exp'):
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir
self.display(save=True, save_dir=save_dir) # save results

def crop(self, save_dir='runs/hub/exp'):
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir
def crop(self, save_dir='runs/detect/exp'):
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir
self.display(crop=True, save_dir=save_dir) # crop results
print(f'Saved results to {save_dir}\n')
LOGGER.info(f'Saved results to {save_dir}\n')

def render(self):
self.display(render=True) # render results
Expand Down
43 changes: 21 additions & 22 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""

import argparse
import logging
import sys
from copy import deepcopy
from pathlib import Path
Expand All @@ -18,15 +17,15 @@
from utils.autoanchor import check_anchor_order
from utils.general import make_divisible, check_file, set_logging
from utils.plots import feature_visualization
from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
from utils.torch_utils import time_sync, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
select_device, copy_attr

try:
import thop # for FLOPs computation
except ImportError:
thop = None

logger = logging.getLogger(__name__)
LOGGER = logging.getLogger(__name__)


class Detect(nn.Module):
Expand Down Expand Up @@ -90,15 +89,15 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i
# Define model
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
if nc and nc != self.yaml['nc']:
logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
self.yaml['nc'] = nc # override yaml value
if anchors:
logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
self.yaml['anchors'] = round(anchors) # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
self.inplace = self.yaml.get('inplace', True)
# logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
# LOGGER.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])

# Build strides, anchors
m = self.model[-1] # Detect()
Expand All @@ -110,12 +109,12 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i
check_anchor_order(m)
self.stride = m.stride
self._initialize_biases() # only run once
# logger.info('Strides: %s' % m.stride.tolist())
# LOGGER.info('Strides: %s' % m.stride.tolist())

# Init weights, biases
initialize_weights(self)
self.info()
logger.info('')
LOGGER.info('')

def forward(self, x, augment=False, profile=False, visualize=False):
if augment:
Expand Down Expand Up @@ -143,13 +142,13 @@ def forward_once(self, x, profile=False, visualize=False):

if profile:
o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
t = time_synchronized()
t = time_sync()
for _ in range(10):
_ = m(x)
dt.append((time_synchronized() - t) * 100)
dt.append((time_sync() - t) * 100)
if m == self.model[0]:
logger.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}")
logger.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}")
LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')

x = m(x) # run
y.append(x if m.i in self.save else None) # save output
Expand All @@ -158,7 +157,7 @@ def forward_once(self, x, profile=False, visualize=False):
feature_visualization(x, m.type, m.i, save_dir=visualize)

if profile:
logger.info('%.1fms total' % sum(dt))
LOGGER.info('%.1fms total' % sum(dt))
return x

def _descale_pred(self, p, flips, scale, img_size):
Expand Down Expand Up @@ -192,16 +191,16 @@ def _print_biases(self):
m = self.model[-1] # Detect() module
for mi in m.m: # from
b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
logger.info(
LOGGER.info(
('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))

# def _print_weights(self):
# for m in self.model.modules():
# if type(m) is Bottleneck:
# logger.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
# LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights

def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
logger.info('Fusing layers... ')
LOGGER.info('Fusing layers... ')
for m in self.model.modules():
if type(m) is Conv and hasattr(m, 'bn'):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
Expand All @@ -213,19 +212,19 @@ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
def nms(self, mode=True): # add or remove NMS module
present = type(self.model[-1]) is NMS # last layer is NMS
if mode and not present:
logger.info('Adding NMS... ')
LOGGER.info('Adding NMS... ')
m = NMS() # module
m.f = -1 # from
m.i = self.model[-1].i + 1 # index
self.model.add_module(name='%s' % m.i, module=m) # add
self.eval()
elif not mode and present:
logger.info('Removing NMS... ')
LOGGER.info('Removing NMS... ')
self.model = self.model[:-1] # remove
return self

def autoshape(self): # add AutoShape module
logger.info('Adding AutoShape... ')
LOGGER.info('Adding AutoShape... ')
m = AutoShape(self) # wrap model
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
return m
Expand All @@ -235,7 +234,7 @@ def info(self, verbose=False, img_size=640): # print model information


def parse_model(d, ch): # model_dict, input_channels(3)
logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
Expand Down Expand Up @@ -279,7 +278,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
t = str(m)[8:-2].replace('__main__.', '') # module type
np = sum([x.numel() for x in m_.parameters()]) # number params
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
LOGGER.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
if i == 0:
Expand Down Expand Up @@ -308,5 +307,5 @@ def parse_model(d, ch): # model_dict, input_channels(3)
# Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
# from torch.utils.tensorboard import SummaryWriter
# tb_writer = SummaryWriter('.')
# logger.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
# LOGGER.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
# tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph
Loading

0 comments on commit f7d8562

Please sign in to comment.