Skip to content

Commit

Permalink
feat(semseg): allow model customization (Lightning-AI#1371)
Browse files Browse the repository at this point in the history
* feat(semantic_segmentation): allow customization of unet

* feat(semseg): allow model customization

* style(semseg): format to PEP8

* fix(semseg): rename logger

* docs(changelog): updated semantic segmentation example

* suggestions

* suggestions

* flake8

Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
2 people authored and tullie committed May 6, 2020
1 parent a2027f7 commit 4c2a135
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 67 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed the default behaviour to no longer include a NaN check with each training iteration. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475))

- Updated semantic segmentation example with custom u-net and logging ([#1371](https://github.com/PyTorchLightning/pytorch-lightning/pull/1371))

-

### Deprecated

-

-


### Removed

Expand Down
123 changes: 85 additions & 38 deletions pl_examples/domain_templates/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import random

import pytorch_lightning as pl
from pl_examples.models.unet import UNet
from pytorch_lightning.loggers import WandbLogger

DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)


class KITTI(Dataset):
Expand All @@ -34,37 +39,40 @@ class KITTI(Dataset):
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
(mask does not usually require transforms, but they can be implemented in a similar way).
"""
IMAGE_PATH = os.path.join('training', 'image_2')
MASK_PATH = os.path.join('training', 'semantic')

def __init__(
self,
root_path,
split='test',
img_size=(1242, 376),
void_labels=[0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1],
valid_labels=[7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33],
data_path: str,
split: str,
img_size: tuple = (1242, 376),
void_labels: list = DEFAULT_VOID_LABELS,
valid_labels: list = DEFAULT_VALID_LABELS,
transform=None
):
self.img_size = img_size
self.void_labels = void_labels
self.valid_labels = valid_labels
self.ignore_index = 250
self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels))))
self.split = split
self.root = root_path
if self.split == 'train':
self.img_path = os.path.join(self.root, 'training/image_2')
self.mask_path = os.path.join(self.root, 'training/semantic')
else:
self.img_path = os.path.join(self.root, 'testing/image_2')
self.mask_path = None

self.transform = transform

self.split = split
self.data_path = data_path
self.img_path = os.path.join(self.data_path, self.IMAGE_PATH)
self.mask_path = os.path.join(self.data_path, self.MASK_PATH)
self.img_list = self.get_filenames(self.img_path)
self.mask_list = self.get_filenames(self.mask_path)

# Split between train and valid set (80/20)
random_inst = random.Random(12345) # for repeatability
n_items = len(self.img_list)
idxs = random_inst.sample(range(n_items), n_items // 5)
if self.split == 'train':
self.mask_list = self.get_filenames(self.mask_path)
else:
self.mask_list = None
idxs = [idx for idx in range(n_items) if idx not in idxs]
self.img_list = [self.img_list[i] for i in idxs]
self.mask_list = [self.mask_list[i] for i in idxs]

def __len__(self):
return len(self.img_list)
Expand All @@ -74,19 +82,15 @@ def __getitem__(self, idx):
img = img.resize(self.img_size)
img = np.array(img)

if self.split == 'train':
mask = Image.open(self.mask_list[idx]).convert('L')
mask = mask.resize(self.img_size)
mask = np.array(mask)
mask = self.encode_segmap(mask)
mask = Image.open(self.mask_list[idx]).convert('L')
mask = mask.resize(self.img_size)
mask = np.array(mask)
mask = self.encode_segmap(mask)

if self.transform:
img = self.transform(img)

if self.split == 'train':
return img, mask
else:
return img
return img, mask

def encode_segmap(self, mask):
"""
Expand All @@ -96,6 +100,8 @@ def encode_segmap(self, mask):
mask[mask == voidc] = self.ignore_index
for validc in self.valid_labels:
mask[mask == validc] = self.class_map[validc]
# remove extra idxs from updated dataset
mask[mask > 18] = self.ignore_index
return mask

def get_filenames(self, path):
Expand Down Expand Up @@ -124,17 +130,19 @@ class SegModel(pl.LightningModule):

def __init__(self, hparams):
super().__init__()
self.root_path = hparams.root
self.hparams = hparams
self.data_path = hparams.data_path
self.batch_size = hparams.batch_size
self.learning_rate = hparams.lr
self.net = UNet(num_classes=19)
self.net = UNet(num_classes=19, num_layers=hparams.num_layers,
features_start=hparams.features_start, bilinear=hparams.bilinear)
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
std=[0.32064945, 0.32098866, 0.32325324])
])
self.trainset = KITTI(self.root_path, split='train', transform=self.transform)
self.testset = KITTI(self.root_path, split='test', transform=self.transform)
self.trainset = KITTI(self.data_path, split='train', transform=self.transform)
self.validset = KITTI(self.data_path, split='valid', transform=self.transform)

def forward(self, x):
return self.net(x)
Expand All @@ -145,7 +153,21 @@ def training_step(self, batch, batch_nb):
mask = mask.long()
out = self(img)
loss_val = F.cross_entropy(out, mask, ignore_index=250)
return {'loss': loss_val}
log_dict = {'train_loss': loss_val}
return {'loss': loss_val, 'log': log_dict, 'progress_bar': log_dict}

def validation_step(self, batch, batch_idx):
img, mask = batch
img = img.float()
mask = mask.long()
out = self(img)
loss_val = F.cross_entropy(out, mask, ignore_index=250)
return {'val_loss': loss_val}

def validation_epoch_end(self, outputs):
loss_val = sum(output['val_loss'] for output in outputs) / len(outputs)
log_dict = {'val_loss': loss_val}
return {'log': log_dict, 'val_loss': log_dict['val_loss'], 'progress_bar': log_dict}

def configure_optimizers(self):
opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
Expand All @@ -155,8 +177,8 @@ def configure_optimizers(self):
def train_dataloader(self):
return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True)

def test_dataloader(self):
return DataLoader(self.testset, batch_size=self.batch_size, shuffle=False)
def val_dataloader(self):
return DataLoader(self.validset, batch_size=self.batch_size, shuffle=False)


def main(hparams):
Expand All @@ -166,24 +188,49 @@ def main(hparams):
model = SegModel(hparams)

# ------------------------
# 2 INIT TRAINER
# 2 SET LOGGER
# ------------------------
logger = False
if hparams.log_wandb:
logger = WandbLogger()

# optional: log model topology
logger.watch(model.net)

# ------------------------
# 3 INIT TRAINER
# ------------------------
trainer = pl.Trainer(
gpus=hparams.gpus
gpus=hparams.gpus,
logger=logger,
max_epochs=hparams.epochs,
accumulate_grad_batches=hparams.grad_batches,
distributed_backend=hparams.distributed_backend,
precision=16 if hparams.use_amp else 32,
)

# ------------------------
# 3 START TRAINING
# 5 START TRAINING
# ------------------------
trainer.fit(model)


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--root", type=str, help="path where dataset is stored")
parser.add_argument("--gpus", type=int, help="number of available GPUs")
parser.add_argument("--data_path", type=str, help="path where dataset is stored")
parser.add_argument("--gpus", type=int, default=-1, help="number of available GPUs")
parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'),
help='supports three options dp, ddp, ddp2')
parser.add_argument('--use_amp', action='store_true', help='if true uses 16 bit precision')
parser.add_argument("--batch_size", type=int, default=4, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate")
parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net")
parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer")
parser.add_argument("--bilinear", action='store_true', default=False,
help="whether to use bilinear interpolation or transposed")
parser.add_argument("--grad_batches", type=int, default=1, help="number of batches to accumulate")
parser.add_argument("--epochs", type=int, default=20, help="number of epochs to train")
parser.add_argument("--log_wandb", action='store_true', help="log training on Weights & Biases")

hparams = parser.parse_args()

Expand Down
61 changes: 34 additions & 27 deletions pl_examples/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,46 @@ class UNet(nn.Module):
Link - https://arxiv.org/abs/1505.04597
Parameters:
num_classes (int): Number of output classes required (default 19 for KITTI dataset)
bilinear (bool): Whether to use bilinear interpolation or transposed
num_classes: Number of output classes required (default 19 for KITTI dataset)
num_layers: Number of layers in each side of U-net
features_start: Number of features in first layer
bilinear: Whether to use bilinear interpolation or transposed
convolutions for upsampling.
"""

def __init__(self, num_classes=19, bilinear=False):
def __init__(
self, num_classes: int = 19,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False
):
super().__init__()
self.layer1 = DoubleConv(3, 64)
self.layer2 = Down(64, 128)
self.layer3 = Down(128, 256)
self.layer4 = Down(256, 512)
self.layer5 = Down(512, 1024)
self.num_layers = num_layers

self.layer6 = Up(1024, 512, bilinear=bilinear)
self.layer7 = Up(512, 256, bilinear=bilinear)
self.layer8 = Up(256, 128, bilinear=bilinear)
self.layer9 = Up(128, 64, bilinear=bilinear)
layers = [DoubleConv(3, features_start)]

self.layer10 = nn.Conv2d(64, num_classes, kernel_size=1)
feats = features_start
for _ in range(num_layers - 1):
layers.append(Down(feats, feats * 2))
feats *= 2

def forward(self, x):
x1 = self.layer1(x)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
x5 = self.layer5(x4)
for _ in range(num_layers - 1):
layers.append(Up(feats, feats // 2), bilinear)
feats //= 2

layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))

x6 = self.layer6(x5, x4)
x6 = self.layer7(x6, x3)
x6 = self.layer8(x6, x2)
x6 = self.layer9(x6, x1)
self.layers = nn.ModuleList(layers)

return self.layer10(x6)
def forward(self, x):
xi = [self.layers[0](x)]
# Down path
for layer in self.layers[1:self.num_layers]:
xi.append(layer(xi[-1]))
# Up path
for i, layer in enumerate(self.layers[self.num_layers:-1]):
xi[-1] = layer(xi[-1], xi[-2 - i])
return self.layers[-1](xi[-1])


class DoubleConv(nn.Module):
Expand All @@ -50,7 +57,7 @@ class DoubleConv(nn.Module):
(3x3 conv -> BN -> ReLU) ** 2
"""

def __init__(self, in_ch, out_ch):
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
Expand All @@ -70,7 +77,7 @@ class Down(nn.Module):
Combination of MaxPool2d and DoubleConv in series
"""

def __init__(self, in_ch, out_ch):
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
Expand All @@ -88,7 +95,7 @@ class Up(nn.Module):
followed by double 3x3 convolution.
"""

def __init__(self, in_ch, out_ch, bilinear=False):
def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
super().__init__()
self.upsample = None
if bilinear:
Expand Down

0 comments on commit 4c2a135

Please sign in to comment.