Skip to content

Commit

Permalink
use fsspec instead of gfile for all IO (#3320)
Browse files Browse the repository at this point in the history
* use fsspec instead of gfile for all IO

This better supports remote (and local) file operations with a dedicated package

* Apply suggestions from code review

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* chlog

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
  • Loading branch information
4 people committed Sep 3, 2020
1 parent d521c1b commit 2d8c1b7
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 126 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528/))
- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528))

### Changed

- Used `fsspec` instead of `gfile` for all IO ([#3320](https://github.com/PyTorchLightning/pytorch-lightning/pull/3320))

### Deprecated

Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies:
- future>=0.17.1
- PyYAML>=5.1
- tqdm>=4.41.0
- fsspec>=0.8.0
- nvidia-apex

# For dev and testing
Expand Down
45 changes: 17 additions & 28 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
from pytorch_lightning.utilities.cloud_io import gfile, makedirs, is_remote_path
from pytorch_lightning.utilities.cloud_io import get_filesystem


class ModelCheckpoint(Callback):
Expand Down Expand Up @@ -119,9 +119,11 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if(filepath):
filepath = str(filepath) # the tests pass in a py.path.local but we want a str
if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0:
if filepath:
self._fs = get_filesystem(filepath)
else:
self._fs = get_filesystem("") # will give local fileystem
if save_top_k > 0 and filepath is not None and self._fs.isdir(filepath) and len(self._fs.ls(filepath)) > 0:
rank_zero_warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
Expand All @@ -133,13 +135,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
if filepath is None: # will be determined by trainer at runtime
self.dirpath, self.filename = None, None
else:
if gfile.isdir(filepath):
self.dirpath, self.filename = filepath, '{epoch}'
if self._fs.isdir(filepath):
self.dirpath, self.filename = filepath, "{epoch}"
else:
if not is_remote_path(filepath): # dont normalize remote paths
if self._fs.protocol == "file": # dont normalize remote paths
filepath = os.path.realpath(filepath)
self.dirpath, self.filename = os.path.split(filepath)
makedirs(self.dirpath) # calls with exist_ok
self._fs.makedirs(self.dirpath, exist_ok=True)
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
Expand Down Expand Up @@ -182,28 +184,16 @@ def kth_best_model(self):
return self.kth_best_model_path

def _del_model(self, filepath):
if gfile.exists(filepath):
try:
# in compat mode, remove is not implemented so if running this
# against an actual remove file system and the correct remote
# dependencies exist then this will work fine.
gfile.remove(filepath)
except AttributeError:
if is_remote_path(filepath):
log.warning("Unable to remove stale checkpoints due to running gfile in compatibility mode."
" Please install tensorflow to run gfile in full mode"
" if writing checkpoints to remote locations")
else:
os.remove(filepath)
if self._fs.exists(filepath):
self._fs.rm(filepath)

def _save_model(self, filepath, trainer, pl_module):

# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)

# make paths
if not gfile.exists(os.path.dirname(filepath)):
makedirs(os.path.dirname(filepath))
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)

# delegate the saving to the model
if self.save_function is not None:
Expand Down Expand Up @@ -308,9 +298,8 @@ def on_pretrain_routine_start(self, trainer, pl_module):

self.dirpath = ckpt_path

assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
if not gfile.exists(self.dirpath):
makedirs(self.dirpath)
assert trainer.global_rank == 0, "tried to make a checkpoint from non global_rank=0"
self._fs.makedirs(self.dirpath, exist_ok=True)

def __warn_deprecated_monitor_key(self):
using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None)
Expand Down Expand Up @@ -359,7 +348,7 @@ def on_validation_end(self, trainer, pl_module):
ckpt_name_metrics = trainer.logged_metrics
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
version_cnt = 0
while gfile.exists(filepath):
while self._fs.exists(filepath):
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1
Expand Down Expand Up @@ -435,4 +424,4 @@ def on_save_checkpoint(self, trainer, pl_module):

def on_load_checkpoint(self, checkpointed_state):
self.best_model_score = checkpointed_state['best_model_score']
self.best_model_path = checkpointed_state['best_model_path']
self.best_model_path = checkpointed_state['best_model_path']
24 changes: 15 additions & 9 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
from argparse import Namespace
from typing import Union, Dict, Any, Optional, Callable, MutableMapping

import fsspec
import torch
import yaml

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import gfile, cloud_open
from pytorch_lightning.utilities.cloud_io import get_filesystem


PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
Expand Down Expand Up @@ -290,25 +292,27 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
True
>>> os.remove(path_csv)
"""
if not gfile.exists(tags_csv):
fs = get_filesystem(tags_csv)
if not fs.exists(tags_csv):
rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning)
return {}

with cloud_open(tags_csv, "r", newline="") as fp:
with fs.open(tags_csv, "r", newline="") as fp:
csv_reader = csv.reader(fp, delimiter=",")
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}

return tags


def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
if not gfile.isdir(os.path.dirname(tags_csv)):
fs = get_filesystem(tags_csv)
if not fs.isdir(os.path.dirname(tags_csv)):
raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")

if isinstance(hparams, Namespace):
hparams = vars(hparams)

with cloud_open(tags_csv, "w", newline="") as fp:
with fs.open(tags_csv, "w", newline="") as fp:
fieldnames = ["key", "value"]
writer = csv.DictWriter(fp, fieldnames=fieldnames)
writer.writerow({"key": "key", "value": "value"})
Expand All @@ -327,11 +331,12 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
True
>>> os.remove(path_yaml)
"""
if not gfile.exists(config_yaml):
fs = get_filesystem(config_yaml)
if not fs.exists(config_yaml):
rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning)
return {}

with cloud_open(config_yaml, "r") as fp:
with fs.open(config_yaml, "r") as fp:
tags = yaml.load(fp)

return tags
Expand All @@ -343,7 +348,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
config_yaml: path to new YAML file
hparams: parameters to be saved
"""
if not gfile.isdir(os.path.dirname(config_yaml)):
fs = get_filesystem(config_yaml)
if not fs.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")

# convert Namespace or AD to dict
Expand All @@ -364,7 +370,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:

# saving the standard way
assert isinstance(hparams, dict)
with cloud_open(config_yaml, 'w', newline='') as fp:
with fs.open(config_yaml, "w", newline="") as fp:
yaml.dump(hparams, fp)


Expand Down
19 changes: 10 additions & 9 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.core.lightning import LightningModule

try:
Expand Down Expand Up @@ -87,6 +87,7 @@ def __init__(
self._version = version
self._log_graph = log_graph
self._default_hp_metric = default_hp_metric
self._fs = get_filesystem(save_dir)

self._experiment = None
self.hparams = {}
Expand Down Expand Up @@ -136,8 +137,8 @@ def experiment(self) -> SummaryWriter:
return self._experiment

assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0'
if self.root_dir and not gfile.exists(str(self.root_dir)):
makedirs(self.root_dir)
if self.root_dir:
self._fs.makedirs(self.root_dir, exist_ok=True)
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment

Expand Down Expand Up @@ -207,7 +208,7 @@ def log_graph(self, model: LightningModule, input_array=None):
def save(self) -> None:
super().save()
dir_path = self.log_dir
if not gfile.isdir(dir_path):
if not self._fs.isdir(dir_path):
dir_path = self.save_dir

# prepare the file path
Expand All @@ -233,16 +234,16 @@ def version(self) -> int:
def _get_next_version(self):
root_dir = os.path.join(self.save_dir, self.name)

if not gfile.isdir(root_dir):
if not self._fs.isdir(root_dir):
log.warning('Missing logger folder: %s', root_dir)
return 0

existing_versions = []
for d in gfile.listdir(root_dir):
if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
dir_ver = d.split("_")[1].replace('/', '')
for d in self._fs.ls(root_dir):
bn = os.path.basename(d)
if self._fs.isdir(d) and bn.startswith("version_"):
dir_ver = bn.split("_")[1].replace('/', '')
existing_versions.append(int(dir_ver))

if len(existing_versions) == 0:
return 0

Expand Down
16 changes: 7 additions & 9 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.cloud_io import is_remote_path
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
from pytorch_lightning.trainer.data_connector import DataConnector
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
Expand Down Expand Up @@ -915,21 +915,19 @@ def default_root_dir(self) -> str:
The default location to save artifacts of loggers, checkpoints etc.
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
"""
if is_remote_path(self._default_root_dir):
# it is a remote uri, use as is
return self._default_root_dir
return os.path.normpath(self._default_root_dir)
if get_filesystem(self._default_root_dir).protocol == "file":
return os.path.normpath(self._default_root_dir)
return self._default_root_dir

@property
def weights_save_path(self) -> str:
"""
The default root location to save weights (checkpoints), e.g., when the
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
"""
if is_remote_path(self._weights_save_path):
# it is a remote uri, use as is
return self._weights_save_path
return os.path.normpath(self._weights_save_path)
if get_filesystem(self._weights_save_path).protocol == "file":
return os.path.normpath(self._weights_save_path)
return self._weights_save_path

def tune(
self,
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,8 @@
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import LightningDataParallel, LightningDistributedDataParallel
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, gfile
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import makedirs
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

try:
Expand Down Expand Up @@ -391,8 +390,9 @@ def restore_hpc_weights_if_needed(self, model: LightningModule):

# look for hpc weights
folderpath = str(self.weights_save_path)
if gfile.exists(folderpath):
files = gfile.listdir(folderpath)
fs = get_filesystem(folderpath)
if fs.exists(folderpath):
files = [os.path.basename(f) for f in fs.ls(folderpath)]
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]

# if hpc weights exist restore model
Expand Down Expand Up @@ -463,16 +463,15 @@ def restore_training_state(self, checkpoint):
def hpc_save(self, folderpath: str, logger):
# make sure the checkpoint folder exists
folderpath = str(folderpath) # because the tests pass a path object
if not gfile.exists(folderpath):
makedirs(folderpath)
fs = get_filesystem(folderpath)
fs.makedirs(folderpath, exist_ok=True)

# save logger to make sure we get all the metrics
logger.save()

ckpt_number = self.max_ckpt_in_folder(folderpath) + 1

if not gfile.exists(folderpath):
makedirs(folderpath)
fs.makedirs(folderpath, exist_ok=True)
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')

# give model a chance to do something on hpc_save
Expand Down Expand Up @@ -525,7 +524,8 @@ def hpc_load(self, folderpath, on_gpu):
log.info(f'restored hpc model from: {filepath}')

def max_ckpt_in_folder(self, path, name_key='ckpt_'):
files = gfile.listdir(str(path))
fs = get_filesystem(path)
files = [os.path.basename(f) for f in fs.ls(path)]
files = [x for x in files if name_key in x]
if len(files) == 0:
return 0
Expand Down
Loading

0 comments on commit 2d8c1b7

Please sign in to comment.