Skip to content

Commit

Permalink
Refactor model summary + generalize example input array (#1773)
Browse files Browse the repository at this point in the history
* squash

variant a


variant b


add test


revert rename


add changelog


docs


move changelog entry to top


use hooks


wip


wipp


layer summary


clean up, refactor


type hints


rename


remove obsolete code


rename


unused imports


simplify formatting of table and increase readability


doctest


superclass object


update examples


print unknown sizes


more docs and doctest


testing


unknown layers


add rnn test


remove main


restore train mode


test device wip


device


constant


simplify model forward transfer


return summary object in method


extend tests


fix summary for empty module


extend tests


refactor and added hook


variant a


variant b


add test


revert rename


add changelog


docs


move changelog entry to top


remove hardcoded string


simplify


test unknown shapes and all others


comments for tests


fix hparams attribute

* update default

* unused import

* clean up

* replace hardcoded strings

* fix doctest

* fix top/full

* black

* fix rnn test

* fix rnn

* update debugging docs


update docs


typo


update docs


update docs

* add changelog

* extract constant

* setter and getter

* move parity models to test folder

* parameterize mode
  • Loading branch information
awaelchli authored Jun 15, 2020
1 parent 22d9464 commit 7dc58bd
Show file tree
Hide file tree
Showing 13 changed files with 541 additions and 287 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed mistake in parameters' grad norm tracking ([#2012](https://github.com/PyTorchLightning/pytorch-lightning/pull/2012))
- Fixed CPU and hanging GPU crash ([#2118](https://github.com/PyTorchLightning/pytorch-lightning/pull/2118))

- Fixed an issue with the model summary and `example_input_array` depending on a specific ordering of the submodules in a LightningModule ([#1773](https://github.com/PyTorchLightning/pytorch-lightning/pull/1773))

## [0.7.6] - 2020-05-16

### Added
Expand Down
77 changes: 0 additions & 77 deletions benchmarks/parity_modules.py

This file was deleted.

2 changes: 1 addition & 1 deletion benchmarks/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch

import tests.base.utils as tutils
from benchmarks.parity_modules import ParityModuleRNN, ParityModuleMNIST
from pytorch_lightning import Trainer, seed_everything
from tests.base.models import ParityModuleRNN, ParityModuleMNIST


@pytest.mark.parametrize('cls_model,max_diff', [
Expand Down
31 changes: 23 additions & 8 deletions docs/source/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,32 @@ argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)

trainer = Trainer(overfit_pct=0.01)

Print the parameter count by layer
----------------------------------
Whenever the .fit() function gets called, the Trainer will print the weights summary for the lightningModule.
To disable this behavior, turn off this flag:

(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary`
argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
Print a summary of your LightningModule
---------------------------------------
Whenever the ``.fit()`` function gets called, the Trainer will print the weights summary for the LightningModule.
By default it only prints the top-level modules. If you want to show all submodules in your network, use the
`'full'` option:

.. testcode::

trainer = Trainer(weights_summary=None)
trainer = Trainer(weights_summary='full')

You can also display the intermediate input- and output sizes of all your layers by setting the
``example_input_array`` attribute in your LightningModule. It will print a table like this

.. code-block:: text
| Name | Type | Params | In sizes | Out sizes
--------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1 K | [10, 512] | [10, 512]
when you call ``.fit()`` on the Trainer. This can help you find bugs in the composition of your layers.

See Also:
- :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary` Trainer argument
- :class:`~pytorch_lightning.core.memory.ModelSummary`


Set the number of validation sanity steps
Expand Down
2 changes: 2 additions & 0 deletions pl_examples/domain_templates/generative_adversarial_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def __init__(self,

self.validation_z = torch.randn(8, self.latent_dim)

self.example_input_array = torch.zeros(2, hparams.latent_dim)

def forward(self, z):
return self.generator(z)

Expand Down
2 changes: 2 additions & 0 deletions pl_examples/models/lightning_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(self,
self.c_d2 = nn.Linear(in_features=self.hidden_dim,
out_features=self.out_features)

self.example_input_array = torch.zeros(2, 1, 28, 28)

def forward(self, x):
"""
No special modification required for Lightning, define it as you normally would
Expand Down
17 changes: 14 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(self, *args, **kwargs):

#: Pointer to the logger object
self.logger = None
self.example_input_array = None

#: True if using dp
self.use_dp = False
Expand All @@ -75,6 +74,17 @@ def __init__(self, *args, **kwargs):
#: device reference
self._device = torch.device('cpu')

# optionally can be set by user
self._example_input_array = None

@property
def example_input_array(self) -> Any:
return self._example_input_array

@example_input_array.setter
def example_input_array(self, example: Any) -> None:
self._example_input_array = example

@property
def on_gpu(self):
"""
Expand Down Expand Up @@ -1445,9 +1455,10 @@ def val_dataloader(self):
will have an argument ``dataset_idx`` which matches the order here.
"""

def summarize(self, mode: str) -> None:
def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary:
model_summary = ModelSummary(self, mode=mode)
log.info('\n' + model_summary.__str__())
log.info('\n' + str(model_summary))
return model_summary

def freeze(self) -> None:
r"""
Expand Down
Loading

0 comments on commit 7dc58bd

Please sign in to comment.