diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java index 5abf5a1de4e..6dc1a4ed454 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java @@ -35,16 +35,19 @@ public class TextEmbeddingTranslator implements Translator { private Batchifier batchifier; private boolean normalize; private String pooling; + private boolean includeTokenTypes; TextEmbeddingTranslator( HuggingFaceTokenizer tokenizer, Batchifier batchifier, String pooling, - boolean normalize) { + boolean normalize, + boolean includeTokenTypes) { this.tokenizer = tokenizer; this.batchifier = batchifier; this.pooling = pooling; this.normalize = normalize; + this.includeTokenTypes = includeTokenTypes; } /** {@inheritDoc} */ @@ -58,7 +61,7 @@ public Batchifier getBatchifier() { public NDList processInput(TranslatorContext ctx, String input) { Encoding encoding = tokenizer.encode(input); ctx.setAttachment("encoding", encoding); - return encoding.toNDList(ctx.getNDManager(), false); + return encoding.toNDList(ctx.getNDManager(), includeTokenTypes); } /** {@inheritDoc} */ @@ -84,6 +87,10 @@ public TextEmbeddingBatchTranslator toBatchTranslator(Batchifier batchifier) { static NDArray processEmbedding( NDManager manager, NDList list, Encoding encoding, String pooling) { NDArray embedding = list.get("last_hidden_state"); + if (embedding == null) { + // For Onnx model, NDArray name is not present + embedding = list.head(); + } long[] attentionMask = encoding.getAttentionMask(); NDArray inputAttentionMask = manager.create(attentionMask).toType(DataType.FLOAT32, true); switch (pooling) { @@ -167,6 +174,7 @@ public static final class Builder { private Batchifier batchifier = Batchifier.STACK; private boolean normalize = true; private String pooling = "mean"; + private boolean includeTokenTypes; Builder(HuggingFaceTokenizer tokenizer) { this.tokenizer = tokenizer; @@ -214,6 +222,17 @@ public Builder optPoolingMode(String poolingMode) { return this; } + /** + * Sets if include token types for the {@link Translator}. + * + * @param includeTokenTypes true to include token types + * @return this builder + */ + public Builder optIncludeTokenTypes(boolean includeTokenTypes) { + this.includeTokenTypes = includeTokenTypes; + return this; + } + /** * Configures the builder with the model arguments. * @@ -224,6 +243,7 @@ public void configure(Map arguments) { optBatchifier(Batchifier.fromString(batchifierStr)); optNormalize(ArgumentsUtil.booleanValue(arguments, "normalize", true)); optPoolingMode(ArgumentsUtil.stringValue(arguments, "pooling", "mean")); + optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes")); } /** @@ -233,7 +253,8 @@ public void configure(Map arguments) { * @throws IOException if I/O error occurs */ public TextEmbeddingTranslator build() throws IOException { - return new TextEmbeddingTranslator(tokenizer, batchifier, pooling, normalize); + return new TextEmbeddingTranslator( + tokenizer, batchifier, pooling, normalize, includeTokenTypes); } } }