From 41b896162e9b0f83a8fdbca4821efbb7a85670ba Mon Sep 17 00:00:00 2001 From: zhiqwang Date: Wed, 3 Mar 2021 10:57:17 -0500 Subject: [PATCH] Add unittest for onnx and libtorch exports --- test/test_onnx.py | 19 ++++++++++++++++++- test/test_torchscript.py | 16 +++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index f1ebca32..363254db 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -10,7 +10,7 @@ import unittest from torchvision.ops._register_onnx_ops import _onnx_opset_version -from yolort.models import yolov5s, yolov5m +from yolort.models import yolov5s, yolov5m, yolotr @unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable') @@ -135,6 +135,23 @@ def test_yolov5m_r40(self): dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, tolerate_small_mismatch=True) + def test_yolotr(self): + images_one, images_two = self.get_test_images() + images_dummy = [torch.ones(3, 100, 100) * 0.3] + model = yolotr(upstream_version='v4.0', export_friendly=True, pretrained=True) + model.eval() + model(images_one) + # Test exported model on images of different size, or dummy input + self.run_model(model, [(images_one,), (images_two,), (images_dummy,)], input_names=["images_tensors"], + output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, + tolerate_small_mismatch=True) + # Test exported model for an image with no detections on other images + self.run_model(model, [(images_dummy,), (images_one,)], input_names=["images_tensors"], + output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, + tolerate_small_mismatch=True) + if __name__ == '__main__': unittest.main() diff --git a/test/test_torchscript.py b/test/test_torchscript.py index 3820c5a4..632f7bbf 100644 --- a/test/test_torchscript.py +++ b/test/test_torchscript.py @@ -2,7 +2,7 @@ import torch -from yolort.models import yolov5s, yolov5m, yolov5l +from yolort.models import yolov5s, yolov5m, yolov5l, yolotr class TorchScriptTester(unittest.TestCase): @@ -51,6 +51,20 @@ def test_yolov5l_script(self): self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"])) self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"])) + def test_yolotr_script(self): + model = yolotr(pretrained=True) + model.eval() + + scripted_model = torch.jit.script(model) + scripted_model.eval() + + x = [torch.rand(3, 416, 320), torch.rand(3, 480, 352)] + + out = model(x) + out_script = scripted_model(x) + self.assertTrue(out[0]["scores"].equal(out_script[0]["scores"])) + self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"])) + self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"])) if __name__ == "__main__": unittest.main()