-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support of transformer-based retrieval models #1128
Conversation
Documentation preview |
@@ -206,6 +206,8 @@ def call( | |||
"You should call the `index` method first to " "set the _candidates index." | |||
) | |||
|
|||
if isinstance(inputs, tf.RaggedTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes that we evaluate only on the last item in the session (which is the default mode during inference too). We might need to extend it to evaluate other items in the sequence in future work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sararb. I was trying to understand why we would have 1 as the 2nd dim of inputs, but from your explanation I now understand the reason is that we have predictions only for the last position.
I think it would be useful to add this remark as a comment.
@sararb I can save the xlnet model but I cannot load it back. I am getting an error. do you think you can add a test for that in the unit test? also a test to showcase how one can do offline prediction? thanks. |
@@ -206,6 +206,8 @@ def call( | |||
"You should call the `index` method first to " "set the _candidates index." | |||
) | |||
|
|||
if isinstance(inputs, tf.RaggedTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sararb. I was trying to understand why we would have 1 as the 2nd dim of inputs, but from your explanation I now understand the reason is that we have predictions only for the last position.
I think it would be useful to add this remark as a comment.
Note: with making the bias term optional in the weight-tying layer EmbeddingTablePrediction, we make sure that the training and inference models have same score calculation, bcs contrastive output head is not exported during inference. |
Goals ⚽
BruteForce(TopKLayer)
class. This is needed to export a sequential session encoder.EmbeddingTablePrediction
. The default is false, as we wouldn't have access to the bias term if we export the query encoder to an ANN system for inference.Testing Details 🔍
test_brute_force_layer
.use_bias
option to the unit testtest_last_item_prediction
.