From 1f2da710697bd4b090dc3b74bf6a583f3ea3d913 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Apr 2020 11:38:52 +0200 Subject: [PATCH 1/5] Improved docs for callbacks (#1370) * improved docs for callbacks * class references * make doctest pass * doctests * fix lines too long * fix line too long * fix permission error in doctest * Apply suggestions from code review Co-Authored-By: Jirka Borovec * fix doctest * fix default Co-authored-by: Jirka Borovec --- docs/source/callbacks.rst | 44 ++++++++------- docs/source/early_stopping.rst | 27 ++++----- pytorch_lightning/callbacks/base.py | 4 +- pytorch_lightning/callbacks/early_stopping.py | 22 ++++---- .../gradient_accumulation_scheduler.py | 13 +++-- .../callbacks/model_checkpoint.py | 56 ++++++++++--------- 6 files changed, 91 insertions(+), 75 deletions(-) diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index 364bf07213d196..ffb7671b7211d6 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -7,7 +7,7 @@ Callbacks ========= Lightning has a callback system to execute arbitrary code. Callbacks should capture NON-ESSENTIAL -logic that is NOT required for your LightningModule to run. +logic that is NOT required for your :class:`~pytorch_lightning.core.LightningModule` to run. An overall Lightning system should have: @@ -15,27 +15,29 @@ An overall Lightning system should have: 2. LightningModule for all research code. 3. Callbacks for non-essential code. -Example -.. code-block:: python - - import pytorch_lightning as pl - - class MyPrintingCallback(pl.Callback): - - def on_init_start(self, trainer): - print('Starting to init trainer!') - - def on_init_end(self, trainer): - print('trainer is init now') - - def on_train_end(self, trainer, pl_module): - print('do something when training ends') - - # pass to trainer - trainer = pl.Trainer(callbacks=[MyPrintingCallback()]) - -We successfully extended functionality without polluting our super clean LightningModule research code +Example: + +.. doctest:: + + >>> import pytorch_lightning as pl + >>> class MyPrintingCallback(pl.Callback): + ... + ... def on_init_start(self, trainer): + ... print('Starting to init trainer!') + ... + ... def on_init_end(self, trainer): + ... print('trainer is init now') + ... + ... def on_train_end(self, trainer, pl_module): + ... print('do something when training ends') + ... + >>> trainer = pl.Trainer(callbacks=[MyPrintingCallback()]) + Starting to init trainer! + trainer is init now + +We successfully extended functionality without polluting our super clean +:class:`~pytorch_lightning.core.LightningModule` research code. --------- diff --git a/docs/source/early_stopping.rst b/docs/source/early_stopping.rst index 585627a3b0a09e..e94cb079a8ee3b 100644 --- a/docs/source/early_stopping.rst +++ b/docs/source/early_stopping.rst @@ -11,24 +11,23 @@ Enable Early Stopping --------------------- There are two ways to enable early stopping. -.. seealso:: - :class:`~pytorch_lightning.trainer.trainer.Trainer` +.. doctest:: -.. code-block:: python + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import EarlyStopping # A) Set early_stop_callback to True. Will look for 'val_loss' # in validation_epoch_end() return dict. If it is not found an error is raised. - trainer = Trainer(early_stop_callback=True) - + >>> trainer = Trainer(early_stop_callback=True) # B) Or configure your own callback - early_stop_callback = EarlyStopping( - monitor='val_loss', - min_delta=0.00, - patience=3, - verbose=False, - mode='min' - ) - trainer = Trainer(early_stop_callback=early_stop_callback) + >>> early_stop_callback = EarlyStopping( + ... monitor='val_loss', + ... min_delta=0.00, + ... patience=3, + ... verbose=False, + ... mode='min' + ... ) + >>> trainer = Trainer(early_stop_callback=early_stop_callback) In any case, the callback will fall back to the training metrics (returned in :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`, @@ -37,6 +36,8 @@ looking for a key to monitor if validation is disabled or :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` is not defined. +.. seealso:: + :class:`~pytorch_lightning.trainer.trainer.Trainer` Disable Early Stopping ---------------------- diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index c7898e9dc1cbfe..9bf576b0c1926d 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -1,7 +1,9 @@ r""" Callback Base ============= - Abstract base class used to build new callbacks. + +Abstract base class used to build new callbacks. + """ import abc diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 53d6b4cfb1888f..f477cd724b47b1 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -1,6 +1,7 @@ r""" Early Stopping ============== + Stop training when a monitored quantity has stopped improving. """ @@ -17,31 +18,30 @@ class EarlyStopping(Callback): r""" Args: - monitor (str): quantity to be monitored. Default: ``'val_loss'``. - min_delta (float): minimum change in the monitored quantity + monitor: quantity to be monitored. Default: ``'val_loss'``. + min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than `min_delta`, will count as no improvement. Default: ``0``. - patience (int): number of epochs with no improvement + patience: number of epochs with no improvement after which training will be stopped. Default: ``0``. - verbose (bool): verbosity mode. Default: ``False``. - mode (str): one of {auto, min, max}. In `min` mode, + verbose: verbosity mode. Default: ``False``. + mode: one of {auto, min, max}. In `min` mode, training will stop when the quantity monitored has stopped decreasing; in `max` mode it will stop when the quantity monitored has stopped increasing; in `auto` mode, the direction is automatically inferred from the name of the monitored quantity. Default: ``'auto'``. - strict (bool): whether to crash the training if `monitor` is + strict: whether to crash the training if `monitor` is not found in the metrics. Default: ``True``. Example:: - from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import EarlyStopping - - early_stopping = EarlyStopping('val_loss') - Trainer(early_stop_callback=early_stopping) + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import EarlyStopping + >>> early_stopping = EarlyStopping('val_loss') + >>> trainer = Trainer(early_stop_callback=early_stopping) """ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 0, diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index 29565800d4ef8b..b0563f46c8c97b 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -1,7 +1,9 @@ r""" Gradient Accumulator ==================== + Change gradient accumulation factor according to scheduling. + """ import warnings @@ -22,12 +24,15 @@ class GradientAccumulationScheduler(Callback): Example:: - from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import GradientAccumulationScheduler + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import GradientAccumulationScheduler # at epoch 5 start accumulating every 2 batches - accumulator = GradientAccumulationScheduler(scheduling: {5: 2}) - Trainer(accumulate_grad_batches=accumulator) + >>> accumulator = GradientAccumulationScheduler(scheduling={5: 2}) + >>> trainer = Trainer(callbacks=[accumulator]) + + # alternatively, pass the scheduling dict directly to the Trainer + >>> trainer = Trainer(accumulate_grad_batches={5: 2}) """ def __init__(self, scheduling: dict): diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 54997e9e63f3b7..90f649394fd917 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -3,6 +3,7 @@ =================== Automatically save model checkpoints during training. + """ import os @@ -26,18 +27,19 @@ class ModelCheckpoint(Callback): Example:: - # no path - ModelCheckpoint() - # saves like /my/path/epoch_0.ckpt - - # save any arbitrary metrics like and val_loss, etc in name - ModelCheckpoint(filepath='/my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}') - # saves file like: /my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt + # custom path + # saves a file like: my/path/epoch_0.ckpt + >>> checkpoint_callback = ModelCheckpoint('my/path/') + # save any arbitrary metrics like `val_loss`, etc. in name + # saves a file like: my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt + >>> checkpoint_callback = ModelCheckpoint( + ... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}' + ... ) - monitor (str): quantity to monitor. - verbose (bool): verbosity mode, False or True. - save_top_k (int): if `save_top_k == k`, + monitor: quantity to monitor. + verbose: verbosity mode. Default: ``False``. + save_top_k: if `save_top_k == k`, the best k models according to the quantity monitored will be saved. if ``save_top_k == 0``, no models are saved. @@ -46,7 +48,7 @@ class ModelCheckpoint(Callback): if ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with `v0`. - mode (str): one of {auto, min, max}. + mode: one of {auto, min, max}. If ``save_top_k != 0``, the decision to overwrite the current save file is made based on either the maximization or the @@ -54,26 +56,29 @@ class ModelCheckpoint(Callback): this should be `max`, for `val_loss` this should be `min`, etc. In `auto` mode, the direction is automatically inferred from the name of the monitored quantity. - save_weights_only (bool): if True, then only the model's weights will be - saved (`model.save_weights(filepath)`), else the full model - is saved (`model.save(filepath)`). - period (int): Interval (number of epochs) between checkpoints. + save_weights_only: if ``True``, then only the model's weights will be + saved (``model.save_weights(filepath)``), else the full model + is saved (``model.save(filepath)``). + period: Interval (number of epochs) between checkpoints. Example:: - from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import ModelCheckpoint + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import ModelCheckpoint - # saves checkpoints to my_path whenever 'val_loss' has a new min - checkpoint_callback = ModelCheckpoint(filepath='my_path') - Trainer(checkpoint_callback=checkpoint_callback) + # saves checkpoints to 'my/path/' whenever 'val_loss' has a new min + >>> checkpoint_callback = ModelCheckpoint(filepath='my/path/') + >>> trainer = Trainer(checkpoint_callback=checkpoint_callback) # save epoch and val_loss in name - ModelCheckpoint(filepath='/my/path/here/sample-mnist_{epoch:02d}-{val_loss:.2f}') - # saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt + # saves a file like: my/path/sample-mnist_epoch=02_val_loss=0.32.ckpt + >>> checkpoint_callback = ModelCheckpoint( + ... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}' + ... ) + """ - def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False, + def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = False, save_top_k: int = 1, save_weights_only: bool = False, mode: str = 'auto', period: int = 1, prefix: str = ''): super().__init__() @@ -137,9 +142,10 @@ def check_monitor_top_k(self, current): return self.monitor_op(current, self.best_k_models[self.kth_best_model]) def format_checkpoint_name(self, epoch, metrics, ver=None): - """Generate a filename according define template. + """Generate a filename according to the defined template. + + Example:: - Examples: >>> tmpdir = os.path.dirname(__file__) >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}')) >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) From f1e11d8b3874067016693c50ae253ec79eecda09 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 5 Apr 2020 09:56:26 -0400 Subject: [PATCH 2/5] model_checkpoint to save all models (#1359) * model_checkpoint to save all models * changelog * rise if Co-authored-by: jamesjjcondon Co-authored-by: J. Borovec --- CHANGELOG.md | 3 ++- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c329a33e914078..f445e994177144 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,7 +56,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- `Trainer.add_argparse_args` classmethod fixed. Now it adds a type for the arguments ([#1147](https://github.com/PyTorchLightning/pytorch-lightning/pull/1147)). +- Fixed `model_checkpoint` when saving all models ([#1359](https://github.com/PyTorchLightning/pytorch-lightning/pull/1359)) +- `Trainer.add_argparse_args` classmethod fixed. Now it adds a type for the arguments ([#1147](https://github.com/PyTorchLightning/pytorch-lightning/pull/1147)) - Fixed bug related to type cheking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114)) - Fixed a bug to ensure lightning checkpoints to be backward compatible ([#1132](https://github.com/PyTorchLightning/pytorch-lightning/pull/1132)) - Fixed a bug that created an extra dataloader with active `reload_dataloaders_every_epoch` ([#1181](https://github.com/PyTorchLightning/pytorch-lightning/issues/1181) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 90f649394fd917..5a2fbb1ce6c075 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -82,7 +82,7 @@ def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = Fal save_top_k: int = 1, save_weights_only: bool = False, mode: str = 'auto', period: int = 1, prefix: str = ''): super().__init__() - if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0: + if save_top_k > 0 and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0: warnings.warn( f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." "All files in this directory will be deleted when a checkpoint is saved!" @@ -219,7 +219,7 @@ def on_validation_end(self, trainer, pl_module): def _do_check_save(self, filepath, current, epoch): # remove kth - if len(self.best_k_models) == self.save_top_k: + if len(self.best_k_models) == self.save_top_k and self.save_top_k > 0: delpath = self.kth_best_model self.best_k_models.pop(self.kth_best_model) self._del_model(delpath) From fdcf9cd51834b72d2000aeec686c565f8e7b8f73 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 5 Apr 2020 17:05:13 +0200 Subject: [PATCH 3/5] add forgotten change logs (#1380) * forgot change logs * more missing * more missing --- CHANGELOG.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f445e994177144..4435c4e6ccdff5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307)) - Made `evaluate` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260)) - Simplify the PL examples structure (shallower and more readable) ([#1247](https://github.com/PyTorchLightning/pytorch-lightning/pull/1247)) +- Changed min max gpu memory to be on their own plots ([#1358](https://github.com/PyTorchLightning/pytorch-lightning/pull/1358)) +- Remove `.item` which causes sync issues ([#1254](https://github.com/PyTorchLightning/pytorch-lightning/pull/1254)) +- Changed smoothing in TQDM to decrease variability of time remaining between training / eval ([#1194](https://github.com/PyTorchLightning/pytorch-lightning/pull/1194)) +- Change default logger to dedicated one ([#1064](https://github.com/PyTorchLightning/pytorch-lightning/pull/1064)) ### Deprecated @@ -67,7 +71,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `WandbLogger.watch` with `wandb.init()` ([#1311](https://github.com/PyTorchLightning/pytorch-lightning/pull/1311)) - Fixed an issue with early stopping that would prevent it from monitoring training metrics when validation is disabled / not implemented ([#1235](https://github.com/PyTorchLightning/pytorch-lightning/pull/1235)). - Fixed a bug that would cause `trainer.test()` to run on the validation set when overloading `validation_epoch_end ` and `test_end` ([#1353](https://github.com/PyTorchLightning/pytorch-lightning/pull/1353)). -- Fixed `WandbLogger.watch` ([#1311](https://github.com/PyTorchLightning/pytorch-lightning/pull/1311)) +- Fixed `WandbLogger.watch` - use of the watch method without importing `wandb` ([#1311](https://github.com/PyTorchLightning/pytorch-lightning/pull/1311)) +- Fixed `WandbLogger` to be used with 'ddp' - allow reinits in sub-processes ([#1149](https://github.com/PyTorchLightning/pytorch-lightning/pull/1149), [#1360](https://github.com/PyTorchLightning/pytorch-lightning/pull/1360)) +- Made `training_epoch_end` behave like `validation_epoch_end` ([#1357](https://github.com/PyTorchLightning/pytorch-lightning/pull/1357)) +- Fixed `fast_dev_run` running validation twice ([#1365](https://github.com/PyTorchLightning/pytorch-lightning/pull/1365)) +- Fixed pickle error from quick patch `__code__` ([#1352](https://github.com/PyTorchLightning/pytorch-lightning/pull/1352)) +- Fixed memory leak on GPU0 ([#1094](https://github.com/PyTorchLightning/pytorch-lightning/pull/1094), [#1349](https://github.com/PyTorchLightning/pytorch-lightning/pull/1349)) +- Fixed checkpointing interval ([#1272](https://github.com/PyTorchLightning/pytorch-lightning/pull/1272)) +- Fixed validation and training loops run the partial dataset ([#1192](https://github.com/PyTorchLightning/pytorch-lightning/pull/1192)) +- Fixed running `on_validation_end` only on main process in DDP ([#1125](https://github.com/PyTorchLightning/pytorch-lightning/pull/1125)) ## [0.7.1] - 2020-03-07 From b18accc64ccd24095c11fdbd64cc924456134592 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Sun, 5 Apr 2020 16:07:16 +0100 Subject: [PATCH 4/5] Add warning for few workers (#1378) * Add warning for few workers * Fix style issue * Update CHANGELOG.md * Update test * formatting * formatting Co-authored-by: Jirka Borovec --- CHANGELOG.md | 1 + pytorch_lightning/trainer/data_loading.py | 15 ++++++-- tests/trainer/test_dataloaders.py | 42 +++++++++++++++++++++++ 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4435c4e6ccdff5..795233c6c908f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199)) - Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269)) - Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279)) +- Added a warning when the number of data loader workers is small. ([#1378](https://github.com/PyTorchLightning/pytorch-lightning/pull/1378)) ### Changed diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index fe1adf75c3fdc4..66b83fd4bfd005 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -1,8 +1,9 @@ +import warnings from abc import ABC, abstractmethod from typing import Union, List, Tuple, Callable import torch.distributed as torch_distrib -from torch.utils.data import SequentialSampler, DataLoader +from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.core import LightningModule @@ -73,6 +74,12 @@ def _percent_range_check(self, name: str) -> None: if not 0. <= value <= 1.: raise ValueError(msg) + def _worker_check(self, dataloader: DataLoader, name: str) -> None: + if isinstance(dataloader, DataLoader) and dataloader.num_workers <= 2: + warnings.warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.' + ' Consider increasing the value of the `num_workers` argument`' + ' in the `DataLoader` init to improve performance.') + def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't do anything if it's not a dataloader @@ -112,11 +119,13 @@ def reset_train_dataloader(self, model: LightningModule) -> None: model: The current `LightningModule` """ self.train_dataloader = self.request_dataloader(model.train_dataloader) + self.num_training_batches = 0 # automatically add samplers self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) + self._worker_check(self.train_dataloader, 'train dataloader') self._percent_range_check('train_percent_check') if not _has_len(self.train_dataloader): @@ -176,10 +185,10 @@ def _reset_eval_dataloader(self, model: LightningModule, # determine number of batches # datasets could be none, 1 or 2+ if len(dataloaders) != 0: - for dataloader in dataloaders: + for i, dataloader in enumerate(dataloaders): + self._worker_check(dataloader, f'{mode} dataloader {i}') if not _has_len(dataloader): num_batches = float('inf') - break percent_check = getattr(self, f'{mode}_percent_check') diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d0da0044f217ef..408774430c398c 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -15,6 +15,7 @@ LightValStepFitMultipleDataloadersMixin, LightValStepFitSingleDataloaderMixin, LightTrainDataloader, + LightValidationDataloader, LightInfTrainDataloader, LightInfValDataloader, LightInfTestDataloader, @@ -485,6 +486,47 @@ class CurrentTestModel( trainer.fit(model) +def test_warning_with_few_workers(tmpdir): + """ Test that error is raised if dataloader with only a few workers is used """ + tutils.reset_seed() + + class CurrentTestModel( + LightTrainDataloader, + LightValStepFitSingleDataloaderMixin, + LightTestFitSingleTestDataloadersMixin, + LightEmptyTestStep, + TestModelBase, + ): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + fit_options = dict(train_dataloader=model._dataloader(train=True), + val_dataloaders=model._dataloader(train=False), + test_dataloaders=model._dataloader(train=False)) + + trainer = Trainer(**trainer_options) + + # fit model + with pytest.warns(UserWarning, match='train'): + trainer.fit(model, **fit_options) + + with pytest.warns(UserWarning, match='val'): + trainer.fit(model, **fit_options) + + with pytest.warns(UserWarning, match='test'): + trainer.test() + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs') def test_dataloader_reinit_for_subclass(): From 38c56081ac23f866bfb00786fbd7d3e947f5cd77 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Date: Sun, 5 Apr 2020 11:10:44 -0400 Subject: [PATCH 5/5] fix docs on saving checkpoints manually (#1373) --- docs/source/weights_loading.rst | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/source/weights_loading.rst b/docs/source/weights_loading.rst index 339469fe33e27e..5f3e4389dd6d04 100644 --- a/docs/source/weights_loading.rst +++ b/docs/source/weights_loading.rst @@ -74,12 +74,14 @@ The Lightning checkpoint also saves the hparams (hyperparams) passed into the Li Manual saving ^^^^^^^^^^^^^ - -To save your own checkpoint call: +You can manually save checkpoints and restore your model from the checkpointed state. .. code-block:: python - model.save_checkpoint(PATH) + model = MyModel(hparams) + trainer.fit(model) + trainer.save_checkpoint("example.ckpt") + new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt") Checkpoint Loading ------------------