Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the compute_output_shape method on DenseVariational. #1515

Merged
merged 4 commits into from
Jun 7, 2022

Conversation

Frightera
Copy link
Contributor

Related to the issue #1505. dense_variational_v2.py does not have this method so it throws:

NotImplementedError: Exception encountered when calling layer "time_distributed" (type TimeDistributed).
    
Please run in eager mode or implement the `compute_output_shape` method on your layer (DenseVariational).

Defining compute_output_shape method fixes this issue. Please see the gist for the reproducible example, fixed example and unit test result.

@Frightera
Copy link
Contributor Author

@ColCarroll Thanks for approving. This should still be an issue with TFP 0.16 as

class DenseVariational(tf.keras.layers.Layer):

does not have compute_output_shape method. Old layers such as DenseFlipout etc has this method:

def compute_output_shape(self, input_shape):

@copybara-service copybara-service bot merged commit 0efbee9 into tensorflow:main Jun 7, 2022
@ColCarroll
Copy link
Contributor

Thanks for the contribution, @Frightera! I'm not sure I understand your comment -- I think you added this method here, so it should be better now, right?

@Frightera
Copy link
Contributor Author

@ColCarroll Yes, it should be better. The last stable version of TFP was 0.15 when this method was implemented. I wanted to point out that it is still an issue with TFP 0.16.

Also, issue #1505 can be closed now.

Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants