diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 0600aca92e29f..2ba97d7d9a725 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -383,12 +383,18 @@ def barrier(self, name: Optional[str] = None) -> None: self.training_type_plugin.barrier(name=name) def broadcast(self, obj: object, src: int = 0) -> object: - """Broadcasts an object to all processes""" + """Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if needed. + + Args: + obj: Object to broadcast to all process, usually a tensor or collection of tensors. + src: The source rank of which the object will be broadcast from + """ return self.training_type_plugin.broadcast(obj, src) def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): """ - Function to gather a tensor from several distributed processes + Function to gather a tensor from several distributed processes. + Args: tensor: tensor of shape (batch, ...) group: the process group to gather results from. Defaults to all processes (world) @@ -409,8 +415,7 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I @property def results(self) -> Any: """ - The results of the last training/testing run will be cached here. + The results of the last training/testing run will be cached within the training type plugin. In distributed training, we make sure to transfer the results to the appropriate master process. """ - # TODO: improve these docs return self.training_type_plugin.results