diff --git a/CHANGELOG.md b/CHANGELOG.md index 62b76cac58518..b101cd0932592 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,14 +17,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the default behaviour to no longer include a NaN check with each training iteration. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475)) +- Updated semantic segmentation example with custom u-net and logging ([#1371](https://github.com/PyTorchLightning/pytorch-lightning/pull/1371)) + - ### Deprecated - -- - ### Removed diff --git a/pl_examples/domain_templates/semantic_segmentation.py b/pl_examples/domain_templates/semantic_segmentation.py index 8ce01c3b7088b..4604b6454db98 100644 --- a/pl_examples/domain_templates/semantic_segmentation.py +++ b/pl_examples/domain_templates/semantic_segmentation.py @@ -7,9 +7,14 @@ import torchvision.transforms as transforms from PIL import Image from torch.utils.data import DataLoader, Dataset +import random import pytorch_lightning as pl from pl_examples.models.unet import UNet +from pytorch_lightning.loggers import WandbLogger + +DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) +DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) class KITTI(Dataset): @@ -34,14 +39,16 @@ class KITTI(Dataset): encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only (mask does not usually require transforms, but they can be implemented in a similar way). """ + IMAGE_PATH = os.path.join('training', 'image_2') + MASK_PATH = os.path.join('training', 'semantic') def __init__( self, - root_path, - split='test', - img_size=(1242, 376), - void_labels=[0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1], - valid_labels=[7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33], + data_path: str, + split: str, + img_size: tuple = (1242, 376), + void_labels: list = DEFAULT_VOID_LABELS, + valid_labels: list = DEFAULT_VALID_LABELS, transform=None ): self.img_size = img_size @@ -49,22 +56,23 @@ def __init__( self.valid_labels = valid_labels self.ignore_index = 250 self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels)))) - self.split = split - self.root = root_path - if self.split == 'train': - self.img_path = os.path.join(self.root, 'training/image_2') - self.mask_path = os.path.join(self.root, 'training/semantic') - else: - self.img_path = os.path.join(self.root, 'testing/image_2') - self.mask_path = None - self.transform = transform + self.split = split + self.data_path = data_path + self.img_path = os.path.join(self.data_path, self.IMAGE_PATH) + self.mask_path = os.path.join(self.data_path, self.MASK_PATH) self.img_list = self.get_filenames(self.img_path) + self.mask_list = self.get_filenames(self.mask_path) + + # Split between train and valid set (80/20) + random_inst = random.Random(12345) # for repeatability + n_items = len(self.img_list) + idxs = random_inst.sample(range(n_items), n_items // 5) if self.split == 'train': - self.mask_list = self.get_filenames(self.mask_path) - else: - self.mask_list = None + idxs = [idx for idx in range(n_items) if idx not in idxs] + self.img_list = [self.img_list[i] for i in idxs] + self.mask_list = [self.mask_list[i] for i in idxs] def __len__(self): return len(self.img_list) @@ -74,19 +82,15 @@ def __getitem__(self, idx): img = img.resize(self.img_size) img = np.array(img) - if self.split == 'train': - mask = Image.open(self.mask_list[idx]).convert('L') - mask = mask.resize(self.img_size) - mask = np.array(mask) - mask = self.encode_segmap(mask) + mask = Image.open(self.mask_list[idx]).convert('L') + mask = mask.resize(self.img_size) + mask = np.array(mask) + mask = self.encode_segmap(mask) if self.transform: img = self.transform(img) - if self.split == 'train': - return img, mask - else: - return img + return img, mask def encode_segmap(self, mask): """ @@ -96,6 +100,8 @@ def encode_segmap(self, mask): mask[mask == voidc] = self.ignore_index for validc in self.valid_labels: mask[mask == validc] = self.class_map[validc] + # remove extra idxs from updated dataset + mask[mask > 18] = self.ignore_index return mask def get_filenames(self, path): @@ -124,17 +130,19 @@ class SegModel(pl.LightningModule): def __init__(self, hparams): super().__init__() - self.root_path = hparams.root + self.hparams = hparams + self.data_path = hparams.data_path self.batch_size = hparams.batch_size self.learning_rate = hparams.lr - self.net = UNet(num_classes=19) + self.net = UNet(num_classes=19, num_layers=hparams.num_layers, + features_start=hparams.features_start, bilinear=hparams.bilinear) self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324]) ]) - self.trainset = KITTI(self.root_path, split='train', transform=self.transform) - self.testset = KITTI(self.root_path, split='test', transform=self.transform) + self.trainset = KITTI(self.data_path, split='train', transform=self.transform) + self.validset = KITTI(self.data_path, split='valid', transform=self.transform) def forward(self, x): return self.net(x) @@ -145,7 +153,21 @@ def training_step(self, batch, batch_nb): mask = mask.long() out = self(img) loss_val = F.cross_entropy(out, mask, ignore_index=250) - return {'loss': loss_val} + log_dict = {'train_loss': loss_val} + return {'loss': loss_val, 'log': log_dict, 'progress_bar': log_dict} + + def validation_step(self, batch, batch_idx): + img, mask = batch + img = img.float() + mask = mask.long() + out = self(img) + loss_val = F.cross_entropy(out, mask, ignore_index=250) + return {'val_loss': loss_val} + + def validation_epoch_end(self, outputs): + loss_val = sum(output['val_loss'] for output in outputs) / len(outputs) + log_dict = {'val_loss': loss_val} + return {'log': log_dict, 'val_loss': log_dict['val_loss'], 'progress_bar': log_dict} def configure_optimizers(self): opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate) @@ -155,8 +177,8 @@ def configure_optimizers(self): def train_dataloader(self): return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True) - def test_dataloader(self): - return DataLoader(self.testset, batch_size=self.batch_size, shuffle=False) + def val_dataloader(self): + return DataLoader(self.validset, batch_size=self.batch_size, shuffle=False) def main(hparams): @@ -166,24 +188,49 @@ def main(hparams): model = SegModel(hparams) # ------------------------ - # 2 INIT TRAINER + # 2 SET LOGGER + # ------------------------ + logger = False + if hparams.log_wandb: + logger = WandbLogger() + + # optional: log model topology + logger.watch(model.net) + + # ------------------------ + # 3 INIT TRAINER # ------------------------ trainer = pl.Trainer( - gpus=hparams.gpus + gpus=hparams.gpus, + logger=logger, + max_epochs=hparams.epochs, + accumulate_grad_batches=hparams.grad_batches, + distributed_backend=hparams.distributed_backend, + precision=16 if hparams.use_amp else 32, ) # ------------------------ - # 3 START TRAINING + # 5 START TRAINING # ------------------------ trainer.fit(model) if __name__ == '__main__': parser = ArgumentParser() - parser.add_argument("--root", type=str, help="path where dataset is stored") - parser.add_argument("--gpus", type=int, help="number of available GPUs") + parser.add_argument("--data_path", type=str, help="path where dataset is stored") + parser.add_argument("--gpus", type=int, default=-1, help="number of available GPUs") + parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'), + help='supports three options dp, ddp, ddp2') + parser.add_argument('--use_amp', action='store_true', help='if true uses 16 bit precision') parser.add_argument("--batch_size", type=int, default=4, help="size of the batches") parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate") + parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net") + parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer") + parser.add_argument("--bilinear", action='store_true', default=False, + help="whether to use bilinear interpolation or transposed") + parser.add_argument("--grad_batches", type=int, default=1, help="number of batches to accumulate") + parser.add_argument("--epochs", type=int, default=20, help="number of epochs to train") + parser.add_argument("--log_wandb", action='store_true', help="log training on Weights & Biases") hparams = parser.parse_args() diff --git a/pl_examples/models/unet.py b/pl_examples/models/unet.py index a7c474f3fc47c..5e85802bfe695 100644 --- a/pl_examples/models/unet.py +++ b/pl_examples/models/unet.py @@ -9,39 +9,46 @@ class UNet(nn.Module): Link - https://arxiv.org/abs/1505.04597 Parameters: - num_classes (int): Number of output classes required (default 19 for KITTI dataset) - bilinear (bool): Whether to use bilinear interpolation or transposed + num_classes: Number of output classes required (default 19 for KITTI dataset) + num_layers: Number of layers in each side of U-net + features_start: Number of features in first layer + bilinear: Whether to use bilinear interpolation or transposed convolutions for upsampling. """ - def __init__(self, num_classes=19, bilinear=False): + def __init__( + self, num_classes: int = 19, + num_layers: int = 5, + features_start: int = 64, + bilinear: bool = False + ): super().__init__() - self.layer1 = DoubleConv(3, 64) - self.layer2 = Down(64, 128) - self.layer3 = Down(128, 256) - self.layer4 = Down(256, 512) - self.layer5 = Down(512, 1024) + self.num_layers = num_layers - self.layer6 = Up(1024, 512, bilinear=bilinear) - self.layer7 = Up(512, 256, bilinear=bilinear) - self.layer8 = Up(256, 128, bilinear=bilinear) - self.layer9 = Up(128, 64, bilinear=bilinear) + layers = [DoubleConv(3, features_start)] - self.layer10 = nn.Conv2d(64, num_classes, kernel_size=1) + feats = features_start + for _ in range(num_layers - 1): + layers.append(Down(feats, feats * 2)) + feats *= 2 - def forward(self, x): - x1 = self.layer1(x) - x2 = self.layer2(x1) - x3 = self.layer3(x2) - x4 = self.layer4(x3) - x5 = self.layer5(x4) + for _ in range(num_layers - 1): + layers.append(Up(feats, feats // 2), bilinear) + feats //= 2 + + layers.append(nn.Conv2d(feats, num_classes, kernel_size=1)) - x6 = self.layer6(x5, x4) - x6 = self.layer7(x6, x3) - x6 = self.layer8(x6, x2) - x6 = self.layer9(x6, x1) + self.layers = nn.ModuleList(layers) - return self.layer10(x6) + def forward(self, x): + xi = [self.layers[0](x)] + # Down path + for layer in self.layers[1:self.num_layers]: + xi.append(layer(xi[-1])) + # Up path + for i, layer in enumerate(self.layers[self.num_layers:-1]): + xi[-1] = layer(xi[-1], xi[-2 - i]) + return self.layers[-1](xi[-1]) class DoubleConv(nn.Module): @@ -50,7 +57,7 @@ class DoubleConv(nn.Module): (3x3 conv -> BN -> ReLU) ** 2 """ - def __init__(self, in_ch, out_ch): + def __init__(self, in_ch: int, out_ch: int): super().__init__() self.net = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), @@ -70,7 +77,7 @@ class Down(nn.Module): Combination of MaxPool2d and DoubleConv in series """ - def __init__(self, in_ch, out_ch): + def __init__(self, in_ch: int, out_ch: int): super().__init__() self.net = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), @@ -88,7 +95,7 @@ class Up(nn.Module): followed by double 3x3 convolution. """ - def __init__(self, in_ch, out_ch, bilinear=False): + def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False): super().__init__() self.upsample = None if bilinear: