Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blacklist a number of prediction field names. #49371

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU
* This way the user can see if the prediction was made with confidence they need.
*/
private static final int DEFAULT_NUM_TOP_CLASSES = 2;
/**
* User-provided name for prediction field must not clash with names of other fields emitted on the same JSON level by C++ code.
* This list should be updated every time a new field is added in lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc.
*/
private static final List<String> PREDICTION_FIELD_NAME_BLACKLIST = List.of("prediction_probability", "is_training", "top_classes");

private final String dependentVariable;
private final BoostedTreeParams boostedTreeParams;
Expand All @@ -82,6 +87,11 @@ public Classification(String dependentVariable,
@Nullable String predictionFieldName,
@Nullable Integer numTopClasses,
@Nullable Double trainingPercent) {
if (predictionFieldName != null && PREDICTION_FIELD_NAME_BLACKLIST.contains(predictionFieldName)) {
przemekwitek marked this conversation as resolved.
Show resolved Hide resolved
throw ExceptionsHelper.badRequestException(
"[{}] must not be equal to any of {}",
PREDICTION_FIELD_NAME.getPreferredName(), PREDICTION_FIELD_NAME_BLACKLIST);
}
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ public static Regression fromXContent(XContentParser parser, boolean ignoreUnkno
return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
}

/**
* User-provided name for prediction field must not clash with names of other fields emitted on the same JSON level by C++ code.
* This list should be updated every time a new field is added in lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc.
*/
private static final List<String> PREDICTION_FIELD_NAME_BLACKLIST = List.of("is_training");

private final String dependentVariable;
private final BoostedTreeParams boostedTreeParams;
private final String predictionFieldName;
Expand All @@ -65,6 +71,10 @@ public Regression(String dependentVariable,
BoostedTreeParams boostedTreeParams,
@Nullable String predictionFieldName,
@Nullable Double trainingPercent) {
if (predictionFieldName != null && PREDICTION_FIELD_NAME_BLACKLIST.contains(predictionFieldName)) {
throw ExceptionsHelper.badRequestException(
"[{}] must not be equal to any of {}", PREDICTION_FIELD_NAME.getPreferredName(), PREDICTION_FIELD_NAME_BLACKLIST);
}
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,33 @@ protected Writeable.Reader<Classification> instanceReader() {
return Classification::new;
}

public void testConstructor_GivenPredictionFieldNameIsBlacklisted() {
{
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "prediction_probability", 3, 50.0));

assertThat(
e.getMessage(),
equalTo("[prediction_field_name] must not be equal to any of [prediction_probability, is_training, top_classes]"));
}
{
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "is_training", 3, 50.0));

assertThat(
e.getMessage(),
equalTo("[prediction_field_name] must not be equal to any of [prediction_probability, is_training, top_classes]"));
}
{
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "top_classes", 3, 50.0));

assertThat(
e.getMessage(),
equalTo("[prediction_field_name] must not be equal to any of [prediction_probability, is_training, top_classes]"));
}
}

public void testConstructor_GivenTrainingPercentIsLessThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ protected Writeable.Reader<Regression> instanceReader() {
return Regression::new;
}

public void testConstructor_GivenPredictionFieldNameIsBlacklisted() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "is_training", 50.0));

assertThat(e.getMessage(), equalTo("[prediction_field_name] must not be equal to any of [is_training]"));
}

public void testConstructor_GivenTrainingPercentIsLessThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,29 @@ setup:
}
}

---
"Test put regression given prediction_field_name is blacklisted":

- do:
catch: /\[prediction_field_name\] must not be equal to any of \[is_training\]/
ml.put_data_frame_analytics:
id: "regression-prediction-field-name-is-blacklisted"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"prediction_field_name": "is_training"
}
}
}

---
"Test put regression given training_percent is less than one":

Expand Down Expand Up @@ -1727,6 +1750,29 @@ setup:
}
}

---
"Test put classification given prediction_field_name is blacklisted":

- do:
catch: /\[prediction_field_name\] must not be equal to any of \[prediction_probability, is_training, top_classes\]/
ml.put_data_frame_analytics:
id: "classification-prediction-field-name-is-blacklisted"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"classification": {
"dependent_variable": "foo",
"prediction_field_name": "is_training"
}
}
}

---
"Test put classification given training_percent is less than one":

Expand Down