Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bump: compatible with Lightning v2 #1094

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Revision of the MoCo SSL model ([#928](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/928))

- Updated lightning dependency to support lightning 2.x ([#1094](https://github.com/Lightning-AI/pytorch-lightning/pull/2671))


### Deprecated

Expand Down
6 changes: 3 additions & 3 deletions requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy <1.26.0
pytorch-lightning >1.7.0, <2.0.0 # strict
torchmetrics >=0.10.0, <0.12.0
lightning >=2.0.0 # strict
torchmetrics >=0.7.0, <1.3.0
lightning-utilities >0.3.1 # this is needed for PL 1.7
torchvision >=0.10.0 # todo: move to topic related extras
torchvision >=0.15.0, <0.19.0 # todo: move to topic related extras
tensorboard >=2.9.1, <2.14.0 # for `TensorBoardLogger`
2 changes: 1 addition & 1 deletion requirements/models.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torchvision >=0.10.0
torchvision >=0.15.0, <0.19.0
scikit-learn >=1.0.2
Pillow >9.0.0
gym[atari] >=0.17.2, <0.22.0 # strict
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/callbacks/byol_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Sequence, Union

import torch.nn as nn
from pytorch_lightning import Callback, LightningModule, Trainer
from lightning import Callback, LightningModule, Trainer
from torch import Tensor


Expand Down
14 changes: 7 additions & 7 deletions src/pl_bolts/callbacks/data_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import numpy as np
import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from lightning import Callback, LightningModule, Trainer
from lightning.fabric.utilities import rank_zero_warn
from lightning.fabric.utilities.apply_func import apply_to_collection
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from torch import Tensor, nn
from torch.nn import Module
from torch.utils.hooks import RemovableHandle
Expand All @@ -16,9 +16,9 @@

# Backward compatibility for Lightning Logger
try:
from pytorch_lightning.loggers import Logger
from lightning.pytorch.loggers import Logger
except ImportError:
from pytorch_lightning.loggers import LightningLoggerBase as Logger
from lightning.pytorch.loggers import LightningLoggerBase as Logger

if _WANDB_AVAILABLE:
import wandb
Expand Down Expand Up @@ -112,7 +112,7 @@ def _is_logger_available(self, logger: Logger) -> bool:
if not isinstance(logger, self.supported_loggers):
rank_zero_warn(
f"{self.__class__.__name__} does not support logging with {logger.__class__.__name__}."
f" Supported loggers are: {', '.join((str(x.__name__) for x in self.supported_loggers))}"
f" Supported loggers are: {', '.join(str(x.__name__) for x in self.supported_loggers)}"
)
available = False
return available
Expand Down
4 changes: 2 additions & 2 deletions src/pl_bolts/callbacks/knn_online.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional, Tuple, Union

import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.accelerators import Accelerator
from lightning import Callback, LightningModule, Trainer
from lightning.pytorch.accelerators import Accelerator
from torch import Tensor
from torch.nn import functional as F # noqa: N812

Expand Down
6 changes: 3 additions & 3 deletions src/pl_bolts/callbacks/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from itertools import zip_longest
from typing import Any, Callable, Dict, List, Optional

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_info
from lightning import LightningModule, Trainer
from lightning.fabric.utilities import rank_zero_info
from lightning.pytorch.callbacks import Callback

from pl_bolts.utils.stability import under_review

Expand Down
10 changes: 7 additions & 3 deletions src/pl_bolts/callbacks/sparseml.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
from typing import Any, Optional

import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from lightning import Callback, LightningModule, Trainer
from lightning.fabric.utilities.exceptions import MisconfigurationException

from pl_bolts.utils import _SPARSEML_AVAILABLE, _SPARSEML_TORCH_SATISFIED, _SPARSEML_TORCH_SATISFIED_ERROR
from pl_bolts.utils import (
_SPARSEML_AVAILABLE,
_SPARSEML_TORCH_SATISFIED,
_SPARSEML_TORCH_SATISFIED_ERROR,
)

if _SPARSEML_TORCH_SATISFIED:
from sparseml.pytorch.optim import ScheduledModifierManager
Expand Down
4 changes: 2 additions & 2 deletions src/pl_bolts/callbacks/ssl_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.utilities import rank_zero_warn
from lightning import Callback, LightningModule, Trainer
from lightning.fabric.utilities import rank_zero_warn
from torch import Tensor, nn
from torch.nn import functional as F # noqa: N812
from torch.optim import Optimizer
Expand Down
4 changes: 2 additions & 2 deletions src/pl_bolts/callbacks/torch_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from lightning import Callback, LightningModule, Trainer
from lightning.fabric.utilities.exceptions import MisconfigurationException

from pl_bolts.utils import _TORCH_ORT_AVAILABLE

Expand Down
4 changes: 2 additions & 2 deletions src/pl_bolts/callbacks/variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import numpy as np
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from torch import Tensor

from pl_bolts.utils import _TORCHVISION_AVAILABLE
Expand Down
6 changes: 3 additions & 3 deletions src/pl_bolts/callbacks/verification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from typing import Any, Optional

import torch.nn as nn
from pytorch_lightning import Callback, LightningModule
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from lightning import Callback, LightningModule
from lightning.fabric.utilities import move_data_to_device, rank_zero_warn
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature

from pl_bolts.utils.stability import under_review

Expand Down
11 changes: 7 additions & 4 deletions src/pl_bolts/callbacks/verification/batch_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@

import torch
import torch.nn as nn
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from lightning import LightningModule, Trainer
from lightning.fabric.utilities.apply_func import apply_to_collection
from lightning.fabric.utilities.exceptions import MisconfigurationException
from torch import Tensor

from pl_bolts.callbacks.verification.base import VerificationBase, VerificationCallbackBase
from pl_bolts.callbacks.verification.base import (
VerificationBase,
VerificationCallbackBase,
)
from pl_bolts.utils.stability import under_review


Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/callbacks/vision/confused_logit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Sequence

import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from lightning import Callback, LightningModule, Trainer
from torch import Tensor, nn

from pl_bolts.utils import _MATPLOTLIB_AVAILABLE
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/callbacks/vision/image_generation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Tuple

import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from lightning import Callback, LightningModule, Trainer

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/callbacks/vision/sr_image_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytorch_lightning as pl
import torch
import torch.nn.functional as F # noqa: N812
from pytorch_lightning import Callback
from lightning import Callback

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/datamodules/cityscapes_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Callable, Optional

from pytorch_lightning import LightningDataModule
from lightning import LightningDataModule
from torch.utils.data import DataLoader

from pl_bolts.utils import _TORCHVISION_AVAILABLE
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from argparse import ArgumentParser
from typing import Any, Callable, Optional

from pytorch_lightning import LightningDataModule
from lightning import LightningDataModule
from torch.utils.data import DataLoader

from pl_bolts.datasets import UnlabeledImagenet
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Callable, Optional

import torch
from pytorch_lightning import LightningDataModule
from lightning import LightningDataModule
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split

Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/datamodules/sklearn_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Tuple

import numpy as np
from pytorch_lightning import LightningDataModule
from lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset

from pl_bolts.utils import _SKLEARN_AVAILABLE
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/datamodules/sr_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any

from pytorch_lightning import LightningDataModule
from lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset

from pl_bolts.utils.stability import under_review
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/datamodules/ssl_imagenet_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import Any, Callable, Optional

from pytorch_lightning import LightningDataModule
from lightning import LightningDataModule
from torch.utils.data import DataLoader

from pl_bolts.datasets import UnlabeledImagenet
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/datamodules/stl10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, Optional

import torch
from pytorch_lightning import LightningDataModule
from lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split

from pl_bolts.datasets import ConcatDataset
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/datamodules/vision_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, List, Optional, Union

import torch
from pytorch_lightning import LightningDataModule
from lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset, random_split


Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/datamodules/vocdetection_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from pytorch_lightning import LightningDataModule
from lightning import LightningDataModule
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/datasets/array_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Tuple, Union

from pytorch_lightning.utilities import exceptions
from lightning.fabric.utilities import exceptions
from torch.utils.data import Dataset

from pl_bolts.datasets.base_dataset import DataModel, TArrays
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/losses/self_supervised_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def forward(self, anchor_maps, positive_maps):
Example:

>>> import torch
>>> from pytorch_lightning import seed_everything
>>> from lightning import seed_everything
>>> seed_everything(0)
0
>>> a1 = torch.rand(3, 5, 2, 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from argparse import ArgumentParser

import torch
from pytorch_lightning import LightningModule, Trainer
from lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F # noqa: N812

Expand Down Expand Up @@ -154,7 +154,11 @@ def add_model_specific_args(parent_parser):

@under_review()
def cli_main(args=None):
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
from pl_bolts.datamodules import (
CIFAR10DataModule,
ImagenetDataModule,
STL10DataModule,
)

parser = ArgumentParser()
parser.add_argument("--dataset", default="cifar10", type=str, choices=["cifar10", "stl10", "imagenet"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from argparse import ArgumentParser

import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
from lightning import LightningModule, Trainer, seed_everything
from torch import nn
from torch.nn import functional as F # noqa: N812

Expand Down Expand Up @@ -187,7 +187,11 @@ def add_model_specific_args(parent_parser):

@under_review()
def cli_main(args=None):
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
from pl_bolts.datamodules import (
CIFAR10DataModule,
ImagenetDataModule,
STL10DataModule,
)

seed_everything()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
from typing import Any, Optional, Union

import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
from lightning import LightningModule, Trainer, seed_everything

from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision.models.detection.faster_rcnn import FasterRCNN as torchvision_FasterRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import (
FasterRCNN as torchvision_FasterRCNN,
)
from torchvision.models.detection.faster_rcnn import (
FastRCNNPredictor,
fasterrcnn_resnet50_fpn,
)
from torchvision.ops import box_iou
else: # pragma: no cover
warn_missing_pkg("torchvision")
Expand Down
Loading
Loading