Skip to content

Commit

Permalink
Fix LightningModule step methods bypassing DDP wrapper in Fabric (#17424
Browse files Browse the repository at this point in the history
)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

(cherry picked from commit 0ee71d6)
  • Loading branch information
awaelchli authored and lantiga committed Apr 24, 2023
1 parent 36c7710 commit 24ddce0
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed an issue with `LightningModule.*_step` methods bypassing the DDP/FSDP wrapper ([#17424](https://github.com/Lightning-AI/lightning/pull/17424))


## [2.0.1] - 2023-03-30
Expand Down
43 changes: 42 additions & 1 deletion src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, overload, TypeVar, Union

import torch
from lightning_utilities import WarningCache
from lightning_utilities.core.apply_func import apply_to_collection
from torch import nn as nn
from torch import Tensor
Expand All @@ -28,8 +29,11 @@
from lightning.fabric.utilities.data import _set_sampler_epoch
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.types import Optimizable
from lightning.fabric.utilities.warnings import PossibleUserWarning

warning_cache = WarningCache()
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
_LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step")


class _FabricOptimizer:
Expand Down Expand Up @@ -132,15 +136,52 @@ def state_dict(
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) -> _IncompatibleKeys:
return self._original_module.load_state_dict(state_dict=state_dict, strict=strict)

def _redirection_through_forward(self, method_name: str) -> Callable:
assert method_name != "forward"
original_forward = self._original_module.forward

def wrapped_forward(*args: Any, **kwargs: Any) -> Any:
# Unpatch ourselves immediately before calling the method `method_name`
# because itself may want to call the real `forward`
self._original_module.forward = original_forward
# Call the actual method e.g. `.training_step(...)`
method = getattr(self._original_module, method_name)
return method(*args, **kwargs)

# We make the caller "unknowingly" send their arguments through the forward_module's `__call__`.
# We expect that the `forward_module` will eventually call `original_module.forward`, which we
# have patched to redirect back to `original_module.method_name()`.
def call_forward_module(*args: Any, **kwargs: Any) -> Any:
# Patch the original_module's forward so we can redirect the arguments back to the real method
self._original_module.forward = wrapped_forward
return self._forward_module(*args, **kwargs)

return call_forward_module

def _validate_method_access(self, name: str, attribute: Any) -> None:
if inspect.ismethod(attribute) and self._forward_module != self._original_module:
warning_cache.warn(
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
" model. This will bypass the wrapper from the strategy and result in incorrect behavior in"
f" `.backward()`. You should pass your inputs through `{type(self._original_module)}.forward()`.",
category=PossibleUserWarning,
)

def __getattr__(self, item: Any) -> Any:
if item in _LIGHTNING_MODULE_STEP_METHODS and self._forward_module != self._original_module:
# Special support for `LightningModule`, to prevent bypassing DDP's forward
return self._redirection_through_forward(item)

try:
# __getattr__ gets called as a last resort if the attribute does not exist
# call nn.Module's implementation first
return super().__getattr__(item)
except AttributeError:
# If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module
original_module = super().__getattr__("_original_module")
return getattr(original_module, item)
attr = getattr(original_module, item)
self._validate_method_access(item, attr)
return attr


class _FabricDataLoader:
Expand Down
98 changes: 97 additions & 1 deletion tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest
import torch
from lightning_utilities.test.warning import no_warning_call
from torch.utils.data import BatchSampler, DistributedSampler
from torch.utils.data.dataloader import DataLoader

Expand Down Expand Up @@ -59,13 +60,43 @@ def __init__(self):
fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
assert fabric_module.attribute == 1
assert fabric_module.layer is original_module.layer
assert fabric_module.method() == 2
assert fabric_module.forward.__self__.__class__ == _FabricModule

with pytest.raises(AttributeError):
_ = fabric_module.not_exists


def test_fabric_module_method_lookup():
"""Test that access to methods warns about improper use when a wrapper from a strategy is involved."""
from lightning.fabric.wrappers import warning_cache

class OriginalModule(torch.nn.Module):
def method(self):
return 100

class ModuleWrapper(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.wrapped = module

# Regular case: forward_module == original_module -> no warnings
original_module = OriginalModule()
fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module)
warning_cache.clear()
with no_warning_call(UserWarning):
assert fabric_module.method() == 100
assert not warning_cache

# Special case: original module wrapped by forward module: -> warn
original_module = OriginalModule()
wrapped_module = ModuleWrapper(original_module)
fabric_module = _FabricModule(forward_module=wrapped_module, precision=Mock(), original_module=original_module)
warning_cache.clear()
with pytest.warns(UserWarning, match=r"You are calling the method `OriginalModule.method\(\)` from outside the"):
assert fabric_module.method() == 100
warning_cache.clear()


def test_fabric_module_state_dict_access():
"""Test that state_dict access passes through to the original module."""

Expand Down Expand Up @@ -353,3 +384,68 @@ def test_is_wrapped():
assert not is_wrapped(dataloader)
wrapped = _FabricDataLoader(dataloader)
assert is_wrapped(wrapped)


def test_step_method_redirection():
"""Test that the FabricModule redirects the special `LightningModule.*_step` methods through the forward-
module."""

class DDP(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module

def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)

class LightningModule(torch.nn.Module):
def forward(self):
return "forward_return"

def training_step(self, arg, kwarg=None):
assert self() == "forward_return"
assert arg == "train_arg"
assert kwarg == "train_kwarg"
return "training_step_return"

def validation_step(self, arg, kwarg=None):
assert self() == "forward_return"
assert arg == "val_arg"
assert kwarg == "val_kwarg"
return "validation_step_return"

def normal_method(self):
pass

original_module = LightningModule()
forward_module = DDP(original_module)
fabric_module = _FabricModule(forward_module=forward_module, precision=Mock(), original_module=original_module)

# Regular methods on the original_module are visible and identical on the fabric_module ...
assert fabric_module.normal_method == original_module.normal_method

# ... but special methods like training_step get redirected to the forward_module
assert fabric_module.training_step.__name__ == "call_forward_module"
assert fabric_module.validation_step.__name__ == "call_forward_module"
assert fabric_module.test_step.__name__ == "call_forward_module"
assert fabric_module.predict_step.__name__ == "call_forward_module"

with pytest.raises(AttributeError, match="has no attribute 'predict_step'"):
# A special method that does not exist will raise its AttributeError when being called
fabric_module.predict_step()

# The forward method on the original module remains untouched
assert original_module.forward.__name__ == "forward"

# The special methods get redirected correctly to produce the expected output
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return"
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return" # call 2nd time
assert fabric_module.validation_step("val_arg", kwarg="val_kwarg") == "validation_step_return"

# The forward method remains untouched/unpatched after the special methods have been called
assert original_module.forward.__name__ == "forward"

# Special case: forward_module == original_module -> no special treatment applied
fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module)
assert fabric_module.training_step == original_module.training_step
assert fabric_module.validation_step == original_module.validation_step

0 comments on commit 24ddce0

Please sign in to comment.