From 1ddd6909ef6a8565dd4a29a0143000cea9019bbe Mon Sep 17 00:00:00 2001 From: "M. Fox" <120434191+lightningforever@users.noreply.github.com> Date: Tue, 6 Jun 2023 18:04:19 +0200 Subject: [PATCH] Add Fabric internal hooks (#17759) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/lightning/fabric/fabric.py | 12 ++++++++++-- src/lightning/fabric/wrappers.py | 12 +++++++++--- tests/tests_fabric/test_fabric.py | 23 +++++++++++++++++++++++ 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index f12dc56ce707e..e8d9297a5def1 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -212,7 +212,10 @@ def setup( # Update the _DeviceDtypeModuleMixin's device parameter module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device) - optimizers = [_FabricOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] + optimizers = [ + _FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks) + for optimizer in optimizers + ] self._models_setup += 1 @@ -220,6 +223,8 @@ def setup( original_module._fabric = self # type: ignore[assignment] original_module._fabric_optimizers = optimizers # type: ignore[assignment] + self.call("on_after_setup", fabric=self, module=module) + if optimizers: # join both types in a tuple for API convenience return (module, *optimizers) @@ -276,7 +281,10 @@ def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, Tu """ self._validate_setup_optimizers(optimizers) optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers] - optimizers = [_FabricOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] + optimizers = [ + _FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks) + for optimizer in optimizers + ] return optimizers[0] if len(optimizers) == 1 else tuple(optimizers) def setup_dataloaders( diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index 947510244d0df..8ed10f00c5511 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, overload, TypeVar, Union +from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, overload, TypeVar, Union import torch from lightning_utilities import WarningCache @@ -38,7 +38,7 @@ class _FabricOptimizer: - def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None: + def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[List[Callable]] = None) -> None: """FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer step calls to the strategy plugin. @@ -54,6 +54,7 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None: self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer self._strategy = strategy + self._callbacks = callbacks or [] @property def optimizer(self) -> Optimizer: @@ -69,10 +70,15 @@ def step(self, closure: Optional[Callable] = None) -> Any: optimizer = self._strategy.model else: optimizer = self.optimizer - return self._strategy.optimizer_step( + output = self._strategy.optimizer_step( optimizer, **kwargs, ) + for callback in self._callbacks: + hook = getattr(callback, "on_after_optimizer_step", None) + if callable(hook): + hook(strategy=self._strategy, optimizer=optimizer) + return output class _FabricModule(_DeviceDtypeModuleMixin): diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 58ab0e34166ed..015d1aba41359 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -753,6 +753,29 @@ def test_call(): assert not callback1.mock_calls +def test_special_callbacks(): + """Tests special callbacks that have hooks for internal Fabric events.""" + + class SpecialCallback: + def on_after_optimizer_step(self, strategy, optimizer): + pass + + def on_after_setup(self, fabric, module): + pass + + callback = Mock(wraps=SpecialCallback()) + fabric = Fabric(accelerator="cpu", callbacks=[callback]) + + model = torch.nn.Linear(2, 2) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + fabric_model, fabric_optimizer = fabric.setup(model, optimizer) + callback.on_after_setup.assert_called_once_with(fabric=fabric, module=fabric_model) + + model(torch.randn(2, 2)).sum().backward() + fabric_optimizer.step() + callback.on_after_optimizer_step.assert_called_once_with(strategy=fabric._strategy, optimizer=optimizer) + + def test_loggers_input(): """Test the various ways in which loggers can be registered with Fabric.""" logger0 = Mock()