From 309ed75c5d6740538fca6d9a571d85606ac6d48b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 30 Jun 2020 16:15:35 -0400 Subject: [PATCH] added reduce ddp results on eval (#2434) * added reduce ddp results on eval * added reduce ddp results on eval * added reduce ddp results on eval * added reduce ddp results on eval * added reduce ddp results on eval * added reduce ddp results on eval * added reduce ddp results on eval * added reduce ddp results on eval * added reduce ddp results on eval * added reduce ddp results on eval * added reduce ddp results on eval --- pytorch_lightning/trainer/evaluation_loop.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 349cf1635c15c..7f37edc04b639 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -132,6 +132,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE +from torch import distributed as dist try: import torch_xla.distributed.parallel_loader as xla_pl @@ -163,6 +164,7 @@ class TrainerEvaluationLoopMixin(ABC): model: LightningModule num_test_batches: List[int] num_val_batches: int + world_size: int fast_dev_run: ... process_output: ... progress_bar_dict: ... @@ -339,6 +341,10 @@ def _evaluate( elif self.is_overridden('validation_epoch_end', model=model): eval_results = model.validation_epoch_end(outputs) + # aggregate ddp stats across + if self.use_ddp or self.use_ddp2: + self.reduce_eval_ddp(eval_results) + # enable train mode again model.train() @@ -347,6 +353,19 @@ def _evaluate( return eval_results + def reduce_eval_ddp(self, eval_results): + # ignore bad inputs + if eval_results is None or len(eval_results) == 0: + return + + for k, v in eval_results.items(): + if isinstance(v, dict): + self.reduce_eval_ddp(v) + elif isinstance(v, torch.Tensor): + dist.all_reduce(v, op=dist.reduce_op.SUM) + v = v / self.world_size + eval_results[k] = v + def run_evaluation(self, test_mode: bool = False): # hook model = self.get_model()