Skip to content

Commit

Permalink
Introduce LastLevelP6 in PAN
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Oct 12, 2021
1 parent 3a39c55 commit 1f738f4
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 11 deletions.
27 changes: 22 additions & 5 deletions yolort/models/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torchvision.models._utils import IntermediateLayerGetter

from . import darknet
from .path_aggregation_network import PathAggregationNetwork
from .path_aggregation_network import PathAggregationNetwork, LastLevelP6


class BackboneWithPAN(nn.Module):
Expand All @@ -25,21 +25,32 @@ class BackboneWithPAN(nn.Module):
that is returned, in the order they are present in the OrderedDict
depth_multiple (float): depth multiplier
version (str): Module version released by ultralytics: ["r3.1", "r4.0", "r6.0"].
use_p6 (bool): Whether to use P6 layers.
Attributes:
out_channels (int): the number of channels in the PAN
"""

def __init__(
self, backbone, return_layers, in_channels_list, depth_multiple, version
self,
backbone,
return_layers,
in_channels_list,
depth_multiple,
version,
use_p6=False,
):
super().__init__()

if use_p6:
extra_blocks = LastLevelP6()

self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.pan = PathAggregationNetwork(
in_channels_list,
depth_multiple,
version=version,
extra_blocks=extra_blocks,
)
self.out_channels = in_channels_list

Expand All @@ -56,6 +67,7 @@ def darknet_pan_backbone(
pretrained: Optional[bool] = False,
returned_layers: Optional[List[int]] = None,
version: str = "r6.0",
use_p6=False,
):
"""
Constructs a specified DarkNet backbone with PAN on top. Freezes the specified number of
Expand All @@ -81,22 +93,27 @@ def darknet_pan_backbone(
pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet
version (str): Module version released by ultralytics. Possible values
are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0".
use_p6 (bool): Whether to use P6 layers.
"""
assert version in [
"r3.1",
"r4.0",
"r6.0",
], "Currently only supports version 'r3.1', 'r4.0' and 'r6.0'."

backbone = darknet.__dict__[backbone_name](pretrained=pretrained).features
last_channel = 768 if use_p6 else 1024
backbone = darknet.__dict__[backbone_name](
pretrained=pretrained,
last_channel=last_channel,
).features

if returned_layers is None:
returned_layers = [4, 6, 8]

return_layers = {str(k): str(i) for i, k in enumerate(returned_layers)}

in_channels_list = [int(gw * width_multiple) for gw in [256, 512, 1024]]
in_channels_list = [int(gw * width_multiple) for gw in [256, 512, last_channel]]

return BackboneWithPAN(
backbone, return_layers, in_channels_list, depth_multiple, version
backbone, return_layers, in_channels_list, depth_multiple, version, use_p6=use_p6
)
9 changes: 6 additions & 3 deletions yolort/models/darknetv5.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ class DarkNetV5(nn.Module):
depth_multiple (float): Depth multiplier
width_multiple (float): Width multiplier - adjusts number of channels
in each layer by this amount
version (str): ultralytics release version: r3.1 or r4.0
version (str): Module version released by ultralytics, set to r4.0.
block: Module specifying inverted residual building block for darknet
stages_repeats (Optional[List[int]]): List of repeats number in the stages.
stages_out_channels (Optional[List[int]]): List of channels number in the stages.
num_classes (int): Number of classes
round_nearest (int): Round the number of channels in each layer to be
a multiple of this number. Set to 1 to turn off rounding
num_classes (int): Number of classes
last_channel (int): Number of the last channel
"""

def __init__(
Expand All @@ -54,6 +57,7 @@ def __init__(
stages_out_channels: Optional[List[int]] = None,
num_classes: int = 1000,
round_nearest: int = 8,
last_channel: int = 1024,
) -> None:
super().__init__()

Expand All @@ -65,7 +69,6 @@ def __init__(
block = _block[version]

input_channel = 64
last_channel = 1024

if stages_repeats is None:
stages_repeats = [3, 9, 9]
Expand Down
7 changes: 5 additions & 2 deletions yolort/models/darknetv6.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ class DarkNetV6(nn.Module):
in each layer by this amount
version (str): Module version released by ultralytics, set to r4.0.
block: Module specifying inverted residual building block for darknet
stages_repeats (Optional[List[int]]): List of repeats number in the stages.
stages_out_channels (Optional[List[int]]): List of channels number in the stages.
num_classes (int): Number of classes
round_nearest (int): Round the number of channels in each layer to be
a multiple of this number. Set to 1 to turn off rounding
num_classes (int): Number of classes
last_channel (int): Number of the last channel
"""

def __init__(
Expand All @@ -48,6 +51,7 @@ def __init__(
stages_out_channels: Optional[List[int]] = None,
num_classes: int = 1000,
round_nearest: int = 8,
last_channel: int = 1024,
) -> None:
super().__init__()

Expand All @@ -59,7 +63,6 @@ def __init__(
block = C3

input_channel = 64
last_channel = 1024

if stages_repeats is None:
stages_repeats = [3, 6, 9]
Expand Down
65 changes: 64 additions & 1 deletion yolort/models/path_aggregation_network.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,37 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from typing import Callable, List, Dict, Optional
from typing import Callable, List, Dict, Optional, Tuple

import torch
from torch import nn, Tensor

from yolort.v5 import Conv, BottleneckCSP, C3, SPPF


class ExtraPANBlock(nn.Module):
"""
Base class for the extra block in the PAN.
Args:
results (List[Tensor]): the result of the PAN
x (List[Tensor]): the original feature maps
names (List[str]): the names for each one of the
original feature maps
Returns:
results (List[Tensor]): the extended set of results
of the PAN
names (List[str]): the extended set of names for the results
"""

def forward(
self,
results: List[Tensor],
x: List[Tensor],
names: List[str],
) -> Tuple[List[Tensor], List[str]]:
pass


class PathAggregationNetwork(nn.Module):
"""
Module that adds a PAN from on top of a set of feature maps. This is based on
Expand Down Expand Up @@ -46,6 +71,7 @@ def __init__(
depth_multiple: float,
version: str = "r4.0",
block: Optional[Callable[..., nn.Module]] = None,
extra_blocks: Optional[ExtraPANBlock] = None,
):
super().__init__()
assert len(in_channels_list) == 3, "Currently only supports length 3."
Expand Down Expand Up @@ -102,6 +128,10 @@ def __init__(
]
self.layer_blocks = nn.ModuleList(layer_blocks)

if extra_blocks is not None:
assert isinstance(extra_blocks, ExtraPANBlock)
self.extra_blocks = extra_blocks

for m in self.modules():
if isinstance(m, nn.Conv2d):
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
Expand Down Expand Up @@ -190,3 +220,36 @@ def forward(self, x: Dict[str, Tensor]) -> List[Tensor]:
"r4.0": C3,
"r6.0": C3,
}


class LastLevelP6(ExtraPANBlock):
"""
This module is used in YOLOv5 to generate extra P6 layers.
"""

def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.p6 = C3(in_channels, out_channels, 3, 2, 1)
self.use_P5 = in_channels == out_channels

for m in self.modules():
if isinstance(m, nn.Conv2d):
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.eps = 1e-3
m.momentum = 0.03
elif isinstance(m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6)):
m.inplace = True

def forward(
self,
p: List[Tensor],
c: List[Tensor],
names: List[str],
) -> Tuple[List[Tensor], List[str]]:
p5, c5 = p[-1], c[-1]
x = p5 if self.use_P5 else c5
p6 = self.p6(x)
p.extend([p6])
names.extend(["p6"])
return p, names

0 comments on commit 1f738f4

Please sign in to comment.