Skip to content

Commit

Permalink
Add way to strip arbitrary output activations (#310)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rubinjo committed May 5, 2023
1 parent 98260ea commit 4c0f355
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
5 changes: 4 additions & 1 deletion src/innvestigate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from innvestigate import analyzer # noqa
from innvestigate.analyzer import create_analyzer # noqa
from innvestigate.analyzer.base import NotAnalyzeableModelException # noqa
from innvestigate.backend.graph import model_wo_softmax # noqa
from innvestigate.backend.graph import (
model_wo_softmax,
model_wo_output_activation,
) # noqa

__version__ = "2.0.1"
33 changes: 24 additions & 9 deletions src/innvestigate/backend/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
"get_layer_neuronwise_io",
"copy_layer_wo_activation",
"copy_layer",
"pre_softmax_tensors",
"pre_output_tensors",
"model_wo_softmax",
"model_wo_output_activation",
"get_model_layers",
"model_contains",
"trace_model_execution",
Expand Down Expand Up @@ -338,32 +339,46 @@ def copy_layer(
return get_layer_from_config(layer, config, weights=weights, **kwargs)


def pre_softmax_tensors(Xs: Tensor, should_find_softmax: bool = True) -> list[Tensor]:
"""Finds the tensors that were preceeding a potential softmax."""
softmax_found = False
def pre_output_tensors(Xs: Tensor, activation: str = None) -> list[Tensor]:
"""Finds the tensors that were preceeding a potential activation."""
activation_found = False

Xs = ibackend.to_list(Xs)
ret = []
for x in Xs:
layer, node_index, _tensor_index = x._keras_history
if ichecks.contains_activation(layer, activation="softmax"):
softmax_found = True
if ichecks.contains_activation(layer, activation=activation):
activation_found = True
if isinstance(layer, klayers.Activation):
ret.append(layer.get_input_at(node_index))
else:
layer_wo_act = copy_layer_wo_activation(layer)
ret.append(layer_wo_act(layer.get_input_at(node_index)))

if should_find_softmax and not softmax_found:
raise Exception("No softmax found.")
if not activation_found:
if not activation == None:
raise Exception(f"No output activation found.")
else:
raise Exception(f"No {activation} found.")

return ret


def model_wo_softmax(model: Model) -> Model:
"""Creates a new model w/o the final softmax activation."""
return kmodels.Model(
inputs=model.inputs, outputs=pre_softmax_tensors(model.outputs), name=model.name
inputs=model.inputs,
outputs=pre_output_tensors(model.outputs, activation="softmax"),
name=model.name,
)


def model_wo_output_activation(model: Model) -> Model:
"""Creates a new model w/o the final activation."""
return kmodels.Model(
inputs=model.inputs,
outputs=pre_output_tensors(model.outputs),
name=model.name,
)


Expand Down

0 comments on commit 4c0f355

Please sign in to comment.