diff --git a/pl_examples/full_examples/semantic_segmentation/models/unet/__init__.py b/pl_examples/full_examples/semantic_segmentation/models/unet/__init__.py new file mode 100644 index 0000000000000..903ed9a2e78c1 --- /dev/null +++ b/pl_examples/full_examples/semantic_segmentation/models/unet/__init__.py @@ -0,0 +1,4 @@ +# For relative imports to work in Python 3.6 +import os +import sys +sys.path.append(os.path.dirname(os.path.realpath(__file__))) diff --git a/pl_examples/full_examples/semantic_segmentation/models/unet/model.py b/pl_examples/full_examples/semantic_segmentation/models/unet/model.py new file mode 100644 index 0000000000000..ee87104599b0a --- /dev/null +++ b/pl_examples/full_examples/semantic_segmentation/models/unet/model.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from parts import DoubleConv, Down, Up + + +class UNet(nn.Module): + ''' + Architecture based on U-Net: Convolutional Networks for Biomedical Image Segmentation + Link - https://arxiv.org/abs/1505.04597 + ''' + def __init__(self, num_classes=19, bilinear=False): + super().__init__() + self.bilinear = bilinear + self.num_classes = num_classes + self.layer1 = DoubleConv(3, 64) + self.layer2 = Down(64, 128) + self.layer3 = Down(128, 256) + self.layer4 = Down(256, 512) + self.layer5 = Down(512, 1024) + + self.layer6 = Up(1024, 512, bilinear=self.bilinear) + self.layer7 = Up(512, 256, bilinear=self.bilinear) + self.layer8 = Up(256, 128, bilinear=self.bilinear) + self.layer9 = Up(128, 64, bilinear=self.bilinear) + + self.layer10 = nn.Conv2d(64, self.num_classes, kernel_size=1) + + def forward(self, x): + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + x5 = self.layer5(x4) + + x6 = self.layer6(x5, x4) + x6 = self.layer7(x6, x3) + x6 = self.layer8(x6, x2) + x6 = self.layer9(x6, x1) + + return self.layer10(x6) diff --git a/pl_examples/full_examples/semantic_segmentation/models/unet/parts.py b/pl_examples/full_examples/semantic_segmentation/models/unet/parts.py new file mode 100644 index 0000000000000..03937846adda8 --- /dev/null +++ b/pl_examples/full_examples/semantic_segmentation/models/unet/parts.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DoubleConv(nn.Module): + ''' + Double Convolution and BN and ReLU + (3x3 conv -> BN -> ReLU) ** 2 + ''' + def __init__(self, in_ch, out_ch): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.net(x) + + +class Down(nn.Module): + ''' + Combination of MaxPool2d and DoubleConv in series + ''' + def __init__(self, in_ch, out_ch): + super().__init__() + self.net = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), + DoubleConv(in_ch, out_ch) + ) + + def forward(self, x): + return self.net(x) + + +class Up(nn.Module): + ''' + Upsampling (by either bilinear interpolation or transpose convolutions) + followed by concatenation of feature map from contracting path, + followed by double 3x3 convolution. + ''' + def __init__(self, in_ch, out_ch, bilinear=False): + super().__init__() + self.upsample = None + if bilinear: + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + else: + self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2) + + self.conv = DoubleConv(in_ch, out_ch) + + def forward(self, x1, x2): + x1 = self.upsample(x1) + + # Pad x1 to the size of x2 + diff_h = x2.shape[2] - x1.shape[2] + diff_w = x2.shape[3] - x1.shape[3] + + x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2]) + + # Concatenate along the channels axis + x = torch.cat([x2, x1], dim=1) + return self.conv(x) diff --git a/pl_examples/full_examples/semantic_segmentation/semseg.py b/pl_examples/full_examples/semantic_segmentation/semseg.py index c763a39f1f64c..1f8a5e9954f22 100644 --- a/pl_examples/full_examples/semantic_segmentation/semseg.py +++ b/pl_examples/full_examples/semantic_segmentation/semseg.py @@ -10,9 +10,9 @@ import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader, Dataset -from torchvision.models.segmentation import fcn_resnet50 import pytorch_lightning as pl +from models.unet.model import UNet class KITTI(Dataset): @@ -128,9 +128,7 @@ def __init__(self, hparams): self.root_path = hparams.root self.batch_size = hparams.batch_size self.learning_rate = hparams.lr - self.net = torchvision.models.segmentation.fcn_resnet50(pretrained=False, - progress=True, - num_classes=19) + self.net = UNet(num_classes=19) self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], @@ -147,7 +145,7 @@ def training_step(self, batch, batch_nb): img = img.float() mask = mask.long() out = self.forward(img) - loss_val = F.cross_entropy(out['out'], mask, ignore_index=250) + loss_val = F.cross_entropy(out, mask, ignore_index=250) return {'loss': loss_val} def configure_optimizers(self):