diff --git a/CHANGELOG.md b/CHANGELOG.md index 13c11163bb244..95ca3329f8497 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added AUC/AUROC class interface ([#5479](https://github.com/PyTorchLightning/pytorch-lightning/pull/5479)) +- Added `PredictLoop` object ([#5752](https://github.com/PyTorchLightning/pytorch-lightning/pull/5752)) + + - Added `QuantizationAwareTraining` callback ([#5706](https://github.com/PyTorchLightning/pytorch-lightning/pull/5706)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 893456f4035e6..2e8e31139dda2 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -139,9 +139,8 @@ def training_step(self, args): args[0] = batch - with self.precision_plugin.train_step_context(): - with self.training_type_plugin.train_step_context(): - return self.training_type_plugin.training_step(*args) + with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context(): + return self.training_type_plugin.training_step(*args) def post_training_step(self): self.training_type_plugin.post_training_step() @@ -161,9 +160,8 @@ def validation_step(self, args): args[0] = batch - with self.precision_plugin.val_step_context(): - with self.training_type_plugin.val_step_context(): - return self.training_type_plugin.validation_step(*args) + with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context(): + return self.training_type_plugin.validation_step(*args) def test_step(self, args): """The actual test step. @@ -180,9 +178,26 @@ def test_step(self, args): args[0] = batch - with self.precision_plugin.test_step_context(): - with self.training_type_plugin.test_step_context(): - return self.training_type_plugin.test_step(*args) + with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context(): + return self.training_type_plugin.test_step(*args) + + def predict(self, args): + """The actual predict step. + + Args: + args: the arguments for the models predict step. Can consist of the following: + batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. + batch_idx (int): The index of this batch. + dataloader_idx (int): The index of the dataloader that produced this batch + (only if multiple predict dataloaders used). + """ + batch = self.to_device(args[0]) + + args[0] = batch + + with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context(): + return self.training_type_plugin.predict(*args) def training_step_end(self, output): """A hook to do something at the end of the training step diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index a37a979c9d971..7de7982b4a2de 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -67,6 +67,7 @@ def __init__(self): self._train_batch_idx = 0 self._val_batch_idx = 0 self._test_batch_idx = 0 + self._predict_batch_idx = 0 @property def trainer(self): @@ -96,6 +97,14 @@ def test_batch_idx(self) -> int: """ return self._test_batch_idx + @property + def predict_batch_idx(self) -> int: + """ + The current batch index being processed during predicting. + Use this to update your progress bar. + """ + return self._predict_batch_idx + @property def total_train_batches(self) -> int: """ @@ -108,7 +117,7 @@ def total_train_batches(self) -> int: @property def total_val_batches(self) -> int: """ - The total number of training batches during validation, which may change from epoch to epoch. + The total number of validation batches during validation, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation dataloader is of infinite size. """ @@ -121,12 +130,21 @@ def total_val_batches(self) -> int: @property def total_test_batches(self) -> int: """ - The total number of training batches during testing, which may change from epoch to epoch. + The total number of testing batches during testing, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is of infinite size. """ return sum(self.trainer.num_test_batches) + @property + def total_predict_batches(self) -> int: + """ + The total number of predicting batches during testing, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + predict dataloader is of infinite size. + """ + return sum(self.trainer.num_predict_batches) + def disable(self): """ You should provide a way to disable the progress bar. @@ -168,6 +186,12 @@ def on_test_start(self, trainer, pl_module): def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._test_batch_idx += 1 + def on_predict_start(self, trainer, pl_module): + self._predict_batch_idx = 0 + + def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self._predict_batch_idx += 1 + class ProgressBar(ProgressBarBase): r""" @@ -282,6 +306,20 @@ def init_train_tqdm(self) -> tqdm: ) return bar + def init_predict_tqdm(self) -> tqdm: + """ Override this to customize the tqdm bar for predicting. """ + bar = tqdm( + desc='Predicting', + initial=self.train_batch_idx, + position=(2 * self.process_position), + disable=self.is_disabled, + leave=True, + dynamic_ncols=True, + file=sys.stdout, + smoothing=0, + ) + return bar + def init_validation_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for validation. """ bar = tqdm( @@ -294,12 +332,10 @@ def init_validation_tqdm(self) -> tqdm: ) return bar - def init_test_tqdm(self, trainer=None) -> tqdm: + def init_test_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for testing. """ - desc = "Testing" - desc = "Predicting" if trainer is not None and getattr(trainer, "is_predicting", False) else "Testing" bar = tqdm( - desc=desc, + desc="Testing", position=(2 * self.process_position), disable=self.is_disabled, leave=True, @@ -365,7 +401,7 @@ def on_train_end(self, trainer, pl_module): def on_test_start(self, trainer, pl_module): super().on_test_start(trainer, pl_module) - self.test_progress_bar = self.init_test_tqdm(trainer=trainer) + self.test_progress_bar = self.init_test_tqdm() self.test_progress_bar.total = convert_inf(self.total_test_batches) def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -377,6 +413,19 @@ def on_test_end(self, trainer, pl_module): super().on_test_end(trainer, pl_module) self.test_progress_bar.close() + def on_predict_start(self, trainer, pl_module): + super().on_predict_start(trainer, pl_module) + self.predict_progress_bar = self.init_predict_tqdm() + self.predict_progress_bar.total = convert_inf(self.total_predict_batches) + + def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self._should_update(self.predict_batch_idx, self.total_predict_batches): + self._update_bar(self.predict_progress_bar) + + def on_predict_end(self, trainer, pl_module): + self.predict_progress_bar.close() + def _should_update(self, current, total): return self.is_enabled and (current % self.refresh_rate == 0 or current == total) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index d0e1725b2c4ac..3b9d8e7de49e1 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -260,6 +260,10 @@ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]] def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: pass + @abstractmethod + def predict_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + pass + @abstractmethod def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: pass diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 23fd5d9b58755..ac7bb2a1d20e1 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -204,17 +204,23 @@ def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader """ # do something when the batch ends + def on_test_model_train(self) -> None: + """ + Sets the model to train during the test loop + """ + self.train() + def on_test_model_eval(self) -> None: """ Sets the model to eval during the test loop """ self.eval() - def on_test_model_train(self) -> None: + def on_predict_model_eval(self) -> None: """ - Sets the model to train during the test loop + Sets the model to eval during the predict loop """ - self.train() + self.eval() def on_epoch_start(self) -> None: """ @@ -518,6 +524,31 @@ def val_dataloader(self): will have an argument ``dataloader_idx`` which matches the order here. """ + def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + r""" + Implement one or multiple PyTorch DataLoaders for prediction. + + It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. + + - :meth:`~pytorch_lightning.trainer.Trainer.fit` + - ... + - :meth:`prepare_data` + - :meth:`train_dataloader` + - :meth:`val_dataloader` + - :meth:`test_dataloader` + + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware + There is no need to set it yourself. + + Return: + Single or multiple PyTorch DataLoaders. + + Note: + In the case where you return multiple prediction dataloaders, the :meth:`predict` + will have an argument ``dataloader_idx`` which matches the order here. + """ + def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: """ Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index 1a33556991148..2fcb4b11a0b7f 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -54,14 +54,22 @@ def forward(self, *inputs, **kwargs): if not self.module.automatic_optimization: self.module.trainer.model.require_backward_grad_sync = False warn_if_output_is_none(output, "training_step") + elif running_stage == RunningStage.TESTING: output = self.module.test_step(*inputs, **kwargs) warn_if_output_is_none(output, "test_step") + elif running_stage == RunningStage.EVALUATING: output = self.module.validation_step(*inputs, **kwargs) warn_if_output_is_none(output, "validation_step") - else: + + elif running_stage == RunningStage.PREDICTING: output = self.module.predict(*inputs, **kwargs) + warn_if_output_is_none(output, "predict") + + else: + output = self.module(*inputs, **kwargs) + return output diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py index b8bdf38a57137..e495d9ffadc3f 100644 --- a/pytorch_lightning/plugins/base_plugin.py +++ b/pytorch_lightning/plugins/base_plugin.py @@ -33,11 +33,11 @@ def connect( Will be called by the accelerator. """ - def pre_training(self) -> None: - """Hook to do something before the training starts.""" + def pre_dispatch(self) -> None: + """Hook to do something before the training/evaluation/prediction starts.""" - def post_training(self) -> None: - """Hook to do something after the training finishes.""" + def post_dispatch(self) -> None: + """Hook to do something after the training/evaluation/prediction finishes.""" @contextlib.contextmanager def train_step_context(self) -> Generator: @@ -53,3 +53,8 @@ def val_step_context(self) -> Generator: def test_step_context(self) -> Generator: """A contextmanager for the teststep""" yield + + @contextlib.contextmanager + def predict_context(self) -> Generator: + """A contextmanager for the predict step""" + yield diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 52a24655f0846..6e6c292eec140 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -215,7 +215,7 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) - def pre_training(self): + def pre_dispatch(self): # TODO: check if needed seed = os.environ.get("PL_GLOBAL_SEED") if seed is not None: @@ -232,7 +232,7 @@ def pre_training(self): # where to store ip_table self.init_ddp_connection(self.global_rank, self.world_size) - # TODO: we moved it to the trainer.fit after calling pre_training + # TODO: we moved it to the trainer.fit after calling pre_dispatch # ... need to double check that it is the correct place # self.trainer.call_setup_hook(self.model) @@ -257,7 +257,7 @@ def pre_training(self): self.barrier() - def post_training(self): + def post_dispatch(self): if "WORLD_SIZE" in os.environ: del os.environ["WORLD_SIZE"] diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 6b6d85ee0d29f..449373e2c35ea 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -110,6 +110,9 @@ def start_training(self, trainer): def start_testing(self, trainer): mp.spawn(self.new_process, **self.mp_spawn_kwargs) + def start_predicting(self, trainer): + mp.spawn(self.new_process, **self.mp_spawn_kwargs) + def new_process(self, process_idx, trainer, mp_queue): self.mp_queue = mp_queue @@ -128,7 +131,7 @@ def new_process(self, process_idx, trainer, mp_queue): # where to store ip_table self.init_ddp_connection(self.global_rank, self.world_size) - # TODO: we moved it to the trainer.fit after calling pre_training + # TODO: we moved it to the trainer.fit after calling pre_dispatch # ... need to double check that it is the correct place # self.trainer.call_setup_hook(self.model) @@ -153,15 +156,12 @@ def new_process(self, process_idx, trainer, mp_queue): self.barrier() - if trainer.testing: - results = trainer.run_test() - else: - results = trainer.train() + results = trainer.train_or_test_or_predict() # persist info in ddp_spawn self.transfer_distrib_spawn_state_on_fit_end(results) - def post_training(self): + def post_dispatch(self): # restore main state with best weights best_path = self.mp_queue.get() last_path = self.mp_queue.get() diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 995c83079992c..c1de2d7833177 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -50,7 +50,7 @@ def setup(self, model): self.model_to_device() - def pre_training(self): + def pre_dispatch(self): def _unpack_lightning_optimizer(opt): return opt._optimizer if isinstance(opt, LightningOptimizer) else opt @@ -95,20 +95,26 @@ def start_training(self, trainer): stack.enter_context(optimizer.skip_synchronize()) # set up training routine - self._results = trainer.train() + self._results = trainer.run_train() # Make sure all workers have finished training before returning to the user hvd.join() def start_testing(self, trainer): with ExitStack() as stack: - # set up training routine - # self.trainer.train_loop.setup_training(self.trainer.model) self._results = trainer.run_test() # Make sure all workers have finished training before returning to the user hvd.join() + def start_predicting(self, trainer): + with ExitStack() as stack: + # set up training routine + self._results = trainer.run_predict() + + # Make sure all workers have finished training before returning to the user + hvd.join() + def barrier(self, *args, **kwargs): hvd.join() diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 345f208b97cde..fc707afb3e2c2 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -325,9 +325,9 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs # Initialize optimizer step on main process self.worker_optimizer_step(model=self.lightning_module, opt_idx=optimizer_idx, **kwargs) - def post_training(self): + def post_training_step(self): if self.main_rpc_process: - super().post_training() + super().post_training_step() def start_training(self, trainer: 'Trainer') -> None: if self.main_rpc_process: diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 46df404bdc02f..40fc9fba3a6a7 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -36,14 +36,14 @@ def connect(self, model: torch.nn.Module) -> torch.nn.Module: def model_to_device(self) -> None: self._model.to(self.root_device) - def pre_training(self) -> None: + def pre_dispatch(self) -> None: if isinstance(self.device, int): self.device = xm.xla_device(self.device) self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal() - def post_training(self) -> None: + def post_dispatch(self) -> None: model = self.lightning_module if on_colab_kaggle(): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 4c5844da94ced..d4374d0ef9c6a 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -95,10 +95,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: trainer.save_checkpoint = self.save_checkpoint self.barrier() - if trainer.testing: - results = trainer.run_test() - else: - results = trainer.train() + results = trainer.train_or_test_or_predict() self.__save_end_of_training_weights(self.lightning_module) self.transfer_distrib_spawn_state_on_fit_end(results) @@ -182,7 +179,7 @@ def reduce_early_stopping_decision(self, should_stop: bool) -> bool: should_stop = int(stop.item()) == self.world_size return should_stop - def post_training(self) -> None: + def post_dispatch(self) -> None: # TODO: Check if trainer references can be resolved otherwise model = self.lightning_module @@ -233,6 +230,9 @@ def start_training(self, trainer) -> None: def start_testing(self, trainer) -> None: xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) + def start_predicting(self, trainer) -> None: + xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) + def training_step(self, *args, **kwargs): return self.lightning_module.training_step(*args, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 74f5837afc67f..cede3e5f98b43 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -112,12 +112,16 @@ def rpc_enabled(self) -> bool: def start_training(self, trainer: 'Trainer') -> None: # double dispatch to initiate the training loop - self._results = trainer.train() + self._results = trainer.run_train() def start_testing(self, trainer: 'Trainer') -> None: # double dispatch to initiate the test loop self._results = trainer.run_test() + def start_predicting(self, trainer: 'Trainer') -> None: + # double dispatch to initiate the predicting loop + self._results = trainer.run_predict() + def training_step(self, *args, **kwargs): return self.lightning_module.training_step(*args, **kwargs) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index a7e13de8ede39..9cb22f39b7228 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -52,7 +52,7 @@ def __verify_train_loop_configuration(self, model): # verify model has a train dataloader # ----------------------------------- has_train_dataloader = is_overridden('train_dataloader', model) - if not has_train_dataloader and not self.trainer._predicting: + if not has_train_dataloader: raise MisconfigurationException( 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' @@ -62,7 +62,7 @@ def __verify_train_loop_configuration(self, model): # verify model has optimizer # ----------------------------------- has_optimizers = is_overridden('configure_optimizers', model) - if not has_optimizers and not self.trainer._predicting: + if not has_optimizers: raise MisconfigurationException( 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 9161f3e8754ec..2852d9dfaf22f 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -90,7 +90,14 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' ) - def attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None): + def attach_dataloaders( + self, + model, + train_dataloader=None, + val_dataloaders=None, + test_dataloaders=None, + predict_dataloaders=None + ): # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations if train_dataloader is not None: @@ -102,6 +109,9 @@ def attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, if test_dataloaders is not None: model.test_dataloader = _PatchDataLoader(test_dataloaders) + if predict_dataloaders is not None: + model.predict_dataloader = _PatchDataLoader(predict_dataloaders) + def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], stage: str) -> None: # Todo: required argument `stage` is not used @@ -118,6 +128,8 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], st model.val_dataloader = datamodule.val_dataloader if is_overridden('test_dataloader', datamodule): model.test_dataloader = datamodule.test_dataloader + if is_overridden('predict_dataloader', datamodule): + model.predict_dataloader = datamodule.predict_dataloader # Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule if is_overridden('transfer_batch_to_device', datamodule): diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index cb2ecc20f51ce..28c99f8f4de6d 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -29,6 +29,7 @@ def on_init_start( limit_train_batches, limit_val_batches, limit_test_batches, + limit_predict_batches, val_check_interval, overfit_batches, fast_dev_run, @@ -56,6 +57,7 @@ def on_init_start( limit_train_batches = fast_dev_run limit_val_batches = fast_dev_run limit_test_batches = fast_dev_run + limit_predict_batches = fast_dev_run self.trainer.max_steps = fast_dev_run self.trainer.num_sanity_val_steps = 0 self.trainer.max_epochs = 1 @@ -71,6 +73,7 @@ def on_init_start( self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches') self.trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches') self.trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches') + self.trainer.limit_predict_batches = _determine_batch_limits(limit_predict_batches, 'limit_predict_batches') self.trainer.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval') self.trainer.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches') self.determine_data_use_amount(self.trainer.overfit_batches) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index fd93c559ff7d2..946a9006442e8 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -293,7 +293,8 @@ def _reset_eval_dataloader( loader = dataloaders[loader_i] # shuffling in val and test set is bad practice - if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler): + modes = ('val', 'test', 'predict') + if mode in modes and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler): # when overfitting, the dataloader should not have sampler if self.overfit_batches > 0: @@ -363,7 +364,7 @@ def reset_val_dataloader(self, model: LightningModule) -> None: self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val') def reset_test_dataloader(self, model) -> None: - """Resets the validation dataloader and determines the number of batches. + """Resets the test dataloader and determines the number of batches. Args: model: The current `LightningModule` @@ -374,6 +375,17 @@ def reset_test_dataloader(self, model) -> None: self.num_test_batches, self.test_dataloaders =\ self._reset_eval_dataloader(model, 'test') + def reset_predict_dataloader(self, model) -> None: + """Resets the predict dataloader and determines the number of batches. + + Args: + model: The current `LightningModule` + """ + has_loader = is_overridden('predict_dataloader', model) + if has_loader: + self.num_predict_batches, self.predict_dataloaders =\ + self._reset_eval_dataloader(model, 'predict') + def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: """Handles downloading data in the GPU or TPU case. diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 1fbcc80ca424b..fe3fc62ff1189 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -154,16 +154,7 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx): model_ref = self.trainer.get_model() model_ref._results = Result() - if self.trainer._predicting: - model_ref._current_fx_name = "predict" - predictions = self.trainer.accelerator_backend.predict(args) - self._predictions[dataloader_idx].append(predictions) - self.trainer._progress_bar_callback.on_test_batch_end( - self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx - ) - return - - elif self.testing: + if self.testing: model_ref._current_fx_name = "test_step" with self.trainer.profiler.profile("test_step"): output = self.trainer.accelerator_backend.test_step(args) diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py new file mode 100644 index 0000000000000..43016b8943c81 --- /dev/null +++ b/pytorch_lightning/trainer/predict_loop.py @@ -0,0 +1,97 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +class PredictLoop(object): + + def __init__(self, trainer): + self.trainer = trainer + self.max_batches = None + self.num_dataloaders = None + + def on_trainer_init(self): + self.trainer.num_predict_batches = [] + + def get_predict_dataloaders(self, max_batches): + # select dataloaders + model = self.trainer.get_model() + self.trainer.reset_predict_dataloader(model) + dataloaders = self.trainer.predict_dataloaders + if max_batches is None: + max_batches = self.trainer.num_predict_batches + + return dataloaders, max_batches + + def should_skip_predict(self, dataloaders, max_batches): + return dataloaders is None or not sum(max_batches) + + def on_predict_model_eval(self, *_, **__): + model_ref = self.trainer.get_model() + model_ref.on_predict_model_eval() + + def setup(self, model, max_batches, dataloaders): + # copy properties for forward overrides + self.trainer.model_connector.copy_trainer_model_properties(model) + + # convert max_batches to list + if isinstance(max_batches, int): + max_batches = [max_batches] * len(dataloaders) + + self.max_batches = max_batches + self.num_dataloaders = self._get_num_dataloaders(dataloaders) + self._predictions = [[] for _ in range(self.num_dataloaders)] + + self.trainer._progress_bar_callback.on_predict_start(self.trainer, self.trainer.get_model()) + + def _get_num_dataloaders(self, dataloaders): + # case where user does: + # return dl1, dl2 + length = len(dataloaders) + if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): + length = len(dataloaders[0]) + return length + + def predict(self, batch, batch_idx, dataloader_idx): + # configure args + args = [batch, batch_idx] + if self.num_dataloaders: + args.append(dataloader_idx) + + model_ref = self.trainer.get_model() + + model_ref._current_fx_name = "predict" + predictions = self.trainer.accelerator_backend.predict(args) + self._predictions[dataloader_idx].append(predictions) + self.trainer._progress_bar_callback.on_predict_batch_end( + self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx + ) + return + + def on_predict_epoch_end(self): + self.trainer._progress_bar_callback.on_predict_end(self.trainer, self.trainer.get_model()) + + results = self._predictions + + def _convert_to_numpy(v): + return v.cpu().numpy() + + results = apply_to_collection(results, torch.Tensor, _convert_to_numpy) + + if len(results) == 1: + return results[0] + + return results diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index db04734d2f2f9..2e45f9502edee 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -49,6 +49,7 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin +from pytorch_lightning.trainer.predict_loop import PredictLoop from pytorch_lightning.trainer.properties import TrainerProperties from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.trainer.training_loop import TrainLoop @@ -57,6 +58,7 @@ from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger +from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.model_helpers import is_overridden @@ -107,6 +109,7 @@ def __init__( limit_train_batches: Union[int, float] = 1.0, limit_val_batches: Union[int, float] = 1.0, limit_test_batches: Union[int, float] = 1.0, + limit_predict_batches: Union[int, float] = 1.0, val_check_interval: Union[int, float] = 1.0, flush_logs_every_n_steps: int = 100, log_every_n_steps: int = 50, @@ -296,7 +299,6 @@ def __init__( """ super().__init__() self._running_stage = None - self._predicting = False distributed_backend = distributed_backend or accelerator @@ -319,8 +321,9 @@ def __init__( self.checkpoint_connector = CheckpointConnector(self) self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) - self.evaluation_loop = EvaluationLoop(self) self.train_loop = TrainLoop(self, multiple_trainloader_mode) + self.evaluation_loop = EvaluationLoop(self) + self.predict_loop = PredictLoop(self) # training state self.weights_summary = weights_summary @@ -393,6 +396,7 @@ def __init__( limit_train_batches, limit_val_batches, limit_test_batches, + limit_predict_batches, val_check_interval, overfit_batches, fast_dev_run, @@ -440,7 +444,11 @@ def fit( """ # bookkeeping self._state = TrainerState.RUNNING - self._set_wide_running_stage(RunningStage.TRAINING) + + # bookkeeping + # we reuse fit in .test() and .predict(). When already set, it shouldn't be modified. + if self._running_stage is None: + self._set_running_stage(RunningStage.TRAINING, model) # set local properties on the model self.model_connector.copy_trainer_model_properties(model) @@ -463,27 +471,47 @@ def fit( self.accelerator_backend.setup(self, model) self.setup_trainer(model) + # ---------------------------- + # INSPECT THE CORE LOOPS + # ---------------------------- + # Lightning internal flow looks like this. + # + # trainer.fit(...) or trainer.test(...) or trainer.predict(...) || + # | || + # create accelerator || + # | || + # trainer.dispatch || LIGHTNING + # | || + # start_training or start_testing or start_predicting call || FLOW + # from `accelerator.training_type_plugin` || + # | || DIRECTION + # run_train or run_test or run_predict call || + # from `trainer` || + # | || + # results \/ + # This is used to guide readers to the core loops: train, test, predict. + # `run_predict` is the simplest to understand, use `Go to Definition` to read it :) + # Search for `start_training` or `start_testing` or `start_predicting` in + # `pytorch_lightning/plugins/training_type` folder to find accelerator dispatch functions. + self.accelerator.train_loop = self.run_train + self.accelerator.validation_loop = self.run_evaluation + self.accelerator.test_loop = self.run_evaluation + self.accelerator.predict_loop = self.run_predict + # ---------------------------- # TRAIN # ---------------------------- # hook self.call_hook("on_fit_start") - # plugin will setup training (e.g. ddp will launch child processes) - # TODO: the old setup is now called "pre_training", where should this hook be called now? - self.training_type_plugin.pre_training() - self.precision_plugin.pre_training() + # plugin will setup fitting (e.g. ddp will launch child processes) + self.pre_dispatch() - # double dispatch: let the plugin initiate the training/test loop. - if self.testing: - self.training_type_plugin.start_testing(self) - else: - self.training_type_plugin.start_training(self) + # dispath `start_training` or `start_testing` or `start_predicting` + self.dispatch() - self.precision_plugin.post_training() - self.training_type_plugin.post_training() - self.accelerator_backend.teardown() - results = self.training_type_plugin.results + # plugin will finalized fitting (e.g. ddp_spawn will load trained model) + self.post_dispatch() # ---------------------------- # POST-Training CLEAN UP @@ -501,31 +529,47 @@ def fit( if self._state != TrainerState.INTERRUPTED: self._state = TrainerState.FINISHED - self._set_wide_running_stage(None) + self._set_running_stage(None, model) - return results or 1 + return self.training_type_plugin.results or 1 - def _set_wide_running_stage(self, stage): - model_ref = self.get_model() + def pre_dispatch(self): + self.training_type_plugin.pre_dispatch() + self.precision_plugin.pre_dispatch() + + def post_dispatch(self): + self.training_type_plugin.post_dispatch() + self.precision_plugin.post_dispatch() + self.accelerator_backend.teardown() - if stage is None: - self._running_stage = stage - model_ref.running_stage = stage - return + def dispatch(self): + if self.testing: + self.training_type_plugin.start_testing(self) - # todo: clean up this routing mess. - if self._running_stage == RunningStage.TESTING: - stage = RunningStage.TESTING + elif self.predicting: + self.training_type_plugin.start_predicting(self) - # WARNING: With predicting, - # trainer _running_state should be RunningStage.TESTING - # however, the model running_stage should be RunningStage.PREDICTING or None - if model_ref is not None: - if self._predicting: - model_ref.running_stage = RunningStage.PREDICTING - else: - model_ref.running_stage = stage + else: + self.training_type_plugin.start_training(self) + + def train_or_test_or_predict(self): + if self.testing: + results = self.run_test() + + elif self.predicting: + results = self.run_predict() + + else: + results = self.run_train() + + return results + def _set_running_stage(self, stage: LightningEnum, model_ref: LightningModule): + """ + This function is used to set the running_state on both + the trainer and the model + """ + model_ref.running_stage = stage self._running_stage = stage def _pre_training_routine(self): @@ -560,7 +604,7 @@ def _pre_training_routine(self): if self.is_function_implemented("on_pretrain_routine_end"): ref_model.on_pretrain_routine_end() - def train(self): + def run_train(self): self._pre_training_routine() @@ -570,7 +614,7 @@ def train(self): self.run_sanity_check(self.get_model()) # set stage for logging - self._set_wide_running_stage(RunningStage.TRAINING) + self._set_running_stage(RunningStage.TRAINING, self.get_model()) self.checkpoint_connector.has_trained = False @@ -634,7 +678,7 @@ def train(self): def run_evaluation(self, max_batches=None, on_epoch=False): # used to know if we are logging for val, test + reset cached results - self._set_wide_running_stage(RunningStage.TESTING if self.testing else RunningStage.EVALUATING) + self._set_running_stage(RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.get_model()) self.logger_connector.reset() # bookkeeping @@ -647,11 +691,10 @@ def run_evaluation(self, max_batches=None, on_epoch=False): if self.evaluation_loop.should_skip_evaluation(max_batches): return [], [] - # ref model - model = self.get_model() - # enable eval mode + no grads self.evaluation_loop.on_evaluation_model_eval() + # ref model + model = self.get_model() model.zero_grad() torch.set_grad_enabled(False) @@ -685,8 +728,6 @@ def run_evaluation(self, max_batches=None, on_epoch=False): # lightning module methods with self.profiler.profile("evaluation_step_and_end"): output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx) - if self._predicting: - continue output = self.evaluation_loop.evaluation_step_end(output) # hook + store predictions @@ -701,9 +742,6 @@ def run_evaluation(self, max_batches=None, on_epoch=False): # store batch level output per dataloader self.evaluation_loop.outputs.append(dl_outputs) - if self._predicting: - return self.evaluation_loop.on_predict_epoch_end() - # lightning module method deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end() @@ -764,6 +802,45 @@ def run_test(self): return eval_loop_results + def run_predict(self): + # prepare dataloaders + dataloaders, max_batches = self.predict_loop.get_predict_dataloaders(None) + + # check if we want to skip this evaluation + if self.predict_loop.should_skip_predict(dataloaders, max_batches): + return [] + + # ref model + model = self.get_model() + + # enable eval mode + no grads + self.predict_loop.on_predict_model_eval() + model.zero_grad() + torch.set_grad_enabled(False) + + # set up the eval loop + self.predict_loop.setup(model, max_batches, dataloaders) + + # run validation/testing + for dataloader_idx, dataloader in enumerate(dataloaders): + dataloader = self.accelerator_backend.process_dataloader(dataloader) + dl_max_batches = self.predict_loop.max_batches[dataloader_idx] + + for batch_idx, batch in enumerate(dataloader): + if batch is None: + continue + + # stop short when running on limited batches + if batch_idx >= dl_max_batches: + break + + # lightning module methods + with self.profiler.profile("predict"): + self.predict_loop.predict(batch, batch_idx, dataloader_idx) + + results = self.predict_loop.on_predict_epoch_end() + return results + def run_sanity_check(self, ref_model): using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model) should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 @@ -828,7 +905,7 @@ def test( # -------------------- self.verbose_test = verbose - self._set_wide_running_stage(RunningStage.TESTING) + self._set_running_stage(RunningStage.TESTING, model or self.get_model()) # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: @@ -845,9 +922,7 @@ def test( results = self.__test_using_best_weights(ckpt_path, test_dataloaders) self.teardown('test') - - self._set_wide_running_stage(None) - + self._set_running_stage(None, model or self.get_model()) return results def __test_using_best_weights(self, ckpt_path, test_dataloaders): @@ -935,35 +1010,28 @@ def predict( # -------------------- # SETUP HOOK # -------------------- - self._set_wide_running_stage(RunningStage.TESTING) - # If you supply a datamodule you can't supply dataloaders + + model = model or self.get_model() + + self._set_running_stage(RunningStage.PREDICTING, model) + if dataloaders and datamodule: raise MisconfigurationException( 'You cannot pass dataloaders to trainer.predict if you supply a datamodule.' ) - if model is None: - raise MisconfigurationException('You need to pass a model to `trainer.predict`.') - if datamodule is not None: # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model, datamodule, 'test') + self.data_connector.attach_datamodule(model, datamodule, 'predict') # attach data if dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) + self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders) - # set path variable - self._predicting = True self.model = model - results = self.fit(model) - - # unset path variable - self.teardown('test') - self._predicting = False - self._set_wide_running_stage(None) + self._set_running_stage(None, model) return results @@ -1069,6 +1137,17 @@ def testing(self, val: bool) -> None: elif self.testing: self._running_stage = None + @property + def predicting(self) -> bool: + return self._running_stage == RunningStage.PREDICTING + + @predicting.setter + def predicting(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.PREDICTING + elif self.predicting: + self._running_stage = None + @property def tuning(self) -> bool: return self._running_stage == RunningStage.TUNING diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1640afe97fba2..0908e96bd1c17 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -547,7 +547,7 @@ def run_training_epoch(self): self.trainer.run_evaluation() # reset stage to train - self.trainer._set_wide_running_stage(RunningStage.TRAINING) + self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module) # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) @@ -594,7 +594,7 @@ def run_training_epoch(self): self.trainer.run_evaluation(on_epoch=True) # reset stage to train - self.trainer._set_wide_running_stage(RunningStage.TRAINING) + self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module) should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 4ea4d511e1d0c..8398aec88fe68 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -288,8 +288,8 @@ def init_validation_tqdm(self): bar = super().init_validation_tqdm() return self._mock_bar_update(bar) - def init_test_tqdm(self, trainer=None): - bar = super().init_test_tqdm(trainer=trainer) + def init_test_tqdm(self): + bar = super().init_test_tqdm() return self._mock_bar_update(bar) diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 15faf787b57f3..64481bd70390d 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -39,7 +39,7 @@ def test_lightning_wrapper_module_methods(wrapper_class): wrapped_module(batch, batch_idx) pl_module.validation_step.assert_called_with(batch, batch_idx) - pl_module.running_stage = None + pl_module.running_stage = RunningStage.PREDICTING wrapped_module(batch) pl_module.predict.assert_called_with(batch) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4e85a5695b9f2..71caaaad4d7f9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1498,6 +1498,9 @@ def __init__(self, dataloaders): def test_dataloader(self): return self._dataloaders + def predict_dataloader(self): + return self._dataloaders + def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=True): @@ -1515,7 +1518,6 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T gpus=gpus, num_processes=num_processes, plugins=plugins, - num_sanity_val_steps=0 ) if datamodule: results = trainer.predict(model, datamodule=datamodule) @@ -1529,9 +1531,6 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T assert results[0][0].shape == torch.Size([1, 2]) -@pytest.mark.skipif( - not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" -) @pytest.mark.parametrize('datamodule', [False, True]) def test_trainer_predict_cpu(tmpdir, datamodule): predict(tmpdir, None, None, 1, datamodule=datamodule)