diff --git a/CHANGELOG.md b/CHANGELOG.md
index 1d6cbc3d087812..11a24b8924c5a3 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
+- Added `TrainsLogger` class ([#1122](https://github.com/PyTorchLightning/pytorch-lightning/pull/1122))
- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))
- Added support for non-primitive types in hparams for TensorboardLogger ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))
diff --git a/docs/source/conf.py b/docs/source/conf.py
index d1c0b54b0a7f10..ce8508f18515df 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -297,7 +297,8 @@ def setup(app):
MOCK_REQUIRE_PACKAGES.append(pkg.rstrip())
# TODO: better parse from package since the import name and package name may differ
-MOCK_MANUAL_PACKAGES = ['torch', 'torchvision', 'test_tube', 'mlflow', 'comet_ml', 'wandb', 'neptune']
+MOCK_MANUAL_PACKAGES = ['torch', 'torchvision', 'test_tube',
+ 'mlflow', 'comet_ml', 'wandb', 'neptune', 'trains']
autodoc_mock_imports = MOCK_REQUIRE_PACKAGES + MOCK_MANUAL_PACKAGES
# for mod_name in MOCK_REQUIRE_PACKAGES:
# sys.modules[mod_name] = mock.Mock()
diff --git a/docs/source/experiment_logging.rst b/docs/source/experiment_logging.rst
index 1edd067511b48c..b3fe825aede77e 100644
--- a/docs/source/experiment_logging.rst
+++ b/docs/source/experiment_logging.rst
@@ -62,6 +62,34 @@ The Neptune.ai is available anywhere except ``__init__`` in your LightningModule
some_img = fake_image()
self.logger.experiment.add_image('generated_images', some_img, 0)
+allegro.ai TRAINS
+^^^^^^^^^^^^^^^^^
+
+`allegro.ai `_ is a third-party logger.
+To use TRAINS as your logger do the following.
+
+.. note:: See: :ref:`trains` docs.
+
+.. code-block:: python
+
+ from pytorch_lightning.loggers import TrainsLogger
+
+ trains_logger = TrainsLogger(
+ project_name="examples",
+ task_name="pytorch lightning test"
+ )
+ trainer = Trainer(logger=trains_logger)
+
+The TrainsLogger is available anywhere in your LightningModule
+
+.. code-block:: python
+
+ class MyModule(pl.LightningModule):
+
+ def __init__(self, ...):
+ some_img = fake_image()
+ self.logger.log_image('debug', 'generated_image_0', some_img, 0)
+
Tensorboard
^^^^^^^^^^^
diff --git a/docs/source/experiment_reporting.rst b/docs/source/experiment_reporting.rst
index a738a234c96745..aa5642ab204571 100644
--- a/docs/source/experiment_reporting.rst
+++ b/docs/source/experiment_reporting.rst
@@ -32,7 +32,7 @@ want to log using this trainer flag.
Log metrics
^^^^^^^^^^^
-To plot metrics into whatever logger you passed in (tensorboard, comet, neptune, etc...)
+To plot metrics into whatever logger you passed in (tensorboard, comet, neptune, TRAINS, etc...)
1. training_epoch_end, validation_epoch_end, test_epoch_end will all log anything in the "log" key of the return dict.
diff --git a/environment.yml b/environment.yml
index 6dae860611b03c..1aa7c8f14820e7 100644
--- a/environment.yml
+++ b/environment.yml
@@ -32,3 +32,4 @@ dependencies:
- comet_ml>=1.0.56
- wandb>=0.8.21
- neptune-client>=0.4.4
+ - trains>=0.13.3
diff --git a/pytorch_lightning/loggers/__init__.py b/pytorch_lightning/loggers/__init__.py
index adcba876d26f5b..93d1b737aab9b6 100644
--- a/pytorch_lightning/loggers/__init__.py
+++ b/pytorch_lightning/loggers/__init__.py
@@ -119,3 +119,9 @@ def any_lightning_module_function_or_hook(...):
__all__.append('WandbLogger')
except ImportError:
pass
+
+try:
+ from .trains import TrainsLogger
+ __all__.append('TrainsLogger')
+except ImportError:
+ pass
diff --git a/pytorch_lightning/loggers/trains.py b/pytorch_lightning/loggers/trains.py
new file mode 100644
index 00000000000000..a56c6c126798e5
--- /dev/null
+++ b/pytorch_lightning/loggers/trains.py
@@ -0,0 +1,283 @@
+"""
+Log using `allegro.ai TRAINS '_
+
+.. code-block:: python
+
+ from pytorch_lightning.loggers import TrainsLogger
+ trains_logger = TrainsLogger(
+ project_name="pytorch lightning",
+ task_name="default",
+ )
+ trainer = Trainer(logger=trains_logger)
+
+
+Use the logger anywhere in you LightningModule as follows:
+
+.. code-block:: python
+
+ def train_step(...):
+ # example
+ self.logger.experiment.whatever_trains_supports(...)
+
+ def any_lightning_module_function_or_hook(...):
+ self.logger.experiment.whatever_trains_supports(...)
+
+"""
+
+import logging as log
+from argparse import Namespace
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+import PIL
+import numpy as np
+import pandas as pd
+import torch
+
+try:
+ import trains
+except ImportError:
+ raise ImportError('You want to use `TRAINS` logger which is not installed yet,'
+ ' install it with `pip install trains`.')
+
+from .base import LightningLoggerBase, rank_zero_only
+
+
+class TrainsLogger(LightningLoggerBase):
+ """Logs using TRAINS
+
+ Args:
+ project_name: The name of the experiment's project. Defaults to None.
+ task_name: The name of the experiment. Defaults to None.
+ task_type: The name of the experiment. Defaults to 'training'.
+ reuse_last_task_id: Start with the previously used task id. Defaults to True.
+ output_uri: Default location for output models. Defaults to None.
+ auto_connect_arg_parser: Automatically grab the ArgParser
+ and connect it with the task. Defaults to True.
+ auto_connect_frameworks: If True, automatically patch to trains backend. Defaults to True.
+ auto_resource_monitoring: If true, machine vitals will be
+ sent along side the task scalars. Defaults to True.
+ """
+
+ def __init__(
+ self, project_name: Optional[str] = None, task_name: Optional[str] = None,
+ task_type: str = 'training', reuse_last_task_id: bool = True,
+ output_uri: Optional[str] = None, auto_connect_arg_parser: bool = True,
+ auto_connect_frameworks: bool = True, auto_resource_monitoring: bool = True) -> None:
+ super().__init__()
+ self._trains = trains.Task.init(
+ project_name=project_name, task_name=task_name, task_type=task_type,
+ reuse_last_task_id=reuse_last_task_id, output_uri=output_uri,
+ auto_connect_arg_parser=auto_connect_arg_parser,
+ auto_connect_frameworks=auto_connect_frameworks,
+ auto_resource_monitoring=auto_resource_monitoring
+ )
+
+ @property
+ def experiment(self) -> trains.Task:
+ r"""Actual TRAINS object. To use TRAINS features do the following.
+
+ Example:
+ .. code-block:: python
+ self.logger.experiment.some_trains_function()
+
+ """
+ return self._trains
+
+ @property
+ def id(self) -> Union[str, None]:
+ """
+ ID is a uuid (string) representing this specific experiment in the entire system.
+ """
+ if not self._trains:
+ return None
+ return self._trains.id
+
+ @rank_zero_only
+ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
+ """Log hyperparameters (numeric values) in TRAINS experiments
+
+ Args:
+ params:
+ The hyperparameters that passed through the model.
+ """
+ if not self._trains:
+ return None
+ if not params:
+ return
+ if isinstance(params, dict):
+ self._trains.connect(params)
+ else:
+ self._trains.connect(vars(params))
+
+ @rank_zero_only
+ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
+ """Log metrics (numeric values) in TRAINS experiments.
+ This method will be called by Trainer.
+
+ Args:
+ metrics:
+ The dictionary of the metrics.
+ If the key contains "/", it will be split by the delimiter,
+ then the elements will be logged as "title" and "series" respectively.
+ step: Step number at which the metrics should be recorded. Defaults to None.
+ """
+ if not self._trains:
+ return None
+
+ if not step:
+ step = self._trains.get_last_iteration()
+
+ for k, v in metrics.items():
+ if isinstance(v, str):
+ log.warning("Discarding metric with string value {}={}".format(k, v))
+ continue
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ parts = k.split('/')
+ if len(parts) <= 1:
+ series = title = k
+ else:
+ title = parts[0]
+ series = '/'.join(parts[1:])
+ self._trains.get_logger().report_scalar(
+ title=title, series=series, value=v, iteration=step)
+
+ @rank_zero_only
+ def log_metric(self, title: str, series: str, value: float, step: Optional[int] = None) -> None:
+ """Log metrics (numeric values) in TRAINS experiments.
+ This method will be called by the users.
+
+ Args:
+ title: The title of the graph to log, e.g. loss, accuracy.
+ series: The series name in the graph, e.g. classification, localization.
+ value: The value to log.
+ step: Step number at which the metrics should be recorded. Defaults to None.
+ """
+ if not self._trains:
+ return None
+
+ if not step:
+ step = self._trains.get_last_iteration()
+
+ if isinstance(value, torch.Tensor):
+ value = value.item()
+ self._trains.get_logger().report_scalar(
+ title=title, series=series, value=value, iteration=step)
+
+ @rank_zero_only
+ def log_text(self, text: str) -> None:
+ """Log console text data in TRAINS experiment
+
+ Args:
+ text: The value of the log (data-point).
+ """
+ if not self._trains:
+ return None
+
+ self._trains.get_logger().report_text(text)
+
+ @rank_zero_only
+ def log_image(
+ self, title: str, series: str,
+ image: Union[str, np.ndarray, PIL.Image.Image, torch.Tensor],
+ step: Optional[int] = None) -> None:
+ """Log Debug image in TRAINS experiment
+
+ Args:
+ title: The title of the debug image, i.e. "failed", "passed".
+ series: The series name of the debug image, i.e. "Image 0", "Image 1".
+ image:
+ Debug image to log. Can be one of the following types:
+ Torch, Numpy, PIL image, path to image file (str)
+ If Numpy or Torch, the image is assume to be the following:
+ shape: CHW
+ color space: RGB
+ value range: [0., 1.] (float) or [0, 255] (uint8)
+ step:
+ Step number at which the metrics should be recorded. Defaults to None.
+ """
+ if not self._trains:
+ return None
+
+ if not step:
+ step = self._trains.get_last_iteration()
+
+ if isinstance(image, str):
+ self._trains.get_logger().report_image(
+ title=title, series=series, local_path=image, iteration=step)
+ else:
+ if isinstance(image, torch.Tensor):
+ image = image.cpu().numpy()
+ if isinstance(image, np.ndarray):
+ image = image.transpose(1, 2, 0)
+ self._trains.get_logger().report_image(
+ title=title, series=series, image=image, iteration=step)
+
+ @rank_zero_only
+ def log_artifact(
+ self, name: str,
+ artifact: Union[str, Path, Dict[str, Any], pd.DataFrame, np.ndarray, PIL.Image.Image],
+ metadata: Optional[Dict[str, Any]] = None, delete_after_upload: bool = False) -> None:
+ """Save an artifact (file/object) in TRAINS experiment storage.
+
+ Args:
+ name: Artifact name. Notice! it will override previous artifact
+ if name already exists
+ artifact: Artifact object to upload. Currently supports:
+ - string / pathlib2.Path are treated as path to artifact file to upload
+ If wildcard or a folder is passed, zip file containing the
+ local files will be created and uploaded
+ - dict will be stored as .json file and uploaded
+ - pandas.DataFrame will be stored as .csv.gz (compressed CSV file) and uploaded
+ - numpy.ndarray will be stored as .npz and uploaded
+ - PIL.Image will be stored to .png file and uploaded
+ metadata:
+ Simple key/value dictionary to store on the artifact. Defaults to None.
+ delete_after_upload:
+ If True local artifact will be deleted (only applies if artifact_object is a
+ local file). Defaults to False.
+ """
+ if not self._trains:
+ return None
+
+ self._trains.upload_artifact(
+ name=name, artifact_object=artifact, metadata=metadata,
+ delete_after_upload=delete_after_upload
+ )
+
+ def save(self) -> None:
+ pass
+
+ @rank_zero_only
+ def finalize(self, status: str) -> None:
+ if not self._trains:
+ return None
+ self._trains.close()
+ self._trains = None
+
+ @property
+ def name(self) -> Union[str, None]:
+ """
+ Name is a human readable non-unique name (str) of the experiment.
+ """
+ if not self._trains:
+ return None
+ return self._trains.name
+
+ @property
+ def version(self) -> Union[str, None]:
+ if not self._trains:
+ return None
+ return self._trains.id
+
+ def __getstate__(self) -> Union[str, None]:
+ if not self._trains:
+ return None
+ return self._trains.id
+
+ def __setstate__(self, state: str) -> None:
+ self._rank = 0
+ self._trains = None
+ if state:
+ self._trains = trains.Task.get_task(task_id=state)
diff --git a/requirements-extra.txt b/requirements-extra.txt
index dd153091052e85..1265bc654a6e4e 100644
--- a/requirements-extra.txt
+++ b/requirements-extra.txt
@@ -2,4 +2,5 @@ neptune-client>=0.4.4
comet-ml>=1.0.56
mlflow>=1.0.0
test_tube>=0.7.5
-wandb>=0.8.21
\ No newline at end of file
+wandb>=0.8.21
+trains>=0.13.3
diff --git a/tests/loggers/test_trains.py b/tests/loggers/test_trains.py
new file mode 100644
index 00000000000000..1c8ca4167462a4
--- /dev/null
+++ b/tests/loggers/test_trains.py
@@ -0,0 +1,48 @@
+import pickle
+
+import tests.models.utils as tutils
+from pytorch_lightning import Trainer
+from pytorch_lightning.loggers import TrainsLogger
+from tests.models import LightningTestModel
+
+
+def test_trains_logger(tmpdir):
+ """Verify that basic functionality of TRAINS logger works."""
+ tutils.reset_seed()
+
+ hparams = tutils.get_hparams()
+ model = LightningTestModel(hparams)
+ logger = TrainsLogger(project_name="examples", task_name="pytorch lightning test")
+
+ trainer_options = dict(
+ default_save_path=tmpdir,
+ max_epochs=1,
+ train_percent_check=0.05,
+ logger=logger
+ )
+ trainer = Trainer(**trainer_options)
+ result = trainer.fit(model)
+
+ print('result finished')
+ assert result == 1, "Training failed"
+
+
+def test_trains_pickle(tmpdir):
+ """Verify that pickling trainer with TRAINS logger works."""
+ tutils.reset_seed()
+
+ # hparams = tutils.get_hparams()
+ # model = LightningTestModel(hparams)
+
+ logger = TrainsLogger(project_name="examples", task_name="pytorch lightning test")
+
+ trainer_options = dict(
+ default_save_path=tmpdir,
+ max_epochs=1,
+ logger=logger
+ )
+
+ trainer = Trainer(**trainer_options)
+ pkl_bytes = pickle.dumps(trainer)
+ trainer2 = pickle.loads(pkl_bytes)
+ trainer2.logger.log_metrics({"acc": 1.0})