Skip to content

Commit

Permalink
Fix log_dir property (#5537)
Browse files Browse the repository at this point in the history
* fix and update tests

* update with ModelCheckpoint

* chlog

* wip wandb fix

* all fixed

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people committed Feb 5, 2021
1 parent 2a8e9df commit 78b4d2b
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 131 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `num_classes` argument in F1 metric ([#5663](https://github.com/PyTorchLightning/pytorch-lightning/pull/5663))


- Fixed `log_dir` property ([#5537](https://github.com/PyTorchLightning/pytorch-lightning/pull/5537))


- Fixed a race condition in `ModelCheckpoint` when checking if a checkpoint file exists ([#5144](https://github.com/PyTorchLightning/pytorch-lightning/pull/5144))

- Remove unnecessary intermediate layers in Dockerfiles ([#5697](https://github.com/PyTorchLightning/pytorch-lightning/pull/5697))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def on_pretrain_routine_start(self, trainer, pl_module):
"""
When pretrain routine starts we build the ckpt dir on the fly
"""
self.__resolve_ckpt_dir(trainer, pl_module)
self.__resolve_ckpt_dir(trainer)
self.save_function = trainer.save_checkpoint

def on_validation_end(self, trainer, pl_module):
Expand Down Expand Up @@ -427,7 +427,7 @@ def format_checkpoint_name(
ckpt_name = f"{filename}{self.FILE_EXTENSION}"
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name

def __resolve_ckpt_dir(self, trainer, pl_module):
def __resolve_ckpt_dir(self, trainer):
"""
Determines model checkpoint save directory at runtime. References attributes from the
trainer's logger to determine where to save checkpoints.
Expand Down
13 changes: 12 additions & 1 deletion pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warning_utils import WarningCache

_WANDB_AVAILABLE = _module_available("wandb")
Expand Down Expand Up @@ -100,6 +101,14 @@ def __init__(
if wandb is None:
raise ImportError('You want to use `wandb` logger which is not installed yet,' # pragma: no-cover
' install it with `pip install wandb`.')

if offline and log_model:
raise MisconfigurationException(
f'Providing log_model={log_model} and offline={offline} is an invalid configuration'
' since model checkpoints cannot be uploaded in offline mode.\n'
'Hint: Set `offline=False` to log your model.'
)

super().__init__()
self._name = name
self._save_dir = save_dir
Expand Down Expand Up @@ -144,10 +153,12 @@ def experiment(self) -> Run:
self._experiment = wandb.init(
name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous,
id=self._id, resume='allow', **self._kwargs) if wandb.run is None else wandb.run

# offset logging step when resuming a run
self._step_offset = self._experiment.step

# save checkpoints in wandb dir to upload on W&B servers
if self._log_model:
if self._save_dir is None:
self._save_dir = self._experiment.dir
return self._experiment

Expand Down
36 changes: 12 additions & 24 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,25 @@
import inspect
import os
from abc import ABC
from argparse import ArgumentParser
from argparse import Namespace
from argparse import ArgumentParser, Namespace
from typing import cast, List, Optional, Type, TypeVar, Union

from pytorch_lightning.accelerators.legacy.accelerator import Accelerator
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import ProgressBarBase
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities import DistributedType
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.argparse import add_argparse_args
from pytorch_lightning.utilities.argparse import from_argparse_args
from pytorch_lightning.utilities.argparse import parse_argparser
from pytorch_lightning.utilities.argparse import parse_env_variables
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType, rank_zero_warn
from pytorch_lightning.utilities.argparse import (
add_argparse_args,
from_argparse_args,
parse_argparser,
parse_env_variables,
)
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.model_helpers import is_overridden

Expand Down Expand Up @@ -75,16 +69,10 @@ class TrainerProperties(ABC):

@property
def log_dir(self):
if self.checkpoint_callback is not None:
dirpath = self.checkpoint_callback.dirpath
dirpath = os.path.split(dirpath)[0]
elif self.logger is not None:
if isinstance(self.logger, TensorBoardLogger):
dirpath = self.logger.log_dir
else:
dirpath = self.logger.save_dir
if self.logger is None:
dirpath = self.default_root_dir
else:
dirpath = self._default_root_dir
dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir')

if self.accelerator_backend is not None:
dirpath = self.accelerator_backend.broadcast(dirpath)
Expand Down
34 changes: 7 additions & 27 deletions tests/loggers/test_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel
from tests.base import BoringModel, EvalModelTemplate


def _patch_comet_atexit(monkeypatch):
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_comet_logger_online(comet):
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_logger_no_api_key_given(comet):
""" Test that CometLogger fails to initialize if both api key and save_dir are missing. """
with pytest.raises(MisconfigurationException):
with pytest.raises(MisconfigurationException, match='requires either api_key or save_dir'):
comet.config.get_api_key.return_value = None
CometLogger(workspace='dummy-test', project_name='general')

Expand All @@ -89,13 +89,10 @@ def test_comet_logger_experiment_name(comet):
# Test api_key given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment:
logger = CometLogger(api_key=api_key, experiment_name=experiment_name,)

assert logger._experiment is None

_ = logger.experiment

comet_experiment.assert_called_once_with(api_key=api_key, project_name=None)

comet_experiment().set_name.assert_called_once_with(experiment_name)


Expand All @@ -118,13 +115,10 @@ def save_os_environ(*args, **kwargs):
with patch.dict(os.environ, {"COMET_EXPERIMENT_KEY": experiment_key}):
with patch('pytorch_lightning.loggers.comet.CometExperiment', side_effect=save_os_environ) as comet_experiment:
logger = CometLogger(api_key=api_key)

assert logger.version == experiment_key

assert logger._experiment is None

_ = logger.experiment

comet_experiment.assert_called_once_with(api_key=api_key, project_name=None)

assert instantation_environ["COMET_EXPERIMENT_KEY"] == experiment_key
Expand Down Expand Up @@ -154,19 +148,14 @@ def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch
logger.experiment.id = '1'
logger.experiment.project_name = 'test'

limit_batches = 5
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
logger=logger,
max_epochs=1,
limit_train_batches=limit_batches,
limit_val_batches=limit_batches,
)
model = EvalModelTemplate()
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
assert trainer.log_dir == logger.save_dir
trainer.fit(model)

assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints')
assert os.listdir(trainer.checkpoint_callback.dirpath) == [f'epoch=0-step={limit_batches - 1}.ckpt']
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
assert trainer.log_dir == logger.save_dir


@patch('pytorch_lightning.loggers.comet.comet_ml')
Expand All @@ -177,11 +166,8 @@ def test_comet_name_default(comet):

with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key)

assert logger._experiment is None

assert logger.name == "comet-default"

assert logger._experiment is None


Expand All @@ -194,11 +180,8 @@ def test_comet_name_project_name(comet):

with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key, project_name=project_name)

assert logger._experiment is None

assert logger.name == project_name

assert logger._experiment is None


Expand All @@ -212,14 +195,11 @@ def test_comet_version_without_experiment(comet):

with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key, experiment_name=experiment_name)

assert logger._experiment is None

first_version = logger.version
assert first_version is not None

assert logger.version == first_version

assert logger._experiment is None

_ = logger.experiment
Expand Down
2 changes: 2 additions & 0 deletions tests/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,11 @@ def test_mlflow_log_dir(client, mlflow, tmpdir):
limit_train_batches=1,
limit_val_batches=3,
)
assert trainer.log_dir == logger.save_dir
trainer.fit(model)
assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=0.ckpt'}
assert trainer.log_dir == logger.save_dir


def test_mlflow_logger_dirs_creation(tmpdir):
Expand Down
2 changes: 2 additions & 0 deletions tests/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def _run_training(logger):
limit_train_batches=0.05,
logger=logger,
)
assert trainer.log_dir is None
trainer.fit(model)
assert trainer.log_dir is None
return logger

logger_close_after_fit = _run_training(NeptuneLogger(offline_mode=True))
Expand Down
27 changes: 15 additions & 12 deletions tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,30 @@

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import BoringModel
from tests.base import BoringModel, EvalModelTemplate


@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.5.0"),
reason="Minimal PT version is set to 1.5",
)
def test_tensorboard_hparams_reload(tmpdir):
class CustomModel(BoringModel):
def __init__(self, b1=0.5, b2=0.999):
super().__init__()
self.save_hyperparameters()
model = EvalModelTemplate()

model = CustomModel()
trainer = Trainer(max_steps=1, default_root_dir=tmpdir)
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
assert trainer.log_dir == trainer.logger.log_dir
trainer.fit(model)

folder_path = trainer.logger.log_dir
assert trainer.log_dir == trainer.logger.log_dir
folder_path = trainer.log_dir

# make sure yaml is there
with open(os.path.join(folder_path, "hparams.yaml")) as file:
# The FullLoader parameter handles the conversion from YAML
# scalar values to Python the dictionary format
yaml_params = yaml.safe_load(file)
assert yaml_params["b1"] == 0.5
assert yaml_params["b2"] == 0.999
assert len(yaml_params.keys()) == 2
assert len(yaml_params.keys()) == 10

# verify artifacts
assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 1
Expand All @@ -59,8 +56,14 @@ def __init__(self, b1=0.5, b2=0.999):
event_acc = EventAccumulator(folder_path)
event_acc.Reload()

data_pt_1_5 = b'\x12\x1b"\x04\n\x02b1"\x04\n\x02b2*\r\n\x0b\x12\thp_metric'
data_pt_1_6 = b'\x12\x1f"\x06\n\x02b1 \x03"\x06\n\x02b2 \x03*\r\n\x0b\x12\thp_metric'
data_pt_1_5 = b'\x12\x93\x01"\x0b\n\tdrop_prob"\x0c\n\nbatch_size"\r\n\x0bin_features"\x0f\n\rlearning_rate"' \
b'\x10\n\x0eoptimizer_name"\x0b\n\tdata_root"\x0e\n\x0cout_features"\x0c\n\nhidden_dim"' \
b'\x04\n\x02b1"\x04\n\x02b2*\r\n\x0b\x12\thp_metric'
data_pt_1_6 = b'\x12\xa7\x01"\r\n\tdrop_prob \x03"\x0e\n\nbatch_size \x03"\x0f\n\x0bin_features \x03"' \
b'\x11\n\rlearning_rate \x03"\x12\n\x0eoptimizer_name \x01"\r\n\tdata_root \x01"' \
b'\x10\n\x0cout_features \x03"\x0e\n\nhidden_dim \x03"\x06\n\x02b1 \x03"' \
b'\x06\n\x02b2 \x03*\r\n\x0b\x12\thp_metric'

hparams_data = data_pt_1_6 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0") else data_pt_1_5

assert event_acc.summary_metadata['_hparams_/experiment'].plugin_data.plugin_name == 'hparams'
Expand Down
28 changes: 15 additions & 13 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
# limitations under the License.
import os
import pickle
import types
from argparse import ArgumentParser
from unittest import mock

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from tests.base import BoringModel
from tests.base import BoringModel, EvalModelTemplate


def get_warnings(recwarn):
Expand Down Expand Up @@ -106,6 +104,7 @@ class Experiment:
""" """
id = 'the_id'
step = 0
dir = 'wandb'

def project_name(self):
return 'the_project_name'
Expand All @@ -121,6 +120,7 @@ def project_name(self):
)
# Access the experiment to ensure it's created
assert trainer.logger.experiment, 'missing experiment'
assert trainer.log_dir == logger.save_dir
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)

Expand Down Expand Up @@ -158,19 +158,14 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
assert not os.listdir(tmpdir)

version = logger.version
model = BoringModel()
limit_batches = 5
trainer = Trainer(
default_root_dir=tmpdir,
logger=logger,
max_epochs=1,
limit_train_batches=limit_batches,
limit_val_batches=limit_batches,
)
model = EvalModelTemplate()
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3, log_every_n_steps=1)
assert trainer.log_dir == logger.save_dir
trainer.fit(model)

assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
assert os.listdir(trainer.checkpoint_callback.dirpath) == [f'epoch=0-step={limit_batches - 1}.ckpt']
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
assert trainer.log_dir == logger.save_dir


def test_wandb_sanitize_callable_params(tmpdir):
Expand Down Expand Up @@ -201,3 +196,10 @@ def wrapper_something():
assert params["something"] == "something"
assert params["wrapper_something"] == "wrapper_something"
assert params["wrapper_something_wo_name"] == "<lambda>"


@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_logger_offline_log_model(wandb, tmpdir):
""" Test that log_model=True raises an error in offline mode """
with pytest.raises(MisconfigurationException, match='checkpoints cannot be uploaded in offline mode'):
logger = WandbLogger(save_dir=str(tmpdir), offline=True, log_model=True)
Loading

0 comments on commit 78b4d2b

Please sign in to comment.