Skip to content

Commit

Permalink
support onnx to tensorrt convert (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
linghu8812 committed Jul 21, 2022
1 parent 4f6e390 commit 96390ed
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
2 changes: 1 addition & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def detect(save_img=False):

# Inference
t1 = time_synchronized()
pred = model(img, augment=opt.augment)[0]
pred = model(img, augment=opt.augment)

# Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
Expand Down
19 changes: 19 additions & 0 deletions models/export.py → export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes')
parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
opt = parser.parse_args()
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
print(opt)
Expand Down Expand Up @@ -68,6 +69,7 @@

print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
f = opt.weights.replace('.pt', '.onnx') # filename
model.eval()
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
output_names=['classes', 'boxes'] if y is None else ['output'],
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
Expand All @@ -76,6 +78,23 @@
# Checks
onnx_model = onnx.load(f) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model

# # Metadata
# d = {'stride': int(max(model.stride))}
# for k, v in d.items():
# meta = onnx_model.metadata_props.add()
# meta.key, meta.value = k, str(v)
# onnx.save(onnx_model, f)

if opt.simplify:
try:
import onnxsim

print('\nStarting to simplify ONNX...')
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, 'assert check failed'
except Exception as e:
print(f'Simplifier failure: {e}')
# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
print('ONNX export success, saved as %s' % f)
except Exception as e:
Expand Down
11 changes: 8 additions & 3 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,16 @@ def forward(self, x):
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)

y = x[i].sigmoid()
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
if not torch.onnx.is_in_onnx_export():
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else:
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no))

return x if self.training else (torch.cat(z, 1), x)
return x if self.training else torch.cat(z, 1)

@staticmethod
def _make_grid(nx=20, ny=20):
Expand Down

0 comments on commit 96390ed

Please sign in to comment.