Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(semseg): allow model customization #1371

Merged
merged 8 commits into from
Apr 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed the default behaviour to no longer include a NaN check with each training iteration. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475))

- Updated semantic segmentation example with custom u-net and logging ([#1371](https://github.com/PyTorchLightning/pytorch-lightning/pull/1371))

-

### Deprecated

-

-


### Removed

Expand Down
123 changes: 85 additions & 38 deletions pl_examples/domain_templates/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import random

import pytorch_lightning as pl
from pl_examples.models.unet import UNet
from pytorch_lightning.loggers import WandbLogger

DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)


class KITTI(Dataset):
Expand All @@ -34,37 +39,40 @@ class KITTI(Dataset):
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
(mask does not usually require transforms, but they can be implemented in a similar way).
"""
IMAGE_PATH = os.path.join('training', 'image_2')
MASK_PATH = os.path.join('training', 'semantic')

def __init__(
self,
root_path,
split='test',
img_size=(1242, 376),
void_labels=[0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1],
valid_labels=[7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33],
data_path: str,
split: str,
img_size: tuple = (1242, 376),
void_labels: list = DEFAULT_VOID_LABELS,
valid_labels: list = DEFAULT_VALID_LABELS,
transform=None
):
self.img_size = img_size
self.void_labels = void_labels
self.valid_labels = valid_labels
self.ignore_index = 250
self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels))))
self.split = split
self.root = root_path
if self.split == 'train':
self.img_path = os.path.join(self.root, 'training/image_2')
self.mask_path = os.path.join(self.root, 'training/semantic')
else:
self.img_path = os.path.join(self.root, 'testing/image_2')
self.mask_path = None

self.transform = transform

self.split = split
self.data_path = data_path
self.img_path = os.path.join(self.data_path, self.IMAGE_PATH)
self.mask_path = os.path.join(self.data_path, self.MASK_PATH)
self.img_list = self.get_filenames(self.img_path)
self.mask_list = self.get_filenames(self.mask_path)

# Split between train and valid set (80/20)
random_inst = random.Random(12345) # for repeatability
n_items = len(self.img_list)
idxs = random_inst.sample(range(n_items), n_items // 5)
if self.split == 'train':
self.mask_list = self.get_filenames(self.mask_path)
else:
self.mask_list = None
idxs = [idx for idx in range(n_items) if idx not in idxs]
self.img_list = [self.img_list[i] for i in idxs]
self.mask_list = [self.mask_list[i] for i in idxs]

def __len__(self):
return len(self.img_list)
Expand All @@ -74,19 +82,15 @@ def __getitem__(self, idx):
img = img.resize(self.img_size)
img = np.array(img)

if self.split == 'train':
mask = Image.open(self.mask_list[idx]).convert('L')
mask = mask.resize(self.img_size)
mask = np.array(mask)
mask = self.encode_segmap(mask)
mask = Image.open(self.mask_list[idx]).convert('L')
mask = mask.resize(self.img_size)
mask = np.array(mask)
mask = self.encode_segmap(mask)

if self.transform:
img = self.transform(img)

if self.split == 'train':
return img, mask
else:
return img
return img, mask

def encode_segmap(self, mask):
"""
Expand All @@ -96,6 +100,8 @@ def encode_segmap(self, mask):
mask[mask == voidc] = self.ignore_index
for validc in self.valid_labels:
mask[mask == validc] = self.class_map[validc]
# remove extra idxs from updated dataset
mask[mask > 18] = self.ignore_index
return mask

def get_filenames(self, path):
Expand Down Expand Up @@ -124,17 +130,19 @@ class SegModel(pl.LightningModule):

def __init__(self, hparams):
super().__init__()
self.root_path = hparams.root
self.hparams = hparams
self.data_path = hparams.data_path
self.batch_size = hparams.batch_size
self.learning_rate = hparams.lr
self.net = UNet(num_classes=19)
self.net = UNet(num_classes=19, num_layers=hparams.num_layers,
features_start=hparams.features_start, bilinear=hparams.bilinear)
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
std=[0.32064945, 0.32098866, 0.32325324])
])
self.trainset = KITTI(self.root_path, split='train', transform=self.transform)
self.testset = KITTI(self.root_path, split='test', transform=self.transform)
self.trainset = KITTI(self.data_path, split='train', transform=self.transform)
self.validset = KITTI(self.data_path, split='valid', transform=self.transform)

def forward(self, x):
return self.net(x)
Expand All @@ -145,7 +153,21 @@ def training_step(self, batch, batch_nb):
mask = mask.long()
out = self(img)
loss_val = F.cross_entropy(out, mask, ignore_index=250)
return {'loss': loss_val}
log_dict = {'train_loss': loss_val}
return {'loss': loss_val, 'log': log_dict, 'progress_bar': log_dict}

def validation_step(self, batch, batch_idx):
img, mask = batch
img = img.float()
mask = mask.long()
out = self(img)
loss_val = F.cross_entropy(out, mask, ignore_index=250)
return {'val_loss': loss_val}

def validation_epoch_end(self, outputs):
loss_val = sum(output['val_loss'] for output in outputs) / len(outputs)
log_dict = {'val_loss': loss_val}
return {'log': log_dict, 'val_loss': log_dict['val_loss'], 'progress_bar': log_dict}

def configure_optimizers(self):
opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
Expand All @@ -155,8 +177,8 @@ def configure_optimizers(self):
def train_dataloader(self):
return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True)

def test_dataloader(self):
return DataLoader(self.testset, batch_size=self.batch_size, shuffle=False)
def val_dataloader(self):
return DataLoader(self.validset, batch_size=self.batch_size, shuffle=False)


def main(hparams):
Expand All @@ -166,24 +188,49 @@ def main(hparams):
model = SegModel(hparams)

# ------------------------
# 2 INIT TRAINER
# 2 SET LOGGER
# ------------------------
logger = False
if hparams.log_wandb:
logger = WandbLogger()

# optional: log model topology
logger.watch(model.net)

# ------------------------
# 3 INIT TRAINER
# ------------------------
trainer = pl.Trainer(
gpus=hparams.gpus
gpus=hparams.gpus,
logger=logger,
max_epochs=hparams.epochs,
accumulate_grad_batches=hparams.grad_batches,
distributed_backend=hparams.distributed_backend,
precision=16 if hparams.use_amp else 32,
)

# ------------------------
# 3 START TRAINING
# 5 START TRAINING
# ------------------------
trainer.fit(model)


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--root", type=str, help="path where dataset is stored")
parser.add_argument("--gpus", type=int, help="number of available GPUs")
parser.add_argument("--data_path", type=str, help="path where dataset is stored")
parser.add_argument("--gpus", type=int, default=-1, help="number of available GPUs")
parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'),
help='supports three options dp, ddp, ddp2')
parser.add_argument('--use_amp', action='store_true', help='if true uses 16 bit precision')
parser.add_argument("--batch_size", type=int, default=4, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate")
parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net")
parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer")
parser.add_argument("--bilinear", action='store_true', default=False,
help="whether to use bilinear interpolation or transposed")
parser.add_argument("--grad_batches", type=int, default=1, help="number of batches to accumulate")
parser.add_argument("--epochs", type=int, default=20, help="number of epochs to train")
parser.add_argument("--log_wandb", action='store_true', help="log training on Weights & Biases")

hparams = parser.parse_args()

Expand Down
61 changes: 34 additions & 27 deletions pl_examples/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,46 @@ class UNet(nn.Module):
Link - https://arxiv.org/abs/1505.04597

Parameters:
num_classes (int): Number of output classes required (default 19 for KITTI dataset)
bilinear (bool): Whether to use bilinear interpolation or transposed
num_classes: Number of output classes required (default 19 for KITTI dataset)
num_layers: Number of layers in each side of U-net
features_start: Number of features in first layer
bilinear: Whether to use bilinear interpolation or transposed
convolutions for upsampling.
"""

def __init__(self, num_classes=19, bilinear=False):
def __init__(
self, num_classes: int = 19,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False
):
super().__init__()
self.layer1 = DoubleConv(3, 64)
self.layer2 = Down(64, 128)
self.layer3 = Down(128, 256)
self.layer4 = Down(256, 512)
self.layer5 = Down(512, 1024)
self.num_layers = num_layers

self.layer6 = Up(1024, 512, bilinear=bilinear)
self.layer7 = Up(512, 256, bilinear=bilinear)
self.layer8 = Up(256, 128, bilinear=bilinear)
self.layer9 = Up(128, 64, bilinear=bilinear)
layers = [DoubleConv(3, features_start)]

self.layer10 = nn.Conv2d(64, num_classes, kernel_size=1)
feats = features_start
for _ in range(num_layers - 1):
layers.append(Down(feats, feats * 2))
feats *= 2

def forward(self, x):
x1 = self.layer1(x)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
x5 = self.layer5(x4)
for _ in range(num_layers - 1):
layers.append(Up(feats, feats // 2), bilinear)
feats //= 2

layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))

x6 = self.layer6(x5, x4)
x6 = self.layer7(x6, x3)
x6 = self.layer8(x6, x2)
x6 = self.layer9(x6, x1)
self.layers = nn.ModuleList(layers)

return self.layer10(x6)
def forward(self, x):
xi = [self.layers[0](x)]
# Down path
for layer in self.layers[1:self.num_layers]:
xi.append(layer(xi[-1]))
# Up path
for i, layer in enumerate(self.layers[self.num_layers:-1]):
xi[-1] = layer(xi[-1], xi[-2 - i])
return self.layers[-1](xi[-1])


class DoubleConv(nn.Module):
Expand All @@ -50,7 +57,7 @@ class DoubleConv(nn.Module):
(3x3 conv -> BN -> ReLU) ** 2
"""

def __init__(self, in_ch, out_ch):
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
Expand All @@ -70,7 +77,7 @@ class Down(nn.Module):
Combination of MaxPool2d and DoubleConv in series
"""

def __init__(self, in_ch, out_ch):
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
Expand All @@ -88,7 +95,7 @@ class Up(nn.Module):
followed by double 3x3 convolution.
"""

def __init__(self, in_ch, out_ch, bilinear=False):
def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
super().__init__()
self.upsample = None
if bilinear:
Expand Down