Skip to content

Commit

Permalink
Fix TrainsLogger doctest failing (switch to bypass mode in GitHub CI) (
Browse files Browse the repository at this point in the history
…#1379)

* Fix TrainsLogger doctest failing (switch to bypass mode in GitHub CI)

* fix

* test ci

* debug

* debug CI

* Fix CircleCI

* Fix Any CI environment switch to bypass mode

* Removed debug prints

* Improve code coverage

* Improve code coverage

* Reverted

* Improve code coverage

* Test CI

* test codecov

* Codecov fix

* remove pragma

Co-authored-by: bmartinn <>
  • Loading branch information
bmartinn committed Apr 8, 2020
1 parent 2ae2bd2 commit fb8d085
Showing 1 changed file with 59 additions and 22 deletions.
81 changes: 59 additions & 22 deletions pytorch_lightning/loggers/trains.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def any_lightning_module_function_or_hook(...):
"""
from argparse import Namespace
from os import environ
from pathlib import Path
from typing import Any, Dict, Optional, Union

Expand Down Expand Up @@ -58,9 +59,9 @@ class TrainsLogger(LightningLoggerBase):
sent along side the task scalars. Defaults to True.
Examples:
>>> logger = TrainsLogger("lightning_log", "my-test", output_uri=".") # doctest: +ELLIPSIS
>>> logger = TrainsLogger("lightning_log", "my-lightning-test", output_uri=".") # doctest: +ELLIPSIS
TRAINS Task: ...
TRAINS results page: https://demoapp.trains.allegro.ai/.../log
TRAINS results page: ...
>>> logger.log_metrics({"val_loss": 1.23}, step=0)
>>> logger.log_text("sample test")
sample test
Expand All @@ -69,7 +70,7 @@ class TrainsLogger(LightningLoggerBase):
>>> logger.log_image("passed", "Image 1", np.random.randint(0, 255, (200, 150, 3), dtype=np.uint8))
"""

_bypass = False
_bypass = None

def __init__(
self,
Expand All @@ -83,8 +84,24 @@ def __init__(
auto_resource_monitoring: bool = True
) -> None:
super().__init__()
if self._bypass:
if self.bypass_mode():
self._trains = None
print('TRAINS Task: running in bypass mode')
print('TRAINS results page: disabled')

class _TaskStub(object):
def __call__(self, *args, **kwargs):
return self

def __getattr__(self, attr):
if attr in ('name', 'id'):
return ''
return self

def __setattr__(self, attr, val):
pass

self._trains = _TaskStub()
else:
self._trains = Task.init(
project_name=project_name,
Expand Down Expand Up @@ -114,8 +131,9 @@ def id(self) -> Union[str, None]:
"""
ID is a uuid (string) representing this specific experiment in the entire system.
"""
if self._bypass or not self._trains:
if not self._trains:
return None

return self._trains.id

@rank_zero_only
Expand All @@ -126,8 +144,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params:
The hyperparameters that passed through the model.
"""
if self._bypass or not self._trains:
return None
if not self._trains:
return
if not params:
return

Expand All @@ -147,8 +165,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
then the elements will be logged as "title" and "series" respectively.
step: Step number at which the metrics should be recorded. Defaults to None.
"""
if self._bypass or not self._trains:
return None
if not self._trains:
return

if not step:
step = self._trains.get_last_iteration()
Expand Down Expand Up @@ -179,8 +197,8 @@ def log_metric(self, title: str, series: str, value: float, step: Optional[int]
value: The value to log.
step: Step number at which the metrics should be recorded. Defaults to None.
"""
if self._bypass or not self._trains:
return None
if not self._trains:
return

if not step:
step = self._trains.get_last_iteration()
Expand All @@ -197,8 +215,12 @@ def log_text(self, text: str) -> None:
Args:
text: The value of the log (data-point).
"""
if self._bypass or not self._trains:
return None
if self.bypass_mode():
print(text)
return

if not self._trains:
return

self._trains.get_logger().report_text(text)

Expand All @@ -222,8 +244,8 @@ def log_image(
step:
Step number at which the metrics should be recorded. Defaults to None.
"""
if self._bypass or not self._trains:
return None
if not self._trains:
return

if not step:
step = self._trains.get_last_iteration()
Expand Down Expand Up @@ -265,8 +287,8 @@ def log_artifact(
If True local artifact will be deleted (only applies if artifact_object is a
local file). Defaults to False.
"""
if self._bypass or not self._trains:
return None
if not self._trains:
return

self._trains.upload_artifact(
name=name, artifact_object=artifact, metadata=metadata,
Expand All @@ -278,8 +300,9 @@ def save(self) -> None:

@rank_zero_only
def finalize(self, status: str = None) -> None:
if self._bypass or not self._trains:
return None
if self.bypass_mode() or not self._trains:
return

self._trains.close()
self._trains = None

Expand All @@ -288,14 +311,16 @@ def name(self) -> Union[str, None]:
"""
Name is a human readable non-unique name (str) of the experiment.
"""
if self._bypass or not self._trains:
if not self._trains:
return ''

return self._trains.name

@property
def version(self) -> Union[str, None]:
if self._bypass or not self._trains:
if not self._trains:
return None

return self._trains.id

@classmethod
Expand Down Expand Up @@ -327,9 +352,21 @@ def set_bypass_mode(cls, bypass: bool) -> None:
"""
cls._bypass = bypass

@classmethod
def bypass_mode(cls) -> bool:
"""
bypass_mode returns the bypass mode state.
Notice GITHUB_ACTIONS env will automatically set bypass_mode to True
unless overridden specifically with set_bypass_mode(False)
:return: If True, all outside communication is skipped
"""
return cls._bypass if cls._bypass is not None else bool(environ.get('CI'))

def __getstate__(self) -> Union[str, None]:
if self._bypass or not self._trains:
if self.bypass_mode() or not self._trains:
return ''

return self._trains.id

def __setstate__(self, state: str) -> None:
Expand Down

0 comments on commit fb8d085

Please sign in to comment.