From d5bc21a33ea05ff310fdbb68d60023064383e38c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 16 Jun 2023 18:12:02 +0000 Subject: [PATCH] Update pytorch-lightning requirement from <1.9.0,>1.7.0 to >1.7.0,<2.0.0 in /requirements (#1006) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- requirements/base.txt | 3 ++- .../detection/retinanet/retinanet_module.py | 7 +++---- src/pl_bolts/utils/__init__.py | 16 +++++++++------- tests/callbacks/test_variational_callbacks.py | 6 +++++- tests/conftest.py | 3 +-- tests/models/test_scripts.py | 5 ++--- 6 files changed, 22 insertions(+), 18 deletions(-) diff --git a/requirements/base.txt b/requirements/base.txt index fd4250619..40d503aed 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,5 +1,6 @@ numpy <1.25.0 -pytorch-lightning >1.7.0, <1.9.0 # strict +pytorch-lightning >1.7.0, <2.0.0 # strict torchmetrics <0.11.0 # strict lightning-utilities >0.3.1 # this is needed for PL 1.7 torchvision >=0.10.0 # todo: move to topic related extras +tensorboard >=2.9.1, <2.14.0 # for `TensorBoardLogger` diff --git a/src/pl_bolts/models/detection/retinanet/retinanet_module.py b/src/pl_bolts/models/detection/retinanet/retinanet_module.py index bd255dc9d..c0415cc30 100644 --- a/src/pl_bolts/models/detection/retinanet/retinanet_module.py +++ b/src/pl_bolts/models/detection/retinanet/retinanet_module.py @@ -135,11 +135,10 @@ def configure_optimizers(self): @under_review() def cli_main(): - # Backward compatibility for Lightning CLI - try: - from pytorch_lightning.utilities.cli import LightningCLI + try: # Backward compatibility for Lightning CLI + from pytorch_lightning.cli import LightningCLI # PL v1.9+ except ImportError: - from pytorch_lightning.cli import LightningCLI + from pytorch_lightning.utilities.cli import LightningCLI # PL v1.8 from pl_bolts.datamodules import VOCDetectionDataModule diff --git a/src/pl_bolts/utils/__init__.py b/src/pl_bolts/utils/__init__.py index 6ef266820..f3c9df53f 100644 --- a/src/pl_bolts/utils/__init__.py +++ b/src/pl_bolts/utils/__init__.py @@ -1,4 +1,5 @@ import operator +import platform import torch from lightning_utilities.core.imports import compare_version, module_available @@ -6,21 +7,22 @@ from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerification # type: ignore _NATIVE_AMP_AVAILABLE: bool = module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") +_IS_WINDOWS = platform.system() == "Windows" +_TORCH_ORT_AVAILABLE = module_available("torch_ort") +_TORCH_MAX_VERSION_SPARSEML = compare_version("torch", operator.lt, "1.11.0") +_TORCH_MESHGRID_REQUIRES_INDEXING = compare_version("torch", operator.ge, "1.10.0") _TORCHVISION_AVAILABLE: bool = module_available("torchvision") +_TORCHVISION_LESS_THAN_0_9_1: bool = compare_version("torchvision", operator.lt, "0.9.1") +_TORCHVISION_LESS_THAN_0_13: bool = compare_version("torchvision", operator.le, "0.13.0") _TORCHMETRICS_DETECTION_AVAILABLE: bool = module_available("torchmetrics.detection") +_PL_GREATER_EQUAL_1_4 = compare_version("pytorch_lightning", operator.ge, "1.4.0") +_PL_GREATER_EQUAL_1_4_5 = compare_version("pytorch_lightning", operator.ge, "1.4.5") _GYM_AVAILABLE: bool = module_available("gym") _SKLEARN_AVAILABLE: bool = module_available("sklearn") _PIL_AVAILABLE: bool = module_available("PIL") _OPENCV_AVAILABLE: bool = module_available("cv2") _WANDB_AVAILABLE: bool = module_available("wandb") _MATPLOTLIB_AVAILABLE: bool = module_available("matplotlib") -_TORCHVISION_LESS_THAN_0_9_1: bool = compare_version("torchvision", operator.lt, "0.9.1") -_TORCHVISION_LESS_THAN_0_13: bool = compare_version("torchvision", operator.le, "0.13.0") -_PL_GREATER_EQUAL_1_4 = compare_version("pytorch_lightning", operator.ge, "1.4.0") -_PL_GREATER_EQUAL_1_4_5 = compare_version("pytorch_lightning", operator.ge, "1.4.5") -_TORCH_ORT_AVAILABLE = module_available("torch_ort") -_TORCH_MAX_VERSION_SPARSEML = compare_version("torch", operator.lt, "1.11.0") -_TORCH_MESHGRID_REQUIRES_INDEXING = compare_version("torch", operator.ge, "1.10.0") _SPARSEML_AVAILABLE = module_available("sparseml") and _PL_GREATER_EQUAL_1_4_5 and _TORCH_MAX_VERSION_SPARSEML _JSONARGPARSE_GREATER_THAN_4_16_0 = compare_version("jsonargparse", operator.gt, "4.16.0") diff --git a/tests/callbacks/test_variational_callbacks.py b/tests/callbacks/test_variational_callbacks.py index 21ad33293..6994beb10 100644 --- a/tests/callbacks/test_variational_callbacks.py +++ b/tests/callbacks/test_variational_callbacks.py @@ -1,6 +1,10 @@ from pl_bolts.callbacks import LatentDimInterpolator from pl_bolts.models.gans import GAN -from pytorch_lightning.loggers.base import DummyLogger + +try: + from pytorch_lightning.loggers.logger import DummyLogger # PL v1.9+ +except ModuleNotFoundError: + from pytorch_lightning.loggers.base import DummyLogger # PL v1.8 def test_latent_dim_interpolator(): diff --git a/tests/conftest.py b/tests/conftest.py index 78e2853a8..7637c059b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,10 +5,9 @@ import pytest import torch -from pl_bolts.utils import _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_13 +from pl_bolts.utils import _IS_WINDOWS, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_13 from pl_bolts.utils.stability import UnderReviewWarning from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector -from pytorch_lightning.utilities.imports import _IS_WINDOWS # GitHub Actions use this path to cache datasets. # Use `datadir` fixture where possible and use `DATASETS_PATH` in diff --git a/tests/models/test_scripts.py b/tests/models/test_scripts.py index 00ec81276..503b4a962 100644 --- a/tests/models/test_scripts.py +++ b/tests/models/test_scripts.py @@ -111,9 +111,8 @@ def test_cli_run_vision_image_gpt(cli_args): @pytest.mark.parametrize("cli_args", [_DEFAULT_LIGHTNING_CLI_ARGS + " --trainer.gpus 1"]) @pytest.mark.skipif(**_MARK_REQUIRE_GPU) -@pytest.mark.skipif( - not _JSONARGPARSE_GREATER_THAN_4_16_0, reason="Failing on CI, need to be fixed" -) # see https://github.com/omni-us/jsonargparse/issues/187 +# FixMe; see https://github.com/omni-us/jsonargparse/issues/187 +@pytest.mark.skipif(not _JSONARGPARSE_GREATER_THAN_4_16_0, reason="Failing on CI, need to be fixed") def test_cli_run_retinanet(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.detection.retinanet.retinanet_module import cli_main