From c425c11ad9649e243b5eb961cd692fff1e51f4c0 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 20:06:36 +0000 Subject: [PATCH 1/2] Trainer only references accelerator where it can --- pytorch_lightning/accelerators/accelerator.py | 50 ++++++++++++------- pytorch_lightning/trainer/data_loading.py | 2 +- pytorch_lightning/trainer/properties.py | 6 +-- pytorch_lightning/trainer/trainer.py | 23 ++++----- 4 files changed, 47 insertions(+), 34 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 967b6a85c878b..8c50831ce47be 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -76,6 +76,26 @@ def setup(self, trainer: "Trainer", model: LightningModule) -> None: self.setup_optimizers(trainer) self.connect_precision_plugin(self.precision_plugin) + def start_training(self, trainer: 'Trainer'): + self.training_type_plugin.start_training(trainer) + + def start_testing(self, trainer: 'Trainer'): + self.training_type_plugin.start_testing(trainer) + + def start_predicting(self, trainer: 'Trainer'): + self.training_type_plugin.start_predicting(trainer) + + def pre_dispatch(self) -> None: + """Hook to do something before the training/evaluation/prediction starts.""" + self.training_type_plugin.pre_dispatch() + self.precision_plugin.pre_dispatch() + + def post_dispatch(self) -> None: + """Hook to do something before the training/evaluation/prediction starts.""" + self.training_type_plugin.post_dispatch() + self.precision_plugin.post_dispatch() + self.teardown() + @property def model(self) -> torch.nn.Module: """Returns the model. This can also be a wrapped LightningModule. @@ -224,23 +244,6 @@ def validation_step_end(self, output): """ return self.training_type_plugin.validation_step_end(output) - def predict(self, args): - """The prediction step. - - Args: - args: the arguments for the models predict step. Can consist of the following: - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): - The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - batch_idx (int): Integer displaying index of this batch - optimizer_idx (int): When using multiple optimizers, this argument will also be present. - hiddens(:class:`~torch.Tensor`): Passed in if - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0. - - """ - batch = self.to_device(args[0]) - args[0] = batch - return self.training_type_plugin.predict(*args) - def backward( self, closure_loss: torch.Tensor, @@ -378,6 +381,10 @@ def on_save(self, checkpoint): def barrier(self, name: Optional[str] = None) -> None: self.training_type_plugin.barrier(name=name) + def broadcast(self, obj: object, src: int = 0) -> object: + """Broadcasts an object to all processes""" + return self.training_type_plugin.broadcast(obj, src) + def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): """ Function to gather a tensor from several distributed processes @@ -397,3 +404,12 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` """ return self.training_type_plugin.process_dataloader(dataloader) + + @property + def results(self) -> Any: + """ + The results of the last training/testing run will be cached here. + In distributed training, we make sure to transfer the results to the appropriate master process. + """ + # TODO: improve these docs + return self.training_type_plugin.results diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 946a9006442e8..6d8fa4ad9040f 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -399,7 +399,7 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: dataloader = self._flatten_dl_only(dataloader) if self.accelerator_backend is not None: - self.training_type_plugin.barrier('get_dataloaders') + self.accelerator_backend.barrier('get_dataloaders') return dataloader def _flatten_dl_only(self, dataloaders): diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 282b4539df0be..ec735e9dccf71 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -21,7 +21,6 @@ from torch.optim import Optimizer from pytorch_lightning.accelerators import Accelerator -from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule @@ -29,6 +28,7 @@ from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin +from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.states import TrainerState @@ -138,7 +138,7 @@ def log_dir(self) -> Optional[str]: else: dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir') - dirpath = self.training_type_plugin.broadcast(dirpath) + dirpath = self.accelerator_backend.broadcast(dirpath) return dirpath @property @@ -365,7 +365,7 @@ def lightning_optimizers(self) -> List[LightningOptimizer]: @property def lightning_module(self) -> LightningModule: - return self.training_type_plugin.lightning_module + return self.accelerator_backend.lightning_module @property def optimizers(self) -> Optional[List[Optimizer]]: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 46ca290b24d34..603c006f21092 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -22,7 +22,6 @@ from pytorch_lightning import _logger as log from pytorch_lightning.accelerators import Accelerator -from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule @@ -32,6 +31,7 @@ from pytorch_lightning.profiler import BaseProfiler from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import ConfigValidator +from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector @@ -483,7 +483,7 @@ def fit( # trainer.dispatch || LIGHTNING # | || # start_training or start_testing or start_predicting call || FLOW - # from `accelerator.training_type_plugin` || + # from `accelerator` || # | || DIRECTION # run_train or run_test or run_predict call || # from `trainer` || @@ -531,26 +531,23 @@ def fit( self._set_running_stage(None, model) - return self.training_type_plugin.results or 1 + return self.accelerator_backend.results or 1 def pre_dispatch(self): - self.training_type_plugin.pre_dispatch() - self.precision_plugin.pre_dispatch() + self.accelerator_backend.pre_dispatch() def post_dispatch(self): - self.training_type_plugin.post_dispatch() - self.precision_plugin.post_dispatch() - self.accelerator_backend.teardown() + self.accelerator_backend.post_dispatch() def dispatch(self): if self.testing: - self.training_type_plugin.start_testing(self) + self.accelerator_backend.start_testing(self) elif self.predicting: - self.training_type_plugin.start_predicting(self) + self.accelerator_backend.start_predicting(self) else: - self.training_type_plugin.start_training(self) + self.accelerator_backend.start_training(self) def train_or_test_or_predict(self): if self.testing: @@ -574,7 +571,7 @@ def _set_running_stage(self, stage: LightningEnum, model_ref: LightningModule): def _pre_training_routine(self): # wait for all to join if on distributed - self.accelerator.training_type_plugin.barrier("setup_training") + self.accelerator.barrier("setup_training") # register auto-resubmit when on SLURM self.slurm_connector.register_slurm_signal_handlers() @@ -947,7 +944,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): ) return {} if not self._device_type == DeviceType.TPU: - self.training_type_plugin.barrier() + self.accelerator_backend.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) From 06c70f1dd4a2c58680b437bccbb5967fa0cb34f4 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 20:13:35 +0000 Subject: [PATCH 2/2] Move teardown to the trainer, as it is reponsible for the accelerator --- pytorch_lightning/accelerators/accelerator.py | 1 - pytorch_lightning/trainer/trainer.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 8c50831ce47be..15186ae81d766 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -94,7 +94,6 @@ def post_dispatch(self) -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.post_dispatch() self.precision_plugin.post_dispatch() - self.teardown() @property def model(self) -> torch.nn.Module: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 603c006f21092..069180ac71941 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -538,6 +538,7 @@ def pre_dispatch(self): def post_dispatch(self): self.accelerator_backend.post_dispatch() + self.accelerator_backend.teardown() def dispatch(self): if self.testing: