Skip to content

Commit

Permalink
Add descriptions to accelerator broadcast function/clean up all_gather (
Browse files Browse the repository at this point in the history
#6044)

* Add descriptions to accelerator broadcast function/clean up all_gather

* Remove todo
  • Loading branch information
SeanNaren authored Feb 18, 2021
1 parent 049006a commit b019c25
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,12 +383,18 @@ def barrier(self, name: Optional[str] = None) -> None:
self.training_type_plugin.barrier(name=name)

def broadcast(self, obj: object, src: int = 0) -> object:
"""Broadcasts an object to all processes"""
"""Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if needed.
Args:
obj: Object to broadcast to all process, usually a tensor or collection of tensors.
src: The source rank of which the object will be broadcast from
"""
return self.training_type_plugin.broadcast(obj, src)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Function to gather a tensor from several distributed processes.
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
Expand All @@ -409,8 +415,7 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I
@property
def results(self) -> Any:
"""
The results of the last training/testing run will be cached here.
The results of the last training/testing run will be cached within the training type plugin.
In distributed training, we make sure to transfer the results to the appropriate master process.
"""
# TODO: improve these docs
return self.training_type_plugin.results

0 comments on commit b019c25

Please sign in to comment.