Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NLP] Support the different mask tokens used by NLP models for Fill Mask #97453

Merged
merged 2 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/reference/ml/ml-shared.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ BERT-style tokenization is to be performed with the enclosed settings.
end::inference-config-nlp-tokenization-bert[]

tag::inference-config-nlp-tokenization-bert-ja[]
experimental:[] BERT-style tokenization for Japanese text is to be performed
experimental:[] BERT-style tokenization for Japanese text is to be performed
with the enclosed settings.
end::inference-config-nlp-tokenization-bert-ja[]

Expand Down Expand Up @@ -1125,6 +1125,10 @@ The field that is added to incoming documents to contain the inference
prediction. Defaults to `predicted_value`.
end::inference-config-results-field[]

tag::inference-config-mask-token[]
The string/token which will be removed from incoming documents and replaced with the inference prediction(s). In a response, this field contains the mask token for the specified model/tokenizer. Each model and tokenizer has a predefined mask token which cannot be changed. Thus, it is recommended not to set this value in requests. However, if this field is present in a request, its value must match the predefined value for that model/tokenizer, otherwise the request will fail.
end::inference-config-mask-token[]

tag::inference-config-results-field-processor[]
The field that is added to incoming documents to contain the inference
prediction. Defaults to the `results_field` value of the {dfanalytics-job} that was
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,17 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-results-field]
(string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-top-classes-results-field]
======

`fill_mask`::::
(Optional, object)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-fill-mask]
+
.Properties of fill_mask inference
[%collapsible%open]
======
`mask_token`::::
(Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-mask-token]

`tokenization`::::
(Optional, object)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public class BertJapaneseTokenization extends Tokenization {

public static final ParseField NAME = new ParseField("bert_ja");

public static final String MASK_TOKEN = "[MASK]";

public static ConstructingObjectParser<BertJapaneseTokenization, Void> createJpParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<BertJapaneseTokenization, Void> parser = new ConstructingObjectParser<>(
"bert_japanese_tokenization",
Expand Down Expand Up @@ -61,6 +63,11 @@ XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IO
return builder;
}

@Override
public String getMaskToken() {
return MASK_TOKEN;
}

@Override
public String getWriteableName() {
return BertJapaneseTokenization.NAME.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public class BertTokenization extends Tokenization {

public static final ParseField NAME = new ParseField("bert");

public static final String MASK_TOKEN = "[MASK]";
maxhniebergall marked this conversation as resolved.
Show resolved Hide resolved

public static ConstructingObjectParser<BertTokenization, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<BertTokenization, Void> parser = new ConstructingObjectParser<>(
"bert_tokenization",
Expand Down Expand Up @@ -67,6 +69,11 @@ XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IO
return builder;
}

@Override
public String getMaskToken() {
return MASK_TOKEN;
}

@Override
public String getWriteableName() {
return NAME.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.Version;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;

import java.io.IOException;
import java.util.Objects;
Expand All @@ -26,6 +29,7 @@
public class FillMaskConfig implements NlpConfig {

public static final String NAME = "fill_mask";
public static final String MASK_TOKEN = "mask_token";
public static final int DEFAULT_NUM_RESULTS = 5;

public static FillMaskConfig fromXContentStrict(XContentParser parser) {
Expand All @@ -36,6 +40,7 @@ public static FillMaskConfig fromXContentLenient(XContentParser parser) {
return LENIENT_PARSER.apply(parser, null).build();
}

private static final ParseField MASK_TOKEN_FIELD = new ParseField(MASK_TOKEN);
private static final ObjectParser<FillMaskConfig.Builder, Void> STRICT_PARSER = createParser(false);
private static final ObjectParser<FillMaskConfig.Builder, Void> LENIENT_PARSER = createParser(true);

Expand All @@ -57,6 +62,7 @@ private static ObjectParser<FillMaskConfig.Builder, Void> createParser(boolean i
);
parser.declareInt(Builder::setNumTopClasses, NUM_TOP_CLASSES);
parser.declareString(Builder::setResultsField, RESULTS_FIELD);
parser.declareString(Builder::setMaskToken, MASK_TOKEN_FIELD);
return parser;
}

Expand Down Expand Up @@ -101,6 +107,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (resultsField != null) {
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
}
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) {
builder.field(MASK_TOKEN_FIELD.getPreferredName(), tokenization.getMaskToken());
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -182,8 +191,9 @@ public boolean isAllocateOnly() {
public static class Builder {
private VocabularyConfig vocabularyConfig;
private Tokenization tokenization;
private int numTopClasses;
private Integer numTopClasses;
private String resultsField;
private String maskToken;

Builder() {}

Expand Down Expand Up @@ -214,8 +224,27 @@ public FillMaskConfig.Builder setResultsField(String resultsField) {
return this;
}

public FillMaskConfig build() {
public FillMaskConfig.Builder setMaskToken(String maskToken) {
this.maskToken = maskToken;
return this;
}

public FillMaskConfig build() throws IllegalArgumentException {
if (tokenization == null) {
tokenization = Tokenization.createDefault();
}
validateMaskToken(tokenization.getMaskToken());
return new FillMaskConfig(vocabularyConfig, tokenization, numTopClasses, resultsField);
}

private void validateMaskToken(String tokenizationMaskToken) throws IllegalArgumentException {
if (maskToken != null) {
if (maskToken.equals(tokenizationMaskToken) == false) {
throw new IllegalArgumentException(
Strings.format("Mask token requested was [%s] but must be [%s] for this model", maskToken, tokenizationMaskToken)
);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
public class MPNetTokenization extends Tokenization {

public static final ParseField NAME = new ParseField("mpnet");
public static final String MASK_TOKEN = "<mask>";

public static ConstructingObjectParser<MPNetTokenization, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<MPNetTokenization, Void> parser = new ConstructingObjectParser<>(
Expand Down Expand Up @@ -67,6 +68,11 @@ XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IO
return builder;
}

@Override
public String getMaskToken() {
return MASK_TOKEN;
}

@Override
public String getWriteableName() {
return NAME.getPreferredName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

public class RobertaTokenization extends Tokenization {
public static final String NAME = "roberta";
public static final String MASK_TOKEN = "<mask>";
private static final boolean DEFAULT_ADD_PREFIX_SPACE = false;

private static final ParseField ADD_PREFIX_SPACE = new ParseField("add_prefix_space");
Expand Down Expand Up @@ -99,6 +100,11 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(addPrefixSpace);
}

@Override
public String getMaskToken() {
return MASK_TOKEN;
}

@Override
XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.field(ADD_PREFIX_SPACE.getPreferredName(), addPrefixSpace);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ public void writeTo(StreamOutput out) throws IOException {
}
}

public abstract String getMaskToken();

abstract XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException;

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

public class XLMRobertaTokenization extends Tokenization {
public static final String NAME = "xlm_roberta";
public static final String MASK_TOKEN = "<mask>";

public static ConstructingObjectParser<XLMRobertaTokenization, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<XLMRobertaTokenization, Void> parser = new ConstructingObjectParser<>(
Expand Down Expand Up @@ -81,6 +82,11 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
}

@Override
public String getMaskToken() {
return MASK_TOKEN;
}

@Override
XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,78 @@ public static FillMaskConfig createRandom() {
randomBoolean() ? null : randomAlphaOfLength(5)
);
}

public void testCreateBuilder() {

VocabularyConfig vocabularyConfig = randomBoolean() ? null : VocabularyConfigTests.createRandom();

Tokenization tokenization = randomBoolean()
? null
: randomFrom(
BertTokenizationTests.createRandom(),
MPNetTokenizationTests.createRandom(),
RobertaTokenizationTests.createRandom()
);

Integer numTopClasses = randomBoolean() ? null : randomInt();

String resultsField = randomBoolean() ? null : randomAlphaOfLength(5);

new FillMaskConfig.Builder().setVocabularyConfig(vocabularyConfig)
.setTokenization(tokenization)
.setNumTopClasses(numTopClasses)
.setResultsField(resultsField)
.setMaskToken(tokenization == null ? null : tokenization.getMaskToken())
.build();
}

public void testCreateBuilderWithException() throws Exception {

VocabularyConfig vocabularyConfig = randomBoolean() ? null : VocabularyConfigTests.createRandom();

Tokenization tokenization = randomBoolean()
? null
: randomFrom(
BertTokenizationTests.createRandom(),
MPNetTokenizationTests.createRandom(),
RobertaTokenizationTests.createRandom()
);

Integer numTopClasses = randomBoolean() ? null : randomInt();

String resultsField = randomBoolean() ? null : randomAlphaOfLength(5);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> {
FillMaskConfig fmc = new FillMaskConfig.Builder().setVocabularyConfig(vocabularyConfig)
.setTokenization(tokenization)
.setNumTopClasses(numTopClasses)
.setResultsField(resultsField)
.setMaskToken("not a real mask token")
.build();
});

}

public void testCreateBuilderWithNullMaskToken() {

VocabularyConfig vocabularyConfig = randomBoolean() ? null : VocabularyConfigTests.createRandom();

Tokenization tokenization = randomBoolean()
? null
: randomFrom(
BertTokenizationTests.createRandom(),
MPNetTokenizationTests.createRandom(),
RobertaTokenizationTests.createRandom()
);

Integer numTopClasses = randomBoolean() ? null : randomInt();

String resultsField = randomBoolean() ? null : randomAlphaOfLength(5);

new FillMaskConfig.Builder().setVocabularyConfig(vocabularyConfig)
.setTokenization(tokenization)
.setNumTopClasses(numTopClasses)
.setResultsField(resultsField)
.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
Expand Down Expand Up @@ -39,7 +40,7 @@ public class BertTokenizer extends NlpTokenizer {
public static final String SEPARATOR_TOKEN = "[SEP]";
public static final String PAD_TOKEN = "[PAD]";
public static final String CLASS_TOKEN = "[CLS]";
public static final String MASK_TOKEN = "[MASK]";
public static final String MASK_TOKEN = BertTokenization.MASK_TOKEN;

private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;

import java.util.Collections;
Expand All @@ -26,7 +27,7 @@ public class MPNetTokenizer extends BertTokenizer {
public static final String SEPARATOR_TOKEN = "</s>";
public static final String PAD_TOKEN = "<pad>";
public static final String CLASS_TOKEN = "<s>";
public static final String MASK_TOKEN = "<mask>";
public static final String MASK_TOKEN = MPNetTokenization.MASK_TOKEN;
private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);

protected MPNetTokenizer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class RobertaTokenizer extends NlpTokenizer {
public static final String SEPARATOR_TOKEN = "</s>";
public static final String PAD_TOKEN = "<pad>";
public static final String CLASS_TOKEN = "<s>";
public static final String MASK_TOKEN = "<mask>";
public static final String MASK_TOKEN = RobertaTokenization.MASK_TOKEN;

private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class XLMRobertaTokenizer extends NlpTokenizer {
public static final String SEPARATOR_TOKEN = "</s>";
public static final String PAD_TOKEN = "<pad>";
public static final String CLASS_TOKEN = "<s>";
public static final String MASK_TOKEN = "<mask>";
public static final String MASK_TOKEN = XLMRobertaTokenization.MASK_TOKEN;

private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;

import java.io.IOException;
import java.util.Collections;
Expand Down Expand Up @@ -135,6 +136,7 @@ public RestResponse buildResponse(T response, XContentBuilder builder) throws Ex
Map<String, String> params = new HashMap<>(channel.request().params());
defaultToXContentParamValues.forEach((k, v) -> params.computeIfAbsent(k, defaultToXContentParamValues::get));
includes.forEach(include -> params.put(include, "true"));
params.put(ToXContentParams.FOR_INTERNAL_STORAGE, "false");
response.toXContent(builder, new ToXContent.MapParams(params));
return new RestResponse(getStatus(response), builder);
}
Expand Down
Loading