Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] AnalyzerNetworkBase analyzers error when using BatchNormalization layers #292

Open
adrhill opened this issue Oct 11, 2022 · 2 comments
Assignees
Labels

Comments

@adrhill
Copy link
Collaborator

adrhill commented Oct 11, 2022

On iNNvestigate v2.0.1, creating an analyzer inheriting from AnalyzerNetworkBase errors when the model contains a BatchNormalization layer, e.g.:

tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'dense_2_input' with dtype float and shape [?,50]

This might be due to batch normalisation layers keeping moving averages of the mean and standard deviation of the training data, causing problems with the Keras history when reversing the computational graph in iNNvestigate's create_analyzer_model.

Minimal example reproducing the issue

import numpy as np
import tensorflow as tf
from keras.layers import BatchNormalization, Dense
from keras.models import Sequential

import innvestigate

tf.compat.v1.disable_eager_execution()

input_shape = (50,)
x = np.random.rand(100, *input_shape)
y = np.random.rand(100, 2)

model1 = Sequential()
model1.add(Dense(10, input_shape=input_shape))
model1.add(Dense(2))

model2 = Sequential()
model2.add(Dense(10, input_shape=input_shape))
model2.add(BatchNormalization())
model2.add(Dense(2))


def run_analysis(model):
    model.compile(optimizer="adam", loss="mse")
    model.fit(x, y, epochs=10, verbose=0)

    analyzer = innvestigate.create_analyzer("gradient", model)
    analyzer.analyze(x)


print("Model without BatchNormalization:")  # passes
run_analysis(model1)
print("Model with BatchNormalization:")     # errors
run_analysis(model2)

Full stacktrace

Model with BatchNormalization:
/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/tensorflow/python/client/session.py:1480: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  ret = tf_session.TF_SessionRunCallable(self._session._session,
Traceback (most recent call last):
  File "/Users/funks/Developer/innvestigate-issues/open/issue_238_v3", line 35, in <module>
    run_analysis(model2)
  File "/Users/funks/Developer/innvestigate-issues/open/issue_238_v3", line 29, in run_analysis
    analyzer.analyze(x)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/innvestigate/analyzer/network_base.py", line 250, in analyze
    self.create_analyzer_model()
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/innvestigate/analyzer/network_base.py", line 196, in create_analyzer_model
    self._analyzer_model = kmodels.Model(
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/tensorflow/python/training/tracking/base.py", line 629, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/functional.py", line 146, in __init__
    self._init_graph_network(inputs, outputs)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/tensorflow/python/training/tracking/base.py", line 629, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/functional.py", line 181, in _init_graph_network
    base_layer_utils.create_keras_history(self._nested_outputs)
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/base_layer_utils.py", line 175, in create_keras_history
    _, created_layers = _create_keras_history_helper(tensors, set(), [])
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/base_layer_utils.py", line 253, in _create_keras_history_helper
    processed_ops, created_layers = _create_keras_history_helper(
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/base_layer_utils.py", line 253, in _create_keras_history_helper
    processed_ops, created_layers = _create_keras_history_helper(
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/base_layer_utils.py", line 253, in _create_keras_history_helper
    processed_ops, created_layers = _create_keras_history_helper(
  [Previous line repeated 3 more times]
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/engine/base_layer_utils.py", line 251, in _create_keras_history_helper
    constants[i] = backend.function([], op_input)([])
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/keras/backend.py", line 4275, in __call__
    fetched = self._callable_fn(*array_vals,
  File "/Users/funks/Library/Caches/pypoetry/virtualenvs/innvestigate-issues-W0iScZgu-py3.10/lib/python3.10/site-packages/tensorflow/python/client/session.py", line 1480, in __call__
    ret = tf_session.TF_SessionRunCallable(self._session._session,
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'dense_2_input' with dtype float and shape [?,50]
         [[{{node dense_2_input}}]]
@yap231995
Copy link

Hello, i am also encountering this issue. How do have a work around it? I saw that you can try to change to Dense layer. Is there a code that i could reference from?

@adrhill
Copy link
Collaborator Author

adrhill commented Feb 23, 2023

The workaround using a Dense layer is described here: #283 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants