Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revision callbacks.data_monitor #848

Open
wants to merge 80 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
14cf441
add type hint
Jul 31, 2022
7488302
minor changes data_monitor
Jul 31, 2022
c00b131
review tests data_monitor consistency
Jul 31, 2022
5c967fc
pre-commit
Aug 2, 2022
d3cf978
type hint, under review
Aug 2, 2022
05c9343
add catch_warnings
Aug 2, 2022
72edfa3
catch_warnings to all tests
luca-medeiros Aug 9, 2022
3d0332c
Merge branch 'master' into master
otaj Aug 11, 2022
2d778cc
Merge branch 'master' into master
otaj Aug 12, 2022
7e1513c
Merge branch 'master' into master
mergify[bot] Aug 16, 2022
1be9dbe
Merge branch 'master' into master
mergify[bot] Aug 23, 2022
2a0fc9f
Merge branch 'master' into master
mergify[bot] Aug 23, 2022
79ae592
Merge branch 'master' into master
mergify[bot] Aug 25, 2022
16ef8f4
Merge branch 'master' into master
mergify[bot] Aug 26, 2022
84b7a76
Merge branch 'master' into master
mergify[bot] Sep 9, 2022
6cb9d89
Merge branch 'master' into master
mergify[bot] Sep 12, 2022
ffabd79
Merge branch 'master' into master
mergify[bot] Sep 15, 2022
ef902a9
Merge branch 'master' into master
mergify[bot] Sep 15, 2022
3e87f18
Merge branch 'master' into master
mergify[bot] Sep 16, 2022
788d382
Merge branch 'master' into master
mergify[bot] Sep 19, 2022
c2f3194
Merge branch 'master' into master
mergify[bot] Sep 19, 2022
0256374
Merge branch 'master' into master
mergify[bot] Sep 19, 2022
f73dded
Merge branch 'master' into master
mergify[bot] Sep 19, 2022
a62e904
fix minor changes
luca-medeiros Sep 19, 2022
625b79d
Merge branch 'master' into master
mergify[bot] Sep 21, 2022
4c26c74
Merge branch 'master' into master
mergify[bot] Sep 21, 2022
4ceb2d5
Merge branch 'master' into master
mergify[bot] Sep 22, 2022
354bb23
Merge branch 'master' into master
mergify[bot] Sep 23, 2022
60b0242
Merge branch 'master' into master
mergify[bot] Sep 23, 2022
d9ade5e
Merge branch 'master' into master
mergify[bot] Sep 23, 2022
e3c3848
Merge branch 'master' into master
mergify[bot] Sep 28, 2022
3f931ba
Merge branch 'master' into master
mergify[bot] Oct 8, 2022
308c18d
Merge branch 'master' into master
mergify[bot] Oct 11, 2022
a6142a3
Merge branch 'master' into master
mergify[bot] Oct 11, 2022
d2a07b7
Merge branch 'master' into master
mergify[bot] Oct 18, 2022
ed5c661
Merge branch 'master' into master
Borda Oct 27, 2022
76b4398
Merge branch 'master' into master
mergify[bot] Oct 28, 2022
165d4a4
Merge branch 'master' into master
mergify[bot] Oct 31, 2022
f742694
Merge branch 'master' into master
otaj Nov 2, 2022
c36a496
precommit
Nov 2, 2022
5792f42
Merge branch 'master' into master
mergify[bot] Nov 3, 2022
d5730ba
Merge branch 'master' into master
mergify[bot] Nov 3, 2022
feb03fc
Merge branch 'master' into master
Borda Jan 8, 2023
513a908
Merge branch 'master' into master
Borda Mar 28, 2023
dedc928
Merge branch 'master' into master
Borda Mar 28, 2023
bd23c27
update mergify team
Borda May 19, 2023
2b0b649
Merge branch 'master' into luca-medeiros/master
Borda May 19, 2023
86a7479
Merge branch 'master' into master
Borda May 19, 2023
0bfee0e
Merge branch 'master' into master
mergify[bot] May 19, 2023
511c903
Merge branch 'master' into master
mergify[bot] May 20, 2023
75652fb
Merge branch 'master' into master
mergify[bot] May 20, 2023
ac338e6
Merge branch 'master' into master
mergify[bot] May 20, 2023
69e0a03
Merge branch 'master' into master
mergify[bot] May 20, 2023
bb2d912
Merge branch 'master' into master
mergify[bot] May 20, 2023
7bbac45
Merge branch 'master' into master
mergify[bot] May 20, 2023
c2127d0
Merge branch 'master' into luca-medeiros/master
Borda May 20, 2023
90338ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2023
6fcd2c0
Merge branch 'master' into master
mergify[bot] May 20, 2023
d0c8a44
Merge branch 'master' into master
mergify[bot] May 21, 2023
8a0860b
Merge branch 'master' into master
mergify[bot] May 21, 2023
015a25a
Merge branch 'master' into master
mergify[bot] May 22, 2023
d211e09
Merge branch 'master' into master
mergify[bot] May 22, 2023
ac15d8d
Merge branch 'master' into master
mergify[bot] May 29, 2023
5f99435
Merge branch 'master' into master
mergify[bot] May 30, 2023
4386ada
Merge branch 'master' into master
Borda May 31, 2023
33824d4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
ef7ad92
Merge branch 'master' into master
mergify[bot] May 31, 2023
e559476
Merge branch 'master' into master
Borda May 31, 2023
7b8097a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
3bcdc56
Merge branch 'master' into master
mergify[bot] Jun 12, 2023
a0ab248
Merge branch 'master' into master
mergify[bot] Jun 16, 2023
a79ec7f
Merge branch 'master' into master
mergify[bot] Jun 16, 2023
9caa021
Merge branch 'master' into master
Borda Jun 29, 2023
b4e44f6
drop LoggerCollection
Borda Jun 30, 2023
b052d9f
logger
Borda Jun 30, 2023
f3c9722
use lightning_utilities
Borda Jun 30, 2023
85c8397
params
Borda Jun 30, 2023
c33dabc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 30, 2023
a5f33e7
Merge branch 'master' into master
Borda Jul 4, 2023
8ae8eba
Merge branch 'master' into master
Borda Jul 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions pl_bolts/callbacks/data_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torch.utils.hooks import RemovableHandle

from pl_bolts.utils import _WANDB_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _WANDB_AVAILABLE:
Expand All @@ -20,15 +19,14 @@
warn_missing_pkg("wandb")


@under_review()
class DataMonitorBase(Callback):

supported_loggers = (
TensorBoardLogger,
WandbLogger,
)
otaj marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, log_every_n_steps: int = None):
def __init__(self, log_every_n_steps: Optional[int] = None):
"""Base class for monitoring data histograms in a LightningModule. This requires a logger configured in the
Trainer, otherwise no data is logged. The specific class that inherits from this base defines what data
gets collected.
Expand Down Expand Up @@ -97,7 +95,7 @@ def log_histogram(self, tensor: Tensor, name: str) -> None:

logger.experiment.log(data={name: wandb.Histogram(tensor)}, commit=False)
luca-medeiros marked this conversation as resolved.
Show resolved Hide resolved

def _is_logger_available(self, logger: LightningLoggerBase) -> bool:
def _is_logger_available(self, logger: Optional[LightningLoggerBase]) -> bool:
luca-medeiros marked this conversation as resolved.
Show resolved Hide resolved
available = True
if not logger:
rank_zero_warn("Cannot log histograms because Trainer has no logger.")
Expand All @@ -111,7 +109,6 @@ def _is_logger_available(self, logger: LightningLoggerBase) -> bool:
return available


@under_review()
class ModuleDataMonitor(DataMonitorBase):

GROUP_NAME_INPUT = "input"
Expand All @@ -120,9 +117,9 @@ class ModuleDataMonitor(DataMonitorBase):
def __init__(
self,
submodules: Optional[Union[bool, List[str]]] = None,
log_every_n_steps: int = None,
log_every_n_steps: Optional[int] = None,
):
"""
"""Logs the in- and output histogram of submodules.
Args:
submodules: If `True`, logs the in- and output histograms of every submodule in the
LightningModule, including the root module itself.
Expand Down Expand Up @@ -157,8 +154,7 @@ def __init__(
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
super().on_train_start(trainer, pl_module)
submodule_dict = dict(pl_module.named_modules())
self._hook_handles = []
luca-medeiros marked this conversation as resolved.
Show resolved Hide resolved
for name in self._get_submodule_names(pl_module):
for name in self._get_submodule_names(submodule_dict):
if name not in submodule_dict:
rank_zero_warn(
f"{name} is not a valid identifier for a submodule in {pl_module.__class__.__name__},"
Expand All @@ -172,23 +168,23 @@ def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
for handle in self._hook_handles:
handle.remove()

def _get_submodule_names(self, root_module: nn.Module) -> List[str]:
def _get_submodule_names(self, named_modules: dict) -> List[str]:
luca-medeiros marked this conversation as resolved.
Show resolved Hide resolved
# default is the root module only
names = [""]

if isinstance(self._submodule_names, list):
names = self._submodule_names

if self._submodule_names is True:
names = [name for name, _ in root_module.named_modules()]
names = list(named_modules.keys())
luca-medeiros marked this conversation as resolved.
Show resolved Hide resolved

return names

def _register_hook(self, module_name: str, module: nn.Module) -> RemovableHandle:
input_group_name = f"{self.GROUP_NAME_INPUT}/{module_name}" if module_name else self.GROUP_NAME_INPUT
output_group_name = f"{self.GROUP_NAME_OUTPUT}/{module_name}" if module_name else self.GROUP_NAME_OUTPUT

def hook(_: Module, inp: Sequence, out: Sequence) -> None:
def hook(_: Module, inp: Any, out: Any) -> None:
inp = inp[0] if len(inp) == 1 else inp
self.log_histograms(inp, group=input_group_name)
self.log_histograms(out, group=output_group_name)
Expand All @@ -197,12 +193,11 @@ def hook(_: Module, inp: Sequence, out: Sequence) -> None:
return handle


@under_review()
class TrainingDataMonitor(DataMonitorBase):

GROUP_NAME = "training_step"

def __init__(self, log_every_n_steps: int = None):
def __init__(self, log_every_n_steps: Optional[int] = None):
"""Callback that logs the histogram of values in the batched data passed to `training_step`.

Args:
Expand Down Expand Up @@ -230,7 +225,11 @@ def on_train_batch_start(
self.log_histograms(batch, group=self.GROUP_NAME)


def collect_and_name_tensors(data: Any, output: Dict[str, Tensor], parent_name: str = "input") -> None:
def collect_and_name_tensors(
data: Union[Tensor, dict, Sequence],
output: Dict[str, Tensor],
parent_name: str = "input",
) -> None:
"""Recursively fetches all tensors in a (nested) collection of data (depth-first search) and names them. Data
in dictionaries get named by their corresponding keys and otherwise they get indexed by an increasing integer.
The shape of the tensor gets appended to the name as well.
Expand Down Expand Up @@ -261,7 +260,6 @@ def collect_and_name_tensors(data: Any, output: Dict[str, Tensor], parent_name:
collect_and_name_tensors(item, output, parent_name=f"{parent_name}/{i:d}")


@under_review()
def shape2str(tensor: Tensor) -> str:
"""Returns the shape of a tensor in bracket notation as a string.

Expand All @@ -271,4 +269,4 @@ def shape2str(tensor: Tensor) -> str:
>>> shape2str(torch.rand(4))
'[4]'
"""
return "[" + ", ".join(map(str, tensor.shape)) + "]"
return str(list(tensor.shape))
77 changes: 63 additions & 14 deletions tests/callbacks/test_data_monitor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import warnings
from unittest import mock
from unittest.mock import call

import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.utilities.rank_zero import LightningDeprecationWarning
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch import nn

from pl_bolts.callbacks import ModuleDataMonitor, TrainingDataMonitor
Expand All @@ -13,15 +16,29 @@

@pytest.mark.parametrize(["log_every_n_steps", "max_steps", "expected_calls"], [pytest.param(3, 10, 3)])
@mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram")
def test_base_log_interval_override(log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir):
def test_base_log_interval_override(
log_histogram,
tmpdir,
log_every_n_steps,
max_steps,
expected_calls,
datadir,
catch_warnings,
):
"""Test logging interval set by log_every_n_steps argument."""
warnings.filterwarnings(
"ignore",
message=".*does not have many workers which may be a bottleneck.*",
category=PossibleUserWarning,
)
monitor = TrainingDataMonitor(log_every_n_steps=log_every_n_steps)
model = LitMNIST(data_dir=datadir, num_workers=0)
trainer = Trainer(
default_root_dir=tmpdir,
log_every_n_steps=1,
max_steps=max_steps,
callbacks=[monitor],
accelerator="auto",
)

trainer.fit(model)
Expand All @@ -38,45 +55,71 @@ def test_base_log_interval_override(log_histogram, tmpdir, log_every_n_steps, ma
],
)
@mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram")
def test_base_log_interval_fallback(log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir):
def test_base_log_interval_fallback(
log_histogram,
tmpdir,
log_every_n_steps,
max_steps,
expected_calls,
datadir,
catch_warnings,
):
"""Test that if log_every_n_steps not set in the callback, fallback to what is defined in the Trainer."""
warnings.filterwarnings(
otaj marked this conversation as resolved.
Show resolved Hide resolved
"ignore",
message=".*does not have many workers which may be a bottleneck.*",
category=PossibleUserWarning,
)
monitor = TrainingDataMonitor()
model = LitMNIST(data_dir=datadir, num_workers=0)
trainer = Trainer(
default_root_dir=tmpdir,
log_every_n_steps=log_every_n_steps,
max_steps=max_steps,
callbacks=[monitor],
accelerator="auto",
)
trainer.fit(model)
assert log_histogram.call_count == (expected_calls * 2) # 2 tensors per log call


def test_base_no_logger_warning():
def test_base_no_logger_warning(catch_warnings):
"""Test a warning is displayed when Trainer has no logger."""
monitor = TrainingDataMonitor()
trainer = Trainer(logger=False, callbacks=[monitor])
trainer = Trainer(logger=False, callbacks=[monitor], accelerator="auto", max_epochs=-1)
with pytest.warns(UserWarning, match="Cannot log histograms because Trainer has no logger"):
monitor.on_train_start(trainer, pl_module=None)
monitor.on_train_start(trainer, pl_module=LightningModule())


def test_base_unsupported_logger_warning(tmpdir):
def test_base_unsupported_logger_warning(tmpdir, catch_warnings):
"""Test a warning is displayed when an unsupported logger is used."""
warnings.filterwarnings(
"ignore",
message=".*is deprecated in v1.6.*",
category=LightningDeprecationWarning,
)
otaj marked this conversation as resolved.
Show resolved Hide resolved
monitor = TrainingDataMonitor()
trainer = Trainer(logger=LoggerCollection([TensorBoardLogger(tmpdir)]), callbacks=[monitor])
trainer = Trainer(
logger=LoggerCollection([TensorBoardLogger(tmpdir)]),
callbacks=[monitor],
accelerator="auto",
max_epochs=1,
)
with pytest.warns(UserWarning, match="does not support logging with LoggerCollection"):
monitor.on_train_start(trainer, pl_module=None)
monitor.on_train_start(trainer, pl_module=LightningModule())


@mock.patch("pl_bolts.callbacks.data_monitor.TrainingDataMonitor.log_histogram")
def test_training_data_monitor(log_histogram, tmpdir, datadir):
def test_training_data_monitor(log_histogram, tmpdir, datadir, catch_warnings):
"""Test that the TrainingDataMonitor logs histograms of data points going into training_step."""
monitor = TrainingDataMonitor()
model = LitMNIST(data_dir=datadir)
trainer = Trainer(
default_root_dir=tmpdir,
log_every_n_steps=1,
callbacks=[monitor],
accelerator="auto",
max_epochs=1,
)
monitor.on_train_start(trainer, model)

Expand Down Expand Up @@ -121,7 +164,7 @@ def forward(self, *args, **kwargs):
return self.sub_layer(*args, **kwargs)


class ModuleDataMonitorModel(nn.Module):
class ModuleDataMonitorModel(LightningModule):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(12, 5)
Expand All @@ -141,14 +184,16 @@ def forward(self, x):


@mock.patch("pl_bolts.callbacks.data_monitor.ModuleDataMonitor.log_histogram")
def test_module_data_monitor_forward(log_histogram, tmpdir):
def test_module_data_monitor_forward(log_histogram, tmpdir, catch_warnings):
"""Test that the default ModuleDataMonitor logs inputs and outputs of model's forward."""
monitor = ModuleDataMonitor(submodules=None)
model = ModuleDataMonitorModel()
trainer = Trainer(
default_root_dir=tmpdir,
log_every_n_steps=1,
callbacks=[monitor],
accelerator="auto",
max_epochs=1,
)
monitor.on_train_start(trainer, model)
monitor.on_train_batch_start(trainer, model, batch=None, batch_idx=0)
Expand All @@ -162,14 +207,16 @@ def test_module_data_monitor_forward(log_histogram, tmpdir):


@mock.patch("pl_bolts.callbacks.data_monitor.ModuleDataMonitor.log_histogram")
def test_module_data_monitor_submodules_all(log_histogram, tmpdir):
def test_module_data_monitor_submodules_all(log_histogram, tmpdir, catch_warnings):
"""Test that the ModuleDataMonitor logs the inputs and outputs of each submodule."""
monitor = ModuleDataMonitor(submodules=True)
model = ModuleDataMonitorModel()
trainer = Trainer(
default_root_dir=tmpdir,
log_every_n_steps=1,
callbacks=[monitor],
accelerator="auto",
max_epochs=1,
)
monitor.on_train_start(trainer, model)
monitor.on_train_batch_start(trainer, model, batch=None, batch_idx=0)
Expand All @@ -189,14 +236,16 @@ def test_module_data_monitor_submodules_all(log_histogram, tmpdir):


@mock.patch("pl_bolts.callbacks.data_monitor.ModuleDataMonitor.log_histogram")
def test_module_data_monitor_submodules_specific(log_histogram, tmpdir):
def test_module_data_monitor_submodules_specific(log_histogram, tmpdir, catch_warnings):
"""Test that the ModuleDataMonitor logs the inputs and outputs of selected submodules."""
monitor = ModuleDataMonitor(submodules=["layer1", "layer2.sub_layer"])
model = ModuleDataMonitorModel()
trainer = Trainer(
default_root_dir=tmpdir,
log_every_n_steps=1,
callbacks=[monitor],
accelerator="auto",
max_epochs=1,
)
monitor.on_train_start(trainer, model)
monitor.on_train_batch_start(trainer, model, batch=None, batch_idx=0)
Expand Down