From f2d9c97f987aabf772a0312f60b14204134d7fa3 Mon Sep 17 00:00:00 2001 From: Brendan Fahy Date: Fri, 12 Jun 2020 15:51:56 -0400 Subject: [PATCH] use gfile to support remote directories Tests all use the `tmpfile` fixture which provides a py.path.local which is incompatible with the compat.gfile. The contract in many places is type str or Optional[str] which py.path.local is not. I hope that folks are not passing in path.local objects, if so this change will break them. The type annotations say to use str, so this should be ok. The other option is to just explicitly convert to str as to not break people using an incorrect type (like the tests were doing) --- .../callbacks/model_checkpoint.py | 25 +++++++--- pytorch_lightning/core/saving.py | 21 +++++--- pytorch_lightning/loggers/tensorboard.py | 12 +++-- pytorch_lightning/trainer/callback_config.py | 8 ++- pytorch_lightning/trainer/training_io.py | 14 +++--- pytorch_lightning/utilities/io.py | 7 +++ tests/loggers/test_all.py | 6 +-- tests/models/test_cpu.py | 36 ++++++------- tests/models/test_gpu.py | 6 +-- tests/trainer/test_lr_finder.py | 18 +++---- tests/trainer/test_trainer.py | 50 +++++++++---------- tests/trainer/test_trainer_tricks.py | 10 ++-- 12 files changed, 120 insertions(+), 93 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 672670da72b55f..6bc64bdcbd8a85 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -16,6 +16,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only +from pytorch_lightning.utilities.io import gfile class ModelCheckpoint(Callback): @@ -97,7 +98,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False, mode: str = 'auto', period: int = 1, prefix: str = ''): super().__init__() - if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0: + if(filepath): + filepath = str(filepath) # the tests pass in a py.path.local but we want a str + if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0: rank_zero_warn( f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." "All files in this directory will be deleted when a checkpoint is saved!" @@ -109,12 +112,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve if filepath is None: # will be determined by trainer at runtime self.dirpath, self.filename = None, None else: - if os.path.isdir(filepath): + if gfile.isdir(filepath): self.dirpath, self.filename = filepath, '{epoch}' else: filepath = os.path.realpath(filepath) self.dirpath, self.filename = os.path.split(filepath) - os.makedirs(self.dirpath, exist_ok=True) + if not gfile.exists(self.dirpath): + gfile.makedirs(self.dirpath) self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only @@ -156,12 +160,19 @@ def kth_best_model(self): return self.kth_best_model_path def _del_model(self, filepath): - if os.path.isfile(filepath): - os.remove(filepath) + if gfile.exists(filepath): + try: + # in compat mode, remove is not implemented so if running this + # against an actual remove file system and the correct remote + # dependencies exist then this will work fine. + gfile.remove(filepath) + except AttributeError: + os.remove(filepath) def _save_model(self, filepath): # make paths - os.makedirs(os.path.dirname(filepath), exist_ok=True) + if not gfile.exists(os.path.dirname(filepath)): + gfile.makedirs(os.path.dirname(filepath)) # delegate the saving to the model if self.save_function is not None: @@ -249,7 +260,7 @@ def on_validation_end(self, trainer, pl_module): filepath = self.format_checkpoint_name(epoch, metrics) version_cnt = 0 - while os.path.isfile(filepath): + while gfile.exists(filepath): filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt) # this epoch called before version_cnt += 1 diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index ed299bb8a816e0..542bcb230aa9a5 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -12,6 +12,11 @@ from pytorch_lightning.utilities import rank_zero_warn, AttributeDict from pytorch_lightning.utilities.io import load as pl_load +# we want this for tf.iogfile, which if tf is installed gives full tf, +# otherwise gives a pruned down version which works for some file backends but +# not all +from tensorboard.compat import tf + PRIMITIVE_TYPES = (bool, int, float, str) ALLOWED_CONFIG_TYPES = (AttributeDict, dict, Namespace) try: @@ -269,11 +274,11 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: True >>> os.remove(path_csv) """ - if not os.path.isfile(tags_csv): + if not tf.io.gfile.exists(tags_csv): rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning) return {} - with open(tags_csv) as fp: + with tf.io.gfile.GFile(tags_csv, "r") as fp: csv_reader = csv.reader(fp, delimiter=',') tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]} @@ -281,13 +286,13 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None: - if not os.path.isdir(os.path.dirname(tags_csv)): + if not tf.io.gfile.isdir(os.path.dirname(tags_csv)): raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.') if isinstance(hparams, Namespace): hparams = vars(hparams) - with open(tags_csv, 'w') as fp: + with tf.io.gfile.GFile(tags_csv, 'w') as fp: fieldnames = ['key', 'value'] writer = csv.DictWriter(fp, fieldnames=fieldnames) writer.writerow({'key': 'key', 'value': 'value'}) @@ -306,24 +311,24 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: True >>> os.remove(path_yaml) """ - if not os.path.isfile(config_yaml): + if not tf.io.gfile.exists(config_yaml): rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning) return {} - with open(config_yaml) as fp: + with tf.io.gfile.GFile(config_yaml, "r") as fp: tags = yaml.load(fp, Loader=yaml.SafeLoader) return tags def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: - if not os.path.isdir(os.path.dirname(config_yaml)): + if not tf.io.gfile.isdir(os.path.dirname(config_yaml)): raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.') if isinstance(hparams, Namespace): hparams = vars(hparams) - with open(config_yaml, 'w', newline='') as fp: + with tf.io.gfile.GFile(config_yaml, 'w') as fp: yaml.dump(hparams, fp) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 12ce47398d3b37..a82c601176bc31 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -17,6 +17,7 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.io import gfile class TensorBoardLogger(LightningLoggerBase): @@ -97,7 +98,8 @@ def experiment(self) -> SummaryWriter: if self._experiment is not None: return self._experiment - os.makedirs(self.root_dir, exist_ok=True) + if not gfile.exists(self.root_dir): + gfile.makedirs(self.root_dir) self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) return self._experiment @@ -145,7 +147,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> def save(self) -> None: super().save() dir_path = self.log_dir - if not os.path.isdir(dir_path): + if not gfile.isdir(dir_path): dir_path = self.save_dir # prepare the file path @@ -171,13 +173,13 @@ def version(self) -> int: def _get_next_version(self): root_dir = os.path.join(self.save_dir, self.name) - if not os.path.isdir(root_dir): + if not gfile.isdir(root_dir): log.warning('Missing logger folder: %s', root_dir) return 0 existing_versions = [] - for d in os.listdir(root_dir): - if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): + for d in gfile.listdir(root_dir): + if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): existing_versions.append(int(d.split("_")[1])) if len(existing_versions) == 0: diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 5e490a106826bd..a42043255d13f1 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -6,6 +6,7 @@ from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.io import gfile class TrainerCallbackConfigMixin(ABC): @@ -67,7 +68,8 @@ def configure_checkpoint_callback(self): monitor_key = 'loss' if train_step_only else 'val_loss' if self.checkpoint_callback is True: - os.makedirs(ckpt_path, exist_ok=True) + if not gfile.exists(ckpt_path): + gfile.makedirs(ckpt_path) self.checkpoint_callback = ModelCheckpoint( filepath=ckpt_path, monitor=monitor_key @@ -77,7 +79,9 @@ def configure_checkpoint_callback(self): and self.checkpoint_callback.dirpath is None: self.checkpoint_callback.dirpath = ckpt_path self.checkpoint_callback.filename = '{epoch}' - os.makedirs(self.checkpoint_callback.dirpath, exist_ok=True) + if not gfile.exists(self.checkpoint_callback.dirpath): + gfile.makedirs(self.checkpoint_callback.dirpath) + elif self.checkpoint_callback is False: self.checkpoint_callback = None diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 6f4e85d5b28e44..39e45a134a4cb4 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -102,6 +102,7 @@ ) from pytorch_lightning.utilities import rank_zero_warn, parsing from pytorch_lightning.utilities.io import load as pl_load +from pytorch_lightning.utilities.io import gfile try: import torch_xla @@ -375,8 +376,8 @@ def restore_hpc_weights_if_needed(self, model: LightningModule): # look for hpc weights folderpath = self.weights_save_path - if os.path.exists(folderpath): - files = os.listdir(folderpath) + if gfile.exists(folderpath): + files = gfile.listdir(folderpath) hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x] # if hpc weights exist restore model @@ -451,15 +452,16 @@ def restore_training_state(self, checkpoint): # ---------------------------------- def hpc_save(self, folderpath: str, logger): # make sure the checkpoint folder exists - os.makedirs(folderpath, exist_ok=True) + if not gfile.exists(folderpath): + gfile.makedirs(folderpath) # save logger to make sure we get all the metrics logger.save() ckpt_number = self.max_ckpt_in_folder(folderpath) + 1 - if not os.path.exists(folderpath): - os.makedirs(folderpath, exist_ok=True) + if not gfile.exists(folderpath): + gfile.makedirs(folderpath) filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt') # give model a chance to do something on hpc_save @@ -509,7 +511,7 @@ def hpc_load(self, folderpath, on_gpu): log.info(f'restored hpc model from: {filepath}') def max_ckpt_in_folder(self, path, name_key='ckpt_'): - files = os.listdir(path) + files = gfile.listdir(path) files = [x for x in files if name_key in x] if len(files) == 0: return 0 diff --git a/pytorch_lightning/utilities/io.py b/pytorch_lightning/utilities/io.py index b2f9ffa788541f..c24ba2ae8edf46 100644 --- a/pytorch_lightning/utilities/io.py +++ b/pytorch_lightning/utilities/io.py @@ -2,6 +2,13 @@ from urllib.parse import urlparse +# we want this for tf.io.gfile, which if tf is installed gives full tf, +# otherwise gives a pruned down version which works for some file backends but +# not all +from tensorboard.compat import tf + +gfile = tf.io.gfile + def load(path_or_url: str, map_location=None): parsed = urlparse(path_or_url) diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index f8a8fead41f586..a6d0f683516e11 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -46,7 +46,7 @@ def log_metrics(self, metrics, step): super().log_metrics(metrics, step) self.history.append((step, metrics)) - logger_args = _get_logger_args(logger_class, tmpdir) + logger_args = _get_logger_args(logger_class, str(tmpdir)) logger = StoreHistoryLogger(**logger_args) trainer = Trainer( @@ -82,7 +82,7 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class): import atexit monkeypatch.setattr(atexit, 'register', lambda _: None) - logger_args = _get_logger_args(logger_class, tmpdir) + logger_args = _get_logger_args(logger_class, str(tmpdir)) logger = logger_class(**logger_args) # test pickling loggers @@ -109,7 +109,7 @@ def test_logger_reset_correctly(tmpdir, extra_params): model = EvalModelTemplate() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), **extra_params ) logger1 = trainer.logger diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index 051db81d1b1657..0656723b3f334c 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -19,7 +19,7 @@ def test_cpu_slurm_save_load(tmpdir): model = EvalModelTemplate(**hparams) # logger file to get meta - logger = tutils.get_default_logger(tmpdir) + logger = tutils.get_default_logger(str(tmpdir)) version = logger.version # fit model @@ -28,7 +28,7 @@ def test_cpu_slurm_save_load(tmpdir): logger=logger, train_percent_check=0.2, val_percent_check=0.2, - checkpoint_callback=ModelCheckpoint(tmpdir) + checkpoint_callback=ModelCheckpoint(str(tmpdir)), ) result = trainer.fit(model) real_global_step = trainer.global_step @@ -54,17 +54,13 @@ def test_cpu_slurm_save_load(tmpdir): # test HPC saving # simulate snapshot on slurm - saved_filepath = trainer.hpc_save(tmpdir, logger) + saved_filepath = trainer.hpc_save(str(tmpdir), logger) assert os.path.exists(saved_filepath) # new logger file to get meta - logger = tutils.get_default_logger(tmpdir, version=version) + logger = tutils.get_default_logger(str(tmpdir), version=version) - trainer = Trainer( - max_epochs=1, - logger=logger, - checkpoint_callback=ModelCheckpoint(tmpdir), - ) + trainer = Trainer(max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(str(tmpdir)),) model = EvalModelTemplate(**hparams) # set the epoch start hook so we can predict before the model does the full training @@ -87,7 +83,7 @@ def test_early_stopping_cpu_model(tmpdir): """Test each of the trainer options.""" stopping = EarlyStopping(monitor='val_loss', min_delta=0.1) trainer_options = dict( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), early_stop_callback=stopping, max_epochs=2, gradient_clip_val=1.0, @@ -116,7 +112,7 @@ def test_multi_cpu_model_ddp(tmpdir): tutils.set_random_master_port() trainer_options = dict( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.4, @@ -133,7 +129,7 @@ def test_multi_cpu_model_ddp(tmpdir): def test_lbfgs_cpu_model(tmpdir): """Test each of the trainer options.""" trainer_options = dict( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=1, progress_bar_refresh_rate=0, weights_summary='top', @@ -152,7 +148,7 @@ def test_lbfgs_cpu_model(tmpdir): def test_default_logger_callbacks_cpu_model(tmpdir): """Test each of the trainer options.""" trainer_options = dict( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=1, gradient_clip_val=1.0, overfit_pct=0.20, @@ -174,14 +170,14 @@ def test_running_test_after_fitting(tmpdir): model = EvalModelTemplate() # logger file to get meta - logger = tutils.get_default_logger(tmpdir) + logger = tutils.get_default_logger(str(tmpdir)) # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) # fit model trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), progress_bar_refresh_rate=0, max_epochs=2, train_percent_check=0.4, @@ -205,7 +201,7 @@ def test_running_test_no_val(tmpdir): model = EvalModelTemplate() # logger file to get meta - logger = tutils.get_default_logger(tmpdir) + logger = tutils.get_default_logger(str(tmpdir)) # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) @@ -284,7 +280,7 @@ def test_simple_cpu(tmpdir): # fit model trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=1, val_percent_check=0.1, train_percent_check=0.1, @@ -298,7 +294,7 @@ def test_simple_cpu(tmpdir): def test_cpu_model(tmpdir): """Make sure model trains on CPU.""" trainer_options = dict( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.4, @@ -313,7 +309,7 @@ def test_cpu_model(tmpdir): def test_all_features_cpu_model(tmpdir): """Test each of the trainer options.""" trainer_options = dict( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), gradient_clip_val=1.0, overfit_pct=0.20, track_grad_norm=2, @@ -387,7 +383,7 @@ def train_dataloader(self): # fit model trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=1, truncated_bptt_steps=truncated_bptt_steps, val_percent_check=0, diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 80249a727ccbbd..90f543bc2fb1c2 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -18,7 +18,7 @@ def test_single_gpu_model(tmpdir, gpus): """Make sure single GPU works (DP mode).""" trainer_options = dict( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.1, @@ -38,7 +38,7 @@ def test_multi_gpu_model(tmpdir, backend): tutils.set_random_master_port() trainer_options = dict( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=1, train_percent_check=0.4, val_percent_check=0.2, @@ -84,7 +84,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir): def test_multi_gpu_none_backend(tmpdir): """Make sure when using multiple GPUs the user can't use `distributed_backend = None`.""" trainer_options = dict( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.1, diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index d0becff0918c65..c635d87788363b 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -14,7 +14,7 @@ def test_error_on_more_than_1_optimizer(tmpdir): # logger file to get meta trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=1 ) @@ -29,7 +29,7 @@ def test_model_reset_correctly(tmpdir): # logger file to get meta trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=1 ) @@ -51,7 +51,7 @@ def test_trainer_reset_correctly(tmpdir): # logger file to get meta trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=1 ) @@ -81,7 +81,7 @@ def test_trainer_arg_bool(tmpdir): # logger file to get meta trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=2, auto_lr_find=True ) @@ -100,7 +100,7 @@ def test_trainer_arg_str(tmpdir): before_lr = model.my_fancy_lr # logger file to get meta trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=2, auto_lr_find='my_fancy_lr' ) @@ -120,7 +120,7 @@ def test_call_to_trainer_method(tmpdir): before_lr = hparams.get('learning_rate') # logger file to get meta trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=2, ) @@ -144,7 +144,7 @@ def test_accumulation_and_early_stopping(tmpdir): before_lr = hparams.get('learning_rate') # logger file to get meta trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), accumulate_grad_batches=2, ) @@ -167,7 +167,7 @@ def test_suggestion_parameters_work(tmpdir): # logger file to get meta trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=3, ) @@ -187,7 +187,7 @@ def test_suggestion_with_non_finite_values(tmpdir): # logger file to get meta trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=3 ) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index f391260c139f42..0c1063ea4ca8ee 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -30,12 +30,12 @@ def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): model = EvalModelTemplate() # logger file to get meta - logger = tutils.get_default_logger(tmpdir) + logger = tutils.get_default_logger(str(tmpdir)) trainer = Trainer( max_epochs=1, logger=logger, - checkpoint_callback=ModelCheckpoint(tmpdir) + checkpoint_callback=ModelCheckpoint(str(tmpdir)) ) # fit model result = trainer.fit(model) @@ -51,7 +51,7 @@ def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): assert LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in ckpt.keys(), 'module_arguments missing from checkpoints' # load new model - hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) + hparams_path = tutils.get_data_path(logger, path_dir=str(tmpdir)) hparams_path = os.path.join(hparams_path, 'hparams.yaml') ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' if url_ckpt else new_weights_path model_2 = EvalModelTemplate.load_from_checkpoint( @@ -65,18 +65,18 @@ def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir - monkeypatch.setenv('TORCH_HOME', tmpdir) + monkeypatch.setenv('TORCH_HOME', str(tmpdir)) model = EvalModelTemplate() # logger file to get meta - logger = tutils.get_default_logger(tmpdir) + logger = tutils.get_default_logger(str(tmpdir)) # fit model trainer = Trainer( max_epochs=1, logger=logger, - checkpoint_callback=ModelCheckpoint(tmpdir) + checkpoint_callback=ModelCheckpoint(str(tmpdir)) ) result = trainer.fit(model) @@ -88,7 +88,7 @@ def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): trainer.save_checkpoint(new_weights_path) # load new model - hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) + hparams_path = tutils.get_data_path(logger, path_dir=str(tmpdir)) hparams_path = os.path.join(hparams_path, 'hparams.yaml') ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' if url_ckpt else new_weights_path model_2 = EvalModelTemplate.load_from_checkpoint( @@ -163,7 +163,7 @@ def _optimizer_step(self, epoch, batch_idx, optimizer, train_percent_check=0.1, val_percent_check=0.1, max_epochs=2, - default_root_dir=tmpdir) + default_root_dir=str(tmpdir)) # for the test trainer.optimizer_step = _optimizer_step @@ -179,13 +179,13 @@ def test_loading_meta_tags(tmpdir): hparams = EvalModelTemplate.get_default_hparams() # save tags - logger = tutils.get_default_logger(tmpdir) + logger = tutils.get_default_logger(str(tmpdir)) logger.log_hyperparams(Namespace(some_str='a_str', an_int=1, a_float=2.0)) logger.log_hyperparams(hparams) logger.save() # load hparams - path_expt_dir = tutils.get_data_path(logger, path_dir=tmpdir) + path_expt_dir = tutils.get_data_path(logger, path_dir=str(tmpdir)) hparams_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_HPARAMS_FILE) hparams = load_hparams_from_yaml(hparams_path) @@ -204,13 +204,13 @@ def test_loading_yaml(tmpdir): hparams = EvalModelTemplate.get_default_hparams() # save tags - logger = tutils.get_default_logger(tmpdir) + logger = tutils.get_default_logger(str(tmpdir)) logger.log_hyperparams(Namespace(some_str='a_str', an_int=1, a_float=2.0)) logger.log_hyperparams(hparams) logger.save() # load hparams - path_expt_dir = tutils.get_data_path(logger, path_dir=tmpdir) + path_expt_dir = tutils.get_data_path(logger, path_dir=str(tmpdir)) hparams_path = os.path.join(path_expt_dir, 'hparams.yaml') tags = load_hparams_from_yaml(hparams_path) @@ -262,7 +262,7 @@ def mock_save_function(filepath, *args): # simulated losses losses = [10, 9, 2.8, 5, 2.5] - checkpoint_callback = ModelCheckpoint(tmpdir, save_top_k=save_top_k, save_last=save_last, + checkpoint_callback = ModelCheckpoint(str(tmpdir), save_top_k=save_top_k, save_last=save_last, prefix=file_prefix, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() @@ -291,7 +291,7 @@ def test_model_checkpoint_only_weights(tmpdir): trainer = Trainer( max_epochs=1, - checkpoint_callback=ModelCheckpoint(tmpdir, save_weights_only=True) + checkpoint_callback=ModelCheckpoint(str(tmpdir), save_weights_only=True) ) # fit model result = trainer.fit(model) @@ -367,8 +367,8 @@ def increment_on_load_checkpoint(self, _): max_epochs=2, train_percent_check=0.65, val_percent_check=1, - checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1), - default_root_dir=tmpdir, + checkpoint_callback=ModelCheckpoint(str(tmpdir), save_top_k=-1), + default_root_dir=str(tmpdir), early_stop_callback=False, val_check_interval=1., ) @@ -423,7 +423,7 @@ def test_trainer_max_steps_and_epochs(tmpdir): # define less train steps than epochs trainer_options.update( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=3, max_steps=num_train_samples + 10 ) @@ -458,7 +458,7 @@ def test_trainer_min_steps_and_epochs(tmpdir): # define callback for stopping the model and default epochs trainer_options.update( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), early_stop_callback=EarlyStopping(monitor='val_loss', min_delta=1.0), val_check_interval=2, min_epochs=1, @@ -501,7 +501,7 @@ def test_benchmark_option(tmpdir): # fit model trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=1, benchmark=True, ) @@ -609,7 +609,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): # fit model trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_steps=(model.test_batch_inf_loss + 1), terminate_on_nan=True ) @@ -634,7 +634,7 @@ def on_after_backward(self): model = CurrentModel() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_steps=(model.test_batch_nan + 1), terminate_on_nan=True ) @@ -669,7 +669,7 @@ def on_batch_start(self, trainer, pl_module): train_percent_check=0.2, progress_bar_refresh_rate=0, logger=False, - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), ) assert not trainer.interrupted trainer.fit(model) @@ -693,7 +693,7 @@ def _optimizer_step(*args, **kwargs): max_steps=1, max_epochs=1, gradient_clip_val=1.0, - default_root_dir=tmpdir + default_root_dir=str(tmpdir) ) # for the test @@ -705,7 +705,7 @@ def _optimizer_step(*args, **kwargs): def test_gpu_choice(tmpdir): trainer_options = dict( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), ) # Only run if CUDA is available if not torch.cuda.is_available(): @@ -849,7 +849,7 @@ def __init__(self, **kwargs): def test_trainer_pickle(tmpdir): trainer = Trainer( max_epochs=1, - default_root_dir=tmpdir + default_root_dir=str(tmpdir) ) pickle.dumps(trainer) cloudpickle.dumps(trainer) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 973ed32e7cd927..5f301aa8e0ced3 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -15,7 +15,7 @@ def test_model_reset_correctly(tmpdir): # logger file to get meta trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=1 ) @@ -38,7 +38,7 @@ def test_trainer_reset_correctly(tmpdir): # logger file to get meta trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=1 ) @@ -77,7 +77,7 @@ def test_trainer_arg(tmpdir, scale_arg): before_batch_size = hparams.get('batch_size') # logger file to get meta trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=1, auto_scale_batch_size=scale_arg, ) @@ -99,7 +99,7 @@ def test_call_to_trainer_method(tmpdir, scale_method): before_batch_size = hparams.get('batch_size') # logger file to get meta trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=1, ) @@ -118,7 +118,7 @@ def test_error_on_dataloader_passed_to_fit(tmpdir): # only train passed to fit model = EvalModelTemplate() trainer = Trainer( - default_root_dir=tmpdir, + default_root_dir=str(tmpdir), max_epochs=1, val_percent_check=0.1, train_percent_check=0.2,