Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set find unused parameters to True by default to fix breaking compatibility #6438

Merged
merged 2 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386))


- Changed the default of `find_unused_parameters` back to `True` in DDP and DDP Spawn ([#6438](https://github.com/PyTorchLightning/pytorch-lightning/pull/6438))


### Deprecated


Expand Down
15 changes: 15 additions & 0 deletions docs/source/benchmarking/performance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ DP performs three GPU transfers for EVERY batch:

Whereas DDP only performs 1 transfer to sync gradients. Because of this, DDP is MUCH faster than DP.

When using DDP set find_unused_parameters=False
-----------------------------------------------

By default we have enabled find unused parameters to True. This is for compatibility issues that have arisen in the past (see the `discussion <https://github.com/PyTorchLightning/pytorch-lightning/discussions/6219>`_ for more information).
This by default comes with a performance hit, and can be disabled in most cases.

.. code-block:: python

from pytorch_lightning.plugins import DDPPlugin

trainer = pl.Trainer(
gpus=2,
plugins=DDPPlugin(find_unused_parameters=False),
)

----------

16-bit precision
Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ def set_world_ranks(self):
self.world_size = self.num_nodes * self.num_processes

def pre_configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
# Many models require setting this parameter to True, as there are corner cases
# when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True.
# This flag does come with a performance hit, so it is suggested to disable in cases where it is possible.
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
"find_unused_parameters", True
)
# todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization
if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get(
"find_unused_parameters", False
Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,13 @@ def post_dispatch(self):
self.__recover_child_process_weights(best_path, last_path)

def pre_configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
# Many models require setting this parameter to True, as there are corner cases
# when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True.
# This flag does come with a performance hit, so it is suggested to disable in cases where it is possible.
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
"find_unused_parameters", True
)
# todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization
if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get(
"find_unused_parameters", False
Expand Down