Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use fsspec instead of gfile for all IO #3320

Merged
merged 3 commits into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528/))
- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528))

### Changed

- Used `fsspec` instead of `gfile` for all IO ([#3320](https://github.com/PyTorchLightning/pytorch-lightning/pull/3320))

### Deprecated

Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies:
- future>=0.17.1
- PyYAML>=5.1
- tqdm>=4.41.0
- fsspec>=0.8.0
- nvidia-apex

# For dev and testing
Expand Down
45 changes: 17 additions & 28 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
from pytorch_lightning.utilities.cloud_io import gfile, makedirs, is_remote_path
from pytorch_lightning.utilities.cloud_io import get_filesystem


class ModelCheckpoint(Callback):
Expand Down Expand Up @@ -119,9 +119,11 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if(filepath):
filepath = str(filepath) # the tests pass in a py.path.local but we want a str
if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0:
if filepath:
self._fs = get_filesystem(filepath)
else:
self._fs = get_filesystem("") # will give local fileystem
if save_top_k > 0 and filepath is not None and self._fs.isdir(filepath) and len(self._fs.ls(filepath)) > 0:
rank_zero_warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
Expand All @@ -133,13 +135,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
if filepath is None: # will be determined by trainer at runtime
self.dirpath, self.filename = None, None
else:
if gfile.isdir(filepath):
self.dirpath, self.filename = filepath, '{epoch}'
if self._fs.isdir(filepath):
self.dirpath, self.filename = filepath, "{epoch}"
else:
if not is_remote_path(filepath): # dont normalize remote paths
if self._fs.protocol == "file": # dont normalize remote paths
filepath = os.path.realpath(filepath)
self.dirpath, self.filename = os.path.split(filepath)
makedirs(self.dirpath) # calls with exist_ok
self._fs.makedirs(self.dirpath, exist_ok=True)
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
Expand Down Expand Up @@ -182,28 +184,16 @@ def kth_best_model(self):
return self.kth_best_model_path

def _del_model(self, filepath):
if gfile.exists(filepath):
try:
# in compat mode, remove is not implemented so if running this
# against an actual remove file system and the correct remote
# dependencies exist then this will work fine.
gfile.remove(filepath)
except AttributeError:
if is_remote_path(filepath):
log.warning("Unable to remove stale checkpoints due to running gfile in compatibility mode."
" Please install tensorflow to run gfile in full mode"
" if writing checkpoints to remote locations")
else:
os.remove(filepath)
if self._fs.exists(filepath):
self._fs.rm(filepath)

def _save_model(self, filepath, trainer, pl_module):

# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)

# make paths
if not gfile.exists(os.path.dirname(filepath)):
makedirs(os.path.dirname(filepath))
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)

# delegate the saving to the model
if self.save_function is not None:
Expand Down Expand Up @@ -308,9 +298,8 @@ def on_pretrain_routine_start(self, trainer, pl_module):

self.dirpath = ckpt_path

assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
if not gfile.exists(self.dirpath):
makedirs(self.dirpath)
assert trainer.global_rank == 0, "tried to make a checkpoint from non global_rank=0"
self._fs.makedirs(self.dirpath, exist_ok=True)

def __warn_deprecated_monitor_key(self):
using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None)
Expand Down Expand Up @@ -359,7 +348,7 @@ def on_validation_end(self, trainer, pl_module):
ckpt_name_metrics = trainer.logged_metrics
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
version_cnt = 0
while gfile.exists(filepath):
while self._fs.exists(filepath):
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1
Expand Down Expand Up @@ -435,4 +424,4 @@ def on_save_checkpoint(self, trainer, pl_module):

def on_load_checkpoint(self, checkpointed_state):
self.best_model_score = checkpointed_state['best_model_score']
self.best_model_path = checkpointed_state['best_model_path']
self.best_model_path = checkpointed_state['best_model_path']
24 changes: 15 additions & 9 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
from argparse import Namespace
from typing import Union, Dict, Any, Optional, Callable, MutableMapping

import fsspec
import torch
import yaml

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import gfile, cloud_open
from pytorch_lightning.utilities.cloud_io import get_filesystem


PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
Expand Down Expand Up @@ -290,25 +292,27 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
True
>>> os.remove(path_csv)
"""
if not gfile.exists(tags_csv):
fs = get_filesystem(tags_csv)
if not fs.exists(tags_csv):
rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning)
return {}

with cloud_open(tags_csv, "r", newline="") as fp:
with fs.open(tags_csv, "r", newline="") as fp:
csv_reader = csv.reader(fp, delimiter=",")
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}

return tags


def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
if not gfile.isdir(os.path.dirname(tags_csv)):
fs = get_filesystem(tags_csv)
if not fs.isdir(os.path.dirname(tags_csv)):
raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")

if isinstance(hparams, Namespace):
hparams = vars(hparams)

with cloud_open(tags_csv, "w", newline="") as fp:
with fs.open(tags_csv, "w", newline="") as fp:
fieldnames = ["key", "value"]
writer = csv.DictWriter(fp, fieldnames=fieldnames)
writer.writerow({"key": "key", "value": "value"})
Expand All @@ -327,11 +331,12 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
True
>>> os.remove(path_yaml)
"""
if not gfile.exists(config_yaml):
fs = get_filesystem(config_yaml)
if not fs.exists(config_yaml):
rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning)
return {}

with cloud_open(config_yaml, "r") as fp:
with fs.open(config_yaml, "r") as fp:
tags = yaml.load(fp)

return tags
Expand All @@ -343,7 +348,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
config_yaml: path to new YAML file
hparams: parameters to be saved
"""
if not gfile.isdir(os.path.dirname(config_yaml)):
fs = get_filesystem(config_yaml)
if not fs.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")

# convert Namespace or AD to dict
Expand All @@ -364,7 +370,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:

# saving the standard way
assert isinstance(hparams, dict)
with cloud_open(config_yaml, 'w', newline='') as fp:
with fs.open(config_yaml, "w", newline="") as fp:
yaml.dump(hparams, fp)


Expand Down
19 changes: 10 additions & 9 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.core.lightning import LightningModule

try:
Expand Down Expand Up @@ -87,6 +87,7 @@ def __init__(
self._version = version
self._log_graph = log_graph
self._default_hp_metric = default_hp_metric
self._fs = get_filesystem(save_dir)

self._experiment = None
self.hparams = {}
Expand Down Expand Up @@ -136,8 +137,8 @@ def experiment(self) -> SummaryWriter:
return self._experiment

assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0'
if self.root_dir and not gfile.exists(str(self.root_dir)):
makedirs(self.root_dir)
if self.root_dir:
self._fs.makedirs(self.root_dir, exist_ok=True)
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment

Expand Down Expand Up @@ -207,7 +208,7 @@ def log_graph(self, model: LightningModule, input_array=None):
def save(self) -> None:
super().save()
dir_path = self.log_dir
if not gfile.isdir(dir_path):
if not self._fs.isdir(dir_path):
dir_path = self.save_dir

# prepare the file path
Expand All @@ -233,16 +234,16 @@ def version(self) -> int:
def _get_next_version(self):
root_dir = os.path.join(self.save_dir, self.name)

if not gfile.isdir(root_dir):
if not self._fs.isdir(root_dir):
log.warning('Missing logger folder: %s', root_dir)
return 0

existing_versions = []
for d in gfile.listdir(root_dir):
if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
dir_ver = d.split("_")[1].replace('/', '')
for d in self._fs.ls(root_dir):
bn = os.path.basename(d)
if self._fs.isdir(d) and bn.startswith("version_"):
dir_ver = bn.split("_")[1].replace('/', '')
existing_versions.append(int(dir_ver))

if len(existing_versions) == 0:
return 0

Expand Down
16 changes: 7 additions & 9 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.cloud_io import is_remote_path
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
from pytorch_lightning.trainer.data_connector import DataConnector
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
Expand Down Expand Up @@ -915,21 +915,19 @@ def default_root_dir(self) -> str:
The default location to save artifacts of loggers, checkpoints etc.
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
"""
if is_remote_path(self._default_root_dir):
# it is a remote uri, use as is
return self._default_root_dir
return os.path.normpath(self._default_root_dir)
if get_filesystem(self._default_root_dir).protocol == "file":
return os.path.normpath(self._default_root_dir)
return self._default_root_dir

@property
def weights_save_path(self) -> str:
"""
The default root location to save weights (checkpoints), e.g., when the
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
"""
if is_remote_path(self._weights_save_path):
# it is a remote uri, use as is
return self._weights_save_path
return os.path.normpath(self._weights_save_path)
if get_filesystem(self._weights_save_path).protocol == "file":
return os.path.normpath(self._weights_save_path)
return self._weights_save_path

def tune(
self,
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,8 @@
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import LightningDataParallel, LightningDistributedDataParallel
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, gfile
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import makedirs
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

try:
Expand Down Expand Up @@ -391,8 +390,9 @@ def restore_hpc_weights_if_needed(self, model: LightningModule):

# look for hpc weights
folderpath = str(self.weights_save_path)
if gfile.exists(folderpath):
files = gfile.listdir(folderpath)
fs = get_filesystem(folderpath)
if fs.exists(folderpath):
files = [os.path.basename(f) for f in fs.ls(folderpath)]
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]

# if hpc weights exist restore model
Expand Down Expand Up @@ -463,16 +463,15 @@ def restore_training_state(self, checkpoint):
def hpc_save(self, folderpath: str, logger):
# make sure the checkpoint folder exists
folderpath = str(folderpath) # because the tests pass a path object
if not gfile.exists(folderpath):
makedirs(folderpath)
fs = get_filesystem(folderpath)
fs.makedirs(folderpath, exist_ok=True)

# save logger to make sure we get all the metrics
logger.save()

ckpt_number = self.max_ckpt_in_folder(folderpath) + 1

if not gfile.exists(folderpath):
makedirs(folderpath)
fs.makedirs(folderpath, exist_ok=True)
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')

# give model a chance to do something on hpc_save
Expand Down Expand Up @@ -525,7 +524,8 @@ def hpc_load(self, folderpath, on_gpu):
log.info(f'restored hpc model from: {filepath}')

def max_ckpt_in_folder(self, path, name_key='ckpt_'):
files = gfile.listdir(str(path))
fs = get_filesystem(path)
files = [os.path.basename(f) for f in fs.ls(path)]
files = [x for x in files if name_key in x]
if len(files) == 0:
return 0
Expand Down
Loading