Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor model summary + generalize example input array #1773

Merged
merged 19 commits into from
Jun 15, 2020
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an issue with `auto_collect_arguments` collecting local variables that are not constructor arguments and not working for signatures that have the instance not named `self` ([#2048](https://github.com/PyTorchLightning/pytorch-lightning/pull/2048))

- Fixed an issue with the model summary and `example_input_array` depending on a specific ordering of the submodules in a LightningModule ([#1773](https://github.com/PyTorchLightning/pytorch-lightning/pull/1773))

## [0.7.6] - 2020-05-16

### Added
Expand Down
31 changes: 23 additions & 8 deletions docs/source/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,32 @@ argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)

trainer = Trainer(overfit_pct=0.01)

Print the parameter count by layer
----------------------------------
Whenever the .fit() function gets called, the Trainer will print the weights summary for the lightningModule.
To disable this behavior, turn off this flag:

(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary`
argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
Print a summary of your LightningModule
---------------------------------------
Whenever the ``.fit()`` function gets called, the Trainer will print the weights summary for the LightningModule.
By default it only prints the top-level modules. If you want to show all submodules in your network, use the
`'full'` option:

.. testcode::

trainer = Trainer(weights_summary=None)
trainer = Trainer(weights_summary='full')

You can also display the intermediate input- and output sizes of all your layers by setting the
``example_input_array`` attribute in your LightningModule. It will print a table like this

.. code-block:: text

| Name | Type | Params | In sizes | Out sizes
--------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1 K | [10, 512] | [10, 512]

when you call ``.fit()`` on the Trainer. This can help you find bugs in the composition of your layers.

See Also:
- :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary` Trainer argument
- :class:`~pytorch_lightning.core.memory.ModelSummary`


Set the number of validation sanity steps
Expand Down
2 changes: 2 additions & 0 deletions pl_examples/domain_templates/generative_adversarial_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def __init__(self,

self.validation_z = torch.randn(8, self.latent_dim)

self.example_input_array = torch.zeros(2, hparams.latent_dim)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, z):
return self.generator(z)

Expand Down
2 changes: 2 additions & 0 deletions pl_examples/models/lightning_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(self,
self.c_d2 = nn.Linear(in_features=self.hidden_dim,
out_features=self.out_features)

self.example_input_array = torch.zeros(2, 1, 28, 28)

def forward(self, x):
"""
No special modification required for Lightning, define it as you normally would
Expand Down
17 changes: 14 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(self, *args, **kwargs):

#: Pointer to the logger object
self.logger = None
self.example_input_array = None

#: True if using dp
self.use_dp = False
Expand All @@ -77,6 +76,17 @@ def __init__(self, *args, **kwargs):
#: device reference
self._device = torch.device('cpu')

# optionally can be set by user
self._example_input_array = None

@property
def example_input_array(self) -> Any:
return self._example_input_array

@example_input_array.setter
def example_input_array(self, example: Any) -> None:
self._example_input_array = example

@property
def on_gpu(self):
"""
Expand Down Expand Up @@ -1598,9 +1608,10 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh

return model

def summarize(self, mode: str) -> None:
def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary:
model_summary = ModelSummary(self, mode=mode)
log.info('\n' + model_summary.__str__())
log.info('\n' + str(model_summary))
return model_summary

def freeze(self) -> None:
r"""
Expand Down
Loading