Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added accumulation of loggers' metrics for the same steps #1278

Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
c5c467e
`add_argparse_args` method fixed (argument types added)
Mar 14, 2020
625c01a
autopep8 fixes
Mar 14, 2020
05b9d55
--gpus=0 removed from test (for ci tests)
Mar 14, 2020
a14da90
Update pytorch_lightning/trainer/trainer.py
alexeykarnachev Mar 20, 2020
30d0347
test_with_accumulate_grad_batches added
Mar 27, 2020
f222da1
agg_and_log_metrics logic added to the base logger class
Mar 29, 2020
0d69480
small format fix
Mar 29, 2020
40812f8
agg metrics strategies removed (not to complicate stuff)
Mar 29, 2020
06f450c
agg metrics: handle zero step
Mar 29, 2020
c4ee89b
autopep8
Mar 29, 2020
d206a13
changelog upd
Mar 29, 2020
aa22904
flake fix
Mar 29, 2020
bab529d
metrics aggregators factored out, metrics_agg.py added + tests
Mar 29, 2020
39745aa
metrics agg default value added
Mar 29, 2020
73a724a
Update pytorch_lightning/loggers/metrics_agg.py
alexeykarnachev Mar 30, 2020
c586217
remove .item which causes sync issues (#1254)
williamFalcon Mar 30, 2020
e48a711
test_metrics_agg.py removed (all tested in doctrings), agg metrics re…
Mar 30, 2020
8064470
autopep8
Mar 30, 2020
b11a80e
loggers base.py types fixed
Mar 30, 2020
8e7e28f
test
Apr 1, 2020
1143294
test
Apr 1, 2020
a04f2c3
metrics aggregation for loggers: each key now has a specific function…
Apr 2, 2020
94dd1f8
metrics aggregators factored out, metrics_agg.py added + tests
Mar 29, 2020
7ba9c96
metrics agg default value added
Mar 29, 2020
e9fc69a
Update pytorch_lightning/loggers/metrics_agg.py
alexeykarnachev Mar 30, 2020
6b90977
test_metrics_agg.py removed (all tested in doctrings), agg metrics re…
Mar 30, 2020
41aff77
metrics aggregation for loggers: each key now has a specific function…
Apr 2, 2020
2fd1c79
docstrings upd
Apr 3, 2020
5afe31d
manual typehints removed from docstrings
Apr 4, 2020
56a9b71
batch_size decreased for test `test_with_accumulate_grad_batches`
Apr 4, 2020
c6ab2e0
extend running accum
Borda Apr 7, 2020
014b6dd
refactor
Borda Apr 7, 2020
0cf3122
fix tests
Borda Apr 7, 2020
64ac60e
fix tests
Borda Apr 7, 2020
2ecb8c5
allowed_types generator scoped
alexeykarnachev Apr 7, 2020
382aadb
trainer.py distutils was imported twice, fixed
alexeykarnachev Apr 7, 2020
80da0fb
TensorRunningAccum refactored
alexeykarnachev Apr 7, 2020
6b64fde
TensorRunningAccum added to change log (Changed)
alexeykarnachev Apr 7, 2020
5f1f557
change log pull link added
alexeykarnachev Apr 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added same step loggers' metrics aggregation ([#1278](https://github.com/PyTorchLightning/pytorch-lightning/pull/1278))
- Added parity test between a vanilla MNIST model and lightning model ([#1284](https://github.com/PyTorchLightning/pytorch-lightning/pull/1284))
- Added parity test between a vanilla RNN model and lightning model ([#1351](https://github.com/PyTorchLightning/pytorch-lightning/pull/1351))
- Added Reinforcement Learning - Deep Q-network (DQN) lightning example ([#1232](https://github.com/PyTorchLightning/pytorch-lightning/pull/1232))
Expand Down
156 changes: 152 additions & 4 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import argparse
import functools
import operator
from abc import ABC, abstractmethod
from argparse import Namespace
from functools import wraps
from typing import Union, Optional, Dict, Iterable, Any, Callable, List
from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple

import numpy as np
import torch


Expand All @@ -25,22 +28,119 @@ def wrapped_fn(self, *args, **kwargs):
class LightningLoggerBase(ABC):
"""Base class for experiment loggers."""

def __init__(self):
def __init__(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Callable[[Sequence[float]], float] = np.mean
):
"""
Args:
agg_key_funcs:
Dictionary which maps a metric name to a function, which will
aggregate the metric values for the same steps.
agg_default_func:
Default function to aggregate metric values. If some metric name
is not presented in the `agg_key_funcs` dictionary, then the
`agg_default_func` will be used for aggregation.

Notes:
`agg_key_funcs` and `agg_default_func` are used only when one logs metrics with
`LightningLoggerBase.agg_and_log_metrics` method.
"""
self._rank = 0
self._prev_step = -1
self._metrics_to_agg: List[Dict[str, float]] = []
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {}
self._agg_default_func = agg_default_func

def update_agg_funcs(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Callable[[Sequence[float]], float] = np.mean
):
"""Update aggregation methods.

Args:
agg_key_funcs:
Dictionary which maps a metric name to a function, which will
aggregate the metric values for the same steps.
agg_default_func:
Default function to aggregate metric values. If some metric name
is not presented in the `agg_key_funcs` dictionary, then the
`agg_default_func` will be used for aggregation.
"""
if agg_key_funcs:
self._agg_key_funcs.update(agg_key_funcs)
if agg_default_func:
self._agg_default_func = agg_default_func

@property
@abstractmethod
def experiment(self) -> Any:
"""Return the experiment object associated with this logger"""

def _aggregate_metrics(
self, metrics: Dict[str, float], step: Optional[int] = None
) -> Tuple[int, Optional[Dict[str, float]]]:
"""Aggregates metrics.

Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded

Returns:
sStep and aggregated metrics. The return value could be None. In such case, metrics
are added to the aggregation list, but not aggregated yet.
"""
# if you still receiving metric from the same step, just accumulate it
if step == self._prev_step:
self._metrics_to_agg.append(metrics)
return step, None

# compute the metrics
agg_step, agg_mets = self._finalize_agg_metrics()

# as new step received reset accumulator
self._metrics_to_agg = [metrics]
self._prev_step = step
return agg_step, agg_mets

def _finalize_agg_metrics(self):
"""Aggregate accumulated metrics. This shall be called in close."""
# compute the metrics
if not self._metrics_to_agg:
agg_mets = None
elif len(self._metrics_to_agg) == 1:
agg_mets = self._metrics_to_agg[0]
else:
agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func)
return self._prev_step, agg_mets

def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like we can specify the mean/avg/... here

"""Aggregates and records metrics.
This method doesn't log the passed metrics instantaneously, but instead
it aggregates them and logs only if metrics are ready to be logged.

Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
"""
agg_step, metrics_to_log = self._aggregate_metrics(metrics=metrics, step=step)

if metrics_to_log is not None:
self.log_metrics(metrics=metrics_to_log, step=agg_step)

@abstractmethod
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""Record metrics.
"""Records metrics.
This method logs metrics as as soon as it received them. If you want to aggregate
metrics for one specific `step`, use the `agg_and_log_metrics` method.

Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
"""
pass

@staticmethod
def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
Expand Down Expand Up @@ -131,7 +231,10 @@ def finalize(self, status: str) -> None:

def close(self) -> None:
"""Do any cleanup that is necessary to close an experiment."""
pass
agg_step, metrics_to_log = self._finalize_agg_metrics()

if metrics_to_log is not None:
self.log_metrics(metrics=metrics_to_log, step=agg_step)

@property
def rank(self) -> int:
Expand Down Expand Up @@ -200,3 +303,48 @@ def name(self) -> str:
@property
def version(self) -> str:
return '_'.join([str(logger.version) for logger in self._logger_iterable])


def merge_dicts(
dicts: Sequence[Mapping],
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
default_func: Callable[[Sequence[float]], float] = np.mean
) -> Dict:
"""Merge a sequence with dictionaries into one dictionary by aggregating the
same keys with some given function.

Args:
dicts:
Sequence of dictionaries to be merged.
agg_key_funcs:
Mapping from key name to function. This function will aggregate a
list of values, obtained from the same key of all dictionaries.
If some key has no specified aggregation function, the default one
will be used. Default is: None (all keys will be aggregated by the
default function).
default_func:
Default function to aggregate keys, which are not presented in the
`agg_key_funcs` map.

Returns:
Dictionary with merged values.

Examples:
>>> import pprint
>>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1}
>>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1}
>>> d3 = {'a': 1.1, 'v': 2.3}
>>> dflt_func = min
>>> agg_funcs = {'a': np.mean, 'v': max}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't numpy functions slow things down because everything needs to go to cpu? we don't want to move things to CPU for the user ever haha. Every cpu calls slows training down a ton

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so to have it rather as Tensor...

Copy link
Member

@Borda Borda Apr 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexeykarnachev we may use the Running accum or make another for full accum and extending every N steps and copy existing...
https://discuss.pytorch.org/t/dynamically-extend-the-tensor/39553/2
or https://discuss.pytorch.org/t/appending-to-a-tensor/2665/9

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the structure will change from list of dict to dict of tensors but not sure if it makes much faster... also it will allow us to use agg implemented for Torch

Copy link
Contributor Author

@alexeykarnachev alexeykarnachev Apr 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. I thought, that at this point all values are already on cpu.
Here we call the log procedure:
https://github.com/PyTorchLightning/pytorch-lightning/blob/f7622ebfca45abe7d8d34f2ee2070d6856e24646/pytorch_lightning/trainer/logging.py#L74

And few lines above this call, we transform the metrics to the scalars:
https://github.com/PyTorchLightning/pytorch-lightning/blob/f7622ebfca45abe7d8d34f2ee2070d6856e24646/pytorch_lightning/trainer/logging.py#L64

So, do we really need tensors here? Metrics which come to the LightningLoggerBase are already itemed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda , the TrainerLoggingMixin forces all metrics to be a scalars. And it seems not very convenient to transform them back to tensors. Besides, as I see, metrics could never be a tensor. For example, I can pass scalar value metric even on the training_step, like so:

    def training_step(self, batch, batch_idx):
        loss, logits, _ = self.forward(batch)
        lr = self.trainer.optimizers[0].param_groups[0]['lr']
        log = {'Loss/train': loss, 'Learning-Rate': lr}
        return {'loss': loss, 'log': log}

Here, the lr is a scalar, it's on cpu and it is not a tensor.

What do you think on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, let's get this done and think about speedup later... :]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok perfect.

this might actually be the main cause of the minimal speed discrepancy between lightning and pure pytorch.

>>> pprint.pprint(merge_dicts([d1, d2, d3], agg_funcs, dflt_func))
{'a': 1.3, 'b': 2.0, 'c': 1, 'v': 2.3}
"""

keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts]))
d_out = {}
for k in keys:
fn = agg_key_funcs.get(k, default_func) if agg_key_funcs else default_func
agg_val = fn([v for v in [d_in.get(k) for d_in in dicts] if v is not None])
d_out[k] = agg_val

return d_out
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
step = step if step is not None else self.global_step
# log actual metrics
if self.proc_rank == 0 and self.logger is not None:
self.logger.log_metrics(scalar_metrics, step=step)
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.logger.save()

def add_tqdm_metrics(self, metrics):
Expand Down
24 changes: 19 additions & 5 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import torch


class TensorRunningMean(object):
class TensorRunningAccum(object):
"""
Tracks a running mean without graph references.
Round robbin for the mean

Examples:
>>> accum = TensorRunningMean(5)
>>> accum = TensorRunningAccum(5)
>>> accum.last(), accum.mean()
(None, None)
>>> accum.append(torch.tensor(1.5))
Expand All @@ -18,8 +18,8 @@ class TensorRunningMean(object):
(tensor(2.5000), tensor(2.))
>>> accum.reset()
>>> _= [accum.append(torch.tensor(i)) for i in range(13)]
>>> accum.last(), accum.mean()
(tensor(12.), tensor(10.))
>>> accum.last(), accum.mean(), accum.min(), accum.max()
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
"""
def __init__(self, window_length: int):
self.window_length = window_length
Expand All @@ -29,13 +29,16 @@ def __init__(self, window_length: int):
self.rotated: bool = False

def reset(self) -> None:
self = TensorRunningMean(self.window_length)
"""Empty the accumulator."""
self = TensorRunningAccum(self.window_length)

def last(self):
"""Get the last added element."""
if self.last_idx is not None:
return self.memory[self.last_idx]

def append(self, x):
"""Add an element to the accumolator."""
# ensure same device and type
if self.memory.device != x.device or self.memory.type() != x.type():
x = x.to(self.memory)
Expand All @@ -54,5 +57,16 @@ def append(self, x):
self.rotated = True

def mean(self):
"""Get mean value from stored elements."""
if self.last_idx is not None:
return self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean()

def max(self):
"""Get maximal value from stored elements."""
if self.last_idx is not None:
return self.memory.max() if self.rotated else self.memory[:self.current_idx].max()

def min(self):
"""Get minimal value from stored elements."""
if self.last_idx is not None:
return self.memory.min() if self.rotated else self.memory[:self.current_idx].min()
33 changes: 17 additions & 16 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
from argparse import ArgumentParser
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any
import distutils

import torch
import torch.distributed as torch_distrib
Expand All @@ -28,7 +29,7 @@
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.supporters import TensorRunningMean
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
Expand Down Expand Up @@ -342,7 +343,7 @@ def __init__(

# training bookeeping
self.total_batch_idx = 0
self.running_loss = TensorRunningMean(window_length=20)
self.running_loss = TensorRunningAccum(window_length=20)
self.batch_idx = 0
self.tqdm_metrics = {}
self.callback_metrics = {}
Expand Down Expand Up @@ -551,20 +552,20 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:

allowed_types = (str, float, int, bool)
# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in cls.get_init_arguments_and_types():
if arg not in depr_arg_names:
for allowed_type in allowed_types:
if allowed_type in arg_types:
if allowed_type is bool:
allowed_type = lambda x: bool(distutils.util.strtobool(x))
parser.add_argument(
f'--{arg}',
default=arg_default,
type=allowed_type,
dest=arg,
help='autogenerated by pl.Trainer'
)
break
for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types()
if at[0] not in depr_arg_names):
for allowed_type in (at for at in allowed_types
if allowed_type in arg_types):
Borda marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(allowed_type, bool):
allowed_type = lambda x: bool(distutils.util.strtobool(x))
parser.add_argument(
f'--{arg}',
default=arg_default,
type=allowed_type,
dest=arg,
help='autogenerated by pl.Trainer'
)
break

return parser

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningMean
from pytorch_lightning.trainer.supporters import TensorRunningAccum

try:
from apex import amp
Expand Down Expand Up @@ -337,7 +337,7 @@ def train(self):
self.accumulation_scheduler.on_epoch_start(self, self.get_model())

# stores accumulated grad fractions per batch
self.batch_loss_value = TensorRunningMean(
self.batch_loss_value = TensorRunningAccum(
window_length=self.accumulate_grad_batches
)

Expand Down
Loading