From 375fc779b40a4d00653ec32c1040bc337a423b10 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 28 Oct 2021 10:40:49 -0400 Subject: [PATCH] [ML] update truncation default & adding field output when input is truncated (#79942) This commit makes the two following changes (along with some refactoring) - Nlp results will now indicate if the input was truncated or not - The default truncation is now `none` instead of `first` --- docs/reference/ml/ml-shared.asciidoc | 2 +- .../MlInferenceNamedXContentProvider.java | 9 +- .../ml/inference/results/FillMaskResults.java | 37 ++--- .../core/ml/inference/results/NerResults.java | 24 ++-- .../NlpClassificationInferenceResults.java | 129 ++++++++++++++++++ .../results/NlpInferenceResults.java | 74 ++++++++++ .../results/PyTorchPassThroughResults.java | 19 ++- .../results/TextEmbeddingResults.java | 22 +-- .../inference/trainedmodel/Tokenization.java | 2 +- .../results/FillMaskResultsTests.java | 12 +- .../ml/inference/results/NerResultsTests.java | 11 +- ...lpClassificationInferenceResultsTests.java | 81 +++++++++++ .../PyTorchPassThroughResultsTests.java | 9 +- .../results/TextEmbeddingResultsTests.java | 9 +- .../deployment/DeploymentManager.java | 17 ++- .../ml/inference/nlp/FillMaskProcessor.java | 6 +- .../xpack/ml/inference/nlp/NerProcessor.java | 7 +- .../inference/nlp/PassThroughProcessor.java | 3 +- .../nlp/TextClassificationProcessor.java | 13 +- .../inference/nlp/TextEmbeddingProcessor.java | 3 +- .../nlp/ZeroShotClassificationProcessor.java | 13 +- .../nlp/tokenizers/BertTokenizer.java | 17 +-- .../nlp/tokenizers/TokenizationResult.java | 16 ++- .../deployment/DeploymentManagerTests.java | 3 + .../inference/nlp/FillMaskProcessorTests.java | 4 +- 25 files changed, 429 insertions(+), 113 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResultsTests.java diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index 8738b754f41f8..bf726741cd460 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -927,7 +927,7 @@ end::inference-config-nlp-tokenization-bert-do-lower-case[] tag::inference-config-nlp-tokenization-bert-truncate[] Indicates how tokens are truncated when they exceed `max_sequence_length`. -The default value is `first`. +The default value is `none`. + -- * `none`: No truncation occurs; the inference request receives an error. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 521aec151e503..457a31136315f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.NerResults; +import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; @@ -498,7 +499,13 @@ public List getNamedWriteables() { new NamedWriteableRegistry.Entry(InferenceResults.class, PyTorchPassThroughResults.NAME, PyTorchPassThroughResults::new) ); namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new)); - + namedWriteables.add( + new NamedWriteableRegistry.Entry( + InferenceResults.class, + NlpClassificationInferenceResults.NAME, + NlpClassificationInferenceResults::new + ) + ); // Inference Configs namedWriteables.add( new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME.getPreferredName(), ClassificationConfig::new) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java index e46d61df6fdf3..863efad3e3b0e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java @@ -10,41 +10,27 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import java.io.IOException; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; -public class FillMaskResults extends ClassificationInferenceResults { +public class FillMaskResults extends NlpClassificationInferenceResults { public static final String NAME = "fill_mask_result"; private final String predictedSequence; public FillMaskResults( - double value, String classificationLabel, String predictedSequence, List topClasses, - String topNumClassesField, String resultsField, - Double predictionProbability + Double predictionProbability, + boolean isTruncated ) { - super( - value, - classificationLabel, - topClasses, - List.of(), - topNumClassesField, - resultsField, - PredictionFieldType.STRING, - 0, - predictionProbability, - null - ); + super(classificationLabel, topClasses, resultsField, predictionProbability, isTruncated); this.predictedSequence = predictedSequence; } @@ -54,8 +40,8 @@ public FillMaskResults(StreamInput in) throws IOException { } @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); + public void doWriteTo(StreamOutput out) throws IOException { + super.doWriteTo(out); out.writeString(predictedSequence); } @@ -64,11 +50,9 @@ public String getPredictedSequence() { } @Override - public Map asMap() { - Map map = new LinkedHashMap<>(); + void addMapFields(Map map) { + super.addMapFields(map); map.put(resultsField + "_sequence", predictedSequence); - map.putAll(super.asMap()); - return map; } @Override @@ -77,8 +61,9 @@ public String getWriteableName() { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return super.toXContent(builder, params).field(resultsField + "_sequence", predictedSequence); + public void doXContentBody(XContentBuilder builder, Params params) throws IOException { + super.doXContentBody(builder, params); + builder.field(resultsField + "_sequence", predictedSequence); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java index 43bb4381d9946..dd0907c95ead2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java @@ -20,7 +20,7 @@ import java.util.Objects; import java.util.stream.Collectors; -public class NerResults implements InferenceResults { +public class NerResults extends NlpInferenceResults { public static final String NAME = "ner_result"; public static final String ENTITY_FIELD = "entities"; @@ -30,27 +30,28 @@ public class NerResults implements InferenceResults { private final List entityGroups; - public NerResults(String resultsField, String annotatedResult, List entityGroups) { + public NerResults(String resultsField, String annotatedResult, List entityGroups, boolean isTruncated) { + super(isTruncated); this.entityGroups = Objects.requireNonNull(entityGroups); this.resultsField = Objects.requireNonNull(resultsField); this.annotatedResult = Objects.requireNonNull(annotatedResult); } public NerResults(StreamInput in) throws IOException { + super(in); entityGroups = in.readList(EntityGroup::new); resultsField = in.readString(); annotatedResult = in.readString(); } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.field(resultsField, annotatedResult); builder.startArray("entities"); for (EntityGroup entity : entityGroups) { entity.toXContent(builder, params); } builder.endArray(); - return builder; } @Override @@ -59,18 +60,16 @@ public String getWriteableName() { } @Override - public void writeTo(StreamOutput out) throws IOException { + void doWriteTo(StreamOutput out) throws IOException { out.writeList(entityGroups); out.writeString(resultsField); out.writeString(annotatedResult); } @Override - public Map asMap() { - Map map = new LinkedHashMap<>(); + void addMapFields(Map map) { map.put(resultsField, annotatedResult); map.put(ENTITY_FIELD, entityGroups.stream().map(EntityGroup::toMap).collect(Collectors.toList())); - return map; } @Override @@ -95,15 +94,16 @@ public String getAnnotatedResult() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; NerResults that = (NerResults) o; - return Objects.equals(entityGroups, that.entityGroups) - && Objects.equals(resultsField, that.resultsField) - && Objects.equals(annotatedResult, that.annotatedResult); + return Objects.equals(resultsField, that.resultsField) + && Objects.equals(annotatedResult, that.annotatedResult) + && Objects.equals(entityGroups, that.entityGroups); } @Override public int hashCode() { - return Objects.hash(entityGroups, resultsField, annotatedResult); + return Objects.hash(super.hashCode(), resultsField, annotatedResult, entityGroups); } public static class EntityGroup implements ToXContentObject, Writeable { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java new file mode 100644 index 0000000000000..4c230f2b4e319 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +public class NlpClassificationInferenceResults extends NlpInferenceResults { + + public static final String NAME = "nlp_classification"; + + // Accessed in sub-classes + protected final String resultsField; + private final String classificationLabel; + private final Double predictionProbability; + private final List topClasses; + + public NlpClassificationInferenceResults( + String classificationLabel, + List topClasses, + String resultsField, + Double predictionProbability, + boolean isTruncated + ) { + super(isTruncated); + this.classificationLabel = Objects.requireNonNull(classificationLabel); + this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses); + this.resultsField = resultsField; + this.predictionProbability = predictionProbability; + } + + public NlpClassificationInferenceResults(StreamInput in) throws IOException { + super(in); + this.classificationLabel = in.readString(); + this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new)); + this.resultsField = in.readString(); + this.predictionProbability = in.readOptionalDouble(); + } + + public String getClassificationLabel() { + return classificationLabel; + } + + public List getTopClasses() { + return topClasses; + } + + @Override + public void doWriteTo(StreamOutput out) throws IOException { + out.writeString(classificationLabel); + out.writeCollection(topClasses); + out.writeString(resultsField); + out.writeOptionalDouble(predictionProbability); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + NlpClassificationInferenceResults that = (NlpClassificationInferenceResults) o; + return Objects.equals(resultsField, that.resultsField) + && Objects.equals(classificationLabel, that.classificationLabel) + && Objects.equals(predictionProbability, that.predictionProbability) + && Objects.equals(topClasses, that.topClasses); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), resultsField, classificationLabel, predictionProbability, topClasses); + } + + public Double getPredictionProbability() { + return predictionProbability; + } + + @Override + public String getResultsField() { + return resultsField; + } + + @Override + public Object predictedValue() { + return classificationLabel; + } + + @Override + void addMapFields(Map map) { + map.put(resultsField, classificationLabel); + if (topClasses.isEmpty() == false) { + map.put( + NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, + topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()) + ); + } + if (predictionProbability != null) { + map.put(PREDICTION_PROBABILITY, predictionProbability); + } + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void doXContentBody(XContentBuilder builder, Params params) throws IOException { + builder.field(resultsField, classificationLabel); + if (topClasses.size() > 0) { + builder.field(NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses); + } + if (predictionProbability != null) { + builder.field(PREDICTION_PROBABILITY, predictionProbability); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java new file mode 100644 index 0000000000000..a503c58d3d900 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; + +abstract class NlpInferenceResults implements InferenceResults { + + protected final boolean isTruncated; + + NlpInferenceResults(boolean isTruncated) { + this.isTruncated = isTruncated; + } + + NlpInferenceResults(StreamInput in) throws IOException { + this.isTruncated = in.readBoolean(); + } + + abstract void doXContentBody(XContentBuilder builder, Params params) throws IOException; + + abstract void doWriteTo(StreamOutput out) throws IOException; + + abstract void addMapFields(Map map); + + @Override + public final void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(isTruncated); + doWriteTo(out); + } + + @Override + public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + doXContentBody(builder, params); + if (isTruncated) { + builder.field("is_truncated", isTruncated); + } + return builder; + } + + @Override + public final Map asMap() { + Map map = new LinkedHashMap<>(); + addMapFields(map); + if (isTruncated) { + map.put("is_truncated", isTruncated); + } + return map; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NlpInferenceResults that = (NlpInferenceResults) o; + return isTruncated == that.isTruncated; + } + + @Override + public int hashCode() { + return Objects.hash(isTruncated); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java index 30a7ff792ad75..668ba6d773c26 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java @@ -13,23 +13,24 @@ import java.io.IOException; import java.util.Arrays; -import java.util.LinkedHashMap; import java.util.Map; import java.util.Objects; -public class PyTorchPassThroughResults implements InferenceResults { +public class PyTorchPassThroughResults extends NlpInferenceResults { public static final String NAME = "pass_through_result"; private final double[][] inference; private final String resultsField; - public PyTorchPassThroughResults(String resultsField, double[][] inference) { + public PyTorchPassThroughResults(String resultsField, double[][] inference, boolean isTruncated) { + super(isTruncated); this.inference = inference; this.resultsField = resultsField; } public PyTorchPassThroughResults(StreamInput in) throws IOException { + super(in); inference = in.readArray(StreamInput::readDoubleArray, double[][]::new); resultsField = in.readString(); } @@ -39,9 +40,8 @@ public double[][] getInference() { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.field(resultsField, inference); - return builder; } @Override @@ -50,7 +50,7 @@ public String getWriteableName() { } @Override - public void writeTo(StreamOutput out) throws IOException { + public void doWriteTo(StreamOutput out) throws IOException { out.writeArray(StreamOutput::writeDoubleArray, inference); out.writeString(resultsField); } @@ -61,10 +61,8 @@ public String getResultsField() { } @Override - public Map asMap() { - Map map = new LinkedHashMap<>(); + void addMapFields(Map map) { map.put(resultsField, inference); - return map; } @Override @@ -76,12 +74,13 @@ public Object predictedValue() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; PyTorchPassThroughResults that = (PyTorchPassThroughResults) o; return Arrays.deepEquals(inference, that.inference) && Objects.equals(resultsField, that.resultsField); } @Override public int hashCode() { - return Objects.hash(Arrays.deepHashCode(inference), resultsField); + return Objects.hash(super.hashCode(), resultsField, Arrays.deepHashCode(inference)); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java index 438f458f69e2e..02cd5864ba0aa 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java @@ -13,23 +13,24 @@ import java.io.IOException; import java.util.Arrays; -import java.util.LinkedHashMap; import java.util.Map; import java.util.Objects; -public class TextEmbeddingResults implements InferenceResults { +public class TextEmbeddingResults extends NlpInferenceResults { public static final String NAME = "text_embedding_result"; private final String resultsField; private final double[] inference; - public TextEmbeddingResults(String resultsField, double[] inference) { + public TextEmbeddingResults(String resultsField, double[] inference, boolean isTruncated) { + super(isTruncated); this.inference = inference; this.resultsField = resultsField; } public TextEmbeddingResults(StreamInput in) throws IOException { + super(in); inference = in.readDoubleArray(); resultsField = in.readString(); } @@ -43,8 +44,8 @@ public double[] getInference() { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.field(resultsField, inference); + void doXContentBody(XContentBuilder builder, Params params) throws IOException { + builder.field(resultsField, inference); } @Override @@ -53,16 +54,14 @@ public String getWriteableName() { } @Override - public void writeTo(StreamOutput out) throws IOException { + void doWriteTo(StreamOutput out) throws IOException { out.writeDoubleArray(inference); out.writeString(resultsField); } @Override - public Map asMap() { - Map map = new LinkedHashMap<>(); + void addMapFields(Map map) { map.put(resultsField, inference); - return map; } @Override @@ -74,12 +73,13 @@ public Object predictedValue() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; TextEmbeddingResults that = (TextEmbeddingResults) o; - return Arrays.equals(inference, that.inference) && Objects.equals(resultsField, that.resultsField); + return Objects.equals(resultsField, that.resultsField) && Arrays.equals(inference, that.inference); } @Override public int hashCode() { - return Objects.hash(Arrays.hashCode(inference), resultsField); + return Objects.hash(super.hashCode(), resultsField, Arrays.hashCode(inference)); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java index 3e6d5e5f522fd..c3e187b4c46d4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java @@ -47,7 +47,7 @@ public String toString() { private static final int DEFAULT_MAX_SEQUENCE_LENGTH = 512; private static final boolean DEFAULT_DO_LOWER_CASE = false; private static final boolean DEFAULT_WITH_SPECIAL_TOKENS = true; - private static final Truncate DEFAULT_TRUNCATION = Truncate.FIRST; + private static final Truncate DEFAULT_TRUNCATION = Truncate.NONE; static void declareCommonFields(ConstructingObjectParser parser) { parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), DO_LOWER_CASE); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java index b87c9d7bc9efb..e78a174d493ee 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java @@ -18,8 +18,10 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; public class FillMaskResultsTests extends AbstractWireSerializingTestCase { @@ -36,13 +38,12 @@ protected FillMaskResults createTestInstance() { resultList.add(TopClassEntryTests.createRandomTopClassEntry()); } return new FillMaskResults( - 0.0, randomAlphaOfLength(10), randomAlphaOfLength(10), resultList, - DEFAULT_TOP_CLASSES_RESULTS_FIELD, DEFAULT_RESULTS_FIELD, - randomDouble() + randomDouble(), + randomBoolean() ); } @@ -54,6 +55,11 @@ public void testAsMap() { assertThat(asMap.get(PREDICTION_PROBABILITY), equalTo(testInstance.getPredictionProbability())); assertThat(asMap.get(DEFAULT_RESULTS_FIELD + "_sequence"), equalTo(testInstance.getPredictedSequence())); List> resultList = (List>) asMap.get(DEFAULT_TOP_CLASSES_RESULTS_FIELD); + if (testInstance.isTruncated) { + assertThat(asMap.get("is_truncated"), is(true)); + } else { + assertThat(asMap, not(hasKey("is_truncated"))); + } if (testInstance.getTopClasses().size() == 0) { assertThat(resultList, is(nullValue())); } else { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java index a3ce28f3d46a9..ab0ceea78aae6 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java @@ -17,7 +17,10 @@ import static org.elasticsearch.xpack.core.ml.inference.results.NerResults.ENTITY_FIELD; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; public class NerResultsTests extends InferenceResultsTestCase { @Override @@ -40,7 +43,8 @@ protected NerResults createTestInstance() { randomIntBetween(-1, 5), randomIntBetween(5, 10) ) - ).limit(numEntities).collect(Collectors.toList()) + ).limit(numEntities).collect(Collectors.toList()), + randomBoolean() ); } @@ -54,6 +58,11 @@ public void testAsMap() { } assertThat(resultList, hasSize(testInstance.getEntityGroups().size())); assertThat(asMap.get(testInstance.getResultsField()), equalTo(testInstance.getAnnotatedResult())); + if (testInstance.isTruncated) { + assertThat(asMap.get("is_truncated"), is(true)); + } else { + assertThat(asMap, not(hasKey("is_truncated"))); + } for (int i = 0; i < testInstance.getEntityGroups().size(); i++) { NerResults.EntityGroup entity = testInstance.getEntityGroups().get(i); Map map = resultList.get(i); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResultsTests.java new file mode 100644 index 0000000000000..92d5272eac4ff --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResultsTests.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestDocument; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult; +import static org.hamcrest.Matchers.equalTo; + +public class NlpClassificationInferenceResultsTests extends InferenceResultsTestCase { + + public static NlpClassificationInferenceResults createRandomResults() { + return new NlpClassificationInferenceResults( + randomAlphaOfLength(10), + randomBoolean() + ? null + : Stream.generate(TopClassEntryTests::createRandomTopClassEntry) + .limit(randomIntBetween(0, 10)) + .collect(Collectors.toList()), + randomAlphaOfLength(10), + randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false), + randomBoolean() + ); + } + + @SuppressWarnings("unchecked") + public void testWriteResultsWithTopClasses() { + List entries = Arrays.asList( + new TopClassEntry("foo", 0.7, 0.7), + new TopClassEntry("bar", 0.2, 0.2), + new TopClassEntry("baz", 0.1, 0.1) + ); + NlpClassificationInferenceResults result = new NlpClassificationInferenceResults( + "foo", + entries, + "my_results", + 0.7, + randomBoolean() + ); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + writeResult(result, document, "result_field", "test"); + + List list = document.getFieldValue("result_field.top_classes", List.class); + assertThat(list.size(), equalTo(3)); + + for (int i = 0; i < 3; i++) { + Map map = (Map) list.get(i); + assertThat(map, equalTo(entries.get(i).asValueMap())); + } + + assertThat(document.getFieldValue("result_field.my_results", String.class), equalTo("foo")); + } + + @Override + protected NlpClassificationInferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return NlpClassificationInferenceResults::new; + } + + @Override + void assertFieldValues(NlpClassificationInferenceResults createdInstance, IngestDocument document, String resultsField) { + String path = resultsField + "." + createdInstance.getResultsField(); + assertThat(document.getFieldValue(path, String.class), equalTo(createdInstance.predictedValue())); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java index 70590dd6d8ee4..e33b5274231a9 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java @@ -14,6 +14,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; public class PyTorchPassThroughResultsTests extends InferenceResultsTestCase { @Override @@ -32,14 +33,18 @@ protected PyTorchPassThroughResults createTestInstance() { } } - return new PyTorchPassThroughResults(DEFAULT_RESULTS_FIELD, arr); + return new PyTorchPassThroughResults(DEFAULT_RESULTS_FIELD, arr, randomBoolean()); } public void testAsMap() { PyTorchPassThroughResults testInstance = createTestInstance(); Map asMap = testInstance.asMap(); - assertThat(asMap.keySet(), hasSize(1)); + int size = testInstance.isTruncated ? 2 : 1; + assertThat(asMap.keySet(), hasSize(size)); assertArrayEquals(testInstance.getInference(), (double[][]) asMap.get(DEFAULT_RESULTS_FIELD)); + if (testInstance.isTruncated) { + assertThat(asMap.get("is_truncated"), is(true)); + } } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java index 3c27af0790ea7..c255c3de8cbfd 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java @@ -14,6 +14,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; public class TextEmbeddingResultsTests extends InferenceResultsTestCase { @Override @@ -29,14 +30,18 @@ protected TextEmbeddingResults createTestInstance() { arr[i] = randomDouble(); } - return new TextEmbeddingResults(DEFAULT_RESULTS_FIELD, arr); + return new TextEmbeddingResults(DEFAULT_RESULTS_FIELD, arr, randomBoolean()); } public void testAsMap() { TextEmbeddingResults testInstance = createTestInstance(); Map asMap = testInstance.asMap(); - assertThat(asMap.keySet(), hasSize(1)); + int size = testInstance.isTruncated ? 2 : 1; + assertThat(asMap.keySet(), hasSize(size)); assertArrayEquals(testInstance.getInference(), (double[]) asMap.get(DEFAULT_RESULTS_FIELD), 1e-10); + if (testInstance.isTruncated) { + assertThat(asMap.get("is_truncated"), is(true)); + } } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index dae5b83f4332e..aba23ca566640 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -238,7 +238,16 @@ public void infer( } final long requestId = requestIdCounter.getAndIncrement(); - InferenceAction inferenceAction = new InferenceAction(requestId, timeout, processContext, config, doc, threadPool, listener); + InferenceAction inferenceAction = new InferenceAction( + task.getModelId(), + requestId, + timeout, + processContext, + config, + doc, + threadPool, + listener + ); try { processContext.executorService.execute(inferenceAction); } catch (Exception e) { @@ -247,6 +256,7 @@ public void infer( } static class InferenceAction extends AbstractRunnable { + private final String modelId; private final long requestId; private final TimeValue timeout; private final Scheduler.Cancellable timeoutHandler; @@ -257,6 +267,7 @@ static class InferenceAction extends AbstractRunnable { private final AtomicBoolean notified = new AtomicBoolean(); InferenceAction( + String modelId, long requestId, TimeValue timeout, ProcessContext processContext, @@ -265,6 +276,7 @@ static class InferenceAction extends AbstractRunnable { ThreadPool threadPool, ActionListener listener ) { + this.modelId = modelId; this.requestId = requestId; this.timeout = timeout; this.processContext = processContext; @@ -321,6 +333,9 @@ protected void doRun() throws Exception { assert config instanceof NlpConfig; NlpTask.Request request = processor.getRequestBuilder((NlpConfig) config).buildRequest(text, requestIdStr); logger.trace(() -> "Inference Request " + request.processInput.utf8ToString()); + if (request.tokenization.anyTruncated()) { + logger.debug("[{}] [{}] input truncated", modelId, requestId); + } PyTorchResultProcessor.PendingResult pendingResult = processContext.getResultProcessor().registerRequest(requestIdStr); processContext.process.get().writeInferenceRequest(request.processInput); waitForResult( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java index e5d9b111b292c..3e54d7ab6ca16 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java @@ -24,7 +24,6 @@ import java.util.Optional; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; -import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD; public class FillMaskProcessor implements NlpTask.Processor { @@ -100,16 +99,15 @@ static InferenceResults processResult( } } return new FillMaskResults( - scoreAndIndices[0].index, tokenization.getFromVocab(scoreAndIndices[0].index), tokenization.getTokenizations() .get(0) .getInput() .replace(BertTokenizer.MASK_TOKEN, tokenization.getFromVocab(scoreAndIndices[0].index)), results, - DEFAULT_TOP_CLASSES_RESULTS_FIELD, Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD), - scoreAndIndices[0].score + scoreAndIndices[0].score, + tokenization.anyTruncated() ); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java index 238c5550b0367..2674d47fdaa4d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java @@ -213,7 +213,12 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchRe ? tokenization.getTokenizations().get(0).getInput().toLowerCase(Locale.ROOT) : tokenization.getTokenizations().get(0).getInput() ); - return new NerResults(resultsField, buildAnnotatedText(tokenization.getTokenizations().get(0).getInput(), entities), entities); + return new NerResults( + resultsField, + buildAnnotatedText(tokenization.getTokenizations().get(0).getInput(), entities), + entities, + tokenization.anyTruncated() + ); } /** diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java index 61c8a4838433e..73bebe136a68e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java @@ -53,7 +53,8 @@ private static InferenceResults processResult(TokenizationResult tokenization, P // TODO - process all results in the batch return new PyTorchPassThroughResults( Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD), - pyTorchResult.getInferenceResult()[0] + pyTorchResult.getInferenceResult()[0], + tokenization.anyTruncated() ); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java index 8b8f9fa53d0a8..77646130a9435 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java @@ -7,12 +7,11 @@ package org.elasticsearch.xpack.ml.inference.nlp; -import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig; import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; @@ -26,7 +25,6 @@ import java.util.stream.IntStream; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; -import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD; public class TextClassificationProcessor implements NlpTask.Processor { @@ -109,20 +107,15 @@ static InferenceResults processResult( .mapToInt(i -> i) .toArray(); - return new ClassificationInferenceResults( - sortedIndices[0], + return new NlpClassificationInferenceResults( labels.get(sortedIndices[0]), Arrays.stream(sortedIndices) .mapToObj(i -> new TopClassEntry(labels.get(i), normalizedScores[i])) .limit(numTopClasses) .collect(Collectors.toList()), - List.of(), - DEFAULT_TOP_CLASSES_RESULTS_FIELD, Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD), - PredictionFieldType.STRING, - 0, normalizedScores[sortedIndices[0]], - null + tokenization.anyTruncated() ); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java index b5b576b441497..77810c00774da 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java @@ -50,7 +50,8 @@ private static InferenceResults processResult(TokenizationResult tokenization, P // TODO - process all results in the batch return new TextEmbeddingResults( Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD), - pyTorchResult.getInferenceResult()[0][0] + pyTorchResult.getInferenceResult()[0][0], + tokenization.anyTruncated() ); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java index 69e87ee64b3bc..1959450d56b17 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java @@ -8,12 +8,11 @@ package org.elasticsearch.xpack.ml.inference.nlp; import org.elasticsearch.common.logging.LoggerMessageFormat; -import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult; @@ -31,7 +30,6 @@ import java.util.stream.IntStream; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; -import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD; public class ZeroShotClassificationProcessor implements NlpTask.Processor { @@ -198,17 +196,12 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchRe .mapToInt(i -> i) .toArray(); - return new ClassificationInferenceResults( - sortedIndices[0], + return new NlpClassificationInferenceResults( labels[sortedIndices[0]], Arrays.stream(sortedIndices).mapToObj(i -> new TopClassEntry(labels[i], normalizedScores[i])).collect(Collectors.toList()), - List.of(), - DEFAULT_TOP_CLASSES_RESULTS_FIELD, Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD), - PredictionFieldType.STRING, - 0, normalizedScores[sortedIndices[0]], - null + tokenization.anyTruncated() ); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java index 16bc124e6e44a..0d2a8e783b0fb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java @@ -118,10 +118,12 @@ public TokenizationResult.Tokenization tokenize(String seq) { List wordPieceTokens = innerResult.v1(); List tokenPositionMap = innerResult.v2(); int numTokens = withSpecialTokens ? wordPieceTokens.size() + 2 : wordPieceTokens.size(); + boolean isTruncated = false; if (numTokens > maxSequenceLength) { switch (truncate) { case FIRST: case SECOND: + isTruncated = true; wordPieceTokens = wordPieceTokens.subList(0, withSpecialTokens ? maxSequenceLength - 2 : maxSequenceLength); break; case NONE: @@ -158,7 +160,7 @@ public TokenizationResult.Tokenization tokenize(String seq) { tokenMap[i] = SPECIAL_TOKEN_POSITION; } - return new TokenizationResult.Tokenization(seq, tokens, tokenIds, tokenMap); + return new TokenizationResult.Tokenization(seq, isTruncated, tokens, tokenIds, tokenMap); } @Override @@ -175,9 +177,11 @@ public TokenizationResult.Tokenization tokenize(String seq1, String seq2) { // [CLS] seq1 [SEP] seq2 [SEP] int numTokens = wordPieceTokenSeq1s.size() + wordPieceTokenSeq2s.size() + 3; + boolean isTruncated = false; if (numTokens > maxSequenceLength) { switch (truncate) { case FIRST: + isTruncated = true; if (wordPieceTokenSeq2s.size() > maxSequenceLength - 3) { throw ExceptionsHelper.badRequestException( "Attempting truncation [{}] but input is too large for the second sequence. " @@ -191,6 +195,7 @@ public TokenizationResult.Tokenization tokenize(String seq1, String seq2) { wordPieceTokenSeq1s = wordPieceTokenSeq1s.subList(0, maxSequenceLength - 3 - wordPieceTokenSeq2s.size()); break; case SECOND: + isTruncated = true; if (wordPieceTokenSeq1s.size() > maxSequenceLength - 3) { throw ExceptionsHelper.badRequestException( "Attempting truncation [{}] but input is too large for the first sequence. " @@ -245,15 +250,7 @@ public TokenizationResult.Tokenization tokenize(String seq1, String seq2) { tokenIds[i] = vocab.get(SEPARATOR_TOKEN); tokenMap[i] = SPECIAL_TOKEN_POSITION; - // TODO handle seq1 truncation - if (tokenIds.length > maxSequenceLength) { - throw ExceptionsHelper.badRequestException( - "Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]", - tokenIds.length, - maxSequenceLength - ); - } - return new TokenizationResult.Tokenization(seq1 + seq2, tokens, tokenIds, tokenMap); + return new TokenizationResult.Tokenization(seq1 + seq2, isTruncated, tokens, tokenIds, tokenMap); } private Tuple, List> innerTokenize(String seq) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java index 418cf7bff4746..b50c9548504a9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/TokenizationResult.java @@ -21,6 +21,10 @@ public TokenizationResult(List vocab) { this.maxLength = -1; } + public boolean anyTruncated() { + return tokenizations.stream().anyMatch(Tokenization::isTruncated); + } + public String getFromVocab(int tokenId) { return vocab.get(tokenId); } @@ -29,9 +33,9 @@ public List getTokenizations() { return tokenizations; } - public void addTokenization(String input, String[] tokens, int[] tokenIds, int[] tokenMap) { + public void addTokenization(String input, boolean isTruncated, String[] tokens, int[] tokenIds, int[] tokenMap) { maxLength = Math.max(maxLength, tokenIds.length); - tokenizations.add(new Tokenization(input, tokens, tokenIds, tokenMap)); + tokenizations.add(new Tokenization(input, isTruncated, tokens, tokenIds, tokenMap)); } public void addTokenization(Tokenization tokenization) { @@ -49,14 +53,16 @@ public static class Tokenization { private final String[] tokens; private final int[] tokenIds; private final int[] tokenMap; + private final boolean truncated; - public Tokenization(String input, String[] tokens, int[] tokenIds, int[] tokenMap) { + public Tokenization(String input, boolean truncated, String[] tokens, int[] tokenIds, int[] tokenMap) { assert tokens.length == tokenIds.length; assert tokenIds.length == tokenMap.length; this.inputSeqs = input; this.tokens = tokens; this.tokenIds = tokenIds; this.tokenMap = tokenMap; + this.truncated = truncated; } /** @@ -91,5 +97,9 @@ public int[] getTokenMap() { public String getInput() { return inputSeqs; } + + public boolean isTruncated() { + return truncated; + } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java index 775b3f305fe08..1b3d55b4fc8e2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java @@ -53,6 +53,7 @@ public void testInferListenerOnlyCalledOnce() { ListenerCounter listener = new ListenerCounter(); DeploymentManager.InferenceAction action = new DeploymentManager.InferenceAction( + "test-model", 1, TimeValue.MAX_VALUE, processContext, @@ -72,6 +73,7 @@ public void testInferListenerOnlyCalledOnce() { assertThat(listener.responseCounts, equalTo(1)); action = new DeploymentManager.InferenceAction( + "test-model", 1, TimeValue.MAX_VALUE, processContext, @@ -91,6 +93,7 @@ public void testInferListenerOnlyCalledOnce() { assertThat(listener.responseCounts, equalTo(1)); action = new DeploymentManager.InferenceAction( + "test-model", 1, TimeValue.MAX_VALUE, processContext, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java index 8b36f2fd6ad6c..e43d1f3e41d60 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java @@ -50,7 +50,7 @@ public void testProcessResults() { int[] tokenIds = new int[] { 0, 1, 2, 3, 4, 5 }; TokenizationResult tokenization = new TokenizationResult(vocab); - tokenization.addTokenization(input, tokens, tokenIds, tokenMap); + tokenization.addTokenization(input, false, tokens, tokenIds, tokenMap); String resultsField = randomAlphaOfLength(10); FillMaskResults result = (FillMaskResults) FillMaskProcessor.processResult( @@ -73,7 +73,7 @@ public void testProcessResults() { public void testProcessResults_GivenMissingTokens() { TokenizationResult tokenization = new TokenizationResult(Collections.emptyList()); - tokenization.addTokenization("", new String[] {}, new int[] {}, new int[] {}); + tokenization.addTokenization("", false, new String[] {}, new int[] {}, new int[] {}); PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][][] { { {} } }, 0L, null); assertThat(