Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added initial semantic segmentation example #751

Merged
merged 7 commits into from
Feb 16, 2020
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions pl_examples/full_examples/semantic_segmentation/semseg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os
from argparse import ArgumentParser
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.models.segmentation import fcn_resnet50

from PIL import Image
akshaykulkarni07 marked this conversation as resolved.
Show resolved Hide resolved
import pytorch_lightning as pl


class KITTI(Dataset):
def __init__(self, root_path, split='test', img_size=(1242, 376), transform=None):
akshaykulkarni07 marked this conversation as resolved.
Show resolved Hide resolved
self.img_size = img_size
self.void_labels = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
self.valid_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
self.ignore_index = 250
self.class_map = dict(zip(self.valid_labels, range(19)))
self.split = split
self.root = root_path
if self.split == 'train':
self.img_path = os.path.join(self.root, 'training/image_2')
self.mask_path = os.path.join(self.root, 'training/semantic')
else:
self.img_path = os.path.join(self.root, 'testing/image_2')
self.mask_path = None

self.transform = transform

self.img_list = self.get_filenames(self.img_path)
if self.split == 'train':
self.mask_list = self.get_filenames(self.mask_path)
else:
self.mask_list = None

def __len__(self):
return(len(self.img_list))

def __getitem__(self, idx):
img = Image.open(self.img_list[idx])
img = img.resize(self.img_size)
img = np.array(img)

if self.split == 'train':
mask = Image.open(self.mask_list[idx]).convert('L')
mask = mask.resize(self.img_size)
mask = np.array(mask)
mask = self.encode_segmap(mask)

if self.transform:
img = self.transform(img)

if self.split == 'train':
return img, mask
else:
return img

def encode_segmap(self, mask):
'''
Sets void classes to zero so they won't be considered for training
'''
for voidc in self.void_labels:
mask[mask == voidc] = self.ignore_index
for validc in self.valid_labels:
mask[mask == validc] = self.class_map[validc]
return mask

def get_filenames(self, path):
'''
Returns a list of absolute paths to images inside given `path`
'''
files_list = list()
akshaykulkarni07 marked this conversation as resolved.
Show resolved Hide resolved
for filename in os.listdir(path):
files_list.append(os.path.join(path, filename))
return files_list


class SegModel(pl.LightningModule):
def __init__(self, hparams):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as examples are becoming part of the Docs, we shall add at least basic documentation with a brief description and reference to resources...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what documentation you need here. Because none of the other examples have documentation for the LightningModule class.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ment someting like:

"""Image segmentation model.

This is a basic image segmentation model implemented with Lightning ...

References
- <link to source / publication>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in the latest commit

super(SegModel, self).__init__()
self.root_path = hparams.root
self.batch_size = hparams.batch_size
self.learning_rate = hparams.lr
self.net = torchvision.models.segmentation.fcn_resnet50(pretrained=False, progress=True, num_classes=19)
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324])
])
self.trainset = KITTI(self.root_path, split='train', transform=self.transform)
self.testset = KITTI(self.root_path, split='test', transform=self.transform)

def forward(self, x):
return self.net(x)

def training_step(self, batch, batch_nb):
img, mask = batch
img = img.float()
mask = mask.long()
out = self.forward(img)
loss_val = F.cross_entropy(out['out'], mask, ignore_index=250)
return {'loss': loss_val}

def configure_optimizers(self):
opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
return [opt], [sch]

def train_dataloader(self):
return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True)

def test_dataloader(self):
return DataLoader(self.testset, batch_size=self.batch_size, shuffle=False)


def main(hparams):
# ------------------------
# 1 INIT LIGHTNING MODEL
# ------------------------
model = SegModel(hparams)

# ------------------------
# 2 INIT TRAINER
# ------------------------
trainer = pl.Trainer(
gpus=hparams.gpus
)

# ------------------------
# 3 START TRAINING
# ------------------------
trainer.fit(model)
Borda marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--root", type=str, help="path where dataset is stored")
parser.add_argument("--gpus", type=int, help="number of available GPUs")
parser.add_argument("--batch_size", type=int, default=4, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate")

hparams = parser.parse_args()

main(hparams)