Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support returning python scalars in DP #1935

Merged
merged 16 commits into from
Aug 7, 2020
46 changes: 23 additions & 23 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def forward(self, batch):

"""

def training_step(self, *args, **kwargs) -> Union[int, Dict[str, Union[Tensor, Dict[str, Tensor]]]]:
def training_step(self, *args, **kwargs) -> Union[int, Dict[str, Union[Tensor, Dict[str, Union[float, Tensor]]]]]:
r"""
Here you compute and return the training loss and some additional metrics for e.g.
the progress bar or logger.
Expand All @@ -186,8 +186,8 @@ def training_step(self, *args, **kwargs) -> Union[int, Dict[str, Union[Tensor, D
When implementing :meth:`training_step`, return whatever you need in that step:

- loss -> tensor scalar **REQUIRED**
- progress_bar -> Dict for progress bar display. Must have only tensors
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
- progress_bar -> Dict for progress bar display. Must have either scalar tensors or Python scalars
- log -> Dict of metrics to add to logger. Must have either scalar tensors or Python scalars (no images, etc)

In this step you'd normally do the forward pass and calculate the loss for a batch.
You can also do fancier things like multiple forward passes or something model specific.
Expand All @@ -202,14 +202,14 @@ def training_step(self, batch, batch_idx):
out = self(x)
loss = self.loss(out, x)

logger_logs = {'training_loss': loss} # optional (MUST ALL BE TENSORS)
logger_logs = {'training_loss': loss} # optional

# if using TestTubeLogger or TensorBoardLogger you can nest scalars
logger_logs = {'losses': logger_logs} # optional (MUST ALL BE TENSORS)
logger_logs = {'losses': logger_logs} # optional

output = {
'loss': loss, # required
'progress_bar': {'training_loss': loss}, # optional (MUST ALL BE TENSORS)
'progress_bar': {'training_loss': loss}, # optional
'log': logger_logs
}

Expand Down Expand Up @@ -259,8 +259,8 @@ def training_end(self, *args, **kwargs):
"""

def training_epoch_end(
self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
) -> Dict[str, Dict[str, Tensor]]:
self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Union[float, Tensor]]]]]
) -> Dict[str, Dict[str, Union[float, Tensor]]]:
"""Called at the end of the training epoch with the outputs of all training steps.

.. code-block:: python
Expand Down Expand Up @@ -334,7 +334,7 @@ def training_epoch_end(self, outputs):
return results
"""

def training_step_end(self, *args, **kwargs) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]:
def training_step_end(self, *args, **kwargs) -> Dict[str, Union[Tensor, Dict[str, Union[float, Tensor]]]]:
"""
Use this when training with dp or ddp2 because :meth:`training_step`
will operate on only part of the batch. However, this is still optional
Expand All @@ -358,8 +358,8 @@ def training_step_end(self, *args, **kwargs) -> Dict[str, Union[Tensor, Dict[str
Dict with loss key and optional log or progress bar keys.

- loss -> tensor scalar **REQUIRED**
- progress_bar -> Dict for progress bar display. Must have only tensors
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
- progress_bar -> Dict for progress bar display. Must have either scalar tensors or Python scalars
- log -> Dict of metrics to add to logger. Must have either scalar tensors or Python scalars (no images, etc)

Examples:
.. code-block:: python
Expand Down Expand Up @@ -396,7 +396,7 @@ def training_step_end(self, outputs):
See the :ref:`multi-gpu-training` guide for more details.
"""

def validation_step(self, *args, **kwargs) -> Dict[str, Tensor]:
def validation_step(self, *args, **kwargs) -> Dict[str, Union[float, Tensor]]:
r"""
Operates on a single batch of data from the validation set.
In this step you'd might generate examples or calculate anything of interest like accuracy.
Expand Down Expand Up @@ -486,7 +486,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx):
the model goes back to training mode and gradients are enabled.
"""

def validation_step_end(self, *args, **kwargs) -> Dict[str, Tensor]:
def validation_step_end(self, *args, **kwargs) -> Dict[str, Union[float, Tensor]]:
"""
Use this when validating with dp or ddp2 because :meth:`validation_step`
will operate on only part of the batch. However, this is still optional
Expand Down Expand Up @@ -553,8 +553,8 @@ def validation_end(self, outputs):
"""

def validation_epoch_end(
self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
) -> Dict[str, Dict[str, Tensor]]:
self, outputs: Union[List[Dict[str, Union[float, Tensor]]], List[List[Dict[str, Union[float, Tensor]]]]]
) -> Dict[str, Dict[str, Union[float, Tensor]]]:
"""
Called at the end of the validation epoch with the outputs of all validation steps.

Expand All @@ -575,8 +575,8 @@ def validation_epoch_end(
Dict or OrderedDict.
May have the following optional keys:

- progress_bar (dict for progress bar display; only tensors)
- log (dict of metrics to add to logger; only tensors).
- progress_bar (dict for progress bar display; either scalar tensors or Python scalars)
- log (dict of metrics to add to logger; either scalar tensors or Python scalars).

Note:
If you didn't define a :meth:`validation_step`, this won't be called.
Expand Down Expand Up @@ -630,7 +630,7 @@ def validation_epoch_end(self, outputs):
return results
"""

def test_step(self, *args, **kwargs) -> Dict[str, Tensor]:
def test_step(self, *args, **kwargs) -> Dict[str, Union[float, Tensor]]:
r"""
Operates on a single batch of data from the test set.
In this step you'd normally generate examples or calculate anything of interest
Expand Down Expand Up @@ -713,7 +713,7 @@ def test_step(self, batch, batch_idx, dataloader_idx):
to training mode and gradients are enabled.
"""

def test_step_end(self, *args, **kwargs) -> Dict[str, Tensor]:
def test_step_end(self, *args, **kwargs) -> Dict[str, Union[float, Tensor]]:
"""
Use this when testing with dp or ddp2 because :meth:`test_step` will operate
on only part of the batch. However, this is still optional
Expand Down Expand Up @@ -779,8 +779,8 @@ def test_end(self, outputs):
"""

def test_epoch_end(
self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]]
) -> Dict[str, Dict[str, Tensor]]:
self, outputs: Union[List[Dict[str, Union[float, Tensor]]], List[List[Dict[str, Union[float, Tensor]]]]]
) -> Dict[str, Dict[str, Union[float, Tensor]]]:
"""
Called at the end of a test epoch with the output of all test steps.

Expand All @@ -800,8 +800,8 @@ def test_epoch_end(
Return:
Dict or OrderedDict: Dict has the following optional keys:

- progress_bar -> Dict for progress bar display. Must have only tensors.
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc).
- progress_bar -> Dict for progress bar display. Must have either scalar tensors or Python scalars.
- log -> Dict of metrics to add to logger. Must have either scalar tensors or Python scalars (no images, etc).

Note:
If you didn't define a :meth:`test_step`, this won't be called.
Expand Down
42 changes: 38 additions & 4 deletions pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import itertools
import threading
from itertools import chain
from collections import Mapping, Iterable

import torch
from torch.cuda._utils import _get_device_index
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel._functions import Gather
from pytorch_lightning.core.step_result import Result


Expand Down Expand Up @@ -68,7 +70,7 @@ def forward(self, *inputs, **kwargs):
if isinstance(outputs[0], Result):
outputs = self.__gather_structured_result(outputs)
else:
outputs = self.gather(outputs, self.output_device)
outputs = self.gather(outputs)
return outputs

def __gather_structured_result(self, outputs):
Expand All @@ -81,7 +83,7 @@ def __gather_structured_result(self, outputs):
for i, output in enumerate(outputs):
del output['meta']

outputs = self.gather(outputs, self.output_device)
outputs = self.gather(outputs)

# pass minimize to constructor for TrainResult
if 'minimize' in outputs:
Expand All @@ -93,6 +95,39 @@ def __gather_structured_result(self, outputs):
result['meta'] = meta
return result

def gather(self, outputs):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
r"""
Override the gather method to support python scalars as well.
"""
def gather_map(outputs):
elem = outputs[0]
elem_type = type(elem)

if isinstance(elem, torch.Tensor):
return Gather.apply(self.output_device, self.dim, *outputs)

if elem is None:
return None

if isinstance(elem, Mapping):
if not all((len(elem) == len(d) for d in outputs)):
raise ValueError('All dicts must have the same number of keys')
return elem_type(((k, gather_map([d[k] for d in outputs]))
for k in elem))

if isinstance(elem, Iterable) and not isinstance(elem, str):
return elem_type(map(gather_map, zip(*outputs)))

return outputs

# Recursive function calls like this create reference cycles.
# Setting the function to None clears the refcycle.
try:
res = gather_map(outputs)
finally:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
gather_map = None
return res

def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

Expand Down Expand Up @@ -126,9 +161,8 @@ def forward(self, *inputs, **kwargs): # pragma: no-cover
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
# normal
# output = self.module(*inputs, **kwargs)
# lightning (ddp_cpu)
# normal lightning (ddp_cpu)
if self.module.training:
output = self.module.training_step(*inputs, **kwargs)
elif self.module.testing:
Expand Down
9 changes: 3 additions & 6 deletions pytorch_lightning/trainer/ignored_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@

def ignore_scalar_return_in_dp():
# Users get confused by this warning so we silence it
m_1 = """
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
Was asked to gather along dimension 0, but all
input tensors were scalars; will instead unsqueeze
and return a vector.
"""
warnings.filterwarnings('ignore', message=m_1)
warnings.filterwarnings('ignore', message='Was asked to gather along dimension 0, but all'
' input tensors were scalars; will instead unsqueeze'
' and return a vector.')


ignore_scalar_return_in_dp()
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ def reduce_distributed_output(self, output, num_gpus):
if isinstance(output[k], dict):
output[k] = self.reduce_distributed_output(output[k], num_gpus)

# compute the average of scalars
elif isinstance(output[k], list):
output[k] = sum(output[k]) / len(output[k])
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

# do nothing when there's a scalar
elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0:
pass
Expand Down
22 changes: 14 additions & 8 deletions tests/base/model_train_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,31 @@ class TrainingStepVariations(ABC):
"""
Houses all variations of training steps
"""

test_step_inf_loss = float('inf')

def training_step(self, batch, batch_idx, optimizer_idx=None):
"""Lightning calls this inside the training loop"""
# forward pass
x, y = batch
x = x.view(x.size(0), -1)

y_hat = self(x)

# calculate loss
loss_val = self.loss(y, y_hat)

# alternate possible outputs to test
output = OrderedDict({
'loss': loss_val,
'progress_bar': {'some_val': loss_val * loss_val},
'log': {'train_some_val': loss_val * loss_val},
})
log_val = loss_val

# alternate between tensors and scalars for "log" and "progress_bar"
if batch_idx % 2 == 0:
log_val = log_val.item()

output = OrderedDict(
{
'loss': loss_val,
'progress_bar': {'some_val': log_val * log_val},
'log': {'train_some_val': log_val * log_val},
}
)
return output

def training_step__inf_loss(self, batch, batch_idx, optimizer_idx=None):
Expand Down
7 changes: 6 additions & 1 deletion tests/base/model_valid_epoch_ends.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def _mean(res, key):
val_loss_mean = _mean(outputs, 'val_loss')
val_acc_mean = _mean(outputs, 'val_acc')

# alternate between tensor and scalar
if self.current_epoch % 2 == 0:
val_loss_mean = val_loss_mean.item()
val_acc_mean = val_acc_mean.item()

metrics_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
return results
Expand Down Expand Up @@ -54,6 +59,6 @@ def _mean(res, key):
results = {
'val_loss': torch.stack([v for k, v in pbar.items() if k.startswith('val_loss')]).mean(),
'progress_bar': pbar,
'log': logs
'log': logs,
}
return results