Skip to content

Commit

Permalink
Simplify the compat.gfile import and fix the tests
Browse files Browse the repository at this point in the history
Tests all use the `tmpfile` fixture which provides a py.path.local which is
incompatible with the compat.gfile. The contract in many places is type str or
Optional[str] which py.path.local is not.

I hope that folks are not passing in path.local objects, if so this change will
break them. The type annotations say to use str, so this should be ok. The
other option is to just explicitly convert to str as to not break people using
an incorrect type (like the tests were doing)
  • Loading branch information
f4hy committed Jun 13, 2020
1 parent 6dcc9e6 commit b211832
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 96 deletions.
29 changes: 18 additions & 11 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
# we want this for tf.iogfile, which if tf is installed gives full tf,
# otherwise gives a pruned down version which works for some file backends but
# not all
from tensorboard.compat import tf
from pytorch_lightning.utilities.io import gfile


class ModelCheckpoint(Callback):
Expand Down Expand Up @@ -101,7 +98,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if save_top_k > 0 and filepath is not None and tf.io.gfile.isdir(filepath) and len(tf.io.gfile.listdir(filepath)) > 0:
if(filepath):
filepath = str(filepath) # the tests pass in a py.path.local but we want a str
if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0:
rank_zero_warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
Expand All @@ -113,11 +112,12 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
if filepath is None: # will be determined by trainer at runtime
self.dirpath, self.filename = None, None
else:
if tf.io.gfile.isdir(filepath):
if gfile.isdir(filepath):
self.dirpath, self.filename = filepath, '{epoch}'
else:
self.dirpath, self.filename = os.path.split(filepath)
os.makedirs(self.dirpath, exist_ok=True)
if not gfile.exists(self.dirpath):
gfile.makedirs(self.dirpath)
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
Expand Down Expand Up @@ -159,12 +159,19 @@ def kth_best_model(self):
return self.kth_best_model_path

def _del_model(self, filepath):
if tf.io.gfile.exists(filepath):
os.remove(filepath)
if gfile.exists(filepath):
try:
# in compat mode, remote is not implemented so if running this
# against an actual remove file system and the correct remote
# dependencies exist then this will work fine.
gfile.remove(filepath)
except AttributeError:
os.remove(filepath)

def _save_model(self, filepath):
# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)
if not gfile.exists(os.path.dirname(filepath)):
gfile.makedirs(os.path.dirname(filepath))

# delegate the saving to the model
if self.save_function is not None:
Expand Down Expand Up @@ -252,7 +259,7 @@ def on_validation_end(self, trainer, pl_module):

filepath = self.format_checkpoint_name(epoch, metrics)
version_cnt = 0
while tf.io.gfile.exists(filepath):
while gfile.exists(filepath):
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning)
return {}

with tf.io.gfile.GFile(tags_csv) as fp:
with tf.io.gfile.GFile(tags_csv, "r") as fp:
csv_reader = csv.reader(fp, delimiter=',')
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}

Expand Down Expand Up @@ -315,7 +315,7 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning)
return {}

with tf.io.gfile.GFile(config_yaml) as fp:
with tf.io.gfile.GFile(config_yaml, "r") as fp:
tags = yaml.load(fp, Loader=yaml.SafeLoader)

return tags
Expand Down
17 changes: 7 additions & 10 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,11 @@
from pkg_resources import parse_version
from torch.utils.tensorboard import SummaryWriter

# we want this for tf.iogfile, which if tf is installed gives full tf,
# otherwise gives a pruned down version which works for some file backends but
# not all
from tensorboard.compat import tf

from pytorch_lightning import _logger as log
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.io import gfile


class TensorBoardLogger(LightningLoggerBase):
Expand Down Expand Up @@ -102,7 +98,8 @@ def experiment(self) -> SummaryWriter:
if self._experiment is not None:
return self._experiment

tf.io.gfile.makedirs(self.root_dir)
if not gfile.exists(self.root_dir):
gfile.makedirs(self.root_dir)
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment

Expand Down Expand Up @@ -150,7 +147,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
def save(self) -> None:
super().save()
dir_path = self.log_dir
if not tf.io.gfile.isdir(dir_path):
if not gfile.isdir(dir_path):
dir_path = self.save_dir

# prepare the file path
Expand All @@ -176,13 +173,13 @@ def version(self) -> int:
def _get_next_version(self):
root_dir = os.path.join(self.save_dir, self.name)

if not tf.io.gfile.isdir(root_dir):
if not gfile.isdir(root_dir):
log.warning('Missing logger folder: %s', root_dir)
return 0

existing_versions = []
for d in tf.io.gfile.listdir(root_dir):
if tf.io.gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
for d in gfile.listdir(root_dir):
if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
existing_versions.append(int(d.split("_")[1]))

if len(existing_versions) == 0:
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.io import gfile


class TrainerCallbackConfigMixin(ABC):
Expand Down Expand Up @@ -67,7 +68,8 @@ def configure_checkpoint_callback(self):
monitor_key = 'loss' if train_step_only else 'val_loss'

if self.checkpoint_callback is True:
os.makedirs(ckpt_path, exist_ok=True)
if not gfile.exists(ckpt_path):
gfile.makedirs(ckpt_path)
self.checkpoint_callback = ModelCheckpoint(
filepath=ckpt_path,
monitor=monitor_key
Expand All @@ -77,7 +79,9 @@ def configure_checkpoint_callback(self):
and self.checkpoint_callback.dirpath is None:
self.checkpoint_callback.dirpath = ckpt_path
self.checkpoint_callback.filename = '{epoch}'
os.makedirs(self.checkpoint_callback.dirpath, exist_ok=True)
if not gfile.exists(self.checkpoint_callback.dirpath):
gfile.makedirs(self.checkpoint_callback.dirpath)

elif self.checkpoint_callback is False:
self.checkpoint_callback = None

Expand Down
14 changes: 8 additions & 6 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
)
from pytorch_lightning.utilities import rank_zero_warn, parsing
from pytorch_lightning.utilities.io import load as pl_load
from pytorch_lightning.utilities.io import gfile

try:
import torch_xla
Expand Down Expand Up @@ -375,8 +376,8 @@ def restore_hpc_weights_if_needed(self, model: LightningModule):

# look for hpc weights
folderpath = self.weights_save_path
if os.path.exists(folderpath):
files = os.listdir(folderpath)
if gfile.exists(folderpath):
files = gfile.listdir(folderpath)
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]

# if hpc weights exist restore model
Expand Down Expand Up @@ -451,15 +452,16 @@ def restore_training_state(self, checkpoint):
# ----------------------------------
def hpc_save(self, folderpath: str, logger):
# make sure the checkpoint folder exists
os.makedirs(folderpath, exist_ok=True)
if not gfile.exists(folderpath):
gfile.makedirs(folderpath)

# save logger to make sure we get all the metrics
logger.save()

ckpt_number = self.max_ckpt_in_folder(folderpath) + 1

if not os.path.exists(folderpath):
os.makedirs(folderpath, exist_ok=True)
if not gfile.exists(folderpath):
gfile.makedirs(folderpath)
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')

# give model a chance to do something on hpc_save
Expand Down Expand Up @@ -509,7 +511,7 @@ def hpc_load(self, folderpath, on_gpu):
log.info(f'restored hpc model from: {filepath}')

def max_ckpt_in_folder(self, path, name_key='ckpt_'):
files = os.listdir(path)
files = gfile.listdir(path)
files = [x for x in files if name_key in x]
if len(files) == 0:
return 0
Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/utilities/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

from urllib.parse import urlparse

# we want this for tf.io.gfile, which if tf is installed gives full tf,
# otherwise gives a pruned down version which works for some file backends but
# not all
from tensorboard.compat import tf

gfile = tf.io.gfile


def load(path_or_url: str, map_location=None):
parsed = urlparse(path_or_url)
Expand Down
6 changes: 3 additions & 3 deletions tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def log_metrics(self, metrics, step):
super().log_metrics(metrics, step)
self.history.append((step, metrics))

logger_args = _get_logger_args(logger_class, tmpdir)
logger_args = _get_logger_args(logger_class, str(tmpdir))
logger = StoreHistoryLogger(**logger_args)

trainer = Trainer(
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
import atexit
monkeypatch.setattr(atexit, 'register', lambda _: None)

logger_args = _get_logger_args(logger_class, tmpdir)
logger_args = _get_logger_args(logger_class, str(tmpdir))
logger = logger_class(**logger_args)

# test pickling loggers
Expand All @@ -109,7 +109,7 @@ def test_logger_reset_correctly(tmpdir, extra_params):
model = EvalModelTemplate()

trainer = Trainer(
default_root_dir=tmpdir,
default_root_dir=str(tmpdir),
**extra_params
)
logger1 = trainer.logger
Expand Down
36 changes: 16 additions & 20 deletions tests/models/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_cpu_slurm_save_load(tmpdir):
model = EvalModelTemplate(**hparams)

# logger file to get meta
logger = tutils.get_default_logger(tmpdir)
logger = tutils.get_default_logger(str(tmpdir))
version = logger.version

# fit model
Expand All @@ -28,7 +28,7 @@ def test_cpu_slurm_save_load(tmpdir):
logger=logger,
train_percent_check=0.2,
val_percent_check=0.2,
checkpoint_callback=ModelCheckpoint(tmpdir)
checkpoint_callback=ModelCheckpoint(str(tmpdir)),
)
result = trainer.fit(model)
real_global_step = trainer.global_step
Expand All @@ -54,17 +54,13 @@ def test_cpu_slurm_save_load(tmpdir):

# test HPC saving
# simulate snapshot on slurm
saved_filepath = trainer.hpc_save(tmpdir, logger)
saved_filepath = trainer.hpc_save(str(tmpdir), logger)
assert os.path.exists(saved_filepath)

# new logger file to get meta
logger = tutils.get_default_logger(tmpdir, version=version)
logger = tutils.get_default_logger(str(tmpdir), version=version)

trainer = Trainer(
max_epochs=1,
logger=logger,
checkpoint_callback=ModelCheckpoint(tmpdir),
)
trainer = Trainer(max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(str(tmpdir)),)
model = EvalModelTemplate(**hparams)

# set the epoch start hook so we can predict before the model does the full training
Expand All @@ -87,7 +83,7 @@ def test_early_stopping_cpu_model(tmpdir):
"""Test each of the trainer options."""
stopping = EarlyStopping(monitor='val_loss', min_delta=0.1)
trainer_options = dict(
default_root_dir=tmpdir,
default_root_dir=str(tmpdir),
early_stop_callback=stopping,
max_epochs=2,
gradient_clip_val=1.0,
Expand Down Expand Up @@ -116,7 +112,7 @@ def test_multi_cpu_model_ddp(tmpdir):
tutils.set_random_master_port()

trainer_options = dict(
default_root_dir=tmpdir,
default_root_dir=str(tmpdir),
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
Expand All @@ -133,7 +129,7 @@ def test_multi_cpu_model_ddp(tmpdir):
def test_lbfgs_cpu_model(tmpdir):
"""Test each of the trainer options."""
trainer_options = dict(
default_root_dir=tmpdir,
default_root_dir=str(tmpdir),
max_epochs=1,
progress_bar_refresh_rate=0,
weights_summary='top',
Expand All @@ -152,7 +148,7 @@ def test_lbfgs_cpu_model(tmpdir):
def test_default_logger_callbacks_cpu_model(tmpdir):
"""Test each of the trainer options."""
trainer_options = dict(
default_root_dir=tmpdir,
default_root_dir=str(tmpdir),
max_epochs=1,
gradient_clip_val=1.0,
overfit_pct=0.20,
Expand All @@ -174,14 +170,14 @@ def test_running_test_after_fitting(tmpdir):
model = EvalModelTemplate()

# logger file to get meta
logger = tutils.get_default_logger(tmpdir)
logger = tutils.get_default_logger(str(tmpdir))

# logger file to get weights
checkpoint = tutils.init_checkpoint_callback(logger)

# fit model
trainer = Trainer(
default_root_dir=tmpdir,
default_root_dir=str(tmpdir),
progress_bar_refresh_rate=0,
max_epochs=2,
train_percent_check=0.4,
Expand All @@ -205,7 +201,7 @@ def test_running_test_no_val(tmpdir):
model = EvalModelTemplate()

# logger file to get meta
logger = tutils.get_default_logger(tmpdir)
logger = tutils.get_default_logger(str(tmpdir))

# logger file to get weights
checkpoint = tutils.init_checkpoint_callback(logger)
Expand Down Expand Up @@ -284,7 +280,7 @@ def test_simple_cpu(tmpdir):

# fit model
trainer = Trainer(
default_root_dir=tmpdir,
default_root_dir=str(tmpdir),
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.1,
Expand All @@ -298,7 +294,7 @@ def test_simple_cpu(tmpdir):
def test_cpu_model(tmpdir):
"""Make sure model trains on CPU."""
trainer_options = dict(
default_root_dir=tmpdir,
default_root_dir=str(tmpdir),
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
Expand All @@ -313,7 +309,7 @@ def test_cpu_model(tmpdir):
def test_all_features_cpu_model(tmpdir):
"""Test each of the trainer options."""
trainer_options = dict(
default_root_dir=tmpdir,
default_root_dir=str(tmpdir),
gradient_clip_val=1.0,
overfit_pct=0.20,
track_grad_norm=2,
Expand Down Expand Up @@ -387,7 +383,7 @@ def train_dataloader(self):

# fit model
trainer = Trainer(
default_root_dir=tmpdir,
default_root_dir=str(tmpdir),
max_epochs=1,
truncated_bptt_steps=truncated_bptt_steps,
val_percent_check=0,
Expand Down
Loading

0 comments on commit b211832

Please sign in to comment.