Skip to content

Commit

Permalink
Add export friendly substitutions of SiLU (#69)
Browse files Browse the repository at this point in the history
* Refactor module importing

* Add onnx friendly institutions of nn.SiLU

* Update notebook for ONNX exporting

* Update notebook for TVM exporting

* Add docs
  • Loading branch information
zhiqwang committed Feb 23, 2021
1 parent 5825161 commit 76f5a5d
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 84 deletions.
2 changes: 1 addition & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Optional list of dependencies required by the package
dependencies = ['torch', 'torchvision']

from yolort.models import yolov5s, yolov5m, yolov5l, yolov5s_r40, yolov5m_r40, yolov5l_r40
from yolort.models import yolov5s, yolov5m, yolov5l
54 changes: 33 additions & 21 deletions notebooks/export-onnx-inference-onnxruntime.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"import onnx\n",
"import onnxruntime\n",
"\n",
"from yolort.models import yolov5_onnx\n",
"from yolort.models import yolov5s\n",
"\n",
"from yolort.utils.image_utils import read_image"
]
Expand All @@ -26,7 +26,7 @@
"import os\n",
"\n",
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"5\"\n",
"\n",
"device = torch.device('cuda')"
]
Expand All @@ -44,7 +44,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = yolov5_onnx(pretrained=True, score_thresh=0.45)\n",
"model = yolov5s(upstream_version='v4.0', export_friendly=True, pretrained=True, score_thresh=0.45)\n",
"\n",
"model = model.eval()\n",
"model = model.to(device)"
Expand Down Expand Up @@ -100,8 +100,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 20 ms, sys: 0 ns, total: 20 ms\n",
"Wall time: 20.1 ms\n"
"CPU times: user 16 ms, sys: 0 ns, total: 16 ms\n",
"Wall time: 16.8 ms\n"
]
}
],
Expand All @@ -119,10 +119,10 @@
{
"data": {
"text/plain": [
"tensor([[ 48.4231, 401.9458, 237.0045, 897.8144],\n",
" [215.4538, 407.8977, 344.6994, 857.3773],\n",
" [ 13.1457, 225.1691, 801.7442, 736.7350],\n",
" [675.6570, 409.5675, 812.7283, 868.2495]], device='cuda:0')"
"tensor([[ 52.1687, 384.9377, 235.4150, 899.1040],\n",
" [223.6789, 406.9857, 346.8747, 862.1425],\n",
" [677.8205, 390.5674, 811.9033, 871.8314],\n",
" [ 9.4887, 227.6140, 799.6432, 766.6011]], device='cuda:0')"
]
},
"execution_count": 7,
Expand All @@ -142,7 +142,7 @@
{
"data": {
"text/plain": [
"tensor([0.8941, 0.8636, 0.8621, 0.7490], device='cuda:0')"
"tensor([0.8995, 0.8665, 0.8193, 0.8094], device='cuda:0')"
]
},
"execution_count": 8,
Expand All @@ -162,7 +162,7 @@
{
"data": {
"text/plain": [
"tensor([0, 0, 5, 0], device='cuda:0')"
"tensor([0, 0, 0, 5], device='cuda:0')"
]
},
"execution_count": 9,
Expand Down Expand Up @@ -224,17 +224,17 @@
" 'Automatically generated names will be applied to each dynamic axes of input {}'.format(key))\n",
"/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3123: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" dtype=torch.float32)).float())) for i in range(dim)]\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/anchor_utils.py:31: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
"/mnt/yolov5-rt-stack/yolort/models/anchor_utils.py:31: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
" stride = torch.as_tensor([stride], dtype=dtype, device=device)\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/anchor_utils.py:50: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
"/mnt/yolov5-rt-stack/yolort/models/anchor_utils.py:50: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
" anchor_grid = torch.as_tensor(anchor_grid, dtype=dtype, device=device)\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/anchor_utils.py:77: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
"/mnt/yolov5-rt-stack/yolort/models/anchor_utils.py:77: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
" shifts = shifts - torch.tensor(0.5, dtype=shifts.dtype, device=device)\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/box_head.py:344: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
"/mnt/yolov5-rt-stack/yolort/models/box_head.py:363: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
" for idx in range(batch_size): # image idx, image inference\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:287: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
"/mnt/yolov5-rt-stack/yolort/models/transform.py:287: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
" for s, s_orig in zip(new_size, original_size)\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:287: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
"/mnt/yolov5-rt-stack/yolort/models/transform.py:287: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" for s, s_orig in zip(new_size, original_size)\n",
"/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic_opset9.py:2378: UserWarning: Exporting aten::index operator of advanced indexing in opset 11 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.\n",
" \"If indices include negative values, the exported graph will produce incorrect results.\")\n",
Expand Down Expand Up @@ -264,6 +264,17 @@
"## Simplifier exported `ONNX` model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Don't forget to install `onnx-simplifier`\n",
"\n",
"```bash\n",
"!pip -U install onnx-simplifier\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 13,
Expand All @@ -273,7 +284,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Starting simplifing with onnxsim 0.3.1\n"
"Starting simplifing with onnxsim 0.3.2\n"
]
}
],
Expand Down Expand Up @@ -361,6 +372,7 @@
"metadata": {},
"outputs": [],
"source": [
"# ort_session = onnxruntime.InferenceSession(export_onnx_name)\n",
"ort_session = onnxruntime.InferenceSession(onnx_simp_name)"
]
},
Expand All @@ -384,8 +396,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 16 ms, sys: 8 ms, total: 24 ms\n",
"Wall time: 22.4 ms\n"
"CPU times: user 2.33 s, sys: 0 ns, total: 2.33 s\n",
"Wall time: 77.9 ms\n"
]
}
],
Expand All @@ -411,7 +423,7 @@
],
"source": [
"for i in range(0, len(outputs)):\n",
" torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-05, atol=1e-07)\n",
" torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-04, atol=1e-07)\n",
"\n",
"print(\"Exported model has been tested with ONNXRuntime, and the result looks good!\")"
]
Expand Down
34 changes: 15 additions & 19 deletions notebooks/export-relay-inference-tvm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Compile PyTorch Object Detection Models\n",
"# Compile YOLOv5 Models\n",
"\n",
"This article is an introductory tutorial to deploy PyTorch object\n",
"detection models with Relay VM.\n",
"This article is an introductory tutorial to deploy PyTorch YOLOv5 models with Relay VM.\n",
"\n",
"For us to begin with, PyTorch should be installed.\n",
"TorchVision is also required since we will be using it as our model zoo.\n",
Expand Down Expand Up @@ -75,7 +74,7 @@
},
"outputs": [],
"source": [
"in_size = 300\n",
"in_size = 416\n",
"\n",
"input_shape = (1, 3, in_size, in_size)\n",
"\n",
Expand Down Expand Up @@ -110,7 +109,7 @@
"source": [
"from yolort.models import yolov5s\n",
"\n",
"model_func = yolov5s(pretrained=True)"
"model_func = yolov5s(upstream_version='v4.0', export_friendly=True, pretrained=True)"
]
},
{
Expand Down Expand Up @@ -142,7 +141,7 @@
" anchor_grid = torch.as_tensor(anchor_grid, dtype=dtype, device=device)\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/anchor_utils.py:77: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
" shifts = shifts - torch.tensor(0.5, dtype=shifts.dtype, device=device)\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/box_head.py:344: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/box_head.py:363: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
" for idx in range(batch_size): # image idx, image inference\n",
"/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:287: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
" for s, s_orig in zip(new_size, original_size)\n",
Expand Down Expand Up @@ -171,12 +170,12 @@
"data": {
"text/plain": [
"graph(%self.1 : __torch__.TraceWrapper,\n",
" %images : Float(1:270000, 3:90000, 300:300, 300:1, requires_grad=0, device=cpu)):\n",
" %4620 : __torch__.yolort.models.yolo_module.YOLOModule = prim::GetAttr[name=\"model\"](%self.1)\n",
" %4999 : (Tensor, Tensor, Tensor) = prim::CallMethod[name=\"forward\"](%4620, %images)\n",
" %4996 : Float(300:4, 4:1, requires_grad=0, device=cpu), %4997 : Float(300:1, requires_grad=0, device=cpu), %4998 : Long(300:1, requires_grad=0, device=cpu) = prim::TupleUnpack(%4999)\n",
" %3728 : (Float(300:4, 4:1, requires_grad=0, device=cpu), Float(300:1, requires_grad=0, device=cpu), Long(300:1, requires_grad=0, device=cpu)) = prim::TupleConstruct(%4996, %4997, %4998)\n",
" return (%3728)"
" %images : Float(1:519168, 3:173056, 416:416, 416:1, requires_grad=0, device=cpu)):\n",
" %4495 : __torch__.yolort.models.yolo_module.YOLOModule = prim::GetAttr[name=\"model\"](%self.1)\n",
" %4874 : (Tensor, Tensor, Tensor) = prim::CallMethod[name=\"forward\"](%4495, %images)\n",
" %4871 : Float(300:4, 4:1, requires_grad=0, device=cpu), %4872 : Float(300:1, requires_grad=0, device=cpu), %4873 : Long(300:1, requires_grad=0, device=cpu) = prim::TupleUnpack(%4874)\n",
" %3611 : (Float(300:4, 4:1, requires_grad=0, device=cpu), Float(300:1, requires_grad=0, device=cpu), Long(300:1, requires_grad=0, device=cpu)) = prim::TupleConstruct(%4871, %4872, %4873)\n",
" return (%3611)"
]
},
"execution_count": 6,
Expand All @@ -201,7 +200,7 @@
"metadata": {},
"outputs": [],
"source": [
"img_path = 'test/assets/bus.jpg'\n",
"img_path = './test/assets/bus.jpg'\n",
"\n",
"img = cv2.imread(img_path).astype(\"float32\")\n",
"img = cv2.resize(img, (in_size, in_size))\n",
Expand Down Expand Up @@ -360,7 +359,6 @@
},
"outputs": [],
"source": [
"# Dummy run\n",
"ctx = tvm.cpu()\n",
"vm = VirtualMachine(vm_exec, ctx)\n",
"vm.set_input(\"main\", **{input_name: img})\n",
Expand All @@ -381,15 +379,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 528 ms, sys: 364 ms, total: 892 ms\n",
"Wall time: 22.3 ms\n"
"CPU times: user 684 ms, sys: 832 ms, total: 1.52 s\n",
"Wall time: 39.2 ms\n"
]
}
],
"source": [
"%%time\n",
"ctx = tvm.cpu()\n",
"vm = VirtualMachine(vm_exec, ctx)\n",
"vm.set_input(\"main\", **{input_name: img})\n",
"tvm_res = vm.run()"
]
Expand Down Expand Up @@ -454,4 +450,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
55 changes: 40 additions & 15 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import unittest
from torchvision.ops._register_onnx_ops import _onnx_opset_version

from yolort.models import yolov5_onnx
from yolort.models import yolov5s, yolov5m


@unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable')
Expand All @@ -19,15 +19,23 @@ class ONNXExporterTester(unittest.TestCase):
def setUpClass(cls):
torch.manual_seed(123)

def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None,
def run_model(self, model, inputs_list, tolerate_small_mismatch=False,
do_constant_folding=True, dynamic_axes=None,
output_names=None, input_names=None):
model.eval()

onnx_io = io.BytesIO()
# export to onnx with the first input
torch.onnx.export(model, inputs_list[0], onnx_io,
do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version,
dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names)
torch.onnx.export(
model,
inputs_list[0],
onnx_io,
do_constant_folding=do_constant_folding,
opset_version=_onnx_opset_version,
dynamic_axes=dynamic_axes,
input_names=input_names,
output_names=output_names,
)
# validate the exported model with onnx runtime
for test_inputs in inputs_list:
with torch.no_grad():
Expand Down Expand Up @@ -89,23 +97,40 @@ def get_test_images(self):
image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
image2 = self.get_image_from_url(url=image_url2, size=(250, 380))

images = [image]
test_images = [image2]
return images, test_images
images_one = [image]
images_two = [image2]
return images_one, images_two

def test_yolov5s(self):
images, test_images = self.get_test_images()
dummy_image = [torch.ones(3, 100, 100) * 0.3]
model = yolov5_onnx(pretrained=True)
def test_yolov5s_r31(self):
images_one, images_two = self.get_test_images()
images_dummy = [torch.ones(3, 100, 100) * 0.3]
model = yolov5s(upstream_version='v3.1', export_friendly=True, pretrained=True)
model.eval()
model(images)
model(images_one)
# Test exported model on images of different size, or dummy input
self.run_model(model, [(images,), (test_images,), (dummy_image,)], input_names=["images_tensors"],
self.run_model(model, [(images_one,), (images_two,), (images_dummy,)], input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
tolerate_small_mismatch=True)
# Test exported model for an image with no detections on other images
self.run_model(model, [(dummy_image,), (images,)], input_names=["images_tensors"],
self.run_model(model, [(images_dummy,), (images_one,)], input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
tolerate_small_mismatch=True)

def test_yolov5m_r40(self):
images_one, images_two = self.get_test_images()
images_dummy = [torch.ones(3, 100, 100) * 0.3]
model = yolov5m(upstream_version='v4.0', export_friendly=True, pretrained=True)
model.eval()
model(images_one)
# Test exported model on images of different size, or dummy input
self.run_model(model, [(images_one,), (images_two,), (images_dummy,)], input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
tolerate_small_mismatch=True)
# Test exported model for an image with no detections on other images
self.run_model(model, [(images_dummy,), (images_one,)], input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
tolerate_small_mismatch=True)
Expand Down
Loading

0 comments on commit 76f5a5d

Please sign in to comment.