diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java index d70bd713bd60a..89b6f3ffa4a59 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java @@ -35,6 +35,7 @@ import java.util.Objects; import java.util.stream.Collectors; +import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; public class EvaluateDataFrameResponse implements ToXContentObject { @@ -47,7 +48,7 @@ public static EvaluateDataFrameResponse fromXContent(XContentParser parser) thro ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); String evaluationName = parser.currentName(); parser.nextToken(); - Map metrics = parser.map(LinkedHashMap::new, EvaluateDataFrameResponse::parseMetric); + Map metrics = parser.map(LinkedHashMap::new, p -> parseMetric(evaluationName, p)); List knownMetrics = metrics.values().stream() .filter(Objects::nonNull) // Filter out null values returned by {@link EvaluateDataFrameResponse::parseMetric}. @@ -56,10 +57,10 @@ public static EvaluateDataFrameResponse fromXContent(XContentParser parser) thro return new EvaluateDataFrameResponse(evaluationName, knownMetrics); } - private static EvaluationMetric.Result parseMetric(XContentParser parser) throws IOException { + private static EvaluationMetric.Result parseMetric(String evaluationName, XContentParser parser) throws IOException { String metricName = parser.currentName(); try { - return parser.namedObject(EvaluationMetric.Result.class, metricName, null); + return parser.namedObject(EvaluationMetric.Result.class, registeredMetricName(evaluationName, metricName), null); } catch (NamedObjectNotFoundException e) { parser.skipChildren(); // Metric name not recognized. Return {@code null} value here and filter it out later. diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index efe58b9739eda..cd5c2abdf5627 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -20,24 +20,36 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; -import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; -import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; -import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.plugins.spi.NamedXContentProvider; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; import java.util.Arrays; import java.util.List; public class MlEvaluationNamedXContentProvider implements NamedXContentProvider { + /** + * Constructs the name under which a metric (or metric result) is registered. + * The name is prefixed with evaluation name so that registered names are unique. + * + * @param evaluationName name of the evaluation + * @param metricName name of the metric + * @return name appropriate for registering a metric (or metric result) in {@link NamedXContentRegistry} + */ + public static String registeredMetricName(String evaluationName, String metricName) { + return evaluationName + "." + metricName; + } + @Override public List getNamedXContentParsers() { return Arrays.asList( @@ -47,39 +59,91 @@ Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClass new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Classification.NAME), Classification::fromXContent), new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent), // Evaluation metrics - new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent), - new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent), - new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent), new NamedXContentRegistry.Entry( - EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent), + EvaluationMetric.class, + new ParseField(registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME)), + AucRocMetric::fromXContent), new NamedXContentRegistry.Entry( - EvaluationMetric.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric::fromXContent), + EvaluationMetric.class, + new ParseField(registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME)), + PrecisionMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, - new ParseField(MulticlassConfusionMatrixMetric.NAME), + new ParseField(registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME)), + RecallMetric::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, + new ParseField(registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME)), + ConfusionMatrixMetric::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, + new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)), + AccuracyMetric::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, + new ParseField(registeredMetricName( + Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)), + org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, + new ParseField(registeredMetricName( + Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)), + org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, + new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)), MulticlassConfusionMatrixMetric::fromXContent), new NamedXContentRegistry.Entry( - EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent), + EvaluationMetric.class, + new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)), + MeanSquaredErrorMetric::fromXContent), new NamedXContentRegistry.Entry( - EvaluationMetric.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric::fromXContent), + EvaluationMetric.class, + new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)), + RSquaredMetric::fromXContent), // Evaluation metrics results new NamedXContentRegistry.Entry( - EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent), + EvaluationMetric.Result.class, + new ParseField(registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME)), + AucRocMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, + new ParseField(registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME)), + PrecisionMetric.Result::fromXContent), new NamedXContentRegistry.Entry( - EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent), + EvaluationMetric.Result.class, + new ParseField(registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME)), + RecallMetric.Result::fromXContent), new NamedXContentRegistry.Entry( - EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent), + EvaluationMetric.Result.class, + new ParseField(registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME)), + ConfusionMatrixMetric.Result::fromXContent), new NamedXContentRegistry.Entry( - EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent), + EvaluationMetric.Result.class, + new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)), + AccuracyMetric.Result::fromXContent), new NamedXContentRegistry.Entry( - EvaluationMetric.Result.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric.Result::fromXContent), + EvaluationMetric.Result.class, + new ParseField(registeredMetricName( + Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)), + org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, - new ParseField(MulticlassConfusionMatrixMetric.NAME), + new ParseField(registeredMetricName( + Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)), + org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, + new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)), MulticlassConfusionMatrixMetric.Result::fromXContent), new NamedXContentRegistry.Entry( - EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent), + EvaluationMetric.Result.class, + new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)), + MeanSquaredErrorMetric.Result::fromXContent), new NamedXContentRegistry.Entry( - EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent)); + EvaluationMetric.Result.class, + new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)), + RSquaredMetric.Result::fromXContent) + ); } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java index d7466fcc023b5..f64078228986b 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java @@ -32,6 +32,10 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + /** * Evaluation of classification results. */ @@ -48,10 +52,10 @@ public class Classification implements Evaluation { NAME, true, a -> new Classification((String) a[0], (String) a[1], (List) a[2])); static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); - PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD); - PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), - (p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS); + PARSER.declareString(constructorArg(), ACTUAL_FIELD); + PARSER.declareString(constructorArg(), PREDICTED_FIELD); + PARSER.declareNamedObjects( + optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS); } public static Classification fromXContent(XContentParser parser) { diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetric.java new file mode 100644 index 0000000000000..8eff7986dcc36 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetric.java @@ -0,0 +1,201 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +/** + * {@link PrecisionMetric} is a metric that answers the question: + * "What fraction of documents classified as X actually belongs to X?" + * for any given class X + * + * equation: precision(X) = TP(X) / (TP(X) + FP(X)) + * where: TP(X) - number of true positives wrt X + * FP(X) - number of false positives wrt X + */ +public class PrecisionMetric implements EvaluationMetric { + + public static final String NAME = "precision"; + + private static final ObjectParser PARSER = new ObjectParser<>(NAME, true, PrecisionMetric::new); + + public static PrecisionMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public PrecisionMetric() {} + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + return Objects.hashCode(NAME); + } + + public static class Result implements EvaluationMetric.Result { + + private static final ParseField CLASSES = new ParseField("classes"); + private static final ParseField AVG_PRECISION = new ParseField("avg_precision"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("precision_result", true, a -> new Result((List) a[0], (double) a[1])); + + static { + PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES); + PARSER.declareDouble(constructorArg(), AVG_PRECISION); + } + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** List of per-class results. */ + private final List classes; + /** Average of per-class precisions. */ + private final double avgPrecision; + + public Result(List classes, double avgPrecision) { + this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes)); + this.avgPrecision = avgPrecision; + } + + @Override + public String getMetricName() { + return NAME; + } + + public List getClasses() { + return classes; + } + + public double getAvgPrecision() { + return avgPrecision; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASSES.getPreferredName(), classes); + builder.field(AVG_PRECISION.getPreferredName(), avgPrecision); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(this.classes, that.classes) + && this.avgPrecision == that.avgPrecision; + } + + @Override + public int hashCode() { + return Objects.hash(classes, avgPrecision); + } + } + + public static class PerClassResult implements ToXContentObject { + + private static final ParseField CLASS_NAME = new ParseField("class_name"); + private static final ParseField PRECISION = new ParseField("precision"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("precision_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1])); + + static { + PARSER.declareString(constructorArg(), CLASS_NAME); + PARSER.declareDouble(constructorArg(), PRECISION); + } + + /** Name of the class. */ + private final String className; + /** Fraction of documents predicted as belonging to the {@code predictedClass} class predicted correctly. */ + private final double precision; + + public PerClassResult(String className, double precision) { + this.className = Objects.requireNonNull(className); + this.precision = precision; + } + + public String getClassName() { + return className; + } + + public double getPrecision() { + return precision; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASS_NAME.getPreferredName(), className); + builder.field(PRECISION.getPreferredName(), precision); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PerClassResult that = (PerClassResult) o; + return Objects.equals(this.className, that.className) + && this.precision == that.precision; + } + + @Override + public int hashCode() { + return Objects.hash(className, precision); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetric.java new file mode 100644 index 0000000000000..d46a70da8c3f6 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetric.java @@ -0,0 +1,201 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +/** + * {@link RecallMetric} is a metric that answers the question: + * "What fraction of documents belonging to X have been predicted as X by the classifier?" + * for any given class X + * + * equation: recall(X) = TP(X) / (TP(X) + FN(X)) + * where: TP(X) - number of true positives wrt X + * FN(X) - number of false negatives wrt X + */ +public class RecallMetric implements EvaluationMetric { + + public static final String NAME = "recall"; + + private static final ObjectParser PARSER = new ObjectParser<>(NAME, true, RecallMetric::new); + + public static RecallMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public RecallMetric() {} + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + return Objects.hashCode(NAME); + } + + public static class Result implements EvaluationMetric.Result { + + private static final ParseField CLASSES = new ParseField("classes"); + private static final ParseField AVG_RECALL = new ParseField("avg_recall"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("recall_result", true, a -> new Result((List) a[0], (double) a[1])); + + static { + PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES); + PARSER.declareDouble(constructorArg(), AVG_RECALL); + } + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** List of per-class results. */ + private final List classes; + /** Average of per-class recalls. */ + private final double avgRecall; + + public Result(List classes, double avgRecall) { + this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes)); + this.avgRecall = avgRecall; + } + + @Override + public String getMetricName() { + return NAME; + } + + public List getClasses() { + return classes; + } + + public double getAvgRecall() { + return avgRecall; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASSES.getPreferredName(), classes); + builder.field(AVG_RECALL.getPreferredName(), avgRecall); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(this.classes, that.classes) + && this.avgRecall == that.avgRecall; + } + + @Override + public int hashCode() { + return Objects.hash(classes, avgRecall); + } + } + + public static class PerClassResult implements ToXContentObject { + + private static final ParseField CLASS_NAME = new ParseField("class_name"); + private static final ParseField RECALL = new ParseField("recall"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("recall_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1])); + + static { + PARSER.declareString(constructorArg(), CLASS_NAME); + PARSER.declareDouble(constructorArg(), RECALL); + } + + /** Name of the class. */ + private final String className; + /** Fraction of documents actually belonging to the {@code actualClass} class predicted correctly. */ + private final double recall; + + public PerClassResult(String className, double recall) { + this.className = Objects.requireNonNull(className); + this.recall = recall; + } + + public String getClassName() { + return className; + } + + public double getRecall() { + return recall; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASS_NAME.getPreferredName(), className); + builder.field(RECALL.getPreferredName(), recall); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PerClassResult that = (PerClassResult) o; + return Objects.equals(this.className, that.className) + && this.recall == that.recall; + } + + @Override + public int hashCode() { + return Objects.hash(className, recall); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java index 79b9ab6eb1dd5..1d8b5bcdb0902 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java @@ -33,6 +33,10 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + /** * Evaluation of regression results. */ @@ -49,10 +53,10 @@ public class Regression implements Evaluation { NAME, true, a -> new Regression((String) a[0], (String) a[1], (List) a[2])); static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); - PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD); - PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), - (p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS); + PARSER.declareString(constructorArg(), ACTUAL_FIELD); + PARSER.declareString(constructorArg(), PREDICTED_FIELD); + PARSER.declareNamedObjects( + optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS); } public static Regression fromXContent(XContentParser parser) { diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java index cb531c6ab044a..b75af7cec11f6 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java @@ -33,6 +33,7 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -59,7 +60,8 @@ public class BinarySoftClassification implements Evaluation { static { PARSER.declareString(constructorArg(), ACTUAL_FIELD); PARSER.declareString(constructorArg(), PREDICTED_PROBABILITY_FIELD); - PARSER.declareNamedObjects(optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, n, null), METRICS); + PARSER.declareNamedObjects( + optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), null), METRICS); } public static BinarySoftClassification fromXContent(XContentParser parser) { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 0bddbfdbb0b7c..18a02f2a4608a 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1860,6 +1860,70 @@ public void testEvaluateDataFrame_Classification() throws IOException { new AccuracyMetric.ActualClass("ant", 1, 0.0)))); assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly } + { // Precision + EvaluateDataFrameRequest evaluateDataFrameRequest = + new EvaluateDataFrameRequest( + indexName, + null, + new Classification( + actualClassField, + predictedClassField, + new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric())); + + EvaluateDataFrameResponse evaluateDataFrameResponse = + execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + + org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult = + evaluateDataFrameResponse.getMetricByName( + org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME); + assertThat( + precisionResult.getMetricName(), + equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)); + assertThat( + precisionResult.getClasses(), + equalTo( + Arrays.asList( + // 3 out of 5 examples labeled as "cat" were classified correctly + new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("cat", 0.6), + // 3 out of 4 examples labeled as "dog" were classified correctly + new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("dog", 0.75)))); + assertThat(precisionResult.getAvgPrecision(), equalTo(0.675)); + } + { // Recall + EvaluateDataFrameRequest evaluateDataFrameRequest = + new EvaluateDataFrameRequest( + indexName, + null, + new Classification( + actualClassField, + predictedClassField, + new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric())); + + EvaluateDataFrameResponse evaluateDataFrameResponse = + execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + + org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult = + evaluateDataFrameResponse.getMetricByName( + org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME); + assertThat( + recallResult.getMetricName(), + equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)); + assertThat( + recallResult.getClasses(), + equalTo( + Arrays.asList( + // 3 out of 5 examples labeled as "cat" were classified correctly + new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("cat", 0.6), + // 3 out of 4 examples labeled as "dog" were classified correctly + new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("dog", 0.75), + // no examples labeled as "ant" were classified correctly + new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("ant", 0.0)))); + assertThat(recallResult.getAvgRecall(), equalTo(0.45)); + } { // No size provided for MulticlassConfusionMatrixMetric, default used instead EvaluateDataFrameRequest evaluateDataFrameRequest = new EvaluateDataFrameRequest( diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 4e43fae52d94c..f1d9976cd6060 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -128,6 +128,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.hamcrest.CoreMatchers.endsWith; import static org.hamcrest.CoreMatchers.equalTo; @@ -688,7 +689,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(51, namedXContents.size()); + assertEquals(55, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -730,26 +731,36 @@ public void testProvidedNamedXContents() { assertTrue(names.contains(TimeSyncConfig.NAME)); assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME)); - assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); + assertEquals(Integer.valueOf(10), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); assertThat(names, - hasItems(AucRocMetric.NAME, - PrecisionMetric.NAME, - RecallMetric.NAME, - ConfusionMatrixMetric.NAME, - AccuracyMetric.NAME, - MulticlassConfusionMatrixMetric.NAME, - MeanSquaredErrorMetric.NAME, - RSquaredMetric.NAME)); - assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); + hasItems( + registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME), + registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME), + registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME), + registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME), + registeredMetricName(Classification.NAME, AccuracyMetric.NAME), + registeredMetricName( + Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME), + registeredMetricName( + Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME), + registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME), + registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME), + registeredMetricName(Regression.NAME, RSquaredMetric.NAME))); + assertEquals(Integer.valueOf(10), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); assertThat(names, - hasItems(AucRocMetric.NAME, - PrecisionMetric.NAME, - RecallMetric.NAME, - ConfusionMatrixMetric.NAME, - AccuracyMetric.NAME, - MulticlassConfusionMatrixMetric.NAME, - MeanSquaredErrorMetric.NAME, - RSquaredMetric.NAME)); + hasItems( + registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME), + registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME), + registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME), + registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME), + registeredMetricName(Classification.NAME, AccuracyMetric.NAME), + registeredMetricName( + Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME), + registeredMetricName( + Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME), + registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME), + registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME), + registeredMetricName(Regression.NAME, RSquaredMetric.NAME))); assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class)); assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME)); assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 87c43e1084386..6082074c69d27 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -3372,7 +3372,9 @@ public void testEvaluateDataFrame_Classification() throws Exception { "predicted_class", // <3> // Evaluation metrics // <4> new AccuracyMetric(), // <5> - new MulticlassConfusionMatrixMetric(3)); // <6> + new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric(), // <6> + new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric(), // <7> + new MulticlassConfusionMatrixMetric(3)); // <8> // end::evaluate-data-frame-evaluation-classification EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation); @@ -3382,16 +3384,34 @@ public void testEvaluateDataFrame_Classification() throws Exception { AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1> double accuracy = accuracyResult.getOverallAccuracy(); // <2> + org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult = + response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME); // <3> + double precision = precisionResult.getAvgPrecision(); // <4> + + org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult = + response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME); // <5> + double recall = recallResult.getAvgRecall(); // <6> + MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix = - response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <3> + response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <7> - List confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <4> - long otherActualClassCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <5> + List confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <8> + long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <9> // end::evaluate-data-frame-results-classification assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME)); assertThat(accuracy, equalTo(0.6)); + assertThat( + precisionResult.getMetricName(), + equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)); + assertThat(precision, equalTo(0.675)); + + assertThat( + recallResult.getMetricName(), + equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)); + assertThat(recall, equalTo(0.45)); + assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); assertThat( confusionMatrix, @@ -3412,7 +3432,7 @@ public void testEvaluateDataFrame_Classification() throws Exception { 4L, Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0L)))); - assertThat(otherActualClassCount, equalTo(0L)); + assertThat(otherClassesCount, equalTo(0L)); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java index f6b7459b1043b..92d3ab81bce47 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java @@ -64,6 +64,8 @@ public static EvaluateDataFrameResponse randomResponse() { metrics = randomSubsetOf( Arrays.asList( AccuracyMetricResultTests.randomResult(), + org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetricResultTests.randomResult(), + org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetricResultTests.randomResult(), MulticlassConfusionMatrixMetricResultTests.randomResult())); break; default: diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java index acb6f21cb8209..81691fcbb2eed 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -41,6 +41,8 @@ static Classification createRandom() { randomSubsetOf( Arrays.asList( AccuracyMetricTests.createRandom(), + PrecisionMetricTests.createRandom(), + RecallMetricTests.createRandom(), MulticlassConfusionMatrixMetricTests.createRandom())); return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetricResultTests.java new file mode 100644 index 0000000000000..ef6e41e78f0e8 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetricResultTests.java @@ -0,0 +1,67 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class PrecisionMetricResultTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + public static Result randomResult() { + int numClasses = randomIntBetween(2, 100); + List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); + List classes = new ArrayList<>(numClasses); + for (int i = 0; i < numClasses; i++) { + double precision = randomDoubleBetween(0.0, 1.0, true); + classes.add(new PerClassResult(classNames.get(i), precision)); + } + double avgPrecision = randomDoubleBetween(0.0, 1.0, true); + return new Result(classes, avgPrecision); + } + + @Override + protected Result createTestInstance() { + return randomResult(); + } + + @Override + protected Result doParseInstance(XContentParser parser) throws IOException { + return Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetricTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetricTests.java new file mode 100644 index 0000000000000..7e21be190d938 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetricTests.java @@ -0,0 +1,53 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class PrecisionMetricTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + static PrecisionMetric createRandom() { + return new PrecisionMetric(); + } + + @Override + protected PrecisionMetric createTestInstance() { + return createRandom(); + } + + @Override + protected PrecisionMetric doParseInstance(XContentParser parser) throws IOException { + return PrecisionMetric.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetricResultTests.java new file mode 100644 index 0000000000000..f8fffb405ea1b --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetricResultTests.java @@ -0,0 +1,67 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class RecallMetricResultTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + public static Result randomResult() { + int numClasses = randomIntBetween(2, 100); + List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); + List classes = new ArrayList<>(numClasses); + for (int i = 0; i < numClasses; i++) { + double recall = randomDoubleBetween(0.0, 1.0, true); + classes.add(new PerClassResult(classNames.get(i), recall)); + } + double avgRecall = randomDoubleBetween(0.0, 1.0, true); + return new Result(classes, avgRecall); + } + + @Override + protected Result createTestInstance() { + return randomResult(); + } + + @Override + protected Result doParseInstance(XContentParser parser) throws IOException { + return Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetricTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetricTests.java new file mode 100644 index 0000000000000..087f9838aaf3e --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetricTests.java @@ -0,0 +1,53 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class RecallMetricTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + static RecallMetric createRandom() { + return new RecallMetric(); + } + + @Override + protected RecallMetric createTestInstance() { + return createRandom(); + } + + @Override + protected RecallMetric doParseInstance(XContentParser parser) throws IOException { + return RecallMetric.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc index b4abafa249ee0..57a82d1c7132f 100644 --- a/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc +++ b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc @@ -53,7 +53,9 @@ include-tagged::{doc-tests-file}[{api}-evaluation-classification] <3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example. <4> The remaining parameters are the metrics to be calculated based on the two fields described above <5> Accuracy -<6> Multiclass confusion matrix of size 3 +<6> Precision +<7> Recall +<8> Multiclass confusion matrix of size 3 ===== Regression @@ -104,9 +106,13 @@ include-tagged::{doc-tests-file}[{api}-results-classification] <1> Fetching accuracy metric by name <2> Fetching the actual accuracy value -<3> Fetching multiclass confusion matrix metric by name -<4> Fetching the contents of the confusion matrix -<5> Fetching the number of classes that were not included in the matrix +<3> Fetching precision metric by name +<4> Fetching the actual precision value +<5> Fetching recall metric by name +<6> Fetching the actual recall value +<7> Fetching multiclass confusion matrix metric by name +<8> Fetching the contents of the confusion matrix +<9> Fetching the number of classes that were not included in the matrix ===== Regression @@ -118,4 +124,4 @@ include-tagged::{doc-tests-file}[{api}-results-regression] <1> Fetching mean squared error metric by name <2> Fetching the actual mean squared error value <3> Fetching R squared metric by name -<4> Fetching the actual R squared value \ No newline at end of file +<4> Fetching the actual R squared value diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index a3343c1850ea3..ca65a05ae8949 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -79,7 +79,6 @@ import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.CloseJobAction; -import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.DeleteCalendarAction; import org.elasticsearch.xpack.core.ml.action.DeleteCalendarEventAction; import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; @@ -91,6 +90,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; +import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction; import org.elasticsearch.xpack.core.ml.action.FindFileStructureAction; import org.elasticsearch.xpack.core.ml.action.FlushJobAction; @@ -146,18 +146,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; @@ -267,6 +256,9 @@ import java.util.Map; import java.util.Optional; import java.util.function.Supplier; +import java.util.stream.Stream; + +import static java.util.stream.Collectors.toList; public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPlugin { @@ -474,7 +466,8 @@ public List> getClientActions() { @Override public List getNamedWriteables() { - return Arrays.asList( + return Stream.concat( + Arrays.asList( // graph new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.GRAPH, GraphFeatureSetUsage::new), // logstash @@ -502,28 +495,6 @@ public List getNamedWriteables() { new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new), new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new), new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new), - // ML - Data frame evaluation - new NamedWriteableRegistry.Entry( - Evaluation.class, - org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification.NAME.getPreferredName(), - org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification::new), - new NamedWriteableRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME.getPreferredName(), - MulticlassConfusionMatrix::new), - new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, MulticlassConfusionMatrix.NAME.getPreferredName(), - MulticlassConfusionMatrix.Result::new), - new NamedWriteableRegistry.Entry(ClassificationMetric.class, Accuracy.NAME.getPreferredName(), Accuracy::new), - new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, Accuracy.NAME.getPreferredName(), Accuracy.Result::new), - new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), - BinarySoftClassification::new), - new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), AucRoc::new), - new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(), Precision::new), - new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(), Recall::new), - new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(), - ConfusionMatrix::new), - new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(), AucRoc.Result::new), - new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME, ScoreByThresholdResult::new), - new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(), - ConfusionMatrix.Result::new), // ML - Inference preprocessing new NamedWriteableRegistry.Entry(PreProcessor.class, FrequencyEncoding.NAME.getPreferredName(), FrequencyEncoding::new), new NamedWriteableRegistry.Entry(PreProcessor.class, OneHotEncoding.NAME.getPreferredName(), OneHotEncoding::new), @@ -628,7 +599,9 @@ public List getNamedWriteables() { new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.ANALYTICS, AnalyticsFeatureSetUsage::new), // Enrich new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.ENRICH, EnrichFeatureSet.Usage::new) - ); + ).stream(), + MlEvaluationNamedXContentProvider.getNamedWriteables().stream() + ).collect(toList()); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java index 98888c539c189..1a79dff41e10c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java @@ -7,12 +7,14 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -76,8 +78,9 @@ default SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery); for (EvaluationMetric metric : getMetrics()) { // Fetch aggregations requested by individual metrics - List aggs = metric.aggs(getActualField(), getPredictedField()); - aggs.forEach(searchSourceBuilder::aggregation); + Tuple, List> aggs = metric.aggs(getActualField(), getPredictedField()); + aggs.v1().forEach(searchSourceBuilder::aggregation); + aggs.v2().forEach(searchSourceBuilder::aggregation); } return searchSourceBuilder; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java index 7a539d030dd44..36bf7634cb43f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java @@ -6,10 +6,12 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import java.util.List; import java.util.Optional; @@ -30,7 +32,7 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable { * @param predictedField the field that stores the predicted value (class name or probability) * @return the aggregations required to compute the metric */ - List aggs(String actualField, String predictedField); + Tuple, List> aggs(String actualField, String predictedField); /** * Processes given aggregations as a step towards computing result diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index 1ef8b89a99609..42e530a7a602d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -5,109 +5,179 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation; +import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; public class MlEvaluationNamedXContentProvider implements NamedXContentProvider { - @Override - public List getNamedXContentParsers() { - List namedXContent = new ArrayList<>(); + /** + * Constructs the name under which a metric (or metric result) is registered. + * The name is prefixed with evaluation name so that registered names are unique. + * + * @param evaluationName name of the evaluation + * @param metricName name of the metric + * @return name appropriate for registering a metric (or metric result) in {@link NamedXContentRegistry} + */ + public static String registeredMetricName(ParseField evaluationName, ParseField metricName) { + return registeredMetricName(evaluationName.getPreferredName(), metricName.getPreferredName()); + } - // Evaluations - namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME, - BinarySoftClassification::fromXContent)); - namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Classification.NAME, Classification::fromXContent)); - namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent)); + /** + * Constructs the name under which a metric (or metric result) is registered. + * The name is prefixed with evaluation name so that registered names are unique. + * + * @param evaluationName name of the evaluation + * @param metricName name of the metric + * @return name appropriate for registering a metric (or metric result) in {@link NamedXContentRegistry} + */ + public static String registeredMetricName(String evaluationName, String metricName) { + return evaluationName + "." + metricName; + } - // Soft classification metrics - namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME, AucRoc::fromXContent)); - namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Precision.NAME, Precision::fromXContent)); - namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Recall.NAME, Recall::fromXContent)); - namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME, - ConfusionMatrix::fromXContent)); + @Override + public List getNamedXContentParsers() { + return Arrays.asList( + // Evaluations + new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME, BinarySoftClassification::fromXContent), + new NamedXContentRegistry.Entry(Evaluation.class, Classification.NAME, Classification::fromXContent), + new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent), - // Classification metrics - namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME, - MulticlassConfusionMatrix::fromXContent)); - namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, Accuracy.NAME, Accuracy::fromXContent)); + // Soft classification metrics + new NamedXContentRegistry.Entry(EvaluationMetric.class, + new ParseField(registeredMetricName(BinarySoftClassification.NAME, AucRoc.NAME)), + AucRoc::fromXContent), + new NamedXContentRegistry.Entry(EvaluationMetric.class, + new ParseField(registeredMetricName(BinarySoftClassification.NAME, Precision.NAME)), + Precision::fromXContent), + new NamedXContentRegistry.Entry(EvaluationMetric.class, + new ParseField(registeredMetricName(BinarySoftClassification.NAME, Recall.NAME)), + Recall::fromXContent), + new NamedXContentRegistry.Entry(EvaluationMetric.class, + new ParseField(registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrix.NAME)), + ConfusionMatrix::fromXContent), - // Regression metrics - namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent)); - namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, RSquared.NAME, RSquared::fromXContent)); + // Classification metrics + new NamedXContentRegistry.Entry(EvaluationMetric.class, + new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME)), + MulticlassConfusionMatrix::fromXContent), + new NamedXContentRegistry.Entry(EvaluationMetric.class, + new ParseField(registeredMetricName(Classification.NAME, Accuracy.NAME)), + Accuracy::fromXContent), + new NamedXContentRegistry.Entry(EvaluationMetric.class, + new ParseField( + registeredMetricName( + Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME)), + org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::fromXContent), + new NamedXContentRegistry.Entry(EvaluationMetric.class, + new ParseField( + registeredMetricName( + Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME)), + org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::fromXContent), - return namedXContent; + // Regression metrics + new NamedXContentRegistry.Entry(EvaluationMetric.class, + new ParseField(registeredMetricName(Regression.NAME, MeanSquaredError.NAME)), + MeanSquaredError::fromXContent), + new NamedXContentRegistry.Entry(EvaluationMetric.class, + new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)), + RSquared::fromXContent) + ); } - public List getNamedWriteables() { - List namedWriteables = new ArrayList<>(); - - // Evaluations - namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), - BinarySoftClassification::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Classification.NAME.getPreferredName(), - Classification::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Regression.NAME.getPreferredName(), Regression::new)); - - // Evaluation Metrics - namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), - AucRoc::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(), - Precision::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(), - Recall::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(), - ConfusionMatrix::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class, - MulticlassConfusionMatrix.NAME.getPreferredName(), - MulticlassConfusionMatrix::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class, Accuracy.NAME.getPreferredName(), Accuracy::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class, - MeanSquaredError.NAME.getPreferredName(), - MeanSquaredError::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class, - RSquared.NAME.getPreferredName(), - RSquared::new)); + public static List getNamedWriteables() { + return Arrays.asList( + // Evaluations + new NamedWriteableRegistry.Entry(Evaluation.class, + BinarySoftClassification.NAME.getPreferredName(), + BinarySoftClassification::new), + new NamedWriteableRegistry.Entry(Evaluation.class, + Classification.NAME.getPreferredName(), + Classification::new), + new NamedWriteableRegistry.Entry(Evaluation.class, + Regression.NAME.getPreferredName(), + Regression::new), - // Evaluation Metrics Results - namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(), - AucRoc.Result::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME, - ScoreByThresholdResult::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(), - ConfusionMatrix.Result::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, - MulticlassConfusionMatrix.NAME.getPreferredName(), - MulticlassConfusionMatrix.Result::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, - Accuracy.NAME.getPreferredName(), - Accuracy.Result::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, - MeanSquaredError.NAME.getPreferredName(), - MeanSquaredError.Result::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, - RSquared.NAME.getPreferredName(), - RSquared.Result::new)); + // Evaluation metrics + new NamedWriteableRegistry.Entry(EvaluationMetric.class, + registeredMetricName(BinarySoftClassification.NAME, AucRoc.NAME), + AucRoc::new), + new NamedWriteableRegistry.Entry(EvaluationMetric.class, + registeredMetricName(BinarySoftClassification.NAME, Precision.NAME), + Precision::new), + new NamedWriteableRegistry.Entry(EvaluationMetric.class, + registeredMetricName(BinarySoftClassification.NAME, Recall.NAME), + Recall::new), + new NamedWriteableRegistry.Entry(EvaluationMetric.class, + registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrix.NAME), + ConfusionMatrix::new), + new NamedWriteableRegistry.Entry(EvaluationMetric.class, + registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME), + MulticlassConfusionMatrix::new), + new NamedWriteableRegistry.Entry(EvaluationMetric.class, + registeredMetricName(Classification.NAME, Accuracy.NAME), + Accuracy::new), + new NamedWriteableRegistry.Entry(EvaluationMetric.class, + registeredMetricName( + Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME), + org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::new), + new NamedWriteableRegistry.Entry(EvaluationMetric.class, + registeredMetricName( + Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME), + org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::new), + new NamedWriteableRegistry.Entry(EvaluationMetric.class, + registeredMetricName(Regression.NAME, MeanSquaredError.NAME), + MeanSquaredError::new), + new NamedWriteableRegistry.Entry(EvaluationMetric.class, + registeredMetricName(Regression.NAME, RSquared.NAME), + RSquared::new), - return namedWriteables; + // Evaluation metrics results + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + registeredMetricName(BinarySoftClassification.NAME, AucRoc.NAME), + AucRoc.Result::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + registeredMetricName(BinarySoftClassification.NAME, ScoreByThresholdResult.NAME), + ScoreByThresholdResult::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrix.NAME), + ConfusionMatrix.Result::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME), + MulticlassConfusionMatrix.Result::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + registeredMetricName(Classification.NAME, Accuracy.NAME), + Accuracy.Result::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + registeredMetricName( + Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME), + org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.Result::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + registeredMetricName( + Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME), + org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.Result::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + registeredMetricName(Regression.NAME, MeanSquaredError.NAME), + MeanSquaredError.Result::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + registeredMetricName(Regression.NAME, RSquared.NAME), + RSquared.Result::new) + ); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java index 6acd5de4f45f8..01f303caf8445 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -18,8 +19,10 @@ import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -34,6 +37,7 @@ import java.util.Optional; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; /** * {@link Accuracy} is a metric that answers the question: @@ -41,7 +45,7 @@ * * equation: accuracy = 1/n * Σ(y == y´) */ -public class Accuracy implements ClassificationMetric { +public class Accuracy implements EvaluationMetric { public static final ParseField NAME = new ParseField("accuracy"); @@ -68,7 +72,7 @@ public Accuracy(StreamInput in) throws IOException {} @Override public String getWriteableName() { - return NAME.getPreferredName(); + return registeredMetricName(Classification.NAME, NAME); } @Override @@ -77,16 +81,18 @@ public String getName() { } @Override - public final List aggs(String actualField, String predictedField) { + public final Tuple, List> aggs(String actualField, String predictedField) { if (result != null) { - return Collections.emptyList(); + return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } Script accuracyScript = new Script(buildScript(actualField, predictedField)); - return Arrays.asList( - AggregationBuilders.terms(CLASSES_AGG_NAME) - .field(actualField) - .subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)), - AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript)); + return Tuple.tuple( + Arrays.asList( + AggregationBuilders.terms(CLASSES_AGG_NAME) + .field(actualField) + .subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)), + AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript)), + Collections.emptyList()); } @Override @@ -169,7 +175,7 @@ public Result(StreamInput in) throws IOException { @Override public String getWriteableName() { - return NAME.getPreferredName(); + return registeredMetricName(Classification.NAME, NAME); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java index ee312ee7c7fd8..fb8014697555e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -20,6 +21,8 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; + /** * Evaluation of classification results. */ @@ -33,13 +36,13 @@ public class Classification implements Evaluation { @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List) a[2])); + NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List) a[2])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD); PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), - (p, c, n) -> p.namedObject(ClassificationMetric.class, n, c), METRICS); + (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME.getPreferredName(), n), c), METRICS); } public static Classification fromXContent(XContentParser parser) { @@ -61,22 +64,22 @@ public static Classification fromXContent(XContentParser parser) { /** * The list of metrics to calculate */ - private final List metrics; + private final List metrics; - public Classification(String actualField, String predictedField, @Nullable List metrics) { + public Classification(String actualField, String predictedField, @Nullable List metrics) { this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD); this.metrics = initMetrics(metrics, Classification::defaultMetrics); } - private static List defaultMetrics() { + private static List defaultMetrics() { return Arrays.asList(new MulticlassConfusionMatrix()); } public Classification(StreamInput in) throws IOException { this.actualField = in.readString(); this.predictedField = in.readString(); - this.metrics = in.readNamedWriteableList(ClassificationMetric.class); + this.metrics = in.readNamedWriteableList(EvaluationMetric.class); } @Override @@ -95,7 +98,7 @@ public String getPredictedField() { } @Override - public List getMetrics() { + public List getMetrics() { return metrics; } @@ -118,8 +121,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); builder.startObject(METRICS.getPreferredName()); - for (ClassificationMetric metric : metrics) { - builder.field(metric.getWriteableName(), metric); + for (EvaluationMetric metric : metrics) { + builder.field(metric.getName(), metric); } builder.endObject(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java deleted file mode 100644 index a61ac9a702fa2..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java +++ /dev/null @@ -1,11 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; - -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; - -public interface ClassificationMetric extends EvaluationMetric { -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index 5dfda00a00a68..4f049efead348 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -7,6 +7,7 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -19,10 +20,12 @@ import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.BucketOrder; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.bucket.filter.Filters; import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter; import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.metrics.Cardinality; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -38,13 +41,14 @@ import static java.util.Comparator.comparing; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; /** * {@link MulticlassConfusionMatrix} is a metric that answers the question: - * "How many examples belonging to class X were classified as Y by the classifier?" + * "How many documents belonging to class X were classified as Y by the classifier?" * for all the possible class pairs {X, Y}. */ -public class MulticlassConfusionMatrix implements ClassificationMetric { +public class MulticlassConfusionMatrix implements EvaluationMetric { public static final ParseField NAME = new ParseField("multiclass_confusion_matrix"); @@ -92,7 +96,7 @@ public MulticlassConfusionMatrix(StreamInput in) throws IOException { @Override public String getWriteableName() { - return NAME.getPreferredName(); + return registeredMetricName(Classification.NAME, NAME); } @Override @@ -105,13 +109,15 @@ public int getSize() { } @Override - public final List aggs(String actualField, String predictedField) { + public final Tuple, List> aggs(String actualField, String predictedField) { if (topActualClassNames == null) { // This is step 1 - return Arrays.asList( - AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) - .field(actualField) - .order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true))) - .size(size)); + return Tuple.tuple( + Arrays.asList( + AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) + .field(actualField) + .order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true))) + .size(size)), + Collections.emptyList()); } if (result == null) { // This is step 2 KeyedFilter[] keyedFiltersActual = @@ -122,15 +128,17 @@ public final List aggs(String actualField, String predictedF topActualClassNames.stream() .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) .toArray(KeyedFilter[]::new); - return Arrays.asList( - AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS) - .field(actualField), - AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual) - .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted) - .otherBucket(true) - .otherBucketKey(OTHER_BUCKET_KEY))); - } - return Collections.emptyList(); + return Tuple.tuple( + Arrays.asList( + AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS) + .field(actualField), + AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual) + .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted) + .otherBucket(true) + .otherBucketKey(OTHER_BUCKET_KEY))), + Collections.emptyList()); + } + return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } @Override @@ -232,7 +240,7 @@ public Result(StreamInput in) throws IOException { @Override public String getWriteableName() { - return NAME.getPreferredName(); + return registeredMetricName(Classification.NAME, NAME); } @Override @@ -301,7 +309,7 @@ public static class ActualClass implements ToXContentObject, Writeable { /** Name of the actual class. */ private final String actualClass; - /** Number of documents (examples) belonging to the {code actualClass} class. */ + /** Number of documents belonging to the {code actualClass} class. */ private final long actualClassDocCount; /** List of predicted classes. */ private final List predictedClasses; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java new file mode 100644 index 0000000000000..b8b468aa0371c --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -0,0 +1,347 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.script.Script; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.BucketOrder; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; +import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders; +import org.elasticsearch.search.aggregations.bucket.filter.Filters; +import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.text.MessageFormat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; + +/** + * {@link Precision} is a metric that answers the question: + * "What fraction of documents classified as X actually belongs to X?" + * for any given class X + * + * equation: precision(X) = TP(X) / (TP(X) + FP(X)) + * where: TP(X) - number of true positives wrt X + * FP(X) - number of false positives wrt X + */ +public class Precision implements EvaluationMetric { + + public static final ParseField NAME = new ParseField("precision"); + + private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value"; + private static final String AGG_NAME_PREFIX = "classification_precision_"; + static final String ACTUAL_CLASSES_NAMES_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class"; + static final String BY_PREDICTED_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_predicted_class"; + static final String PER_PREDICTED_CLASS_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "per_predicted_class_precision"; + static final String AVG_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "avg_precision"; + + private static Script buildScript(Object...args) { + return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args)); + } + + private static final ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Precision::new); + + public static Precision fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000; + + private final int maxClassesCardinality; + private String actualField; + private List topActualClassNames; + private EvaluationMetricResult result; + + public Precision() { + this((Integer) null); + } + + // Visible for testing + public Precision(@Nullable Integer maxClassesCardinality) { + this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY; + } + + public Precision(StreamInput in) throws IOException { + this.maxClassesCardinality = in.readVInt(); + } + + @Override + public String getWriteableName() { + return registeredMetricName(Classification.NAME, NAME); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public final Tuple, List> aggs(String actualField, String predictedField) { + // Store given {@code actualField} for the purpose of generating error message in {@code process}. + this.actualField = actualField; + if (topActualClassNames == null) { // This is step 1 + return Tuple.tuple( + Arrays.asList( + AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME) + .field(actualField) + .order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true))) + .size(maxClassesCardinality)), + Collections.emptyList()); + } + if (result == null) { // This is step 2 + KeyedFilter[] keyedFiltersPredicted = + topActualClassNames.stream() + .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) + .toArray(KeyedFilter[]::new); + Script script = buildScript(actualField, predictedField); + return Tuple.tuple( + Arrays.asList( + AggregationBuilders.filters(BY_PREDICTED_CLASS_AGG_NAME, keyedFiltersPredicted) + .subAggregation(AggregationBuilders.avg(PER_PREDICTED_CLASS_PRECISION_AGG_NAME).script(script))), + Arrays.asList( + PipelineAggregatorBuilders.avgBucket( + AVG_PRECISION_AGG_NAME, + BY_PREDICTED_CLASS_AGG_NAME + ">" + PER_PREDICTED_CLASS_PRECISION_AGG_NAME))); + } + return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); + } + + @Override + public void process(Aggregations aggs) { + if (topActualClassNames == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) { + Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME); + if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) { + // This means there were more than {@code maxClassesCardinality} buckets. + // We cannot calculate average precision accurately, so we fail. + throw ExceptionsHelper.badRequestException( + "Cannot calculate average precision. Cardinality of field [{}] is too high", actualField); + } + topActualClassNames = + topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()); + } + if (result == null && + aggs.get(BY_PREDICTED_CLASS_AGG_NAME) instanceof Filters && + aggs.get(AVG_PRECISION_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { + Filters byPredictedClassAgg = aggs.get(BY_PREDICTED_CLASS_AGG_NAME); + NumericMetricsAggregation.SingleValue avgPrecisionAgg = aggs.get(AVG_PRECISION_AGG_NAME); + List classes = new ArrayList<>(byPredictedClassAgg.getBuckets().size()); + for (Filters.Bucket bucket : byPredictedClassAgg.getBuckets()) { + String className = bucket.getKeyAsString(); + NumericMetricsAggregation.SingleValue precisionAgg = bucket.getAggregations().get(PER_PREDICTED_CLASS_PRECISION_AGG_NAME); + double precision = precisionAgg.value(); + if (Double.isFinite(precision)) { + classes.add(new PerClassResult(className, precision)); + } + } + result = new Result(classes, avgPrecisionAgg.value()); + } + } + + @Override + public Optional getResult() { + return Optional.ofNullable(result); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(maxClassesCardinality); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + return Objects.hashCode(NAME.getPreferredName()); + } + + public static class Result implements EvaluationMetricResult { + + private static final ParseField CLASSES = new ParseField("classes"); + private static final ParseField AVG_PRECISION = new ParseField("avg_precision"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("precision_result", true, a -> new Result((List) a[0], (double) a[1])); + + static { + PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES); + PARSER.declareDouble(constructorArg(), AVG_PRECISION); + } + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** List of per-class results. */ + private final List classes; + /** Average of per-class precisions. */ + private final double avgPrecision; + + public Result(List classes, double avgPrecision) { + this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES)); + this.avgPrecision = avgPrecision; + } + + public Result(StreamInput in) throws IOException { + this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new)); + this.avgPrecision = in.readDouble(); + } + + @Override + public String getWriteableName() { + return registeredMetricName(Classification.NAME, NAME); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + public List getClasses() { + return classes; + } + + public double getAvgPrecision() { + return avgPrecision; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeList(classes); + out.writeDouble(avgPrecision); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASSES.getPreferredName(), classes); + builder.field(AVG_PRECISION.getPreferredName(), avgPrecision); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(this.classes, that.classes) + && this.avgPrecision == that.avgPrecision; + } + + @Override + public int hashCode() { + return Objects.hash(classes, avgPrecision); + } + } + + public static class PerClassResult implements ToXContentObject, Writeable { + + private static final ParseField CLASS_NAME = new ParseField("class_name"); + private static final ParseField PRECISION = new ParseField("precision"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("precision_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1])); + + static { + PARSER.declareString(constructorArg(), CLASS_NAME); + PARSER.declareDouble(constructorArg(), PRECISION); + } + + /** Name of the class. */ + private final String className; + /** Fraction of documents predicted as belonging to the {@code predictedClass} class predicted correctly. */ + private final double precision; + + public PerClassResult(String className, double precision) { + this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME); + this.precision = precision; + } + + public PerClassResult(StreamInput in) throws IOException { + this.className = in.readString(); + this.precision = in.readDouble(); + } + + public String getClassName() { + return className; + } + + public double getPrecision() { + return precision; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(className); + out.writeDouble(precision); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASS_NAME.getPreferredName(), className); + builder.field(PRECISION.getPreferredName(), precision); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PerClassResult that = (PerClassResult) o; + return Objects.equals(this.className, that.className) + && this.precision == that.precision; + } + + @Override + public int hashCode() { + return Objects.hash(className, precision); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java new file mode 100644 index 0000000000000..c3151b82484b0 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java @@ -0,0 +1,321 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.script.Script; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; +import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.text.MessageFormat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; + +/** + * {@link Recall} is a metric that answers the question: + * "What fraction of documents belonging to X have been predicted as X by the classifier?" + * for any given class X + * + * equation: recall(X) = TP(X) / (TP(X) + FN(X)) + * where: TP(X) - number of true positives wrt X + * FN(X) - number of false negatives wrt X + */ +public class Recall implements EvaluationMetric { + + public static final ParseField NAME = new ParseField("recall"); + + private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value"; + private static final String AGG_NAME_PREFIX = "classification_recall_"; + static final String BY_ACTUAL_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class"; + static final String PER_ACTUAL_CLASS_RECALL_AGG_NAME = AGG_NAME_PREFIX + "per_actual_class_recall"; + static final String AVG_RECALL_AGG_NAME = AGG_NAME_PREFIX + "avg_recall"; + + private static Script buildScript(Object...args) { + return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args)); + } + + private static final ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Recall::new); + + public static Recall fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000; + + private final int maxClassesCardinality; + private String actualField; + private EvaluationMetricResult result; + + public Recall() { + this((Integer) null); + } + + // Visible for testing + public Recall(@Nullable Integer maxClassesCardinality) { + this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY; + } + + public Recall(StreamInput in) throws IOException { + this.maxClassesCardinality = in.readVInt(); + } + + @Override + public String getWriteableName() { + return registeredMetricName(Classification.NAME, NAME); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public final Tuple, List> aggs(String actualField, String predictedField) { + // Store given {@code actualField} for the purpose of generating error message in {@code process}. + this.actualField = actualField; + if (result != null) { + return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); + } + Script script = buildScript(actualField, predictedField); + return Tuple.tuple( + Arrays.asList( + AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME) + .field(actualField) + .size(maxClassesCardinality) + .subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))), + Arrays.asList( + PipelineAggregatorBuilders.avgBucket( + AVG_RECALL_AGG_NAME, + BY_ACTUAL_CLASS_AGG_NAME + ">" + PER_ACTUAL_CLASS_RECALL_AGG_NAME))); + } + + @Override + public void process(Aggregations aggs) { + if (result == null && + aggs.get(BY_ACTUAL_CLASS_AGG_NAME) instanceof Terms && + aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { + Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME); + if (byActualClassAgg.getSumOfOtherDocCounts() > 0) { + // This means there were more than {@code maxClassesCardinality} buckets. + // We cannot calculate average recall accurately, so we fail. + throw ExceptionsHelper.badRequestException( + "Cannot calculate average recall. Cardinality of field [{}] is too high", actualField); + } + NumericMetricsAggregation.SingleValue avgRecallAgg = aggs.get(AVG_RECALL_AGG_NAME); + List classes = new ArrayList<>(byActualClassAgg.getBuckets().size()); + for (Terms.Bucket bucket : byActualClassAgg.getBuckets()) { + String className = bucket.getKeyAsString(); + NumericMetricsAggregation.SingleValue recallAgg = bucket.getAggregations().get(PER_ACTUAL_CLASS_RECALL_AGG_NAME); + classes.add(new PerClassResult(className, recallAgg.value())); + } + result = new Result(classes, avgRecallAgg.value()); + } + } + + @Override + public Optional getResult() { + return Optional.ofNullable(result); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(maxClassesCardinality); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + return Objects.hashCode(NAME.getPreferredName()); + } + + public static class Result implements EvaluationMetricResult { + + private static final ParseField CLASSES = new ParseField("classes"); + private static final ParseField AVG_RECALL = new ParseField("avg_recall"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("recall_result", true, a -> new Result((List) a[0], (double) a[1])); + + static { + PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES); + PARSER.declareDouble(constructorArg(), AVG_RECALL); + } + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** List of per-class results. */ + private final List classes; + /** Average of per-class recalls. */ + private final double avgRecall; + + public Result(List classes, double avgRecall) { + this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES)); + this.avgRecall = avgRecall; + } + + public Result(StreamInput in) throws IOException { + this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new)); + this.avgRecall = in.readDouble(); + } + + @Override + public String getWriteableName() { + return registeredMetricName(Classification.NAME, NAME); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + public List getClasses() { + return classes; + } + + public double getAvgRecall() { + return avgRecall; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeList(classes); + out.writeDouble(avgRecall); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASSES.getPreferredName(), classes); + builder.field(AVG_RECALL.getPreferredName(), avgRecall); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(this.classes, that.classes) + && this.avgRecall == that.avgRecall; + } + + @Override + public int hashCode() { + return Objects.hash(classes, avgRecall); + } + } + + public static class PerClassResult implements ToXContentObject, Writeable { + + private static final ParseField CLASS_NAME = new ParseField("class_name"); + private static final ParseField RECALL = new ParseField("recall"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("recall_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1])); + + static { + PARSER.declareString(constructorArg(), CLASS_NAME); + PARSER.declareDouble(constructorArg(), RECALL); + } + + /** Name of the class. */ + private final String className; + /** Fraction of documents actually belonging to the {@code actualClass} class predicted correctly. */ + private final double recall; + + public PerClassResult(String className, double recall) { + this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME); + this.recall = recall; + } + + public PerClassResult(StreamInput in) throws IOException { + this.className = in.readString(); + this.recall = in.readDouble(); + } + + public String getClassName() { + return className; + } + + public double getRecall() { + return recall; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(className); + out.writeDouble(recall); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASS_NAME.getPreferredName(), className); + builder.field(RECALL.getPreferredName(), recall); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PerClassResult that = (PerClassResult) o; + return Objects.equals(this.className, that.className) + && this.recall == that.recall; + } + + @Override + public int hashCode() { + return Objects.hash(className, recall); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java index dc8de45f7bce7..f2abbe54454f0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ObjectParser; @@ -15,7 +16,9 @@ import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import java.io.IOException; @@ -27,12 +30,14 @@ import java.util.Objects; import java.util.Optional; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; + /** * Calculates the mean squared error between two known numerical fields. * * equation: mse = 1/n * Σ(y - y´)^2 */ -public class MeanSquaredError implements RegressionMetric { +public class MeanSquaredError implements EvaluationMetric { public static final ParseField NAME = new ParseField("mean_squared_error"); @@ -62,11 +67,13 @@ public String getName() { } @Override - public List aggs(String actualField, String predictedField) { + public Tuple, List> aggs(String actualField, String predictedField) { if (result != null) { - return Collections.emptyList(); + return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } - return Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField)))); + return Tuple.tuple( + Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField)))), + Collections.emptyList()); } @Override @@ -82,7 +89,7 @@ public Optional getResult() { @Override public String getWriteableName() { - return NAME.getPreferredName(); + return registeredMetricName(Regression.NAME, NAME); } @Override @@ -125,7 +132,7 @@ public Result(StreamInput in) throws IOException { @Override public String getWriteableName() { - return NAME.getPreferredName(); + return registeredMetricName(Regression.NAME, NAME); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java index 408f8ff0a6900..c7b989dca1182 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ObjectParser; @@ -15,9 +16,11 @@ import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.ExtendedStats; import org.elasticsearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import java.io.IOException; @@ -29,6 +32,8 @@ import java.util.Objects; import java.util.Optional; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; + /** * Calculates R-Squared between two known numerical fields. * @@ -37,7 +42,7 @@ * SSres = Σ(y - y´)^2, The residual sum of squares * SStot = Σ(y - y_mean)^2, The total sum of squares */ -public class RSquared implements RegressionMetric { +public class RSquared implements EvaluationMetric { public static final ParseField NAME = new ParseField("r_squared"); @@ -67,13 +72,15 @@ public String getName() { } @Override - public List aggs(String actualField, String predictedField) { + public Tuple, List> aggs(String actualField, String predictedField) { if (result != null) { - return Collections.emptyList(); + return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } - return Arrays.asList( - AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))), - AggregationBuilders.extendedStats(ExtendedStatsAggregationBuilder.NAME + "_actual").field(actualField)); + return Tuple.tuple( + Arrays.asList( + AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))), + AggregationBuilders.extendedStats(ExtendedStatsAggregationBuilder.NAME + "_actual").field(actualField)), + Collections.emptyList()); } @Override @@ -97,7 +104,7 @@ public Optional getResult() { @Override public String getWriteableName() { - return NAME.getPreferredName(); + return registeredMetricName(Regression.NAME, NAME); } @Override @@ -140,7 +147,7 @@ public Result(StreamInput in) throws IOException { @Override public String getWriteableName() { - return NAME.getPreferredName(); + return registeredMetricName(Regression.NAME, NAME); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java index ccf16a9618ec6..cc32ea4049282 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -20,6 +21,8 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; + /** * Evaluation of regression results. */ @@ -33,13 +36,13 @@ public class Regression implements Evaluation { @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - NAME.getPreferredName(), a -> new Regression((String) a[0], (String) a[1], (List) a[2])); + NAME.getPreferredName(), a -> new Regression((String) a[0], (String) a[1], (List) a[2])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD); PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), - (p, c, n) -> p.namedObject(RegressionMetric.class, n, c), METRICS); + (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME.getPreferredName(), n), c), METRICS); } public static Regression fromXContent(XContentParser parser) { @@ -61,22 +64,22 @@ public static Regression fromXContent(XContentParser parser) { /** * The list of metrics to calculate */ - private final List metrics; + private final List metrics; - public Regression(String actualField, String predictedField, @Nullable List metrics) { + public Regression(String actualField, String predictedField, @Nullable List metrics) { this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD); this.metrics = initMetrics(metrics, Regression::defaultMetrics); } - private static List defaultMetrics() { + private static List defaultMetrics() { return Arrays.asList(new MeanSquaredError(), new RSquared()); } public Regression(StreamInput in) throws IOException { this.actualField = in.readString(); this.predictedField = in.readString(); - this.metrics = in.readNamedWriteableList(RegressionMetric.class); + this.metrics = in.readNamedWriteableList(EvaluationMetric.class); } @Override @@ -95,7 +98,7 @@ public String getPredictedField() { } @Override - public List getMetrics() { + public List getMetrics() { return metrics; } @@ -118,8 +121,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); builder.startObject(METRICS.getPreferredName()); - for (RegressionMetric metric : metrics) { - builder.field(metric.getWriteableName(), metric); + for (EvaluationMetric metric : metrics) { + builder.field(metric.getName(), metric); } builder.endObject(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java deleted file mode 100644 index 5b46829b4c852..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java +++ /dev/null @@ -1,11 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; - -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; - -public interface RegressionMetric extends EvaluationMetric { -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java index 53455bce3fa44..34667aaabc9b7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -15,6 +16,8 @@ import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -23,9 +26,9 @@ import java.util.List; import java.util.Optional; -import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification.actualIsTrueQuery; -abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric { +abstract class AbstractConfusionMatrixMetric implements EvaluationMetric { public static final ParseField AT = new ParseField("at"); @@ -63,11 +66,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } @Override - public final List aggs(String actualField, String predictedProbabilityField) { + public Tuple, List> aggs(String actualField, String predictedProbabilityField) { if (result != null) { - return Collections.emptyList(); + return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } - return aggsAt(actualField, predictedProbabilityField); + return Tuple.tuple(aggsAt(actualField, predictedProbabilityField), Collections.emptyList()); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java index 135e32ff508a9..614d351c887bb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java @@ -7,6 +7,7 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -18,8 +19,10 @@ import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.bucket.filter.Filter; import org.elasticsearch.search.aggregations.metrics.Percentiles; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -33,7 +36,8 @@ import java.util.Optional; import java.util.stream.IntStream; -import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification.actualIsTrueQuery; /** * Area under the curve (AUC) of the receiver operating characteristic (ROC). @@ -53,7 +57,7 @@ * When this is used for multi-class classification, it will calculate the ROC * curve of each class versus the rest. */ -public class AucRoc implements SoftClassificationMetric { +public class AucRoc implements EvaluationMetric { public static final ParseField NAME = new ParseField("auc_roc"); @@ -88,7 +92,7 @@ public AucRoc(StreamInput in) throws IOException { @Override public String getWriteableName() { - return NAME.getPreferredName(); + return registeredMetricName(BinarySoftClassification.NAME, NAME); } @Override @@ -123,9 +127,9 @@ public int hashCode() { } @Override - public List aggs(String actualField, String predictedProbabilityField) { + public Tuple, List> aggs(String actualField, String predictedProbabilityField) { if (result != null) { - return Collections.emptyList(); + return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray(); AggregationBuilder percentilesForClassValueAgg = @@ -138,7 +142,9 @@ public List aggs(String actualField, String predictedProbabi .filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(actualIsTrueQuery(actualField))) .subAggregation( AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles)); - return Arrays.asList(percentilesForClassValueAgg, percentilesForRestAgg); + return Tuple.tuple( + Arrays.asList(percentilesForClassValueAgg, percentilesForRestAgg), + Collections.emptyList()); } @Override @@ -330,7 +336,7 @@ public Result(StreamInput in) throws IOException { @Override public String getWriteableName() { - return NAME.getPreferredName(); + return registeredMetricName(BinarySoftClassification.NAME, NAME); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java index 67a635e078be2..8d4f4f01d02cd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java @@ -12,7 +12,10 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -20,6 +23,8 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; + /** * Evaluation of binary soft classification methods, e.g. outlier detection. * This is useful to evaluate problems where a model outputs a probability of whether @@ -34,19 +39,23 @@ public class BinarySoftClassification implements Evaluation { private static final ParseField METRICS = new ParseField("metrics"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - NAME.getPreferredName(), a -> new BinarySoftClassification((String) a[0], (String) a[1], (List) a[2])); + NAME.getPreferredName(), a -> new BinarySoftClassification((String) a[0], (String) a[1], (List) a[2])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_PROBABILITY_FIELD); PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), - (p, c, n) -> p.namedObject(SoftClassificationMetric.class, n, null), METRICS); + (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME.getPreferredName(), n), c), METRICS); } public static BinarySoftClassification fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } + static QueryBuilder actualIsTrueQuery(String actualField) { + return QueryBuilders.queryStringQuery(actualField + ": (1 OR true)"); + } + /** * The field where the actual class is marked up. * The value of this field is assumed to either be 1 or 0, or true or false. @@ -61,16 +70,16 @@ public static BinarySoftClassification fromXContent(XContentParser parser) { /** * The list of metrics to calculate */ - private final List metrics; + private final List metrics; public BinarySoftClassification(String actualField, String predictedProbabilityField, - @Nullable List metrics) { + @Nullable List metrics) { this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); this.predictedProbabilityField = ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD); this.metrics = initMetrics(metrics, BinarySoftClassification::defaultMetrics); } - private static List defaultMetrics() { + private static List defaultMetrics() { return Arrays.asList( new AucRoc(false), new Precision(Arrays.asList(0.25, 0.5, 0.75)), @@ -81,7 +90,7 @@ private static List defaultMetrics() { public BinarySoftClassification(StreamInput in) throws IOException { this.actualField = in.readString(); this.predictedProbabilityField = in.readString(); - this.metrics = in.readNamedWriteableList(SoftClassificationMetric.class); + this.metrics = in.readNamedWriteableList(EvaluationMetric.class); } @Override @@ -100,7 +109,7 @@ public String getPredictedField() { } @Override - public List getMetrics() { + public List getMetrics() { return metrics; } @@ -123,7 +132,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField); builder.startObject(METRICS.getPreferredName()); - for (SoftClassificationMetric metric : metrics) { + for (EvaluationMetric metric : metrics) { builder.field(metric.getName(), metric); } builder.endObject(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java index d52468a0214b6..1b1d5b8f9d170 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java @@ -21,6 +21,8 @@ import java.util.Arrays; import java.util.List; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; + public class ConfusionMatrix extends AbstractConfusionMatrixMetric { public static final ParseField NAME = new ParseField("confusion_matrix"); @@ -46,7 +48,7 @@ public ConfusionMatrix(StreamInput in) throws IOException { @Override public String getWriteableName() { - return NAME.getPreferredName(); + return registeredMetricName(BinarySoftClassification.NAME, NAME); } @Override @@ -129,7 +131,7 @@ public Result(StreamInput in) throws IOException { @Override public String getWriteableName() { - return NAME.getPreferredName(); + return registeredMetricName(BinarySoftClassification.NAME, NAME); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java index 80f838dd5d166..d05ddb5fc4c9b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java @@ -19,6 +19,8 @@ import java.util.Arrays; import java.util.List; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; + public class Precision extends AbstractConfusionMatrixMetric { public static final ParseField NAME = new ParseField("precision"); @@ -44,7 +46,7 @@ public Precision(StreamInput in) throws IOException { @Override public String getWriteableName() { - return NAME.getPreferredName(); + return registeredMetricName(BinarySoftClassification.NAME, NAME); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java index 70bda8099db89..2dd44aff6715d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java @@ -19,6 +19,8 @@ import java.util.Arrays; import java.util.List; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; + public class Recall extends AbstractConfusionMatrixMetric { public static final ParseField NAME = new ParseField("recall"); @@ -44,7 +46,7 @@ public Recall(StreamInput in) throws IOException { @Override public String getWriteableName() { - return NAME.getPreferredName(); + return registeredMetricName(BinarySoftClassification.NAME, NAME); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java index 0ad99a83cf25b..8fdb06bde4d6e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; +import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -13,9 +14,11 @@ import java.io.IOException; import java.util.Objects; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; + public class ScoreByThresholdResult implements EvaluationMetricResult { - public static final String NAME = "score_by_threshold_result"; + public static final ParseField NAME = new ParseField("score_by_threshold_result"); private final String name; private final double[] thresholds; @@ -36,7 +39,7 @@ public ScoreByThresholdResult(StreamInput in) throws IOException { @Override public String getWriteableName() { - return NAME; + return registeredMetricName(BinarySoftClassification.NAME, NAME); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java deleted file mode 100644 index 9a9c382caf9d1..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java +++ /dev/null @@ -1,17 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; - -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; - -public interface SoftClassificationMetric extends EvaluationMetric { - - static QueryBuilder actualIsTrueQuery(String actualField) { - return QueryBuilders.queryStringQuery(actualField + ": (1 OR true)"); - } -} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java index 77bb6f30e20eb..be0e0dd13ef5a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java @@ -31,7 +31,7 @@ public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTest @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { List namedWriteables = new ArrayList<>(); - namedWriteables.addAll(new MlEvaluationNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(MlEvaluationNamedXContentProvider.getNamedWriteables()); namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); return new NamedWriteableRegistry(namedWriteables); } @@ -46,13 +46,11 @@ protected NamedXContentRegistry xContentRegistry() { @Override protected Request createTestInstance() { - Request request = new Request(); int indicesCount = randomIntBetween(1, 5); List indices = new ArrayList<>(indicesCount); for (int i = 0; i < indicesCount; i++) { indices.add(randomAlphaOfLength(10)); } - request.setIndices(indices); QueryProvider queryProvider = null; if (randomBoolean()) { try { @@ -62,10 +60,11 @@ protected Request createTestInstance() { throw new UncheckedIOException(e); } } - request.setQueryProvider(queryProvider); Evaluation evaluation = randomBoolean() ? BinarySoftClassificationTests.createRandom() : RegressionTests.createRandom(); - request.setEvaluation(evaluation); - return request; + return new Request() + .setIndices(indices) + .setQueryProvider(queryProvider) + .setEvaluation(evaluation); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java index c9eb0ae437eb1..d722dd72eec44 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java @@ -11,7 +11,10 @@ import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Response; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AccuracyResultTests; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixResultTests; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.PrecisionResultTests; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; @@ -22,7 +25,7 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); } @Override @@ -30,11 +33,13 @@ protected Response createTestInstance() { String evaluationName = randomAlphaOfLength(10); List metrics = Arrays.asList( + AccuracyResultTests.createRandom(), + PrecisionResultTests.createRandom(), + RecallResultTests.createRandom(), MulticlassConfusionMatrixResultTests.createRandom(), new MeanSquaredError.Result(randomDouble()), new RSquared.Result(randomDouble())); - int numMetrics = randomIntBetween(0, metrics.size()); - return new Response(evaluationName, metrics.subList(0, numMetrics)); + return new Response(evaluationName, randomSubsetOf(metrics)); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java index bb3cc99192067..8fb4c6c02408d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java @@ -17,15 +17,9 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public class AccuracyResultTests extends AbstractWireSerializingTestCase { +public class AccuracyResultTests extends AbstractWireSerializingTestCase { - @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables()); - } - - @Override - protected Accuracy.Result createTestInstance() { + public static Result createRandom() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); List actualClasses = new ArrayList<>(numClasses); @@ -38,7 +32,17 @@ protected Accuracy.Result createTestInstance() { } @Override - protected Writeable.Reader instanceReader() { - return Accuracy.Result::new; + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); + } + + @Override + protected Result createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Result::new; } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java index 6deb06cf66dfd..132195e78d1d3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -8,6 +8,7 @@ import org.apache.lucene.search.TotalHits; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -19,8 +20,10 @@ import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; @@ -42,7 +45,7 @@ public class ClassificationTests extends AbstractSerializingTestCase metrics = + List metrics = randomSubsetOf( Arrays.asList( AccuracyTests.createRandom(), + PrecisionTests.createRandom(), + RecallTests.createRandom(), MulticlassConfusionMatrixTests.createRandom())); return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics); } @@ -101,10 +106,10 @@ public void testBuildSearch() { } public void testProcess_MultipleMetricsWithDifferentNumberOfSteps() { - ClassificationMetric metric1 = new FakeClassificationMetric("fake_metric_1", 2); - ClassificationMetric metric2 = new FakeClassificationMetric("fake_metric_2", 3); - ClassificationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4); - ClassificationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5); + EvaluationMetric metric1 = new FakeClassificationMetric("fake_metric_1", 2); + EvaluationMetric metric2 = new FakeClassificationMetric("fake_metric_2", 3); + EvaluationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4); + EvaluationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5); Classification evaluation = new Classification("act", "pred", Arrays.asList(metric1, metric2, metric3, metric4)); assertThat(metric1.getResult(), isEmpty()); @@ -168,7 +173,7 @@ private static SearchResponse mockSearchResponseWithNonZeroTotalHits() { * Number of steps is configurable. * Upon reaching the last step, the result is produced. */ - private static class FakeClassificationMetric implements ClassificationMetric { + private static class FakeClassificationMetric implements EvaluationMetric { private final String name; private final int numSteps; @@ -191,8 +196,8 @@ public String getWriteableName() { } @Override - public List aggs(String actualField, String predictedField) { - return Collections.emptyList(); + public Tuple, List> aggs(String actualField, String predictedField) { + return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java index bb6c484a545c8..da0778db140b8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java @@ -6,10 +6,12 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; @@ -25,9 +27,9 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase { @@ -74,8 +76,8 @@ public void testConstructor_SizeValidationFailures() { public void testAggs() { MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(); - List aggs = confusionMatrix.aggs("act", "pred"); - assertThat(aggs, is(not(empty()))); + Tuple, List> aggs = confusionMatrix.aggs("act", "pred"); + assertThat(aggs, isTuple(not(empty()), empty())); assertThat(confusionMatrix.getResult(), isEmpty()); } @@ -109,7 +111,7 @@ public void testEvaluate() { MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); confusionMatrix.process(aggs); - assertThat(confusionMatrix.aggs("act", "pred"), is(empty())); + assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); assertThat( @@ -151,7 +153,7 @@ public void testEvaluate_OtherClassesCountGreaterThanZero() { MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); confusionMatrix.process(aggs); - assertThat(confusionMatrix.aggs("act", "pred"), is(empty())); + assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); assertThat( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionResultTests.java new file mode 100644 index 0000000000000..b86448a4daacb --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionResultTests.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.PerClassResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.Result; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class PrecisionResultTests extends AbstractWireSerializingTestCase { + + public static Result createRandom() { + int numClasses = randomIntBetween(2, 100); + List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); + List classes = new ArrayList<>(numClasses); + for (int i = 0; i < numClasses; i++) { + double precision = randomDoubleBetween(0.0, 1.0, true); + classes.add(new PerClassResult(classNames.get(i), precision)); + } + double avgPrecision = randomDoubleBetween(0.0, 1.0, true); + return new Result(classes, avgPrecision); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); + } + + @Override + protected Result createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Result::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionTests.java new file mode 100644 index 0000000000000..85741a4c39ce8 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionTests.java @@ -0,0 +1,118 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + +import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; + +public class PrecisionTests extends AbstractSerializingTestCase { + + @Override + protected Precision doParseInstance(XContentParser parser) throws IOException { + return Precision.fromXContent(parser); + } + + @Override + protected Precision createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Precision::new; + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + public static Precision createRandom() { + return new Precision(); + } + + public void testProcess() { + Aggregations aggs = new Aggregations(Arrays.asList( + mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME), + mockFilters(Precision.BY_PREDICTED_CLASS_AGG_NAME), + mockSingleValue(Precision.AVG_PRECISION_AGG_NAME, 0.8123), + mockSingleValue("some_other_single_metric_agg", 0.2377) + )); + + Precision precision = new Precision(); + precision.process(aggs); + + assertThat(precision.aggs("act", "pred"), isTuple(empty(), empty())); + assertThat(precision.getResult().get(), equalTo(new Precision.Result(Collections.emptyList(), 0.8123))); + } + + public void testProcess_GivenMissingAgg() { + { + Aggregations aggs = new Aggregations(Arrays.asList( + mockFilters(Precision.BY_PREDICTED_CLASS_AGG_NAME), + mockSingleValue("some_other_single_metric_agg", 0.2377) + )); + Precision precision = new Precision(); + precision.process(aggs); + assertThat(precision.getResult(), isEmpty()); + } + { + Aggregations aggs = new Aggregations(Arrays.asList( + mockSingleValue(Precision.AVG_PRECISION_AGG_NAME, 0.8123), + mockSingleValue("some_other_single_metric_agg", 0.2377) + )); + Precision precision = new Precision(); + precision.process(aggs); + assertThat(precision.getResult(), isEmpty()); + } + } + + public void testProcess_GivenAggOfWrongType() { + { + Aggregations aggs = new Aggregations(Arrays.asList( + mockFilters(Precision.BY_PREDICTED_CLASS_AGG_NAME), + mockFilters(Precision.AVG_PRECISION_AGG_NAME) + )); + Precision precision = new Precision(); + precision.process(aggs); + assertThat(precision.getResult(), isEmpty()); + } + { + Aggregations aggs = new Aggregations(Arrays.asList( + mockSingleValue(Precision.BY_PREDICTED_CLASS_AGG_NAME, 1.0), + mockSingleValue(Precision.AVG_PRECISION_AGG_NAME, 0.8123) + )); + Precision precision = new Precision(); + precision.process(aggs); + assertThat(precision.getResult(), isEmpty()); + } + } + + public void testProcess_GivenCardinalityTooHigh() { + Aggregations aggs = + new Aggregations(Collections.singletonList(mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME, Collections.emptyList(), 1))); + Precision precision = new Precision(); + precision.aggs("foo", "bar"); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> precision.process(aggs)); + assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high")); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallResultTests.java new file mode 100644 index 0000000000000..a2a44ded76189 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallResultTests.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.PerClassResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.Result; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class RecallResultTests extends AbstractWireSerializingTestCase { + + public static Result createRandom() { + int numClasses = randomIntBetween(2, 100); + List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); + List classes = new ArrayList<>(numClasses); + for (int i = 0; i < numClasses; i++) { + double recall = randomDoubleBetween(0.0, 1.0, true); + classes.add(new PerClassResult(classNames.get(i), recall)); + } + double avgRecall = randomDoubleBetween(0.0, 1.0, true); + return new Result(classes, avgRecall); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); + } + + @Override + protected Result createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Result::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallTests.java new file mode 100644 index 0000000000000..e3062fa863544 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallTests.java @@ -0,0 +1,117 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + +import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; + +public class RecallTests extends AbstractSerializingTestCase { + + @Override + protected Recall doParseInstance(XContentParser parser) throws IOException { + return Recall.fromXContent(parser); + } + + @Override + protected Recall createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Recall::new; + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + public static Recall createRandom() { + return new Recall(); + } + + public void testProcess() { + Aggregations aggs = new Aggregations(Arrays.asList( + mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME), + mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123), + mockSingleValue("some_other_single_metric_agg", 0.2377) + )); + + Recall recall = new Recall(); + recall.process(aggs); + + assertThat(recall.aggs("act", "pred"), isTuple(empty(), empty())); + assertThat(recall.getResult().get(), equalTo(new Recall.Result(Collections.emptyList(), 0.8123))); + } + + public void testProcess_GivenMissingAgg() { + { + Aggregations aggs = new Aggregations(Arrays.asList( + mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME), + mockSingleValue("some_other_single_metric_agg", 0.2377) + )); + Recall recall = new Recall(); + recall.process(aggs); + assertThat(recall.getResult(), isEmpty()); + } + { + Aggregations aggs = new Aggregations(Arrays.asList( + mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123), + mockSingleValue("some_other_single_metric_agg", 0.2377) + )); + Recall recall = new Recall(); + recall.process(aggs); + assertThat(recall.getResult(), isEmpty()); + } + } + + public void testProcess_GivenAggOfWrongType() { + { + Aggregations aggs = new Aggregations(Arrays.asList( + mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME), + mockTerms(Recall.AVG_RECALL_AGG_NAME) + )); + Recall recall = new Recall(); + recall.process(aggs); + assertThat(recall.getResult(), isEmpty()); + } + { + Aggregations aggs = new Aggregations(Arrays.asList( + mockSingleValue(Recall.BY_ACTUAL_CLASS_AGG_NAME, 1.0), + mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123) + )); + Recall recall = new Recall(); + recall.process(aggs); + assertThat(recall.getResult(), isEmpty()); + } + } + + public void testProcess_GivenCardinalityTooHigh() { + Aggregations aggs = new Aggregations(Arrays.asList( + mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME, Collections.emptyList(), 1), + mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123))); + Recall recall = new Recall(); + recall.aggs("foo", "bar"); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> recall.process(aggs)); + assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high")); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/TupleMatchers.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/TupleMatchers.java new file mode 100644 index 0000000000000..8bd8ff7572f54 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/TupleMatchers.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.collect.Tuple; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; + +import java.util.Arrays; + +public class TupleMatchers { + + private static class TupleMatcher extends TypeSafeMatcher> { + + private final Matcher v1Matcher; + private final Matcher v2Matcher; + + private TupleMatcher(Matcher v1Matcher, Matcher v2Matcher) { + this.v1Matcher = v1Matcher; + this.v2Matcher = v2Matcher; + } + + @Override + protected boolean matchesSafely(final Tuple item) { + return item != null && v1Matcher.matches(item.v1()) && v2Matcher.matches(item.v2()); + } + + @Override + public void describeTo(final Description description) { + description.appendText("expected tuple matching ").appendList("[", ", ", "]", Arrays.asList(v1Matcher, v2Matcher)); + } + } + + /** + * Creates a matcher that matches iff: + * 1. the examined tuple's v1() matches the specified v1Matcher + * and + * 2. the examined tuple's v2() matches the specified v2Matcher + * For example: + *
assertThat(Tuple.tuple("myValue1", "myValue2"), isTuple(startsWith("my"), containsString("Val")))
+ */ + public static TupleMatcher isTuple(Matcher v1Matcher, Matcher v2Matcher) { + return new TupleMatcher(v1Matcher, v2Matcher); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java index 077998b66aed0..96ba97ecc9348 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import java.io.IOException; @@ -29,7 +30,7 @@ public class RegressionTests extends AbstractSerializingTestCase { @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); } @Override @@ -38,7 +39,7 @@ protected NamedXContentRegistry xContentRegistry() { } public static Regression createRandom() { - List metrics = new ArrayList<>(); + List metrics = new ArrayList<>(); if (randomBoolean()) { metrics.add(MeanSquaredErrorTests.createRandom()); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java index e63e88f6f848f..28e0a045b190d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import java.io.IOException; @@ -29,7 +30,7 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase metrics = new ArrayList<>(); + List metrics = new ArrayList<>(); if (randomBoolean()) { metrics.add(AucRocTests.createRandom()); } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 9f9db8084404d..d90609c896793 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -5,22 +5,24 @@ */ package org.elasticsearch.xpack.ml.integration; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; import org.junit.After; import org.junit.Before; import java.util.Arrays; import java.util.List; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -117,6 +119,69 @@ public void testEvaluate_Accuracy_BooleanField() { assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75)); } + public void testEvaluate_Precision() { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); + assertThat( + precisionResult.getClasses(), + equalTo( + Arrays.asList( + new Precision.PerClassResult("ant", 1.0 / 15), + new Precision.PerClassResult("cat", 1.0 / 15), + new Precision.PerClassResult("dog", 1.0 / 15), + new Precision.PerClassResult("fox", 1.0 / 15), + new Precision.PerClassResult("mouse", 1.0 / 15)))); + assertThat(precisionResult.getAvgPrecision(), equalTo(5.0 / 75)); + } + + public void testEvaluate_Precision_CardinalityTooHigh() { + ElasticsearchStatusException e = + expectThrows( + ElasticsearchStatusException.class, + () -> evaluateDataFrame( + ANIMALS_DATA_INDEX, + new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision(4))))); + assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); + } + + public void testEvaluate_Recall() { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); + assertThat( + recallResult.getClasses(), + equalTo( + Arrays.asList( + new Recall.PerClassResult("ant", 1.0 / 15), + new Recall.PerClassResult("cat", 1.0 / 15), + new Recall.PerClassResult("dog", 1.0 / 15), + new Recall.PerClassResult("fox", 1.0 / 15), + new Recall.PerClassResult("mouse", 1.0 / 15)))); + assertThat(recallResult.getAvgRecall(), equalTo(5.0 / 75)); + } + + public void testEvaluate_Recall_CardinalityTooHigh() { + ElasticsearchStatusException e = + expectThrows( + ElasticsearchStatusException.class, + () -> evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall(4))))); + assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); + } + public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( @@ -132,50 +197,50 @@ public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() { assertThat( confusionMatrixResult.getConfusionMatrix(), equalTo(Arrays.asList( - new ActualClass("ant", + new MulticlassConfusionMatrix.ActualClass("ant", 15, Arrays.asList( - new PredictedClass("ant", 1L), - new PredictedClass("cat", 4L), - new PredictedClass("dog", 3L), - new PredictedClass("fox", 2L), - new PredictedClass("mouse", 5L)), + new MulticlassConfusionMatrix.PredictedClass("ant", 1L), + new MulticlassConfusionMatrix.PredictedClass("cat", 4L), + new MulticlassConfusionMatrix.PredictedClass("dog", 3L), + new MulticlassConfusionMatrix.PredictedClass("fox", 2L), + new MulticlassConfusionMatrix.PredictedClass("mouse", 5L)), 0), - new ActualClass("cat", + new MulticlassConfusionMatrix.ActualClass("cat", 15, Arrays.asList( - new PredictedClass("ant", 3L), - new PredictedClass("cat", 1L), - new PredictedClass("dog", 5L), - new PredictedClass("fox", 4L), - new PredictedClass("mouse", 2L)), + new MulticlassConfusionMatrix.PredictedClass("ant", 3L), + new MulticlassConfusionMatrix.PredictedClass("cat", 1L), + new MulticlassConfusionMatrix.PredictedClass("dog", 5L), + new MulticlassConfusionMatrix.PredictedClass("fox", 4L), + new MulticlassConfusionMatrix.PredictedClass("mouse", 2L)), 0), - new ActualClass("dog", + new MulticlassConfusionMatrix.ActualClass("dog", 15, Arrays.asList( - new PredictedClass("ant", 4L), - new PredictedClass("cat", 2L), - new PredictedClass("dog", 1L), - new PredictedClass("fox", 5L), - new PredictedClass("mouse", 3L)), + new MulticlassConfusionMatrix.PredictedClass("ant", 4L), + new MulticlassConfusionMatrix.PredictedClass("cat", 2L), + new MulticlassConfusionMatrix.PredictedClass("dog", 1L), + new MulticlassConfusionMatrix.PredictedClass("fox", 5L), + new MulticlassConfusionMatrix.PredictedClass("mouse", 3L)), 0), - new ActualClass("fox", + new MulticlassConfusionMatrix.ActualClass("fox", 15, Arrays.asList( - new PredictedClass("ant", 5L), - new PredictedClass("cat", 3L), - new PredictedClass("dog", 2L), - new PredictedClass("fox", 1L), - new PredictedClass("mouse", 4L)), + new MulticlassConfusionMatrix.PredictedClass("ant", 5L), + new MulticlassConfusionMatrix.PredictedClass("cat", 3L), + new MulticlassConfusionMatrix.PredictedClass("dog", 2L), + new MulticlassConfusionMatrix.PredictedClass("fox", 1L), + new MulticlassConfusionMatrix.PredictedClass("mouse", 4L)), 0), - new ActualClass("mouse", + new MulticlassConfusionMatrix.ActualClass("mouse", 15, Arrays.asList( - new PredictedClass("ant", 2L), - new PredictedClass("cat", 5L), - new PredictedClass("dog", 4L), - new PredictedClass("fox", 3L), - new PredictedClass("mouse", 1L)), + new MulticlassConfusionMatrix.PredictedClass("ant", 2L), + new MulticlassConfusionMatrix.PredictedClass("cat", 5L), + new MulticlassConfusionMatrix.PredictedClass("dog", 4L), + new MulticlassConfusionMatrix.PredictedClass("fox", 3L), + new MulticlassConfusionMatrix.PredictedClass("mouse", 1L)), 0)))); assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L)); } @@ -194,17 +259,26 @@ public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() { assertThat( confusionMatrixResult.getConfusionMatrix(), equalTo(Arrays.asList( - new ActualClass("ant", + new MulticlassConfusionMatrix.ActualClass("ant", 15, - Arrays.asList(new PredictedClass("ant", 1L), new PredictedClass("cat", 4L), new PredictedClass("dog", 3L)), + Arrays.asList( + new MulticlassConfusionMatrix.PredictedClass("ant", 1L), + new MulticlassConfusionMatrix.PredictedClass("cat", 4L), + new MulticlassConfusionMatrix.PredictedClass("dog", 3L)), 7), - new ActualClass("cat", + new MulticlassConfusionMatrix.ActualClass("cat", 15, - Arrays.asList(new PredictedClass("ant", 3L), new PredictedClass("cat", 1L), new PredictedClass("dog", 5L)), + Arrays.asList( + new MulticlassConfusionMatrix.PredictedClass("ant", 3L), + new MulticlassConfusionMatrix.PredictedClass("cat", 1L), + new MulticlassConfusionMatrix.PredictedClass("dog", 5L)), 6), - new ActualClass("dog", + new MulticlassConfusionMatrix.ActualClass("dog", 15, - Arrays.asList(new PredictedClass("ant", 4L), new PredictedClass("cat", 2L), new PredictedClass("dog", 1L)), + Arrays.asList( + new MulticlassConfusionMatrix.PredictedClass("ant", 4L), + new MulticlassConfusionMatrix.PredictedClass("cat", 2L), + new MulticlassConfusionMatrix.PredictedClass("dog", 1L)), 8)))); assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L)); } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 33537af3f4477..a95d104eee97f 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -27,6 +27,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall; import org.junit.After; import java.util.ArrayList; @@ -450,9 +452,11 @@ private void assertEvaluation(String dependentVariable, List dependentVar evaluateDataFrame( destIndex, new org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification( - dependentVariable, predictedClassField, Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix()))); + dependentVariable, + predictedClassField, + Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4)); { // Accuracy Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); @@ -483,6 +487,24 @@ private void assertEvaluation(String dependentVariable, List dependentVar } assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L)); } + + { // Precision + Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(2); + assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); + for (Precision.PerClassResult klass : precisionResult.getClasses()) { + assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings))); + assertThat(klass.getPrecision(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); + } + } + + { // Recall + Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(3); + assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); + for (Recall.PerClassResult klass : recallResult.getClasses()) { + assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings))); + assertThat(klass.getRecall(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); + } + } } protected String stateDocId() { diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 9d0d645e3d33a..95a7ef4e33218 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -632,6 +632,58 @@ setup: accuracy: 0.5 # 1 out of 2 overall_accuracy: 0.625 # 5 out of 8 --- +"Test classification precision": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "predicted_field": "classification_field_pred.keyword", + "metrics": { "precision": {} } + } + } + } + + - match: + classification.precision: + classes: + - class_name: "cat" + precision: 0.5 # 2 out of 4 + - class_name: "dog" + precision: 0.6666666666666666 # 2 out of 3 + - class_name: "mouse" + precision: 1.0 # 1 out of 1 + avg_precision: 0.7222222222222222 +--- +"Test classification recall": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "predicted_field": "classification_field_pred.keyword", + "metrics": { "recall": {} } + } + } + } + + - match: + classification.recall: + classes: + - class_name: "cat" + recall: 0.6666666666666666 # 2 out of 3 + - class_name: "dog" + recall: 0.6666666666666666 # 2 out of 3 + - class_name: "mouse" + recall: 0.5 # 1 out of 2 + avg_recall: 0.611111111111111 +--- "Test classification multiclass_confusion_matrix": - do: ml.evaluate_data_frame: