diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 4917010fdd4152..3945d770fe8d42 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -70,7 +70,7 @@ def forward(self, *inputs, **kwargs): if isinstance(outputs[0], Result): outputs = self.__gather_structured_result(outputs) else: - outputs = self.gather(outputs, self.output_device) + outputs = self.gather(outputs) return outputs def __gather_structured_result(self, outputs): @@ -83,7 +83,7 @@ def __gather_structured_result(self, outputs): for i, output in enumerate(outputs): del output['meta'] - outputs = self.gather(outputs, self.output_device) + outputs = self.gather(outputs) # pass minimize to constructor for TrainResult if 'minimize' in outputs: @@ -106,16 +106,16 @@ def gather_map(outputs): if isinstance(elem, torch.Tensor): return Gather.apply(self.output_device, self.dim, *outputs) - elif elem is None: + if elem is None: return None - elif isinstance(elem, Mapping): + if isinstance(elem, Mapping): if not all((len(elem) == len(d) for d in outputs)): raise ValueError('All dicts must have the same number of keys') return elem_type(((k, gather_map([d[k] for d in outputs])) for k in elem)) - elif isinstance(elem, Iterable) and not isinstance(elem, str): + if isinstance(elem, Iterable) and not isinstance(elem, str): return elem_type(map(gather_map, zip(*outputs))) return outputs