Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up gan and vae #153

Merged
merged 39 commits into from
Aug 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
2fb9a8d
pl 0.9 update
williamFalcon Aug 9, 2020
f534943
pl 0.9 update
williamFalcon Aug 9, 2020
25b252a
pl 0.9 update
williamFalcon Aug 9, 2020
ca06a67
pl 0.9 update
williamFalcon Aug 9, 2020
2fe32b4
pl 0.9 update
williamFalcon Aug 9, 2020
0066e32
pl 0.9 update
williamFalcon Aug 9, 2020
6bd423a
pl 0.9 update
williamFalcon Aug 9, 2020
3e11762
pl 0.9 update
williamFalcon Aug 9, 2020
447cc0e
pl 0.9 update
williamFalcon Aug 9, 2020
a5ca01b
pl 0.9 update
williamFalcon Aug 9, 2020
e2171d3
pl 0.9 update
williamFalcon Aug 9, 2020
3ca4839
pl 0.9 update
williamFalcon Aug 9, 2020
845f2ad
pl 0.9 update
williamFalcon Aug 9, 2020
98de7d2
pl 0.9 update
williamFalcon Aug 9, 2020
27c30ad
pl 0.9 update
williamFalcon Aug 9, 2020
962463f
pl 0.9 update
williamFalcon Aug 9, 2020
cd0c96d
pl 0.9 update
williamFalcon Aug 9, 2020
bc81dd3
pl 0.9 update
williamFalcon Aug 9, 2020
f7104d6
pl 0.9 update
williamFalcon Aug 9, 2020
b2b63a8
pl 0.9 update
williamFalcon Aug 9, 2020
686eb2a
pl 0.9 update
williamFalcon Aug 9, 2020
7ab01c0
pl 0.9 update
williamFalcon Aug 9, 2020
43c6fbb
pl 0.9 update
williamFalcon Aug 9, 2020
b229a4f
pl 0.9 update
williamFalcon Aug 9, 2020
09973ef
pl 0.9 update
williamFalcon Aug 9, 2020
e1ba7f4
pl 0.9 update
williamFalcon Aug 9, 2020
5b4c022
pl 0.9 update
williamFalcon Aug 9, 2020
73f1b53
pl 0.9 update
williamFalcon Aug 9, 2020
a660b4f
pl 0.9 update
williamFalcon Aug 9, 2020
8779aa9
pl 0.9 update
williamFalcon Aug 9, 2020
01c4eea
pl 0.9 update
williamFalcon Aug 9, 2020
5dbf3fc
pl 0.9 update
williamFalcon Aug 9, 2020
108b87e
pl 0.9 update
williamFalcon Aug 9, 2020
3ce7c64
pl 0.9 update
williamFalcon Aug 9, 2020
03f978b
update cpc
williamFalcon Aug 9, 2020
51cf0af
update cpc
williamFalcon Aug 9, 2020
dc2ec91
update cpc
williamFalcon Aug 9, 2020
6d9c174
update cpc
williamFalcon Aug 9, 2020
03b9230
update cpc
williamFalcon Aug 9, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/vision_datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ Supervised learning
These are standard vision datasets with the train, test, val splits pre-generated in DataLoaders with
the standard transforms (and Normalization) values


BinaryMNIST
^^^^^^^^^^^

.. autoclass:: pl_bolts.datamodules.binary_mnist_datamodule.BinaryMNISTDataModule
:noindex:

CityScapes
^^^^^^^^^^

Expand Down
23 changes: 12 additions & 11 deletions pl_bolts/callbacks/self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def get_representations(self, pl_module, x):
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):

x, y = batch
x = x.to(pl_module.device)
y = y.to(pl_module.device)

with torch.no_grad():
representations = self.get_representations(pl_module, x)

Expand All @@ -76,7 +79,7 @@ def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_id
# log metrics
acc = accuracy(mlp_preds, y)
metrics = {'ft_callback_mlp_loss': mlp_loss, 'ft_callback_mlp_acc': acc}
pl_module.logger.log_metrics(metrics)
pl_module.logger.log_metrics(metrics, step=trainer.global_step)


class BYOLMAWeightUpdate(pl.Callback):
Expand Down Expand Up @@ -114,18 +117,16 @@ def __init__(self, initial_tau=0.996):
self.initial_tau = initial_tau
self.current_tau = initial_tau

def on_batch_end(self, trainer, pl_module):

if pl_module.training:
# get networks
online_net = pl_module.online_network
target_net = pl_module.target_network
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
# get networks
online_net = pl_module.online_network
target_net = pl_module.target_network

# update weights
self.update_weights(online_net, target_net)
# update weights
self.update_weights(online_net, target_net)

# update tau after
self.current_tau = self.update_tau(pl_module, trainer)
# update tau after
self.current_tau = self.update_tau(pl_module, trainer)

def update_tau(self, pl_module, trainer):
tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / trainer.max_steps) + 1) / 2
Expand Down
36 changes: 21 additions & 15 deletions pl_bolts/callbacks/variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,27 @@ def on_epoch_end(self, trainer, pl_module):

def interpolate_latent_space(self, pl_module, latent_dim):
images = []
for z1 in range(self.range_start, self.range_end, 1):
for z2 in range(self.range_start, self.range_end, 1):
# set all dims to zero
z = torch.zeros(self.num_samples, latent_dim, device=pl_module.device)

# set the fist 2 dims to the value
z[:, 0] = torch.tensor(z1)
z[:, 1] = torch.tensor(z2)

# sample
# generate images
with torch.no_grad():
pl_module.eval()
with torch.no_grad():
pl_module.eval()
for z1 in range(self.range_start, self.range_end, 1):
for z2 in range(self.range_start, self.range_end, 1):
# set all dims to zero
z = torch.zeros(self.num_samples, latent_dim, device=pl_module.device)

# set the fist 2 dims to the value
z[:, 0] = torch.tensor(z1)
z[:, 1] = torch.tensor(z2)

# sample
# generate images
img = pl_module(z)
pl_module.train()
images.append(img)

if len(img.size()) == 2:
img = img.view(self.num_samples, *pl_module.img_dim)

img = img[0]
img = img.unsqueeze(0)
images.append(img)

pl_module.train()
return images
6 changes: 3 additions & 3 deletions pl_bolts/callbacks/vision/confused_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def __init__(
self.logging_batch_interval = logging_batch_interval
self.min_logit_value = min_logit_value

def on_batch_end(self, trainer, pl_module):

def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
# show images only every 20 batches
if (trainer.batch_idx + 1) % self.logging_batch_interval != 0:
return
Expand All @@ -81,7 +80,9 @@ def on_batch_end(self, trainer, pl_module):

mask_idxs = idxs[mask]

pl_module.eval()
self._plot(confusing_x, confusing_y, trainer, pl_module, mask_idxs)
pl_module.train()

def _plot(self, confusing_x, confusing_y, trainer, model, mask_idxs):
from matplotlib import pyplot as plt
Expand All @@ -91,7 +92,6 @@ def _plot(self, confusing_x, confusing_y, trainer, model, mask_idxs):
confusing_x = confusing_x[:self.top_k]
confusing_y = confusing_y[:self.top_k]

model.eval()
x_param_a = nn.Parameter(confusing_x)
x_param_b = nn.Parameter(confusing_x)

Expand Down
3 changes: 2 additions & 1 deletion pl_bolts/callbacks/vision/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def __init__(self, num_samples: int = 3):
def on_epoch_end(self, trainer, pl_module):
import torchvision

z = torch.randn(self.num_samples, pl_module.hparams.latent_dim, device=pl_module.device)
dim = (self.num_samples, pl_module.hparams.latent_dim)
z = torch.normal(mean=0.0, std=1.0, size=dim, device=pl_module.device)

# generate images
with torch.no_grad():
Expand Down
1 change: 1 addition & 0 deletions pl_bolts/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataset, SklearnDataModule, TensorDataset, TensorDataModule
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
188 changes: 188 additions & 0 deletions pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split
from torchvision import transforms as transform_lib
from torchvision.datasets import MNIST
from PIL import Image


class BinaryMNISTDataModule(LightningDataModule):

name = 'mnist'

def __init__(
self,
data_dir: str,
val_split: int = 5000,
num_workers: int = 16,
normalize: bool = False,
seed: int = 42,
*args,
**kwargs,
):
"""
.. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png
:width: 400
:alt: MNIST

Specs:
- 10 classes (1 per digit)
- Each image is (1 x 28 x 28)

Binary MNIST, train, val, test splits and transforms

Transforms::

mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor()
])

Example::

from pl_bolts.datamodules import BinaryMNISTDataModule

dm = BinaryMNISTDataModule('.')
model = LitModel()

Trainer().fit(model, dm)

Args:
data_dir: where to save/load the data
val_split: how many of the training images to use for the validation split
num_workers: how many workers to use for loading data
normalize: If true applies image normalize
"""
super().__init__(*args, **kwargs)
self.dims = (1, 28, 28)
self.data_dir = data_dir
self.val_split = val_split
self.num_workers = num_workers
self.normalize = normalize
self.seed = seed

@property
def num_classes(self):
"""
Return:
10
"""
return 10

def prepare_data(self):
"""
Saves MNIST files to data_dir
"""
MNIST(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor())
MNIST(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor())

def train_dataloader(self, batch_size=32, transforms=None):
"""
MNIST train set removes a subset to use for validation

Args:
batch_size: size of batch
transforms: custom transforms
"""
transforms = transforms or self.train_transforms or self._default_transforms()

dataset = BinaryMNIST(self.data_dir, train=True, download=False, transform=transforms)
train_length = len(dataset)
dataset_train, _ = random_split(
dataset,
[train_length - self.val_split, self.val_split],
generator=torch.Generator().manual_seed(self.seed)
)
loader = DataLoader(
dataset_train,
batch_size=batch_size,
shuffle=True,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
)
return loader

def val_dataloader(self, batch_size=32, transforms=None):
"""
MNIST val set uses a subset of the training set for validation

Args:
batch_size: size of batch
transforms: custom transforms
"""
transforms = transforms or self.val_transforms or self._default_transforms()
dataset = BinaryMNIST(self.data_dir, train=True, download=True, transform=transforms)
train_length = len(dataset)
_, dataset_val = random_split(
dataset,
[train_length - self.val_split, self.val_split],
generator=torch.Generator().manual_seed(self.seed)
)
loader = DataLoader(
dataset_val,
batch_size=batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
)
return loader

def test_dataloader(self, batch_size=32, transforms=None):
"""
MNIST test set uses the test split

Args:
batch_size: size of batch
transforms: custom transforms
"""
transforms = transforms or self.val_transforms or self._default_transforms()

dataset = BinaryMNIST(self.data_dir, train=False, download=False, transform=transforms)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
)
return loader

def _default_transforms(self):
if self.normalize:
mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
else:
mnist_transforms = transform_lib.ToTensor()

return mnist_transforms


class BinaryMNIST(MNIST):
def __getitem__(self, idx):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[idx], int(self.targets[idx])

# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

# binary
img[img < 0.5] = 0.0
img[img >= 0.5] = 1.0

return img, target
Loading