Skip to content

Commit

Permalink
Fix is_layer_at_idx for LRP (#308)
Browse files Browse the repository at this point in the history
* Fix `is_layer_at_idx` for LRP
* add composite lrp test
  • Loading branch information
Rubinjo committed May 5, 2023
1 parent 7397864 commit 98260ea
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,9 @@ def __init__(
# apply rule to first self._until_layer_idx layers
if self._until_layer_rule is not None and self._until_layer_idx is not None:
for i in range(self._until_layer_idx + 1):
is_at_idx: LayerCheck = lambda layer: ichecks.is_layer_at_idx(layer, i)
is_at_idx: LayerCheck = lambda layer, i=i: ichecks.is_layer_at_idx(
model, layer, i
)
rules.insert(0, (is_at_idx, self._until_layer_rule))

# create a BoundedRule for input layer handling from given tuple
Expand Down
10 changes: 4 additions & 6 deletions src/innvestigate/backend/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tensorflow import Module, keras

import innvestigate.backend as ibackend
from innvestigate.backend.types import Layer
from innvestigate.backend.types import Layer, Model

__all__ = [
"get_activation_search_safe_layers",
Expand Down Expand Up @@ -325,8 +325,6 @@ def is_input_layer(layer: Layer, ignore_reshape_layers: bool = True) -> bool:
return all(isinstance(x, klayers.InputLayer) for x in layer_inputs)


def is_layer_at_idx(layer: Layer, index, ignore_reshape_layers=True) -> bool:
"""Checks if layer is a layer at index index,
by repeatedly applying is_input_layer()."""
# TODO: implement layer index check
raise NotImplementedError("Layer index checking hasn't been implemented yet.")
def is_layer_at_idx(model: Model, layer: Layer, index) -> bool:
"""Checks if layer is a layer at specified index of model."""
return layer == model.layers[index]
42 changes: 42 additions & 0 deletions tests/backend/test_layer_idx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
import tensorflow.keras.layers as klayers
import tensorflow.keras.models as kmodels

from innvestigate.analyzer.relevance_based.relevance_analyzer import LRP


@pytest.mark.graph
@pytest.mark.fast
@pytest.mark.lrp
def test_composite_lrp():
model = kmodels.Sequential(
[
klayers.Input(shape=(28, 28, 3)),
klayers.Conv2D(8, 3, activation="relu"),
klayers.Conv2D(4, 5, activation="relu"),
klayers.Flatten(),
klayers.Dense(16, activation="relu"),
klayers.Dense(2, activation="softmax"),
]
)
analyzer = LRP(
model,
rule="Z",
input_layer_rule="Flat",
until_layer_idx=2,
until_layer_rule="Epsilon",
)
correct_rules = [
"Flat",
"Epsilon",
"Epsilon",
"Z",
"Z",
] # Correct rules corresponding to analyzer input args

for i, layer in enumerate(model.layers):
for condition, rule in analyzer._rules:
if condition(layer):
rule_class = rule
break
assert rule_class == correct_rules[i]

0 comments on commit 98260ea

Please sign in to comment.