Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Adding UPerNet #926

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .decoders.pspnet import PSPNet
from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
from .decoders.pan import PAN
from .decoders.upernet import UPerNet
from .base.hub_mixin import from_pretrained

from .__version__ import __version__
Expand Down Expand Up @@ -48,6 +49,7 @@ def create_model(
DeepLabV3,
DeepLabV3Plus,
PAN,
UPerNet,
]
archs_dict = {a.__name__.lower(): a for a in archs}
try:
Expand Down Expand Up @@ -82,6 +84,7 @@ def create_model(
"DeepLabV3",
"DeepLabV3Plus",
"PAN",
"UPerNet",
"from_pretrained",
"create_model",
"__version__",
Expand Down
3 changes: 3 additions & 0 deletions segmentation_models_pytorch/decoders/upernet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .model import UPerNet

__all__ = ["UPerNet"]
134 changes: 134 additions & 0 deletions segmentation_models_pytorch/decoders/upernet/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from segmentation_models_pytorch.base import modules as md


class PSPModule(nn.Module):
def __init__(
self,
in_channels,
out_channels,
sizes=(1, 2, 3, 6),
use_batchnorm=True,
):
super().__init__()
self.blocks = nn.ModuleList(
[
nn.Sequential(
nn.AdaptiveAvgPool2d(size),
md.Conv2dReLU(
in_channels,
in_channels // len(sizes),
kernel_size=1,
use_batchnorm=use_batchnorm,
),
)
for size in sizes
]
)
self.out_conv = md.Conv2dReLU(
in_channels=in_channels * 2,
out_channels=out_channels,
kernel_size=1,
use_batchnorm=True,
)

def forward(self, x):
_, _, h, w = x.shape
out = [x] + [
F.interpolate(block(x), size=(h, w), mode="bilinear", align_corners=False)
for block in self.blocks
]
out = self.out_conv(torch.cat(out, dim=1))
return out


class FPNBlock(nn.Module):
def __init__(self, skip_channels, pyramid_channels, use_bathcnorm=True):
super().__init__()
self.skip_conv = (
md.Conv2dReLU(
skip_channels,
pyramid_channels,
kernel_size=1,
use_batchnorm=use_bathcnorm,
)
if skip_channels != 0
else nn.Identity()
)

def forward(self, x, skip):
_, ch, h, w = skip.shape
x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False)
if ch != 0:
skip = self.skip_conv(skip)
x = x + skip
return x


class UPerNetDecoder(nn.Module):
def __init__(
self,
encoder_channels,
encoder_depth=5,
pyramid_channels=256,
segmentation_channels=64,
):
super().__init__()
self.out_channels = segmentation_channels
if encoder_depth < 3:
raise ValueError(
"Encoder depth for UPerNet decoder cannot be less than 3, got {}.".format(
encoder_depth
)
)

encoder_channels = encoder_channels[::-1]

# PSP Module
self.psp = PSPModule(
in_channels=encoder_channels[0],
out_channels=pyramid_channels,
sizes=(1, 2, 3, 6),
use_batchnorm=True,
)

# FPN Module
self.fpn_stages = nn.ModuleList(
[FPNBlock(ch, pyramid_channels) for ch in encoder_channels[1:]]
)

self.fpn_bottleneck = md.Conv2dReLU(
in_channels=(len(encoder_channels) - 1) * pyramid_channels,
out_channels=segmentation_channels,
kernel_size=3,
padding=1,
use_batchnorm=True,
)

def forward(self, *features):
# Resize all FPN features to the size of the largest feature
target_size = features[0].shape[2:]

features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder

psp_out = self.psp(features[0])

fpn_features = [psp_out]
for feature, stage in zip(features[1:], self.fpn_stages):
fpn_feature = stage(fpn_features[-1], feature)
fpn_features.append(fpn_feature)

resized_fpn_features = []
for feature in fpn_features:
resized_feature = F.interpolate(
feature, size=target_size, mode="bilinear", align_corners=False
)
resized_fpn_features.append(resized_feature)

output = self.fpn_bottleneck(torch.cat(resized_fpn_features, dim=1))

return output
91 changes: 91 additions & 0 deletions segmentation_models_pytorch/decoders/upernet/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Optional, Union

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
)
from .decoder import UPerNetDecoder


class UPerNet(SegmentationModel):
"""UPerNet is a unified perceptual parsing network for image segmentation.

Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
to extract features of different spatial resolution
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
Default is 5
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
decoder_pyramid_channels: A number of convolution filters in Feature Pyramid, default is 256
decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 64
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- classes (int): A number of classes
- pooling (str): One of "max", "avg". Default is "avg"
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)

Returns:
``torch.nn.Module``: **UPerNet**

.. _UPerNet:
https://arxiv.org/abs/1807.10221

"""

def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_pyramid_channels: int = 256,
decoder_segmentation_channels: int = 64,
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
):
super().__init__()

self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
)

self.decoder = UPerNetDecoder(
encoder_channels=self.encoder.out_channels,
encoder_depth=encoder_depth,
pyramid_channels=decoder_pyramid_channels,
segmentation_channels=decoder_segmentation_channels,
)

self.segmentation_head = SegmentationHead(
in_channels=self.decoder.out_channels,
out_channels=classes,
activation=activation,
kernel_size=3,
)

if aux_params is not None:
self.classification_head = ClassificationHead(
in_channels=self.encoder.out_channels[-1], **aux_params
)
else:
self.classification_head = None

self.name = "upernet-{}".format(encoder_name)
self.initialize()
4 changes: 3 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_sample(model_class):
smp.PSPNet,
smp.UnetPlusPlus,
smp.MAnet,
smp.UPerNet,
]:
sample = torch.ones([1, 3, 64, 64])
elif model_class == smp.PAN:
Expand Down Expand Up @@ -57,7 +58,8 @@ def _test_forward_backward(model, sample, test_shape=False):
@pytest.mark.parametrize("encoder_name", ENCODERS)
@pytest.mark.parametrize("encoder_depth", [3, 5])
@pytest.mark.parametrize(
"model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus]
"model_class",
[smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.UPerNet],
)
def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
if (
Expand Down