From f2f717a03cda032513a7158046b2d08714007bac Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 15 Sep 2024 16:21:08 +0800 Subject: [PATCH 1/3] add UPerNet --- segmentation_models_pytorch/__init__.py | 3 + .../decoders/upernet/__init__.py | 3 + .../decoders/upernet/decoder.py | 134 ++++++++++++++++++ .../decoders/upernet/model.py | 91 ++++++++++++ 4 files changed, 231 insertions(+) create mode 100644 segmentation_models_pytorch/decoders/upernet/__init__.py create mode 100644 segmentation_models_pytorch/decoders/upernet/decoder.py create mode 100644 segmentation_models_pytorch/decoders/upernet/model.py diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py index d3778ecc..5733e7b9 100644 --- a/segmentation_models_pytorch/__init__.py +++ b/segmentation_models_pytorch/__init__.py @@ -14,6 +14,7 @@ from .decoders.pspnet import PSPNet from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus from .decoders.pan import PAN +from .decoders.upernet import UPerNet from .base.hub_mixin import from_pretrained from .__version__ import __version__ @@ -48,6 +49,7 @@ def create_model( DeepLabV3, DeepLabV3Plus, PAN, + UPerNet, ] archs_dict = {a.__name__.lower(): a for a in archs} try: @@ -82,6 +84,7 @@ def create_model( "DeepLabV3", "DeepLabV3Plus", "PAN", + "UPerNet", "from_pretrained", "create_model", "__version__", diff --git a/segmentation_models_pytorch/decoders/upernet/__init__.py b/segmentation_models_pytorch/decoders/upernet/__init__.py new file mode 100644 index 00000000..012967b5 --- /dev/null +++ b/segmentation_models_pytorch/decoders/upernet/__init__.py @@ -0,0 +1,3 @@ +from .model import UPerNet + +__all__ = ["UPerNet"] diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py new file mode 100644 index 00000000..fe823822 --- /dev/null +++ b/segmentation_models_pytorch/decoders/upernet/decoder.py @@ -0,0 +1,134 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from segmentation_models_pytorch.base import modules as md + + +class PSPModule(nn.Module): + def __init__( + self, + in_channels, + out_channels, + sizes=(1, 2, 3, 6), + use_batchnorm=True, + ): + super().__init__() + self.blocks = nn.ModuleList( + [ + nn.Sequential( + nn.AdaptiveAvgPool2d(size), + md.Conv2dReLU( + in_channels, + in_channels // len(sizes), + kernel_size=1, + use_batchnorm=use_batchnorm, + ), + ) + for size in sizes + ] + ) + self.out_conv = md.Conv2dReLU( + in_channels=in_channels * 2, + out_channels=out_channels, + kernel_size=1, + use_batchnorm=True, + ) + + def forward(self, x): + _, _, h, w = x.shape + out = [x] + [ + F.interpolate(block(x), size=(h, w), mode="bilinear", align_corners=False) + for block in self.blocks + ] + out = self.out_conv(torch.cat(out, dim=1)) + return out + + +class FPNBlock(nn.Module): + def __init__(self, skip_channels, pyramid_channels, use_bathcnorm=True): + super().__init__() + self.skip_conv = ( + md.Conv2dReLU( + skip_channels, + pyramid_channels, + kernel_size=1, + use_batchnorm=use_bathcnorm, + ) + if skip_channels != 0 + else nn.Identity() + ) + + def forward(self, x, skip): + _, ch, h, w = skip.shape + x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False) + if ch != 0: + skip = self.skip_conv(skip) + x = x + skip + return x + + +class UPerNetDecoder(nn.Module): + def __init__( + self, + encoder_channels, + encoder_depth=5, + pyramid_channels=256, + segmentation_channels=64, + ): + super().__init__() + self.out_channels = segmentation_channels + if encoder_depth < 3: + raise ValueError( + "Encoder depth for UPerNet decoder cannot be less than 3, got {}.".format( + encoder_depth + ) + ) + + encoder_channels = encoder_channels[::-1] + + # PSP Module + self.psp = PSPModule( + in_channels=encoder_channels[0], + out_channels=pyramid_channels, + sizes=(1, 2, 3, 6), + use_batchnorm=True, + ) + + # FPN Module + self.fpn_stages = nn.ModuleList( + [FPNBlock(ch, pyramid_channels) for ch in encoder_channels[1:]] + ) + + self.fpn_bottleneck = md.Conv2dReLU( + in_channels=(len(encoder_channels) - 1) * pyramid_channels, + out_channels=segmentation_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + + def forward(self, *features): + # Resize all FPN features to the size of the largest feature + target_size = features[0].shape[2:] + + features = features[1:] # remove first skip with same spatial resolution + features = features[::-1] # reverse channels to start from head of encoder + + psp_out = self.psp(features[0]) + + fpn_features = [psp_out] + for feature, stage in zip(features[1:], self.fpn_stages): + fpn_feature = stage(fpn_features[-1], feature) + fpn_features.append(fpn_feature) + + resized_fpn_features = [] + for feature in fpn_features: + resized_feature = F.interpolate( + feature, size=target_size, mode="bilinear", align_corners=False + ) + resized_fpn_features.append(resized_feature) + + output = self.fpn_bottleneck(torch.cat(resized_fpn_features, dim=1)) + + return output diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py new file mode 100644 index 00000000..0de37a6c --- /dev/null +++ b/segmentation_models_pytorch/decoders/upernet/model.py @@ -0,0 +1,91 @@ +from typing import Optional, Union + +from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base import ( + SegmentationModel, + SegmentationHead, + ClassificationHead, +) +from .decoder import UPerNetDecoder + + +class UPerNet(SegmentationModel): + """UPerNet is a unified perceptual parsing network for image segmentation. + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_pyramid_channels: A number of convolution filters in Feature Pyramid, default is 256 + decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 64 + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: **UPerNet** + + .. _UPerNet: + https://arxiv.org/abs/1505.04597 + + """ + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + decoder_pyramid_channels: int = 256, + decoder_segmentation_channels: int = 64, + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, callable]] = None, + aux_params: Optional[dict] = None, + ): + super().__init__() + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + ) + + self.decoder = UPerNetDecoder( + encoder_channels=self.encoder.out_channels, + encoder_depth=encoder_depth, + pyramid_channels=decoder_pyramid_channels, + segmentation_channels=decoder_segmentation_channels, + ) + + self.segmentation_head = SegmentationHead( + in_channels=self.decoder.out_channels, + out_channels=classes, + activation=activation, + kernel_size=3, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "upernet-{}".format(encoder_name) + self.initialize() From 43000a3ba98e1250e5cf64d123889492031a00da Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 15 Sep 2024 19:02:32 +0800 Subject: [PATCH 2/3] update paper link --- segmentation_models_pytorch/decoders/upernet/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py index 0de37a6c..5523b72e 100644 --- a/segmentation_models_pytorch/decoders/upernet/model.py +++ b/segmentation_models_pytorch/decoders/upernet/model.py @@ -41,7 +41,7 @@ class UPerNet(SegmentationModel): ``torch.nn.Module``: **UPerNet** .. _UPerNet: - https://arxiv.org/abs/1505.04597 + https://arxiv.org/abs/1807.10221 """ From 748070ebd4ad7672821a1efda529c8454257442f Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 15 Sep 2024 19:23:05 +0800 Subject: [PATCH 3/3] update tests add UPerNet for test_models --- tests/test_models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index f78f55d6..a1b5f2c6 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -29,6 +29,7 @@ def get_sample(model_class): smp.PSPNet, smp.UnetPlusPlus, smp.MAnet, + smp.UPerNet, ]: sample = torch.ones([1, 3, 64, 64]) elif model_class == smp.PAN: @@ -57,7 +58,8 @@ def _test_forward_backward(model, sample, test_shape=False): @pytest.mark.parametrize("encoder_name", ENCODERS) @pytest.mark.parametrize("encoder_depth", [3, 5]) @pytest.mark.parametrize( - "model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus] + "model_class", + [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.UPerNet], ) def test_forward(model_class, encoder_name, encoder_depth, **kwargs): if (