Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor model summary + generalize example input array #1773

Merged
merged 19 commits into from
Jun 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed mistake in parameters' grad norm tracking ([#2012](https://github.com/PyTorchLightning/pytorch-lightning/pull/2012))
- Fixed CPU and hanging GPU crash ([#2118](https://github.com/PyTorchLightning/pytorch-lightning/pull/2118))

- Fixed an issue with the model summary and `example_input_array` depending on a specific ordering of the submodules in a LightningModule ([#1773](https://github.com/PyTorchLightning/pytorch-lightning/pull/1773))

## [0.7.6] - 2020-05-16

### Added
Expand Down
77 changes: 0 additions & 77 deletions benchmarks/parity_modules.py

This file was deleted.

2 changes: 1 addition & 1 deletion benchmarks/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch

import tests.base.utils as tutils
from benchmarks.parity_modules import ParityModuleRNN, ParityModuleMNIST
from pytorch_lightning import Trainer, seed_everything
from tests.base.models import ParityModuleRNN, ParityModuleMNIST


@pytest.mark.parametrize('cls_model,max_diff', [
Expand Down
31 changes: 23 additions & 8 deletions docs/source/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,32 @@ argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)

trainer = Trainer(overfit_pct=0.01)

Print the parameter count by layer
----------------------------------
Whenever the .fit() function gets called, the Trainer will print the weights summary for the lightningModule.
To disable this behavior, turn off this flag:

(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary`
argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
Print a summary of your LightningModule
---------------------------------------
Whenever the ``.fit()`` function gets called, the Trainer will print the weights summary for the LightningModule.
By default it only prints the top-level modules. If you want to show all submodules in your network, use the
`'full'` option:

.. testcode::

trainer = Trainer(weights_summary=None)
trainer = Trainer(weights_summary='full')

You can also display the intermediate input- and output sizes of all your layers by setting the
``example_input_array`` attribute in your LightningModule. It will print a table like this

.. code-block:: text

| Name | Type | Params | In sizes | Out sizes
--------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1 K | [10, 512] | [10, 512]

when you call ``.fit()`` on the Trainer. This can help you find bugs in the composition of your layers.

See Also:
- :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary` Trainer argument
- :class:`~pytorch_lightning.core.memory.ModelSummary`


Set the number of validation sanity steps
Expand Down
2 changes: 2 additions & 0 deletions pl_examples/domain_templates/generative_adversarial_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def __init__(self,

self.validation_z = torch.randn(8, self.latent_dim)

self.example_input_array = torch.zeros(2, hparams.latent_dim)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, z):
return self.generator(z)

Expand Down
2 changes: 2 additions & 0 deletions pl_examples/models/lightning_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(self,
self.c_d2 = nn.Linear(in_features=self.hidden_dim,
out_features=self.out_features)

self.example_input_array = torch.zeros(2, 1, 28, 28)

def forward(self, x):
"""
No special modification required for Lightning, define it as you normally would
Expand Down
17 changes: 14 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(self, *args, **kwargs):

#: Pointer to the logger object
self.logger = None
self.example_input_array = None

#: True if using dp
self.use_dp = False
Expand All @@ -75,6 +74,17 @@ def __init__(self, *args, **kwargs):
#: device reference
self._device = torch.device('cpu')

# optionally can be set by user
self._example_input_array = None

@property
def example_input_array(self) -> Any:
return self._example_input_array

@example_input_array.setter
def example_input_array(self, example: Any) -> None:
self._example_input_array = example

@property
def on_gpu(self):
"""
Expand Down Expand Up @@ -1442,9 +1452,10 @@ def val_dataloader(self):
will have an argument ``dataset_idx`` which matches the order here.
"""

def summarize(self, mode: str) -> None:
def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary:
model_summary = ModelSummary(self, mode=mode)
log.info('\n' + model_summary.__str__())
log.info('\n' + str(model_summary))
return model_summary

def freeze(self) -> None:
r"""
Expand Down
Loading