From 7e5a3338213852c6109ad5e10ae19aaf8bfd184f Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 4 Mar 2021 00:24:13 +0800 Subject: [PATCH] Apply Transformer to YOLO (#75) * Add yolotr model structure * Add updated model checkpoint from dingyiwei * Fix docs and copyright statements * Add unittest for onnx and libtorch exports * Add unittest for model features --- test/test_models.py | 49 ++++++++- test/test_onnx.py | 19 +++- test/test_torchscript.py | 16 ++- yolort/models/__init__.py | 23 ++++- yolort/models/backbone_utils.py | 7 +- yolort/models/experimental.py | 2 +- yolort/models/yolo.py | 38 ++++++- yolort/models/yolotr.py | 176 ++++++++++++++++++++++++++++++++ 8 files changed, 314 insertions(+), 16 deletions(-) create mode 100644 yolort/models/yolotr.py diff --git a/test/test_models.py b/test/test_models.py index 16379af3..ecd24039 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -2,6 +2,7 @@ import torch from yolort.models.backbone_utils import darknet_pan_backbone +from yolort.models.yolotr import darknet_pan_tr_backbone from yolort.models.anchor_utils import AnchorGenerator from yolort.models.box_head import YoloHead, PostProcess, SetCriterion @@ -65,19 +66,61 @@ def _get_head_outputs(self, batch_size, h, w): return head_outputs - def _init_test_backbone_with_fpn(self): + def _init_test_backbone_with_pan_r3_1(self): backbone_name = 'darknet_s_r3_1' depth_multiple = 0.33 width_multiple = 0.5 backbone_with_fpn = darknet_pan_backbone(backbone_name, depth_multiple, width_multiple) return backbone_with_fpn - def test_backbone_with_fpn(self): + def test_backbone_with_pan_r3_1(self): N, H, W = 4, 416, 352 out_shape = self._get_feature_shapes(H, W) x = torch.rand(N, 3, H, W) - model = self._init_test_backbone_with_fpn() + model = self._init_test_backbone_with_pan_r3_1() + out = model(x) + + self.assertEqual(len(out), 3) + self.assertEqual(tuple(out[0].shape), (N, *out_shape[0])) + self.assertEqual(tuple(out[1].shape), (N, *out_shape[1])) + self.assertEqual(tuple(out[2].shape), (N, *out_shape[2])) + self.check_jit_scriptable(model, (x,)) + + def _init_test_backbone_with_pan_r4_0(self): + backbone_name = 'darknet_s_r4_0' + depth_multiple = 0.33 + width_multiple = 0.5 + backbone_with_fpn = darknet_pan_backbone(backbone_name, depth_multiple, width_multiple) + return backbone_with_fpn + + def test_backbone_with_pan_r4_0(self): + N, H, W = 4, 416, 352 + out_shape = self._get_feature_shapes(H, W) + + x = torch.rand(N, 3, H, W) + model = self._init_test_backbone_with_pan_r4_0() + out = model(x) + + self.assertEqual(len(out), 3) + self.assertEqual(tuple(out[0].shape), (N, *out_shape[0])) + self.assertEqual(tuple(out[1].shape), (N, *out_shape[1])) + self.assertEqual(tuple(out[2].shape), (N, *out_shape[2])) + self.check_jit_scriptable(model, (x,)) + + def _init_test_backbone_with_pan_tr(self): + backbone_name = 'darknet_s_r4_0' + depth_multiple = 0.33 + width_multiple = 0.5 + backbone_with_fpn_tr = darknet_pan_tr_backbone(backbone_name, depth_multiple, width_multiple) + return backbone_with_fpn_tr + + def test_backbone_with_pan_tr(self): + N, H, W = 4, 416, 352 + out_shape = self._get_feature_shapes(H, W) + + x = torch.rand(N, 3, H, W) + model = self._init_test_backbone_with_pan_tr() out = model(x) self.assertEqual(len(out), 3) diff --git a/test/test_onnx.py b/test/test_onnx.py index f1ebca32..363254db 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -10,7 +10,7 @@ import unittest from torchvision.ops._register_onnx_ops import _onnx_opset_version -from yolort.models import yolov5s, yolov5m +from yolort.models import yolov5s, yolov5m, yolotr @unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable') @@ -135,6 +135,23 @@ def test_yolov5m_r40(self): dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, tolerate_small_mismatch=True) + def test_yolotr(self): + images_one, images_two = self.get_test_images() + images_dummy = [torch.ones(3, 100, 100) * 0.3] + model = yolotr(upstream_version='v4.0', export_friendly=True, pretrained=True) + model.eval() + model(images_one) + # Test exported model on images of different size, or dummy input + self.run_model(model, [(images_one,), (images_two,), (images_dummy,)], input_names=["images_tensors"], + output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, + tolerate_small_mismatch=True) + # Test exported model for an image with no detections on other images + self.run_model(model, [(images_dummy,), (images_one,)], input_names=["images_tensors"], + output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, + tolerate_small_mismatch=True) + if __name__ == '__main__': unittest.main() diff --git a/test/test_torchscript.py b/test/test_torchscript.py index 3820c5a4..632f7bbf 100644 --- a/test/test_torchscript.py +++ b/test/test_torchscript.py @@ -2,7 +2,7 @@ import torch -from yolort.models import yolov5s, yolov5m, yolov5l +from yolort.models import yolov5s, yolov5m, yolov5l, yolotr class TorchScriptTester(unittest.TestCase): @@ -51,6 +51,20 @@ def test_yolov5l_script(self): self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"])) self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"])) + def test_yolotr_script(self): + model = yolotr(pretrained=True) + model.eval() + + scripted_model = torch.jit.script(model) + scripted_model.eval() + + x = [torch.rand(3, 416, 320), torch.rand(3, 480, 352)] + + out = model(x) + out_script = scripted_model(x) + self.assertTrue(out[0]["scores"].equal(out_script[0]["scores"])) + self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"])) + self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"])) if __name__ == "__main__": unittest.main() diff --git a/yolort/models/__init__.py b/yolort/models/__init__.py index 9abc75b9..f933888a 100644 --- a/yolort/models/__init__.py +++ b/yolort/models/__init__.py @@ -9,7 +9,7 @@ from typing import Any -def yolov5s(upstream_version: str ='v3.1', export_friendly: bool = False, **kwargs: Any): +def yolov5s(upstream_version: str = 'v3.1', export_friendly: bool = False, **kwargs: Any): """ Args: upstream_version (str): Determine the upstream YOLOv5 version. @@ -28,7 +28,7 @@ def yolov5s(upstream_version: str ='v3.1', export_friendly: bool = False, **kwar return model -def yolov5m(upstream_version: str ='v3.1', export_friendly: bool = False, **kwargs: Any): +def yolov5m(upstream_version: str = 'v3.1', export_friendly: bool = False, **kwargs: Any): """ Args: upstream_version (str): Determine the upstream YOLOv5 version. @@ -47,7 +47,7 @@ def yolov5m(upstream_version: str ='v3.1', export_friendly: bool = False, **kwar return model -def yolov5l(upstream_version: str ='v3.1', export_friendly: bool = False, **kwargs: Any): +def yolov5l(upstream_version: str = 'v3.1', export_friendly: bool = False, **kwargs: Any): """ Args: upstream_version (str): Determine the upstream YOLOv5 version. @@ -66,6 +66,23 @@ def yolov5l(upstream_version: str ='v3.1', export_friendly: bool = False, **kwar return model +def yolotr(upstream_version: str = 'v4.0', export_friendly: bool = False, **kwargs: Any): + """ + Args: + upstream_version (str): Determine the upstream YOLOv5 version. + export_friendly (bool): Deciding whether to use (ONNX/TVM) export friendly mode. + """ + if upstream_version == 'v4.0': + model = YOLOModule(arch="yolov5_darknet_pan_s_tr", **kwargs) + else: + raise NotImplementedError("Currently only supports v4.0 versions") + + if export_friendly: + _export_module_friendly(model) + + return model + + def _export_module_friendly(model): for m in model.modules(): m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility diff --git a/yolort/models/backbone_utils.py b/yolort/models/backbone_utils.py index 4b1b87ea..0f29dd0c 100644 --- a/yolort/models/backbone_utils.py +++ b/yolort/models/backbone_utils.py @@ -4,7 +4,6 @@ from . import darknet from .path_aggregation_network import PathAggregationNetwork -from .common import BottleneckCSP, C3 from typing import List, Optional @@ -53,7 +52,7 @@ def darknet_pan_backbone( version: str = 'v4.0', ): """ - Constructs a specified ResNet backbone with PAN on top. Freezes the specified number of + Constructs a specified DarkNet backbone with PAN on top. Freezes the specified number of layers in the backbone. Examples:: @@ -71,12 +70,12 @@ def darknet_pan_backbone( >>> ('2', torch.Size([1, 512, 2, 2]))] Args: - backbone_name (string): resnet architecture. Possible values are 'DarkNet', 'darknet_s_r3_1', + backbone_name (string): darknet architecture. Possible values are 'DarkNet', 'darknet_s_r3_1', 'darknet_m_r3_1', 'darknet_l_r3_1', 'darknet_s_r4_0', 'darknet_m_r4_0', 'darknet_l_r4_0' norm_layer (torchvision.ops): it is recommended to use the default value. For details visit: (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267) pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block. + trainable_layers (int): number of trainable (not frozen) darknet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. version (str): ultralytics release version: v3.1 or v4.0 """ diff --git a/yolort/models/experimental.py b/yolort/models/experimental.py index 4e1d87cd..14ed3b6d 100644 --- a/yolort/models/experimental.py +++ b/yolort/models/experimental.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn -from ..models.common import Conv, DWConv +from .common import Conv, DWConv class CrossConv(nn.Module): diff --git a/yolort/models/yolo.py b/yolort/models/yolo.py index 0f00d5cd..0af3cce3 100644 --- a/yolort/models/yolo.py +++ b/yolort/models/yolo.py @@ -1,5 +1,4 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -# Modified by Zhiqiang Wang (me@zhiqwang.com) +# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved. import warnings import torch @@ -8,13 +7,15 @@ from torchvision.models.utils import load_state_dict_from_url from .backbone_utils import darknet_pan_backbone +from .yolotr import darknet_pan_tr_backbone from .anchor_utils import AnchorGenerator from .box_head import YoloHead, SetCriterion, PostProcess from typing import Tuple, Any, List, Dict, Optional __all__ = ['YOLO', 'yolov5_darknet_pan_s_r31', 'yolov5_darknet_pan_m_r31', 'yolov5_darknet_pan_l_r31', - 'yolov5_darknet_pan_s_r40', 'yolov5_darknet_pan_m_r40', 'yolov5_darknet_pan_l_r40'] + 'yolov5_darknet_pan_s_r40', 'yolov5_darknet_pan_m_r40', 'yolov5_darknet_pan_l_r40', + 'yolov5_darknet_pan_s_tr'] class YOLO(nn.Module): @@ -133,6 +134,7 @@ def forward( 'yolov5_darknet_pan_s_r40_coco': f'{model_urls_root}/yolov5_darknet_pan_s_r40_coco-e3fd213d.pt', 'yolov5_darknet_pan_m_r40_coco': f'{model_urls_root}/yolov5_darknet_pan_m_r40_coco-d295cb02.pt', 'yolov5_darknet_pan_l_r40_coco': f'{model_urls_root}/yolov5_darknet_pan_l_r40_coco-4416841f.pt', + 'yolov5_darknet_pan_s_tr_coco': f'{model_urls_root}/yolov5_darknet_pan_s_tr_coco-f09f21f7.pt', } @@ -299,3 +301,33 @@ def yolov5_darknet_pan_l_r40(pretrained: bool = False, progress: bool = True, nu version = 'v4.0' return _yolov5_darknet_pan(backbone_name, depth_multiple, width_multiple, version, weights_name, pretrained=pretrained, progress=progress, num_classes=num_classes, **kwargs) + + +def yolov5_darknet_pan_s_tr(pretrained: bool = False, progress: bool = True, num_classes: int = 80, + **kwargs: Any) -> YOLO: + r"""yolov5 small with a transformer block model from + `"dingyiwei/yolov5" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + backbone_name = 'darknet_s_r4_0' + weights_name = 'yolov5_darknet_pan_s_tr_coco' + depth_multiple = 0.33 + width_multiple = 0.5 + version = 'v4.0' + + backbone = darknet_pan_tr_backbone(backbone_name, depth_multiple, width_multiple, version=version) + + anchor_grids = [[10, 13, 16, 30, 33, 23], + [30, 61, 62, 45, 59, 119], + [116, 90, 156, 198, 373, 326]] + + model = YOLO(backbone, num_classes, anchor_grids, **kwargs) + if pretrained: + if model_urls.get(weights_name, None) is None: + raise ValueError(f"No checkpoint is available for model {weights_name}") + state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) + model.load_state_dict(state_dict) + + return model diff --git a/yolort/models/yolotr.py b/yolort/models/yolotr.py new file mode 100644 index 00000000..070929b9 --- /dev/null +++ b/yolort/models/yolotr.py @@ -0,0 +1,176 @@ +# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. +from torch import nn + +from .common import Conv, C3 +from .path_aggregation_network import PathAggregationNetwork +from .backbone_utils import BackboneWithPAN + +from . import darknet + +from typing import Callable, List, Optional + + +def darknet_pan_tr_backbone( + backbone_name: str, + depth_multiple: float, + width_multiple: float, + pretrained: Optional[bool] = False, + returned_layers: Optional[List[int]] = None, + version: str = 'v4.0', +): + """ + Constructs a specified DarkNet backbone with PAN on top. Freezes the specified number of + layers in the backbone. + + Examples:: + + >>> from models.backbone_utils import darknet_pan_tr_backbone + >>> backbone = darknet_pan_tr_backbone('darknet3_1', pretrained=True, trainable_layers=3) + >>> # get some dummy image + >>> x = torch.rand(1, 3, 64, 64) + >>> # compute the output + >>> output = backbone(x) + >>> print([(k, v.shape) for k, v in output.items()]) + >>> # returns + >>> [('0', torch.Size([1, 128, 8, 8])), + >>> ('1', torch.Size([1, 256, 4, 4])), + >>> ('2', torch.Size([1, 512, 2, 2]))] + + Args: + backbone_name (string): darknet architecture. Possible values are 'DarkNet', 'darknet_s_r3_1', + 'darknet_m_r3_1', 'darknet_l_r3_1', 'darknet_s_r4_0', 'darknet_m_r4_0', 'darknet_l_r4_0' + norm_layer (torchvision.ops): it is recommended to use the default value. For details visit: + (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267) + pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet + trainable_layers (int): number of trainable (not frozen) darknet layers starting from final block. + Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. + version (str): ultralytics release version, currently only supports v3.1 or v4.0 + """ + backbone = darknet.__dict__[backbone_name](pretrained=pretrained).features + + if returned_layers is None: + returned_layers = [4, 6, 8] + + return_layers = {str(k): str(i) for i, k in enumerate(returned_layers)} + + in_channels_list = [int(gw * width_multiple) for gw in [256, 512, 1024]] + + return BackboneWithPANTranformer(backbone, return_layers, in_channels_list, depth_multiple, version) + + +class BackboneWithPANTranformer(BackboneWithPAN): + def __init__(self, backbone, return_layers, in_channels_list, depth_multiple, version): + super().__init__(backbone, return_layers, in_channels_list, depth_multiple, version) + self.pan = PathAggregationNetworkTransformer( + in_channels_list, + depth_multiple, + version=version, + ) + + +class PathAggregationNetworkTransformer(PathAggregationNetwork): + def __init__( + self, + in_channels_list: List[int], + depth_multiple: float, + version: str, + block: Optional[Callable[..., nn.Module]] = None, + ): + super().__init__(in_channels_list, depth_multiple, version=version, block=block) + assert len(in_channels_list) == 3, "currently only support length 3." + + if block is None: + block = C3 + + depth_gain = max(round(3 * depth_multiple), 1) + + inner_blocks = [ + C3TR(in_channels_list[2], in_channels_list[2], n=depth_gain, shortcut=False), + Conv(in_channels_list[2], in_channels_list[1], 1, 1, version=version), + nn.Upsample(scale_factor=2), + block(in_channels_list[2], in_channels_list[1], n=depth_gain, shortcut=False), + Conv(in_channels_list[1], in_channels_list[0], 1, 1, version=version), + nn.Upsample(scale_factor=2), + ] + + self.inner_blocks = nn.ModuleList(inner_blocks) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.eps = 1e-3 + m.momentum = 0.03 + elif isinstance(m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6)): + m.inplace = True + + +class C3TR(C3): + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) + self.m = TransformerBlock(c_, c_, 4, n) + + +class TransformerLayer(nn.Module): + def __init__(self, c, num_heads): + """ + Args: + c (int): + num_heads: + """ + super().__init__() + + self.ln1 = nn.LayerNorm(c) + self.q = nn.Linear(c, c, bias=False) + self.k = nn.Linear(c, c, bias=False) + self.v = nn.Linear(c, c, bias=False) + self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads) + self.ln2 = nn.LayerNorm(c) + self.fc1 = nn.Linear(c, c, bias=False) + self.fc2 = nn.Linear(c, c, bias=False) + + def forward(self, x): + x_ = self.ln1(x) + x = self.ma(self.q(x_), self.k(x_), self.v(x_))[0] + x + x = self.ln2(x) + x = self.fc2(self.fc1(x)) + x + return x + + +class TransformerBlock(nn.Module): + def __init__(self, c1, c2, num_heads, num_layers): + """ + Args: + c1 (int): ch_in + c2 (int): ch_out + num_heads: + num_layers: + """ + super().__init__() + + self.conv = None + if c1 != c2: + self.conv = Conv(c1, c2) + self.linear = nn.Linear(c2, c2) + self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)]) + self.c2 = c2 + + def forward(self, x): + if self.conv is not None: + x = self.conv(x) + + b, _, w, h = x.shape + p = x.flatten(2) + p = p.unsqueeze(0) + p = p.transpose(0, 3) + p = p.squeeze(3) + + e = self.linear(p) + x = p + e + + x = self.tr(x) + x = x.unsqueeze(3) + x = x.transpose(0, 3) + x = x.reshape(b, self.c2, w, h) + return x