Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix log_dir property #5537

Merged
merged 37 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
ca4b20b
fix and update tests
rohitgr7 Jan 15, 2021
df17236
update with ModelCheckpoint
rohitgr7 Jan 15, 2021
a346e87
Merge branch 'master' into bugfix/log_dir_prop
rohitgr7 Jan 15, 2021
80187ff
chlog
rohitgr7 Jan 15, 2021
ca248cd
wip wandb fix
rohitgr7 Jan 16, 2021
4167fff
all fixed
rohitgr7 Jan 16, 2021
a2ac5a0
Merge branch 'master' into bugfix/log_dir_prop
rohitgr7 Jan 16, 2021
5d52fed
Merge branch 'master' into bugfix/log_dir_prop
tchaton Jan 18, 2021
1bc3420
Merge branch 'master' into bugfix/log_dir_prop
Borda Jan 24, 2021
0eb7218
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Jan 24, 2021
59c6246
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Jan 24, 2021
e3b3eba
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Jan 25, 2021
5e289b4
Merge branch 'master' into bugfix/log_dir_prop
rohitgr7 Jan 25, 2021
ddde5ed
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Jan 25, 2021
bb5daed
Merge branch 'master' into bugfix/log_dir_prop
Borda Jan 26, 2021
004b1a1
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Jan 26, 2021
62a7f43
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Jan 26, 2021
0aca929
Merge branch 'master' into bugfix/log_dir_prop
Borda Jan 26, 2021
73dba9e
Merge branch 'master' into bugfix/log_dir_prop
rohitgr7 Jan 27, 2021
d62badc
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Jan 27, 2021
60d1e4c
Merge branch 'master' into bugfix/log_dir_prop
SkafteNicki Jan 27, 2021
dcae70e
Merge branch 'master' into bugfix/log_dir_prop
Borda Jan 28, 2021
c6949e9
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Jan 28, 2021
7f3bd24
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Jan 29, 2021
abf6ea5
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Jan 29, 2021
9f817d4
hint
Borda Jan 29, 2021
e53b437
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Jan 29, 2021
f2c3c21
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Jan 30, 2021
de33c41
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Jan 31, 2021
4441003
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Jan 31, 2021
a2b9942
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Feb 1, 2021
f360ee6
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Feb 1, 2021
7d24cab
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Feb 1, 2021
250d6fb
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Feb 1, 2021
ec85ce0
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Feb 1, 2021
e16c440
Merge branch 'master' into bugfix/log_dir_prop
mergify[bot] Feb 2, 2021
81b82c9
rev
rohitgr7 Feb 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `reinit_scheduler_properties` with correct optimizer ([#5519](https://github.com/PyTorchLightning/pytorch-lightning/pull/5519))


- Fixed `log_dir` property ([#5537](https://github.com/PyTorchLightning/pytorch-lightning/pull/5537))


- Fixed `val_check_interval` with `fast_dev_run` ([#5540](https://github.com/PyTorchLightning/pytorch-lightning/pull/5540))


Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def on_pretrain_routine_start(self, trainer, pl_module):
"""
When pretrain routine starts we build the ckpt dir on the fly
"""
self.__resolve_ckpt_dir(trainer, pl_module)
self.__resolve_ckpt_dir(trainer)
self.save_function = trainer.save_checkpoint

def on_validation_end(self, trainer, pl_module):
Expand Down Expand Up @@ -447,7 +447,7 @@ def format_checkpoint_name(
ckpt_name = f"{filename}{self.FILE_EXTENSION}"
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name

def __resolve_ckpt_dir(self, trainer, pl_module):
def __resolve_ckpt_dir(self, trainer):
"""
Determines model checkpoint save directory at runtime. References attributes from the
trainer's logger to determine where to save checkpoints.
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ def on_after_backward(self) -> None:
def on_after_backward(self):
# example to inspect gradient information in tensorboard
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
params = self.state_dict()
for k, v in params.items():
params = self.named_parameters()
for k, v in params:
self.logger.experiment.add_histogram(
tag=k, values=v.grad, global_step=self.trainer.global_step
)
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@
from time import time
from typing import Any, Dict, Optional, Union


from pytorch_lightning import _logger as log
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn, _module_available

from pytorch_lightning.utilities import _module_available, rank_zero_only, rank_zero_warn

LOCAL_FILE_URI_PREFIX = "file:"

Expand Down
14 changes: 12 additions & 2 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import torch.nn as nn

from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only, _module_available
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warning_utils import WarningCache

_WANDB_AVAILABLE = _module_available("wandb")
Expand Down Expand Up @@ -98,6 +99,13 @@ def __init__(
if wandb is None:
raise ImportError('You want to use `wandb` logger which is not installed yet,' # pragma: no-cover
' install it with `pip install wandb`.')

if offline and log_model:
raise MisconfigurationException(
f'log_model={log_model} and offline={offline} is an invalid configuration'
' since model checkpoints cannot be uploaded in offline mode.'
)
Borda marked this conversation as resolved.
Show resolved Hide resolved

super().__init__()
self._name = name
self._save_dir = save_dir
Expand Down Expand Up @@ -141,10 +149,12 @@ def experiment(self) -> Run:
self._experiment = wandb.init(
name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous,
id=self._id, resume='allow', **self._kwargs) if wandb.run is None else wandb.run

# offset logging step when resuming a run
self._step_offset = self._experiment.step

# save checkpoints in wandb dir to upload on W&B servers
if self._log_model:
if self._save_dir is None:
self._save_dir = self._experiment.dir
return self._experiment

Expand Down
16 changes: 5 additions & 11 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from argparse import ArgumentParser, Namespace
import inspect
import os
from abc import ABC
from argparse import ArgumentParser, Namespace
from typing import cast, List, Optional, Type, TypeVar, Union

from pytorch_lightning.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -63,16 +63,10 @@ class TrainerProperties(ABC):

@property
def log_dir(self):
if self.checkpoint_callback is not None:
dirpath = self.checkpoint_callback.dirpath
dirpath = os.path.split(dirpath)[0]
elif self.logger is not None:
if isinstance(self.logger, TensorBoardLogger):
dirpath = self.logger.log_dir
else:
dirpath = self.logger.save_dir
if self.logger is None:
dirpath = self.default_root_dir
else:
dirpath = self._default_root_dir
dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir')

if self.accelerator_backend is not None:
dirpath = self.accelerator_backend.broadcast(dirpath)
Expand Down
21 changes: 4 additions & 17 deletions tests/loggers/test_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest.mock import patch, DEFAULT
from unittest.mock import DEFAULT, patch

import pytest

Expand Down Expand Up @@ -74,7 +74,7 @@ def test_comet_logger_online(comet):
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_logger_no_api_key_given(comet):
""" Test that CometLogger fails to initialize if both api key and save_dir are missing. """
with pytest.raises(MisconfigurationException):
with pytest.raises(MisconfigurationException, match='requires either api_key or save_dir'):
comet.config.get_api_key.return_value = None
CometLogger(workspace='dummy-test', project_name='general')

Expand All @@ -89,13 +89,10 @@ def test_comet_logger_experiment_name(comet):
# Test api_key given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment:
logger = CometLogger(api_key=api_key, experiment_name=experiment_name,)

assert logger._experiment is None

_ = logger.experiment

comet_experiment.assert_called_once_with(api_key=api_key, project_name=None)

comet_experiment().set_name.assert_called_once_with(experiment_name)


Expand All @@ -118,13 +115,10 @@ def save_os_environ(*args, **kwargs):
with patch.dict(os.environ, {"COMET_EXPERIMENT_KEY": experiment_key}):
with patch('pytorch_lightning.loggers.comet.CometExperiment', side_effect=save_os_environ) as comet_experiment:
logger = CometLogger(api_key=api_key)

assert logger.version == experiment_key

assert logger._experiment is None

_ = logger.experiment

comet_experiment.assert_called_once_with(api_key=api_key, project_name=None)

assert instantation_environ["COMET_EXPERIMENT_KEY"] == experiment_key
Expand Down Expand Up @@ -156,10 +150,12 @@ def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch

model = EvalModelTemplate()
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
assert trainer.log_dir == logger.save_dir
trainer.fit(model)

assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
assert trainer.log_dir == logger.save_dir


@patch('pytorch_lightning.loggers.comet.comet_ml')
Expand All @@ -170,11 +166,8 @@ def test_comet_name_default(comet):

with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key)

assert logger._experiment is None

assert logger.name == "comet-default"

assert logger._experiment is None


Expand All @@ -187,11 +180,8 @@ def test_comet_name_project_name(comet):

with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key, project_name=project_name)

assert logger._experiment is None

assert logger.name == project_name

assert logger._experiment is None


Expand All @@ -205,14 +195,11 @@ def test_comet_version_without_experiment(comet):

with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key, experiment_name=experiment_name)

assert logger._experiment is None

first_version = logger.version
assert first_version is not None

assert logger.version == first_version

assert logger._experiment is None

_ = logger.experiment
Expand Down
5 changes: 3 additions & 2 deletions tests/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.
import importlib.util
import os

from unittest import mock
from unittest.mock import MagicMock
import pytest

import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import _MLFLOW_AVAILABLE, MLFlowLogger
Expand Down Expand Up @@ -113,9 +112,11 @@ def test_mlflow_log_dir(client, mlflow, tmpdir):
limit_train_batches=1,
limit_val_batches=3,
)
assert trainer.log_dir == logger.save_dir
trainer.fit(model)
assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=0.ckpt'}
assert trainer.log_dir == logger.save_dir


def test_mlflow_logger_dirs_creation(tmpdir):
Expand Down
4 changes: 3 additions & 1 deletion tests/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch

import torch

Expand Down Expand Up @@ -112,7 +112,9 @@ def _run_training(logger):
limit_train_batches=0.05,
logger=logger,
)
assert trainer.log_dir is None
trainer.fit(model)
assert trainer.log_dir is None
return logger

logger_close_after_fit = _run_training(NeptuneLogger(offline_mode=True))
Expand Down
4 changes: 3 additions & 1 deletion tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def test_tensorboard_hparams_reload(tmpdir):
model = EvalModelTemplate()

trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
assert trainer.log_dir == trainer.logger.log_dir
trainer.fit(model)

folder_path = trainer.logger.log_dir
assert trainer.log_dir == trainer.logger.log_dir
folder_path = trainer.log_dir

# make sure yaml is there
with open(os.path.join(folder_path, "hparams.yaml")) as file:
Expand Down
22 changes: 18 additions & 4 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# limitations under the License.
import os
import pickle
from unittest import mock
from argparse import ArgumentParser
import types
from argparse import ArgumentParser
from unittest import mock

import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from tests.base import EvalModelTemplate, BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel, EvalModelTemplate


def get_warnings(recwarn):
Expand Down Expand Up @@ -94,6 +97,7 @@ class Experiment:
""" """
id = 'the_id'
step = 0
dir = 'wandb'

def project_name(self):
return 'the_project_name'
Expand All @@ -109,6 +113,7 @@ def project_name(self):
)
# Access the experiment to ensure it's created
assert trainer.logger.experiment, 'missing experiment'
assert trainer.log_dir == logger.save_dir
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)

Expand Down Expand Up @@ -147,11 +152,13 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):

version = logger.version
model = EvalModelTemplate()
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3, log_every_n_steps=1)
assert trainer.log_dir == logger.save_dir
trainer.fit(model)

assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
assert trainer.log_dir == logger.save_dir


def test_wandb_sanitize_callable_params(tmpdir):
Expand Down Expand Up @@ -182,3 +189,10 @@ def wrapper_something():
assert params["something"] == "something"
assert params["wrapper_something"] == "wrapper_something"
assert params["wrapper_something_wo_name"] == "<lambda>"


@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_logger_offline_log_model(wandb, tmpdir):
""" Test that log_model=True raises an error in offline mode """
with pytest.raises(MisconfigurationException, match='checkpoints cannot be uploaded in offline mode'):
logger = WandbLogger(save_dir=str(tmpdir), offline=True, log_model=True)
Loading