From c0bedd25872d12d1e1f55c3f02b31f78d006cdf1 Mon Sep 17 00:00:00 2001 From: "Martin.B" <51887611+bmartinn@users.noreply.github.com> Date: Sat, 14 Mar 2020 19:02:14 +0200 Subject: [PATCH] Add TRAINS experiment manager support (#1122) * Add allegro.ai TRAINS experiment manager support * improve docstring and type hinting, fix the bug in log_metrics, add support torch.Tensor to input into log_image * complete missing docstring of constructor's arguments * fix docs * pep8 * pep8 * remove redundant typing use logging fix typing and pep8 * remove deprecated interface * add TrainsLogger test * add TrainsLogger PR in CHANGELOG * add id/name property documentation * change logging as log Co-authored-by: bmartinn <> Co-authored-by: Sou Uchida --- CHANGELOG.md | 1 + docs/source/conf.py | 3 +- docs/source/experiment_logging.rst | 28 +++ docs/source/experiment_reporting.rst | 2 +- environment.yml | 1 + pytorch_lightning/loggers/__init__.py | 6 + pytorch_lightning/loggers/trains.py | 283 ++++++++++++++++++++++++++ requirements-extra.txt | 3 +- tests/loggers/test_trains.py | 48 +++++ 9 files changed, 372 insertions(+), 3 deletions(-) create mode 100644 pytorch_lightning/loggers/trains.py create mode 100644 tests/loggers/test_trains.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d6cbc3d08781..11a24b8924c5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `TrainsLogger` class ([#1122](https://github.com/PyTorchLightning/pytorch-lightning/pull/1122)) - Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946)) - Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104)) - Added support for non-primitive types in hparams for TensorboardLogger ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130)) diff --git a/docs/source/conf.py b/docs/source/conf.py index d1c0b54b0a7f1..ce8508f18515d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -297,7 +297,8 @@ def setup(app): MOCK_REQUIRE_PACKAGES.append(pkg.rstrip()) # TODO: better parse from package since the import name and package name may differ -MOCK_MANUAL_PACKAGES = ['torch', 'torchvision', 'test_tube', 'mlflow', 'comet_ml', 'wandb', 'neptune'] +MOCK_MANUAL_PACKAGES = ['torch', 'torchvision', 'test_tube', + 'mlflow', 'comet_ml', 'wandb', 'neptune', 'trains'] autodoc_mock_imports = MOCK_REQUIRE_PACKAGES + MOCK_MANUAL_PACKAGES # for mod_name in MOCK_REQUIRE_PACKAGES: # sys.modules[mod_name] = mock.Mock() diff --git a/docs/source/experiment_logging.rst b/docs/source/experiment_logging.rst index 1edd067511b48..b3fe825aede77 100644 --- a/docs/source/experiment_logging.rst +++ b/docs/source/experiment_logging.rst @@ -62,6 +62,34 @@ The Neptune.ai is available anywhere except ``__init__`` in your LightningModule some_img = fake_image() self.logger.experiment.add_image('generated_images', some_img, 0) +allegro.ai TRAINS +^^^^^^^^^^^^^^^^^ + +`allegro.ai `_ is a third-party logger. +To use TRAINS as your logger do the following. + +.. note:: See: :ref:`trains` docs. + +.. code-block:: python + + from pytorch_lightning.loggers import TrainsLogger + + trains_logger = TrainsLogger( + project_name="examples", + task_name="pytorch lightning test" + ) + trainer = Trainer(logger=trains_logger) + +The TrainsLogger is available anywhere in your LightningModule + +.. code-block:: python + + class MyModule(pl.LightningModule): + + def __init__(self, ...): + some_img = fake_image() + self.logger.log_image('debug', 'generated_image_0', some_img, 0) + Tensorboard ^^^^^^^^^^^ diff --git a/docs/source/experiment_reporting.rst b/docs/source/experiment_reporting.rst index a738a234c9674..aa5642ab20457 100644 --- a/docs/source/experiment_reporting.rst +++ b/docs/source/experiment_reporting.rst @@ -32,7 +32,7 @@ want to log using this trainer flag. Log metrics ^^^^^^^^^^^ -To plot metrics into whatever logger you passed in (tensorboard, comet, neptune, etc...) +To plot metrics into whatever logger you passed in (tensorboard, comet, neptune, TRAINS, etc...) 1. training_epoch_end, validation_epoch_end, test_epoch_end will all log anything in the "log" key of the return dict. diff --git a/environment.yml b/environment.yml index 6dae860611b03..1aa7c8f14820e 100644 --- a/environment.yml +++ b/environment.yml @@ -32,3 +32,4 @@ dependencies: - comet_ml>=1.0.56 - wandb>=0.8.21 - neptune-client>=0.4.4 + - trains>=0.13.3 diff --git a/pytorch_lightning/loggers/__init__.py b/pytorch_lightning/loggers/__init__.py index adcba876d26f5..93d1b737aab9b 100644 --- a/pytorch_lightning/loggers/__init__.py +++ b/pytorch_lightning/loggers/__init__.py @@ -119,3 +119,9 @@ def any_lightning_module_function_or_hook(...): __all__.append('WandbLogger') except ImportError: pass + +try: + from .trains import TrainsLogger + __all__.append('TrainsLogger') +except ImportError: + pass diff --git a/pytorch_lightning/loggers/trains.py b/pytorch_lightning/loggers/trains.py new file mode 100644 index 0000000000000..a56c6c126798e --- /dev/null +++ b/pytorch_lightning/loggers/trains.py @@ -0,0 +1,283 @@ +""" +Log using `allegro.ai TRAINS '_ + +.. code-block:: python + + from pytorch_lightning.loggers import TrainsLogger + trains_logger = TrainsLogger( + project_name="pytorch lightning", + task_name="default", + ) + trainer = Trainer(logger=trains_logger) + + +Use the logger anywhere in you LightningModule as follows: + +.. code-block:: python + + def train_step(...): + # example + self.logger.experiment.whatever_trains_supports(...) + + def any_lightning_module_function_or_hook(...): + self.logger.experiment.whatever_trains_supports(...) + +""" + +import logging as log +from argparse import Namespace +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import PIL +import numpy as np +import pandas as pd +import torch + +try: + import trains +except ImportError: + raise ImportError('You want to use `TRAINS` logger which is not installed yet,' + ' install it with `pip install trains`.') + +from .base import LightningLoggerBase, rank_zero_only + + +class TrainsLogger(LightningLoggerBase): + """Logs using TRAINS + + Args: + project_name: The name of the experiment's project. Defaults to None. + task_name: The name of the experiment. Defaults to None. + task_type: The name of the experiment. Defaults to 'training'. + reuse_last_task_id: Start with the previously used task id. Defaults to True. + output_uri: Default location for output models. Defaults to None. + auto_connect_arg_parser: Automatically grab the ArgParser + and connect it with the task. Defaults to True. + auto_connect_frameworks: If True, automatically patch to trains backend. Defaults to True. + auto_resource_monitoring: If true, machine vitals will be + sent along side the task scalars. Defaults to True. + """ + + def __init__( + self, project_name: Optional[str] = None, task_name: Optional[str] = None, + task_type: str = 'training', reuse_last_task_id: bool = True, + output_uri: Optional[str] = None, auto_connect_arg_parser: bool = True, + auto_connect_frameworks: bool = True, auto_resource_monitoring: bool = True) -> None: + super().__init__() + self._trains = trains.Task.init( + project_name=project_name, task_name=task_name, task_type=task_type, + reuse_last_task_id=reuse_last_task_id, output_uri=output_uri, + auto_connect_arg_parser=auto_connect_arg_parser, + auto_connect_frameworks=auto_connect_frameworks, + auto_resource_monitoring=auto_resource_monitoring + ) + + @property + def experiment(self) -> trains.Task: + r"""Actual TRAINS object. To use TRAINS features do the following. + + Example: + .. code-block:: python + self.logger.experiment.some_trains_function() + + """ + return self._trains + + @property + def id(self) -> Union[str, None]: + """ + ID is a uuid (string) representing this specific experiment in the entire system. + """ + if not self._trains: + return None + return self._trains.id + + @rank_zero_only + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + """Log hyperparameters (numeric values) in TRAINS experiments + + Args: + params: + The hyperparameters that passed through the model. + """ + if not self._trains: + return None + if not params: + return + if isinstance(params, dict): + self._trains.connect(params) + else: + self._trains.connect(vars(params)) + + @rank_zero_only + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + """Log metrics (numeric values) in TRAINS experiments. + This method will be called by Trainer. + + Args: + metrics: + The dictionary of the metrics. + If the key contains "/", it will be split by the delimiter, + then the elements will be logged as "title" and "series" respectively. + step: Step number at which the metrics should be recorded. Defaults to None. + """ + if not self._trains: + return None + + if not step: + step = self._trains.get_last_iteration() + + for k, v in metrics.items(): + if isinstance(v, str): + log.warning("Discarding metric with string value {}={}".format(k, v)) + continue + if isinstance(v, torch.Tensor): + v = v.item() + parts = k.split('/') + if len(parts) <= 1: + series = title = k + else: + title = parts[0] + series = '/'.join(parts[1:]) + self._trains.get_logger().report_scalar( + title=title, series=series, value=v, iteration=step) + + @rank_zero_only + def log_metric(self, title: str, series: str, value: float, step: Optional[int] = None) -> None: + """Log metrics (numeric values) in TRAINS experiments. + This method will be called by the users. + + Args: + title: The title of the graph to log, e.g. loss, accuracy. + series: The series name in the graph, e.g. classification, localization. + value: The value to log. + step: Step number at which the metrics should be recorded. Defaults to None. + """ + if not self._trains: + return None + + if not step: + step = self._trains.get_last_iteration() + + if isinstance(value, torch.Tensor): + value = value.item() + self._trains.get_logger().report_scalar( + title=title, series=series, value=value, iteration=step) + + @rank_zero_only + def log_text(self, text: str) -> None: + """Log console text data in TRAINS experiment + + Args: + text: The value of the log (data-point). + """ + if not self._trains: + return None + + self._trains.get_logger().report_text(text) + + @rank_zero_only + def log_image( + self, title: str, series: str, + image: Union[str, np.ndarray, PIL.Image.Image, torch.Tensor], + step: Optional[int] = None) -> None: + """Log Debug image in TRAINS experiment + + Args: + title: The title of the debug image, i.e. "failed", "passed". + series: The series name of the debug image, i.e. "Image 0", "Image 1". + image: + Debug image to log. Can be one of the following types: + Torch, Numpy, PIL image, path to image file (str) + If Numpy or Torch, the image is assume to be the following: + shape: CHW + color space: RGB + value range: [0., 1.] (float) or [0, 255] (uint8) + step: + Step number at which the metrics should be recorded. Defaults to None. + """ + if not self._trains: + return None + + if not step: + step = self._trains.get_last_iteration() + + if isinstance(image, str): + self._trains.get_logger().report_image( + title=title, series=series, local_path=image, iteration=step) + else: + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + if isinstance(image, np.ndarray): + image = image.transpose(1, 2, 0) + self._trains.get_logger().report_image( + title=title, series=series, image=image, iteration=step) + + @rank_zero_only + def log_artifact( + self, name: str, + artifact: Union[str, Path, Dict[str, Any], pd.DataFrame, np.ndarray, PIL.Image.Image], + metadata: Optional[Dict[str, Any]] = None, delete_after_upload: bool = False) -> None: + """Save an artifact (file/object) in TRAINS experiment storage. + + Args: + name: Artifact name. Notice! it will override previous artifact + if name already exists + artifact: Artifact object to upload. Currently supports: + - string / pathlib2.Path are treated as path to artifact file to upload + If wildcard or a folder is passed, zip file containing the + local files will be created and uploaded + - dict will be stored as .json file and uploaded + - pandas.DataFrame will be stored as .csv.gz (compressed CSV file) and uploaded + - numpy.ndarray will be stored as .npz and uploaded + - PIL.Image will be stored to .png file and uploaded + metadata: + Simple key/value dictionary to store on the artifact. Defaults to None. + delete_after_upload: + If True local artifact will be deleted (only applies if artifact_object is a + local file). Defaults to False. + """ + if not self._trains: + return None + + self._trains.upload_artifact( + name=name, artifact_object=artifact, metadata=metadata, + delete_after_upload=delete_after_upload + ) + + def save(self) -> None: + pass + + @rank_zero_only + def finalize(self, status: str) -> None: + if not self._trains: + return None + self._trains.close() + self._trains = None + + @property + def name(self) -> Union[str, None]: + """ + Name is a human readable non-unique name (str) of the experiment. + """ + if not self._trains: + return None + return self._trains.name + + @property + def version(self) -> Union[str, None]: + if not self._trains: + return None + return self._trains.id + + def __getstate__(self) -> Union[str, None]: + if not self._trains: + return None + return self._trains.id + + def __setstate__(self, state: str) -> None: + self._rank = 0 + self._trains = None + if state: + self._trains = trains.Task.get_task(task_id=state) diff --git a/requirements-extra.txt b/requirements-extra.txt index dd153091052e8..1265bc654a6e4 100644 --- a/requirements-extra.txt +++ b/requirements-extra.txt @@ -2,4 +2,5 @@ neptune-client>=0.4.4 comet-ml>=1.0.56 mlflow>=1.0.0 test_tube>=0.7.5 -wandb>=0.8.21 \ No newline at end of file +wandb>=0.8.21 +trains>=0.13.3 diff --git a/tests/loggers/test_trains.py b/tests/loggers/test_trains.py new file mode 100644 index 0000000000000..1c8ca4167462a --- /dev/null +++ b/tests/loggers/test_trains.py @@ -0,0 +1,48 @@ +import pickle + +import tests.models.utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.loggers import TrainsLogger +from tests.models import LightningTestModel + + +def test_trains_logger(tmpdir): + """Verify that basic functionality of TRAINS logger works.""" + tutils.reset_seed() + + hparams = tutils.get_hparams() + model = LightningTestModel(hparams) + logger = TrainsLogger(project_name="examples", task_name="pytorch lightning test") + + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + train_percent_check=0.05, + logger=logger + ) + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + print('result finished') + assert result == 1, "Training failed" + + +def test_trains_pickle(tmpdir): + """Verify that pickling trainer with TRAINS logger works.""" + tutils.reset_seed() + + # hparams = tutils.get_hparams() + # model = LightningTestModel(hparams) + + logger = TrainsLogger(project_name="examples", task_name="pytorch lightning test") + + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + logger=logger + ) + + trainer = Trainer(**trainer_options) + pkl_bytes = pickle.dumps(trainer) + trainer2 = pickle.loads(pkl_bytes) + trainer2.logger.log_metrics({"acc": 1.0})