Skip to content

Commit

Permalink
Add Fabric internal hooks (#17759)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and Borda committed Jun 7, 2023
1 parent 5491e15 commit 55dcc2b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 5 deletions.
12 changes: 10 additions & 2 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,19 @@ def setup(
# Update the _DeviceDtypeModuleMixin's device parameter
module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device)

optimizers = [_FabricOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
optimizers = [
_FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks)
for optimizer in optimizers
]

self._models_setup += 1

if hasattr(original_module, "_fabric"): # this is probably a LightningModule
original_module._fabric = self # type: ignore[assignment]
original_module._fabric_optimizers = optimizers # type: ignore[assignment]

self.call("on_after_setup", fabric=self, module=module)

if optimizers:
# join both types in a tuple for API convenience
return (module, *optimizers)
Expand Down Expand Up @@ -276,7 +281,10 @@ def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, Tu
"""
self._validate_setup_optimizers(optimizers)
optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers]
optimizers = [_FabricOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
optimizers = [
_FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks)
for optimizer in optimizers
]
return optimizers[0] if len(optimizers) == 1 else tuple(optimizers)

def setup_dataloaders(
Expand Down
12 changes: 9 additions & 3 deletions src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, overload, TypeVar, Union
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, overload, TypeVar, Union

import torch
from lightning_utilities import WarningCache
Expand All @@ -38,7 +38,7 @@


class _FabricOptimizer:
def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None:
def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[List[Callable]] = None) -> None:
"""FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the
optimizer step calls to the strategy plugin.
Expand All @@ -54,6 +54,7 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None:
self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
self._optimizer = optimizer
self._strategy = strategy
self._callbacks = callbacks or []

@property
def optimizer(self) -> Optimizer:
Expand All @@ -69,10 +70,15 @@ def step(self, closure: Optional[Callable] = None) -> Any:
optimizer = self._strategy.model
else:
optimizer = self.optimizer
return self._strategy.optimizer_step(
output = self._strategy.optimizer_step(
optimizer,
**kwargs,
)
for callback in self._callbacks:
hook = getattr(callback, "on_after_optimizer_step", None)
if callable(hook):
hook(strategy=self._strategy, optimizer=optimizer)
return output


class _FabricModule(_DeviceDtypeModuleMixin):
Expand Down
23 changes: 23 additions & 0 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,29 @@ def test_call():
assert not callback1.mock_calls


def test_special_callbacks():
"""Tests special callbacks that have hooks for internal Fabric events."""

class SpecialCallback:
def on_after_optimizer_step(self, strategy, optimizer):
pass

def on_after_setup(self, fabric, module):
pass

callback = Mock(wraps=SpecialCallback())
fabric = Fabric(accelerator="cpu", callbacks=[callback])

model = torch.nn.Linear(2, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
fabric_model, fabric_optimizer = fabric.setup(model, optimizer)
callback.on_after_setup.assert_called_once_with(fabric=fabric, module=fabric_model)

model(torch.randn(2, 2)).sum().backward()
fabric_optimizer.step()
callback.on_after_optimizer_step.assert_called_once_with(strategy=fabric._strategy, optimizer=optimizer)


def test_loggers_input():
"""Test the various ways in which loggers can be registered with Fabric."""
logger0 = Mock()
Expand Down

0 comments on commit 55dcc2b

Please sign in to comment.