Skip to content

Commit

Permalink
Add unittest for model features
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Mar 3, 2021
1 parent 41b8961 commit ab8bafe
Showing 1 changed file with 46 additions and 3 deletions.
49 changes: 46 additions & 3 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch

from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.yolotr import darknet_pan_tr_backbone
from yolort.models.anchor_utils import AnchorGenerator
from yolort.models.box_head import YoloHead, PostProcess, SetCriterion

Expand Down Expand Up @@ -65,19 +66,61 @@ def _get_head_outputs(self, batch_size, h, w):

return head_outputs

def _init_test_backbone_with_fpn(self):
def _init_test_backbone_with_pan_r3_1(self):
backbone_name = 'darknet_s_r3_1'
depth_multiple = 0.33
width_multiple = 0.5
backbone_with_fpn = darknet_pan_backbone(backbone_name, depth_multiple, width_multiple)
return backbone_with_fpn

def test_backbone_with_fpn(self):
def test_backbone_with_pan_r3_1(self):
N, H, W = 4, 416, 352
out_shape = self._get_feature_shapes(H, W)

x = torch.rand(N, 3, H, W)
model = self._init_test_backbone_with_fpn()
model = self._init_test_backbone_with_pan_r3_1()
out = model(x)

self.assertEqual(len(out), 3)
self.assertEqual(tuple(out[0].shape), (N, *out_shape[0]))
self.assertEqual(tuple(out[1].shape), (N, *out_shape[1]))
self.assertEqual(tuple(out[2].shape), (N, *out_shape[2]))
self.check_jit_scriptable(model, (x,))

def _init_test_backbone_with_pan_r4_0(self):
backbone_name = 'darknet_s_r4_0'
depth_multiple = 0.33
width_multiple = 0.5
backbone_with_fpn = darknet_pan_backbone(backbone_name, depth_multiple, width_multiple)
return backbone_with_fpn

def test_backbone_with_pan_r4_0(self):
N, H, W = 4, 416, 352
out_shape = self._get_feature_shapes(H, W)

x = torch.rand(N, 3, H, W)
model = self._init_test_backbone_with_pan_r4_0()
out = model(x)

self.assertEqual(len(out), 3)
self.assertEqual(tuple(out[0].shape), (N, *out_shape[0]))
self.assertEqual(tuple(out[1].shape), (N, *out_shape[1]))
self.assertEqual(tuple(out[2].shape), (N, *out_shape[2]))
self.check_jit_scriptable(model, (x,))

def _init_test_backbone_with_pan_tr(self):
backbone_name = 'darknet_s_r4_0'
depth_multiple = 0.33
width_multiple = 0.5
backbone_with_fpn_tr = darknet_pan_tr_backbone(backbone_name, depth_multiple, width_multiple)
return backbone_with_fpn_tr

def test_backbone_with_pan_tr(self):
N, H, W = 4, 416, 352
out_shape = self._get_feature_shapes(H, W)

x = torch.rand(N, 3, H, W)
model = self._init_test_backbone_with_pan_tr()
out = model(x)

self.assertEqual(len(out), 3)
Expand Down

0 comments on commit ab8bafe

Please sign in to comment.