Skip to content

Commit

Permalink
Added accumulation of loggers' metrics for the same steps (#1278)
Browse files Browse the repository at this point in the history
* `add_argparse_args` method fixed (argument types added)

* autopep8 fixes

* --gpus=0 removed from test (for ci tests)

* Update pytorch_lightning/trainer/trainer.py

Co-Authored-By: Joe Davison <joe@huggingface.co>

* test_with_accumulate_grad_batches added

* agg_and_log_metrics logic added to the base logger class

* small format fix

* agg metrics strategies removed (not to complicate stuff)

* agg metrics: handle zero step

* autopep8

* changelog upd

* flake fix

* metrics aggregators factored out, metrics_agg.py added + tests

* metrics agg default value added

* Update pytorch_lightning/loggers/metrics_agg.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* metrics aggregators factored out, metrics_agg.py added + tests

* metrics agg default value added

* Update pytorch_lightning/loggers/metrics_agg.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* remove .item which causes sync issues (#1254)

* remove .item which causes sync issues

* fixed gradient acc sched

* fixed gradient acc sched

* test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored

* test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored

* autopep8

* loggers base.py types fixed

* test

* test

* metrics aggregation for loggers: each key now has a specific function (or default one)

* metrics aggregation for loggers: each key now has a specific function (or default one)

* docstrings upd

* manual typehints removed from docstrings

* batch_size decreased for test `test_with_accumulate_grad_batches`

* extend running accum

* refactor

* fix tests

* fix tests

* allowed_types generator scoped

* trainer.py distutils was imported twice, fixed

* TensorRunningAccum refactored

* TensorRunningAccum added to change log (Changed)

* change log pull link added

Co-authored-by: Joe Davison <joe@huggingface.co>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
5 people committed Apr 8, 2020
1 parent 471499c commit ddbf7de
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 38 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added same step loggers' metrics aggregation ([#1278](https://github.com/PyTorchLightning/pytorch-lightning/pull/1278))
- Added parity test between a vanilla MNIST model and lightning model ([#1284](https://github.com/PyTorchLightning/pytorch-lightning/pull/1284))
- Added parity test between a vanilla RNN model and lightning model ([#1351](https://github.com/PyTorchLightning/pytorch-lightning/pull/1351))
- Added Reinforcement Learning - Deep Q-network (DQN) lightning example ([#1232](https://github.com/PyTorchLightning/pytorch-lightning/pull/1232))
Expand All @@ -30,6 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Changed (renamed and refatored) `TensorRunningMean` -> `TensorRunningAccum`: running accumulations were generalized. ([#1278](https://github.com/PyTorchLightning/pytorch-lightning/pull/1278))
- Changed `progress_bar_refresh_rate` trainer flag to disable progress bar when set to 0. ([#1108](https://github.com/PyTorchLightning/pytorch-lightning/pull/1108))
- Enhanced `load_from_checkpoint` to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
Expand Down
156 changes: 152 additions & 4 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import argparse
import functools
import operator
from abc import ABC, abstractmethod
from argparse import Namespace
from functools import wraps
from typing import Union, Optional, Dict, Iterable, Any, Callable, List
from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple

import numpy as np
import torch


Expand All @@ -25,22 +28,119 @@ def wrapped_fn(self, *args, **kwargs):
class LightningLoggerBase(ABC):
"""Base class for experiment loggers."""

def __init__(self):
def __init__(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Callable[[Sequence[float]], float] = np.mean
):
"""
Args:
agg_key_funcs:
Dictionary which maps a metric name to a function, which will
aggregate the metric values for the same steps.
agg_default_func:
Default function to aggregate metric values. If some metric name
is not presented in the `agg_key_funcs` dictionary, then the
`agg_default_func` will be used for aggregation.
Notes:
`agg_key_funcs` and `agg_default_func` are used only when one logs metrics with
`LightningLoggerBase.agg_and_log_metrics` method.
"""
self._rank = 0
self._prev_step = -1
self._metrics_to_agg: List[Dict[str, float]] = []
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {}
self._agg_default_func = agg_default_func

def update_agg_funcs(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Callable[[Sequence[float]], float] = np.mean
):
"""Update aggregation methods.
Args:
agg_key_funcs:
Dictionary which maps a metric name to a function, which will
aggregate the metric values for the same steps.
agg_default_func:
Default function to aggregate metric values. If some metric name
is not presented in the `agg_key_funcs` dictionary, then the
`agg_default_func` will be used for aggregation.
"""
if agg_key_funcs:
self._agg_key_funcs.update(agg_key_funcs)
if agg_default_func:
self._agg_default_func = agg_default_func

@property
@abstractmethod
def experiment(self) -> Any:
"""Return the experiment object associated with this logger"""

def _aggregate_metrics(
self, metrics: Dict[str, float], step: Optional[int] = None
) -> Tuple[int, Optional[Dict[str, float]]]:
"""Aggregates metrics.
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
Returns:
sStep and aggregated metrics. The return value could be None. In such case, metrics
are added to the aggregation list, but not aggregated yet.
"""
# if you still receiving metric from the same step, just accumulate it
if step == self._prev_step:
self._metrics_to_agg.append(metrics)
return step, None

# compute the metrics
agg_step, agg_mets = self._finalize_agg_metrics()

# as new step received reset accumulator
self._metrics_to_agg = [metrics]
self._prev_step = step
return agg_step, agg_mets

def _finalize_agg_metrics(self):
"""Aggregate accumulated metrics. This shall be called in close."""
# compute the metrics
if not self._metrics_to_agg:
agg_mets = None
elif len(self._metrics_to_agg) == 1:
agg_mets = self._metrics_to_agg[0]
else:
agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func)
return self._prev_step, agg_mets

def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""Aggregates and records metrics.
This method doesn't log the passed metrics instantaneously, but instead
it aggregates them and logs only if metrics are ready to be logged.
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
"""
agg_step, metrics_to_log = self._aggregate_metrics(metrics=metrics, step=step)

if metrics_to_log is not None:
self.log_metrics(metrics=metrics_to_log, step=agg_step)

@abstractmethod
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""Record metrics.
"""Records metrics.
This method logs metrics as as soon as it received them. If you want to aggregate
metrics for one specific `step`, use the `agg_and_log_metrics` method.
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
"""
pass

@staticmethod
def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
Expand Down Expand Up @@ -131,7 +231,10 @@ def finalize(self, status: str) -> None:

def close(self) -> None:
"""Do any cleanup that is necessary to close an experiment."""
pass
agg_step, metrics_to_log = self._finalize_agg_metrics()

if metrics_to_log is not None:
self.log_metrics(metrics=metrics_to_log, step=agg_step)

@property
def rank(self) -> int:
Expand Down Expand Up @@ -200,3 +303,48 @@ def name(self) -> str:
@property
def version(self) -> str:
return '_'.join([str(logger.version) for logger in self._logger_iterable])


def merge_dicts(
dicts: Sequence[Mapping],
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
default_func: Callable[[Sequence[float]], float] = np.mean
) -> Dict:
"""Merge a sequence with dictionaries into one dictionary by aggregating the
same keys with some given function.
Args:
dicts:
Sequence of dictionaries to be merged.
agg_key_funcs:
Mapping from key name to function. This function will aggregate a
list of values, obtained from the same key of all dictionaries.
If some key has no specified aggregation function, the default one
will be used. Default is: None (all keys will be aggregated by the
default function).
default_func:
Default function to aggregate keys, which are not presented in the
`agg_key_funcs` map.
Returns:
Dictionary with merged values.
Examples:
>>> import pprint
>>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1}
>>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1}
>>> d3 = {'a': 1.1, 'v': 2.3}
>>> dflt_func = min
>>> agg_funcs = {'a': np.mean, 'v': max}
>>> pprint.pprint(merge_dicts([d1, d2, d3], agg_funcs, dflt_func))
{'a': 1.3, 'b': 2.0, 'c': 1, 'v': 2.3}
"""

keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts]))
d_out = {}
for k in keys:
fn = agg_key_funcs.get(k, default_func) if agg_key_funcs else default_func
agg_val = fn([v for v in [d_in.get(k) for d_in in dicts] if v is not None])
d_out[k] = agg_val

return d_out
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
step = step if step is not None else self.global_step
# log actual metrics
if self.proc_rank == 0 and self.logger is not None:
self.logger.log_metrics(scalar_metrics, step=step)
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.logger.save()

def add_tqdm_metrics(self, metrics):
Expand Down
35 changes: 26 additions & 9 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import torch


class TensorRunningMean(object):
"""
Tracks a running mean without graph references.
Round robbin for the mean
class TensorRunningAccum(object):
"""Tracks a running accumulation values (min, max, mean) without graph
references.
Examples:
>>> accum = TensorRunningMean(5)
>>> accum = TensorRunningAccum(5)
>>> accum.last(), accum.mean()
(None, None)
>>> accum.append(torch.tensor(1.5))
Expand All @@ -18,8 +17,8 @@ class TensorRunningMean(object):
(tensor(2.5000), tensor(2.))
>>> accum.reset()
>>> _= [accum.append(torch.tensor(i)) for i in range(13)]
>>> accum.last(), accum.mean()
(tensor(12.), tensor(10.))
>>> accum.last(), accum.mean(), accum.min(), accum.max()
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
"""
def __init__(self, window_length: int):
self.window_length = window_length
Expand All @@ -29,13 +28,16 @@ def __init__(self, window_length: int):
self.rotated: bool = False

def reset(self) -> None:
self = TensorRunningMean(self.window_length)
"""Empty the accumulator."""
self = TensorRunningAccum(self.window_length)

def last(self):
"""Get the last added element."""
if self.last_idx is not None:
return self.memory[self.last_idx]

def append(self, x):
"""Add an element to the accumulator."""
# ensure same device and type
if self.memory.device != x.device or self.memory.type() != x.type():
x = x.to(self.memory)
Expand All @@ -54,5 +56,20 @@ def append(self, x):
self.rotated = True

def mean(self):
"""Get mean value from stored elements."""
return self._agg_memory('mean')

def max(self):
"""Get maximal value from stored elements."""
return self._agg_memory('max')

def min(self):
"""Get minimal value from stored elements."""
return self._agg_memory('min')

def _agg_memory(self, how: str):
if self.last_idx is not None:
return self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean()
if self.rotated:
return getattr(self.memory, how)()
else:
return getattr(self.memory[:self.current_idx], how)()
31 changes: 15 additions & 16 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.supporters import TensorRunningMean
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
Expand Down Expand Up @@ -342,7 +342,7 @@ def __init__(

# training bookeeping
self.total_batch_idx = 0
self.running_loss = TensorRunningMean(window_length=20)
self.running_loss = TensorRunningAccum(window_length=20)
self.batch_idx = 0
self.tqdm_metrics = {}
self.callback_metrics = {}
Expand Down Expand Up @@ -551,20 +551,19 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:

allowed_types = (str, float, int, bool)
# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in cls.get_init_arguments_and_types():
if arg not in depr_arg_names:
for allowed_type in allowed_types:
if allowed_type in arg_types:
if allowed_type is bool:
allowed_type = lambda x: bool(distutils.util.strtobool(x))
parser.add_argument(
f'--{arg}',
default=arg_default,
type=allowed_type,
dest=arg,
help='autogenerated by pl.Trainer'
)
break
for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types()
if at[0] not in depr_arg_names):
for allowed_type in (at for at in allowed_types if at in arg_types):
if isinstance(allowed_type, bool):
allowed_type = lambda x: bool(distutils.util.strtobool(x))
parser.add_argument(
f'--{arg}',
default=arg_default,
type=allowed_type,
dest=arg,
help='autogenerated by pl.Trainer'
)
break

return parser

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningMean
from pytorch_lightning.trainer.supporters import TensorRunningAccum

try:
from apex import amp
Expand Down Expand Up @@ -337,7 +337,7 @@ def train(self):
self.accumulation_scheduler.on_epoch_start(self, self.get_model())

# stores accumulated grad fractions per batch
self.batch_loss_value = TensorRunningMean(
self.batch_loss_value = TensorRunningAccum(
window_length=self.accumulate_grad_batches
)

Expand Down
Loading

0 comments on commit ddbf7de

Please sign in to comment.