Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added reduce ddp results on eval #2434

Merged
merged 11 commits into from
Jun 30, 2020
19 changes: 19 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
from torch import distributed as dist

try:
import torch_xla.distributed.parallel_loader as xla_pl
Expand Down Expand Up @@ -163,6 +164,7 @@ class TrainerEvaluationLoopMixin(ABC):
model: LightningModule
num_test_batches: List[int]
num_val_batches: int
world_size: int
fast_dev_run: ...
process_output: ...
progress_bar_dict: ...
Expand Down Expand Up @@ -339,6 +341,10 @@ def _evaluate(
elif self.is_overridden('validation_epoch_end', model=model):
eval_results = model.validation_epoch_end(outputs)

# aggregate ddp stats across
if self.use_ddp or self.use_ddp2:
self.reduce_eval_ddp(eval_results)

# enable train mode again
model.train()

Expand All @@ -347,6 +353,19 @@ def _evaluate(

return eval_results

def reduce_eval_ddp(self, eval_results):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this traversal of nested types looks like what @justusschock did in metrics. maybe one could re-use his apply_to_collection from utilitiles.apply_func.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point @justusschock how would i adapt that here?

Copy link
Member

@awaelchli awaelchli Jun 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick brain-compiled code:

eval_results = apply_to_collection(
    eval_results, 
    dtype=torch.Tensor, 
    function=reduce_eval_ddp, 
    world_size=self.world_size
)

with reduce_eval_ddp defined as

def reduce_eval_ddp(v, world_size):
    dist.all_reduce(v, op=dist.reduce_op.SUM)
    v = v / world_size
    return v

# ignore bad inputs
if eval_results is None or len(eval_results) == 0:
return

for k, v in eval_results.items():
if isinstance(v, dict):
self.reduce_eval_ddp(v)
elif isinstance(v, torch.Tensor):
dist.all_reduce(v, op=dist.reduce_op.SUM)
v = v / self.world_size
eval_results[k] = v

def run_evaluation(self, test_mode: bool = False):
# hook
model = self.get_model()
Expand Down