Skip to content

Commit

Permalink
Add PredictLoop (#5752)
Browse files Browse the repository at this point in the history
* integrate distrib_type

* sync changes

* sync

* fixes

* add forgotten generators

* add missing logic

* update

* import

* missed imports

* import fixes

* isort

* mv f

* changelog

* format

* move helper to parallel plugin

* d

* add world size

* clean up

* duplicate

* activate ddp_sharded and tpu

* set nvidia flags

* remove unused colab var

* use_tpu <-> on_tpu attrs

* make some ddp_cpu and clusterplugin tests pass

* Ref/accelerator connector (#5742)

* final cleanup

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* connector cleanup

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* trainer cleanup

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* accelerator cleanup + missing logic in accelerator connector

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* add missing changes to callbacks

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* reflect accelerator changes to lightning module

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* clean cluster envs

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* cleanup plugins

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* add broadcasting

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* yapf

* remove plugin connector

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* plugins

* add predict_loop

* manual optimization

* clean predictloop

* update optimizer routing

* add predict loop on new accelerator

* resolve a bug

* add rank to torchelastic

* add predict_loop

* add predict loop on new accelerator

* resolve a bug

* fix memory mixed precision

* update

* setstate on trainer for pickling in ddp spawn

* add predict_loop

* clean predictloop

* add predict loop on new accelerator

* resolve a bug

* add predict_loop

* add predict loop on new accelerator

* resolve a bug

* add predict_loop

* add predict loop on new accelerator

* resolve a bug

* add predict_loop

* add predict loop on new accelerator

* resolve a bug

* add predict_loop

* clean predictloop

* add predict loop on new accelerator

* resolve a bug

* add predict_loop

* add predict loop on new accelerator

* resolve a bug

* resolve tests

* add predict method

* add back commented accelerator code

* adapt test for sync_batch_norm to new plugin

* fix deprecated tests

* fix ddp cpu choice when no num_processes are given

* yapf format

* skip a memory test that cannot pass anymore

* remove sanetize

* rename train to run_train

* remove useless hooks

* add misconfigurationException

* remove wrong naming

* resolve some legacy

* udpate docstring

* fix pickle error in spawn plugin

* x

* avoid

* x

* fix cyclic import in docs build

* add support for sharded

* update typing

* add sharded and sharded_spawn to distributed types

* make unwrap model default

* refactor LightningShardedDataParallel similar to LightningDistributedDataParallel

* update sharded spawn to reflect changes

* update sharded to reflect changes

* Merge 1.1.5 changes

* fix merge

* fix merge

* yapf isort

* fix merge

* yapf isort

* fix indentation in test

* copy over reinit scheduler implementation from dev1.2

* fix apex tracking calls with dev_debugger

* reduce diff to dev1.2, clean up

* fix trainer config test  when gpus>0 and num_processes >0 and ddp_cpu

* sort plugin tests legacy/new

* fix error handling for amp on cpu

* fix merge


fix merge


fix merge

* [Feat] Resolve manual_backward (#5837)

* resolve manual_backward

* resolve flake8

* update

* resolve for ddp_spawn

* resolve flake8

* resolve flake8

* resolve flake8

Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>

* fix tests/accelerator tests on cpu

* [BugFix] Resolve manual optimization (#5852)

* resolve manual_optimization

* update

* update

Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>

* Remove copy trainer parameters to happen earlier within the loop and add safe guard to get ref model (#5856)

* resovle a bug

* Accelerator refactor sharded rpc (#5854)

* rpc branch

* merge

* update handling of rpc

* make devices etc. Optional in RPC

* set devices etc. later if necessary

* remove devices from sequential

* make devices optional in rpc

* fix import

* uncomment everything

* fix cluster selection

Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>

* resolve bug

* fix assert in rpc test

* resolve a test

* fix docs compilation

* accelerator refactor - fix for sharded parity test (#5866)

* fix memory issue with ddp_spawn

* x


x


x


x


x


x


x


x


x

* x

* Remove DDP2 as this does not apply

* Add missing pre optimizer hook to ensure lambda closure is called

* fix apex docstring

* [accelerator][BugFix] Resolve some test for 1 gpu (#5863)

* update

* revert init

* resolve a bug

* update

* resolve flake8

* update

* update

* update

* revert init

* resolve a bug

* update

* resolve flake8

* update

* update

* update

* update

* update

* revert init

* resolve a bug

* update

* resolve flake8

* update

* update

* update

* revert init

* update

* resolve flake8

* update

* update

* update

* update

* update

* all_gather

* update

* make plugins work, add misconfig for RPC

* update

* update

* remove breaking test

* resolve some tests

* resolve flake8

* revert to ddp_spawn

Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>

* yapf isort

* resolve flake8

* fix apex doctests

* fix apex doctests 2

* resolve docs

* update drone

* clean env

* update

* update

* update

* update

* merge

* Fix RPC related tests, clean out old API, update for new accelerator API [skip ci] (#5881)

* Fix RPC related tests, clean out old API, update for new accelerator API

* Move tests out of legacy folder, update paths and names

* Update test_remove_1-4.py

* Expose properties for tpu cores/gpus/num_gpus

* Add root GPU property

* Move properties to properties.py

* move tests that were previously in drone

* Fix root GPU property (#5908)

* Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator

* Add missing tests back

* fix best model path transfer when no checkpoint callback available

* Fix setup hook order [wip] (#5858)

* Call trainer setup hook before accelerator setup

* Add test case

* add new test

* typo

* fix callback order in test

Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* rename ddp sequential -> rpc sequential for special test

* revert

* fix stupid merge problem

* Use property in connector for sampler (#5913)

* merge the import conflicts

* fix spawning of processes in slurm

* [wip] Fix some bugs for TPU [skip ci] (#5878)

* fixed for single tpu

* fixed spawn

* fixed spawn

* update

* update

* wip

* resolve bugs

* resolve bug

* update on comment

* removed decorator

* resolve comments

* set to 4

* update

* update

* need cleaning

* update

* update

* update

* resolve flake8

* resolve bugs

* exclude broadcast

* resolve bugs

* change test

* update

* update

* skip if meet fails

* properly raise trace

* update

* add catch

* wrap test

* resolve typo

* update

* typo

Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>

* resolve some tests

* update

* fix imports

* update

* resolve flake8

* update azure pipeline

* skip a sharded test on cpu that requires a gpu

* resolve tpus

* resolve bug

* resolve flake8

* update

* updat utils

* revert permission change on files

* suggestions from carlos

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* remove unrelated formatting changes

* remove incomplete comment

* Update pytorch_lightning/accelerators/__init__.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* remove unrelated formatting change

* add types

* warn 1.7 ddp manual backward only if ddp kwarg unset

* yapf + isort

* pep8 unused imports

* fix cyclic import in docs

* Apply suggestions from code review

* typer in accelerator.py

* typo

* resolve flake8

* update code

* update

* Update pytorch_lightning/trainer/predict_loop.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/trainer/predict_loop.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* fix merge

* fix merge

* reset legacy accelerator

* add missing rename dispatch

* rename post traning

* update code

* resolved comments

* typo

* typo

* add flow description

* resolve comments

* update on comments

* update flow

* add backticks

* resolve tpu

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: justusschock <justus.schock@posteo.de>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
14 people authored Feb 16, 2021
1 parent a52be5b commit e982800
Show file tree
Hide file tree
Showing 25 changed files with 454 additions and 136 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added AUC/AUROC class interface ([#5479](https://github.com/PyTorchLightning/pytorch-lightning/pull/5479))


- Added `PredictLoop` object ([#5752](https://github.com/PyTorchLightning/pytorch-lightning/pull/5752))


- Added `QuantizationAwareTraining` callback ([#5706](https://github.com/PyTorchLightning/pytorch-lightning/pull/5706))


Expand Down
33 changes: 24 additions & 9 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,8 @@ def training_step(self, args):

args[0] = batch

with self.precision_plugin.train_step_context():
with self.training_type_plugin.train_step_context():
return self.training_type_plugin.training_step(*args)
with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
return self.training_type_plugin.training_step(*args)

def post_training_step(self):
self.training_type_plugin.post_training_step()
Expand All @@ -161,9 +160,8 @@ def validation_step(self, args):

args[0] = batch

with self.precision_plugin.val_step_context():
with self.training_type_plugin.val_step_context():
return self.training_type_plugin.validation_step(*args)
with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
return self.training_type_plugin.validation_step(*args)

def test_step(self, args):
"""The actual test step.
Expand All @@ -180,9 +178,26 @@ def test_step(self, args):

args[0] = batch

with self.precision_plugin.test_step_context():
with self.training_type_plugin.test_step_context():
return self.training_type_plugin.test_step(*args)
with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context():
return self.training_type_plugin.test_step(*args)

def predict(self, args):
"""The actual predict step.
Args:
args: the arguments for the models predict step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): The index of this batch.
dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple predict dataloaders used).
"""
batch = self.to_device(args[0])

args[0] = batch

with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context():
return self.training_type_plugin.predict(*args)

def training_step_end(self, output):
"""A hook to do something at the end of the training step
Expand Down
63 changes: 56 additions & 7 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self):
self._train_batch_idx = 0
self._val_batch_idx = 0
self._test_batch_idx = 0
self._predict_batch_idx = 0

@property
def trainer(self):
Expand Down Expand Up @@ -96,6 +97,14 @@ def test_batch_idx(self) -> int:
"""
return self._test_batch_idx

@property
def predict_batch_idx(self) -> int:
"""
The current batch index being processed during predicting.
Use this to update your progress bar.
"""
return self._predict_batch_idx

@property
def total_train_batches(self) -> int:
"""
Expand All @@ -108,7 +117,7 @@ def total_train_batches(self) -> int:
@property
def total_val_batches(self) -> int:
"""
The total number of training batches during validation, which may change from epoch to epoch.
The total number of validation batches during validation, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
validation dataloader is of infinite size.
"""
Expand All @@ -121,12 +130,21 @@ def total_val_batches(self) -> int:
@property
def total_test_batches(self) -> int:
"""
The total number of training batches during testing, which may change from epoch to epoch.
The total number of testing batches during testing, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
test dataloader is of infinite size.
"""
return sum(self.trainer.num_test_batches)

@property
def total_predict_batches(self) -> int:
"""
The total number of predicting batches during testing, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
predict dataloader is of infinite size.
"""
return sum(self.trainer.num_predict_batches)

def disable(self):
"""
You should provide a way to disable the progress bar.
Expand Down Expand Up @@ -168,6 +186,12 @@ def on_test_start(self, trainer, pl_module):
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._test_batch_idx += 1

def on_predict_start(self, trainer, pl_module):
self._predict_batch_idx = 0

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._predict_batch_idx += 1


class ProgressBar(ProgressBarBase):
r"""
Expand Down Expand Up @@ -282,6 +306,20 @@ def init_train_tqdm(self) -> tqdm:
)
return bar

def init_predict_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for predicting. """
bar = tqdm(
desc='Predicting',
initial=self.train_batch_idx,
position=(2 * self.process_position),
disable=self.is_disabled,
leave=True,
dynamic_ncols=True,
file=sys.stdout,
smoothing=0,
)
return bar

def init_validation_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for validation. """
bar = tqdm(
Expand All @@ -294,12 +332,10 @@ def init_validation_tqdm(self) -> tqdm:
)
return bar

def init_test_tqdm(self, trainer=None) -> tqdm:
def init_test_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for testing. """
desc = "Testing"
desc = "Predicting" if trainer is not None and getattr(trainer, "is_predicting", False) else "Testing"
bar = tqdm(
desc=desc,
desc="Testing",
position=(2 * self.process_position),
disable=self.is_disabled,
leave=True,
Expand Down Expand Up @@ -365,7 +401,7 @@ def on_train_end(self, trainer, pl_module):

def on_test_start(self, trainer, pl_module):
super().on_test_start(trainer, pl_module)
self.test_progress_bar = self.init_test_tqdm(trainer=trainer)
self.test_progress_bar = self.init_test_tqdm()
self.test_progress_bar.total = convert_inf(self.total_test_batches)

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
Expand All @@ -377,6 +413,19 @@ def on_test_end(self, trainer, pl_module):
super().on_test_end(trainer, pl_module)
self.test_progress_bar.close()

def on_predict_start(self, trainer, pl_module):
super().on_predict_start(trainer, pl_module)
self.predict_progress_bar = self.init_predict_tqdm()
self.predict_progress_bar.total = convert_inf(self.total_predict_batches)

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
self._update_bar(self.predict_progress_bar)

def on_predict_end(self, trainer, pl_module):
self.predict_progress_bar.close()

def _should_update(self, current, total):
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
pass

@abstractmethod
def predict_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
pass

@abstractmethod
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
pass
Expand Down
37 changes: 34 additions & 3 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,23 @@ def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader
"""
# do something when the batch ends

def on_test_model_train(self) -> None:
"""
Sets the model to train during the test loop
"""
self.train()

def on_test_model_eval(self) -> None:
"""
Sets the model to eval during the test loop
"""
self.eval()

def on_test_model_train(self) -> None:
def on_predict_model_eval(self) -> None:
"""
Sets the model to train during the test loop
Sets the model to eval during the predict loop
"""
self.train()
self.eval()

def on_epoch_start(self) -> None:
"""
Expand Down Expand Up @@ -518,6 +524,31 @@ def val_dataloader(self):
will have an argument ``dataloader_idx`` which matches the order here.
"""

def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement one or multiple PyTorch DataLoaders for prediction.
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
- :meth:`~pytorch_lightning.trainer.Trainer.fit`
- ...
- :meth:`prepare_data`
- :meth:`train_dataloader`
- :meth:`val_dataloader`
- :meth:`test_dataloader`
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware
There is no need to set it yourself.
Return:
Single or multiple PyTorch DataLoaders.
Note:
In the case where you return multiple prediction dataloaders, the :meth:`predict`
will have an argument ``dataloader_idx`` which matches the order here.
"""

def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,22 @@ def forward(self, *inputs, **kwargs):
if not self.module.automatic_optimization:
self.module.trainer.model.require_backward_grad_sync = False
warn_if_output_is_none(output, "training_step")

elif running_stage == RunningStage.TESTING:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")

elif running_stage == RunningStage.EVALUATING:
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")
else:

elif running_stage == RunningStage.PREDICTING:
output = self.module.predict(*inputs, **kwargs)
warn_if_output_is_none(output, "predict")

else:
output = self.module(*inputs, **kwargs)

return output


Expand Down
13 changes: 9 additions & 4 deletions pytorch_lightning/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def connect(
Will be called by the accelerator.
"""

def pre_training(self) -> None:
"""Hook to do something before the training starts."""
def pre_dispatch(self) -> None:
"""Hook to do something before the training/evaluation/prediction starts."""

def post_training(self) -> None:
"""Hook to do something after the training finishes."""
def post_dispatch(self) -> None:
"""Hook to do something after the training/evaluation/prediction finishes."""

@contextlib.contextmanager
def train_step_context(self) -> Generator:
Expand All @@ -53,3 +53,8 @@ def val_step_context(self) -> Generator:
def test_step_context(self) -> Generator:
"""A contextmanager for the teststep"""
yield

@contextlib.contextmanager
def predict_context(self) -> Generator:
"""A contextmanager for the predict step"""
yield
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)

def pre_training(self):
def pre_dispatch(self):
# TODO: check if needed
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
Expand All @@ -232,7 +232,7 @@ def pre_training(self):
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)

# TODO: we moved it to the trainer.fit after calling pre_training
# TODO: we moved it to the trainer.fit after calling pre_dispatch
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)

Expand All @@ -257,7 +257,7 @@ def pre_training(self):

self.barrier()

def post_training(self):
def post_dispatch(self):
if "WORLD_SIZE" in os.environ:
del os.environ["WORLD_SIZE"]

Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def start_training(self, trainer):
def start_testing(self, trainer):
mp.spawn(self.new_process, **self.mp_spawn_kwargs)

def start_predicting(self, trainer):
mp.spawn(self.new_process, **self.mp_spawn_kwargs)

def new_process(self, process_idx, trainer, mp_queue):
self.mp_queue = mp_queue

Expand All @@ -128,7 +131,7 @@ def new_process(self, process_idx, trainer, mp_queue):
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)

# TODO: we moved it to the trainer.fit after calling pre_training
# TODO: we moved it to the trainer.fit after calling pre_dispatch
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)

Expand All @@ -153,15 +156,12 @@ def new_process(self, process_idx, trainer, mp_queue):

self.barrier()

if trainer.testing:
results = trainer.run_test()
else:
results = trainer.train()
results = trainer.train_or_test_or_predict()

# persist info in ddp_spawn
self.transfer_distrib_spawn_state_on_fit_end(results)

def post_training(self):
def post_dispatch(self):
# restore main state with best weights
best_path = self.mp_queue.get()
last_path = self.mp_queue.get()
Expand Down
Loading

0 comments on commit e982800

Please sign in to comment.