Skip to content

Commit

Permalink
External callback registry through entry points for Fabric (#17756)
Browse files Browse the repository at this point in the history
  • Loading branch information
lightningforever authored and lantiga committed Jun 7, 2023
1 parent 1ddd690 commit f2f187f
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 47 deletions.
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [UnReleased] - 2023-04-DD

- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756))


### Changed

-
Expand Down
11 changes: 9 additions & 2 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
has_iterable_dataset,
)
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
from lightning.fabric.utilities.registry import _load_external_callbacks
from lightning.fabric.utilities.seed import seed_everything
from lightning.fabric.utilities.types import ReduceOp
from lightning.fabric.utilities.warnings import PossibleUserWarning
Expand Down Expand Up @@ -105,8 +106,7 @@ def __init__(
self._strategy: Strategy = self._connector.strategy
self._accelerator: Accelerator = self._connector.accelerator
self._precision: Precision = self._strategy.precision
callbacks = callbacks if callbacks is not None else []
self._callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
self._callbacks = self._configure_callbacks(callbacks)
loggers = loggers if loggers is not None else []
self._loggers = loggers if isinstance(loggers, list) else [loggers]
self._models_setup: int = 0
Expand Down Expand Up @@ -846,6 +846,13 @@ def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None:
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")

@staticmethod
def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]:
callbacks = callbacks if callbacks is not None else []
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
callbacks.extend(_load_external_callbacks("lightning.fabric.callbacks_factory"))
return callbacks


def _is_using_cli() -> bool:
return bool(int(os.environ.get("LT_CLI_USED", "0")))
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True)
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
_TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1

_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
44 changes: 43 additions & 1 deletion src/lightning/fabric/utilities/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any
import logging
from typing import Any, List, Union

from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0

_log = logging.getLogger(__name__)


def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool:
Expand All @@ -25,3 +30,40 @@ def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> boo
return False

return mod_attr.__code__ is not super_attr.__code__


def _load_external_callbacks(group: str) -> List[Any]:
"""Collect external callbacks registered through entry points.
The entry points are expected to be functions returning a list of callbacks.
Args:
group: The entry point group name to load callbacks from.
Return:
A list of all callbacks collected from external factories.
"""
if _PYTHON_GREATER_EQUAL_3_8_0:
from importlib.metadata import entry_points

factories = (
entry_points(group=group)
if _PYTHON_GREATER_EQUAL_3_10_0
else entry_points().get(group, {}) # type: ignore[arg-type]
)
else:
from pkg_resources import iter_entry_points

factories = iter_entry_points(group) # type: ignore[assignment]

external_callbacks: List[Any] = []
for factory in factories:
callback_factory = factory.load()
callbacks_list: Union[List[Any], Any] = callback_factory()
callbacks_list = [callbacks_list] if not isinstance(callbacks_list, list) else callbacks_list
_log.info(
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
)
external_callbacks.extend(callbacks_list)
return external_callbacks
40 changes: 2 additions & 38 deletions src/lightning/pytorch/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Dict, List, Optional, Sequence, Union

import lightning.pytorch as pl
from lightning.fabric.utilities.registry import _load_external_callbacks
from lightning.pytorch.callbacks import (
Callback,
Checkpoint,
Expand All @@ -33,7 +34,6 @@
from lightning.pytorch.callbacks.timer import Timer
from lightning.pytorch.trainer import call
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_info

Expand Down Expand Up @@ -75,7 +75,7 @@ def on_trainer_init(
# configure the ModelSummary callback
self._configure_model_summary_callback(enable_model_summary)

self.trainer.callbacks.extend(_configure_external_callbacks())
self.trainer.callbacks.extend(_load_external_callbacks("lightning.pytorch.callbacks_factory"))
_validate_callbacks_list(self.trainer.callbacks)

# push all model checkpoint callbacks to the end
Expand Down Expand Up @@ -213,42 +213,6 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
return tuner_callbacks + other_callbacks + checkpoint_callbacks


def _configure_external_callbacks() -> List[Callback]:
"""Collect external callbacks registered through entry points.
The entry points are expected to be functions returning a list of callbacks.
Return:
A list of all callbacks collected from external factories.
"""
group = "lightning.pytorch.callbacks_factory"

if _PYTHON_GREATER_EQUAL_3_8_0:
from importlib.metadata import entry_points

factories = (
entry_points(group=group)
if _PYTHON_GREATER_EQUAL_3_10_0
else entry_points().get(group, {}) # type: ignore[arg-type]
)
else:
from pkg_resources import iter_entry_points

factories = iter_entry_points(group) # type: ignore[assignment]

external_callbacks: List[Callback] = []
for factory in factories:
callback_factory = factory.load()
callbacks_list: Union[List[Callback], Callback] = callback_factory()
callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list
_log.info(
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
)
external_callbacks.extend(callbacks_list)
return external_callbacks


def _validate_callbacks_list(callbacks: List[Callback]) -> None:
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
seen_callbacks = set()
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/trainer/connectors/signal_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

import lightning.pytorch as pl
from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from lightning.fabric.utilities.imports import _IS_WINDOWS, _PYTHON_GREATER_EQUAL_3_8_0
from lightning.pytorch.utilities.rank_zero import rank_zero_info

# copied from signal.pyi
Expand Down
2 changes: 0 additions & 2 deletions src/lightning/pytorch/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import torch
from lightning_utilities.core.imports import package_available, RequirementCache

_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task
Expand Down
64 changes: 64 additions & 0 deletions tests/tests_fabric/utilities/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import contextlib
from unittest import mock
from unittest.mock import Mock

from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from lightning.fabric.utilities.registry import _load_external_callbacks


class ExternalCallback:
"""A callback in another library that gets registered through entry points."""

pass


def test_load_external_callbacks():
"""Test that the connector collects Callback instances from factories registered through entry points."""

def factory_no_callback():
return []

def factory_one_callback():
return ExternalCallback()

def factory_one_callback_list():
return [ExternalCallback()]

def factory_multiple_callbacks_list():
return [ExternalCallback(), ExternalCallback()]

with _make_entry_point_query_mock(factory_no_callback):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert callbacks == []

with _make_entry_point_query_mock(factory_one_callback):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert isinstance(callbacks[0], ExternalCallback)

with _make_entry_point_query_mock(factory_one_callback_list):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert isinstance(callbacks[0], ExternalCallback)

with _make_entry_point_query_mock(factory_multiple_callbacks_list):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert isinstance(callbacks[0], ExternalCallback)
assert isinstance(callbacks[1], ExternalCallback)


@contextlib.contextmanager
def _make_entry_point_query_mock(callback_factory):
query_mock = Mock()
entry_point = Mock()
entry_point.name = "mocked"
entry_point.load.return_value = callback_factory
if _PYTHON_GREATER_EQUAL_3_10_0:
query_mock.return_value = [entry_point]
import_path = "importlib.metadata.entry_points"
elif _PYTHON_GREATER_EQUAL_3_8_0:
query_mock().get.return_value = [entry_point]
import_path = "importlib.metadata.entry_points"
else:
query_mock.return_value = [entry_point]
import_path = "pkg_resources.iter_entry_points"
with mock.patch(import_path, query_mock):
yield
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest
import torch

from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from lightning.pytorch import Callback, LightningModule, Trainer
from lightning.pytorch.callbacks import (
EarlyStopping,
Expand All @@ -32,7 +33,6 @@
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0


def test_checkpoint_callbacks_are_last(tmpdir):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
import torch
from torch import Tensor

from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from lightning.pytorch import callbacks, Trainer
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.loops import _EvaluationLoop
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from tests_pytorch.helpers.runif import RunIf

if _RICH_AVAILABLE:
Expand Down

0 comments on commit f2f187f

Please sign in to comment.