From ab8bafe33071afb8ce75d53b3e48f502170e349e Mon Sep 17 00:00:00 2001 From: zhiqwang Date: Wed, 3 Mar 2021 11:06:20 -0500 Subject: [PATCH] Add unittest for model features --- test/test_models.py | 49 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 3 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 16379af3..ecd24039 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -2,6 +2,7 @@ import torch from yolort.models.backbone_utils import darknet_pan_backbone +from yolort.models.yolotr import darknet_pan_tr_backbone from yolort.models.anchor_utils import AnchorGenerator from yolort.models.box_head import YoloHead, PostProcess, SetCriterion @@ -65,19 +66,61 @@ def _get_head_outputs(self, batch_size, h, w): return head_outputs - def _init_test_backbone_with_fpn(self): + def _init_test_backbone_with_pan_r3_1(self): backbone_name = 'darknet_s_r3_1' depth_multiple = 0.33 width_multiple = 0.5 backbone_with_fpn = darknet_pan_backbone(backbone_name, depth_multiple, width_multiple) return backbone_with_fpn - def test_backbone_with_fpn(self): + def test_backbone_with_pan_r3_1(self): N, H, W = 4, 416, 352 out_shape = self._get_feature_shapes(H, W) x = torch.rand(N, 3, H, W) - model = self._init_test_backbone_with_fpn() + model = self._init_test_backbone_with_pan_r3_1() + out = model(x) + + self.assertEqual(len(out), 3) + self.assertEqual(tuple(out[0].shape), (N, *out_shape[0])) + self.assertEqual(tuple(out[1].shape), (N, *out_shape[1])) + self.assertEqual(tuple(out[2].shape), (N, *out_shape[2])) + self.check_jit_scriptable(model, (x,)) + + def _init_test_backbone_with_pan_r4_0(self): + backbone_name = 'darknet_s_r4_0' + depth_multiple = 0.33 + width_multiple = 0.5 + backbone_with_fpn = darknet_pan_backbone(backbone_name, depth_multiple, width_multiple) + return backbone_with_fpn + + def test_backbone_with_pan_r4_0(self): + N, H, W = 4, 416, 352 + out_shape = self._get_feature_shapes(H, W) + + x = torch.rand(N, 3, H, W) + model = self._init_test_backbone_with_pan_r4_0() + out = model(x) + + self.assertEqual(len(out), 3) + self.assertEqual(tuple(out[0].shape), (N, *out_shape[0])) + self.assertEqual(tuple(out[1].shape), (N, *out_shape[1])) + self.assertEqual(tuple(out[2].shape), (N, *out_shape[2])) + self.check_jit_scriptable(model, (x,)) + + def _init_test_backbone_with_pan_tr(self): + backbone_name = 'darknet_s_r4_0' + depth_multiple = 0.33 + width_multiple = 0.5 + backbone_with_fpn_tr = darknet_pan_tr_backbone(backbone_name, depth_multiple, width_multiple) + return backbone_with_fpn_tr + + def test_backbone_with_pan_tr(self): + N, H, W = 4, 416, 352 + out_shape = self._get_feature_shapes(H, W) + + x = torch.rand(N, 3, H, W) + model = self._init_test_backbone_with_pan_tr() out = model(x) self.assertEqual(len(out), 3)