Skip to content

Commit

Permalink
End2end (WongKinYiu#61)
Browse files Browse the repository at this point in the history
* export end2end onnx model

* fixbug

* add web demo (WongKinYiu#58)

* Update README.md

* main code

update yolov7-tiny deploy cfg

* main code

update yolov7-tiny training cfg

* main code

@liguagua752109150 WongKinYiu#33 (comment)

* main code

@albertfaromatics WongKinYiu#35 (comment)

* main code

update link

* main code

add custom hyp

* main code

update default activation function

* main code

update path

* main figure

add more tasks

* main code

update readme

* main code

update reparameterization

* Update README.md

* main code

update readme

* main code

update aux training

* main code

update aux training

* main code

update aux training

* main figure

update yolov7 prediction

* main code

update readme

* main code

rename

* main code

rename

* main code

rename

* main code

rename

* main code

update readme

* main code

update visualization

* main code

fix gain for train_aux

* main code

update loss

* main code

update instance segmentation demo

* main code

update keypoint detection demo

* main code

update pose demo

* main code

update pose

* main code

update pose

* main code

update pose

* main code

update pose

* main code

update trace

* Update README.md

* main code

fix ciou

* main code

fix nan of aux training WongKinYiu#250 (comment) @hudingding

* support onnx to tensorrt convert (WongKinYiu#114)

* fuse IDetect (WongKinYiu#148)

* Fixes WongKinYiu#199 (WongKinYiu#203)

* minor fix

* resolve conflict

* resolve conflict

* resolve conflict

* resolve conflict

* resolve conflict

* resolve

* resolve

* resolve

* resolve

Co-authored-by: AK391 <81195143+AK391@users.noreply.github.com>
Co-authored-by: Alexey <AlexeyAB@users.noreply.github.com>
Co-authored-by: Kin-Yiu, Wong <102582011@cc.ncu.edu.tw>
Co-authored-by: linghu8812 <36389436+linghu8812@users.noreply.github.com>
Co-authored-by: Alexander <84590713+SashaAlderson@users.noreply.github.com>
Co-authored-by: Ben Raymond <ben@theraymonds.org>
Co-authored-by: AlexeyAB84 <alexeyab84@gmail.com>
  • Loading branch information
8 people committed Jul 22, 2022
1 parent 1a01e35 commit fe38440
Show file tree
Hide file tree
Showing 12 changed files with 284 additions and 18 deletions.
3 changes: 3 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 46 additions & 0 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions .idea/yolov7.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ python detect.py --weights yolov7.pt --conf 0.25 --img-size 640 --source inferen
</a>
</div>


## Export
Use the args `--include-nms` can to export end to end onnx model which include the `EfficientNMS`.
```shell
python models/export.py --weights yolov7.pt --grid --include-nms
```

## Citation

```
Expand Down
35 changes: 22 additions & 13 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from utils.activations import Hardswish, SiLU
from utils.general import set_logging, check_img_size
from utils.torch_utils import select_device
from utils.add_nms import RegisterNMS

if __name__ == '__main__':
parser = argparse.ArgumentParser()
Expand All @@ -22,6 +23,7 @@
parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
parser.add_argument('--include-nms', action='store_true', help='export end2end onnx')
opt = parser.parse_args()
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
print(opt)
Expand Down Expand Up @@ -52,7 +54,9 @@
# m.forward = m.forward_export # assign forward (optional)
model.model[-1].export = not opt.grid # set Detect() layer grid export
y = model(img) # dry run

if opt.include_nms:
model.model[-1].include_nms = True
y = None
# TorchScript export
try:
print('\nStarting TorchScript export with torch %s...' % torch.__version__)
Expand All @@ -75,16 +79,23 @@
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)

# Checks
onnx_model = onnx.load(f) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model

# # Metadata
# d = {'stride': int(max(model.stride))}
# for k, v in d.items():
# meta = onnx_model.metadata_props.add()
# meta.key, meta.value = k, str(v)
# onnx.save(onnx_model, f)
if opt.include_nms:
print('Registering NMS plugin...')
mo = RegisterNMS(f)
mo.register_nms()
mo.save(f)
else:
# Checks
onnx_model = onnx.load(f) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model
# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model

# # Metadata
# d = {'stride': int(max(model.stride))}
# for k, v in d.items():
# meta = onnx_model.metadata_props.add()
# meta.key, meta.value = k, str(v)
# onnx.save(onnx_model, f)

if opt.simplify:
try:
Expand All @@ -95,11 +106,9 @@
assert check, 'assert check failed'
except Exception as e:
print(f'Simplifier failure: {e}')
# print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
print('ONNX export success, saved as %s' % f)
except Exception as e:
print('ONNX export failure: %s' % e)

# CoreML export
try:
import coremltools as ct
Expand Down
2 changes: 1 addition & 1 deletion models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def forward(self, x):
class ResX(Res):
# ResNet bottleneck
def __init__(self, c1, c2, shortcut=True, g=32, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
super().__init__(c1, c2, shortcu, g, e)
super().__init__(c1, c2, shortcut, g, e)
c_ = int(c2 * e) # hidden channels


Expand Down
22 changes: 18 additions & 4 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

sys.path.append('./') # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__)

import torch
from models.common import *
from models.experimental import *
from utils.autoanchor import check_anchor_order
Expand All @@ -23,7 +23,7 @@
class Detect(nn.Module):
stride = None # strides computed during build
export = False # onnx export

include_nms = False
def __init__(self, nc=80, anchors=(), ch=()): # detection layer
super(Detect, self).__init__()
self.nc = nc # number of classes
Expand All @@ -48,7 +48,6 @@ def forward(self, x):
if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)

y = x[i].sigmoid()
if not torch.onnx.is_in_onnx_export():
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
Expand All @@ -59,13 +58,28 @@ def forward(self, x):
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no))

return x if self.training else (torch.cat(z, 1), x)
if self.include_nms:
z = self.convert(z)

return x if self.training else (z, ) if self.include_nms else (torch.cat(z, 1), x)

@staticmethod
def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()

def convert(self, z):
z = torch.cat(z, 1)
box = z[:, :, :4]
conf = z[:, :, 4:5]
score = z[:, :, 5:]
score *= conf
convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
dtype=torch.float32,
device=z.device)
box @= convert_matrix
return (box, score)


class IDetect(nn.Module):
stride = None # strides computed during build
Expand Down
Loading

0 comments on commit fe38440

Please sign in to comment.