Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ready for review] use gfile to support remote directories #2164

Merged
merged 1 commit into from
Aug 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
from pytorch_lightning.utilities.cloud_io import gfile, makedirs


class ModelCheckpoint(Callback):
Expand Down Expand Up @@ -104,7 +105,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
if(filepath):
filepath = str(filepath) # the tests pass in a py.path.local but we want a str
if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0:
rank_zero_warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
Expand All @@ -116,12 +119,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
if filepath is None: # will be determined by trainer at runtime
self.dirpath, self.filename = None, None
else:
if os.path.isdir(filepath):
if gfile.isdir(filepath):
self.dirpath, self.filename = filepath, '{epoch}'
else:
filepath = os.path.realpath(filepath)
self.dirpath, self.filename = os.path.split(filepath)
os.makedirs(self.dirpath, exist_ok=True)
if not gfile.exists(self.dirpath):
makedirs(self.dirpath)
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
Expand Down Expand Up @@ -163,16 +167,23 @@ def kth_best_model(self):
return self.kth_best_model_path

def _del_model(self, filepath):
if os.path.isfile(filepath):
os.remove(filepath)
if gfile.exists(filepath):
try:
# in compat mode, remove is not implemented so if running this
# against an actual remove file system and the correct remote
# dependencies exist then this will work fine.
gfile.remove(filepath)
except AttributeError:
os.remove(filepath)

def _save_model(self, filepath, trainer, pl_module):

# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)

# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)
if not gfile.exists(os.path.dirname(filepath)):
makedirs(os.path.dirname(filepath))

# delegate the saving to the model
if self.save_function is not None:
Expand Down Expand Up @@ -308,7 +319,7 @@ def on_validation_end(self, trainer, pl_module):

filepath = self.format_checkpoint_name(epoch, metrics)
version_cnt = 0
while os.path.isfile(filepath):
while gfile.exists(filepath):
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1
Expand Down
34 changes: 18 additions & 16 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import gfile, cloud_open

PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
Expand Down Expand Up @@ -273,30 +274,30 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
True
>>> os.remove(path_csv)
"""
if not os.path.isfile(tags_csv):
rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning)
if not gfile.exists(tags_csv):
rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning)
return {}

with open(tags_csv) as fp:
csv_reader = csv.reader(fp, delimiter=',')
with cloud_open(tags_csv, "r", newline="") as fp:
csv_reader = csv.reader(fp, delimiter=",")
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}

return tags


def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
if not os.path.isdir(os.path.dirname(tags_csv)):
raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.')
if not gfile.isdir(os.path.dirname(tags_csv)):
raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")

if isinstance(hparams, Namespace):
hparams = vars(hparams)

with open(tags_csv, 'w', newline='') as fp:
fieldnames = ['key', 'value']
with cloud_open(tags_csv, "w", newline="") as fp:
fieldnames = ["key", "value"]
writer = csv.DictWriter(fp, fieldnames=fieldnames)
writer.writerow({'key': 'key', 'value': 'value'})
writer.writerow({"key": "key", "value": "value"})
for k, v in hparams.items():
writer.writerow({'key': k, 'value': v})
writer.writerow({"key": k, "value": v})


def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
Expand All @@ -310,11 +311,11 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
True
>>> os.remove(path_yaml)
"""
if not os.path.isfile(config_yaml):
rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning)
if not gfile.exists(config_yaml):
rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning)
return {}

with open(config_yaml) as fp:
with cloud_open(config_yaml, "r") as fp:
tags = yaml.load(fp)

return tags
Expand All @@ -326,11 +327,12 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
config_yaml: path to new YAML file
hparams: parameters to be saved
"""
if not os.path.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
if not gfile.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")

if OMEGACONF_AVAILABLE and isinstance(hparams, Container):
from omegaconf import OmegaConf

OmegaConf.save(hparams, config_yaml, resolve=True)
return

Expand All @@ -341,7 +343,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
hparams = dict(hparams)
assert isinstance(hparams, dict)

with open(config_yaml, 'w', newline='') as fp:
with cloud_open(config_yaml, "w", newline="") as fp:
yaml.dump(hparams, fp)


Expand Down
12 changes: 7 additions & 5 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.cloud_io import gfile, makedirs

try:
from omegaconf import Container, OmegaConf
Expand Down Expand Up @@ -109,7 +110,8 @@ def experiment(self) -> SummaryWriter:
return self._experiment

assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0'
os.makedirs(self.root_dir, exist_ok=True)
if self.root_dir and not gfile.exists(str(self.root_dir)):
makedirs(self.root_dir)
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment

Expand Down Expand Up @@ -162,7 +164,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
def save(self) -> None:
super().save()
dir_path = self.log_dir
if not os.path.isdir(dir_path):
if not gfile.isdir(dir_path):
dir_path = self.save_dir

# prepare the file path
Expand All @@ -188,13 +190,13 @@ def version(self) -> int:
def _get_next_version(self):
root_dir = os.path.join(self.save_dir, self.name)

if not os.path.isdir(root_dir):
if not gfile.isdir(root_dir):
log.warning('Missing logger folder: %s', root_dir)
return 0

existing_versions = []
for d in os.listdir(root_dir):
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
for d in gfile.listdir(root_dir):
if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
existing_versions.append(int(d.split("_")[1]))

if len(existing_versions) == 0:
Expand Down
17 changes: 10 additions & 7 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
)
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import gfile, makedirs

try:
import torch_xla
Expand Down Expand Up @@ -407,9 +408,9 @@ def restore_hpc_weights_if_needed(self, model: LightningModule):
did_restore = False

# look for hpc weights
folderpath = self.weights_save_path
if os.path.exists(folderpath):
files = os.listdir(folderpath)
folderpath = str(self.weights_save_path)
if gfile.exists(folderpath):
files = gfile.listdir(folderpath)
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]

# if hpc weights exist restore model
Expand Down Expand Up @@ -488,15 +489,17 @@ def restore_training_state(self, checkpoint):
# ----------------------------------
def hpc_save(self, folderpath: str, logger):
# make sure the checkpoint folder exists
os.makedirs(folderpath, exist_ok=True)
folderpath = str(folderpath) # because the tests pass a path object
if not gfile.exists(folderpath):
makedirs(folderpath)

# save logger to make sure we get all the metrics
logger.save()

ckpt_number = self.max_ckpt_in_folder(folderpath) + 1

if not os.path.exists(folderpath):
os.makedirs(folderpath, exist_ok=True)
if not gfile.exists(folderpath):
makedirs(folderpath)
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')

# give model a chance to do something on hpc_save
Expand Down Expand Up @@ -549,7 +552,7 @@ def hpc_load(self, folderpath, on_gpu):
log.info(f'restored hpc model from: {filepath}')

def max_ckpt_in_folder(self, path, name_key='ckpt_'):
files = os.listdir(path)
files = gfile.listdir(str(path))
files = [x for x in files if name_key in x]
if len(files) == 0:
return 0
Expand Down
60 changes: 57 additions & 3 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,65 @@
import sys
import os
from typing import Union
from pathlib import Path
from urllib.parse import urlparse

import torch

import tensorboard
from packaging import version
from pytorch_lightning import _logger as log

# we want this for tf.io.gfile, which if tf is installed gives full tf,
# otherwise gives a pruned down version which works for some file backends but
# not all
from tensorboard.compat import tf

gfile = tf.io.gfile

pathlike = Union[Path, str]

# older version of tensorboard had buggy gfile compatibility layers
# only support remote cloud paths if newer


def load(path_or_url: str, map_location=None):
if urlparse(path_or_url).scheme == '' or Path(path_or_url).drive: # no scheme or with a drive letter
return torch.load(path_or_url, map_location=map_location)
else:
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)


def modern_gfile():
"""Check the version number of tensorboard.

Cheking to see if it has the gfile compatibility layers needed for remote
file operations
"""
tb_version = version.parse(tensorboard.version.VERSION)
modern_gfile = tb_version >= version.parse('2.0')


def cloud_open(path: pathlike, mode: str, newline:str = None):
if sys.platform == "win32":
log.debug(
"gfile does not handle newlines correctly on windows so remote files are not"
"supported falling back to normal local file open."
)
return open(path, mode, newline=newline)
if not modern_gfile():
log.debug(
"tenosrboard.compat gfile does not work on older versions "
"of tensorboard for remote files, using normal local file open."
)
return open(path, mode, newline=newline)
try:
return gfile.GFile(path, mode)
except NotImplementedError as e:
# minimal dependencies are installed and only local files will work
return open(path, mode, newline=newline)


def makedirs(path: pathlike):
if hasattr(gfile, "makedirs") and modern_gfile():
return gfile.makedirs(str(path))
# otherwise minimal dependencies are installed and only local files will work
return os.makedirs(path, exist_ok=True)
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ future>=0.17.1 # required for builtins in setup.py
# pyyaml>=3.13
PyYAML>=5.1 # OmegaConf requirement >=5.1
tqdm>=4.41.0
packaging