Skip to content

Commit

Permalink
fix logging on rank 0 only (#2425)
Browse files Browse the repository at this point in the history
* fix and test for ddp block logging rank > 0

* rename

* use the dummy logger

* dummy logger test

* set the logger in  model

* decorator for rank zero experiment

* simplify check

* simplify

* fix problem with None in checkpoint path

* revert configure logger

* unused import

* offline

* try rank 0 decorator in checkpoint

* try fix test

* imgs

* add asserts to make sure log zero only saves checkpoints

* add asserts to make sure log zero only saves checkpoints

* add asserts to make sure log zero only saves checkpoints

* add asserts to make sure log zero only saves checkpoints

* add asserts to make sure log zero only saves checkpoints

* fix tpu tests

* fix tpu tests

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
awaelchli and williamFalcon authored Jun 30, 2020
1 parent 04e68f0 commit 145670f
Show file tree
Hide file tree
Showing 12 changed files with 104 additions and 14 deletions.
8 changes: 7 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def format_checkpoint_name(self, epoch, metrics, ver=None):
filepath = os.path.join(self.dirpath, self.prefix + filename + str_ver + '.ckpt')
return filepath

@rank_zero_only
def on_train_start(self, trainer, pl_module):
"""
Determine model checkpoint save directory at runtime. References attributes from the
Expand All @@ -236,7 +237,7 @@ def on_train_start(self, trainer, pl_module):

self.filename = '{epoch}'

if trainer.logger is not None and trainer.logger.experiment is not None:
if trainer.logger is not None:
# weights_save_path overrides anything
if getattr(trainer, 'weights_save_path', None) is not None:
save_dir = trainer.weights_save_path
Expand All @@ -257,6 +258,9 @@ def on_train_start(self, trainer, pl_module):
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")

self.dirpath = ckpt_path

assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'

os.makedirs(self.dirpath, exist_ok=True)
trainer.ckpt_path = ckpt_path
trainer.weights_save_path = ckpt_path
Expand Down Expand Up @@ -312,6 +316,8 @@ def on_validation_end(self, trainer, pl_module):
else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')

assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
self._save_model(filepath)

def _do_check_save(self, filepath, current, epoch):
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import operator
from abc import ABC, abstractmethod
from argparse import Namespace
from functools import wraps
from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple, MutableMapping

import numpy as np
import torch

from pytorch_lightning.utilities import rank_zero_only


class LightningLoggerBase(ABC):
"""
Expand All @@ -32,7 +35,6 @@ def __init__(
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Callable[[Sequence[float]], float] = np.mean
):
self._rank = 0
self._prev_step: int = -1
self._metrics_to_agg: List[Dict[str, float]] = []
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {}
Expand Down Expand Up @@ -377,3 +379,14 @@ def merge_dicts(
d_out[k] = (fn or default_func)(values_to_agg)

return d_out


def rank_zero_experiment(fn: Callable) -> Callable:
""" Returns the real experiment on rank 0 and otherwise the DummyExperiment. """
@wraps(fn)
def experiment(self):
@rank_zero_only
def get_experiment():
return fn(self)
return get_experiment() or DummyExperiment()
return experiment
5 changes: 4 additions & 1 deletion pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torch import is_tensor

from pytorch_lightning import _logger as log
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_only

Expand Down Expand Up @@ -140,6 +140,7 @@ def __init__(self,
self._kwargs = kwargs

@property
@rank_zero_experiment
def experiment(self) -> CometBaseExperiment:
r"""
Actual Comet object. To use Comet features in your
Expand Down Expand Up @@ -192,6 +193,8 @@ def log_metrics(
metrics: Dict[str, Union[torch.Tensor, float]],
step: Optional[int] = None
) -> None:
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
for key, val in metrics.items():
if is_tensor(val):
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
_MLFLOW_AVAILABLE = False

from pytorch_lightning import _logger as log
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only


Expand Down Expand Up @@ -75,6 +75,7 @@ def __init__(self,
self.tags = tags

@property
@rank_zero_experiment
def experiment(self) -> MlflowClient:
r"""
Actual MLflow object. To use mlflow features in your
Expand Down Expand Up @@ -113,6 +114,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:

@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

timestamp_ms = int(time() * 1000)
for k, v in metrics.items():
if isinstance(v, str):
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import is_tensor

from pytorch_lightning import _logger as log
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only


Expand Down Expand Up @@ -211,6 +211,7 @@ def __getstate__(self):
return state

@property
@rank_zero_experiment
def experiment(self) -> Experiment:
r"""
Actual Neptune object. To use neptune features in your
Expand Down Expand Up @@ -249,6 +250,7 @@ def log_metrics(
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded, must be strictly increasing
"""
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'
for key, val in metrics.items():
self.log_metric(key, val, step=step)

Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only


Expand Down Expand Up @@ -83,6 +83,7 @@ def log_dir(self) -> str:
return log_dir

@property
@rank_zero_experiment
def experiment(self) -> SummaryWriter:
r"""
Actual tensorboard object. To use TensorBoard features in your
Expand All @@ -96,6 +97,7 @@ def experiment(self) -> SummaryWriter:
if self._experiment is not None:
return self._experiment

assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0'
os.makedirs(self.root_dir, exist_ok=True)
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment
Expand Down Expand Up @@ -135,6 +137,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace],

@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

for k, v in metrics.items():
if isinstance(v, torch.Tensor):
v = v.item()
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/loggers/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Experiment = None
_TEST_TUBE_AVAILABLE = False

from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities.distributed import rank_zero_only


Expand Down Expand Up @@ -77,6 +77,7 @@ def __init__(self,
self._experiment = None

@property
@rank_zero_experiment
def experiment(self) -> Experiment:
r"""
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Run = None
_WANDB_AVAILABLE = False

from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only


Expand Down Expand Up @@ -95,6 +95,7 @@ def __getstate__(self):
return state

@property
@rank_zero_experiment
def experiment(self) -> Run:
r"""
Expand Down Expand Up @@ -128,6 +129,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:

@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

self.experiment.log({'global_step': step, **metrics} if step is not None else metrics)

@property
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class TrainerDPMixin(ABC):
tpu_id: Optional[int]
on_colab_kaggle: str
save_spawn_weights: Callable
logger: ...

@property
@abstractmethod
Expand Down Expand Up @@ -106,6 +107,7 @@ def copy_trainer_model_properties(self, model):

for m in [model, ref_model]:
m.trainer = self
m.logger = self.logger
m.use_dp = self.use_dp
m.use_ddp2 = self.use_ddp2
m.use_ddp = self.use_ddp
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only
import warnings

# warnings to ignore
warnings.filterwarnings('ignore', message='torch.distributed.reduce_op is deprecated, '
'please use torch.distributed.ReduceOp instead')

try:
from apex import amp
Expand Down Expand Up @@ -359,8 +364,6 @@ def __init__(
default_root_dir = os.getcwd()
self.default_root_dir = default_root_dir

self.configure_logger(logger)

# init callbacks
self.callbacks = callbacks or []

Expand Down Expand Up @@ -500,9 +503,9 @@ def __init__(
self._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)

# logging
self.configure_logger(logger)
self.log_save_interval = log_save_interval
self.val_check_interval = val_check_interval

self.row_log_interval = row_log_interval

# how much of the data to use
Expand Down Expand Up @@ -843,7 +846,6 @@ def fit(
"""
# bind logger and other properties
model.logger = self.logger
self.copy_trainer_model_properties(model)

# clean hparams
Expand Down
53 changes: 51 additions & 2 deletions tests/loggers/test_all.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import inspect
import pickle
import platform

import pytest

import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, Callback
from pytorch_lightning.loggers import (
TensorBoardLogger, MLFlowLogger, NeptuneLogger, TestTubeLogger, CometLogger)
TensorBoardLogger,
MLFlowLogger,
NeptuneLogger,
TestTubeLogger,
CometLogger,
WandbLogger,
)
from pytorch_lightning.loggers.base import DummyExperiment
from tests.base import EvalModelTemplate


Expand All @@ -16,6 +24,8 @@ def _get_logger_args(logger_class, save_dir):
logger_args.update(save_dir=str(save_dir))
if 'offline_mode' in inspect.getfullargspec(logger_class).args:
logger_args.update(offline_mode=True)
if 'offline' in inspect.getfullargspec(logger_class).args:
logger_args.update(offline=True)
return logger_args


Expand Down Expand Up @@ -119,3 +129,42 @@ def test_logger_reset_correctly(tmpdir, extra_params):
'Finder altered the logger of trainer'
assert logger2 == logger3, \
'Finder altered the logger of model'


class RankZeroLoggerCheck(Callback):
# this class has to be defined outside the test function, otherwise we get pickle error
# due to the way ddp process is launched

def on_batch_start(self, trainer, pl_module):
is_dummy = isinstance(trainer.logger.experiment, DummyExperiment)
if trainer.is_global_zero:
assert not is_dummy
else:
assert is_dummy
assert pl_module.logger.experiment.something(foo="bar") is None


@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.parametrize("logger_class", [
TensorBoardLogger,
CometLogger,
#MLFlowLogger,
NeptuneLogger,
TestTubeLogger,
WandbLogger,
])
def test_logger_created_on_rank_zero_only(tmpdir, logger_class):
logger_args = _get_logger_args(logger_class, tmpdir)
logger = logger_class(**logger_args)
model = EvalModelTemplate()
trainer = Trainer(
logger=logger,
default_root_dir=tmpdir,
distributed_backend='ddp_cpu',
num_processes=2,
max_steps=1,
checkpoint_callback=True,
callbacks=[RankZeroLoggerCheck()],
)
result = trainer.fit(model)
assert result == 1
2 changes: 2 additions & 0 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _run_horovod(trainer_options, on_gpu=False):
assert exit_code == 0


@pytest.mark.skipif(True, reason="Need to deconflict what happens to file paths in ddp")
@pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8")
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
def test_horovod_cpu(tmpdir):
Expand All @@ -70,6 +71,7 @@ def test_horovod_cpu(tmpdir):
_run_horovod(trainer_options)


@pytest.mark.skipif(True, reason="Need to deconflict what happens to file paths in ddp")
@pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8")
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
def test_horovod_cpu_implicit(tmpdir):
Expand Down

0 comments on commit 145670f

Please sign in to comment.