Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer only references accelerator #6039

Merged
merged 2 commits into from
Feb 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,25 @@ def setup(self, trainer: "Trainer", model: LightningModule) -> None:
self.setup_optimizers(trainer)
self.connect_precision_plugin(self.precision_plugin)

def start_training(self, trainer: 'Trainer'):
self.training_type_plugin.start_training(trainer)

def start_testing(self, trainer: 'Trainer'):
self.training_type_plugin.start_testing(trainer)

def start_predicting(self, trainer: 'Trainer'):
self.training_type_plugin.start_predicting(trainer)

def pre_dispatch(self) -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.pre_dispatch()
self.precision_plugin.pre_dispatch()

def post_dispatch(self) -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch()
self.precision_plugin.post_dispatch()

@property
def model(self) -> torch.nn.Module:
"""Returns the model. This can also be a wrapped LightningModule.
Expand Down Expand Up @@ -224,23 +243,6 @@ def validation_step_end(self, output):
"""
return self.training_type_plugin.validation_step_end(output)

def predict(self, args):
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
"""The prediction step.

Args:
args: the arguments for the models predict step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): Integer displaying index of this batch
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.

"""
batch = self.to_device(args[0])
args[0] = batch
return self.training_type_plugin.predict(*args)

def backward(
self,
closure_loss: torch.Tensor,
Expand Down Expand Up @@ -378,6 +380,10 @@ def on_save(self, checkpoint):
def barrier(self, name: Optional[str] = None) -> None:
self.training_type_plugin.barrier(name=name)

def broadcast(self, obj: object, src: int = 0) -> object:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is src could be better named or add docsring?

"""Broadcasts an object to all processes"""
return self.training_type_plugin.broadcast(obj, src)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Expand All @@ -397,3 +403,12 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I
dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
"""
return self.training_type_plugin.process_dataloader(dataloader)

@property
def results(self) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like it !

"""
The results of the last training/testing run will be cached here.
In distributed training, we make sure to transfer the results to the appropriate master process.
"""
# TODO: improve these docs
return self.training_type_plugin.results
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
dataloader = self._flatten_dl_only(dataloader)

if self.accelerator_backend is not None:
self.training_type_plugin.barrier('get_dataloaders')
self.accelerator_backend.barrier('get_dataloaders')
return dataloader

def _flatten_dl_only(self, dataloaders):
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
from torch.optim import Optimizer

from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.states import TrainerState
Expand Down Expand Up @@ -138,7 +138,7 @@ def log_dir(self) -> Optional[str]:
else:
dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir')

dirpath = self.training_type_plugin.broadcast(dirpath)
dirpath = self.accelerator_backend.broadcast(dirpath)
return dirpath

@property
Expand Down Expand Up @@ -365,7 +365,7 @@ def lightning_optimizers(self) -> List[LightningOptimizer]:

@property
def lightning_module(self) -> LightningModule:
return self.training_type_plugin.lightning_module
return self.accelerator_backend.lightning_module

@property
def optimizers(self) -> Optional[List[Optimizer]]:
Expand Down
22 changes: 10 additions & 12 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
Expand All @@ -32,6 +31,7 @@
from pytorch_lightning.profiler import BaseProfiler
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
Expand Down Expand Up @@ -483,7 +483,7 @@ def fit(
# trainer.dispatch || LIGHTNING
# | ||
# start_training or start_testing or start_predicting call || FLOW
# from `accelerator.training_type_plugin` ||
# from `accelerator` ||
# | || DIRECTION
# run_train or run_test or run_predict call ||
# from `trainer` ||
Expand Down Expand Up @@ -531,26 +531,24 @@ def fit(

self._set_running_stage(None, model)

return self.training_type_plugin.results or 1
return self.accelerator_backend.results or 1
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

def pre_dispatch(self):
self.training_type_plugin.pre_dispatch()
self.precision_plugin.pre_dispatch()
self.accelerator_backend.pre_dispatch()

def post_dispatch(self):
self.training_type_plugin.post_dispatch()
self.precision_plugin.post_dispatch()
self.accelerator_backend.post_dispatch()
self.accelerator_backend.teardown()

def dispatch(self):
if self.testing:
self.training_type_plugin.start_testing(self)
self.accelerator_backend.start_testing(self)

elif self.predicting:
self.training_type_plugin.start_predicting(self)
self.accelerator_backend.start_predicting(self)

else:
self.training_type_plugin.start_training(self)
self.accelerator_backend.start_training(self)

def train_or_test_or_predict(self):
if self.testing:
Expand All @@ -574,7 +572,7 @@ def _set_running_stage(self, stage: LightningEnum, model_ref: LightningModule):

def _pre_training_routine(self):
# wait for all to join if on distributed
self.accelerator.training_type_plugin.barrier("setup_training")
self.accelerator.barrier("setup_training")

# register auto-resubmit when on SLURM
self.slurm_connector.register_slurm_signal_handlers()
Expand Down Expand Up @@ -947,7 +945,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
)
return {}
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()
self.accelerator_backend.barrier()

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
Expand Down