#! /usr/bin/env python import argparse import os import numpy as np import pytorch_lightning as pl import pytorch_lightning.metrics import torch import torch.nn from torch.nn import functional as F from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision import transforms parser = argparse.ArgumentParser() parser.add_argument('--normalize', action='store_true', help='Normalize the confusion matrix') parser.add_argument( '--set-num-classes', action='store_true', help='Set number of classes for metrics. Requires patch.' ) args = parser.parse_args() # This is a continuation of the basic tests provided previously with # extensions to their computation over multiple GPUs using DDP. # This set of tests works by computing the metrics within a model under # the control of a DDP trainer. The "model" that is tested does not # actually run a forward pass, but just computes the metric. The # training_step is configured to train on MNIST so that it is possible for # Pytorch Lightning to perform the appropriate training operations without # crashing on issues such as missing parameters or bad data. data = np.array([ [3, 3], [0, 0], [0, 0], [0, 0], [0, 0], [0, 1], [0, 2], [0, 2], [0, 2], [0, 2], [2, 0], [2, 0], [2, 0], [2, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]] ) # For this problem, the confusion matrix can be computed directly: true_cm_count = np.zeros((4, 4), dtype=int) for i in range(len(data)): true_cm_count[data[i, 0], data[i, 1]] += 1 # Normalization of the confusion matrix divides each row by the sum of that # row; however, care must be taken for rows with no data. row_sums = true_cm_count.sum(axis=1, keepdims=True) row_divisor = np.maximum(row_sums, 1) true_cm_norm = true_cm_count.astype(float) / row_divisor rank = dict(os.environ).get('LOCAL_RANK', None) if rank is None: print('Raw data with true label and declared label') print(data) print('True confusion matrix with counts') print(true_cm_count) print('Normalized confusion matrix by rows') print(true_cm_norm) print('') # Now let's use Pytorch Lightning to compute the same metric # This model set up to train on MNIST so that Pytorch Lightning does not # fail on the fact that I am not actually training a model. class MetricComputer(pl.LightningModule): def __init__(self, **kwargs): super().__init__() self.metric = pl.metrics.ConfusionMatrix(**kwargs) self.model = torch.nn.Linear(28 * 28, 10) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.02) def train_dataloader(self): return DataLoader( MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32 ) def val_dataloader(self): return DataLoader(data, batch_size=4) def forward(self, x): raise NotImplementedError def training_step(self, batch, batch_nb): x, y = batch z = torch.relu(self.model(x.view(x.size(0), -1))) loss = F.cross_entropy(z, y) return {'loss': loss} def validation_step(self, batch, batch_nb): return {'batch': batch} def validation_epoch_end(self, val_data): val_data = torch.cat([x['batch'] for x in val_data], dim=0) target = val_data[:, 0] pred = val_data[:, 1] cm = self.metric(pred, target) self.print('Computed confusion matrix') self.print(cm) return {} # GPU model with DDP without normalization model_args = { 'normalize': args.normalize } if args.set_num_classes: model_args['num_classes'] = 4 model = MetricComputer(**model_args) trainer = pl.Trainer(max_epochs=1, num_sanity_val_steps=0, limit_train_batches=0.01, progress_bar_refresh_rate=0, gpus=[0, 1], distributed_backend='ddp') trainer.fit(model)