diff --git a/CHANGELOG.md b/CHANGELOG.md index 2348f6afc4ff1..99eb26935f197 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -128,6 +128,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688)) +- Fixed broacast to use PyTorch `broadcast_object_list` and add `reduce_decision` ([#6410](https://github.com/PyTorchLightning/pytorch-lightning/pull/6410)) + + - Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 84d53b5addd6b..f9ccc7a42fa06 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -21,7 +21,6 @@ from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin from pytorch_lightning.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.enums import AMPType, LightningEnum @@ -396,7 +395,7 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s Return: A tensor of shape (world_size, batch, ...) """ - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads) def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: """Wraps the dataloader if necessary diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index d114641f1af72..2dfd0afb02634 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -190,4 +190,4 @@ def _run_early_stopping_check(self, trainer, pl_module): trainer.should_stop = True # stop every ddp process if any world process decides to stop - trainer.should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop) + trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 5b5a851a922b7..383e1caa6a7dc 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -336,7 +336,7 @@ def _save_model(self, filepath: str, trainer, pl_module): else: raise ValueError(".save_function() not set") - def check_monitor_top_k(self, current) -> bool: + def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) -> bool: if current is None: return False @@ -356,7 +356,12 @@ def check_monitor_top_k(self, current) -> bool: current = torch.tensor(current) monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode] - return monitor_op(current, self.best_k_models[self.kth_best_model_path]).item() + should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) + + # If using multiple devices, make sure all processes are unanimous on the decision. + should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save) + + return should_update_best_and_save @classmethod def _format_checkpoint_name( @@ -554,15 +559,7 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics): epoch = metrics.get("epoch") step = metrics.get("step") - # when `val_loss` is being logged and no ModelCheckpoint is being provided - # `val_loss` will be selected for monitor and need to be reduced to - # prevent processes divergence - # TODO: Move this logic to logger_connector. This also needs to be fixed for any - # other monitor logged value which aren't produced from a Metric. - if self.monitor == "val_loss": - current = trainer.training_type_plugin.reduce(current, reduce_op="mean") - - if self.check_monitor_top_k(current): + if self.check_monitor_top_k(trainer, current): self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics) elif self.verbose: rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}") @@ -627,5 +624,4 @@ def file_exists(self, filepath: Union[str, Path], trainer) -> bool: the internal state to diverge between ranks. """ exists = self._fs.exists(filepath) - exists = trainer.training_type_plugin.broadcast(exists) - return exists + return trainer.training_type_plugin.broadcast(exists) diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index 5da7dfa86084d..37ac5d8b13462 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -11,18 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io from typing import Any -import torch -from torch import distributed as torch_distrib - -from pytorch_lightning.utilities import _GROUP_AVAILABLE - -WORLD = None -if _GROUP_AVAILABLE: - from torch.distributed import group - WORLD = group.WORLD +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list +from pytorch_lightning.utilities.distributed import group as _group class LightningDistributed: @@ -31,32 +23,13 @@ def __init__(self, rank=None, device=None): self.rank = rank self.device = device - def broadcast(self, obj: Any, group=WORLD): - if self.rank == 0: - self._emit(obj, group) - else: - obj = self._receive(group) - return obj - - def _broadcast(self, tensor, src=0, group=WORLD): - if group is None: - return torch_distrib.broadcast(tensor, src=src) - return torch_distrib.broadcast(tensor, src=0, group=group) - - def _emit(self, obj: Any, group=WORLD): - buffer = io.BytesIO() - torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - length_tensor = torch.tensor([len(data)]).long().to(self.device) - self._broadcast(length_tensor, src=0, group=group) - data_tensor = torch.ByteTensor(data).to(self.device) - self._broadcast(data_tensor, src=0, group=group) - - def _receive(self, group=WORLD): - length_tensor = torch.tensor([0]).long().to(self.device) - self._broadcast(length_tensor, src=0, group=group) - data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device) - self._broadcast(data_tensor, src=0, group=group) - buffer = io.BytesIO(data_tensor.cpu().numpy()) - obj = torch.load(buffer) - return obj + def broadcast(self, obj: Any, group=_group.WORLD): + # always wrap into a list so list can be brodcasted. + obj = [obj] + + if self.rank != 0: + obj = [None] * len(obj) + + broadcast_object_list(obj, 0, group=group or _group.WORLD) + + return obj[0] diff --git a/pytorch_lightning/overrides/torch_distributed.py b/pytorch_lightning/overrides/torch_distributed.py new file mode 100644 index 0000000000000..67b64c046dc18 --- /dev/null +++ b/pytorch_lightning/overrides/torch_distributed.py @@ -0,0 +1,94 @@ +import logging +import pickle + +import torch + +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 + +log = logging.getLogger(__name__) + +if torch.distributed.is_available(): + from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember + +# The code underneath is taken from PyTorch ``torch/distributed/distributed_c10d.py`` +# and enable broadcasting for PyTorch 1.6 and lower. + + +# https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160 +def _rank_not_in_group(group): + """ + Helper that checks if the current process's rank is not in a given group. + """ + if group is None: + return False + return group == GroupMember.NON_GROUP_MEMBER + + +# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1164 +def _object_to_tensor(obj): + buffer = pickle.dumps(obj) + byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined] + byte_tensor = torch.ByteTensor(byte_storage) + local_size = torch.LongTensor([byte_tensor.numel()]) + return byte_tensor, local_size + + +# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py +def _tensor_to_object(tensor, tensor_size): + buf = tensor.numpy().tobytes()[:tensor_size] + out = pickle.loads(buf) + return out + + +# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327 +def _broadcast_object_list(object_list, src=0, group=None): + if _rank_not_in_group(group): + return + + my_rank = get_rank() + # Serialize object_list elements to tensors on src rank. + if my_rank == src: + tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.LongTensor(len(object_list)) + + group_backend = get_backend(group) + is_nccl_backend = group_backend == Backend.NCCL + current_device = torch.device("cpu") + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device('cuda', torch.cuda.current_device()) + object_sizes_tensor = object_sizes_tensor.to(current_device) + object_sizes_tensor = object_sizes_tensor.to(current_device) + + # Broadcast object sizes + broadcast(object_sizes_tensor, src=src, group=group) + + # Concatenate and broadcast serialized object tensors + if my_rank == src: + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item()) + + if is_nccl_backend: + object_tensor = object_tensor.to(current_device) + + broadcast(object_tensor, src=src, group=group) + + # Deserialize objects using their stored sizes. + offset = 0 + if my_rank != src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset:offset + obj_size] + obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload] + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size) + + +if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available(): + from torch.distributed.distributed_c10d import broadcast_object_list +else: + broadcast_object_list = _broadcast_object_list diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index a1f33b9931cf5..b600eca5e6bc2 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Tuple +from typing import Any, Callable, Generator, List, Sequence, Tuple, Type, TYPE_CHECKING import torch -from torch.optim import Optimizer from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin @@ -23,24 +22,28 @@ if _APEX_AVAILABLE: from apex import amp +if TYPE_CHECKING: + from torch.optim import Optimizer + class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): """Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)""" - def __init__(self, amp_level: str): + def __init__(self, amp_level: str = "O2") -> None: self.backend = AMPType.APEX self.amp_level = amp_level - def master_params(self, optimizer: torch.optim.Optimizer): + def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]: return amp.master_params(optimizer) - def connect(self, model: torch.nn.Module, optimizers, lr_schedulers): + def connect(self, model: torch.nn.Module, optimizers: Sequence['Optimizer'], + lr_schedulers: Sequence[Any]) -> Tuple[torch.nn.Module, Sequence['Optimizer'], Sequence[Any]]: """Connects the precision plugin to the training process, configures apex and reinits the schedulers """ if model.device.type != "cuda": return model, optimizers, lr_schedulers - model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level) + model, optimizers = self.configure_apex(amp, model, list(optimizers), self.amp_level) self.reinit_scheduler_properties(optimizers, lr_schedulers) return model, optimizers, lr_schedulers @@ -48,12 +51,12 @@ def backward( self, model: LightningModule, closure_loss: torch.Tensor, - optimizer: torch.optim.Optimizer, + optimizer: 'Optimizer', opt_idx: int, should_accumulate: bool, - *args, - **kwargs, - ): + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: """performs the actual backpropagation Args: @@ -94,11 +97,11 @@ def backward( def configure_apex( self, - amp: object, + amp: Type, model: LightningModule, - optimizers: List[Optimizer], + optimizers: List['Optimizer'], amp_level: str, - ) -> Tuple[LightningModule, List[Optimizer]]: + ) -> Tuple[LightningModule, List['Optimizer']]: r""" Override to init AMP your own way. Must return a model and list of optimizers. @@ -127,7 +130,7 @@ def configure_apex(self, amp, model, optimizers, amp_level): return model, optimizers @staticmethod - def reinit_scheduler_properties(optimizers: list, schedulers: list): + def reinit_scheduler_properties(optimizers: Sequence['Optimizer'], schedulers: Sequence[Any]) -> None: """Reinitializes schedulers with correct properties""" # Reinitialize optimizer.step properties added by schedulers for scheduler in schedulers: @@ -149,7 +152,12 @@ def reinit_scheduler_properties(optimizers: list, schedulers: list): break def pre_optimizer_step( - self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs + self, + pl_module: LightningModule, + optimizer: 'Optimizer', + optimizer_idx: int, + lambda_closure: Callable, + **kwargs: Any, ) -> bool: """ always called before the optimizer step. @@ -160,6 +168,5 @@ def pre_optimizer_step( if not pl_module.automatic_optimization: pl_module.trainer.call_hook("on_after_backward") - optimizer.step() - + optimizer.step(**kwargs) return False diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index af8cfa7755974..7a1f7ac1201c0 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -71,8 +71,8 @@ def barrier(self, *args, **kwargs): def broadcast(self, obj: object, src: int = 0) -> object: return obj - def reduce_early_stopping_decision(self, should_stop: bool) -> bool: - return should_stop + def reduce_boolean_decision(self, decision: bool) -> bool: + return decision def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 13585f8f368f4..27ae26d67b493 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -21,7 +21,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE -from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp +from pytorch_lightning.utilities.distributed import group, rank_zero_only, ReduceOp if _HOROVOD_AVAILABLE: import horovod.torch as hvd @@ -147,8 +147,13 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ hvd.join() return hvd.allreduce(output, op=reduce_op) - def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None): - if group is not None: + def all_gather( + self, + result: Union[torch.Tensor], + group: Optional[Any] = group.WORLD, + sync_grads: bool = False + ) -> torch.Tensor: + if group is not None and group != group.WORLD: raise ValueError( "Horovod does not support allgather using a subcommunicator at this time. " "Unset `group`." diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index f3c825fe9cd7a..9809443aff3fb 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io import os from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import List, Optional +from typing import Any, List, Optional import torch from torch.nn.parallel import DistributedDataParallel @@ -36,9 +35,10 @@ def __init__( ): super().__init__() self.parallel_devices = parallel_devices + self.cluster_environment = cluster_environment + self.global_rank = 0 self.world_size = 1 self.local_rank = 0 - self.cluster_environment = cluster_environment @property def cluster_local_rank(self): @@ -77,11 +77,15 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=len(self.parallel_devices), rank=self.global_rank) return distributed_sampler_kwargs - def reduce_early_stopping_decision(self, should_stop: bool) -> bool: - should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device) - should_stop = self.reduce(should_stop, reduce_op=ReduceOp.SUM) - should_stop = bool(should_stop == self.world_size) - return should_stop + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + + def reduce_boolean_decision(self, decision: bool) -> bool: + decision = torch.tensor(int(decision), device=self.lightning_module.device) + decision = self.reduce(decision, reduce_op=ReduceOp.SUM) + decision = bool(decision == self.world_size) + return decision @property def torch_distributed_backend(self): @@ -119,13 +123,3 @@ def block_backward_sync(self): yield None else: yield None - - def broadcast(self, obj: object, src: int) -> object: - buffer = io.BytesIO() - torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - data_tensor = torch.tensor(data).to(self.root_device, dtype=torch.float) - data = all_gather_ddp_if_available(data_tensor) - buffer = io.BytesIO(data.cpu().byte().numpy()) - obj = torch.load(buffer) - return obj diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 4b1d24301b8a0..39fe06e1d46f2 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -1,4 +1,17 @@ -from typing import Any, Union +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union import torch @@ -10,6 +23,9 @@ class SingleDevicePlugin(TrainingTypePlugin): def __init__(self, device: torch.device): super().__init__() self.device: torch.device = device + self.global_rank = 0 + self.local_rank = 0 + self.world_size = 1 @property def on_tpu(self) -> bool: @@ -20,8 +36,24 @@ def on_gpu(self) -> bool: return self.device.type == "cuda" and torch.cuda.is_available() def reduce(self, output: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]: + """ + Reduces output from several distributed processes to one aggregated tensor. + As this plugin only operates with a single device, the reduction is simply the identity. + + Args: + output: the tensor to sync and reduce + *args: ignored + **kwargs: ignored + + Return: + the unmodified input as reduction is not needed for single process operation + """ return output + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes """ + return tensor + @property def root_device(self) -> torch.device: return self.device diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 371649057909b..1e951329b22cc 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -188,12 +188,11 @@ def save_spawn_weights(self, model: LightningModule) -> Optional[str]: model.trainer.save_checkpoint(path) return path - def reduce_early_stopping_decision(self, should_stop: bool) -> bool: - should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device) - stop = xm.mesh_reduce('stop_signal', should_stop, sum) - rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") - should_stop = int(stop.item()) == self.world_size - return should_stop + def reduce_decision(self, decision: bool) -> bool: + decision = torch.tensor(int(decision), device=self.device) + decision = self.reduce(decision, "sum") + decision = bool(decision == self.world_size) + return decision def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): if not isinstance(output, torch.Tensor): diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index d7c3b4d4d77e1..b3a6c36bfbbf6 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -33,7 +33,6 @@ class TrainingTypePlugin(Plugin, ABC): def __init__(self) -> None: self._model = None self._results = None - self.global_rank = 0 @property @abstractmethod @@ -66,9 +65,13 @@ def barrier(self, name: Optional[str] = None) -> None: def broadcast(self, obj: object, src: int = 0) -> object: """Broadcasts an object to all processes""" - def reduce_early_stopping_decision(self, should_stop: bool) -> bool: - """Reduce the early stopping decision across all possibly spawned processes""" - return should_stop + @abstractmethod + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes """ + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes""" + return decision def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run before precision plugin executes backward""" diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index a547144c8a6f3..b6270ef0cb835 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -13,9 +13,11 @@ # limitations under the License. from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple +from weakref import proxy import torch +import pytorch_lightning as pl from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import DistributedType, LightningEnum @@ -50,7 +52,7 @@ class HookResultStore: Those data structures enables us to reduce properly Result object when batch loop is finished. """ - def __init__(self, fx_name): + def __init__(self, fx_name: str) -> None: self._fx_name = fx_name self._internals = {} self._internals_reduced = {} @@ -104,6 +106,7 @@ def get_batch_log_metrics(self, *args, **kwargs): def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> None: if not isinstance(opt_metric, Result): raise Exception("The provided opt_metric should be a Result Object. Something is wrong") + func = getattr(opt_metric, func_name) metrics_to_log = func(*args, add_dataloader_idx=self.has_several_dataloaders, **kwargs) results.append(metrics_to_log) @@ -222,7 +225,7 @@ class EpochResultStore: ``` """ - def __init__(self, trainer, stage): + def __init__(self, trainer: 'pl.Trainer', stage): self.trainer = trainer self._stage = stage self.reset() diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index f283497e5e5a1..61f581a5b5571 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -23,6 +23,7 @@ if torch.distributed.is_available(): from torch.distributed import group, ReduceOp + else: class ReduceOp: diff --git a/setup.cfg b/setup.cfg index d9b27118a4030..4c478dccb709e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,10 +43,6 @@ exclude_lines = # The actual coverage for each is 90%+ # *metrics (94%+) are temporarily removed from testing while tests speed up omit = - pytorch_lightning/accelerators/ddp_*.py - pytorch_lightning/accelerators/ddp2_*.py - pytorch_lightning/accelerators/dp_*.py - pytorch_lightning/accelerators/tpu_*.py pytorch_lightning/cluster_environments/*.py pytorch_lightning/utilities/xla_device_utils.py pytorch_lightning/utilities/distributed.py diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 6ce1938d3990f..ca37a2d3ec86a 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -19,6 +19,7 @@ from pytorch_lightning import callbacks, seed_everything, Trainer from tests.helpers import BoringModel +from tests.helpers.runif import RunIf @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @@ -98,3 +99,42 @@ def training_step(self, batch, batch_idx): # make sure types are correct assert save_mock.call_count == expected + + +@mock.patch('torch.save') +@RunIf(special=True, min_gpus=2) +@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [(1, 1, 1.0, 1), (2, 2, 0.3, 5)]) +def test_top_k_ddp(save_mock, tmpdir, k, epochs, val_check_interval, expected): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + local_rank = int(os.getenv("LOCAL_RANK")) + self.log('my_loss', batch_idx * (1 + local_rank), on_epoch=True) + return super().training_step(batch, batch_idx) + + def training_epoch_end(self, outputs) -> None: + data = str(self.global_rank) + obj = [[data], (data, ), set(data)] + out = self.trainer.training_type_plugin.broadcast(obj) + assert obj == [[str(self.global_rank)], (str(self.global_rank), ), set(str(self.global_rank))] + assert out == [['0'], ('0', ), set('0')] + + model = TestModel() + trainer = Trainer( + callbacks=[callbacks.ModelCheckpoint(dirpath=tmpdir, monitor='my_loss_step', save_top_k=k, mode="max")], + default_root_dir=tmpdir, + max_epochs=epochs, + weights_summary=None, + val_check_interval=val_check_interval, + accelerator="ddp", + gpus=2, + limit_train_batches=64, + limit_val_batches=32, + ) + if os.getenv("LOCAL_RANK") == "0": + with pytest.raises(UserWarning, match="The value associated to the key my_loss_epoch: [15.5, 31.0]"): + trainer.fit(model) + assert save_mock.call_count == expected + else: + trainer.fit(model) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 06a114ca15eb9..29eaebc031e3c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -347,7 +347,7 @@ def on_train_start(self, trainer, pl_module): torch.save = Mock(wraps=torch.save) def on_save_checkpoint(self, trainer, pl_module, checkpoint): - # expect all ranks to run but only rank 0 will actually write the checkpoint file + # only rank 0 will call ``torch.save`` super().on_save_checkpoint(trainer, pl_module, checkpoint) self.on_save_checkpoint_count += 1 @@ -357,8 +357,7 @@ def on_train_end(self, trainer, pl_module): assert self.best_model_score assert self.on_save_checkpoint_count == self.expected_count if trainer.is_global_zero: - # twice the calls expected because ddp broadcast also uses torch.save - assert torch.save.call_count == self.expected_count * 2 + assert torch.save.call_count == self.expected_count else: assert torch.save.call_count == 0 diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 93a637dda1071..1ef55e729912b 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -21,6 +21,8 @@ import os import sys +import torch + # this is needed because Conda does not use `PYTHONPATH` env var while pip and virtualenv do PYTHONPATH = os.getenv('PYTHONPATH', '') if ':' in PYTHONPATH: @@ -53,8 +55,13 @@ def run_test_from_config(trainer_options): ckpt_path = trainer_options['weights_save_path'] trainer_options.update(callbacks=[ModelCheckpoint(dirpath=ckpt_path)]) - model = BoringModel() + class TestModel(BoringModel): + + def training_epoch_end(self, outputs) -> None: + res = self.trainer.training_type_plugin.reduce(torch.tensor(1., device=self.device), reduce_op="sum") + assert res.sum() == self.trainer.training_type_plugin.world_size + model = TestModel() trainer = Trainer(**trainer_options) trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 0b89c3b06c041..8a1260251eb14 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -17,11 +17,13 @@ import shlex import subprocess import sys +from unittest.mock import patch import numpy as np import pytest import torch from sklearn.metrics import accuracy_score +from torch import optim import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils @@ -55,6 +57,9 @@ def _run_horovod(trainer_options, on_gpu=False): # for Horovod, we interpret `gpus` to be set per worker trainer_options.update(gpus=1 if on_gpu else None) tutils.reset_seed() + # todo: Find why coverage breaks CI. + # append = '-a' if '.coverage' in os.listdir(_PROJECT_ROOT) else '' # noqa E265 + # str(num_processes), sys.executable, '-m', 'coverage', 'run', '--source', 'pytorch_lightning', append, # noqa E265 cmdline = [ 'horovodrun', '-np', str(num_processes), sys.executable, TEST_SCRIPT, '--trainer-options', @@ -119,6 +124,8 @@ def test_horovod_multi_gpu(tmpdir): _run_horovod(trainer_options, on_gpu=True) +# https://discuss.pytorch.org/t/torch-cuda-amp-vs-nvidia-apex/74994 +# Check with (tgaddair) on Horovod issues if this feature is needed @pytest.mark.skip(reason="Horovod has a problem with broadcast when using apex?") @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") @pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") @@ -167,6 +174,27 @@ def test_horovod_amp(tmpdir): _run_horovod(trainer_options, on_gpu=True) +@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +@pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires torch.cuda.amp") +def test_horovod_gather(tmpdir): + """Test Horovod with multi-GPU support using native amp.""" + trainer_options = dict( + default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), + gradient_clip_val=1.0, + progress_bar_refresh_rate=0, + max_epochs=1, + limit_train_batches=0.4, + limit_val_batches=0.2, + gpus=2, + deterministic=True, + accelerator='horovod', + ) + _run_horovod(trainer_options, on_gpu=True) + + @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") @pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @@ -198,6 +226,7 @@ def validation_step(self, batch, *args, **kwargs): @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +@pytest.mark.skipif(not _HOROVOD_AVAILABLE, reason="Horovod is unavailable") def test_horovod_multi_optimizer(tmpdir): model = BasicGAN() @@ -230,7 +259,7 @@ def get_optimizer_params(optimizer): # TODO: unclear Horovod failure... -@pytest.mark.skip(reason="unclear Horovod failure...") +@pytest.mark.skipif(reason="CI agent.jobstatus=Succeeded: Permission denied") @pytest.mark.skipif(not _HOROVOD_AVAILABLE, reason="Horovod is unavailable") @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") def test_result_reduce_horovod(tmpdir): @@ -273,6 +302,7 @@ def training_epoch_end(self, outputs) -> None: max_epochs=1, log_every_n_steps=1, weights_summary=None, + logger=False ) trainer.fit(model) @@ -281,7 +311,7 @@ def training_epoch_end(self, outputs) -> None: # TODO: unclear Horovod failure... -@pytest.mark.skip(reason="unclear Horovod failure...") +@pytest.mark.skipif(reason="CI agent.jobstatus=Succeeded: Permission denied") @pytest.mark.skipif(not _HOROVOD_AVAILABLE, reason="Horovod is unavailable") @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") def test_accuracy_metric_horovod(): @@ -298,10 +328,7 @@ def sk_metric(preds, target): target = torch.randint(high=2, size=(num_batches, batch_size)) def _compute_batch(): - trainer = Trainer( - fast_dev_run=True, - accelerator='horovod', - ) + trainer = Trainer(fast_dev_run=True, accelerator='horovod', logger=False) assert isinstance(trainer.accelerator, CPUAccelerator) # TODO: test that we selected the correct training_type_plugin based on horovod flags @@ -309,7 +336,7 @@ def _compute_batch(): metric = Accuracy( compute_on_step=True, dist_sync_on_step=True, - dist_sync_fn=trainer.training_type_plugin.gather_all_tensors, + dist_sync_fn=trainer.training_type_plugin.all_gather, threshold=threshold ) @@ -334,33 +361,46 @@ def _compute_batch(): horovod.run(_compute_batch, np=2) -# @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") -# def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir): -# model = BoringModel() -# model.configure_optimizers = model.configure_optimizers__multiple_schedulers -# -# num_workers = 8 -# init_lr = hparams.get('learning_rate') * num_workers -# -# with patch('pytorch_lightning.accelerators.legacy.horovod_backend.hvd.size') as mock_hvd_size: -# mock_hvd_size.return_value = 8 -# -# # fit model -# trainer = Trainer( -# default_root_dir=tmpdir, -# max_epochs=1, -# limit_val_batches=0.5, -# limit_train_batches=0.2, -# distributed_backend='horovod' -# ) -# results = trainer.fit(model) -# assert results == 1 -# -# adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0] -# adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0] -# -# # Called ones after end of epoch with gamma=0.1 -# assert pytest.approx(init_lr * 0.1) == adjusted_lr1 -# -# # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1 -# assert pytest.approx(init_lr * 0.1) == adjusted_lr2 +@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +@pytest.mark.skipif(not _HOROVOD_AVAILABLE, reason="Horovod is unavailable") +def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx): + return super().training_step(batch, batch_idx) + + def configure_optimizers(self): + optimizer1 = optim.Adam(self.parameters(), lr=0.1) + optimizer2 = optim.Adam(self.parameters(), lr=0.1) + lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) + lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) + return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] + + model = TestModel() + model.training_epoch_end = None + + num_workers = 8 + init_lr = 0.1 * num_workers + + with patch('horovod.torch.size', return_value=8): + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.5, + limit_train_batches=0.2, + accelerator='horovod' + ) + results = trainer.fit(model) + assert results == 1 + + adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0] + adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0] + + # Called ones after end of epoch with gamma=0.1 + assert pytest.approx(init_lr * 0.1) == adjusted_lr1 + + # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1 + assert pytest.approx(init_lr * 0.1) == adjusted_lr2 diff --git a/tests/special_tests.sh b/tests/special_tests.sh index a2373d05a42ef..e1dabc7daa5b5 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -35,4 +35,5 @@ python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_ python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp python ${DEFAULTS} tests/trainer/test_data_loading.py::test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model +python ${DEFAULTS} tests/checkpointing/test_checkpoint_callback_frequency.py::test_top_k_ddp nvprof --profile-from-start off -o trace_name.prof -- python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_nested_emit_nvtx