Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Oct 5, 2020
1 parent c5d2331 commit 372ad4c
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 17 deletions.
4 changes: 1 addition & 3 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def create(name, pretrained, channels, classes):
state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
model.load_state_dict(state_dict, strict=False) # load

model.add_nms() # add NMS module
model.eval()
# model = model.autoshape() # cv2/PIL/np inference: predictions = model(cv2.imread('img.jpg'))
return model

except Exception as e:
Expand Down
31 changes: 27 additions & 4 deletions models/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# This file contains modules common to various models
import math

import numpy as np
import torch
import torch.nn as nn
from utils.general import non_max_suppression

from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords


def autopad(k, p=None): # kernel, padding
Expand Down Expand Up @@ -101,17 +104,37 @@ def forward(self, x):

class NMS(nn.Module):
# Non-Maximum Suppression (NMS) module
conf = 0.3 # confidence threshold
iou = 0.6 # IoU threshold
conf = 0.25 # confidence threshold
iou = 0.45 # IoU threshold
classes = None # (optional list) filter by class

def __init__(self, dimension=1):
def __init__(self):
super(NMS, self).__init__()

def forward(self, x):
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)


class autoShape(nn.Module):
# auto-reshape image size model wrapper
img_size = 640 # inference size (pixels)

def __init__(self, model):
super(autoShape, self).__init__()
self.model = model

def forward(self, x, shape=640, augment=False, profile=False): # x = cv2.imread('img.jpg')
x0shape = x.shape[:2]
p = next(self.model.parameters())
x, ratio, (dw, dh) = letterbox(x, new_shape=make_divisible(shape or max(x0shape), int(self.stride.max())))
x1shape = x.shape[:2]
x = np.ascontiguousarray(x[:, :, ::-1].transpose(2, 0, 1)) # BGR to RGB, to 3x640x640
x = torch.from_numpy(x).to(p.device).type_as(p).unsqueeze(0) / 255. # uint8 to fp16/32
x = self.model(x, augment, profile) # forward
x[0][:, :4] = scale_coords(x1shape, x[0][:, :4], x0shape)
return x


class Flatten(nn.Module):
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
@staticmethod
Expand Down
29 changes: 19 additions & 10 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import torch
import torch.nn as nn

from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape
from models.experimental import MixConv2d, CrossConv, C3
from utils.general import check_anchor_order, make_divisible, check_file, set_logging
from utils.torch_utils import (
time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, select_device)
from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
select_device, copy_attr


class Detect(nn.Module):
Expand Down Expand Up @@ -140,6 +140,7 @@ def forward_once(self, x, profile=False):
return x

def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
# https://arxiv.org/abs/1708.02002 section 3.3
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
m = self.model[-1] # Detect() module
for mi, s in zip(m.m, m.stride): # from
Expand Down Expand Up @@ -170,15 +171,27 @@ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
self.info()
return self

def add_nms(self): # fuse model Conv2d() + BatchNorm2d() layers
if type(self.model[-1]) is not NMS: # if missing NMS
print('Adding NMS module... ')
def nms(self, mode=True): # add or remove NMS module
present = type(self.model[-1]) is NMS # last layer is NMS
if mode and not present:
print('Adding NMS... ')
m = NMS() # module
m.f = -1 # from
m.i = self.model[-1].i + 1 # index
self.model.add_module(name='%s' % m.i, module=m) # add
self.eval()
elif not mode and present:
print('Removing NMS... ')
self.model = self.model[:-1] # remove
return self

def autoshape(self): # add autoShape module
print('Adding autoShape... ')
self.nms() # add NMS
m = autoShape(self) # wrap model
copy_attr(m, self, include=('names', 'stride', 'nc', 'autoshape'), exclude=()) # copy attributes
return m

def info(self, verbose=False): # print model information
model_info(self, verbose)

Expand Down Expand Up @@ -263,10 +276,6 @@ def parse_model(d, ch): # model_dict, input_channels(3)
# img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
# y = model(img, profile=True)

# ONNX export
# model.model[-1].export = True
# torch.onnx.export(model, img, opt.cfg.replace('.yaml', '.onnx'), verbose=True, opset_version=11)

# Tensorboard
# from torch.utils.tensorboard import SummaryWriter
# tb_writer = SummaryWriter()
Expand Down

0 comments on commit 372ad4c

Please sign in to comment.