Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] update truncation default & adding field output when input is truncated #79942

Merged
2 changes: 1 addition & 1 deletion docs/reference/ml/ml-shared.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ end::inference-config-nlp-tokenization-bert-do-lower-case[]

tag::inference-config-nlp-tokenization-bert-truncate[]
Indicates how tokens are truncated when they exceed `max_sequence_length`.
The default value is `first`.
The default value is `none`.
+
--
* `none`: No truncation occurs; the inference request receives an error.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
Expand Down Expand Up @@ -498,7 +499,13 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
new NamedWriteableRegistry.Entry(InferenceResults.class, PyTorchPassThroughResults.NAME, PyTorchPassThroughResults::new)
);
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceResults.class,
NlpClassificationInferenceResults.NAME,
NlpClassificationInferenceResults::new
)
);
// Inference Configs
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME.getPreferredName(), ClassificationConfig::new)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,27 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;

import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class FillMaskResults extends ClassificationInferenceResults {
public class FillMaskResults extends NlpClassificationInferenceResults {

public static final String NAME = "fill_mask_result";

private final String predictedSequence;

public FillMaskResults(
double value,
String classificationLabel,
String predictedSequence,
List<TopClassEntry> topClasses,
String topNumClassesField,
String resultsField,
Double predictionProbability
Double predictionProbability,
boolean isTruncated
) {
super(
value,
classificationLabel,
topClasses,
List.of(),
topNumClassesField,
resultsField,
PredictionFieldType.STRING,
0,
predictionProbability,
null
);
super(classificationLabel, topClasses, resultsField, predictionProbability, isTruncated);
this.predictedSequence = predictedSequence;
}

Expand All @@ -54,8 +40,8 @@ public FillMaskResults(StreamInput in) throws IOException {
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
public void doWriteTo(StreamOutput out) throws IOException {
super.doWriteTo(out);
out.writeString(predictedSequence);
}

Expand All @@ -64,11 +50,9 @@ public String getPredictedSequence() {
}

@Override
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
void addMapFields(Map<String, Object> map) {
super.addMapFields(map);
map.put(resultsField + "_sequence", predictedSequence);
map.putAll(super.asMap());
return map;
}

@Override
Expand All @@ -77,8 +61,9 @@ public String getWriteableName() {
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return super.toXContent(builder, params).field(resultsField + "_sequence", predictedSequence);
public void doXContentBody(XContentBuilder builder, Params params) throws IOException {
super.doXContentBody(builder, params);
builder.field(resultsField + "_sequence", predictedSequence);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.util.Objects;
import java.util.stream.Collectors;

public class NerResults implements InferenceResults {
public class NerResults extends NlpInferenceResults {

public static final String NAME = "ner_result";
public static final String ENTITY_FIELD = "entities";
Expand All @@ -30,27 +30,28 @@ public class NerResults implements InferenceResults {

private final List<EntityGroup> entityGroups;

public NerResults(String resultsField, String annotatedResult, List<EntityGroup> entityGroups) {
public NerResults(String resultsField, String annotatedResult, List<EntityGroup> entityGroups, boolean isTruncated) {
super(isTruncated);
this.entityGroups = Objects.requireNonNull(entityGroups);
this.resultsField = Objects.requireNonNull(resultsField);
this.annotatedResult = Objects.requireNonNull(annotatedResult);
}

public NerResults(StreamInput in) throws IOException {
super(in);
entityGroups = in.readList(EntityGroup::new);
resultsField = in.readString();
annotatedResult = in.readString();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
void doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.field(resultsField, annotatedResult);
builder.startArray("entities");
for (EntityGroup entity : entityGroups) {
entity.toXContent(builder, params);
}
builder.endArray();
return builder;
}

@Override
Expand All @@ -59,18 +60,16 @@ public String getWriteableName() {
}

@Override
public void writeTo(StreamOutput out) throws IOException {
void doWriteTo(StreamOutput out) throws IOException {
out.writeList(entityGroups);
out.writeString(resultsField);
out.writeString(annotatedResult);
}

@Override
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
void addMapFields(Map<String, Object> map) {
map.put(resultsField, annotatedResult);
map.put(ENTITY_FIELD, entityGroups.stream().map(EntityGroup::toMap).collect(Collectors.toList()));
return map;
}

@Override
Expand All @@ -95,15 +94,16 @@ public String getAnnotatedResult() {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
NerResults that = (NerResults) o;
return Objects.equals(entityGroups, that.entityGroups)
&& Objects.equals(resultsField, that.resultsField)
&& Objects.equals(annotatedResult, that.annotatedResult);
return Objects.equals(resultsField, that.resultsField)
&& Objects.equals(annotatedResult, that.annotatedResult)
&& Objects.equals(entityGroups, that.entityGroups);
}

@Override
public int hashCode() {
return Objects.hash(entityGroups, resultsField, annotatedResult);
return Objects.hash(super.hashCode(), resultsField, annotatedResult, entityGroups);
}

public static class EntityGroup implements ToXContentObject, Writeable {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

public class NlpClassificationInferenceResults extends NlpInferenceResults {

public static final String NAME = "nlp_classification";

// Accessed in sub-classes
protected final String resultsField;
private final String classificationLabel;
private final Double predictionProbability;
private final List<TopClassEntry> topClasses;

public NlpClassificationInferenceResults(
String classificationLabel,
List<TopClassEntry> topClasses,
String resultsField,
Double predictionProbability,
boolean isTruncated
) {
super(isTruncated);
this.classificationLabel = Objects.requireNonNull(classificationLabel);
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
this.resultsField = resultsField;
this.predictionProbability = predictionProbability;
}

public NlpClassificationInferenceResults(StreamInput in) throws IOException {
super(in);
this.classificationLabel = in.readString();
this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new));
this.resultsField = in.readString();
this.predictionProbability = in.readOptionalDouble();
}

public String getClassificationLabel() {
return classificationLabel;
}

public List<TopClassEntry> getTopClasses() {
return topClasses;
}

@Override
public void doWriteTo(StreamOutput out) throws IOException {
out.writeString(classificationLabel);
out.writeCollection(topClasses);
out.writeString(resultsField);
out.writeOptionalDouble(predictionProbability);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
NlpClassificationInferenceResults that = (NlpClassificationInferenceResults) o;
return Objects.equals(resultsField, that.resultsField)
&& Objects.equals(classificationLabel, that.classificationLabel)
&& Objects.equals(predictionProbability, that.predictionProbability)
&& Objects.equals(topClasses, that.topClasses);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), resultsField, classificationLabel, predictionProbability, topClasses);
}

public Double getPredictionProbability() {
return predictionProbability;
}

@Override
public String getResultsField() {
return resultsField;
}

@Override
public Object predictedValue() {
return classificationLabel;
}

@Override
void addMapFields(Map<String, Object> map) {
map.put(resultsField, classificationLabel);
if (topClasses.isEmpty() == false) {
map.put(
NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD,
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList())
);
}
if (predictionProbability != null) {
map.put(PREDICTION_PROBABILITY, predictionProbability);
}
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.field(resultsField, classificationLabel);
if (topClasses.size() > 0) {
builder.field(NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses);
}
if (predictionProbability != null) {
builder.field(PREDICTION_PROBABILITY, predictionProbability);
}
}
}
Loading