Skip to content

Commit

Permalink
add NMS to pretrained pytorch hub models
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Sep 19, 2020
1 parent 5a9c5c1 commit c4cb785
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
7 changes: 7 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch

from models.common import NMS
from models.yolo import Model
from utils.google_utils import attempt_download

Expand All @@ -35,6 +36,12 @@ def create(name, pretrained, channels, classes):
state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
model.load_state_dict(state_dict, strict=False) # load

m = NMS()
m.f = -1 # from
m.i = model.model[-1].i + 1 # index
model.model.add_module(name='%s' % m.i, module=m) # add NMS
model.eval()
return model

except Exception as e:
Expand Down
14 changes: 14 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.nn as nn
from utils.general import non_max_suppression


def autopad(k, p=None): # kernel, padding
Expand Down Expand Up @@ -98,6 +99,19 @@ def forward(self, x):
return torch.cat(x, self.d)


class NMS(nn.Module):
# Non-Maximum Suppression (NMS) module
conf = 0.3 # confidence threshold
iou = 0.6 # IoU threshold
classes = None # (optional list) filter by class

def __init__(self, dimension=1):
super(NMS, self).__init__()

def forward(self, x):
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)


class Flatten(nn.Module):
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
@staticmethod
Expand Down

1 comment on commit c4cb785

@glenn-jocher
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enables simpler pytorch hub inference. See #36 (comment)

Please sign in to comment.