Skip to content

Commit

Permalink
Add unittest for onnx and libtorch exports
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Mar 3, 2021
1 parent b6cdb5d commit 41b8961
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
19 changes: 18 additions & 1 deletion test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import unittest
from torchvision.ops._register_onnx_ops import _onnx_opset_version

from yolort.models import yolov5s, yolov5m
from yolort.models import yolov5s, yolov5m, yolotr


@unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable')
Expand Down Expand Up @@ -135,6 +135,23 @@ def test_yolov5m_r40(self):
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
tolerate_small_mismatch=True)

def test_yolotr(self):
images_one, images_two = self.get_test_images()
images_dummy = [torch.ones(3, 100, 100) * 0.3]
model = yolotr(upstream_version='v4.0', export_friendly=True, pretrained=True)
model.eval()
model(images_one)
# Test exported model on images of different size, or dummy input
self.run_model(model, [(images_one,), (images_two,), (images_dummy,)], input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
tolerate_small_mismatch=True)
# Test exported model for an image with no detections on other images
self.run_model(model, [(images_dummy,), (images_one,)], input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
tolerate_small_mismatch=True)


if __name__ == '__main__':
unittest.main()
16 changes: 15 additions & 1 deletion test/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from yolort.models import yolov5s, yolov5m, yolov5l
from yolort.models import yolov5s, yolov5m, yolov5l, yolotr


class TorchScriptTester(unittest.TestCase):
Expand Down Expand Up @@ -51,6 +51,20 @@ def test_yolov5l_script(self):
self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"]))

def test_yolotr_script(self):
model = yolotr(pretrained=True)
model.eval()

scripted_model = torch.jit.script(model)
scripted_model.eval()

x = [torch.rand(3, 416, 320), torch.rand(3, 480, 352)]

out = model(x)
out_script = scripted_model(x)
self.assertTrue(out[0]["scores"].equal(out_script[0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"]))

if __name__ == "__main__":
unittest.main()

0 comments on commit 41b8961

Please sign in to comment.