From da4b5eb09b5b36b934dec7c2dc22a6dd60e452d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 2 Mar 2021 10:47:55 +0100 Subject: [PATCH 01/10] fix duplicate console logging bug v2 (#6275) Co-authored-by: chaton Co-authored-by: Jirka Borovec (cherry picked from commit bc577ca7923387384c015d7e04d5b0c4e5f1afd9) --- CHANGELOG.md | 3 +++ docs/source/extensions/logging.rst | 14 ++++++++++---- .../computer_vision_fine_tuning.py | 3 ++- pytorch_lightning/__init__.py | 17 ++++++++++------- pytorch_lightning/callbacks/finetuning.py | 4 +++- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- pytorch_lightning/callbacks/pruning.py | 4 +++- pytorch_lightning/core/lightning.py | 4 +++- pytorch_lightning/core/saving.py | 3 ++- pytorch_lightning/loggers/comet.py | 3 ++- pytorch_lightning/loggers/csv_logs.py | 4 +++- pytorch_lightning/loggers/mlflow.py | 4 ++-- pytorch_lightning/loggers/neptune.py | 3 ++- pytorch_lightning/loggers/tensorboard.py | 4 +++- .../plugins/environments/slurm_environment.py | 4 +++- .../environments/torchelastic_environment.py | 4 +++- pytorch_lightning/plugins/training_type/ddp.py | 5 ++++- .../plugins/training_type/ddp_spawn.py | 4 +++- pytorch_lightning/profiler/profilers.py | 4 +++- .../trainer/connectors/accelerator_connector.py | 4 +++- .../trainer/connectors/slurm_connector.py | 3 ++- pytorch_lightning/trainer/trainer.py | 3 ++- pytorch_lightning/trainer/training_tricks.py | 3 ++- pytorch_lightning/tuner/batch_size_scaling.py | 4 +++- pytorch_lightning/tuner/lr_finder.py | 4 +++- pytorch_lightning/utilities/distributed.py | 3 ++- pytorch_lightning/utilities/seed.py | 4 +++- .../utilities/upgrade_checkpoint.py | 4 +++- tests/__init__.py | 6 ++++-- tests/callbacks/test_early_stopping.py | 5 ++++- tests/checkpointing/test_model_checkpoint.py | 4 +++- tests/test_profiler.py | 5 +++-- 32 files changed, 103 insertions(+), 44 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 42100f29ee6a4..84a512202b0a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -180,6 +180,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297) +- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) + + ## [1.2.1] - 2021-02-23 ### Fixed diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst index 026040f03a330..bfeed22fd4e66 100644 --- a/docs/source/extensions/logging.rst +++ b/docs/source/extensions/logging.rst @@ -259,13 +259,19 @@ Configure console logging ************************* Lightning logs useful information about the training process and user warnings to the console. -You can retrieve the Lightning logger and change it to your liking. For example, increase the logging level -to see fewer messages like so: +You can retrieve the Lightning logger and change it to your liking. For example, adjust the logging level +or redirect output for certain modules to log files: -.. code-block:: python +.. testcode:: import logging - logging.getLogger("lightning").setLevel(logging.ERROR) + + # configure logging at the root level of lightning + logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) + + # configure logging on module level, redirect to file + logger = logging.getLogger("pytorch_lightning.core") + logger.addHandler(logging.FileHandler("core.log")) Read more about custom Python logging `here `_. diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 65bf1bde141fa..823efaa53a5e5 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -38,6 +38,7 @@ See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html """ import argparse +import logging import os from pathlib import Path from typing import Union @@ -54,11 +55,11 @@ import pytorch_lightning as pl from pl_examples import cli_lightning_logo -from pytorch_lightning import _logger as log from pytorch_lightning import LightningDataModule from pytorch_lightning.callbacks.finetuning import BaseFinetuning from pytorch_lightning.utilities import rank_zero_info +log = logging.getLogger(__name__) DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip" # --- Finetuning Callback --- diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index fcde9037aee72..824f5ec7dbcaa 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -1,7 +1,8 @@ """Root package info.""" -import logging as python_logging +import logging import os +import sys import time _this_year = time.strftime("%Y") @@ -37,10 +38,14 @@ - https://pytorch-lightning.readthedocs.io/en/latest - https://pytorch-lightning.readthedocs.io/en/stable """ +_root_logger = logging.getLogger() +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) -_logger = python_logging.getLogger("lightning") -_logger.addHandler(python_logging.StreamHandler()) -_logger.setLevel(python_logging.INFO) +# if root logger has handlers, propagate messages up and let root logger process them +if not _root_logger.hasHandlers(): + _logger.addHandler(logging.StreamHandler()) + _logger.propagate = False _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) @@ -53,9 +58,7 @@ except NameError: __LIGHTNING_SETUP__: bool = False -if __LIGHTNING_SETUP__: - import sys # pragma: no-cover - +if __LIGHTNING_SETUP__: # pragma: no-cover sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover # We are not importing the rest of the lightning during the build process, as it may not be compiled yet else: diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index 9f2697a9f9635..b25e5e06e8b86 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -16,6 +16,7 @@ ^^^^^^^^^^^^^^^^^^^^ Freeze and unfreeze models for finetuning purposes """ +import logging from typing import Callable, Generator, Iterable, List, Optional, Union import torch @@ -24,12 +25,13 @@ from torch.nn.modules.container import Container, ModuleDict, ModuleList, Sequential from torch.optim.optimizer import Optimizer -from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +log = logging.getLogger(__name__) + def multiplicative(epoch): return 2 diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 383e1caa6a7dc..43f7a66dca313 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -18,7 +18,7 @@ Automatically save model checkpoints during training. """ - +import logging import os import re from copy import deepcopy @@ -29,13 +29,13 @@ import torch import yaml -from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import WarningCache +log = logging.getLogger(__name__) warning_cache = WarningCache() diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 72c32a8f5b738..3f82ab3565403 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -16,6 +16,7 @@ ^^^^^^^^^^^^ """ import inspect +import logging from copy import deepcopy from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -24,12 +25,13 @@ import torch.nn.utils.prune as pytorch_prune from torch import nn -from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException +log = logging.getLogger(__name__) + _PYTORCH_PRUNING_FUNCTIONS = { "ln_structured": pytorch_prune.ln_structured, "l1_unstructured": pytorch_prune.l1_unstructured, diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6a83b7b1f8637..d1a0a87c37f33 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -16,6 +16,7 @@ import collections import copy import inspect +import logging import os import re import tempfile @@ -31,7 +32,6 @@ from torch.nn import Module from torch.optim.optimizer import Optimizer -from pytorch_lightning import _logger as log from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary @@ -44,6 +44,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args +log = logging.getLogger(__name__) + class LightningModule( ABC, diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 2b470f43eaf3d..280eca55260a7 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -15,6 +15,7 @@ import ast import csv import inspect +import logging import os from argparse import Namespace from copy import deepcopy @@ -25,13 +26,13 @@ import torch import yaml -from pytorch_lightning import _logger as log from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.parsing import parse_class_init_keys +log = logging.getLogger(__name__) PRIMITIVE_TYPES = (bool, int, float, str) ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 9356552cbea4f..788c34fb9d58b 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -16,6 +16,7 @@ ------------ """ +import logging import os from argparse import Namespace from typing import Any, Dict, Optional, Union @@ -23,12 +24,12 @@ import torch from torch import is_tensor -from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException +log = logging.getLogger(__name__) _COMET_AVAILABLE = _module_available("comet_ml") if _COMET_AVAILABLE: diff --git a/pytorch_lightning/loggers/csv_logs.py b/pytorch_lightning/loggers/csv_logs.py index a78440143167b..4df672fa6e3b5 100644 --- a/pytorch_lightning/loggers/csv_logs.py +++ b/pytorch_lightning/loggers/csv_logs.py @@ -20,17 +20,19 @@ """ import csv import io +import logging import os from argparse import Namespace from typing import Any, Dict, Optional, Union import torch -from pytorch_lightning import _logger as log from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn +log = logging.getLogger(__name__) + class ExperimentWriter(object): r""" diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 8ae59581fe006..f99dcf441c3d1 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -15,17 +15,17 @@ MLflow Logger ------------- """ +import logging import re from argparse import Namespace from time import time from typing import Any, Dict, Optional, Union -from pytorch_lightning import _logger as log from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only, rank_zero_warn +log = logging.getLogger(__name__) LOCAL_FILE_URI_PREFIX = "file:" - _MLFLOW_AVAILABLE = _module_available("mlflow") try: import mlflow diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index 3960a983d929b..e3939b2c5d15a 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -15,16 +15,17 @@ Neptune Logger -------------- """ +import logging from argparse import Namespace from typing import Any, Dict, Iterable, Optional, Union import torch from torch import is_tensor -from pytorch_lightning import _logger as log from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only +log = logging.getLogger(__name__) _NEPTUNE_AVAILABLE = _module_available("neptune") if _NEPTUNE_AVAILABLE: diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 0485868fa2ef1..72d1731f80ec5 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -16,6 +16,7 @@ ------------------ """ +import logging import os from argparse import Namespace from typing import Any, Dict, Optional, Union @@ -24,13 +25,14 @@ from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard.summary import hparams -from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem +log = logging.getLogger(__name__) + if _OMEGACONF_AVAILABLE: from omegaconf import Container, OmegaConf diff --git a/pytorch_lightning/plugins/environments/slurm_environment.py b/pytorch_lightning/plugins/environments/slurm_environment.py index 59ab27cd4c323..7f9586cab0ace 100644 --- a/pytorch_lightning/plugins/environments/slurm_environment.py +++ b/pytorch_lightning/plugins/environments/slurm_environment.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import re -from pytorch_lightning import _logger as log from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +log = logging.getLogger(__name__) + class SLURMEnvironment(ClusterEnvironment): diff --git a/pytorch_lightning/plugins/environments/torchelastic_environment.py b/pytorch_lightning/plugins/environments/torchelastic_environment.py index bb77760e9dd61..5ac7d9f1c9a40 100644 --- a/pytorch_lightning/plugins/environments/torchelastic_environment.py +++ b/pytorch_lightning/plugins/environments/torchelastic_environment.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os -from pytorch_lightning import _logger as log from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.utilities import rank_zero_warn +log = logging.getLogger(__name__) + class TorchElasticEnvironment(ClusterEnvironment): diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index e0a52fc7609d6..007f898a27cc7 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import subprocess import sys @@ -23,7 +24,6 @@ from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer -from pytorch_lightning import _logger as log from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward @@ -44,6 +44,9 @@ from hydra.utils import get_original_cwd, to_absolute_path +log = logging.getLogger(__name__) + + class DDPPlugin(ParallelPlugin): """ Plugin for multi-process single-device training on one or multiple nodes. diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index cde2f3dea711c..fdb88a3c5cba5 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import re from typing import Any, Dict, List, Optional, Union @@ -21,7 +22,6 @@ from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer -from pytorch_lightning import _logger as log from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward @@ -39,6 +39,8 @@ ) from pytorch_lightning.utilities.seed import seed_everything +log = logging.getLogger(__name__) + class DDPSpawnPlugin(ParallelPlugin): diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index 24cf9af8e5802..ddef02e283578 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -14,6 +14,7 @@ """Profiler to check if there are any bottlenecks in your code.""" import cProfile import io +import logging import os import pstats import time @@ -24,9 +25,10 @@ import numpy as np -from pytorch_lightning import _logger as log from pytorch_lightning.utilities.cloud_io import get_filesystem +log = logging.getLogger(__name__) + class BaseProfiler(ABC): """ diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index e5c17614474ee..7d5e5fb9c358c 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from typing import List, Optional, Sequence, Union import torch -from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.gpu import GPUAccelerator @@ -61,6 +61,8 @@ if _HOROVOD_AVAILABLE: import horovod.torch as hvd +log = logging.getLogger(__name__) + class AcceleratorConnector(object): diff --git a/pytorch_lightning/trainer/connectors/slurm_connector.py b/pytorch_lightning/trainer/connectors/slurm_connector.py index 5086bab25593a..f2bb00abd84bd 100644 --- a/pytorch_lightning/trainer/connectors/slurm_connector.py +++ b/pytorch_lightning/trainer/connectors/slurm_connector.py @@ -1,8 +1,9 @@ +import logging import os import signal from subprocess import call -from pytorch_lightning import _logger as log +log = logging.getLogger(__name__) class SLURMConnector: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cffb1914c69f9..f378ee830d261 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Trainer to automate the training.""" +import logging import warnings from itertools import count from pathlib import Path @@ -20,7 +21,6 @@ import torch from torch.utils.data import DataLoader -from pytorch_lightning import _logger as log from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule @@ -64,6 +64,7 @@ from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.model_helpers import is_overridden +log = logging.getLogger(__name__) # warnings to ignore in trainer warnings.filterwarnings( 'ignore', message='torch.distributed.reduce_op is deprecated, ' diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 6b388f7137ce1..54731977cbee9 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from abc import ABC import torch from torch import Tensor -from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule EPSILON = 1e-6 EPSILON_FP16 = 1e-5 +log = logging.getLogger(__name__) class TrainerTrainingTricksMixin(ABC): diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index c29cffc42607b..a07de29324b24 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License +import logging import os from typing import Optional, Tuple -from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import DeviceType, rank_zero_warn @@ -24,6 +24,8 @@ from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr +log = logging.getLogger(__name__) + def scale_batch_size( trainer, diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 0975fdcbb6a79..c30e2cdac59cb 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +import logging import os from functools import wraps from typing import Callable, List, Optional, Sequence, Union @@ -22,7 +23,6 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from pytorch_lightning import _logger as log from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule @@ -39,6 +39,8 @@ else: from tqdm import tqdm +log = logging.getLogger(__name__) + def _determine_lr_attr_name(trainer, model: LightningModule) -> str: if isinstance(trainer.auto_lr_find, str): diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 61f581a5b5571..9e47af26f53d5 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import warnings from functools import wraps @@ -19,7 +20,7 @@ import torch -from pytorch_lightning import _logger as log +log = logging.getLogger(__name__) if torch.distributed.is_available(): from torch.distributed import group, ReduceOp diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index da98e00b71e60..8129075f99f4d 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -13,6 +13,7 @@ # limitations under the License. """Helper functions to help with reproducibility of models. """ +import logging import os import random from typing import Optional @@ -20,9 +21,10 @@ import numpy as np import torch -from pytorch_lightning import _logger as log from pytorch_lightning.utilities import rank_zero_warn +log = logging.getLogger(__name__) + def seed_everything(seed: Optional[int] = None) -> int: """ diff --git a/pytorch_lightning/utilities/upgrade_checkpoint.py b/pytorch_lightning/utilities/upgrade_checkpoint.py index 2e767542cd9bd..4896845f10263 100644 --- a/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import logging from shutil import copyfile import torch -from pytorch_lightning import _logger as log from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint KEYS_MAPPING = { @@ -27,6 +27,8 @@ "early_stop_callback_patience": (EarlyStopping, "patience"), } +log = logging.getLogger(__name__) + def upgrade_checkpoint(filepath): checkpoint = torch.load(filepath) diff --git a/tests/__init__.py b/tests/__init__.py index a833da7cbd890..7f88230f3296e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -11,12 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import numpy as np import torch -from pytorch_lightning.utilities import _TORCH_LOWER_EQUAL_1_4, _TORCH_QUANTIZE_AVAILABLE +logging.basicConfig(level=logging.ERROR) + +from pytorch_lightning.utilities import _TORCH_LOWER_EQUAL_1_4, _TORCH_QUANTIZE_AVAILABLE # noqa: E402 _TEST_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_TEST_ROOT) @@ -36,7 +39,6 @@ os.mkdir(_TEMP_PATH) _MISS_QUANT_DEFAULT = 'fbgemm' not in torch.backends.quantized.supported_engines - _SKIPIF_ARGS_PT_LE_1_4 = dict(condition=_TORCH_LOWER_EQUAL_1_4, reason="test pytorch > 1.4") _SKIPIF_ARGS_NO_GPU = dict(condition=not torch.cuda.is_available(), reason="test requires single-GPU machine") _SKIPIF_ARGS_NO_GPUS = dict(condition=torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 078b831a89908..7062fe35bbcb7 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import pickle import sys @@ -21,7 +22,7 @@ import pytest import torch -from pytorch_lightning import _logger, seed_everything, Trainer +from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -29,6 +30,8 @@ from tests.helpers.datamodules import ClassifDataModule from tests.helpers.simple_models import ClassificationModel +_logger = logging.getLogger(__name__) + class EarlyStoppingTestRestore(EarlyStopping): # this class has to be defined outside the test function, otherwise we get pickle error diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 29eaebc031e3c..3b4ea00ecb0ba 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import math import os import pickle @@ -676,7 +677,8 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v callbacks=[ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last)], max_epochs=max_epochs, ) - trainer.fit(model) + with caplog.at_level(logging.INFO): + trainer.fit(model) assert caplog.messages.count('Saving latest checkpoint...') == save_last diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 701b2a4bfb900..667e153a9edd4 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import logging import os import time from pathlib import Path @@ -95,7 +95,8 @@ def test_simple_profiler_overhead(simple_profiler, n_iter=5): def test_simple_profiler_describe(caplog, simple_profiler): """Ensure the profiler won't fail when reporting the summary.""" - simple_profiler.describe() + with caplog.at_level(logging.INFO): + simple_profiler.describe() assert "Profiler Report" in caplog.text From 0003f66bb15758b3eb24d818d89493750b3699ce Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 18 Mar 2021 17:14:38 +0000 Subject: [PATCH 02/10] [doc] Update Dict Train Loader doc. (#6579) * update doc * update example (cherry picked from commit 8853a36d457bf2b226f6dd28a816712f02c0069f) --- docs/source/advanced/multiple_loaders.rst | 50 +++++++++++++++++++---- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/docs/source/advanced/multiple_loaders.rst b/docs/source/advanced/multiple_loaders.rst index 2e3e3201b2181..6e5974ef1a0f8 100644 --- a/docs/source/advanced/multiple_loaders.rst +++ b/docs/source/advanced/multiple_loaders.rst @@ -9,7 +9,7 @@ Multiple Datasets Lightning supports multiple dataloaders in a few ways. 1. Create a dataloader that iterates multiple datasets under the hood. -2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning +2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning will automatically combine the batches from different loaders. 3. In the validation and test loop you also have the option to return multiple dataloaders which lightning will call sequentially. @@ -75,13 +75,13 @@ For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer loader_a = torch.utils.data.DataLoader(range(6), batch_size=4) loader_b = torch.utils.data.DataLoader(range(15), batch_size=5) - + # pass loaders as a dict. This will create batches like this: # {'a': batch from loader_a, 'b': batch from loader_b} loaders = {'a': loader_a, 'b': loader_b} - # OR: + # OR: # pass loaders as sequence. This will create batches like this: # [batch from loader_a, batch from loader_b] loaders = [loader_a, loader_b] @@ -89,7 +89,24 @@ For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer return loaders Furthermore, Lightning also supports that nested lists and dicts (or a combination) can -be returned +be returned. + +.. testcode:: + + class LitModel(LightningModule): + + def train_dataloader(self): + + loader_a = torch.utils.data.DataLoader(range(8), batch_size=4) + loader_b = torch.utils.data.DataLoader(range(16), batch_size=2) + + return {'a': loader_a, 'b': loader_b} + + def training_step(self, batch, batch_idx): + # access a dictionnary with a batch from each dataloader + batch_a = batch["a"] + batch_b = batch["b"] + .. testcode:: @@ -103,12 +120,29 @@ be returned loader_c = torch.utils.data.DataLoader(range(64), batch_size=4) # pass loaders as a nested dict. This will create batches like this: - # {'loader_a_b': {'a': batch from loader a, 'b': batch from loader b}, - # 'loader_c_d': {'c': batch from loader c, 'd': batch from loader d}} - loaders = {'loaders_a_b': {'a': loader_a, 'b': loader_b}, - 'loaders_c_d': {'c': loader_c, 'd': loader_d}} + loaders = { + 'loaders_a_b': { + 'a': loader_a, + 'b': loader_b + }, + 'loaders_c_d': { + 'c': loader_c, + 'd': loader_d + } + } return loaders + def training_step(self, batch, batch_idx): + # access the data + batch_a_b = batch["loaders_a_b"] + batch_c_d = batch["loaders_c_d"] + + batch_a = batch_a_b["a"] + batch_b = batch_a_b["a"] + + batch_c = batch_c_d["c"] + batch_d = batch_c_d["d"] + ---------- Test/Val dataloaders From b31e75e8b6a41655f0a7f37ecbb9c705ddbf458c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 19 Mar 2021 17:26:58 +0100 Subject: [PATCH 03/10] Fix all_gather for tpu_cores=8 (#6587) (cherry picked from commit 983a888f498131911a1296400448b004a53c1717) --- CHANGELOG.md | 35 ++++++++++++++++++--------- pytorch_lightning/accelerators/tpu.py | 8 +++--- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 84a512202b0a0..5ba84e666d2c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -119,6 +119,29 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416)) +- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) + + +- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) + + +- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093)) + + +- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115) + + + +## [1.2.5] - 2021-03-23 + +### Changed + + +### Fixed + +- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587)) + + ## [1.2.4] - 2021-03-16 ### Changed @@ -139,9 +162,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) -- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) - - ## [1.2.3] - 2021-03-09 ### Fixed @@ -180,9 +200,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297) -- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) - - ## [1.2.1] - 2021-02-23 ### Fixed @@ -192,12 +209,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107)) -- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093)) - - -- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115) - - ## [1.2.0] - 2021-02-18 ### Added diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index d285a197c49fe..0fd1d4254e5a1 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -35,12 +35,12 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s Function to gather a tensor from several distributed processes Args: tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op + group: not available with TPUs + sync_grads: not available with TPUs Return: A tensor of shape (world_size, batch, ...) """ # todo: Add support for backward with all_gather - if torch.distributed.is_initialized(): - return xm.all_gather(tensor, group=group, sync_grads=sync_grads) + if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed: + return xm.all_gather(tensor).view(-1, *tensor.shape) return tensor From 02b40d2d2f4ef363e8b2b60f37be2add039bdcf9 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 19 Mar 2021 20:32:57 +0100 Subject: [PATCH 04/10] Update Gradient Clipping for TPU Accelerator (#6576) (cherry picked from commit 87c03b10389bb88d1b1a4e5fbc40e8e02091fd04) --- CHANGELOG.md | 2 ++ pytorch_lightning/accelerators/tpu.py | 16 +++++++++++ .../plugins/precision/precision_plugin.py | 1 - tests/models/test_tpu.py | 28 +++++++++++++++++++ 4 files changed, 46 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ba84e666d2c1..c8a5071065270 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -136,6 +136,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576)) + ### Fixed diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 0fd1d4254e5a1..772c9f354ac9f 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -12,6 +12,9 @@ if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm + from torch_xla._patched_functions import clip_grad_norm_ + + xla_clip_grad_norm_ = clip_grad_norm_ class TPUAccelerator(Accelerator): @@ -44,3 +47,16 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed: return xm.all_gather(tensor).view(-1, *tensor.shape) return tensor + + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0): + + model = self.lightning_module + parameters = model.parameters() + + grad_clip_val = float(clip_val) + if grad_clip_val <= 0: + return + + max_norm = grad_clip_val + + xla_clip_grad_norm_(parameters, max_norm, norm_type) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 34879e514a6f2..ad96a2c81bcfd 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -88,7 +88,6 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)) -> None: """Clips the gradients to a specific value""" - # TODO: separate TPU case from here if clip_val is None: return diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index db96e6854db90..0554d924e6e9f 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -347,3 +347,31 @@ def test_reduce(rank): assert result.item() == 8 xmp.spawn(test_reduce, nprocs=8, start_method='fork') + + +@pytest.mark.parametrize("clip_val", [0, 10]) +@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@pl_multi_process_test +@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_") +def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir): + """ + Ensure that clip gradients is only called if the value is greater than 0. + """ + tutils.reset_seed() + trainer_options = dict( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=1, + tpu_cores=1, + precision=16, + limit_train_batches=4, + limit_val_batches=4, + gradient_clip_val=clip_val, + ) + model = BoringModel() + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) + + if clip_val > 0: + mock_clip_grad_norm.assert_called() + else: + mock_clip_grad_norm.assert_not_called() From 5a4529a02e1be25fc3d945c24dcae33609579dce Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 20 Mar 2021 19:58:59 +0100 Subject: [PATCH 05/10] fixing examples (#6600) * try Azure * -e * path (cherry picked from commit cb590392880aefcb0830bf00ec08e4beef6d4f7e) --- azure-pipelines.yml | 10 +++++----- pl_examples/basic_examples/submit_ddp2_job.sh | 2 +- pl_examples/basic_examples/submit_ddp_job.sh | 2 +- tests/__init__.py | 4 ++-- tests/base/model_template.py | 3 ++- tests/checkpointing/test_legacy_checkpoints.py | 4 ++-- tests/helpers/advanced_models.py | 4 +++- tests/helpers/datasets.py | 15 +++++---------- tests/helpers/test_datasets.py | 11 ++++++++--- 9 files changed, 29 insertions(+), 26 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 6dfddda0295fe..1447176c7ea70 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -95,12 +95,12 @@ jobs: python -m pytest benchmarks -v --maxfail=2 --durations=0 displayName: 'Testing: benchmarks' - - bash: | + - script: | + set -e python -m pytest pl_examples -v --maxfail=2 --durations=0 python setup.py install --user --quiet bash pl_examples/run_ddp-example.sh - cd pl_examples/basic_examples - bash submit_ddp_job.sh - bash submit_ddp2_job.sh - pip uninstall -y pytorch-lightning + # cd pl_examples/basic_examples + # bash submit_ddp_job.sh + # bash submit_ddp2_job.sh displayName: 'Examples' diff --git a/pl_examples/basic_examples/submit_ddp2_job.sh b/pl_examples/basic_examples/submit_ddp2_job.sh index 6fed6afef0d1c..026589a604c36 100755 --- a/pl_examples/basic_examples/submit_ddp2_job.sh +++ b/pl_examples/basic_examples/submit_ddp2_job.sh @@ -24,4 +24,4 @@ source activate $1 # ------------------------- # run script from above -srun python3 image_classifier.py --accelerator 'ddp2' --gpus 2 --num_nodes 2 +srun python3 simple_image_classifier.py --accelerator 'ddp2' --gpus 2 --num_nodes 2 --max_epochs 5 diff --git a/pl_examples/basic_examples/submit_ddp_job.sh b/pl_examples/basic_examples/submit_ddp_job.sh index 383579c4346b6..b4f5ff0a64d92 100755 --- a/pl_examples/basic_examples/submit_ddp_job.sh +++ b/pl_examples/basic_examples/submit_ddp_job.sh @@ -24,4 +24,4 @@ source activate $1 # ------------------------- # run script from above -srun python3 image_classifier.py --accelerator 'ddp' --gpus 2 --num_nodes 2 +srun python3 simple_image_classifier.py --accelerator 'ddp' --gpus 2 --num_nodes 2 --max_epochs 5 diff --git a/tests/__init__.py b/tests/__init__.py index 7f88230f3296e..e002e36518661 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -24,8 +24,8 @@ _TEST_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_TEST_ROOT) _TEMP_PATH = os.path.join(_PROJECT_ROOT, 'test_temp') -DATASETS_PATH = os.path.join(_PROJECT_ROOT, 'Datasets') -LEGACY_PATH = os.path.join(_PROJECT_ROOT, 'legacy') +PATH_DATASETS = os.path.join(_PROJECT_ROOT, 'Datasets') +PATH_LEGACY = os.path.join(_PROJECT_ROOT, 'legacy') # todo: this setting `PYTHONPATH` may not be used by other evns like Conda for import packages if _PROJECT_ROOT not in os.getenv('PYTHONPATH', ""): diff --git a/tests/base/model_template.py b/tests/base/model_template.py index 1d36df8f5ef50..991ed03a4b7a3 100644 --- a/tests/base/model_template.py +++ b/tests/base/model_template.py @@ -18,6 +18,7 @@ import torch.nn.functional as F from pytorch_lightning.core.lightning import LightningModule +from tests import PATH_DATASETS from tests.base.model_optimizers import ConfigureOptimizersPool from tests.base.model_test_dataloaders import TestDataloaderVariations from tests.base.model_test_epoch_ends import TestEpochEndVariations @@ -28,7 +29,7 @@ from tests.base.model_valid_dataloaders import ValDataloaderVariations from tests.base.model_valid_epoch_ends import ValidationEpochEndVariations from tests.base.model_valid_steps import ValidationStepVariations -from tests.helpers.datasets import PATH_DATASETS, TrialMNIST +from tests.helpers.datasets import TrialMNIST class EvalModelTemplate( diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index b5d22372ff15f..f40b849dd2b36 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -18,9 +18,9 @@ import pytest from pytorch_lightning import Trainer -from tests import LEGACY_PATH +from tests import PATH_LEGACY -LEGACY_CHECKPOINTS_PATH = os.path.join(LEGACY_PATH, 'checkpoints') +LEGACY_CHECKPOINTS_PATH = os.path.join(PATH_LEGACY, 'checkpoints') CHECKPOINT_EXTENSION = ".ckpt" diff --git a/tests/helpers/advanced_models.py b/tests/helpers/advanced_models.py index 7ad678b3046fd..2b0146e1ee099 100644 --- a/tests/helpers/advanced_models.py +++ b/tests/helpers/advanced_models.py @@ -20,6 +20,7 @@ from torch.utils.data import DataLoader from pytorch_lightning.core.lightning import LightningModule +from tests import PATH_DATASETS from tests.helpers.datasets import AverageDataset, MNIST, TrialMNIST @@ -165,7 +166,7 @@ def configure_optimizers(self): return [opt_g, opt_d], [] def train_dataloader(self): - return DataLoader(TrialMNIST(train=True, download=True), batch_size=16) + return DataLoader(TrialMNIST(root=PATH_DATASETS, train=True, download=True), batch_size=16) class ParityModuleRNN(LightningModule): @@ -223,6 +224,7 @@ def configure_optimizers(self): def train_dataloader(self): return DataLoader(MNIST( + root=PATH_DATASETS, train=True, download=True, ), batch_size=128, num_workers=1) diff --git a/tests/helpers/datasets.py b/tests/helpers/datasets.py index e7bdad0f1538c..77035796ca3b1 100644 --- a/tests/helpers/datasets.py +++ b/tests/helpers/datasets.py @@ -22,11 +22,6 @@ from torch import Tensor from torch.utils.data import Dataset -from tests import _PROJECT_ROOT - -#: local path to test datasets -PATH_DATASETS = os.path.join(_PROJECT_ROOT, 'Datasets') - class MNIST(Dataset): """ @@ -47,7 +42,7 @@ class MNIST(Dataset): downloaded again. Examples: - >>> dataset = MNIST(download=True) + >>> dataset = MNIST(".", download=True) >>> len(dataset) 60000 >>> torch.bincount(dataset.targets) @@ -65,7 +60,7 @@ class MNIST(Dataset): def __init__( self, - root: str = PATH_DATASETS, + root: str, train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, @@ -152,7 +147,7 @@ class TrialMNIST(MNIST): kwargs: Same as MNIST Examples: - >>> dataset = TrialMNIST(download=True) + >>> dataset = TrialMNIST(".", download=True) >>> len(dataset) 300 >>> sorted(set([d.item() for d in dataset.targets])) @@ -161,7 +156,7 @@ class TrialMNIST(MNIST): tensor([100, 100, 100]) """ - def __init__(self, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs): + def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs): # number of examples per class self.num_samples = num_samples # take just a subset of MNIST dataset @@ -169,7 +164,7 @@ def __init__(self, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2 self.cache_folder_name = f"digits-{'-'.join(str(d) for d in self.digits)}_nb-{self.num_samples}" - super().__init__(normalize=(0.5, 1.0), **kwargs) + super().__init__(root, normalize=(0.5, 1.0), **kwargs) @staticmethod def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor, num_samples: int, digits: Sequence): diff --git a/tests/helpers/test_datasets.py b/tests/helpers/test_datasets.py index 6319fdb562504..42b5df0ff91a4 100644 --- a/tests/helpers/test_datasets.py +++ b/tests/helpers/test_datasets.py @@ -16,12 +16,17 @@ import cloudpickle import pytest +from tests import PATH_DATASETS from tests.helpers.datasets import AverageDataset, MNIST, TrialMNIST -@pytest.mark.parametrize('dataset_cls', [MNIST, TrialMNIST, AverageDataset]) -def test_pickling_dataset_mnist(tmpdir, dataset_cls): - mnist = dataset_cls() +@pytest.mark.parametrize('dataset_cls,args', [ + (MNIST, dict(root=PATH_DATASETS)), + (TrialMNIST, dict(root=PATH_DATASETS)), + (AverageDataset, dict()), +]) +def test_pickling_dataset_mnist(tmpdir, dataset_cls, args): + mnist = dataset_cls(**args) mnist_pickled = pickle.dumps(mnist) pickle.loads(mnist_pickled) From 717ea43f8d6ffe5cac45cea30aa56e458c1fcfa5 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 22 Mar 2021 13:39:19 +0100 Subject: [PATCH 06/10] refactoring setup (#6590) * refactoring setup * . * docs * flake8 (cherry picked from commit 1fae10a2dc8224379eac84d6242e0847c2685565) --- docs/source/conf.py | 23 ++++--- pytorch_lightning/__init__.py | 81 +++++++------------------ pytorch_lightning/callbacks/progress.py | 3 +- pytorch_lightning/info.py | 35 +++++++++++ pytorch_lightning/setup_tools.py | 6 +- setup.py | 46 ++++++++------ 6 files changed, 101 insertions(+), 93 deletions(-) create mode 100644 pytorch_lightning/info.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 813d5ee978821..b6b97540179db 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,7 +13,6 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. # import m2r -import builtins import glob import os import shutil @@ -27,10 +26,13 @@ FOLDER_GENERATED = 'generated' SPHINX_MOCK_REQUIREMENTS = int(os.environ.get('SPHINX_MOCK_REQUIREMENTS', True)) -if SPHINX_MOCK_REQUIREMENTS: - builtins.__LIGHTNING_SETUP__ = True -import pytorch_lightning # noqa: E402 +try: + from pytorch_lightning import info +except ImportError: + # alternative https://stackoverflow.com/a/67692/4521646 + sys.path.append(os.path.join(PATH_ROOT, "pytorch_lightning")) + import info # -- Project documents ------------------------------------------------------- @@ -79,13 +81,13 @@ def _transform_changelog(path_in: str, path_out: str) -> None: # -- Project information ----------------------------------------------------- project = 'PyTorch Lightning' -copyright = pytorch_lightning.__copyright__ -author = pytorch_lightning.__author__ +copyright = info.__copyright__ +author = info.__author__ # The short X.Y version -version = pytorch_lightning.__version__ +version = info.__version__ # The full version, including alpha/beta/rc tags -release = pytorch_lightning.__version__ +release = info.__version__ # -- General configuration --------------------------------------------------- @@ -176,8 +178,8 @@ def _transform_changelog(path_in: str, path_out: str) -> None: # documentation. html_theme_options = { - 'pytorch_project': pytorch_lightning.__homepage__, - 'canonical_url': pytorch_lightning.__homepage__, + 'pytorch_project': info.__homepage__, + 'canonical_url': info.__homepage__, 'collapse_navigation': False, 'display_version': True, 'logo_only': False, @@ -279,6 +281,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None: 'torch': ('https://pytorch.org/docs/stable/', None), 'numpy': ('https://numpy.org/doc/stable/', None), 'PIL': ('https://pillow.readthedocs.io/en/stable/', None), + 'torchmetrics': ('https://torchmetrics.readthedocs.io/en/stable/', None), } # -- Options for todo extension ---------------------------------------------- diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index 824f5ec7dbcaa..b9660475bf2f7 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -2,42 +2,17 @@ import logging import os -import sys -import time -_this_year = time.strftime("%Y") -__version__ = '1.2.4' -__author__ = 'William Falcon et al.' -__author_email__ = 'waf2107@columbia.edu' -__license__ = 'Apache-2.0' -__copyright__ = f'Copyright (c) 2018-{_this_year}, {__author__}.' -__homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning' -# this has to be simple string, see: https://github.com/pypa/twine/issues/522 -__docs__ = ( - "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." - " Scale your models. Write less boilerplate." +from pytorch_lightning.info import ( # noqa: F401 + __author__, + __author_email__, + __copyright__, + __docs__, + __homepage__, + __license__, + __version__, ) -__long_docs__ = """ -Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. - It's more of a style-guide than a framework. -In Lightning, you organize your code into 3 distinct categories: - -1. Research code (goes in the LightningModule). -2. Engineering code (you delete, and is handled by the Trainer). -3. Non-essential research code (logging, etc. this goes in Callbacks). - -Although your research/production project might start simple, once you add things like GPU AND TPU training, - 16-bit precision, etc, you end up spending more time engineering than researching. - Lightning automates AND rigorously tests those parts for you. - -Overall, Lightning guarantees rigorously tested, correct, modern best practices for the automated parts. - -Documentation -------------- -- https://pytorch-lightning.readthedocs.io/en/latest -- https://pytorch-lightning.readthedocs.io/en/stable -""" _root_logger = logging.getLogger() _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) @@ -50,32 +25,20 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) -try: - # This variable is injected in the __builtins__ by the build - # process. It used to enable importing subpackages of skimage when - # the binaries are not built - _ = None if __LIGHTNING_SETUP__ else None -except NameError: - __LIGHTNING_SETUP__: bool = False - -if __LIGHTNING_SETUP__: # pragma: no-cover - sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover - # We are not importing the rest of the lightning during the build process, as it may not be compiled yet -else: - from pytorch_lightning import metrics - from pytorch_lightning.callbacks import Callback - from pytorch_lightning.core import LightningDataModule, LightningModule - from pytorch_lightning.trainer import Trainer - from pytorch_lightning.utilities.seed import seed_everything - - __all__ = [ - 'Trainer', - 'LightningDataModule', - 'LightningModule', - 'Callback', - 'seed_everything', - 'metrics', - ] +from pytorch_lightning import metrics # noqa: E402 +from pytorch_lightning.callbacks import Callback # noqa: E402 +from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402 +from pytorch_lightning.trainer import Trainer # noqa: E402 +from pytorch_lightning.utilities.seed import seed_everything # noqa: E402 + +__all__ = [ + 'Trainer', + 'LightningDataModule', + 'LightningModule', + 'Callback', + 'seed_everything', + 'metrics', +] # for compatibility with namespace packages __import__('pkg_resources').declare_namespace(__name__) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 3f401669c351e..587fee95e9cd0 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -37,8 +37,7 @@ class tqdm(_tqdm): """ - Custom tqdm progressbar where we append 0 to floating points/strings to - prevent the progress bar from flickering + Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering """ @staticmethod diff --git a/pytorch_lightning/info.py b/pytorch_lightning/info.py new file mode 100644 index 0000000000000..1d729ee758d02 --- /dev/null +++ b/pytorch_lightning/info.py @@ -0,0 +1,35 @@ +import time + +_this_year = time.strftime("%Y") +__version__ = '1.2.4' +__author__ = 'William Falcon et al.' +__author_email__ = 'waf2107@columbia.edu' +__license__ = 'Apache-2.0' +__copyright__ = f'Copyright (c) 2018-{_this_year}, {__author__}.' +__homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning' +# this has to be simple string, see: https://github.com/pypa/twine/issues/522 +__docs__ = ( + "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." + " Scale your models. Write less boilerplate." +) +__long_docs__ = """ +Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. + It's more of a style-guide than a framework. + +In Lightning, you organize your code into 3 distinct categories: + +1. Research code (goes in the LightningModule). +2. Engineering code (you delete, and is handled by the Trainer). +3. Non-essential research code (logging, etc. this goes in Callbacks). + +Although your research/production project might start simple, once you add things like GPU AND TPU training, + 16-bit precision, etc, you end up spending more time engineering than researching. + Lightning automates AND rigorously tests those parts for you. + +Overall, Lightning guarantees rigorously tested, correct, modern best practices for the automated parts. + +Documentation +------------- +- https://pytorch-lightning.readthedocs.io/en/latest +- https://pytorch-lightning.readthedocs.io/en/stable +""" diff --git a/pytorch_lightning/setup_tools.py b/pytorch_lightning/setup_tools.py index f5aed2608635e..3362ccb479895 100644 --- a/pytorch_lightning/setup_tools.py +++ b/pytorch_lightning/setup_tools.py @@ -16,7 +16,7 @@ import re from typing import List -from pytorch_lightning import __homepage__, __version__, _PROJECT_ROOT +_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]: @@ -40,10 +40,10 @@ def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comme return reqs -def _load_readme_description(path_dir: str, homepage: str = __homepage__, version: str = __version__) -> str: +def _load_readme_description(path_dir: str, homepage: str, version: str) -> str: """Load readme as decribtion - >>> _load_readme_description(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> _load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE '
...' """ path_readme = os.path.join(path_dir, "README.md") diff --git a/setup.py b/setup.py index 5d619d51977b2..e53e24ebf0702 100755 --- a/setup.py +++ b/setup.py @@ -16,20 +16,22 @@ import os # Always prefer setuptools over distutils +import sys + from setuptools import find_packages, setup try: - import builtins + from pytorch_lightning import info, setup_tools except ImportError: - import __builtin__ as builtins + # alternative https://stackoverflow.com/a/67692/4521646 + sys.path.append("pytorch_lightning") + import info + import setup_tools # https://packaging.python.org/guides/single-sourcing-package-version/ # http://blog.ionelmc.ro/2014/05/25/python-packaging/ -PATH_ROOT = os.path.dirname(__file__) -builtins.__LIGHTNING_SETUP__ = True - -import pytorch_lightning # noqa: E402 -from pytorch_lightning.setup_tools import _load_readme_description, _load_requirements # noqa: E402 +_PATH_ROOT = os.path.dirname(__file__) +_PATH_REQUIRE = os.path.join(_PATH_ROOT, 'requirements') # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras # Define package extras. These are only installed if you specify them. @@ -37,10 +39,10 @@ # From local copy of repo, use like `pip install ".[dev, docs]"` extras = { # 'docs': load_requirements(file_name='docs.txt'), - 'examples': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='examples.txt'), - 'loggers': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='loggers.txt'), - 'extra': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='extra.txt'), - 'test': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='test.txt') + 'examples': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='examples.txt'), + 'loggers': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='loggers.txt'), + 'extra': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='extra.txt'), + 'test': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='test.txt') } extras['dev'] = extras['extra'] + extras['loggers'] + extras['test'] extras['all'] = extras['dev'] + extras['examples'] # + extras['docs'] @@ -53,6 +55,12 @@ # filter cpu only packages extras[ex] = [pkg for pkg in extras[kw] if not any(pgpu.lower() in pkg.lower() for pgpu in PACKAGES_GPU_ONLY)] +long_description = setup_tools._load_readme_description( + _PATH_ROOT, + homepage=info.__homepage__, + version=info.__version__, +) + # https://packaging.python.org/discussions/install-requires-vs-requirements / # keep the meta-data here for simplicity in reading this file... it's not obvious # what happens and to non-engineers they won't know to look in init ... @@ -60,22 +68,22 @@ # engineer specific practices setup( name="pytorch-lightning", - version=pytorch_lightning.__version__, - description=pytorch_lightning.__docs__, - author=pytorch_lightning.__author__, - author_email=pytorch_lightning.__author_email__, - url=pytorch_lightning.__homepage__, + version=info.__version__, + description=info.__docs__, + author=info.__author__, + author_email=info.__author_email__, + url=info.__homepage__, download_url='https://github.com/PyTorchLightning/pytorch-lightning', - license=pytorch_lightning.__license__, + license=info.__license__, packages=find_packages(exclude=['tests', 'tests/*', 'benchmarks', 'legacy', 'legacy/*']), - long_description=_load_readme_description(PATH_ROOT), + long_description=long_description, long_description_content_type='text/markdown', include_package_data=True, zip_safe=False, keywords=['deep learning', 'pytorch', 'AI'], python_requires='>=3.6', setup_requires=[], - install_requires=_load_requirements(PATH_ROOT), + install_requires=setup_tools._load_requirements(_PATH_ROOT), extras_require=extras, project_urls={ "Bug Tracker": "https://github.com/PyTorchLightning/pytorch-lightning/issues", From fa6b3be6e971773ce492a2a17de30beb65411142 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 23 Mar 2021 08:51:45 +0100 Subject: [PATCH 07/10] fix comparing versions (#6434) * fix comparing versions * chlog * . * ... * datasets (cherry picked from commit 8cd75a4dd51939881da265752c2d81307cbe4d9e) --- .github/workflows/docs-checks.yml | 2 +- CHANGELOG.md | 3 +++ Makefile | 2 +- docs/source/conf.py | 1 + pytorch_lightning/utilities/imports.py | 22 ++++++++++++++++++---- requirements/extra.txt | 1 + 6 files changed, 25 insertions(+), 6 deletions(-) diff --git a/.github/workflows/docs-checks.yml b/.github/workflows/docs-checks.yml index 5ee4f23b4b3cc..4488c598c8ac7 100644 --- a/.github/workflows/docs-checks.yml +++ b/.github/workflows/docs-checks.yml @@ -98,7 +98,7 @@ jobs: # First run the same pipeline as Read-The-Docs cd docs make clean - make html --debug --jobs $(nproc) SPHINXOPTS="-W" + make html --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going" - name: Upload built docs uses: actions/upload-artifact@v2 diff --git a/CHANGELOG.md b/CHANGELOG.md index c8a5071065270..8064134b77278 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -144,6 +144,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587)) +- Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434)) + + ## [1.2.4] - 2021-03-16 ### Changed diff --git a/Makefile b/Makefile index d35e0b77f8429..04b08fa2d27d1 100644 --- a/Makefile +++ b/Makefile @@ -29,4 +29,4 @@ test: clean docs: clean pip install --quiet -r requirements/docs.txt - python -m sphinx -b html -W docs/source docs/build + python -m sphinx -b html -W --keep-going docs/source docs/build diff --git a/docs/source/conf.py b/docs/source/conf.py index b6b97540179db..4ab675f1e4dd2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -334,6 +334,7 @@ def package_list_from_file(file): } MOCK_PACKAGES = [] if SPHINX_MOCK_REQUIREMENTS: + MOCK_PACKAGES += ['fairscale'] # mock also base packages when we are on RTD since we don't install them there MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements.txt')) MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements', 'extra.txt')) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 8024997382457..f94044ff7683d 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities""" +import importlib import operator import platform from distutils.version import LooseVersion from importlib.util import find_spec import torch -from pkg_resources import DistributionNotFound, get_distribution +from pkg_resources import DistributionNotFound def _module_available(module_path: str) -> bool: @@ -41,11 +42,24 @@ def _module_available(module_path: str) -> bool: def _compare_version(package: str, op, version) -> bool: + """ + Compare package version with some requirements + + >>> _compare_version("torch", operator.ge, "0.1") + True + """ try: - pkg_version = LooseVersion(get_distribution(package).version) - return op(pkg_version, LooseVersion(version)) - except DistributionNotFound: + pkg = importlib.import_module(package) + except (ModuleNotFoundError, DistributionNotFound): + return False + try: + pkg_version = LooseVersion(pkg.__version__) + except AttributeError: return False + if not (hasattr(pkg_version, "vstring") and hasattr(pkg_version, "version")): + # this is mock by sphinx, so it shall return True ro generate all summaries + return True + return op(pkg_version, LooseVersion(version)) _IS_WINDOWS = platform.system() == "Windows" diff --git a/requirements/extra.txt b/requirements/extra.txt index a05c4971ac450..715916c4e36ac 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,4 +7,5 @@ torchtext>=0.5 # onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 +# todo: when switch to standard package stream, drop `fairscale` from hard mocked docs libs https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip From e69f66fd439a5d7db4d349d000d06194bbf493b8 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Sun, 21 Mar 2021 00:15:49 +0100 Subject: [PATCH 08/10] Add AMP for validation, prediction and testing (#6565) * Add Tests for val and test-steps * Add native AMP * pep8 tests * pep8 plugin * changelog (cherry picked from commit 634d83134fea4bb701c24abd5a4a38adb0eddbcd) --- CHANGELOG.md | 2 ++ .../plugins/precision/native_amp.py | 18 +++++++++++ tests/models/test_amp.py | 31 +++++++++++++++++-- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8064134b77278..ce55dccce3597 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -92,6 +92,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565)) + - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 94e6cf376b03a..d19b05358cdd5 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -93,3 +93,21 @@ def train_step_context(self) -> Generator[autocast, None, None]: """Enable autocast context""" with torch.cuda.amp.autocast(): yield + + @contextmanager + def val_step_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield + + @contextmanager + def test_step_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield + + @contextmanager + def predict_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 53ec32764f3ed..2f16f2fe64e75 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -17,6 +17,7 @@ import pytest import torch from torch import optim +from torch.utils.data import DataLoader import tests.helpers.utils as tutils from pytorch_lightning import Trainer @@ -24,17 +25,35 @@ from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _APEX_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel +from tests.helpers import BoringModel, RandomDataset class AMPTestModel(BoringModel): - def training_step(self, batch, batch_idx): + def _step(self, batch, batch_idx): assert torch.is_autocast_enabled() output = self(batch) assert output.dtype == torch.float16 loss = self.loss(batch, output) - return {"loss": loss} + return loss + + def training_step(self, batch, batch_idx): + output = self._step(batch, batch_idx) + return {"loss": output} + + def validation_step(self, batch, batch_idx): + output = self._step(batch, batch_idx) + return {"x": output} + + def test_step(self, batch, batch_idx): + output = self._step(batch, batch_idx) + return {"y": output} + + def predict(self, batch, batch_idx, dataloader_idx=None): + assert torch.is_autocast_enabled() + output = self(batch) + assert output.dtype == torch.float16 + return output @pytest.mark.skip(reason='dp + amp not supported currently') # TODO @@ -54,6 +73,8 @@ def test_amp_single_gpu_dp(tmpdir): model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -73,6 +94,8 @@ def test_amp_single_gpu_ddp_spawn(tmpdir): model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -112,6 +135,8 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir): model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" From 884c3abc9fabcf6677a5267f34c39979db6e920a Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 23 Mar 2021 23:37:28 +0100 Subject: [PATCH 09/10] v1.2.5 & chlog --- CHANGELOG.md | 134 ++------------------------------------ pytorch_lightning/info.py | 2 +- 2 files changed, 5 insertions(+), 131 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ce55dccce3597..d192f814c4081 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,148 +5,21 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -## [UnReleased] - 2021-MM-DD - -### Added - -- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) - -- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - -- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) - - -- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) - - -- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) - - -- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948)) - - -- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) - - -- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277)) - - -- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) - - -- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) - - -### Changed - -- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) - - -- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) - - -- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) - - -- Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) - - -### Deprecated - -- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - - -- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) - - -### Removed - -- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164)) - - -- Removed no return warning from val/test step ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) - - -- Removed passing a `ModelCheckpoint` instance to `Trainer(checkpoint_callback)` ([#6166](https://github.com/PyTorchLightning/pytorch-lightning/pull/6166)) - - -- Removed deprecated Trainer argument `enable_pl_optimizer` and `automatic_optimization` ([#6163](https://github.com/PyTorchLightning/pytorch-lightning/pull/6163)) - - -- Removed deprecated metrics ([#6161](https://github.com/PyTorchLightning/pytorch-lightning/pull/6161)) - * from `pytorch_lightning.metrics.functional.classification` removed `to_onehot`, `to_categorical`, `get_num_classes`, `roc`, `multiclass_roc`, `average_precision`, `precision_recall_curve`, `multiclass_precision_recall_curve` - * from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce` - - -- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162)) - - -- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167)) - - -- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207)) - - -- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093)) - - -### Fixed - -- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565)) - -- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) - - -- Move lightning module to correct device type when using LightningDistributedWrapper ([#6070](https://github.com/PyTorchLightning/pytorch-lightning/pull/6070)) - - -- Do not print top-k verbose log with `ModelCheckpoint(monitor=None)` ([#6109](https://github.com/PyTorchLightning/pytorch-lightning/pull/6109)) - - -- Fixed `ModelCheckpoint(monitor=None, save_last=True)` not saving checkpoints ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136)) - - -- Fixed `ModelCheckpoint(save_top_k=0, save_last=True)` not saving the `last` checkpoint ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136)) - - -- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) - - -- Fixed `.teardown(stage='fit')` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) - - -- Fixed `.on_fit_{start,end}()` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) - - -- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416)) - - -- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) - - -- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) - - -- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093)) - - -- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115) - - ## [1.2.5] - 2021-03-23 ### Changed - Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576)) +- Refactored setup for typing friendly ([#6590](https://github.com/PyTorchLightning/pytorch-lightning/pull/6590)) ### Fixed - Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587)) - - - Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434)) +- Fixed duplicate logs appearing in console when using the python logging module ([#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) +- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565)) ## [1.2.4] - 2021-03-16 @@ -167,6 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968)) - Fixed an exception in the layer summary when the model contains torch.jit scripted submodules ([#6511](https://github.com/PyTorchLightning/pytorch-lightning/pull/6511)) - Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) +- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115) ## [1.2.3] - 2021-03-09 diff --git a/pytorch_lightning/info.py b/pytorch_lightning/info.py index 1d729ee758d02..99a5ffa9e45e9 100644 --- a/pytorch_lightning/info.py +++ b/pytorch_lightning/info.py @@ -1,7 +1,7 @@ import time _this_year = time.strftime("%Y") -__version__ = '1.2.4' +__version__ = '1.2.5' __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0' From 944c79607ea74b6fbd04624caf735fc68d939a4f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 15 Mar 2021 20:28:18 +0100 Subject: [PATCH 10/10] Prune metrics base classes 2/n (#6530) * base class * extensions * chlog * _stable_1d_sort * _check_same_shape * _input_format_classification_one_hot * utils * to_onehot * select_topk * to_categorical * get_num_classes * reduce * class_reduce * tests (cherry picked from commit 6453091b8ab3713e2d58bad7acc9a4345dc5d10b) --- .../basic_examples/conv_sequential_example.py | 4 +- pytorch_lightning/accelerators/gpu.py | 2 +- pytorch_lightning/core/step_result.py | 2 +- pytorch_lightning/metrics/compositional.py | 100 +--- .../metrics/functional/classification.py | 21 + pytorch_lightning/metrics/functional/psnr.py | 8 +- pytorch_lightning/metrics/metric.py | 464 +----------------- pytorch_lightning/metrics/utils.py | 3 + pytorch_lightning/trainer/callback_hook.py | 2 +- .../trainer/connectors/callback_connector.py | 7 +- .../logger_connector/metrics_holder.py | 3 +- requirements.txt | 1 + tests/accelerators/test_dp.py | 2 +- tests/deprecated_api/test_remove_1-5.py | 2 +- tests/metrics/classification/test_inputs.py | 2 +- .../metrics/functional/test_classification.py | 2 +- tests/metrics/functional/test_reduction.py | 3 +- tests/metrics/test_metric_lightning.py | 8 +- 18 files changed, 98 insertions(+), 538 deletions(-) diff --git a/pl_examples/basic_examples/conv_sequential_example.py b/pl_examples/basic_examples/conv_sequential_example.py index b558020838cdb..6cfb6109f04fc 100644 --- a/pl_examples/basic_examples/conv_sequential_example.py +++ b/pl_examples/basic_examples/conv_sequential_example.py @@ -189,6 +189,7 @@ def instantiate_datamodule(args): ]) cifar10_dm = pl_bolts.datamodules.CIFAR10DataModule( + data_dir=args.data_dir, batch_size=args.batch_size, train_transforms=train_transforms, test_transforms=test_transforms, @@ -206,6 +207,7 @@ def instantiate_datamodule(args): parser = ArgumentParser(description="Pipe Example") parser.add_argument("--use_rpc_sequential", action="store_true") + parser.add_argument("--manual_optimization", action="store_true") parser = Trainer.add_argparse_args(parser) parser = pl_bolts.datamodules.CIFAR10DataModule.add_argparse_args(parser) args = parser.parse_args() @@ -216,7 +218,7 @@ def instantiate_datamodule(args): if args.use_rpc_sequential: plugins = RPCSequentialPlugin() - model = LitResnet(batch_size=args.batch_size, manual_optimization=not args.automatic_optimization) + model = LitResnet(batch_size=args.batch_size, manual_optimization=args.manual_optimization) trainer = pl.Trainer.from_argparse_args(args, plugins=[plugins] if plugins else None) trainer.fit(model, cifar10_dm) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index af9ce25f902b3..5c5dc5cc6f531 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -1,6 +1,6 @@ import logging import os -from typing import TYPE_CHECKING, Any +from typing import Any, TYPE_CHECKING import torch diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index f8d7a2ffe3a23..3961586f4946a 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -20,8 +20,8 @@ import torch from torch import Tensor +from torchmetrics import Metric -from pytorch_lightning.metrics import Metric from pytorch_lightning.utilities.distributed import sync_ddp_if_available diff --git a/pytorch_lightning/metrics/compositional.py b/pytorch_lightning/metrics/compositional.py index df98d16a3ef7e..d51332c43b6b4 100644 --- a/pytorch_lightning/metrics/compositional.py +++ b/pytorch_lightning/metrics/compositional.py @@ -1,14 +1,30 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Callable, Union import torch +from torchmetrics.metric import CompositionalMetric as _CompositionalMetric -from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics import Metric +from pytorch_lightning.utilities import rank_zero_warn -class CompositionalMetric(Metric): - """Composition of two metrics with a specific operator - which will be executed upon metric's compute +class CompositionalMetric(_CompositionalMetric): + r""" + This implementation refers to :class:`~torchmetrics.metric.CompositionalMetric`. + .. warning:: This metric is deprecated, use ``torchmetrics.metric.CompositionalMetric``. Will be removed in v1.5.0. """ def __init__( @@ -17,76 +33,8 @@ def __init__( metric_a: Union[Metric, int, float, torch.Tensor], metric_b: Union[Metric, int, float, torch.Tensor, None], ): - """ - - Args: - operator: the operator taking in one (if metric_b is None) - or two arguments. Will be applied to outputs of metric_a.compute() - and (optionally if metric_b is not None) metric_b.compute() - metric_a: first metric whose compute() result is the first argument of operator - metric_b: second metric whose compute() result is the second argument of operator. - For operators taking in only one input, this should be None - """ - super().__init__() - - self.op = operator - - if isinstance(metric_a, torch.Tensor): - self.register_buffer("metric_a", metric_a) - else: - self.metric_a = metric_a - - if isinstance(metric_b, torch.Tensor): - self.register_buffer("metric_b", metric_b) - else: - self.metric_b = metric_b - - def _sync_dist(self, dist_sync_fn=None): - # No syncing required here. syncing will be done in metric_a and metric_b - pass - - def update(self, *args, **kwargs): - if isinstance(self.metric_a, Metric): - self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs)) - - if isinstance(self.metric_b, Metric): - self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs)) - - def compute(self): - - # also some parsing for kwargs? - if isinstance(self.metric_a, Metric): - val_a = self.metric_a.compute() - else: - val_a = self.metric_a - - if isinstance(self.metric_b, Metric): - val_b = self.metric_b.compute() - else: - val_b = self.metric_b - - if val_b is None: - return self.op(val_a) - - return self.op(val_a, val_b) - - def reset(self): - if isinstance(self.metric_a, Metric): - self.metric_a.reset() - - if isinstance(self.metric_b, Metric): - self.metric_b.reset() - - def persistent(self, mode: bool = False): - if isinstance(self.metric_a, Metric): - self.metric_a.persistent(mode=mode) - if isinstance(self.metric_b, Metric): - self.metric_b.persistent(mode=mode) - - def __repr__(self): - repr_str = ( - self.__class__.__name__ - + f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)" + rank_zero_warn( + "This `Metric` was deprecated since v1.3.0 in favor of `torchmetrics.Metric`." + " It will be removed in v1.5.0", DeprecationWarning ) - - return repr_str + super().__init__(operator=operator, metric_a=metric_a, metric_b=metric_b) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index fae9e0770f88d..7281ca3f83717 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -123,6 +123,7 @@ def stat_scores( return tp, fp, tn, fn, sup +# todo: remove in 1.4 def stat_scores_multiple_classes( pred: torch.Tensor, target: torch.Tensor, @@ -136,6 +137,9 @@ def stat_scores_multiple_classes( .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.stat_scores` + Raises: + ValueError: + If ``reduction`` is not one of ``"none"``, ``"sum"`` or ``"elementwise_mean"``. """ rank_zero_warn( @@ -211,6 +215,7 @@ def _confmat_normalize(cm): return cm +# todo: remove in 1.4 def precision_recall( pred: torch.Tensor, target: torch.Tensor, @@ -269,6 +274,7 @@ def precision_recall( return precision, recall +# todo: remove in 1.4 def precision( pred: torch.Tensor, target: torch.Tensor, @@ -312,6 +318,7 @@ def precision( return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[0] +# todo: remove in 1.4 def recall( pred: torch.Tensor, target: torch.Tensor, @@ -509,6 +516,7 @@ def auc( return __auc(x, y) +# todo: remove in 1.4 def auc_decorator() -> Callable: rank_zero_warn("This `auc_decorator` was deprecated in v1.2.0." " It will be removed in v1.4.0", DeprecationWarning) @@ -525,6 +533,7 @@ def new_func(*args, **kwargs) -> torch.Tensor: return wrapper +# todo: remove in 1.4 def multiclass_auc_decorator() -> Callable: rank_zero_warn( "This `multiclass_auc_decorator` was deprecated in v1.2.0." @@ -547,6 +556,7 @@ def new_func(*args, **kwargs) -> torch.Tensor: return wrapper +# todo: remove in 1.4 def auroc( pred: torch.Tensor, target: torch.Tensor, @@ -589,6 +599,7 @@ def auroc( ) +# todo: remove in 1.4 def multiclass_auroc( pred: torch.Tensor, target: torch.Tensor, @@ -612,6 +623,16 @@ def multiclass_auroc( Return: Tensor containing ROCAUC score + Raises: + ValueError: + If ``pred`` don't sum up to ``1`` over classes for ``Multiclass AUROC``. + ValueError: + If number of classes found in ``target`` does not equal the number of + columns in ``pred``. + ValueError: + If number of classes deduced from ``pred`` does not equal the number of + classes passed in ``num_classes``. + Example: >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], diff --git a/pytorch_lightning/metrics/functional/psnr.py b/pytorch_lightning/metrics/functional/psnr.py index 434b2ae60218d..87ee1a93c9d19 100644 --- a/pytorch_lightning/metrics/functional/psnr.py +++ b/pytorch_lightning/metrics/functional/psnr.py @@ -15,8 +15,8 @@ import torch -from pytorch_lightning import utilities -from pytorch_lightning.metrics import utils +from pytorch_lightning.metrics.utils import reduce +from pytorch_lightning.utilities import rank_zero_warn def _psnr_compute( @@ -28,7 +28,7 @@ def _psnr_compute( ) -> torch.Tensor: psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs) psnr = psnr_base_e * (10 / torch.log(torch.tensor(base))) - return utils.reduce(psnr, reduction=reduction) + return reduce(psnr, reduction=reduction) def _psnr_update(preds: torch.Tensor, @@ -93,7 +93,7 @@ def psnr( """ if dim is None and reduction != 'elementwise_mean': - utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') + rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') if data_range is None: if dim is not None: diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index ab198356f7279..730011d998f10 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -11,52 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -import inspect -from abc import ABC, abstractmethod -from collections.abc import Sequence -from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from torch import nn +from torchmetrics import Metric as _Metric +from torchmetrics import MetricCollection as _MetricCollection -from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum -from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import gather_all_tensors +from pytorch_lightning.utilities.distributed import rank_zero_warn -class Metric(nn.Module, ABC): - """ - Base class for all metrics present in the Metrics API. - - Implements ``add_state()``, ``forward()``, ``reset()`` and a few other things to - handle distributed synchronization and per-step metric computation. - - Override ``update()`` and ``compute()`` functions to implement your own metric. Use - ``add_state()`` to register metric state variables which keep track of state on each - call of ``update()`` and are synchronized across processes when ``compute()`` is called. - - Note: - Metric state variables can either be ``torch.Tensors`` or an empty list which can we used - to store `torch.Tensors``. - - Note: - Different metrics only override ``update()`` and not ``forward()``. A call to ``update()`` - is valid, but it won't return the metric value at the current step. A call to ``forward()`` - automatically calls ``update()`` and also returns the metric value at the current step. - - Args: - compute_on_step: - Forward only calls ``update()`` and returns None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None +class Metric(_Metric): + r""" + This implementation refers to :class:`~torchmetrics.Metric`. + + .. warning:: This metric is deprecated, use ``torchmetrics.Metric``. Will be removed in v1.5.0. """ def __init__( @@ -66,356 +34,78 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__() - - self.dist_sync_on_step = dist_sync_on_step - self.compute_on_step = compute_on_step - self.process_group = process_group - self.dist_sync_fn = dist_sync_fn - self._to_sync = True - - self._update_signature = inspect.signature(self.update) - self.update = self._wrap_update(self.update) - self.compute = self._wrap_compute(self.compute) - self._computed = None - self._forward_cache = None - - # initialize state - self._defaults = {} - self._persistent = {} - self._reductions = {} - - def add_state( - self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = False - ): - """ - Adds metric state variable. Only used by subclasses. - - Args: - name: The name of the state variable. The variable will then be accessible at ``self.name``. - default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be - reset to this value when ``self.reset()`` is called. - dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode. - If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, - and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction - only makes sense if the state is a list, and not a tensor. The user can also pass a custom - function in this parameter. - persistent (Optional): whether the state will be saved as part of the modules ``state_dict``. - Default is ``False``. - - Note: - Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes. - However, there won't be any reduction function applied to the synchronized metric state. - - The metric states would be synced as follows - - - If the metric state is ``torch.Tensor``, the synced value will be a stacked ``torch.Tensor`` across - the process dimension if the metric state was a ``torch.Tensor``. The original ``torch.Tensor`` metric - state retains dimension and hence the synchronized output will be of shape ``(num_process, ...)``. - - - If the metric state is a ``list``, the synced value will be a ``list`` containing the - combined elements from all processes. - - Note: - When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow - the format discussed in the above note. - - """ - if ( - not isinstance(default, torch.Tensor) and not isinstance(default, list) # noqa: W503 - or (isinstance(default, list) and len(default) != 0) # noqa: W503 - ): - raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)") - - if dist_reduce_fx == "sum": - dist_reduce_fx = dim_zero_sum - elif dist_reduce_fx == "mean": - dist_reduce_fx = dim_zero_mean - elif dist_reduce_fx == "cat": - dist_reduce_fx = dim_zero_cat - elif dist_reduce_fx is not None and not isinstance(dist_reduce_fx, Callable): - raise ValueError("`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]") - - setattr(self, name, default) - - self._defaults[name] = deepcopy(default) - self._persistent[name] = persistent - self._reductions[name] = dist_reduce_fx - - @torch.jit.unused - def forward(self, *args, **kwargs): - """ - Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True. - """ - # add current step - with torch.no_grad(): - self.update(*args, **kwargs) - self._forward_cache = None - - if self.compute_on_step: - self._to_sync = self.dist_sync_on_step - - # save context before switch - cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} - - # call reset, update, compute, on single batch - self.reset() - self.update(*args, **kwargs) - self._forward_cache = self.compute() - - # restore context - for attr, val in cache.items(): - setattr(self, attr, val) - self._to_sync = True - self._computed = None - - return self._forward_cache - - def _sync_dist(self, dist_sync_fn=gather_all_tensors): - input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()} - output_dict = apply_to_collection( - input_dict, - torch.Tensor, - dist_sync_fn, - group=self.process_group, + rank_zero_warn( + "This `Metric` was deprecated since v1.3.0 in favor of `torchmetrics.Metric`." + " It will be removed in v1.5.0", DeprecationWarning + ) + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, ) - - for attr, reduction_fn in self._reductions.items(): - # pre-processing ops (stack or flatten for inputs) - if isinstance(output_dict[attr][0], torch.Tensor): - output_dict[attr] = torch.stack(output_dict[attr]) - elif isinstance(output_dict[attr][0], list): - output_dict[attr] = _flatten(output_dict[attr]) - - assert isinstance(reduction_fn, (Callable)) or reduction_fn is None - reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr] - setattr(self, attr, reduced) - - def _wrap_update(self, update): - - @functools.wraps(update) - def wrapped_func(*args, **kwargs): - self._computed = None - return update(*args, **kwargs) - - return wrapped_func - - def _wrap_compute(self, compute): - - @functools.wraps(compute) - def wrapped_func(*args, **kwargs): - # return cached value - if self._computed is not None: - return self._computed - - dist_sync_fn = self.dist_sync_fn - if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized(): - # User provided a bool, so we assume DDP if available - dist_sync_fn = gather_all_tensors - - synced = False - if self._to_sync and dist_sync_fn is not None: - # cache prior to syncing - cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} - - # sync - self._sync_dist(dist_sync_fn) - synced = True - - self._computed = compute(*args, **kwargs) - if synced: - # if we synced, restore to cache so that we can continue to accumulate un-synced state - for attr, val in cache.items(): - setattr(self, attr, val) - - return self._computed - - return wrapped_func - - @abstractmethod - def update(self) -> None: # pylint: disable=E0202 - """ - Override this method to update the state variables of your metric class. - """ - pass - - @abstractmethod - def compute(self): # pylint: disable=E0202 - """ - Override this method to compute the final metric value from state variables - synchronized across the distributed backend. - """ - pass - - def reset(self): - """ - This method automatically resets the metric state variables to their default value. - """ - for attr, default in self._defaults.items(): - current_val = getattr(self, attr) - if isinstance(default, torch.Tensor): - setattr(self, attr, deepcopy(default).to(current_val.device)) - else: - setattr(self, attr, deepcopy(default)) - - def clone(self): - """ Make a copy of the metric """ - return deepcopy(self) - - def __getstate__(self): - # ignore update and compute functions for pickling - return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]} - - def __setstate__(self, state): - # manually restore update and compute functions for pickling - self.__dict__.update(state) - self.update = self._wrap_update(self.update) - self.compute = self._wrap_compute(self.compute) - - def _apply(self, fn): - """Overwrite _apply function such that we can also move metric states - to the correct device when `.to`, `.cuda`, etc methods are called - """ - self = super()._apply(fn) - # Also apply fn to metric states - for key in self._defaults.keys(): - current_val = getattr(self, key) - if isinstance(current_val, torch.Tensor): - setattr(self, key, fn(current_val)) - elif isinstance(current_val, Sequence): - setattr(self, key, [fn(cur_v) for cur_v in current_val]) - else: - raise TypeError( - "Expected metric state to be either a torch.Tensor" - f"or a list of torch.Tensor, but encountered {current_val}" - ) - return self - - def persistent(self, mode: bool = False): - """Method for post-init to change if metric states should be saved to - its state_dict - """ - for key in self._persistent.keys(): - self._persistent[key] = mode - - def state_dict(self, destination=None, prefix='', keep_vars=False): - destination = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - # Register metric states to be part of the state_dict - for key in self._defaults.keys(): - if self._persistent[key]: - current_val = getattr(self, key) - if not keep_vars: - if torch.is_tensor(current_val): - current_val = current_val.detach() - elif isinstance(current_val, list): - current_val = [cur_v.detach() if torch.is_tensor(cur_v) else cur_v for cur_v in current_val] - destination[prefix + key] = current_val - return destination - - def _filter_kwargs(self, **kwargs): - """ filter kwargs such that they match the update signature of the metric """ - - # filter all parameters based on update signature except those of - # type VAR_POSITIONAL (*args) and VAR_KEYWORD (**kwargs) - _params = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) - filtered_kwargs = { - k: v - for k, v in kwargs.items() if k in self._update_signature.parameters.keys() - and self._update_signature.parameters[k].kind not in _params - } - - # if no kwargs filtered, return al kwargs as default - if not filtered_kwargs: - filtered_kwargs = kwargs - return filtered_kwargs def __hash__(self): - hash_vals = [self.__class__.__name__] - - for key in self._defaults.keys(): - val = getattr(self, key) - # Special case: allow list values, so long - # as their elements are hashable - if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor): - hash_vals.extend(val) - else: - hash_vals.append(val) - - return hash(tuple(hash_vals)) + return super().__hash__() def __add__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.add, self, other) def __and__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.bitwise_and, self, other) def __eq__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.eq, self, other) def __floordiv__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.floor_divide, self, other) def __ge__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.ge, self, other) def __gt__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.gt, self, other) def __le__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.le, self, other) def __lt__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.lt, self, other) def __matmul__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.matmul, self, other) def __mod__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.fmod, self, other) def __mul__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.mul, self, other) def __ne__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.ne, self, other) def __or__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.bitwise_or, self, other) def __pow__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.pow, self, other) def __radd__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.add, other, self) def __rand__(self, other: Any): @@ -426,72 +116,58 @@ def __rand__(self, other: Any): def __rfloordiv__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.floor_divide, other, self) def __rmatmul__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.matmul, other, self) def __rmod__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.fmod, other, self) def __rmul__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.mul, other, self) def __ror__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.bitwise_or, other, self) def __rpow__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.pow, other, self) def __rsub__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.sub, other, self) def __rtruediv__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.true_divide, other, self) def __rxor__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.bitwise_xor, other, self) def __sub__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.sub, self, other) def __truediv__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.true_divide, self, other) def __xor__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.bitwise_xor, self, other) def __abs__(self): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.abs, self, None) def __inv__(self): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.bitwise_not, self, None) def __invert__(self): @@ -499,12 +175,10 @@ def __invert__(self): def __neg__(self): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(_neg, self, None) def __pos__(self): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.abs, self, None) @@ -512,100 +186,16 @@ def _neg(tensor: torch.Tensor): return -torch.abs(tensor) -class MetricCollection(nn.ModuleDict): - """ - MetricCollection class can be used to chain metrics that have the same - call pattern into one single class. - - Args: - metrics: One of the following - - * list or tuple: if metrics are passed in as a list, will use the - metrics class name as key for output dict. Therefore, two metrics - of the same class cannot be chained this way. - - * dict: if metrics are passed in as a dict, will use each key in the - dict as key for output dict. Use this format if you want to chain - together multiple of the same metric with different parameters. - - Example (input as list): - - >>> from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall - >>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) - >>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) - >>> metrics = MetricCollection([Accuracy(), - ... Precision(num_classes=3, average='macro'), - ... Recall(num_classes=3, average='macro')]) - >>> metrics(preds, target) - {'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)} - - Example (input as dict): - - >>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'), - ... 'macro_recall': Recall(num_classes=3, average='macro')}) - >>> metrics(preds, target) - {'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)} +class MetricCollection(_MetricCollection): + r""" + This implementation refers to :class:`~torchmetrics.MetricCollection`. + .. warning:: This metric is deprecated, use ``torchmetrics.MetricCollection``. Will be removed in v1.5.0. """ def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): - super().__init__() - if isinstance(metrics, dict): - # Check all values are metrics - for name, metric in metrics.items(): - if not isinstance(metric, Metric): - raise ValueError( - f"Value {metric} belonging to key {name}" - " is not an instance of `pl.metrics.Metric`" - ) - self[name] = metric - elif isinstance(metrics, (tuple, list)): - for metric in metrics: - if not isinstance(metric, Metric): - raise ValueError( - f"Input {metric} to `MetricCollection` is not a instance" - " of `pl.metrics.Metric`" - ) - name = metric.__class__.__name__ - if name in self: - raise ValueError(f"Encountered two metrics both named {name}") - self[name] = metric - else: - raise ValueError("Unknown input to MetricCollection.") - - def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202 - """ - Iteratively call forward for each metric. Positional arguments (args) will - be passed to every metric in the collection, while keyword arguments (kwargs) - will be filtered based on the signature of the individual metric. - """ - return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} - - def update(self, *args, **kwargs): # pylint: disable=E0202 - """ - Iteratively call update for each metric. Positional arguments (args) will - be passed to every metric in the collection, while keyword arguments (kwargs) - will be filtered based on the signature of the individual metric. - """ - for _, m in self.items(): - m_kwargs = m._filter_kwargs(**kwargs) - m.update(*args, **m_kwargs) - - def compute(self) -> Dict[str, Any]: - return {k: m.compute() for k, m in self.items()} - - def reset(self): - """ Iteratively call reset for each metric """ - for _, m in self.items(): - m.reset() - - def clone(self): - """ Make a copy of the metric collection """ - return deepcopy(self) - - def persistent(self, mode: bool = True): - """Method for post-init to change if metric states should be saved to - its state_dict - """ - for _, m in self.items(): - m.persistent(mode) + rank_zero_warn( + "This `MetricCollection` was deprecated since v1.3.0 in favor of `torchmetrics.MetricCollection`." + " It will be removed in v1.5.0", DeprecationWarning + ) + super().__init__(metrics=metrics) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 7abf260a822ef..5084294bfbf98 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -245,6 +245,9 @@ def class_reduce( - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - ``'none'`` or ``None``: returns calculated metric per class + Raises: + ValueError: + If ``class_reduction`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"`` or ``None``. """ valid_reduction = ("micro", "macro", "weighted", "none", None) if class_reduction == "micro": diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 60e9183ac42f7..d33338055a5b1 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,7 +15,7 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import List, Dict, Any, Type, Callable +from typing import Any, Callable, Dict, List, Type from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 14694c8f77811..ee768c05cc8a2 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -14,12 +14,7 @@ import os from typing import List, Union -from pytorch_lightning.callbacks import ( - Callback, - ModelCheckpoint, - ProgressBar, - ProgressBarBase, -) +from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py index 82f328a927485..554f1d3faf9ed 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -15,8 +15,7 @@ from typing import Any import torch - -from pytorch_lightning.metrics.metric import Metric +from torchmetrics import Metric class MetricsHolder: diff --git a/requirements.txt b/requirements.txt index bdfd6601ba4c2..f196b5e639bf5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ PyYAML>=5.1, !=5.4.* # OmegaConf requirement >=5.1 tqdm>=4.41.0 fsspec[http]>=0.8.1 tensorboard>=2.2.0 +torchmetrics>=0.2.0 diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index 15faf98d94d57..8252aac9e9092 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -23,10 +23,10 @@ from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.core import memory from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import EvalModelTemplate from tests.helpers import BoringModel, RandomDataset from tests.helpers.datamodules import ClassifDataModule from tests.helpers.simple_models import ClassificationModel -from tests.base import EvalModelTemplate class CustomClassificationModelDP(ClassificationModel): diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 415f1d040ba70..c6b2dc24b35ff 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -17,7 +17,7 @@ import pytest -from pytorch_lightning import Trainer, Callback +from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import WandbLogger from tests.helpers import BoringModel from tests.helpers.utils import no_warning_call diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index a78d799b1a07d..2b7be8caa7a0d 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -1,9 +1,9 @@ import pytest import torch from torch import rand, randint +from torchmetrics.utilities.data import select_topk, to_onehot from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType -from pytorch_lightning.metrics.utils import select_topk, to_onehot from tests.metrics.classification.inputs import _input_binary as _bin from tests.metrics.classification.inputs import _input_binary_prob as _bin_prob from tests.metrics.classification.inputs import _input_multiclass as _mc diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 39622c4cd3550..bca50867dcb44 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -1,10 +1,10 @@ import pytest import torch +from torchmetrics.utilities.data import get_num_classes, to_categorical, to_onehot from pytorch_lightning import seed_everything from pytorch_lightning.metrics.functional.classification import dice_score from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve -from pytorch_lightning.metrics.utils import get_num_classes, to_categorical, to_onehot def test_onehot(): diff --git a/tests/metrics/functional/test_reduction.py b/tests/metrics/functional/test_reduction.py index 03a34f6c5a25b..9949c8086a44a 100644 --- a/tests/metrics/functional/test_reduction.py +++ b/tests/metrics/functional/test_reduction.py @@ -1,7 +1,6 @@ import pytest import torch - -from pytorch_lightning.metrics.utils import class_reduce, reduce +from torchmetrics.utilities import class_reduce, reduce def test_reduce(): diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 895305fa9da7e..e52e39cb16488 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -1,11 +1,13 @@ import torch +from torchmetrics import Metric as TMetric from pytorch_lightning import Trainer -from pytorch_lightning.metrics import Metric, MetricCollection +from pytorch_lightning.metrics import Metric as PLMetric +from pytorch_lightning.metrics import MetricCollection from tests.helpers.boring_model import BoringModel -class SumMetric(Metric): +class SumMetric(TMetric): def __init__(self): super().__init__() @@ -18,7 +20,7 @@ def compute(self): return self.x -class DiffMetric(Metric): +class DiffMetric(PLMetric): def __init__(self): super().__init__()