Skip to content

Commit

Permalink
Update pytorch-lightning requirement from <1.9.0,>1.7.0 to >1.7.0,<2.…
Browse files Browse the repository at this point in the history
…0.0 in /requirements (#1006)

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
dependabot[bot] and Borda authored Jun 16, 2023
1 parent 6656dae commit d5bc21a
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 18 deletions.
3 changes: 2 additions & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
numpy <1.25.0
pytorch-lightning >1.7.0, <1.9.0 # strict
pytorch-lightning >1.7.0, <2.0.0 # strict
torchmetrics <0.11.0 # strict
lightning-utilities >0.3.1 # this is needed for PL 1.7
torchvision >=0.10.0 # todo: move to topic related extras
tensorboard >=2.9.1, <2.14.0 # for `TensorBoardLogger`
7 changes: 3 additions & 4 deletions src/pl_bolts/models/detection/retinanet/retinanet_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,10 @@ def configure_optimizers(self):

@under_review()
def cli_main():
# Backward compatibility for Lightning CLI
try:
from pytorch_lightning.utilities.cli import LightningCLI
try: # Backward compatibility for Lightning CLI
from pytorch_lightning.cli import LightningCLI # PL v1.9+
except ImportError:
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.utilities.cli import LightningCLI # PL v1.8

from pl_bolts.datamodules import VOCDetectionDataModule

Expand Down
16 changes: 9 additions & 7 deletions src/pl_bolts/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
import operator
import platform

import torch
from lightning_utilities.core.imports import compare_version, module_available

from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerification # type: ignore

_NATIVE_AMP_AVAILABLE: bool = module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
_IS_WINDOWS = platform.system() == "Windows"
_TORCH_ORT_AVAILABLE = module_available("torch_ort")
_TORCH_MAX_VERSION_SPARSEML = compare_version("torch", operator.lt, "1.11.0")
_TORCH_MESHGRID_REQUIRES_INDEXING = compare_version("torch", operator.ge, "1.10.0")
_TORCHVISION_AVAILABLE: bool = module_available("torchvision")
_TORCHVISION_LESS_THAN_0_9_1: bool = compare_version("torchvision", operator.lt, "0.9.1")
_TORCHVISION_LESS_THAN_0_13: bool = compare_version("torchvision", operator.le, "0.13.0")
_TORCHMETRICS_DETECTION_AVAILABLE: bool = module_available("torchmetrics.detection")
_PL_GREATER_EQUAL_1_4 = compare_version("pytorch_lightning", operator.ge, "1.4.0")
_PL_GREATER_EQUAL_1_4_5 = compare_version("pytorch_lightning", operator.ge, "1.4.5")
_GYM_AVAILABLE: bool = module_available("gym")
_SKLEARN_AVAILABLE: bool = module_available("sklearn")
_PIL_AVAILABLE: bool = module_available("PIL")
_OPENCV_AVAILABLE: bool = module_available("cv2")
_WANDB_AVAILABLE: bool = module_available("wandb")
_MATPLOTLIB_AVAILABLE: bool = module_available("matplotlib")
_TORCHVISION_LESS_THAN_0_9_1: bool = compare_version("torchvision", operator.lt, "0.9.1")
_TORCHVISION_LESS_THAN_0_13: bool = compare_version("torchvision", operator.le, "0.13.0")
_PL_GREATER_EQUAL_1_4 = compare_version("pytorch_lightning", operator.ge, "1.4.0")
_PL_GREATER_EQUAL_1_4_5 = compare_version("pytorch_lightning", operator.ge, "1.4.5")
_TORCH_ORT_AVAILABLE = module_available("torch_ort")
_TORCH_MAX_VERSION_SPARSEML = compare_version("torch", operator.lt, "1.11.0")
_TORCH_MESHGRID_REQUIRES_INDEXING = compare_version("torch", operator.ge, "1.10.0")
_SPARSEML_AVAILABLE = module_available("sparseml") and _PL_GREATER_EQUAL_1_4_5 and _TORCH_MAX_VERSION_SPARSEML
_JSONARGPARSE_GREATER_THAN_4_16_0 = compare_version("jsonargparse", operator.gt, "4.16.0")

Expand Down
6 changes: 5 additions & 1 deletion tests/callbacks/test_variational_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from pl_bolts.callbacks import LatentDimInterpolator
from pl_bolts.models.gans import GAN
from pytorch_lightning.loggers.base import DummyLogger

try:
from pytorch_lightning.loggers.logger import DummyLogger # PL v1.9+
except ModuleNotFoundError:
from pytorch_lightning.loggers.base import DummyLogger # PL v1.8


def test_latent_dim_interpolator():
Expand Down
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

import pytest
import torch
from pl_bolts.utils import _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_13
from pl_bolts.utils import _IS_WINDOWS, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_13
from pl_bolts.utils.stability import UnderReviewWarning
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.utilities.imports import _IS_WINDOWS

# GitHub Actions use this path to cache datasets.
# Use `datadir` fixture where possible and use `DATASETS_PATH` in
Expand Down
5 changes: 2 additions & 3 deletions tests/models/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,8 @@ def test_cli_run_vision_image_gpt(cli_args):

@pytest.mark.parametrize("cli_args", [_DEFAULT_LIGHTNING_CLI_ARGS + " --trainer.gpus 1"])
@pytest.mark.skipif(**_MARK_REQUIRE_GPU)
@pytest.mark.skipif(
not _JSONARGPARSE_GREATER_THAN_4_16_0, reason="Failing on CI, need to be fixed"
) # see https://github.com/omni-us/jsonargparse/issues/187
# FixMe; see https://github.com/omni-us/jsonargparse/issues/187
@pytest.mark.skipif(not _JSONARGPARSE_GREATER_THAN_4_16_0, reason="Failing on CI, need to be fixed")
def test_cli_run_retinanet(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.detection.retinanet.retinanet_module import cli_main
Expand Down

0 comments on commit d5bc21a

Please sign in to comment.