Skip to content

Commit

Permalink
Revert "Added accumulation of loggers' metrics for the same steps (Li…
Browse files Browse the repository at this point in the history
…ghtning-AI#1278)"

This reverts commit ddbf7de.
  • Loading branch information
alsrgv committed Apr 11, 2020
1 parent 3f1e4b9 commit 9b83715
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 217 deletions.
2 changes: 0 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added same step loggers' metrics aggregation ([#1278](https://github.com/PyTorchLightning/pytorch-lightning/pull/1278))
- Added parity test between a vanilla MNIST model and lightning model ([#1284](https://github.com/PyTorchLightning/pytorch-lightning/pull/1284))
- Added parity test between a vanilla RNN model and lightning model ([#1351](https://github.com/PyTorchLightning/pytorch-lightning/pull/1351))
- Added Reinforcement Learning - Deep Q-network (DQN) lightning example ([#1232](https://github.com/PyTorchLightning/pytorch-lightning/pull/1232))
Expand All @@ -85,7 +84,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Changed (renamed and refatored) `TensorRunningMean` -> `TensorRunningAccum`: running accumulations were generalized. ([#1278](https://github.com/PyTorchLightning/pytorch-lightning/pull/1278))
- Changed `progress_bar_refresh_rate` trainer flag to disable progress bar when set to 0. ([#1108](https://github.com/PyTorchLightning/pytorch-lightning/pull/1108))
- Enhanced `load_from_checkpoint` to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
- Updated references to `self.forward()` to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
Expand Down
156 changes: 4 additions & 152 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import argparse
import functools
import operator
from abc import ABC, abstractmethod
from argparse import Namespace
from functools import wraps
from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple
from typing import Union, Optional, Dict, Iterable, Any, Callable, List

import numpy as np
import torch


Expand All @@ -28,119 +25,22 @@ def wrapped_fn(self, *args, **kwargs):
class LightningLoggerBase(ABC):
"""Base class for experiment loggers."""

def __init__(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Callable[[Sequence[float]], float] = np.mean
):
"""
Args:
agg_key_funcs:
Dictionary which maps a metric name to a function, which will
aggregate the metric values for the same steps.
agg_default_func:
Default function to aggregate metric values. If some metric name
is not presented in the `agg_key_funcs` dictionary, then the
`agg_default_func` will be used for aggregation.
Notes:
`agg_key_funcs` and `agg_default_func` are used only when one logs metrics with
`LightningLoggerBase.agg_and_log_metrics` method.
"""
def __init__(self):
self._rank = 0
self._prev_step = -1
self._metrics_to_agg: List[Dict[str, float]] = []
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {}
self._agg_default_func = agg_default_func

def update_agg_funcs(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Callable[[Sequence[float]], float] = np.mean
):
"""Update aggregation methods.
Args:
agg_key_funcs:
Dictionary which maps a metric name to a function, which will
aggregate the metric values for the same steps.
agg_default_func:
Default function to aggregate metric values. If some metric name
is not presented in the `agg_key_funcs` dictionary, then the
`agg_default_func` will be used for aggregation.
"""
if agg_key_funcs:
self._agg_key_funcs.update(agg_key_funcs)
if agg_default_func:
self._agg_default_func = agg_default_func

@property
@abstractmethod
def experiment(self) -> Any:
"""Return the experiment object associated with this logger"""

def _aggregate_metrics(
self, metrics: Dict[str, float], step: Optional[int] = None
) -> Tuple[int, Optional[Dict[str, float]]]:
"""Aggregates metrics.
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
Returns:
sStep and aggregated metrics. The return value could be None. In such case, metrics
are added to the aggregation list, but not aggregated yet.
"""
# if you still receiving metric from the same step, just accumulate it
if step == self._prev_step:
self._metrics_to_agg.append(metrics)
return step, None

# compute the metrics
agg_step, agg_mets = self._finalize_agg_metrics()

# as new step received reset accumulator
self._metrics_to_agg = [metrics]
self._prev_step = step
return agg_step, agg_mets

def _finalize_agg_metrics(self):
"""Aggregate accumulated metrics. This shall be called in close."""
# compute the metrics
if not self._metrics_to_agg:
agg_mets = None
elif len(self._metrics_to_agg) == 1:
agg_mets = self._metrics_to_agg[0]
else:
agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func)
return self._prev_step, agg_mets

def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""Aggregates and records metrics.
This method doesn't log the passed metrics instantaneously, but instead
it aggregates them and logs only if metrics are ready to be logged.
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
"""
agg_step, metrics_to_log = self._aggregate_metrics(metrics=metrics, step=step)

if metrics_to_log is not None:
self.log_metrics(metrics=metrics_to_log, step=agg_step)

@abstractmethod
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""Records metrics.
This method logs metrics as as soon as it received them. If you want to aggregate
metrics for one specific `step`, use the `agg_and_log_metrics` method.
"""Record metrics.
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
"""
pass

@staticmethod
def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
Expand Down Expand Up @@ -231,10 +131,7 @@ def finalize(self, status: str) -> None:

def close(self) -> None:
"""Do any cleanup that is necessary to close an experiment."""
agg_step, metrics_to_log = self._finalize_agg_metrics()

if metrics_to_log is not None:
self.log_metrics(metrics=metrics_to_log, step=agg_step)
pass

@property
def rank(self) -> int:
Expand Down Expand Up @@ -303,48 +200,3 @@ def name(self) -> str:
@property
def version(self) -> str:
return '_'.join([str(logger.version) for logger in self._logger_iterable])


def merge_dicts(
dicts: Sequence[Mapping],
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
default_func: Callable[[Sequence[float]], float] = np.mean
) -> Dict:
"""Merge a sequence with dictionaries into one dictionary by aggregating the
same keys with some given function.
Args:
dicts:
Sequence of dictionaries to be merged.
agg_key_funcs:
Mapping from key name to function. This function will aggregate a
list of values, obtained from the same key of all dictionaries.
If some key has no specified aggregation function, the default one
will be used. Default is: None (all keys will be aggregated by the
default function).
default_func:
Default function to aggregate keys, which are not presented in the
`agg_key_funcs` map.
Returns:
Dictionary with merged values.
Examples:
>>> import pprint
>>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1}
>>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1}
>>> d3 = {'a': 1.1, 'v': 2.3}
>>> dflt_func = min
>>> agg_funcs = {'a': np.mean, 'v': max}
>>> pprint.pprint(merge_dicts([d1, d2, d3], agg_funcs, dflt_func))
{'a': 1.3, 'b': 2.0, 'c': 1, 'v': 2.3}
"""

keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts]))
d_out = {}
for k in keys:
fn = agg_key_funcs.get(k, default_func) if agg_key_funcs else default_func
agg_val = fn([v for v in [d_in.get(k) for d_in in dicts] if v is not None])
d_out[k] = agg_val

return d_out
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
step = step if step is not None else self.global_step
# log actual metrics
if self.proc_rank == 0 and self.logger is not None:
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.logger.log_metrics(scalar_metrics, step=step)
self.logger.save()

def add_tqdm_metrics(self, metrics):
Expand Down
35 changes: 9 additions & 26 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import torch


class TensorRunningAccum(object):
"""Tracks a running accumulation values (min, max, mean) without graph
references.
class TensorRunningMean(object):
"""
Tracks a running mean without graph references.
Round robbin for the mean
Examples:
>>> accum = TensorRunningAccum(5)
>>> accum = TensorRunningMean(5)
>>> accum.last(), accum.mean()
(None, None)
>>> accum.append(torch.tensor(1.5))
Expand All @@ -17,8 +18,8 @@ class TensorRunningAccum(object):
(tensor(2.5000), tensor(2.))
>>> accum.reset()
>>> _= [accum.append(torch.tensor(i)) for i in range(13)]
>>> accum.last(), accum.mean(), accum.min(), accum.max()
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
>>> accum.last(), accum.mean()
(tensor(12.), tensor(10.))
"""

def __init__(self, window_length: int):
Expand All @@ -29,16 +30,13 @@ def __init__(self, window_length: int):
self.rotated: bool = False

def reset(self) -> None:
"""Empty the accumulator."""
self = TensorRunningAccum(self.window_length)
self = TensorRunningMean(self.window_length)

def last(self):
"""Get the last added element."""
if self.last_idx is not None:
return self.memory[self.last_idx]

def append(self, x):
"""Add an element to the accumulator."""
# ensure same device and type
if self.memory.device != x.device or self.memory.type() != x.type():
x = x.to(self.memory)
Expand All @@ -57,20 +55,5 @@ def append(self, x):
self.rotated = True

def mean(self):
"""Get mean value from stored elements."""
return self._agg_memory('mean')

def max(self):
"""Get maximal value from stored elements."""
return self._agg_memory('max')

def min(self):
"""Get minimal value from stored elements."""
return self._agg_memory('min')

def _agg_memory(self, how: str):
if self.last_idx is not None:
if self.rotated:
return getattr(self.memory, how)()
else:
return getattr(self.memory[:self.current_idx], how)()
return self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean()
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.trainer.supporters import TensorRunningMean
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
Expand Down Expand Up @@ -378,7 +378,7 @@ def __init__(

# training bookeeping
self.total_batch_idx = 0
self.running_loss = TensorRunningAccum(window_length=20)
self.running_loss = TensorRunningMean(window_length=20)
self.batch_idx = 0
self.tqdm_metrics = {}
self.callback_metrics = {}
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.trainer.supporters import TensorRunningMean

try:
from apex import amp
Expand Down Expand Up @@ -337,7 +337,7 @@ def train(self):
self.accumulation_scheduler.on_epoch_start(self, self.get_model())

# stores accumulated grad fractions per batch
self.batch_loss_value = TensorRunningAccum(
self.batch_loss_value = TensorRunningMean(
window_length=self.accumulate_grad_batches
)

Expand Down
31 changes: 1 addition & 30 deletions tests/loggers/test_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import pickle
from collections import OrderedDict
from unittest.mock import MagicMock

import numpy as np

import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LightningLoggerBase, rank_zero_only, LoggerCollection
Expand Down Expand Up @@ -59,18 +56,6 @@ def version(self):
return "1"


class StoreHistoryLogger(CustomLogger):
def __init__(self):
super().__init__()
self.history = {}

@rank_zero_only
def log_metrics(self, metrics, step):
if step not in self.history:
self.history[step] = {}
self.history[step].update(metrics)


def test_custom_logger(tmpdir):
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
Expand Down Expand Up @@ -168,19 +153,5 @@ def decorated(metrics, step):
num_sanity_val_steps=0,
)
trainer = Trainer(**trainer_options)
trainer.logger.log_metrics = _log_metrics_decorator(
trainer.logger.log_metrics)
trainer.logger.log_metrics = _log_metrics_decorator(trainer.logger.log_metrics)
trainer.fit(model)


def test_with_accumulate_grad_batches():
"""Checks if the logging is performed once for `accumulate_grad_batches` steps."""
logger = StoreHistoryLogger()

np.random.seed(42)
for i, loss in enumerate(np.random.random(10)):
logger.agg_and_log_metrics({'loss': loss}, step=int(i / 5))

assert logger.history == {0: {'loss': 0.5623850983416314}}
logger.close()
assert logger.history == {0: {'loss': 0.5623850983416314}, 1: {'loss': 0.4778883735637184}}
7 changes: 5 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import glob
import math
import os
from argparse import Namespace, ArgumentParser
from argparse import Namespace

import pytest
import torch

import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks import (
EarlyStopping,
ModelCheckpoint,
)
from pytorch_lightning import Callback
from pytorch_lightning.core.lightning import load_hparams_from_tags_csv
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
Expand Down

0 comments on commit 9b83715

Please sign in to comment.