Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of transformer-based retrieval models #1128

Merged
merged 9 commits into from
Jun 12, 2023
Merged

Add support of transformer-based retrieval models #1128

merged 9 commits into from
Jun 12, 2023

Conversation

sararb
Copy link
Contributor

@sararb sararb commented May 31, 2023

Goals ⚽

  • Support ragged queries and targets in the BruteForce(TopKLayer) class. This is needed to export a sequential session encoder.
  • Support exporting the candidate embeddings from CategoricalOutput class.
  • Make the bias term optional in the weight-tying layer EmbeddingTablePrediction. The default is false, as we wouldn't have access to the bias term if we export the query encoder to an ANN system for inference.

Testing Details 🔍

  • Add a check of ragged_query in the unit testtest_brute_force_layer.
  • Add use_bias option to the unit test test_last_item_prediction.

@sararb sararb added bug Something isn't working enhancement New feature or request P0 labels May 31, 2023
@sararb sararb self-assigned this May 31, 2023
@github-actions
Copy link

Documentation preview

https://nvidia-merlin.github.io/models/review/pr-1128

@@ -206,6 +206,8 @@ def call(
"You should call the `index` method first to " "set the _candidates index."
)

if isinstance(inputs, tf.RaggedTensor):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that we evaluate only on the last item in the session (which is the default mode during inference too). We might need to extend it to evaluate other items in the sequence in future work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sararb. I was trying to understand why we would have 1 as the 2nd dim of inputs, but from your explanation I now understand the reason is that we have predictions only for the last position.
I think it would be useful to add this remark as a comment.

@rnyak
Copy link
Contributor

rnyak commented Jun 2, 2023

@sararb I can save the xlnet model but I cannot load it back. I am getting an error. ValueError: The last dimension of the input shape of a Dense layer should be defined. Found None. Received: input_shape=(None, None)

do you think you can add a test for that in the unit test? also a test to showcase how one can do offline prediction? thanks.

@@ -206,6 +206,8 @@ def call(
"You should call the `index` method first to " "set the _candidates index."
)

if isinstance(inputs, tf.RaggedTensor):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sararb. I was trying to understand why we would have 1 as the 2nd dim of inputs, but from your explanation I now understand the reason is that we have predictions only for the last position.
I think it would be useful to add this remark as a comment.

@rnyak rnyak added this to the Merlin 23.06 milestone Jun 5, 2023
@edknv edknv merged commit 9980689 into main Jun 12, 2023
@edknv edknv deleted the topk-ragged branch June 12, 2023 08:59
@rnyak
Copy link
Contributor

rnyak commented Nov 7, 2023

Note: with making the bias term optional in the weight-tying layer EmbeddingTablePrediction, we make sure that the training and inference models have same score calculation, bcs contrastive output head is not exported during inference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request P0
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants