Skip to content

Commit

Permalink
Feature/log computational graph (#3003)
Browse files Browse the repository at this point in the history
* add methods

* log in trainer

* add tests

* changelog

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* text

* added argument

* update tests

* fix styling

* improve testing
  • Loading branch information
SkafteNicki authored Aug 19, 2020
1 parent 7b917de commit cefc7f7
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 5 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added saving test predictions on multiple GPUs ([#2926](https://github.com/PyTorchLightning/pytorch-lightning/pull/2926))

- Auto log the computational graph for loggers that support this ([#3003](https://github.com/PyTorchLightning/pytorch-lightning/pull/3003))

### Changed

- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))
Expand Down Expand Up @@ -110,7 +112,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed test metrics not being logged with `LoggerCollection` ([#2723](https://github.com/PyTorchLightning/pytorch-lightning/pull/2723))

- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689))
- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689))

- Fixed shuffle argument for distributed sampler ([#2789](https://github.com/PyTorchLightning/pytorch-lightning/pull/2789))

Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch

from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.core.lightning import LightningModule


class LightningLoggerBase(ABC):
Expand Down Expand Up @@ -220,6 +221,16 @@ def log_hyperparams(self, params: argparse.Namespace):
params: :class:`~argparse.Namespace` containing the hyperparameters
"""

def log_graph(self, model: LightningModule, input_array=None) -> None:
"""
Record model graph
Args:
model: lightning model
input_array: input passes to `model.forward`
"""
pass

def save(self) -> None:
"""Save log data."""
self._finalize_agg_metrics()
Expand Down Expand Up @@ -296,6 +307,10 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
for logger in self._logger_iterable:
logger.log_hyperparams(params)

def log_graph(self, model: LightningModule, input_array=None) -> None:
for logger in self._logger_iterable:
logger.log_graph(model, input_array)

def save(self) -> None:
for logger in self._logger_iterable:
logger.save()
Expand Down
26 changes: 25 additions & 1 deletion pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
from pytorch_lightning.core.lightning import LightningModule

try:
from omegaconf import Container, OmegaConf
Expand Down Expand Up @@ -47,6 +48,9 @@ class TensorBoardLogger(LightningLoggerBase):
directory for existing versions, then automatically assigns the next available version.
If it is a string then it is used as the run-specific subdirectory name,
otherwise ``'version_${version}'`` is used.
log_graph: Adds the computational graph to tensorboard. This requires that
the user has defined the `self.example_input_array` attribute in their
model.
\**kwargs: Other arguments are passed directly to the :class:`SummaryWriter` constructor.
"""
Expand All @@ -56,11 +60,13 @@ def __init__(self,
save_dir: str,
name: Optional[str] = "default",
version: Optional[Union[int, str]] = None,
log_graph: bool = True,
**kwargs):
super().__init__()
self._save_dir = save_dir
self._name = name or ''
self._version = version
self._log_graph = log_graph

self._experiment = None
self.hparams = {}
Expand Down Expand Up @@ -160,6 +166,24 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
v = v.item()
self.experiment.add_scalar(k, v, step)

@rank_zero_only
def log_graph(self, model: LightningModule, input_array=None):
if self._log_graph:
if input_array is None:
input_array = model.example_input_array

if input_array is not None:
self.experiment.add_graph(
model,
model.transfer_batch_to_device(
model.example_input_array, model.device)
)
else:
rank_zero_warn('Could not log computational graph since the'
' `model.example_input_array` attribute is not set'
' or `input_array` was not given',
UserWarning)

@rank_zero_only
def save(self) -> None:
super().save()
Expand Down
29 changes: 26 additions & 3 deletions pytorch_lightning/loggers/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
_TEST_TUBE_AVAILABLE = False

from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.core.lightning import LightningModule


class TestTubeLogger(LightningLoggerBase):
Expand Down Expand Up @@ -51,7 +52,9 @@ class TestTubeLogger(LightningLoggerBase):
version: Experiment version. If version is not specified the logger inspects the save
directory for existing versions, then automatically assigns the next available version.
create_git_tag: If ``True`` creates a git tag to save the code used in this experiment.
log_graph: Adds the computational graph to tensorboard. This requires that
the user has defined the `self.example_input_array` attribute in their
model.
"""

__test__ = False
Expand All @@ -62,7 +65,8 @@ def __init__(self,
description: Optional[str] = None,
debug: bool = False,
version: Optional[int] = None,
create_git_tag: bool = False):
create_git_tag: bool = False,
log_graph=True):

if not _TEST_TUBE_AVAILABLE:
raise ImportError('You want to use `test_tube` logger which is not installed yet,'
Expand All @@ -74,6 +78,7 @@ def __init__(self,
self.debug = debug
self._version = version
self.create_git_tag = create_git_tag
self._log_graph = log_graph
self._experiment = None

@property
Expand Down Expand Up @@ -117,6 +122,24 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
self.experiment.debug = self.debug
self.experiment.log(metrics, global_step=step)

@rank_zero_only
def log_graph(self, model: LightningModule, input_array=None):
if self._log_graph:
if input_array is None:
input_array = model.example_input_array

if input_array is not None:
self.experiment.add_graph(
model,
model.transfer_batch_to_device(
model.example_input_array, model.device)
)
else:
rank_zero_warn('Could not log computational graph since the'
' `model.example_input_array` attribute is not set'
' or `input_array` was not given',
UserWarning)

@rank_zero_only
def save(self) -> None:
super().save()
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,7 @@ def run_pretrain_routine(self, model: LightningModule):
if self.logger is not None:
# save exp to get started
self.logger.log_hyperparams(ref_model.hparams)
self.logger.log_graph(ref_model)
self.logger.save()

if self.use_ddp or self.use_ddp2:
Expand Down
25 changes: 25 additions & 0 deletions tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,28 @@ def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):

metrics = {"abc": torch.tensor([0.54])}
logger.log_hyperparams(hparams, metrics)


@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 28 * 28)])
def test_tensorboard_log_graph(tmpdir, example_input_array):
""" test that log graph works with both model.example_input_array and
if array is passed externaly
"""
model = EvalModelTemplate()
if example_input_array is None:
model.example_input_array = None
logger = TensorBoardLogger(tmpdir)
logger.log_graph(model, example_input_array)


def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
""" test that log graph throws warning if model.example_input_array is None """
model = EvalModelTemplate()
model.example_input_array = None
logger = TensorBoardLogger(tmpdir)
with pytest.warns(
UserWarning,
match='Could not log computational graph since the `model.example_input_array`'
' attribute is not set or `input_array` was not given'
):
logger.log_graph(model)
3 changes: 3 additions & 0 deletions tests/models/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def train_dataloader(self):
)

model = BpttTestModel(**hparams)
model.example_input_array = torch.randn(5, truncated_bptt_steps)

# fit model
trainer = Trainer(
Expand Down Expand Up @@ -424,6 +425,7 @@ def train_dataloader(self):
)

model = BpttTestModel(**hparams)
model.example_input_array = torch.randn(5, truncated_bptt_steps)

# fit model
trainer = Trainer(
Expand Down Expand Up @@ -494,6 +496,7 @@ def train_dataloader(self):
)

model = BpttTestModel(**hparams)
model.example_input_array = torch.randn(5, truncated_bptt_steps)

# fit model
trainer = Trainer(
Expand Down

0 comments on commit cefc7f7

Please sign in to comment.