diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 313541b8b5bd31..4c07f6e4877bbe 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -155,7 +155,7 @@ formatting errors. In certain cases, a missing blank line or a wrong indent can Run these commands ```bash -pip install ".[docs]" +pip install -r requirements/docs.txt cd docs make html ``` diff --git a/.pyrightconfig.json b/.pyrightconfig.json index 5f5c753023c9df..97000d69dd29d3 100644 --- a/.pyrightconfig.json +++ b/.pyrightconfig.json @@ -7,7 +7,7 @@ "pytorch_lightning/__init__.py", "pytorch_lightning/callbacks", "pytorch_lightning/core", - "pytorch_lightning/accelerator_backends", + "pytorch_lightning/accelerators", "pytorch_lightning/loggers", "pytorch_lightning/logging", "pytorch_lightning/metrics", diff --git a/.run_local_tests.sh b/.run_local_tests.sh index c0c030a78ec939..2ff2c56b76c7e2 100644 --- a/.run_local_tests.sh +++ b/.run_local_tests.sh @@ -6,11 +6,6 @@ export SLURM_LOCALID=0 # use this to run tests rm -rf _ckpt_* -rm -rf ./tests/save_dir* -rm -rf ./tests/mlruns_* -rm -rf ./tests/cometruns* -rm -rf ./tests/wandb* -rm -rf ./tests/tests/* rm -rf ./lightning_logs python -m coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --flake8 python -m coverage report -m diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e9788334c614c..ad17f9f4ad4829 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -106,6 +106,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed LR finder and `hparams` compatibility ([#2821](https://github.com/PyTorchLightning/pytorch-lightning/pull/2821)) +- Fixed `ModelCheckpoint` not saving the latest information when `save_last=True` ([#2881](https://github.com/PyTorchLightning/pytorch-lightning/pull/2881)) + ## [0.8.5] - 2020-07-09 ### Added diff --git a/MANIFEST.in b/MANIFEST.in index fc2610caea6cfa..6c79a2ac700001 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -26,7 +26,7 @@ exclude tests recursive-exclude docs * exclude docs recursive-include docs/source/_images/logos/ * -recursive-include docs/source/_images/general/ pl_overview* tf_* tutorial_* +recursive-include docs/source/_images/general/ pl_overview* tf_* tutorial_* PTL101_* # Include the Requirements recursive-include requirements *.txt diff --git a/README.md b/README.md index 08a3a519bac1a7..6dda80b68f52e0 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,11 @@ Once you do this, you can train on multiple-GPUs, TPUs, CPUs and even in 16-bit Get started with our [QUICK START PAGE](https://pytorch-lightning.readthedocs.io/en/stable/new-project.html) +--- +### [Tune in for our PyTorch Lightning 101 class with William Falcon and Alfredo Canziani! New episodes every week!](https://www.youtube.com/watch?v=DbESHcCoWbM&list=PLaMu-SDt_RB5NUm67hU2pdE75j6KaIOv2) +[![IMAGE ALT TEXT HERE](docs/source/_images/general/PTL101_youtube_thumbnail.jpg)](https://www.youtube.com/watch?v=DbESHcCoWbM&list=PLaMu-SDt_RB5NUm67hU2pdE75j6KaIOv2) +--- + ## Refactoring your PyTorch code + benefits + full walk-through [![Watch the video](docs/source/_images/general/tutorial_cover.jpg)](https://www.youtube.com/watch?v=QHww1JH7IDU) diff --git a/dockers/cuda-extras/Dockerfile b/dockers/cuda-extras/Dockerfile index a1aaff0e7d1d25..f3c435ff5333a4 100644 --- a/dockers/cuda-extras/Dockerfile +++ b/dockers/cuda-extras/Dockerfile @@ -39,7 +39,6 @@ RUN apt-get update && \ && \ # Install AMP - # TODO: skip this instrall for PT >= 1.6 bash install_AMP.sh && \ # Install all requirements pip install -r requirements.txt && \ diff --git a/docs/source/_images/general/PTL101_youtube_thumbnail.jpg b/docs/source/_images/general/PTL101_youtube_thumbnail.jpg new file mode 100644 index 00000000000000..a09dc43d47bb71 Binary files /dev/null and b/docs/source/_images/general/PTL101_youtube_thumbnail.jpg differ diff --git a/docs/source/conf.py b/docs/source/conf.py index d0670254b16b49..f62b540720c286 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -138,7 +138,7 @@ exclude_patterns = [ 'api/pytorch_lightning.rst', 'api/pl_examples.*', - 'api/pytorch_lightning.accelerator_backends.*', + 'api/pytorch_lightning.accelerators.*', 'api/modules.rst', 'PULL_REQUEST_TEMPLATE.md', ] diff --git a/docs/source/index.rst b/docs/source/index.rst index e91743849f64ec..0c887ee9196be7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -63,6 +63,7 @@ PyTorch Lightning Documentation :name: Tutorials :caption: Tutorials + PyTorch Lightning 101 class From PyTorch to PyTorch Lightning Video on how to refactor PyTorch into PyTorch Lightning diff --git a/pl_examples/basic_examples/cpu_template.py b/pl_examples/basic_examples/cpu_template.py index 613f37105a6ff6..9d1fb524957cf6 100644 --- a/pl_examples/basic_examples/cpu_template.py +++ b/pl_examples/basic_examples/cpu_template.py @@ -4,8 +4,8 @@ import os from argparse import ArgumentParser -from pytorch_lightning import Trainer, seed_everything from pl_examples.models.lightning_template import LightningTemplateModel +from pytorch_lightning import Trainer, seed_everything seed_everything(234) diff --git a/pl_examples/basic_examples/gpu_template.py b/pl_examples/basic_examples/gpu_template.py index ced4525d4db66a..64cbabb10b7b25 100644 --- a/pl_examples/basic_examples/gpu_template.py +++ b/pl_examples/basic_examples/gpu_template.py @@ -4,8 +4,8 @@ import os from argparse import ArgumentParser -from pytorch_lightning import Trainer, seed_everything from pl_examples.models.lightning_template import LightningTemplateModel +from pytorch_lightning import Trainer, seed_everything seed_everything(234) diff --git a/pl_examples/basic_examples/multi_node_ddp2_demo.py b/pl_examples/basic_examples/multi_node_ddp2_demo.py index 1dc73613fb6742..aead1fba1e4f2e 100644 --- a/pl_examples/basic_examples/multi_node_ddp2_demo.py +++ b/pl_examples/basic_examples/multi_node_ddp2_demo.py @@ -4,8 +4,8 @@ import os from argparse import ArgumentParser -from pytorch_lightning import Trainer, seed_everything from pl_examples.models.lightning_template import LightningTemplateModel +from pytorch_lightning import Trainer, seed_everything seed_everything(234) diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 7ebcbf0b7d7618..21f6644b09a5b3 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -27,13 +27,10 @@ from tempfile import TemporaryDirectory from typing import Optional, Generator, Union -from torch.nn import Module - -import pytorch_lightning as pl import torch import torch.nn.functional as F -from pytorch_lightning import _logger as log from torch import optim +from torch.nn import Module from torch.optim.lr_scheduler import MultiStepLR from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader @@ -42,6 +39,9 @@ from torchvision.datasets import ImageFolder from torchvision.datasets.utils import download_and_extract_archive +import pytorch_lightning as pl +from pytorch_lightning import _logger as log + BN_TYPES = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d) DATA_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip' diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py index 20fb1cae247320..299c078d172119 100644 --- a/pl_examples/domain_templates/imagenet.py +++ b/pl_examples/domain_templates/imagenet.py @@ -1,9 +1,9 @@ """ This example is largely adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py """ -from argparse import ArgumentParser, Namespace import os import random +from argparse import ArgumentParser, Namespace from collections import OrderedDict import torch diff --git a/pl_examples/domain_templates/reinforce_learn_Qnet.py b/pl_examples/domain_templates/reinforce_learn_Qnet.py index 026827f70b6f0e..cf304f9eb9261f 100644 --- a/pl_examples/domain_templates/reinforce_learn_Qnet.py +++ b/pl_examples/domain_templates/reinforce_learn_Qnet.py @@ -16,12 +16,9 @@ tensorboard --logdir default """ -import pytorch_lightning as pl - -from typing import Tuple, List - import argparse from collections import OrderedDict, deque, namedtuple +from typing import Tuple, List import gym import numpy as np @@ -32,6 +29,8 @@ from torch.utils.data import DataLoader from torch.utils.data.dataset import IterableDataset +import pytorch_lightning as pl + class DQN(nn.Module): """ diff --git a/pl_examples/domain_templates/semantic_segmentation.py b/pl_examples/domain_templates/semantic_segmentation.py index 2f486c5b818275..fff3d1cf98695a 100644 --- a/pl_examples/domain_templates/semantic_segmentation.py +++ b/pl_examples/domain_templates/semantic_segmentation.py @@ -1,4 +1,5 @@ import os +import random from argparse import ArgumentParser, Namespace import numpy as np @@ -7,7 +8,6 @@ import torchvision.transforms as transforms from PIL import Image from torch.utils.data import DataLoader, Dataset -import random import pytorch_lightning as pl from pl_examples.models.unet import UNet diff --git a/pl_examples/models/lightning_template.py b/pl_examples/models/lightning_template.py index 2b4b201cddc8e4..099d6b5a65c02c 100644 --- a/pl_examples/models/lightning_template.py +++ b/pl_examples/models/lightning_template.py @@ -12,7 +12,6 @@ from torch.utils.data import DataLoader from torchvision.datasets import MNIST -from pytorch_lightning import _logger as log from pytorch_lightning.core import LightningModule diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index 3a1d0fa8474580..580467e5555693 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -1,6 +1,6 @@ """Root package info.""" -__version__ = '0.9.0rc9' +__version__ = '0.9.0rc11' __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0' diff --git a/pytorch_lightning/accelerator_backends/__init__.py b/pytorch_lightning/accelerator_backends/__init__.py deleted file mode 100644 index d56ca53d063b30..00000000000000 --- a/pytorch_lightning/accelerator_backends/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from pytorch_lightning.accelerator_backends.gpu_backend import GPUBackend -from pytorch_lightning.accelerator_backends.tpu_backend import TPUBackend -from pytorch_lightning.accelerator_backends.dp_backend import DataParallelBackend -from pytorch_lightning.accelerator_backends.ddp_spawn_backend import DDPSpawnBackend -from pytorch_lightning.accelerator_backends.cpu_backend import CPUBackend -from pytorch_lightning.accelerator_backends.ddp_backend import DDPBackend -from pytorch_lightning.accelerator_backends.ddp2_backend import DDP2Backend diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py new file mode 100644 index 00000000000000..a8f4c14649b396 --- /dev/null +++ b/pytorch_lightning/accelerators/__init__.py @@ -0,0 +1,7 @@ +from pytorch_lightning.accelerators.cpu_backend import CPUBackend +from pytorch_lightning.accelerators.ddp2_backend import DDP2Backend +from pytorch_lightning.accelerators.ddp_backend import DDPBackend +from pytorch_lightning.accelerators.ddp_spawn_backend import DDPSpawnBackend +from pytorch_lightning.accelerators.dp_backend import DataParallelBackend +from pytorch_lightning.accelerators.gpu_backend import GPUBackend +from pytorch_lightning.accelerators.tpu_backend import TPUBackend diff --git a/pytorch_lightning/accelerator_backends/cpu_backend.py b/pytorch_lightning/accelerators/cpu_backend.py similarity index 97% rename from pytorch_lightning/accelerator_backends/cpu_backend.py rename to pytorch_lightning/accelerators/cpu_backend.py index 7760442a206c5c..cfee51e4dd5d8f 100644 --- a/pytorch_lightning/accelerator_backends/cpu_backend.py +++ b/pytorch_lightning/accelerators/cpu_backend.py @@ -22,7 +22,7 @@ def __init__(self, trainer): def setup(self, model): # run through amp wrapper - if self.trainer.use_amp: + if self.trainer.amp_type: raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') # call setup after the ddp process has connected diff --git a/pytorch_lightning/accelerator_backends/ddp2_backend.py b/pytorch_lightning/accelerators/ddp2_backend.py similarity index 94% rename from pytorch_lightning/accelerator_backends/ddp2_backend.py rename to pytorch_lightning/accelerators/ddp2_backend.py index cc14c44ebe1845..85bda4cd8deef4 100644 --- a/pytorch_lightning/accelerator_backends/ddp2_backend.py +++ b/pytorch_lightning/accelerators/ddp2_backend.py @@ -13,10 +13,12 @@ # limitations under the License import os + import torch -from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE -from pytorch_lightning.utilities.distributed import rank_zero_only + from pytorch_lightning import _logger as log +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException try: @@ -30,9 +32,7 @@ try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None class DDP2Backend(object): @@ -133,10 +133,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # set model properties before going into wrapper self.trainer.copy_trainer_model_properties(model) - # AMP - # run through amp wrapper before going to distributed DP - # TODO: remove with dropping NVIDIA AMP support - if self.trainer.use_amp and not NATIVE_AMP_AVALAIBLE: + # AMP - run through amp wrapper before going to distributed DP + if self.trainer.amp_type == AMPType.APEX: model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) self.trainer.optimizers = optimizers self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) diff --git a/pytorch_lightning/accelerator_backends/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py similarity index 96% rename from pytorch_lightning/accelerator_backends/ddp_backend.py rename to pytorch_lightning/accelerators/ddp_backend.py index 44ad52d34ba2f9..e499feda651d9e 100644 --- a/pytorch_lightning/accelerator_backends/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -13,16 +13,18 @@ # limitations under the License import os -import torch import subprocess import sys +from os.path import abspath from time import sleep +from typing import Optional + import numpy as np -from os.path import abspath -from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE -from pytorch_lightning.utilities.distributed import rank_zero_only +import torch + from pytorch_lightning import _logger as log -from typing import Optional +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.distributed import rank_zero_only try: from hydra.utils import to_absolute_path, get_original_cwd @@ -35,9 +37,7 @@ try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None class DDPBackend(object): @@ -200,10 +200,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # set model properties before going into wrapper self.trainer.copy_trainer_model_properties(model) - # AMP - # run through amp wrapper before going to distributed DP - # TODO: remove with dropping NVIDIA AMP support - if self.trainer.use_amp and not NATIVE_AMP_AVALAIBLE: + # AMP - run through amp wrapper before going to distributed DP + if self.trainer.amp_type == AMPType.APEX: model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) self.trainer.optimizers = optimizers self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) diff --git a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py b/pytorch_lightning/accelerators/ddp_spawn_backend.py similarity index 95% rename from pytorch_lightning/accelerator_backends/ddp_spawn_backend.py rename to pytorch_lightning/accelerators/ddp_spawn_backend.py index 704fc5558588a4..9ed68f66083ad2 100644 --- a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_spawn_backend.py @@ -12,18 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License -import os import torch import torch.multiprocessing as mp -from pytorch_lightning.utilities.distributed import rank_zero_only + from pytorch_lightning import _logger as log +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.distributed import rank_zero_only try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None class DDPSpawnBackend(object): @@ -133,11 +132,9 @@ def ddp_train(self, process_idx, mp_queue, model): # set model properties before going into wrapper self.trainer.copy_trainer_model_properties(model) - # AMP + # AMP - # run through amp wrapper before going to distributed DP - # TODO: remove with dropping NVIDIA AMP support - native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") - if self.trainer.use_amp and not native_amp_available: + if self.trainer.amp_type == AMPType.APEX: model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) self.trainer.optimizers = optimizers self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) diff --git a/pytorch_lightning/accelerator_backends/dp_backend.py b/pytorch_lightning/accelerators/dp_backend.py similarity index 95% rename from pytorch_lightning/accelerator_backends/dp_backend.py rename to pytorch_lightning/accelerators/dp_backend.py index efb683ff4eaa9e..31791ee5ecbaf7 100644 --- a/pytorch_lightning/accelerator_backends/dp_backend.py +++ b/pytorch_lightning/accelerators/dp_backend.py @@ -13,16 +13,16 @@ # limitations under the License. import torch -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.overrides.data_parallel import LightningDataParallel from torch import optim +from pytorch_lightning.overrides.data_parallel import LightningDataParallel +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.exceptions import MisconfigurationException + try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None class DataParallelBackend(object): @@ -49,7 +49,7 @@ def setup(self, model): self.model_autocast_original_forward = model.forward # init half precision - if self.trainer.use_amp: + if self.trainer.amp_type: model = self.__init_half_precision(model) # init torch data parallel @@ -69,9 +69,7 @@ def __init_torch_data_parallel(self, model): return model def __init_half_precision(self, model): - native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") - - if native_amp_available: + if self.trainer.amp_type == AMPType.NATIVE: self.__init_native_amp(model) else: model = self.__init_nvidia_apex(model) diff --git a/pytorch_lightning/accelerator_backends/gpu_backend.py b/pytorch_lightning/accelerators/gpu_backend.py similarity index 85% rename from pytorch_lightning/accelerator_backends/gpu_backend.py rename to pytorch_lightning/accelerators/gpu_backend.py index 7f15d3c25f4104..30920998b2f9dc 100644 --- a/pytorch_lightning/accelerator_backends/gpu_backend.py +++ b/pytorch_lightning/accelerators/gpu_backend.py @@ -12,18 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch - from pytorch_lightning.core import LightningModule +from pytorch_lightning.utilities import AMPType + try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None class GPUBackend(object): + amp_type: AMPType def __init__(self, trainer): self.trainer = trainer @@ -42,9 +41,7 @@ def setup(self, model): self.trainer.lr_schedulers = lr_schedulers self.trainer.optimizer_frequencies = optimizer_frequencies - # TODO: remove with dropping NVIDIA AMP support - native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") - if APEX_AVAILABLE and self.trainer.use_amp and not native_amp_available: + if self.trainer.amp_type == AMPType.APEX: model = self._setup_nvidia_apex(model) return model diff --git a/pytorch_lightning/accelerator_backends/tpu_backend.py b/pytorch_lightning/accelerators/tpu_backend.py similarity index 100% rename from pytorch_lightning/accelerator_backends/tpu_backend.py rename to pytorch_lightning/accelerators/tpu_backend.py index 2c0b172b9e2113..e879af3f5c78da 100644 --- a/pytorch_lightning/accelerator_backends/tpu_backend.py +++ b/pytorch_lightning/accelerators/tpu_backend.py @@ -17,10 +17,10 @@ import torch import torch.multiprocessing as mp +from pytorch_lightning import _logger as log from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning import _logger as log try: import torch_xla diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index 7e8e0ce5bcfef3..92920ee9b74bdb 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -1,8 +1,8 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.lr_logger import LearningRateLogger +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar __all__ = [ diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index b804241aa1b7ea..f241318fbfe7e7 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -47,11 +47,11 @@ def on_sanity_check_end(self, trainer, pl_module): pass def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - """Called when the validation batch begins.""" + """Called when the train batch begins.""" pass def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - """Called when the validation batch ends.""" + """Called when the train batch ends.""" pass def on_train_epoch_start(self, trainer, pl_module): diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index d60e3d81584705..3b4aed8566bb65 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -7,7 +7,6 @@ """ from copy import deepcopy -import os import numpy as np import torch import torch.distributed as dist diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index 997574644072d1..8d61d29856b7e8 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -7,7 +7,6 @@ """ from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_warn class GradientAccumulationScheduler(Callback): diff --git a/pytorch_lightning/callbacks/lr_logger.py b/pytorch_lightning/callbacks/lr_logger.py index 27fbb81800241d..a401cc84516f1d 100755 --- a/pytorch_lightning/callbacks/lr_logger.py +++ b/pytorch_lightning/callbacks/lr_logger.py @@ -8,9 +8,8 @@ """ from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities.exceptions import MisconfigurationException - from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException class LearningRateLogger(Callback): diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 41015ff829db37..09069da8eb805e 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -8,11 +8,11 @@ import os import re - -import numpy as np from typing import Optional +import numpy as np import torch + from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only @@ -97,6 +97,10 @@ class ModelCheckpoint(Callback): """ + CHECKPOINT_NAME_LAST = "last.ckpt" + CHECKPOINT_STATE_BEST_SCORE = "checkpoint_callback_best_model_score" + CHECKPOINT_STATE_BEST_PATH = "checkpoint_callback_best_model_path" + def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False, save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False, mode: str = 'auto', period: int = 1, prefix: str = ''): @@ -313,10 +317,6 @@ def on_validation_end(self, trainer, pl_module): self.epoch_last_check = epoch - if self.save_last: - filepath = os.path.join(self.dirpath, self.prefix + 'last.ckpt') - self._save_model(filepath, trainer, pl_module) - filepath = self.format_checkpoint_name(epoch, metrics) version_cnt = 0 while gfile.exists(filepath): @@ -351,6 +351,10 @@ def on_validation_end(self, trainer, pl_module): assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0' self._save_model(filepath, trainer, pl_module) + if self.save_last: + filepath = os.path.join(self.dirpath, self.prefix + ModelCheckpoint.CHECKPOINT_NAME_LAST) + self._save_model(filepath, trainer, pl_module) + def _do_check_save(self, filepath, current, epoch, trainer, pl_module): # remove kth diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 922f500ce8c250..1695e090f031b9 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -4,15 +4,13 @@ from torch import Tensor from torch.nn import Module from torch.optim.optimizer import Optimizer -from pytorch_lightning.utilities import move_data_to_device, NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import move_data_to_device, AMPType try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None class ModelHooks(Module): @@ -218,7 +216,7 @@ def on_before_zero_grad(self, optimizer: Optimizer) -> None: for optimizer in optimizers: optimizer.step() model.on_before_zero_grad(optimizer) # < ---- called here - optimizer.zero_grad + optimizer.zero_grad() Args: optimizer: The optimizer for which grads should be zeroed. @@ -267,8 +265,8 @@ def backward(self, trainer, loss, optimizer, optimizer_idx): """ loss.backward() - def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx): - if NATIVE_AMP_AVALAIBLE: + def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_type: AMPType): + if amp_type == AMPType.NATIVE: scaled_loss = self.trainer.scaler.scale(unscaled_loss) else: scaled_loss = amp.scale_loss(unscaled_loss, optimizer) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index 391192615fb43c..f55b8e026c5c8e 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -9,9 +9,7 @@ import torch.nn as nn from torch.utils.hooks import RemovableHandle - -from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE -from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities import AMPType PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"] UNKNOWN_SIZE = "?" @@ -209,8 +207,7 @@ def _forward_example_input(self) -> None: input_ = model.example_input_array input_ = model.transfer_batch_to_device(input_, model.device) - if trainer is not None and trainer.use_amp and not trainer.use_tpu: - if NATIVE_AMP_AVALAIBLE: + if trainer is not None and trainer.amp_type == AMPType.NATIVE and not trainer.use_tpu: model.forward = torch.cuda.amp.autocast()(model.forward) mode = model.training diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index dae8e2e2ec07fa..429ebd298c7795 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -3,11 +3,11 @@ import io import inspect import os +from argparse import Namespace +from typing import Union, Dict, Any, Optional, Callable, MutableMapping import torch import yaml -from argparse import Namespace -from typing import Union, Dict, Any, Optional, Callable, MutableMapping from pytorch_lightning import _logger as log from pytorch_lightning.utilities import rank_zero_warn, AttributeDict diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 172930fd4ad9a8..84f9669cf40c23 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -1,8 +1,10 @@ import numbers +from copy import copy from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any -from torch import Tensor + import torch -from copy import copy +from torch import Tensor + from pytorch_lightning.metrics.converters import _sync_ddp_if_available diff --git a/pytorch_lightning/loggers/__init__.py b/pytorch_lightning/loggers/__init__.py index 5f2f3044d0a65d..e4480603355a2f 100644 --- a/pytorch_lightning/loggers/__init__.py +++ b/pytorch_lightning/loggers/__init__.py @@ -1,9 +1,8 @@ from os import environ from pytorch_lightning.loggers.base import LightningLoggerBase, LoggerCollection -from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.loggers.csv_logs import CSVLogger - +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger __all__ = [ 'LightningLoggerBase', diff --git a/pytorch_lightning/loggers/csv_logs.py b/pytorch_lightning/loggers/csv_logs.py index 1e395abadb2937..8d3c69dbf34284 100644 --- a/pytorch_lightning/loggers/csv_logs.py +++ b/pytorch_lightning/loggers/csv_logs.py @@ -5,13 +5,14 @@ CSV logger for basic experiment logging that does not require opening ports """ +import csv import io import os -import csv -import torch from argparse import Namespace from typing import Optional, Dict, Any, Union +import torch + from pytorch_lightning import _logger as log from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 2a107c639bd3fb..65898a5fedb5a5 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -1,13 +1,3 @@ -from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric -from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric -from pytorch_lightning.metrics.regression import ( - MAE, - MSE, - PSNR, - RMSE, - RMSLE, - SSIM -) from pytorch_lightning.metrics.classification import ( Accuracy, AveragePrecision, @@ -24,12 +14,22 @@ PrecisionRecall, IoU, ) +from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric +from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric +from pytorch_lightning.metrics.nlp import BLEUScore +from pytorch_lightning.metrics.regression import ( + MAE, + MSE, + PSNR, + RMSE, + RMSLE, + SSIM +) from pytorch_lightning.metrics.sklearns import ( AUC, PrecisionRecallCurve, SklearnMetric, ) -from pytorch_lightning.metrics.nlp import BLEUScore __classification_metrics = [ "AUC", diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index d01ee4e6db5eb4..c58a3f55f0ddef 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -10,8 +10,9 @@ import numpy as np import torch from torch.utils.data._utils.collate import np_str_obj_array_pattern -from pytorch_lightning.utilities.apply_func import apply_to_collection + from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.apply_func import apply_to_collection try: from torch.distributed import ReduceOp diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 4d940ad18bd6ad..926803b5045e1a 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -20,6 +20,7 @@ to_onehot, iou, ) +from pytorch_lightning.metrics.functional.nlp import bleu_score from pytorch_lightning.metrics.functional.regression import ( mae, mse, @@ -28,4 +29,3 @@ rmsle, ssim ) -from pytorch_lightning.metrics.functional.nlp import bleu_score diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index d12509d5885299..0a77dd6b67682c 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -1,4 +1,3 @@ -import sys from functools import wraps from typing import Callable, Optional, Sequence, Tuple @@ -182,9 +181,9 @@ def stat_scores_multiple_classes( num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) if pred.dtype != torch.bool: - pred.clamp_max_(max=num_classes) + pred = pred.clamp_max(max=num_classes) if target.dtype != torch.bool: - target.clamp_max_(max=num_classes) + target = target.clamp_max(max=num_classes) possible_reductions = ('none', 'sum', 'elementwise_mean') if reduction not in possible_reductions: diff --git a/pytorch_lightning/metrics/functional/nlp.py b/pytorch_lightning/metrics/functional/nlp.py index e1bb86ab19c36f..22645bb5494b67 100644 --- a/pytorch_lightning/metrics/functional/nlp.py +++ b/pytorch_lightning/metrics/functional/nlp.py @@ -3,8 +3,8 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from typing import Sequence, List from collections import Counter +from typing import Sequence, List import torch diff --git a/pytorch_lightning/metrics/regression.py b/pytorch_lightning/metrics/regression.py index a2cbaaf4f822a9..1edb388fe293d5 100644 --- a/pytorch_lightning/metrics/regression.py +++ b/pytorch_lightning/metrics/regression.py @@ -101,7 +101,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: class MAE(Metric): """ - Computes the root mean absolute loss or L1-loss. + Computes the mean absolute loss or L1-loss. Example: diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 3945d770fe8d42..c6e8cd8fa64133 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -1,13 +1,14 @@ import itertools import threading -from itertools import chain from collections import Mapping, Iterable +from itertools import chain import torch from torch.cuda._utils import _get_device_index from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel._functions import Gather + from pytorch_lightning.core.step_result import Result diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 8dcec8eb305110..6b8a58ace5b92c 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -818,7 +818,7 @@ def on_train_end(self, trainer, pl_module): ^^^^^^^^^^^^^^^^^^^ Enables auto adding of distributed sampler. By default it will add ``shuffle=True`` for train sampler and ``shuffle=False`` for val/test sampler. If you want to customize -it, you can set ``replace_ddp_sampler=False`` and add your own distributed sampler. +it, you can set ``replace_sampler_ddp=False`` and add your own distributed sampler. .. testcode:: @@ -864,6 +864,19 @@ def on_train_end(self, trainer, pl_module): trainer = Trainer(sync_batchnorm=True) +amp_type +^^^^^^^^ + +Define a preferable mixed precision, either NVIDIA Apex ("apex") or PyTorch built-in ("native") AMP which is supported from v1.6. + +.. testcode:: + + # using NVIDIA Apex + trainer = Trainer(amp_type='apex') + + # using PyTorch built-in AMP + trainer = Trainer(amp_type='native') + val_percent_check ^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/trainer/auto_mix_precision.py b/pytorch_lightning/trainer/auto_mix_precision.py index 570bcc3aafd2b4..06e823117caa76 100644 --- a/pytorch_lightning/trainer/auto_mix_precision.py +++ b/pytorch_lightning/trainer/auto_mix_precision.py @@ -1,8 +1,7 @@ from abc import ABC from pytorch_lightning import _logger as log -from pytorch_lightning.utilities import rank_zero_warn, APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE -from pytorch_lightning.utilities.distributed import rank_zero_debug +from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, rank_zero_warn, AMPType class TrainerAMPMixin(ABC): @@ -11,26 +10,39 @@ class TrainerAMPMixin(ABC): # the proper values/initialisation should be done in child class precision: int - def init_amp(self): - if NATIVE_AMP_AVALAIBLE: - log.debug("`amp_level` has been deprecated since v0.7.4 (native amp does not require it)") - - assert self.precision in (16, 32), 'only 32 or 16 bit precision supported' - - if self.use_amp and NATIVE_AMP_AVALAIBLE: - log.info('Using native 16bit precision.') + def _setup_amp_type(self, amp_type: str): + self.amp_type = None + if self.precision != 16: + # no AMP requested, so we can leave now return - - if self.use_amp and not APEX_AVAILABLE: # pragma: no-cover + amp_type = amp_type.lower() + assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}' + if amp_type == 'native': + if not NATIVE_AMP_AVALAIBLE: + rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.' + ' Consider upgrading with `pip install torch>=1.6`.' + ' We will attempt to use NVIDIA Apex for this session.') + amp_type = 'apex' + else: + log.info('Using native 16bit precision.') + self.amp_type = AMPType.NATIVE + if amp_type == 'apex': + if not APEX_AVAILABLE: + rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.' + ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux') + else: + log.info('Using APEX 16bit precision.') + self.amp_type = AMPType.APEX + if not self.amp_type: raise ModuleNotFoundError( - "You set `use_amp=True` but do not have apex installed." - " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" - " and rerun with `use_amp=True`." - " This run will NOT use 16 bit precision." + f'You have asked for AMP support {amp_type}, but there is no support on your side yet.' + f' Consider installing torch >= 1.6 or NVIDIA Apex.' ) - if self.use_amp: - log.info('Using APEX 16bit precision.') + def init_amp(self, amp_type: str): + assert self.precision in (16, 32), 'only 32 or 16 bit precision supported' + + self._setup_amp_type(amp_type) @property def use_amp(self) -> bool: diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 8600449d86a946..6547e7348a5cb3 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -1,7 +1,5 @@ -import os from abc import ABC, abstractmethod -from typing import List, Callable, Optional - +from typing import List, Optional from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar from pytorch_lightning.loggers import LightningLoggerBase diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 3c33f33fe50ab1..f4144563a7c8f5 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -1,6 +1,6 @@ from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException class ConfigValidator(object): diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 38a1118118a403..575e28354b5de3 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -22,9 +22,7 @@ try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None try: import torch_xla diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 1f92e61e94a394..c550fb648f0ca6 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -142,9 +142,8 @@ def train_fx(trial_hparams, cluster_manager, _): import torch from pytorch_lightning import _logger as log from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info +from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule @@ -152,9 +151,7 @@ def train_fx(trial_hparams, cluster_manager, _): try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None try: import horovod.torch as hvd @@ -213,11 +210,6 @@ def call_setup_hook(self, *args): def num_gpus(self) -> int: """Warning: this is just empty shell for code implemented in other class.""" - @property - @abstractmethod - def use_amp(self) -> bool: - """Warning: this is just empty shell for code implemented in other class.""" - @abstractmethod def copy_trainer_model_properties(self, *args): """Warning: this is just empty shell for code implemented in other class.""" diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 7d5a00523ef9e6..f76c6f1b008dea 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -33,17 +33,14 @@ LightningDistributedDataParallel, LightningDataParallel, ) -from pytorch_lightning.utilities import move_data_to_device, NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import move_data_to_device, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.distributed import rank_zero_only -from pytorch_lightning.utilities import rank_zero_warn try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None try: import torch_xla.core.xla_model as xm @@ -82,11 +79,7 @@ class TrainerDPMixin(ABC): on_colab_kaggle: str save_spawn_weights: Callable logger: ... - - @property - @abstractmethod - def use_amp(self) -> bool: - """Warning: this is just empty shell for code implemented in other class.""" + amp_type: AMPType @abstractmethod def call_setup_hook(self, *args): @@ -130,7 +123,7 @@ def copy_trainer_model_properties(self, model): m.use_dp = self.use_dp m.use_ddp2 = self.use_ddp2 m.use_ddp = self.use_ddp - m.use_amp = self.use_amp + m.use_amp = self.amp_type is not None m.testing = self.testing m.use_single_gpu = self.use_single_gpu m.use_tpu = self.use_tpu @@ -212,7 +205,7 @@ def horovod_train(self, model): if isinstance(scheduler, _LRScheduler): scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs] - if self.use_amp: + if self.amp_type: model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) self.optimizers = optimizers self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 75d09db2cec655..5ce7b7718c2a7e 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -131,9 +131,7 @@ from torch.utils.data import DataLoader from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel -from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE, flatten_dict -from torch import distributed as dist +from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType from pytorch_lightning.core.step_result import Result, EvalResult from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -182,6 +180,7 @@ class TrainerEvaluationLoopMixin(ABC): tpu_id: int verbose_test: bool running_sanity_check: bool + amp_type: AMPType # Callback system on_validation_batch_start: Callable @@ -319,7 +318,7 @@ def _evaluate( # ----------------- # RUN EVALUATION STEP # ----------------- - if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: + if self.amp_type == AMPType.NATIVE and not self.use_tpu: with torch.cuda.amp.autocast(): output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) else: diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index aced8b64a47ed2..c90ba59abf7352 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -1,4 +1,3 @@ -import os from abc import ABC from typing import Union, Iterable diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 3b3e992e0ed35b..0c3a9047d91724 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -1,8 +1,8 @@ """ Trainer Learning Rate Finder """ -import os import importlib +import os from abc import ABC, abstractmethod from typing import Optional, Sequence, Tuple, List, Union diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b19481986e11bf..b8d246b5c8f252 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -20,9 +20,10 @@ import torch import torch.distributed as torch_distrib -import torch.multiprocessing as mp from torch.utils.data import DataLoader +from pytorch_lightning.accelerators import ( + GPUBackend, TPUBackend, CPUBackend, DDPSpawnBackend, DataParallelBackend, DDPBackend, DDP2Backend) from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule @@ -30,9 +31,10 @@ from pytorch_lightning.core.step_result import EvalResult from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler -from pytorch_lightning.trainer.auto_mix_precision import NATIVE_AMP_AVALAIBLE, TrainerAMPMixin +from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin +from pytorch_lightning.trainer.configuration_validator import ConfigValidator from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_10 from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin @@ -47,12 +49,9 @@ from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin -from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.trainer.configuration_validator import ConfigValidator -from pytorch_lightning.accelerator_backends import ( - GPUBackend, TPUBackend, CPUBackend, DDPSpawnBackend, DataParallelBackend, DDPBackend, DDP2Backend) # warnings to ignore in trainer warnings.filterwarnings( @@ -62,9 +61,7 @@ try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None try: import torch_xla @@ -200,6 +197,7 @@ def __init__( terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: bool = True, + amp_type: str = 'native', amp_level: str = 'O2', # backward compatible, todo: remove in v1.0.0 val_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 test_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 @@ -309,6 +307,7 @@ def __init__( Defaults to `default_root_dir`. amp_level: The optimization level to use (O1, O2, etc...). + .. warning:: .. deprecated:: v0.7.4 num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. Set it to `-1` to run all batches in all validation dataloaders. Default: 2 @@ -331,7 +330,7 @@ def __init__( replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this will toggled automatically when DDP is used. By default it will add ``shuffle=True`` for train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it, - you can set ``replace_ddp_sampler=False`` and add your own distributed sampler. + you can set ``replace_sampler_ddp=False`` and add your own distributed sampler. benchmark: If true enables cudnn.benchmark. @@ -588,7 +587,7 @@ def __init__( self.scaler = None self.amp_level = amp_level - self.init_amp() + self.init_amp(amp_type) self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE') @@ -1129,7 +1128,7 @@ def run_pretrain_routine(self, model: LightningModule): self.copy_trainer_model_properties(ref_model) # init amp. Must be done here instead of __init__ to allow ddp to work - if NATIVE_AMP_AVALAIBLE and self.precision == 16 and not self.use_tpu: + if self.amp_type == AMPType.NATIVE and self.precision == 16 and not self.use_tpu: self.scaler = torch.cuda.amp.GradScaler() # log hyper-parameters diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 87fc4ee8c41d30..7a1613b919a267 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -89,21 +89,20 @@ from abc import ABC from distutils.version import LooseVersion from subprocess import call -from pkg_resources import parse_version import torch import torch.distributed as torch_distrib import pytorch_lightning from pytorch_lightning import _logger as log -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.overrides.data_parallel import ( LightningDistributedDataParallel, LightningDataParallel, ) -from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import rank_zero_warn, AMPType from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.cloud_io import gfile, makedirs @@ -119,9 +118,7 @@ try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None try: import horovod.torch as hvd @@ -159,8 +156,9 @@ class TrainerIOMixin(ABC): on_tpu: bool num_training_batches: int accumulate_grad_batches: int - use_amp: bool scaler: ... + use_tpu: bool + amp_type: AMPType def get_model(self): is_dp_module = isinstance(self.model, (LightningDistributedDataParallel, LightningDataParallel)) @@ -325,9 +323,9 @@ def restore(self, checkpoint_path: str, on_gpu: bool): model.cuda(self.root_gpu) # restore amp scaling - if self.use_amp and NATIVE_AMP_AVALAIBLE and 'native_amp_scaling_state' in checkpoint: + if self.amp_type == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint: self.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) - elif self.use_amp and not NATIVE_AMP_AVALAIBLE and 'amp_scaling_state' in checkpoint: + elif self.amp_type == AMPType.APEX and 'amp_scaling_state' in checkpoint: amp.load_state_dict(checkpoint['amp_scaling_state']) # load training state (affects trainer only) @@ -357,8 +355,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: if checkpoint_callbacks: # we add the official checkpoint callback to the end of the list # extra user provided callbacks will not be persisted yet - checkpoint['checkpoint_callback_best_model_score'] = self.checkpoint_callback.best_model_score - checkpoint['checkpoint_callback_best_model_path'] = self.checkpoint_callback.best_model_path + checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE] = self.checkpoint_callback.best_model_score + checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH] = self.checkpoint_callback.best_model_path if early_stopping_callbacks and checkpoint_callbacks: # we add the official early stopping callback to the end of the list @@ -378,9 +376,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: checkpoint['lr_schedulers'] = lr_schedulers # save native amp scaling - if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: + if self.amp_type == AMPType.NATIVE and not self.use_tpu: checkpoint['native_amp_scaling_state'] = self.scaler.state_dict() - elif self.use_amp and not NATIVE_AMP_AVALAIBLE: + elif self.amp_type == AMPType.APEX: checkpoint['amp_scaling_state'] = amp.state_dict() # add the module_arguments and state_dict from the model @@ -439,8 +437,8 @@ def restore_training_state(self, checkpoint): early_stopping_callbacks = [c for c in self.callbacks if isinstance(c, EarlyStopping)] if checkpoint_callbacks: - if 'checkpoint_callback_best_model_score' in checkpoint: - checkpoint_callbacks[-1].best_model_score = checkpoint['checkpoint_callback_best_model_score'] + if ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE in checkpoint: + checkpoint_callbacks[-1].best_model_score = checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE] else: # Old naming until version 0.7.6 rank_zero_warn( @@ -448,7 +446,7 @@ def restore_training_state(self, checkpoint): 'this will not be supported in the future.' ) checkpoint_callbacks[-1].best_model_score = checkpoint['checkpoint_callback_best'] - checkpoint_callbacks[-1].best_model_path = checkpoint['checkpoint_callback_best_model_path'] + checkpoint_callbacks[-1].best_model_path = checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH] if early_stopping_callbacks: state = checkpoint['early_stop_callback_state_dict'] @@ -537,9 +535,9 @@ def hpc_load(self, folderpath, on_gpu): model.load_state_dict(checkpoint['state_dict']) # restore amp scaling - if self.use_amp and NATIVE_AMP_AVALAIBLE and 'native_amp_scaling_state' in checkpoint: + if self.amp_type == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint: self.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) - elif self.use_amp and not NATIVE_AMP_AVALAIBLE and 'amp_scaling_state' in checkpoint: + elif self.amp_type == AMPType.APEX and 'amp_scaling_state' in checkpoint: amp.load_state_dict(checkpoint['amp_scaling_state']) if self.root_gpu is not None: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5df7437348aa6c..72178b8e8e17a3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -157,36 +157,33 @@ def training_step(self, batch, batch_idx): trainer = Trainer(terminate_on_nan=True) """ -import os import subprocess from abc import ABC, abstractmethod +from copy import copy from typing import Callable from typing import Union, List import numpy as np import torch -from torch.utils.data import DataLoader import torch.distributed as torch_distrib -from copy import copy +from torch.utils.data import DataLoader from pytorch_lightning import _logger as log -from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator -from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import rank_zero_warn, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.memory import recursive_detach -from pytorch_lightning.core.step_result import EvalResult, TrainResult, Result +from pytorch_lightning.utilities.parsing import AttributeDict try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None try: import torch_xla.distributed.parallel_loader as xla_pl @@ -256,6 +253,8 @@ class TrainerTrainLoopMixin(ABC): terminate_on_nan: bool tpu_id: int interactive_ddp_procs: ... + amp_type: AMPType + on_tpu: bool # Callback system callbacks: List[Callback] @@ -740,7 +739,7 @@ def run_training_batch(self, batch, batch_idx): batch_idx, opt_idx, optimizer, - self.hiddens + self.hiddens, ) using_results_obj = isinstance(opt_closure_result.training_step_output, Result) @@ -836,7 +835,7 @@ def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer): # ------------------ # CLIP GRADS # ------------------ - if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: + if self.amp_type == AMPType.NATIVE and not self.use_tpu: self.scaler.unscale_(optimizer) self.clip_gradients(optimizer) @@ -858,7 +857,7 @@ def call_optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch): batch_idx, opt_idx, optimizer, - self.hiddens + self.hiddens, ).loss # apply TPU optimizer @@ -870,7 +869,7 @@ def call_optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch): elif isinstance(optimizer, torch.optim.LBFGS): # native amp + lbfgs is a no go right now - if self.use_amp and NATIVE_AMP_AVALAIBLE: + if self.amp_type == AMPType.NATIVE: raise MisconfigurationException( 'native PyTorch amp and lbfgs are not compatible.' ' To request, please file a Github issue in PyTorch and tag @mcarilli') @@ -879,12 +878,12 @@ def call_optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch): # when using 16-bit else: - native_amp = self.use_amp and NATIVE_AMP_AVALAIBLE + native_amp = self.amp_type == AMPType.NATIVE model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure, using_native_amp=native_amp) # in native 16-bit we need to update scaler after optimizer step - if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: + if self.amp_type == AMPType.NATIVE and not self.use_tpu: self.scaler.update() # model hook @@ -901,7 +900,7 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens) # FORWARD (TRAINING STEP + TRAIN STEP END) # --------------------------- with self.profiler.profile('model_forward'): - if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: + if self.amp_type == AMPType.NATIVE and not self.use_tpu: with torch.cuda.amp.autocast(): training_step_output = self.training_forward(split_batch, batch_idx, opt_idx, hiddens) @@ -955,10 +954,10 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens) with self.profiler.profile('model_backward'): # scale loss for 16 bit if self.precision == 16 and not self.on_tpu: - closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx) + closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx, amp_type=self.amp_type) # enter amp context - if not NATIVE_AMP_AVALAIBLE: + if self.amp_type == AMPType.APEX: context = closure_loss closure_loss = closure_loss.__enter__() @@ -966,7 +965,7 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens) model_ref.backward(self, closure_loss, optimizer, opt_idx) # exit amp context - if self.precision == 16 and not NATIVE_AMP_AVALAIBLE and not self.on_tpu: + if self.precision == 16 and self.amp_type == AMPType.APEX and not self.on_tpu: a, b, c = None, None, None error = context.__exit__(a, b, c) if error: diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 44b66407c7645d..14214d2432a143 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -13,30 +13,25 @@ # limitations under the License. import math -import sys -from abc import ABC, abstractmethod -import gc import os +from abc import ABC, abstractmethod from typing import Optional import torch from torch import Tensor -from torch.utils.data import DataLoader from pytorch_lightning import _logger as log -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.callbacks import GradientAccumulationScheduler +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import DummyLogger -from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None EPSILON = 1e-6 EPSILON_FP16 = 1e-5 @@ -51,6 +46,7 @@ class TrainerTrainingTricksMixin(ABC): default_root_dir: str progress_bar_callback: ... on_gpu: bool + amp_type: AMPType @abstractmethod def get_model(self) -> LightningModule: @@ -75,7 +71,7 @@ def clip_gradients(self, optimizer): if self.gradient_clip_val <= 0: return model = self.get_model() - if self.use_amp and not NATIVE_AMP_AVALAIBLE: + if self.amp_type == AMPType.APEX: parameters = amp.master_params(optimizer) else: parameters = model.parameters() diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 7d2d550f8b7dc1..a6ea1d8467a0d9 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -1,10 +1,11 @@ """General utilities""" +from enum import Enum import numpy import torch -from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict try: @@ -19,3 +20,8 @@ FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps + + +class AMPType(Enum): + APEX = 'apex' + NATIVE = 'native' diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 75130b297ddccd..59b73f0fced3c3 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -1,3 +1,4 @@ +import importlib from abc import ABC from collections.abc import Mapping, Sequence from copy import copy @@ -5,8 +6,6 @@ import torch -import importlib - TORCHTEXT_AVAILABLE = importlib.util.find_spec("torchtext") is not None if TORCHTEXT_AVAILABLE: from torchtext.data import Batch diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index 7329213b20d4fe..2d9b0dfa6491e9 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -22,6 +22,8 @@ # only support remote cloud paths if newer modern_gfile = version.parse(tensorboard.version.VERSION) >= version.parse('2.0') +import torch + def load(path_or_url: str, map_location=None): if urlparse(path_or_url).scheme == '' or Path(path_or_url).drive: # no scheme or with a drive letter diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 537c5d8f3e73db..cd0621496fe42e 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -1,7 +1,8 @@ -from functools import wraps +import os import warnings +from functools import wraps + from pytorch_lightning import _logger as log -import os def rank_zero_only(fn): diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index 19a473640e2d93..1202b93e03a0db 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -1,4 +1,5 @@ import gc + import torch diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 2d0f4620d6a279..60e188409ea3f7 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -1,10 +1,10 @@ """Helper functions to help with reproducibility of models. """ import os -from typing import Optional, Type +import random +from typing import Optional import numpy as np -import random import torch from pytorch_lightning import _logger as log diff --git a/requirements/docs.txt b/requirements/docs.txt index 0cbf169d706901..7678dbdea45629 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -6,7 +6,7 @@ pandoc docutils sphinxcontrib-fulltoc sphinxcontrib-mockautodoc -git+https://github.com/PytorchLightning/lightning_sphinx_theme.git +https://github.com/PyTorchLightning/lightning_sphinx_theme/archive/master.zip#egg=pt-lightning-sphinx-theme # pip_shims sphinx-autodoc-typehints sphinx-paramlinks<0.4.0 diff --git a/setup.py b/setup.py index ded68282faa43b..908bf237222909 100755 --- a/setup.py +++ b/setup.py @@ -26,9 +26,9 @@ def load_requirements(path_dir=PATH_ROOT, file_name='base.txt', comment_char='#' # filer all comments if comment_char in ln: ln = ln[:ln.index(comment_char)].strip() - # Make slight syntax alteration to git dependency for PL's sphinx theme - if ln.startswith('git') and file_name == 'docs.txt': - ln = f'pt_lightning_sphinx_theme @ {ln}#egg=pt-lightning-sphinx-theme' + # skip directly installed dependencies + if ln.startswith('http'): + continue if ln: # if requirement is not empty reqs.append(ln) return reqs diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 71b38ac6980adb..dfcc1e036854d5 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -5,9 +5,10 @@ import cloudpickle import pytest +import torch import tests.base.develop_utils as tutils -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from tests.base import EvalModelTemplate @@ -93,3 +94,37 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir): ) result = trainer.fit(model) assert 1 == result + + +def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): + """ Tests that the checkpoint saved as 'last.ckpt' contains the latest information. """ + seed_everything(100) + model = EvalModelTemplate() + num_epochs = 3 + model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + early_stop_callback=False, + checkpoint_callback=model_checkpoint, + max_epochs=num_epochs, + ) + trainer.fit(model) + path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) # epoch=3.ckpt + path_last = str(tmpdir / ModelCheckpoint.CHECKPOINT_NAME_LAST) # last.ckpt + assert path_last_epoch != path_last + ckpt_last_epoch = torch.load(path_last_epoch) + ckpt_last = torch.load(path_last) + matching_keys = ( + "epoch", + "global_step", + ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE, + ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH, + ) + for key in matching_keys: + assert ckpt_last_epoch[key] == ckpt_last[key] + + # it is easier to load the model objects than to iterate over the raw dict of tensors + model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch) + model_last = EvalModelTemplate.load_from_checkpoint(path_last) + for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): + assert w0.eq(w1).all() diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index c0c41be42f5d69..0dc81c757a7c6e 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -299,7 +299,7 @@ def test_full_loop_ddp_spawn(tmpdir): trainer = Trainer( default_root_dir=tmpdir, - max_epochs=3, + max_epochs=5, weights_summary=None, distributed_backend='ddp_spawn', gpus=[0, 1] diff --git a/tests/trainer/test_trainer_steps_dict_return.py b/tests/trainer/test_trainer_steps_dict_return.py index 290983fbf6a5c8..7d6df7a2076a6e 100644 --- a/tests/trainer/test_trainer_steps_dict_return.py +++ b/tests/trainer/test_trainer_steps_dict_return.py @@ -56,7 +56,11 @@ def training_step_with_step_end(tmpdir): model.training_step_end = model.training_step_end_dict model.val_dataloader = None - trainer = Trainer(fast_dev_run=True, weights_summary=None) + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + weights_summary=None, + ) trainer.fit(model) # make sure correct steps were called @@ -107,8 +111,7 @@ def test_full_training_loop_dict(tmpdir): assert trainer.progress_bar_metrics['epoch_end_pbar_1'] == 234 # make sure training outputs what is expected - for batch_idx, batch in enumerate(model.train_dataloader()): - break + batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.run_training_batch(batch, batch_idx) assert out.signal == 0 @@ -131,7 +134,11 @@ def test_train_step_epoch_end(tmpdir): model.training_epoch_end = model.training_epoch_end_dict model.val_dataloader = None - trainer = Trainer(max_epochs=1, weights_summary=None) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + weights_summary=None, + ) trainer.fit(model) # make sure correct steps were called @@ -144,8 +151,7 @@ def test_train_step_epoch_end(tmpdir): assert trainer.progress_bar_metrics['epoch_end_pbar_1'] == 234 # make sure training outputs what is expected - for batch_idx, batch in enumerate(model.train_dataloader()): - break + batch_idx, batch = 0, next(iter(model.train_dataloader())) out = trainer.run_training_batch(batch, batch_idx) assert out.signal == 0