Skip to content

Commit

Permalink
enabled no returns from eval (#2446)
Browse files Browse the repository at this point in the history
* enabled no returns from eval

* fixed docs

* fixed docs

* fixed docs

* fixed docs

* fixed docs

* fixed docs

* fixed docs

* fixed docs

* fixed docs

* fixed docs

* fixed docs

* fixed docs
  • Loading branch information
williamFalcon committed Jul 1, 2020
1 parent fa2233f commit 325852c
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 35 deletions.
89 changes: 89 additions & 0 deletions docs/source/bolts.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
Bolts
=====
`PyTorch Lightning Bolts <https://pytorch-lightning-bolts.readthedocs.io/en/latest/>`_, is our official collection
of prebuilt models across many research domains.

.. code-block:: bash
pip install pytorch-lightning-bolts
In bolts we have:

- A collection of pretrained state-of-the-art models.
- A collection of models designed to bootstrap your research.
- A collection of Callbacks, transforms, full datasets.
- All models work on CPUs, TPUs, GPUs and 16-bit precision.

-----------------

Quality control
---------------
Bolts are built-by the Lightning community and contributed to bolts.
The lightning team guarantees that contributions are:

- Rigorously Tested (CPUs, GPUs, TPUs)
- Rigorously Documented
- Standardized via PyTorch Lightning
- Optimized for speed
- Checked for correctness

---------

Example 1: Pretrained, prebuilt models
--------------------------------------

.. code-block:: python
from pl_bolts.models import VAE, GPT2, ImageGPT, PixelCNN
from pl_bolts.models.self_supervised import AMDIM, CPCV2, SimCLR, MocoV2
from pl_bolts.models import LinearRegression, LogisticRegression
from pl_bolts.models.gans import GAN
from pl_bolts.callbacks import PrintTableMetricsCallback
from pl_bolts.datamodules import FashionMNISTDataModule, CIFAR10DataModule, ImagenetDataModule
------------

Example 2: Extend for faster research
-------------------------------------
Bolts are contributed with benchmarks and continuous-integration tests. This means
you can trust the implementations and use them to bootstrap your resarch much faster.

.. code-block:: python
from pl_bolts.models import ImageGPT
from pl_bolts.self_supervised import SimCLR
class VideoGPT(ImageGPT):
def training_step(self, batch, batch_idx):
x, y = batch
x = _shape_input(x)
logits = self.gpt(x)
simclr_features = self.simclr(x)
# -----------------
# do something new with GPT logits + simclr_features
# -----------------
loss = self.criterion(logits.view(-1, logits.size(-1)), x.view(-1).long())
logs = {"loss": loss}
return {"loss": loss, "log": logs}
----------

Example 3: Callbacks
--------------------
We also have a collection of callbacks.

.. code-block:: python
from pl_bolts.callbacks import PrintTableMetricsCallback
import pytorch_lightning as pl
trainer = pl.Trainer(callbacks=[PrintTableMetricsCallback()])
# loss│train_loss│val_loss│epoch
# ──────────────────────────────
# 2.2541470527648926│2.2541470527648926│2.2158432006835938│0
16 changes: 8 additions & 8 deletions docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ We successfully extended functionality without polluting our super clean
----------------

Best Practices
==============

1. Callbacks should be isolated in their functionality. Your callback should not rely on the
behavior of other callbacks in order to work properly.
2. Do not manually call methods from the callback. The callbacks are designed to be
invoked at specific times during training. Directly calling methods (eg. `on_validation_end`)
is strongly discouraged.
3. Whenever possible, your callbacks should not depend on the order in which they are executed.
--------------
The following are best practices when using/designing callbacks.

1. Callbacks should be isolated in their functionality.
2. Your callback should not rely on the behavior of other callbacks in order to work properly.
3. Do not manually call methods from the callback.
4. Directly calling methods (eg. `on_validation_end`) is strongly discouraged.
5. Whenever possible, your callbacks should not depend on the order in which they are executed.


---------
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
'api/pytorch_lightning.rst',
'api/pl_examples.*',
'api/modules.rst',
'PULL_REQUEST_TEMPLATE.md',

# deprecated/renamed:
'api/pytorch_lightning.logging.*', # TODO: remove in v0.9.0
Expand Down
4 changes: 4 additions & 0 deletions docs/source/hooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Hooks lifecycle
Training set-up
^^^^^^^^^^^^^^^

- :meth:`~pytorch_lightning.core.lightning.LightningModule.prepare_data`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.setup`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.init_ddp_connection`
- :meth:`~pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin.init_optimizers`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_apex`
Expand All @@ -30,6 +32,8 @@ Training set-up
- :meth:`~pytorch_lightning.core.lightning.LightningModule.summarize`
- :meth:`~pytorch_lightning.trainer.training_io.TrainerIOMixin.restore_weights`

.. warning:: `prepare_data` is only called from global_rank=0. Don't assign state (self.something), use `setup` for that

----------

Training loop
Expand Down
9 changes: 7 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ PyTorch Lightning Documentation
hooks
trainer

.. toctree::
:maxdepth: 1
:name: Bolts
:caption: Bolts

bolts

.. toctree::
:maxdepth: 1
:name: Community Examples
Expand All @@ -35,7 +42,6 @@ PyTorch Lightning Documentation
Contextual Emotion Detection (DoubleDistilBert) <https://github.com/PyTorchLightning/emotion_transformer>
Cotatron: Transcription-Guided Speech Encoder <https://github.com/mindslab-ai/cotatron>
FasterRCNN object detection + Hydra <https://github.com/PyTorchLightning/wheat>
Generative Adversarial Network <https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=TyYOdg8g77P0>
Hyperparameter optimization with Optuna <https://github.com/optuna/optuna/blob/master/examples/pytorch_lightning_simple.py>
Image Inpainting using Partial Convolutions <https://github.com/ryanwongsa/Image-Inpainting>
MNIST on TPU <https://colab.research.google.com/drive/1-_LKx4HwAxl5M6xPJmqAAu444LTDQoa3#scrollTo=BHBz1_AnamN_>
Expand Down Expand Up @@ -100,7 +106,6 @@ PyTorch Lightning Documentation
CODE_OF_CONDUCT.md
CONTRIBUTING.md
BECOMING_A_CORE_CONTRIBUTOR.md
PULL_REQUEST_TEMPLATE.md
governance.md

Indices and tables
Expand Down
21 changes: 19 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,11 +1337,19 @@ def train_dataloader(self) -> DataLoader:
The dataloader you return will not be called every epoch unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
For data processing use the following pattern:
- download in :meth:`prepare_data`
- process and split in :meth:`setup`
However, the above are only necessary for distributed processing.
.. warning:: do not assign state in prepare_data
- :meth:`~pytorch_lightning.trainer.Trainer.fit`
- ...
- :meth:`prepare_data`
- :meth:`setup`
- :meth:`train_dataloader`
Note:
Expand Down Expand Up @@ -1383,11 +1391,20 @@ def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
The dataloader you return will not be called every epoch unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
For data processing use the following pattern:
- download in :meth:`prepare_data`
- process and split in :meth:`setup`
However, the above are only necessary for distributed processing.
.. warning:: do not assign state in prepare_data
- :meth:`~pytorch_lightning.trainer.Trainer.fit`
- ...
- :meth:`prepare_data`
- :meth:`setup`
- :meth:`train_dataloader`
- :meth:`val_dataloader`
- :meth:`test_dataloader`
Expand Down
36 changes: 20 additions & 16 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,10 @@ def _evaluate(
elif self.is_overridden('validation_epoch_end', model=model):
eval_results = model.validation_epoch_end(outputs)

# aggregate ddp stats across
if self.use_ddp or self.use_ddp2:
self.reduce_eval_ddp(eval_results)
# aggregate ddp stats across
has_content = eval_results is not None and len(eval_results) > 0
if has_content and (self.use_ddp or self.use_ddp2):
self.reduce_eval_ddp(eval_results)

# enable train mode again
model.train()
Expand Down Expand Up @@ -406,23 +407,26 @@ def run_evaluation(self, test_mode: bool = False):

# run evaluation
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results)

# add metrics to prog bar
self.add_progress_bar_metrics(prog_bar_metrics)
# enable no returns
if eval_results is not None and len(eval_results) > 0:
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results)

# log results of test
if test_mode and self.is_global_zero:
print('-' * 80)
print('TEST RESULTS')
pprint(callback_metrics)
print('-' * 80)
# add metrics to prog bar
self.add_progress_bar_metrics(prog_bar_metrics)

# log metrics
self.log_metrics(log_metrics, {})
# log results of test
if test_mode and self.is_global_zero:
print('-' * 80)
print('TEST RESULTS')
pprint(callback_metrics)
print('-' * 80)

# track metrics for callbacks
self.callback_metrics.update(callback_metrics)
# log metrics
self.log_metrics(log_metrics, {})

# track metrics for callbacks
self.callback_metrics.update(callback_metrics)

# hook
model.on_post_performance_check()
Expand Down
12 changes: 5 additions & 7 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,6 @@ class Trainer(
>>> trainer.fit(model, train_loader)
1
>>> trainer.test(model, train_loader) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
--------------------------------------------------------------------------------
TEST RESULTS
...
--------------------------------------------------------------------------------
"""
DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar', 'training_tqdm_dict', 'num_tpu_cores')

Expand Down Expand Up @@ -1142,8 +1137,11 @@ def run_pretrain_routine(self, model: LightningModule):
self.val_dataloaders,
max_batches,
False)
_, _, _, callback_metrics, _ = self.process_output(eval_results)
self.callback_metrics = callback_metrics

# allow no returns from eval
if eval_results is not None and len(eval_results) > 0:
_, _, _, callback_metrics, _ = self.process_output(eval_results)
self.callback_metrics = callback_metrics

self.on_sanity_check_end()

Expand Down

0 comments on commit 325852c

Please sign in to comment.