Skip to content

Commit

Permalink
Fix segmentation example (#876)
Browse files Browse the repository at this point in the history
* removed torchvision model and added custom model

* minor fix

* Fixed relative imports issue
  • Loading branch information
akshaykulkarni07 committed Feb 17, 2020
1 parent 6029fad commit 43ac63f
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# For relative imports to work in Python 3.6
import os
import sys
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from parts import DoubleConv, Down, Up


class UNet(nn.Module):
'''
Architecture based on U-Net: Convolutional Networks for Biomedical Image Segmentation
Link - https://arxiv.org/abs/1505.04597
'''
def __init__(self, num_classes=19, bilinear=False):
super().__init__()
self.bilinear = bilinear
self.num_classes = num_classes
self.layer1 = DoubleConv(3, 64)
self.layer2 = Down(64, 128)
self.layer3 = Down(128, 256)
self.layer4 = Down(256, 512)
self.layer5 = Down(512, 1024)

self.layer6 = Up(1024, 512, bilinear=self.bilinear)
self.layer7 = Up(512, 256, bilinear=self.bilinear)
self.layer8 = Up(256, 128, bilinear=self.bilinear)
self.layer9 = Up(128, 64, bilinear=self.bilinear)

self.layer10 = nn.Conv2d(64, self.num_classes, kernel_size=1)

def forward(self, x):
x1 = self.layer1(x)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
x5 = self.layer5(x4)

x6 = self.layer6(x5, x4)
x6 = self.layer7(x6, x3)
x6 = self.layer8(x6, x2)
x6 = self.layer9(x6, x1)

return self.layer10(x6)
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
'''
Double Convolution and BN and ReLU
(3x3 conv -> BN -> ReLU) ** 2
'''
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.net(x)


class Down(nn.Module):
'''
Combination of MaxPool2d and DoubleConv in series
'''
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
DoubleConv(in_ch, out_ch)
)

def forward(self, x):
return self.net(x)


class Up(nn.Module):
'''
Upsampling (by either bilinear interpolation or transpose convolutions)
followed by concatenation of feature map from contracting path,
followed by double 3x3 convolution.
'''
def __init__(self, in_ch, out_ch, bilinear=False):
super().__init__()
self.upsample = None
if bilinear:
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)

self.conv = DoubleConv(in_ch, out_ch)

def forward(self, x1, x2):
x1 = self.upsample(x1)

# Pad x1 to the size of x2
diff_h = x2.shape[2] - x1.shape[2]
diff_w = x2.shape[3] - x1.shape[3]

x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])

# Concatenate along the channels axis
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
8 changes: 3 additions & 5 deletions pl_examples/full_examples/semantic_segmentation/semseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.models.segmentation import fcn_resnet50

import pytorch_lightning as pl
from models.unet.model import UNet


class KITTI(Dataset):
Expand Down Expand Up @@ -128,9 +128,7 @@ def __init__(self, hparams):
self.root_path = hparams.root
self.batch_size = hparams.batch_size
self.learning_rate = hparams.lr
self.net = torchvision.models.segmentation.fcn_resnet50(pretrained=False,
progress=True,
num_classes=19)
self.net = UNet(num_classes=19)
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
Expand All @@ -147,7 +145,7 @@ def training_step(self, batch, batch_nb):
img = img.float()
mask = mask.long()
out = self.forward(img)
loss_val = F.cross_entropy(out['out'], mask, ignore_index=250)
loss_val = F.cross_entropy(out, mask, ignore_index=250)
return {'loss': loss_val}

def configure_optimizers(self):
Expand Down

0 comments on commit 43ac63f

Please sign in to comment.