-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a Python alternative to seq2seq.gather_tree #1925
Conversation
Marking this PR as draft. The fixture from #1929 is needed to run tests for both the custom op and the py op. |
Tests are now running for this new implementation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change. I didn't verify the detail of the beam search implementation, since the newly added test should ensure that the py implementation will have the same behavior as the fused gpu kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM Thx!
* Add a Python alternative to seq2seq.gather_tree * Enable tests for the Python op
* Add a Python alternative to seq2seq.gather_tree * Enable tests for the Python op
This PR adds a pure Python TensorFlow implementation of
tfa.seq2seq.gather_tree
, which can be enabled with the global flagTF_ADDONS_PY_OP
. This is useful when Addons custom ops are not readily available (e.g. TensorFlow Serving).I tested the performance on a real world application: beam search decoding of a neural machine translation model. I'm using a custom beam search implementation but it is close to the
BeamSearchDecoder
included in Addons.I did not find significant performance impact. And I think this makes sense: the function is only used at the very end of the decoding and does not involve very complex ops.