Skip to content

Commit

Permalink
squash
Browse files Browse the repository at this point in the history
variant a


variant b


add test


revert rename


add changelog


docs


move changelog entry to top


use hooks


wip


wipp


layer summary


clean up, refactor


type hints


rename


remove obsolete code


rename


unused imports


simplify formatting of table and increase readability


doctest


superclass object


update examples


print unknown sizes


more docs and doctest


testing


unknown layers


add rnn test


remove main


restore train mode


test device wip


device


constant


simplify model forward transfer


return summary object in method


extend tests


fix summary for empty module


extend tests


refactor and added hook


variant a


variant b


add test


revert rename


add changelog


docs


move changelog entry to top


remove hardcoded string


simplify


test unknown shapes and all others


comments for tests


fix hparams attribute
  • Loading branch information
awaelchli committed Jun 6, 2020
1 parent c09317e commit cf303ba
Show file tree
Hide file tree
Showing 10 changed files with 529 additions and 177 deletions.
154 changes: 154 additions & 0 deletions benchmarks/test_rnn_parity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import time

import numpy as np
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import tests.base.utils as tutils

from pytorch_lightning import Trainer, LightningModule, seed_everything


class AverageDataset(Dataset):
def __init__(self, dataset_len=300, sequence_len=100):
self.dataset_len = dataset_len
self.sequence_len = sequence_len
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
top, bottom = self.input_seq.chunk(2, -1)
self.output_seq = top + bottom.roll(shifts=1, dims=-1)

def __len__(self):
return self.dataset_len

def __getitem__(self, item):
return self.input_seq[item], self.output_seq[item]


class ParityRNN(LightningModule):
def __init__(self):
super(ParityRNN, self).__init__()
self.rnn = nn.LSTM(10, 20, batch_first=True)
self.linear_out = nn.Linear(in_features=20, out_features=5)
self.example_input_array = torch.zeros(1, 5, 10)

def forward(self, x):
seq, last = self.rnn(x)
return self.linear_out(seq)

def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
return {'loss': loss}

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)

def train_dataloader(self):
return DataLoader(AverageDataset(), batch_size=30)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_pytorch_parity(tmpdir):
"""
Verify that the same pytorch and lightning models achieve the same results
:param tmpdir:
:return:
"""
num_epochs = 2
num_rums = 3

lightning_outs, pl_times = lightning_loop(ParityRNN, num_rums, num_epochs)
manual_outs, pt_times = vanilla_loop(ParityRNN, num_rums, num_epochs)
# make sure the losses match exactly to 5 decimal places
for pl_out, pt_out in zip(lightning_outs, manual_outs):
np.testing.assert_almost_equal(pl_out, pt_out, 8)

tutils.assert_speed_parity(pl_times, pt_times, num_epochs)


def vanilla_loop(MODEL, num_runs=10, num_epochs=10):
"""
Returns an array with the last loss from each epoch for each run
"""
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
errors = []
times = []

torch.backends.cudnn.deterministic = True
for i in range(num_runs):
time_start = time.perf_counter()

# set seed
seed = i
seed_everything(seed)

# init model parts
model = MODEL()
dl = model.train_dataloader()
optimizer = model.configure_optimizers()

# model to GPU
model = model.to(device)

epoch_losses = []
for epoch in range(num_epochs):

# run through full training set
for j, batch in enumerate(dl):
x, y = batch
x = x.cuda(0)
y = y.cuda(0)
batch = (x, y)

loss_dict = model.training_step(batch, j)
loss = loss_dict['loss']
loss.backward()
optimizer.step()
optimizer.zero_grad()

# track last epoch loss
epoch_losses.append(loss.item())

time_end = time.perf_counter()
times.append(time_end - time_start)

errors.append(epoch_losses[-1])

return errors, times


def lightning_loop(MODEL, num_runs=10, num_epochs=10):
errors = []
times = []

for i in range(num_runs):
time_start = time.perf_counter()

# set seed
seed = i
seed_everything(seed)
model = MODEL()

# init model parts
trainer = Trainer(
max_epochs=num_epochs,
progress_bar_refresh_rate=0,
weights_summary=None,
gpus=1,
early_stop_callback=False,
checkpoint_callback=False,
distributed_backend='dp',
deterministic=True,
)
trainer.fit(model)

final_loss = trainer.running_loss.last().item()
errors.append(final_loss)

time_end = time.perf_counter()
times.append(time_end - time_start)

return errors, times
2 changes: 2 additions & 0 deletions pl_examples/domain_templates/generative_adversarial_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def __init__(self,

self.validation_z = torch.randn(8, self.latent_dim)

self.example_input_array = torch.zeros(2, hparams.latent_dim)

def forward(self, z):
return self.generator(z)

Expand Down
2 changes: 2 additions & 0 deletions pl_examples/models/lightning_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(self,
self.c_d2 = nn.Linear(in_features=self.hidden_dim,
out_features=self.out_features)

self.example_input_array = torch.zeros(2, 1, 28, 28)

def forward(self, x):
"""
No special modification required for Lightning, define it as you normally would
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1598,9 +1598,10 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh

return model

def summarize(self, mode: str) -> None:
def summarize(self, mode: str = 'full') -> ModelSummary:
model_summary = ModelSummary(self, mode=mode)
log.info('\n' + model_summary.__str__())
log.info('\n' + str(model_summary))
return model_summary

def freeze(self) -> None:
r"""
Expand Down
Loading

0 comments on commit cf303ba

Please sign in to comment.