Skip to content

Commit

Permalink
Revert AugmentReduceBase neuron selection mode
Browse files Browse the repository at this point in the history
use previous mechanism of overwriting the subanalyzer's `neuron_selection_mode`. Add comments and assert subanalyzer neuron selection mode.
  • Loading branch information
adrhill committed Jul 27, 2021
1 parent 909a7e9 commit 7419aa4
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 61 deletions.
58 changes: 41 additions & 17 deletions src/innvestigate/analyzer/gradient_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,20 +270,32 @@ class IntegratedGradients(PathIntegrator):
:param steps: Number of steps to use average along integration path.
"""

def __init__(self, model, steps=64, **kwargs):
def __init__(
self,
model,
steps=64,
neuron_selection_mode="max_activation",
postprocess=None,
**kwargs
):
# If initialized through serialization:
if "subanalyzer" in kwargs:
subanalyzer = kwargs.pop("subanalyzer")
# If initialized normally:
else:
subanalyzer_kwargs = {}
kwargs_keys = ["neuron_selection_mode", "postprocess"]
for key in kwargs_keys:
if key in kwargs:
subanalyzer_kwargs[key] = kwargs.pop(key)
subanalyzer = Gradient(model, **subanalyzer_kwargs)

super().__init__(subanalyzer, steps=steps, **kwargs)
subanalyzer = Gradient(
model,
neuron_selection_mode=neuron_selection_mode,
postprocess=postprocess,
)

super().__init__(
subanalyzer,
steps=steps,
neuron_selection_mode=neuron_selection_mode,
**kwargs
)


###############################################################################
Expand All @@ -298,17 +310,29 @@ class SmoothGrad(GaussianSmoother):
:param augment_by_n: Number of distortions to average for smoothing.
"""

def __init__(self, model, augment_by_n=64, **kwargs):
def __init__(
self,
model,
augment_by_n=64,
neuron_selection_mode="max_activation",
postprocess=None,
**kwargs
):
# If initialized through serialization:
if "subanalyzer" in kwargs:
subanalyzer = kwargs.pop("subanalyzer")
# If initialized normally:
else:
subanalyzer_kwargs = {}
kwargs_keys = ["neuron_selection_mode", "postprocess"]
for key in kwargs_keys:
if key in kwargs:
subanalyzer_kwargs[key] = kwargs.pop(key)
subanalyzer = Gradient(model, **subanalyzer_kwargs)

super().__init__(subanalyzer, augment_by_n=augment_by_n, **kwargs)

subanalyzer = Gradient(
model,
neuron_selection_mode=neuron_selection_mode,
postprocess=postprocess,
)

super().__init__(
subanalyzer,
augment_by_n=augment_by_n,
neuron_selection_mode=neuron_selection_mode,
**kwargs
)
98 changes: 54 additions & 44 deletions src/innvestigate/analyzer/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __init__(self, subanalyzer: AnalyzerBase, *args, **kwargs):
# and the subanalyzer model is passed to `AnalyzerBase`.
kwargs.pop("model", None)
super().__init__(subanalyzer._model, *args, **kwargs)

self._subanalyzer_name = subanalyzer.__class__.__name__
self._subanalyzer = subanalyzer

def analyze(self, *args, **kwargs):
Expand Down Expand Up @@ -76,32 +78,32 @@ class AugmentReduceBase(WrapperBase):
"""

def __init__(
self, subanalyzer: AnalyzerNetworkBase, *args, augment_by_n: int = 2, **kwargs
self,
subanalyzer: AnalyzerNetworkBase,
*args,
augment_by_n: int = 2,
neuron_selection_mode="max_activation",
**kwargs,
):
if not isinstance(subanalyzer, AnalyzerNetworkBase):
raise NotImplementedError("Keras-based subanalyzer required.")

_subanalyzer_name = subanalyzer.__class__.__name__
if subanalyzer._neuron_selection_mode == "max_activation":
warnings.warn(
f"Subanalyzer {_subanalyzer_name} created through AugmentReduceBase "
f"""only supports neuron_selection_mode "all" and "index". """
f"""Specified mode "max_activation" has been changed to "all"."""
)
subanalyzer._neuron_selection_mode = "all"

if subanalyzer._neuron_selection_mode not in ["all", "index"]:
raise NotImplementedError(
f"Subanalyzer {_subanalyzer_name} created through AugmentReduceBase "
f"""only supports neuron_selection_mode "all" and "index". """
f"""got "{subanalyzer._neuron_selection_mode}"."""
)

super().__init__(subanalyzer, *args, **kwargs)
if neuron_selection_mode == "max_activation":
# TODO: find a more transparent way.
#
# Since AugmentReduceBase analyzers augment the input,
# it is possible that the neuron w/ max activation changes.
# As a workaround, the index of the maximally activated neuron
# w.r.t. the "unperturbed" input is computed and used in combination
# with neuron_selection_mode = "index" in the subanalyzer.
#
# NOTE:
# The analyzer will still have neuron_selection_mode = "max_activation"!
subanalyzer._neuron_selection_mode = "index"

super().__init__(
subanalyzer, *args, neuron_selection_mode=neuron_selection_mode, **kwargs
)

self._subanalyzer = subanalyzer
self._augment_by_n = augment_by_n
self._neuron_selection_mode = subanalyzer._neuron_selection_mode
self._augment_by_n: int = augment_by_n # number of samples to create

def create_analyzer_model(self):
self._subanalyzer.create_analyzer_model()
Expand Down Expand Up @@ -141,28 +143,36 @@ def create_analyzer_model(self):
def analyze(
self, X: OptionalList[np.ndarray], *args, **kwargs
) -> OptionalList[np.ndarray]:
if not hasattr(self._subanalyzer, "_analyzer_model"):
if self._subanalyzer._analyzer_model is None:
self.create_analyzer_model()

ns_mode = self._neuron_selection_mode
# TODO: fix neuron_selection with mode "index"
if ns_mode in ["max_activation", "index"]:
if ns_mode == "index":
# TODO: make neuron_selection arg or kwarg, not both
if len(args):
arglist = list(args)
indices = arglist.pop(0)
else:
indices = kwargs.pop("neuron_selection")
# TODO: add "max_activation"
# elif ns_mode == "max_activation":
# tmp = self._subanalyzer._model.predict(X)
# indices = np.argmax(tmp, axis=1)

# broadcast to match augmented samples.
indices = np.repeat(indices, self._augment_by_n)

kwargs["neuron_selection"] = indices
if ns_mode == "all":
return self._subanalyzer.analyze(X, *args, **kwargs)

# As described in the AugmentReduceBase init,
# both ns_mode "max_activation" and "index" make use
# of a subanalyzer using neuron_selection_mode="index".
elif ns_mode == "max_activation":
# obtain max neuron activations over batch
pred = self._subanalyzer._model.predict(X)
indices = np.argmax(pred, axis=1)
elif ns_mode == "index":
# TODO: make neuron_selection arg or kwarg, not both
if len(args):
arglist = list(args)
indices = arglist.pop(0)
else:
indices = kwargs.pop("neuron_selection")

if not self._subanalyzer._neuron_selection_mode == "index":
raise AssertionError(
'Subanalyzer neuron_selection_mode has to be "index" '
'when using analyzer with neuron_selection_mode != "all".'
)
# broadcast to match augmented samples.
indices = np.repeat(indices, self._augment_by_n)
kwargs["neuron_selection"] = indices
return self._subanalyzer.analyze(X, *args, **kwargs)

def _keras_get_constant_inputs(self):
Expand Down Expand Up @@ -242,8 +252,8 @@ class PathIntegrator(AugmentReduceBase):
This analyzer:
* creates a path from input to reference image.
* creates steps number of intermediate inputs and
crests an analysis for them.
* creates `steps` number of intermediate inputs and
creates an analysis for them.
* sums the analyses and multiplies them with the input-reference_input.
This wrapper is used to implement Integrated Gradients.
Expand Down

0 comments on commit 7419aa4

Please sign in to comment.