From 78ed66fc3416c622cb757a96e7ee2bbe63701f99 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Sun, 10 Oct 2021 21:43:27 +0800 Subject: [PATCH] Resolve version conflicts (#195) * Rename yolotr to yolov5t * Remove redundant code and resolve version conflicts * Rename yolov5t to yolov5ts * Resolve version conflicts * Fixing docstrings * Resolve version conflicts * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixing TypeError in darknet_tan_backbone * Fixing docstrings * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixing test_load_from_yolov5 * Fixing test_load_from_yolov5 * Rename to up_version Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- hubconf.py | 2 +- test/test_models.py | 30 +++---- test/test_onnx.py | 2 +- yolort/models/__init__.py | 4 +- yolort/models/_utils.py | 12 ++- yolort/models/backbone_utils.py | 24 ++--- yolort/models/darknetv5.py | 26 ++++-- yolort/models/darknetv6.py | 18 ++-- yolort/models/path_aggregation_network.py | 18 +++- yolort/models/transformer.py | 101 +++------------------- yolort/models/yolo.py | 13 ++- yolort/models/yolo_module.py | 14 ++- yolort/utils/update_module_state.py | 2 +- yolort/v5/models/__init__.py | 2 + yolort/v5/models/common.py | 62 ++++++++----- 15 files changed, 166 insertions(+), 164 deletions(-) diff --git a/hubconf.py b/hubconf.py index f91e9439..a583b7c1 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,4 +1,4 @@ # Optional list of dependencies required by the package dependencies = ["torch", "torchvision"] -from yolort.models import yolov5s, yolov5m, yolov5l +from yolort.models import yolov5s, yolov5m, yolov5l, yolov5ts diff --git a/test/test_models.py b/test/test_models.py index 477faeaa..028f2b9f 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -131,10 +131,10 @@ def _init_test_backbone_with_pan_r3_1(self): backbone_name = "darknet_s_r3_1" depth_multiple = 0.33 width_multiple = 0.5 - backbone_with_fpn = darknet_pan_backbone( + backbone_with_pan = darknet_pan_backbone( backbone_name, depth_multiple, width_multiple ) - return backbone_with_fpn + return backbone_with_pan def test_backbone_with_pan_r3_1(self): N, H, W = 4, 416, 352 @@ -154,10 +154,10 @@ def _init_test_backbone_with_pan_r4_0(self): backbone_name = "darknet_s_r4_0" depth_multiple = 0.33 width_multiple = 0.5 - backbone_with_fpn = darknet_pan_backbone( + backbone_with_pan = darknet_pan_backbone( backbone_name, depth_multiple, width_multiple ) - return backbone_with_fpn + return backbone_with_pan def test_backbone_with_pan_r4_0(self): N, H, W = 4, 416, 352 @@ -173,21 +173,21 @@ def test_backbone_with_pan_r4_0(self): assert tuple(out[2].shape) == (N, *out_shape[2]) _check_jit_scriptable(model, (x,)) - def _init_test_backbone_with_pan_tr(self): + def _init_test_backbone_with_tan_r4_0(self): backbone_name = "darknet_s_r4_0" depth_multiple = 0.33 width_multiple = 0.5 - backbone_with_fpn_tr = darknet_tan_backbone( + backbone_with_tan = darknet_tan_backbone( backbone_name, depth_multiple, width_multiple ) - return backbone_with_fpn_tr + return backbone_with_tan - def test_backbone_with_pan_tr(self): + def test_backbone_with_tan_r4_0(self): N, H, W = 4, 416, 352 out_shape = self._get_feature_shapes(H, W) x = torch.rand(N, 3, H, W) - model = self._init_test_backbone_with_pan_tr() + model = self._init_test_backbone_with_tan_r4_0() out = model(x) assert len(out) == 3 @@ -278,7 +278,7 @@ def test_criterion(self): assert isinstance(losses["objectness"], Tensor) -@pytest.mark.parametrize("arch", ["yolov5s", "yolov5m", "yolov5l", "yolotr"]) +@pytest.mark.parametrize("arch", ["yolov5s", "yolov5m", "yolov5l", "yolov5ts"]) def test_torchscript(arch): model = models.__dict__[arch](pretrained=True, size=(320, 320), score_thresh=0.45) model.eval() @@ -303,21 +303,21 @@ def test_torchscript(arch): @pytest.mark.parametrize( - "arch, version, hash_prefix", [("yolov5s", "v4.0", "9ca9a642")] + "arch, up_version, hash_prefix", [("yolov5s", "v4.0", "9ca9a642")] ) -def test_load_from_yolov5(arch, version, hash_prefix): +def test_load_from_yolov5(arch, up_version, hash_prefix): img_path = "test/assets/bus.jpg" yolov5s_r40_path = Path(f"{arch}.pt") if not yolov5s_r40_path.exists(): torch.hub.download_url_to_file( - f"https://github.com/ultralytics/yolov5/releases/download/{version}/{arch}.pt", + f"https://github.com/ultralytics/yolov5/releases/download/{up_version}/{arch}.pt", yolov5s_r40_path, hash_prefix=hash_prefix, ) - yolov5 = YOLOv5() - model_yolov5 = yolov5.load_from_yolov5(yolov5s_r40_path, score_thresh=0.25) + version = up_version.replace("v", "r") + model_yolov5 = YOLOv5.load_from_yolov5(yolov5s_r40_path, version=version) model_yolov5.eval() out_from_yolov5 = model_yolov5.predict(img_path) assert isinstance(out_from_yolov5[0], dict) diff --git a/test/test_onnx.py b/test/test_onnx.py index eb777690..fa3077ea 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -116,7 +116,7 @@ def get_test_images(self): [ ("yolov5s", "r3.1"), ("yolov5m", "r4.0"), - # ('yolotr', 'r4.0'), + # ("yolov5ts", "r4.0"), ], ) def test_yolort_export_onnx(self, arch, upstream_version): diff --git a/yolort/models/__init__.py b/yolort/models/__init__.py index 33b5ddb8..0bbf6b08 100644 --- a/yolort/models/__init__.py +++ b/yolort/models/__init__.py @@ -8,7 +8,7 @@ from .yolo import YOLO from .yolo_module import YOLOv5 -__all__ = ["YOLO", "YOLOv5", "yolov5s", "yolov5m", "yolov5l", "yolotr"] +__all__ = ["YOLO", "YOLOv5", "yolov5s", "yolov5m", "yolov5l", "yolov5ts"] def yolov5s( @@ -80,7 +80,7 @@ def yolov5l( return model -def yolotr( +def yolov5ts( upstream_version: str = "r4.0", export_friendly: bool = False, **kwargs: Any ): """ diff --git a/yolort/models/_utils.py b/yolort/models/_utils.py index a40fe4a7..fbde3f8b 100644 --- a/yolort/models/_utils.py +++ b/yolort/models/_utils.py @@ -9,15 +9,21 @@ from yolort.v5 import load_yolov5_model, get_yolov5_size -def load_from_ultralytics(checkpoint_path: str, version: str = "r4.0"): +def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"): """ Load YOLOv5 state from the checkpoint trained from the ultralytics. Args: checkpoint_path (str): Path of the YOLOv5 checkpoint model. - version (str): upstream version released by the ultralytics/yolov5, - versions r3.1 and r4.0 are currently supported. + version (str): upstream version released by the ultralytics/yolov5, Possible + values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0". """ + + assert version in [ + "r3.1", + "r4.0", + "r6.0", + ], "Currently only supports version 'r3.1', 'r4.0' and 'r6.0'." checkpoint_yolov5 = load_yolov5_model(checkpoint_path) num_classes = checkpoint_yolov5.yaml["nc"] anchor_grids = checkpoint_yolov5.yaml["anchors"] diff --git a/yolort/models/backbone_utils.py b/yolort/models/backbone_utils.py index d858bf9e..8db9ca24 100644 --- a/yolort/models/backbone_utils.py +++ b/yolort/models/backbone_utils.py @@ -24,7 +24,7 @@ class BackboneWithPAN(nn.Module): in_channels_list (List[int]): number of channels for each feature map that is returned, in the order they are present in the OrderedDict depth_multiple (float): depth multiplier - version (str): ultralytics release version: ["r3.1", "r4.0", "r6.0"] + version (str): Module version released by ultralytics: ["r3.1", "r4.0", "r6.0"]. Attributes: out_channels (int): the number of channels in the PAN @@ -55,7 +55,7 @@ def darknet_pan_backbone( width_multiple: float, pretrained: Optional[bool] = False, returned_layers: Optional[List[int]] = None, - version: str = "r4.0", + version: str = "r6.0", ): """ Constructs a specified DarkNet backbone with PAN on top. Freezes the specified number of @@ -64,7 +64,7 @@ def darknet_pan_backbone( Examples: >>> from models.backbone_utils import darknet_pan_backbone - >>> backbone = darknet_pan_backbone('darknet3_1', pretrained=True, trainable_layers=3) + >>> backbone = darknet_pan_backbone("darknet_s_r4_0") >>> # get some dummy image >>> x = torch.rand(1, 3, 64, 64) >>> # compute the output @@ -75,15 +75,19 @@ def darknet_pan_backbone( ('2', torch.Size([1, 512, 2, 2]))] Args: - backbone_name (string): darknet architecture. Possible values are 'DarkNet', 'darknet_s_r3_1', - 'darknet_m_r3_1', 'darknet_l_r3_1', 'darknet_s_r4_0', 'darknet_m_r4_0', 'darknet_l_r4_0' - norm_layer (torchvision.ops): it is recommended to use the default value. For details visit: - (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267) + backbone_name (string): darknet architecture. Possible values are "darknet_s_r3_1", + "darknet_m_r3_1", "darknet_l_r3_1", "darknet_s_r4_0", "darknet_m_r4_0", + "darknet_l_r4_0", "darknet_s_r6_0", "darknet_m_r6_0", and "darknet_l_r6_0". pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_layers (int): number of trainable (not frozen) darknet layers starting from final block. - Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. - version (str): ultralytics release version: ["r3.1", "r4.0", "r6.0"] + version (str): Module version released by ultralytics. Possible values + are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0". """ + assert version in [ + "r3.1", + "r4.0", + "r6.0", + ], "Currently only supports version 'r3.1', 'r4.0' and 'r6.0'." + backbone = darknet.__dict__[backbone_name](pretrained=pretrained).features if returned_layers is None: diff --git a/yolort/models/darknetv5.py b/yolort/models/darknetv5.py index f95648bc..f85451f4 100644 --- a/yolort/models/darknetv5.py +++ b/yolort/models/darknetv5.py @@ -57,6 +57,10 @@ def __init__( ) -> None: super().__init__() + assert version in ["r3.1", "r4.0"], ( + "Currently the module version used in DarkNetV5 is r3.1 or r4.0", + ) + if block is None: block = _block[version] @@ -128,7 +132,7 @@ def forward(self, x: Tensor) -> Tensor: } -def _darknet( +def _darknetv5( arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any ) -> DarkNetV5: """ @@ -162,7 +166,9 @@ def darknet_s_r3_1( pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _darknet("darknet_s_r3.1", pretrained, progress, 0.33, 0.5, "r3.1", **kwargs) + return _darknetv5( + "darknet_s_r3.1", pretrained, progress, 0.33, 0.5, "r3.1", **kwargs + ) def darknet_m_r3_1( @@ -176,7 +182,7 @@ def darknet_m_r3_1( pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _darknet( + return _darknetv5( "darknet_m_r3.1", pretrained, progress, 0.67, 0.75, "r3.1", **kwargs ) @@ -192,7 +198,9 @@ def darknet_l_r3_1( pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _darknet("darknet_l_r3.1", pretrained, progress, 1.0, 1.0, "r3.1", **kwargs) + return _darknetv5( + "darknet_l_r3.1", pretrained, progress, 1.0, 1.0, "r3.1", **kwargs + ) def darknet_s_r4_0( @@ -206,7 +214,9 @@ def darknet_s_r4_0( pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _darknet("darknet_s_r4.0", pretrained, progress, 0.33, 0.5, "r4.0", **kwargs) + return _darknetv5( + "darknet_s_r4.0", pretrained, progress, 0.33, 0.5, "r4.0", **kwargs + ) def darknet_m_r4_0( @@ -220,7 +230,7 @@ def darknet_m_r4_0( pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _darknet( + return _darknetv5( "darknet_m_r4.0", pretrained, progress, 0.67, 0.75, "r4.0", **kwargs ) @@ -236,4 +246,6 @@ def darknet_l_r4_0( pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _darknet("darknet_l_r4.0", pretrained, progress, 1.0, 1.0, "r4.0", **kwargs) + return _darknetv5( + "darknet_l_r4.0", pretrained, progress, 1.0, 1.0, "r4.0", **kwargs + ) diff --git a/yolort/models/darknetv6.py b/yolort/models/darknetv6.py index 30645b6a..b4528e07 100644 --- a/yolort/models/darknetv6.py +++ b/yolort/models/darknetv6.py @@ -31,7 +31,7 @@ class DarkNetV6(nn.Module): depth_multiple (float): Depth multiplier width_multiple (float): Width multiplier - adjusts number of channels in each layer by this amount - version (str): ultralytics release version: r3.1 or r4.0 + version (str): Module version released by ultralytics, set to r4.0. block: Module specifying inverted residual building block for darknet round_nearest (int): Round the number of channels in each layer to be a multiple of this number. Set to 1 to turn off rounding @@ -42,7 +42,7 @@ def __init__( self, depth_multiple: float, width_multiple: float, - version: str, + version: str = "r4.0", block: Optional[Callable[..., nn.Module]] = None, stages_repeats: Optional[List[int]] = None, stages_out_channels: Optional[List[int]] = None, @@ -51,6 +51,10 @@ def __init__( ) -> None: super().__init__() + assert version == "r4.0", ( + "Currently the module version used in DarkNetV6 is r4.0", + ) + if block is None: block = C3 @@ -117,7 +121,7 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _darknet( +def _darknetv6( arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any ) -> DarkNetV6: """ @@ -151,7 +155,7 @@ def darknet_s_r6_0( pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _darknet("darknet_s_r6.0", pretrained, progress, 0.33, 0.5, "r6.0", **kwargs) + return _darknetv6("darknet_s_r6.0", pretrained, progress, 0.33, 0.5, **kwargs) def darknet_m_r6_0( @@ -165,9 +169,7 @@ def darknet_m_r6_0( pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _darknet( - "darknet_m_r6.0", pretrained, progress, 0.67, 0.75, "r6.0", **kwargs - ) + return _darknetv6("darknet_m_r6.0", pretrained, progress, 0.67, 0.75, **kwargs) def darknet_l_r6_0( @@ -181,4 +183,4 @@ def darknet_l_r6_0( pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _darknet("darknet_l_r6.0", pretrained, progress, 1.0, 1.0, "r6.0", **kwargs) + return _darknetv6("darknet_l_r6.0", pretrained, progress, 1.0, 1.0, **kwargs) diff --git a/yolort/models/path_aggregation_network.py b/yolort/models/path_aggregation_network.py index d4e7ce54..62aea951 100644 --- a/yolort/models/path_aggregation_network.py +++ b/yolort/models/path_aggregation_network.py @@ -57,21 +57,27 @@ def __init__( if version == "r6.0": init_block = SPPF(in_channels_list[2], in_channels_list[2], k=5) + module_version = "r4.0" elif version in ["r3.1", "r4.0"]: init_block = block( in_channels_list[2], in_channels_list[2], n=depth_gain, shortcut=False ) + module_version = version else: raise NotImplementedError(f"Version {version} is not implemented yet.") inner_blocks = [ init_block, - Conv(in_channels_list[2], in_channels_list[1], 1, 1, version=version), + Conv( + in_channels_list[2], in_channels_list[1], 1, 1, version=module_version + ), nn.Upsample(scale_factor=2), block( in_channels_list[2], in_channels_list[1], n=depth_gain, shortcut=False ), - Conv(in_channels_list[1], in_channels_list[0], 1, 1, version=version), + Conv( + in_channels_list[1], in_channels_list[0], 1, 1, version=module_version + ), nn.Upsample(scale_factor=2), ] @@ -81,11 +87,15 @@ def __init__( block( in_channels_list[1], in_channels_list[0], n=depth_gain, shortcut=False ), - Conv(in_channels_list[0], in_channels_list[0], 3, 2, version=version), + Conv( + in_channels_list[0], in_channels_list[0], 3, 2, version=module_version + ), block( in_channels_list[1], in_channels_list[1], n=depth_gain, shortcut=False ), - Conv(in_channels_list[1], in_channels_list[1], 3, 2, version=version), + Conv( + in_channels_list[1], in_channels_list[1], 3, 2, version=module_version + ), block( in_channels_list[2], in_channels_list[2], n=depth_gain, shortcut=False ), diff --git a/yolort/models/transformer.py b/yolort/models/transformer.py index e93d7209..5b62b5ce 100644 --- a/yolort/models/transformer.py +++ b/yolort/models/transformer.py @@ -8,7 +8,7 @@ from torch import nn -from yolort.v5 import Conv, C3 +from yolort.v5 import Conv, C3, C3TR from . import darknet from .backbone_utils import BackboneWithPAN from .path_aggregation_network import PathAggregationNetwork @@ -26,10 +26,10 @@ def darknet_tan_backbone( Constructs a specified DarkNet backbone with TAN on top. Freezes the specified number of layers in the backbone. - Examples:: + Examples: >>> from models.backbone_utils import darknet_tan_backbone - >>> backbone = darknet_tan_backbone('darknet3_1', pretrained=True, trainable_layers=3) + >>> backbone = darknet_tan_backbone("darknet_s_r4_0") >>> # get some dummy image >>> x = torch.rand(1, 3, 64, 64) >>> # compute the output @@ -41,15 +41,14 @@ def darknet_tan_backbone( >>> ('2', torch.Size([1, 512, 2, 2]))] Args: - backbone_name (string): darknet architecture. Possible values are 'DarkNet', 'darknet_s_r3_1', - 'darknet_m_r3_1', 'darknet_l_r3_1', 'darknet_s_r4_0', 'darknet_m_r4_0', 'darknet_l_r4_0' + backbone_name (string): darknet architecture. Possible values are "darknet_s_r4_0" Now. norm_layer (torchvision.ops): it is recommended to use the default value. For details visit: (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267) pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_layers (int): number of trainable (not frozen) darknet layers starting from final block. - Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. - version (str): ultralytics release version, currently only supports r3.1 or r4.0 + version (str): Module version released by ultralytics, set to "r4.0". """ + assert version == "r4.0", "Currently only supports version r4.0." + backbone = darknet.__dict__[backbone_name](pretrained=pretrained).features if returned_layers is None: @@ -59,9 +58,7 @@ def darknet_tan_backbone( in_channels_list = [int(gw * width_multiple) for gw in [256, 512, 1024]] - return BackboneWithTAN( - backbone, return_layers, in_channels_list, depth_multiple, version - ) + return BackboneWithTAN(backbone, return_layers, in_channels_list, depth_multiple) class BackboneWithTAN(BackboneWithPAN): @@ -69,16 +66,14 @@ class BackboneWithTAN(BackboneWithPAN): Adds a TAN on top of a model. """ - def __init__( - self, backbone, return_layers, in_channels_list, depth_multiple, version - ): + def __init__(self, backbone, return_layers, in_channels_list, depth_multiple): super().__init__( - backbone, return_layers, in_channels_list, depth_multiple, version + backbone, return_layers, in_channels_list, depth_multiple, "r4.0" ) self.pan = TransformerAttentionNetwork( in_channels_list, depth_multiple, - version=version, + version="r4.0", ) @@ -87,11 +82,12 @@ def __init__( self, in_channels_list: List[int], depth_multiple: float, - version: str, + version: str = "r4.0", block: Optional[Callable[..., nn.Module]] = None, ): super().__init__(in_channels_list, depth_multiple, version=version, block=block) - assert len(in_channels_list) == 3, "currently only support length 3." + assert len(in_channels_list) == 3, "Currently only supports length 3." + assert version == "r4.0", "Currently only supports version r4.0." if block is None: block = C3 @@ -121,72 +117,3 @@ def __init__( m.momentum = 0.03 elif isinstance(m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6)): m.inplace = True - - -class C3TR(C3): - def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): - super().__init__(c1, c2, n, shortcut, g, e) - c_ = int(c2 * e) - self.m = TransformerBlock(c_, c_, 4, n) - - -class TransformerLayer(nn.Module): - def __init__(self, c, num_heads): - """ - Args: - c (int): number of channels - num_heads: number of heads - """ - super().__init__() - self.q = nn.Linear(c, c, bias=False) - self.k = nn.Linear(c, c, bias=False) - self.v = nn.Linear(c, c, bias=False) - - self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads) - self.fc1 = nn.Linear(c, c, bias=False) - self.fc2 = nn.Linear(c, c, bias=False) - - def forward(self, x): - x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x - x = self.fc2(self.fc1(x)) + x - return x - - -class TransformerBlock(nn.Module): - def __init__(self, c1, c2, num_heads, num_layers): - """ - Args: - c1 (int): number of input channels - c2 (int): number of output channels - num_heads: number of heads - num_layers: number of layers - """ - super().__init__() - - self.conv = None - if c1 != c2: - self.conv = Conv(c1, c2) - self.linear = nn.Linear(c2, c2) - self.tr = nn.Sequential( - *[TransformerLayer(c2, num_heads) for _ in range(num_layers)] - ) - self.c2 = c2 - - def forward(self, x): - if self.conv is not None: - x = self.conv(x) - - b, _, w, h = x.shape - p = x.flatten(2) - p = p.unsqueeze(0) - p = p.transpose(0, 3) - p = p.squeeze(3) - - e = self.linear(p) - x = p + e - - x = self.tr(x) - x = x.unsqueeze(3) - x = x.transpose(0, 3) - x = x.reshape(b, self.c2, w, h) - return x diff --git a/yolort/models/yolo.py b/yolort/models/yolo.py index cc99b1fa..5390e6e7 100644 --- a/yolort/models/yolo.py +++ b/yolort/models/yolo.py @@ -183,10 +183,17 @@ def load_from_yolov5( checkpoint_path: str, score_thresh: float = 0.25, nms_thresh: float = 0.45, - version: str = "r4.0", + version: str = "r6.0", ): """ Load model state from the checkpoint trained by YOLOv5. + + Args: + checkpoint_path (str): Path of the YOLOv5 checkpoint model. + score_thresh (float): Score threshold used for postprocessing the detections. + nms_thresh (float): NMS threshold used for postprocessing the detections. + version (str): upstream version released by the ultralytics/yolov5, Possible + values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0". """ model_info = load_from_ultralytics(checkpoint_path, version=version) backbone_name = f"darknet_{model_info['size']}_{version.replace('.', '_')}" @@ -243,9 +250,11 @@ def _yolov5_darknet_pan( >>> x = torch.rand(4, 3, 416, 320) >>> predictions = model(x) - Arguments: + Args: pretrained (bool): If True, returns a model pre-trained on COCO train2017 progress (bool): If True, displays a progress bar of the download to stderr + version (str): Module version released by ultralytics. Possible values + are ["r3.1", "r4.0", "r6.0"]. """ backbone = darknet_pan_backbone( backbone_name, depth_multiple, width_multiple, version=version diff --git a/yolort/models/yolo_module.py b/yolort/models/yolo_module.py index 87312f5d..6d19816e 100644 --- a/yolort/models/yolo_module.py +++ b/yolort/models/yolo_module.py @@ -298,12 +298,22 @@ def load_from_yolov5( size: Tuple[int, int] = (640, 640), score_thresh: float = 0.25, nms_thresh: float = 0.45, - version: str = "r4.0", + version: str = "r6.0", ): """ Load model state from the checkpoint trained by YOLOv5. + + Args: + checkpoint_path (str): Path of the YOLOv5 checkpoint model. + lr (float): The initial learning rate + size: (Tuple[int, int]): the width and height to which images will be rescaled + before feeding them to the backbone. Default: (640, 640). + score_thresh (float): Score threshold used for postprocessing the detections. + nms_thresh (float): NMS threshold used for postprocessing the detections. + version (str): upstream version released by the ultralytics/yolov5, Possible + values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0". """ - model_info = load_from_ultralytics(checkpoint_path) + model_info = load_from_ultralytics(checkpoint_path, version=version) arch = f"yolov5_darknet_pan_{model_info['size']}_{version.replace('.', '')}" yolov5 = cls( lr=lr, diff --git a/yolort/utils/update_module_state.py b/yolort/utils/update_module_state.py index 046bde59..80e7f8f4 100644 --- a/yolort/utils/update_module_state.py +++ b/yolort/utils/update_module_state.py @@ -75,7 +75,7 @@ def __init__( arch: Optional[str] = "yolov5_darknet_pan_s_r31", depth_multiple: Optional[float] = None, width_multiple: Optional[float] = None, - version: str = "r4.0", + version: str = "r6.0", num_classes: int = 80, inner_block_maps: Optional[Dict[str, str]] = None, layer_block_maps: Optional[Dict[str, str]] = None, diff --git a/yolort/v5/models/__init__.py b/yolort/v5/models/__init__.py index 95dc795c..11b93328 100644 --- a/yolort/v5/models/__init__.py +++ b/yolort/v5/models/__init__.py @@ -7,6 +7,7 @@ Focus, BottleneckCSP, C3, + C3TR, Concat, GhostConv, GhostBottleneck, @@ -28,6 +29,7 @@ "Focus", "BottleneckCSP", "C3", + "C3TR", "Concat", "GhostConv", "GhostBottleneck", diff --git a/yolort/v5/models/common.py b/yolort/v5/models/common.py index 9df71563..b505d4f6 100644 --- a/yolort/v5/models/common.py +++ b/yolort/v5/models/common.py @@ -54,14 +54,14 @@ class Conv(nn.Module): g (int): groups act (bool or nn.Module): determine the activation function version (str): Module version released by ultralytics. Possible values - are ["r3.1", "r4.0", "r5.0", "r6.0"]. Default: "r6.0". + are ["r3.1", "r4.0"]. Default: "r4.0". """ - def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, version="r6.0"): + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, version="r4.0"): super().__init__() self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) self.bn = nn.BatchNorm2d(c2) - if version in ["r4.0", "r5.0", "r6.0"]: + if version == "r4.0": self.act = ( nn.SiLU() if act @@ -74,7 +74,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, version="r6.0"): else (act if isinstance(act, nn.Module) else nn.Identity()) ) else: - raise NotImplementedError("Currently only supports version above r3.1") + raise NotImplementedError(f"Currently doesn't support version {version}.") def forward(self, x: Tensor) -> Tensor: return self.act(self.bn(self.conv(x))) @@ -94,10 +94,10 @@ class DWConv(Conv): s (int): stride act (bool or nn.Module): determine the activation function version (str): Module version released by ultralytics. Possible values - are ["r3.1", "r4.0", "r5.0", "r6.0"]. Default: "r6.0". + are ["r3.1", "r4.0"]. Default: "r4.0". """ - def __init__(self, c1, c2, k=1, s=1, act=True, version="r6.0"): + def __init__(self, c1, c2, k=1, s=1, act=True, version="r4.0"): super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act, version=version) @@ -112,10 +112,10 @@ class Bottleneck(nn.Module): g (int): groups e (float): expansion version (str): Module version released by ultralytics. Possible values - are ["r3.1", "r4.0", "r5.0", "r6.0"]. Default: "r6.0". + are ["r3.1", "r4.0"]. Default: "r4.0". """ - def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, version="r6.0"): + def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, version="r4.0"): super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, 1, 1, version=version) @@ -169,16 +169,18 @@ class C3(nn.Module): shortcut (bool): shortcut g (int): groups e (float): expansion + version (str): Module version released by ultralytics. Possible values + are ["r4.0"]. Default: "r4.0". """ - def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, version="r4.0"): super().__init__() c_ = int(c2 * e) # hidden channels - self.cv1 = Conv(c1, c_, 1, 1) - self.cv2 = Conv(c1, c_, 1, 1) - self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) + self.cv1 = Conv(c1, c_, 1, 1, version=version) + self.cv2 = Conv(c1, c_, 1, 1, version=version) + self.cv3 = Conv(2 * c_, c2, 1, version=version) # act=FReLU(c2) self.m = nn.Sequential( - *[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)] + *[Bottleneck(c_, c_, shortcut, g, e=1.0, version=version) for _ in range(n)] ) def forward(self, x): @@ -187,7 +189,7 @@ def forward(self, x): class SPP(nn.Module): # Spatial pyramid pooling layer used in YOLOv3-SPP - def __init__(self, c1, c2, k=(5, 9, 13), version="r6.0"): + def __init__(self, c1, c2, k=(5, 9, 13), version="r4.0"): super().__init__() c_ = c1 // 2 # hidden channels self.cv1 = Conv(c1, c_, 1, 1, version=version) @@ -206,7 +208,7 @@ class SPPF(nn.Module): Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher """ - def __init__(self, c1, c2, k=5, version="r6.0"): + def __init__(self, c1, c2, k=5, version="r4.0"): # Equivalent to SPP(k=(5, 9, 13)) when k=5 super().__init__() c_ = c1 // 2 # hidden channels @@ -234,10 +236,10 @@ class Focus(nn.Module): g (int): groups act (bool or nn.Module): determine the activation function version (str): Module version released by ultralytics. Possible values - are ["r3.1", "r4.0", "r5.0"]. Default: "r5.0". + are ["r3.1", "r4.0"]. Default: "r4.0". """ - def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, version="r5.0"): + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, version="r4.0"): super().__init__() self.conv = Conv(c1 * 4, c2, k, s, p, g, act, version=version) @@ -289,12 +291,21 @@ def forward(x): class TransformerLayer(nn.Module): - # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance) + """ + Transformer layer . + Remove the LayerNorm layers for better performance + + Args: + c (int): number of channels + num_heads: number of heads + """ + def __init__(self, c, num_heads): super().__init__() self.q = nn.Linear(c, c, bias=False) self.k = nn.Linear(c, c, bias=False) self.v = nn.Linear(c, c, bias=False) + self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads) self.fc1 = nn.Linear(c, c, bias=False) self.fc2 = nn.Linear(c, c, bias=False) @@ -306,12 +317,21 @@ def forward(self, x): class TransformerBlock(nn.Module): - # Vision Transformer https://arxiv.org/abs/2010.11929 + """ + Vision Transformer . + + Args: + c1 (int): number of input channels + c2 (int): number of output channels + num_heads: number of heads + num_layers: number of layers + """ + def __init__(self, c1, c2, num_heads, num_layers): super().__init__() self.conv = None if c1 != c2: - self.conv = Conv(c1, c2) + self.conv = Conv(c1, c2, version="r4.0") self.linear = nn.Linear(c2, c2) # learnable position embedding self.tr = nn.Sequential( *[TransformerLayer(c2, num_heads) for _ in range(num_layers)] @@ -334,7 +354,7 @@ def forward(self, x): class C3TR(C3): # C3 module with TransformerBlock() def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): - super().__init__(c1, c2, n, shortcut, g, e) + super().__init__(c1, c2, n, shortcut, g, e, version="r4.0") c_ = int(c2 * e) self.m = TransformerBlock(c_, c_, 4, n)