Skip to content

Commit

Permalink
Add TRAINS experiment manager support (Lightning-AI#1122)
Browse files Browse the repository at this point in the history
* Add allegro.ai TRAINS experiment manager support

* improve docstring and type hinting, fix the bug in log_metrics, add support torch.Tensor to input into log_image

* complete missing docstring of constructor's arguments

* fix docs

* pep8

* pep8

* remove redundant typing
use logging
fix typing and pep8

* remove deprecated interface

* add TrainsLogger test

* add TrainsLogger PR in CHANGELOG

* add id/name property documentation

* change logging as log

Co-authored-by: bmartinn <>
Co-authored-by: Sou Uchida <s.aiueo32@gmail.com>
  • Loading branch information
2 people authored and tullie committed Apr 3, 2020
1 parent 0c48ba3 commit 923f4eb
Show file tree
Hide file tree
Showing 9 changed files with 372 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `TrainsLogger` class ([#1122](https://github.com/PyTorchLightning/pytorch-lightning/pull/1122))
- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))
- Added support for non-primitive types in hparams for TensorboardLogger ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))
Expand Down
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def setup(app):
MOCK_REQUIRE_PACKAGES.append(pkg.rstrip())

# TODO: better parse from package since the import name and package name may differ
MOCK_MANUAL_PACKAGES = ['torch', 'torchvision', 'test_tube', 'mlflow', 'comet_ml', 'wandb', 'neptune']
MOCK_MANUAL_PACKAGES = ['torch', 'torchvision', 'test_tube',
'mlflow', 'comet_ml', 'wandb', 'neptune', 'trains']
autodoc_mock_imports = MOCK_REQUIRE_PACKAGES + MOCK_MANUAL_PACKAGES
# for mod_name in MOCK_REQUIRE_PACKAGES:
# sys.modules[mod_name] = mock.Mock()
Expand Down
28 changes: 28 additions & 0 deletions docs/source/experiment_logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,34 @@ The Neptune.ai is available anywhere except ``__init__`` in your LightningModule
some_img = fake_image()
self.logger.experiment.add_image('generated_images', some_img, 0)
allegro.ai TRAINS
^^^^^^^^^^^^^^^^^

`allegro.ai <https://github.com/allegroai/trains/>`_ is a third-party logger.
To use TRAINS as your logger do the following.

.. note:: See: :ref:`trains` docs.

.. code-block:: python
from pytorch_lightning.loggers import TrainsLogger
trains_logger = TrainsLogger(
project_name="examples",
task_name="pytorch lightning test"
)
trainer = Trainer(logger=trains_logger)
The TrainsLogger is available anywhere in your LightningModule

.. code-block:: python
class MyModule(pl.LightningModule):
def __init__(self, ...):
some_img = fake_image()
self.logger.log_image('debug', 'generated_image_0', some_img, 0)
Tensorboard
^^^^^^^^^^^

Expand Down
2 changes: 1 addition & 1 deletion docs/source/experiment_reporting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ want to log using this trainer flag.
Log metrics
^^^^^^^^^^^

To plot metrics into whatever logger you passed in (tensorboard, comet, neptune, etc...)
To plot metrics into whatever logger you passed in (tensorboard, comet, neptune, TRAINS, etc...)

1. training_epoch_end, validation_epoch_end, test_epoch_end will all log anything in the "log" key of the return dict.

Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ dependencies:
- comet_ml>=1.0.56
- wandb>=0.8.21
- neptune-client>=0.4.4
- trains>=0.13.3
6 changes: 6 additions & 0 deletions pytorch_lightning/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,9 @@ def any_lightning_module_function_or_hook(...):
__all__.append('WandbLogger')
except ImportError:
pass

try:
from .trains import TrainsLogger
__all__.append('TrainsLogger')
except ImportError:
pass
283 changes: 283 additions & 0 deletions pytorch_lightning/loggers/trains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
"""
Log using `allegro.ai TRAINS <https://github.com/allegroai/trains>'_
.. code-block:: python
from pytorch_lightning.loggers import TrainsLogger
trains_logger = TrainsLogger(
project_name="pytorch lightning",
task_name="default",
)
trainer = Trainer(logger=trains_logger)
Use the logger anywhere in you LightningModule as follows:
.. code-block:: python
def train_step(...):
# example
self.logger.experiment.whatever_trains_supports(...)
def any_lightning_module_function_or_hook(...):
self.logger.experiment.whatever_trains_supports(...)
"""

import logging as log
from argparse import Namespace
from pathlib import Path
from typing import Any, Dict, Optional, Union

import PIL
import numpy as np
import pandas as pd
import torch

try:
import trains
except ImportError:
raise ImportError('You want to use `TRAINS` logger which is not installed yet,'
' install it with `pip install trains`.')

from .base import LightningLoggerBase, rank_zero_only


class TrainsLogger(LightningLoggerBase):
"""Logs using TRAINS
Args:
project_name: The name of the experiment's project. Defaults to None.
task_name: The name of the experiment. Defaults to None.
task_type: The name of the experiment. Defaults to 'training'.
reuse_last_task_id: Start with the previously used task id. Defaults to True.
output_uri: Default location for output models. Defaults to None.
auto_connect_arg_parser: Automatically grab the ArgParser
and connect it with the task. Defaults to True.
auto_connect_frameworks: If True, automatically patch to trains backend. Defaults to True.
auto_resource_monitoring: If true, machine vitals will be
sent along side the task scalars. Defaults to True.
"""

def __init__(
self, project_name: Optional[str] = None, task_name: Optional[str] = None,
task_type: str = 'training', reuse_last_task_id: bool = True,
output_uri: Optional[str] = None, auto_connect_arg_parser: bool = True,
auto_connect_frameworks: bool = True, auto_resource_monitoring: bool = True) -> None:
super().__init__()
self._trains = trains.Task.init(
project_name=project_name, task_name=task_name, task_type=task_type,
reuse_last_task_id=reuse_last_task_id, output_uri=output_uri,
auto_connect_arg_parser=auto_connect_arg_parser,
auto_connect_frameworks=auto_connect_frameworks,
auto_resource_monitoring=auto_resource_monitoring
)

@property
def experiment(self) -> trains.Task:
r"""Actual TRAINS object. To use TRAINS features do the following.
Example:
.. code-block:: python
self.logger.experiment.some_trains_function()
"""
return self._trains

@property
def id(self) -> Union[str, None]:
"""
ID is a uuid (string) representing this specific experiment in the entire system.
"""
if not self._trains:
return None
return self._trains.id

@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
"""Log hyperparameters (numeric values) in TRAINS experiments
Args:
params:
The hyperparameters that passed through the model.
"""
if not self._trains:
return None
if not params:
return
if isinstance(params, dict):
self._trains.connect(params)
else:
self._trains.connect(vars(params))

@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
"""Log metrics (numeric values) in TRAINS experiments.
This method will be called by Trainer.
Args:
metrics:
The dictionary of the metrics.
If the key contains "/", it will be split by the delimiter,
then the elements will be logged as "title" and "series" respectively.
step: Step number at which the metrics should be recorded. Defaults to None.
"""
if not self._trains:
return None

if not step:
step = self._trains.get_last_iteration()

for k, v in metrics.items():
if isinstance(v, str):
log.warning("Discarding metric with string value {}={}".format(k, v))
continue
if isinstance(v, torch.Tensor):
v = v.item()
parts = k.split('/')
if len(parts) <= 1:
series = title = k
else:
title = parts[0]
series = '/'.join(parts[1:])
self._trains.get_logger().report_scalar(
title=title, series=series, value=v, iteration=step)

@rank_zero_only
def log_metric(self, title: str, series: str, value: float, step: Optional[int] = None) -> None:
"""Log metrics (numeric values) in TRAINS experiments.
This method will be called by the users.
Args:
title: The title of the graph to log, e.g. loss, accuracy.
series: The series name in the graph, e.g. classification, localization.
value: The value to log.
step: Step number at which the metrics should be recorded. Defaults to None.
"""
if not self._trains:
return None

if not step:
step = self._trains.get_last_iteration()

if isinstance(value, torch.Tensor):
value = value.item()
self._trains.get_logger().report_scalar(
title=title, series=series, value=value, iteration=step)

@rank_zero_only
def log_text(self, text: str) -> None:
"""Log console text data in TRAINS experiment
Args:
text: The value of the log (data-point).
"""
if not self._trains:
return None

self._trains.get_logger().report_text(text)

@rank_zero_only
def log_image(
self, title: str, series: str,
image: Union[str, np.ndarray, PIL.Image.Image, torch.Tensor],
step: Optional[int] = None) -> None:
"""Log Debug image in TRAINS experiment
Args:
title: The title of the debug image, i.e. "failed", "passed".
series: The series name of the debug image, i.e. "Image 0", "Image 1".
image:
Debug image to log. Can be one of the following types:
Torch, Numpy, PIL image, path to image file (str)
If Numpy or Torch, the image is assume to be the following:
shape: CHW
color space: RGB
value range: [0., 1.] (float) or [0, 255] (uint8)
step:
Step number at which the metrics should be recorded. Defaults to None.
"""
if not self._trains:
return None

if not step:
step = self._trains.get_last_iteration()

if isinstance(image, str):
self._trains.get_logger().report_image(
title=title, series=series, local_path=image, iteration=step)
else:
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
if isinstance(image, np.ndarray):
image = image.transpose(1, 2, 0)
self._trains.get_logger().report_image(
title=title, series=series, image=image, iteration=step)

@rank_zero_only
def log_artifact(
self, name: str,
artifact: Union[str, Path, Dict[str, Any], pd.DataFrame, np.ndarray, PIL.Image.Image],
metadata: Optional[Dict[str, Any]] = None, delete_after_upload: bool = False) -> None:
"""Save an artifact (file/object) in TRAINS experiment storage.
Args:
name: Artifact name. Notice! it will override previous artifact
if name already exists
artifact: Artifact object to upload. Currently supports:
- string / pathlib2.Path are treated as path to artifact file to upload
If wildcard or a folder is passed, zip file containing the
local files will be created and uploaded
- dict will be stored as .json file and uploaded
- pandas.DataFrame will be stored as .csv.gz (compressed CSV file) and uploaded
- numpy.ndarray will be stored as .npz and uploaded
- PIL.Image will be stored to .png file and uploaded
metadata:
Simple key/value dictionary to store on the artifact. Defaults to None.
delete_after_upload:
If True local artifact will be deleted (only applies if artifact_object is a
local file). Defaults to False.
"""
if not self._trains:
return None

self._trains.upload_artifact(
name=name, artifact_object=artifact, metadata=metadata,
delete_after_upload=delete_after_upload
)

def save(self) -> None:
pass

@rank_zero_only
def finalize(self, status: str) -> None:
if not self._trains:
return None
self._trains.close()
self._trains = None

@property
def name(self) -> Union[str, None]:
"""
Name is a human readable non-unique name (str) of the experiment.
"""
if not self._trains:
return None
return self._trains.name

@property
def version(self) -> Union[str, None]:
if not self._trains:
return None
return self._trains.id

def __getstate__(self) -> Union[str, None]:
if not self._trains:
return None
return self._trains.id

def __setstate__(self, state: str) -> None:
self._rank = 0
self._trains = None
if state:
self._trains = trains.Task.get_task(task_id=state)
3 changes: 2 additions & 1 deletion requirements-extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ neptune-client>=0.4.4
comet-ml>=1.0.56
mlflow>=1.0.0
test_tube>=0.7.5
wandb>=0.8.21
wandb>=0.8.21
trains>=0.13.3
Loading

0 comments on commit 923f4eb

Please sign in to comment.