From 1f738f4f383b0a1eed64faa92dd975e71262eeaf Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Tue, 12 Oct 2021 08:52:39 +0800 Subject: [PATCH] Introduce LastLevelP6 in PAN --- yolort/models/backbone_utils.py | 27 ++++++++-- yolort/models/darknetv5.py | 9 ++-- yolort/models/darknetv6.py | 7 ++- yolort/models/path_aggregation_network.py | 65 ++++++++++++++++++++++- 4 files changed, 97 insertions(+), 11 deletions(-) diff --git a/yolort/models/backbone_utils.py b/yolort/models/backbone_utils.py index 8db9ca24f..69fae9d7a 100644 --- a/yolort/models/backbone_utils.py +++ b/yolort/models/backbone_utils.py @@ -5,7 +5,7 @@ from torchvision.models._utils import IntermediateLayerGetter from . import darknet -from .path_aggregation_network import PathAggregationNetwork +from .path_aggregation_network import PathAggregationNetwork, LastLevelP6 class BackboneWithPAN(nn.Module): @@ -25,21 +25,32 @@ class BackboneWithPAN(nn.Module): that is returned, in the order they are present in the OrderedDict depth_multiple (float): depth multiplier version (str): Module version released by ultralytics: ["r3.1", "r4.0", "r6.0"]. + use_p6 (bool): Whether to use P6 layers. Attributes: out_channels (int): the number of channels in the PAN """ def __init__( - self, backbone, return_layers, in_channels_list, depth_multiple, version + self, + backbone, + return_layers, + in_channels_list, + depth_multiple, + version, + use_p6=False, ): super().__init__() + if use_p6: + extra_blocks = LastLevelP6() + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) self.pan = PathAggregationNetwork( in_channels_list, depth_multiple, version=version, + extra_blocks=extra_blocks, ) self.out_channels = in_channels_list @@ -56,6 +67,7 @@ def darknet_pan_backbone( pretrained: Optional[bool] = False, returned_layers: Optional[List[int]] = None, version: str = "r6.0", + use_p6=False, ): """ Constructs a specified DarkNet backbone with PAN on top. Freezes the specified number of @@ -81,6 +93,7 @@ def darknet_pan_backbone( pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet version (str): Module version released by ultralytics. Possible values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0". + use_p6 (bool): Whether to use P6 layers. """ assert version in [ "r3.1", @@ -88,15 +101,19 @@ def darknet_pan_backbone( "r6.0", ], "Currently only supports version 'r3.1', 'r4.0' and 'r6.0'." - backbone = darknet.__dict__[backbone_name](pretrained=pretrained).features + last_channel = 768 if use_p6 else 1024 + backbone = darknet.__dict__[backbone_name]( + pretrained=pretrained, + last_channel=last_channel, + ).features if returned_layers is None: returned_layers = [4, 6, 8] return_layers = {str(k): str(i) for i, k in enumerate(returned_layers)} - in_channels_list = [int(gw * width_multiple) for gw in [256, 512, 1024]] + in_channels_list = [int(gw * width_multiple) for gw in [256, 512, last_channel]] return BackboneWithPAN( - backbone, return_layers, in_channels_list, depth_multiple, version + backbone, return_layers, in_channels_list, depth_multiple, version, use_p6=use_p6 ) diff --git a/yolort/models/darknetv5.py b/yolort/models/darknetv5.py index f85451f49..f0ac14980 100644 --- a/yolort/models/darknetv5.py +++ b/yolort/models/darknetv5.py @@ -37,11 +37,14 @@ class DarkNetV5(nn.Module): depth_multiple (float): Depth multiplier width_multiple (float): Width multiplier - adjusts number of channels in each layer by this amount - version (str): ultralytics release version: r3.1 or r4.0 + version (str): Module version released by ultralytics, set to r4.0. block: Module specifying inverted residual building block for darknet + stages_repeats (Optional[List[int]]): List of repeats number in the stages. + stages_out_channels (Optional[List[int]]): List of channels number in the stages. + num_classes (int): Number of classes round_nearest (int): Round the number of channels in each layer to be a multiple of this number. Set to 1 to turn off rounding - num_classes (int): Number of classes + last_channel (int): Number of the last channel """ def __init__( @@ -54,6 +57,7 @@ def __init__( stages_out_channels: Optional[List[int]] = None, num_classes: int = 1000, round_nearest: int = 8, + last_channel: int = 1024, ) -> None: super().__init__() @@ -65,7 +69,6 @@ def __init__( block = _block[version] input_channel = 64 - last_channel = 1024 if stages_repeats is None: stages_repeats = [3, 9, 9] diff --git a/yolort/models/darknetv6.py b/yolort/models/darknetv6.py index b4528e07f..6d0d02b71 100644 --- a/yolort/models/darknetv6.py +++ b/yolort/models/darknetv6.py @@ -33,9 +33,12 @@ class DarkNetV6(nn.Module): in each layer by this amount version (str): Module version released by ultralytics, set to r4.0. block: Module specifying inverted residual building block for darknet + stages_repeats (Optional[List[int]]): List of repeats number in the stages. + stages_out_channels (Optional[List[int]]): List of channels number in the stages. + num_classes (int): Number of classes round_nearest (int): Round the number of channels in each layer to be a multiple of this number. Set to 1 to turn off rounding - num_classes (int): Number of classes + last_channel (int): Number of the last channel """ def __init__( @@ -48,6 +51,7 @@ def __init__( stages_out_channels: Optional[List[int]] = None, num_classes: int = 1000, round_nearest: int = 8, + last_channel: int = 1024, ) -> None: super().__init__() @@ -59,7 +63,6 @@ def __init__( block = C3 input_channel = 64 - last_channel = 1024 if stages_repeats is None: stages_repeats = [3, 6, 9] diff --git a/yolort/models/path_aggregation_network.py b/yolort/models/path_aggregation_network.py index 62aea9514..b51c70c8d 100644 --- a/yolort/models/path_aggregation_network.py +++ b/yolort/models/path_aggregation_network.py @@ -1,5 +1,5 @@ # Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. -from typing import Callable, List, Dict, Optional +from typing import Callable, List, Dict, Optional, Tuple import torch from torch import nn, Tensor @@ -7,6 +7,31 @@ from yolort.v5 import Conv, BottleneckCSP, C3, SPPF +class ExtraPANBlock(nn.Module): + """ + Base class for the extra block in the PAN. + + Args: + results (List[Tensor]): the result of the PAN + x (List[Tensor]): the original feature maps + names (List[str]): the names for each one of the + original feature maps + + Returns: + results (List[Tensor]): the extended set of results + of the PAN + names (List[str]): the extended set of names for the results + """ + + def forward( + self, + results: List[Tensor], + x: List[Tensor], + names: List[str], + ) -> Tuple[List[Tensor], List[str]]: + pass + + class PathAggregationNetwork(nn.Module): """ Module that adds a PAN from on top of a set of feature maps. This is based on @@ -46,6 +71,7 @@ def __init__( depth_multiple: float, version: str = "r4.0", block: Optional[Callable[..., nn.Module]] = None, + extra_blocks: Optional[ExtraPANBlock] = None, ): super().__init__() assert len(in_channels_list) == 3, "Currently only supports length 3." @@ -102,6 +128,10 @@ def __init__( ] self.layer_blocks = nn.ModuleList(layer_blocks) + if extra_blocks is not None: + assert isinstance(extra_blocks, ExtraPANBlock) + self.extra_blocks = extra_blocks + for m in self.modules(): if isinstance(m, nn.Conv2d): pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') @@ -190,3 +220,36 @@ def forward(self, x: Dict[str, Tensor]) -> List[Tensor]: "r4.0": C3, "r6.0": C3, } + + +class LastLevelP6(ExtraPANBlock): + """ + This module is used in YOLOv5 to generate extra P6 layers. + """ + + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.p6 = C3(in_channels, out_channels, 3, 2, 1) + self.use_P5 = in_channels == out_channels + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.eps = 1e-3 + m.momentum = 0.03 + elif isinstance(m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6)): + m.inplace = True + + def forward( + self, + p: List[Tensor], + c: List[Tensor], + names: List[str], + ) -> Tuple[List[Tensor], List[str]]: + p5, c5 = p[-1], c[-1] + x = p5 if self.use_P5 else c5 + p6 = self.p6(x) + p.extend([p6]) + names.extend(["p6"]) + return p, names