Skip to content

Commit

Permalink
Use tensorboard.compat.gfile to support remote writing
Browse files Browse the repository at this point in the history
  • Loading branch information
f4hy committed Aug 7, 2020
1 parent 2cc60c6 commit 6175d4e
Show file tree
Hide file tree
Showing 9 changed files with 552 additions and 321 deletions.
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ dependencies:
- pillow<7.0.0
- scikit-image
- nltk>=3.3
- boto3
- moto>=1.3.14

# Optional
- scikit-learn>=0.20.0
Expand Down
25 changes: 18 additions & 7 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
from pytorch_lightning.utilities.cloud_io import gfile, makedirs


class ModelCheckpoint(Callback):
Expand Down Expand Up @@ -100,7 +101,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
if(filepath):
filepath = str(filepath) # the tests pass in a py.path.local but we want a str
if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0:
rank_zero_warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
Expand All @@ -112,12 +115,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
if filepath is None: # will be determined by trainer at runtime
self.dirpath, self.filename = None, None
else:
if os.path.isdir(filepath):
if gfile.isdir(filepath):
self.dirpath, self.filename = filepath, '{epoch}'
else:
filepath = os.path.realpath(filepath)
self.dirpath, self.filename = os.path.split(filepath)
os.makedirs(self.dirpath, exist_ok=True)
if not gfile.exists(self.dirpath):
makedirs(self.dirpath)
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
Expand Down Expand Up @@ -159,16 +163,23 @@ def kth_best_model(self):
return self.kth_best_model_path

def _del_model(self, filepath):
if os.path.isfile(filepath):
os.remove(filepath)
if gfile.exists(filepath):
try:
# in compat mode, remove is not implemented so if running this
# against an actual remove file system and the correct remote
# dependencies exist then this will work fine.
gfile.remove(filepath)
except AttributeError:
os.remove(filepath)

def _save_model(self, filepath, trainer, pl_module):

# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)

# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)
if not gfile.exists(os.path.dirname(filepath)):
makedirs(os.path.dirname(filepath))

# delegate the saving to the model
if self.save_function is not None:
Expand Down Expand Up @@ -308,7 +319,7 @@ def on_validation_end(self, trainer, pl_module):

filepath = self.format_checkpoint_name(epoch, metrics)
version_cnt = 0
while os.path.isfile(filepath):
while gfile.exists(filepath):
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1
Expand Down
35 changes: 19 additions & 16 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
import csv
import io
import inspect
import os

Expand All @@ -11,6 +12,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import gfile, cloud_open

PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
Expand Down Expand Up @@ -273,30 +275,30 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
True
>>> os.remove(path_csv)
"""
if not os.path.isfile(tags_csv):
rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning)
if not gfile.exists(tags_csv):
rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning)
return {}

with open(tags_csv) as fp:
csv_reader = csv.reader(fp, delimiter=',')
with cloud_open(tags_csv, "r") as fp:
csv_reader = csv.reader(fp.read(), delimiter=",")
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}

return tags


def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
if not os.path.isdir(os.path.dirname(tags_csv)):
raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.')
if not gfile.isdir(os.path.dirname(tags_csv)):
raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")

if isinstance(hparams, Namespace):
hparams = vars(hparams)

with open(tags_csv, 'w', newline='') as fp:
fieldnames = ['key', 'value']
with cloud_open(tags_csv, "w", newline="") as fp:
fieldnames = ["key", "value"]
writer = csv.DictWriter(fp, fieldnames=fieldnames)
writer.writerow({'key': 'key', 'value': 'value'})
writer.writerow({"key": "key", "value": "value"})
for k, v in hparams.items():
writer.writerow({'key': k, 'value': v})
writer.writerow({"key": k, "value": v})


def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
Expand All @@ -310,11 +312,11 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
True
>>> os.remove(path_yaml)
"""
if not os.path.isfile(config_yaml):
rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning)
if not gfile.exists(config_yaml):
rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning)
return {}

with open(config_yaml) as fp:
with cloud_open(config_yaml, "r") as fp:
tags = yaml.load(fp)

return tags
Expand All @@ -326,11 +328,12 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
config_yaml: path to new YAML file
hparams: parameters to be saved
"""
if not os.path.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
if not gfile.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")

if OMEGACONF_AVAILABLE and isinstance(hparams, Container):
from omegaconf import OmegaConf

OmegaConf.save(hparams, config_yaml, resolve=True)
return

Expand All @@ -341,7 +344,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
hparams = dict(hparams)
assert isinstance(hparams, dict)

with open(config_yaml, 'w', newline='') as fp:
with cloud_open(config_yaml, "w", newline="") as fp:
yaml.dump(hparams, fp)


Expand Down
12 changes: 7 additions & 5 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.cloud_io import gfile, makedirs

try:
from omegaconf import Container, OmegaConf
Expand Down Expand Up @@ -109,7 +110,8 @@ def experiment(self) -> SummaryWriter:
return self._experiment

assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0'
os.makedirs(self.root_dir, exist_ok=True)
if self.root_dir and not gfile.exists(str(self.root_dir)):
makedirs(self.root_dir)
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment

Expand Down Expand Up @@ -162,7 +164,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
def save(self) -> None:
super().save()
dir_path = self.log_dir
if not os.path.isdir(dir_path):
if not gfile.isdir(dir_path):
dir_path = self.save_dir

# prepare the file path
Expand All @@ -188,13 +190,13 @@ def version(self) -> int:
def _get_next_version(self):
root_dir = os.path.join(self.save_dir, self.name)

if not os.path.isdir(root_dir):
if not gfile.isdir(root_dir):
log.warning('Missing logger folder: %s', root_dir)
return 0

existing_versions = []
for d in os.listdir(root_dir):
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
for d in gfile.listdir(root_dir):
if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
existing_versions.append(int(d.split("_")[1]))

if len(existing_versions) == 0:
Expand Down
17 changes: 10 additions & 7 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
)
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import gfile, makedirs

try:
import torch_xla
Expand Down Expand Up @@ -409,9 +410,9 @@ def restore_hpc_weights_if_needed(self, model: LightningModule):
did_restore = False

# look for hpc weights
folderpath = self.weights_save_path
if os.path.exists(folderpath):
files = os.listdir(folderpath)
folderpath = str(self.weights_save_path)
if gfile.exists(folderpath):
files = gfile.listdir(folderpath)
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]

# if hpc weights exist restore model
Expand Down Expand Up @@ -490,15 +491,17 @@ def restore_training_state(self, checkpoint):
# ----------------------------------
def hpc_save(self, folderpath: str, logger):
# make sure the checkpoint folder exists
os.makedirs(folderpath, exist_ok=True)
folderpath = str(folderpath) # because the tests pass a path object
if not gfile.exists(folderpath):
makedirs(folderpath)

# save logger to make sure we get all the metrics
logger.save()

ckpt_number = self.max_ckpt_in_folder(folderpath) + 1

if not os.path.exists(folderpath):
os.makedirs(folderpath, exist_ok=True)
if not gfile.exists(folderpath):
makedirs(folderpath)
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')

# give model a chance to do something on hpc_save
Expand Down Expand Up @@ -551,7 +554,7 @@ def hpc_load(self, folderpath, on_gpu):
log.info(f'restored hpc model from: {filepath}')

def max_ckpt_in_folder(self, path, name_key='ckpt_'):
files = os.listdir(path)
files = gfile.listdir(str(path))
files = [x for x in files if name_key in x]
if len(files) == 0:
return 0
Expand Down
53 changes: 49 additions & 4 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,56 @@
import torch

import sys
import os
from typing import Union
from pathlib import Path
from urllib.parse import urlparse
import torch

import tensorboard
from packaging import version
from pytorch_lightning import _logger as log

# we want this for tf.io.gfile, which if tf is installed gives full tf,
# otherwise gives a pruned down version which works for some file backends but
# not all
from tensorboard.compat import tf

gfile = tf.io.gfile

pathlike = Union[Path, str]

# older version of tensorboard had buggy gfile compatibility layers
# only support remote cloud paths if newer
modern_gfile = version.parse(tensorboard.version.VERSION) >= version.parse('2.0')


def load(path_or_url: str, map_location=None):
if urlparse(path_or_url).scheme == '' or Path(path_or_url).drive: # no scheme or with a drive letter
return torch.load(path_or_url, map_location=map_location)
else:
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)


def cloud_open(path: pathlike, mode: str, newline:str = None):
if not modern_gfile or sys.platform == "win32":
log.debug(
"tenosrboard.compat gfile does not work on older versions "
"of tensorboard normal local file open."
)
return open(path, mode, newline=newline)
if sys.platform == "win32":
log.debug(
"gfile does not handle newlines correctly on windows so remote files are not"
"supported falling back to normal local file open."
)
return open(path, mode, newline=newline)
try:
return gfile.GFile(path, mode)
except NotImplementedError as e:
# minimal dependencies are installed and only local files will work
return open(path, mode)


def makedirs(path: pathlike):
if modern_gfile and hasattr(gfile, "makedirs"):
return gfile.makedirs(str(path))
# otherwise minimal dependencies are installed and only local files will work
return os.makedirs(path, exist_ok=True)
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ future>=0.17.1 # required for builtins in setup.py
# pyyaml>=3.13
PyYAML>=5.1 # OmegaConf requirement >=5.1
tqdm>=4.41.0
packaging
3 changes: 3 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@ black==19.10b0
pre-commit>=1.0

cloudpickle>=1.2

boto3
moto>=1.3.14
nltk>=3.3
Loading

0 comments on commit 6175d4e

Please sign in to comment.