diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java index 24b5e345430df..29cab602dab06 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -27,7 +27,7 @@ public class InferModelAction extends ActionType { public static final InferModelAction INSTANCE = new InferModelAction(); - public static final String NAME = "cluster:admin/xpack/ml/infer"; + public static final String NAME = "cluster:admin/xpack/ml/inference/infer"; private InferModelAction() { super(NAME, Response::new); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java index 5aa0403d94753..4c9fc4d89e93b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java @@ -5,12 +5,16 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; public class ClassificationConfig implements InferenceConfig { @@ -18,11 +22,21 @@ public class ClassificationConfig implements InferenceConfig { public static final String NAME = "classification"; public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + private static final Version MIN_SUPPORTED_VERSION = Version.V_8_0_0; public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0); private final int numTopClasses; + public static ClassificationConfig fromMap(Map map) { + Map options = new HashMap<>(map); + Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName()); + if (options.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet()); + } + return new ClassificationConfig(numTopClasses); + } + public ClassificationConfig(Integer numTopClasses) { this.numTopClasses = numTopClasses == null ? 0 : numTopClasses; } @@ -78,4 +92,9 @@ public boolean isTargetTypeSupported(TargetType targetType) { return TargetType.CLASSIFICATION.equals(targetType); } + @Override + public Version getMinimalSupportedVersion() { + return MIN_SUPPORTED_VERSION; + } + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java index 6129d71d5ff95..5d1dc7983ff3c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; @@ -13,4 +14,8 @@ public interface InferenceConfig extends NamedXContentObject, NamedWriteable { boolean isTargetTypeSupported(TargetType targetType); + /** + * All nodes in the cluster must be at least this version + */ + Version getMinimalSupportedVersion(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java index bb7f772f86ba4..58bd7bbd3d558 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java @@ -5,16 +5,27 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.util.Map; import java.util.Objects; public class RegressionConfig implements InferenceConfig { public static final String NAME = "regression"; + private static final Version MIN_SUPPORTED_VERSION = Version.V_8_0_0; + + public static RegressionConfig fromMap(Map map) { + if (map.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet()); + } + return new RegressionConfig(); + } public RegressionConfig() { } @@ -61,4 +72,9 @@ public boolean isTargetTypeSupported(TargetType targetType) { return TargetType.REGRESSION.equals(targetType); } + @Override + public Version getMinimalSupportedVersion() { + return MIN_SUPPORTED_VERSION; + } + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java index 7628d0beec25f..42757d889818e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; @@ -26,6 +27,11 @@ public boolean isTargetTypeSupported(TargetType targetType) { return true; } + @Override + public Version getMinimalSupportedVersion() { + return Version.CURRENT; + } + @Override public String getWriteableName() { return "null"; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index c6558d781bbbc..05bc8250333e0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -85,6 +85,8 @@ public final class Messages { public static final String INFERENCE_FAILED_TO_SERIALIZE_MODEL = "Failed to serialize the trained model [{0}] for storage"; public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]"; + public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION = + "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]"; public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_CREATED = "Job created"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java index 4df3263215f63..808aaf960f4e1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java @@ -5,15 +5,35 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import java.util.Collections; + +import static org.hamcrest.Matchers.equalTo; + public class ClassificationConfigTests extends AbstractWireSerializingTestCase { public static ClassificationConfig randomClassificationConfig() { return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10)); } + public void testFromMap() { + ClassificationConfig expected = new ClassificationConfig(0); + assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected)); + + expected = new ClassificationConfig(3); + assertThat(ClassificationConfig.fromMap(Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3)), + equalTo(expected)); + } + + public void testFromMapWithUnknownField() { + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> ClassificationConfig.fromMap(Collections.singletonMap("some_key", 1))); + assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + } + @Override protected ClassificationConfig createTestInstance() { return randomClassificationConfig(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java index 57efcdd15009a..bdb0e6d03201f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java @@ -5,15 +5,31 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import java.util.Collections; + +import static org.hamcrest.Matchers.equalTo; + public class RegressionConfigTests extends AbstractWireSerializingTestCase { public static RegressionConfig randomRegressionConfig() { return new RegressionConfig(); } + public void testFromMap() { + RegressionConfig expected = new RegressionConfig(); + assertThat(RegressionConfig.fromMap(Collections.emptyMap()), equalTo(expected)); + } + + public void testFromMapWithUnknownField() { + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> RegressionConfig.fromMap(Collections.singletonMap("some_key", 1))); + assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + } + @Override protected RegressionConfig createTestInstance() { return randomRegressionConfig(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java new file mode 100644 index 0000000000000..852b3fcea0f0e --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -0,0 +1,530 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; +import org.elasticsearch.action.ingest.SimulateDocumentBaseResult; +import org.elasticsearch.action.ingest.SimulatePipelineResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.junit.Before; + +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { + + @Before + public void createBothModels() { + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, + "_doc", + "test_classification") + .setSource(CLASSIFICATION_MODEL, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get().status(), equalTo(RestStatus.CREATED)); + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, + "_doc", + "test_regression") + .setSource(REGRESSION_MODEL, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get().status(), equalTo(RestStatus.CREATED)); + } + + public void testPipelineCreationAndDeletion() throws Exception { + + for (int i = 0; i < 10; i++) { + assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline", + new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + client().prepareIndex("index_for_inference_test", "_doc") + .setSource(new HashMap<>(){{ + put("col1", randomFrom("female", "male")); + put("col2", randomFrom("S", "M", "L", "XL")); + put("col3", randomFrom("true", "false", "none", "other")); + put("col4", randomIntBetween(0, 10)); + }}) + .setPipeline("simple_classification_pipeline") + .get(); + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(), + is(true)); + + assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline", + new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + client().prepareIndex("index_for_inference_test", "_doc") + .setSource(new HashMap<>(){{ + put("col1", randomFrom("female", "male")); + put("col2", randomFrom("S", "M", "L", "XL")); + put("col3", randomFrom("true", "false", "none", "other")); + put("col4", randomIntBetween(0, 10)); + }}) + .setPipeline("simple_regression_pipeline") + .get(); + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(), + is(true)); + } + + assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline", + new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline", + new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + for (int i = 0; i < 10; i++) { + client().prepareIndex("index_for_inference_test", "_doc") + .setSource(generateSourceDoc()) + .setPipeline("simple_classification_pipeline") + .get(); + + client().prepareIndex("index_for_inference_test", "_doc") + .setSource(generateSourceDoc()) + .setPipeline("simple_regression_pipeline") + .get(); + } + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(), + is(true)); + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(), + is(true)); + + client().admin().indices().refresh(new RefreshRequest("index_for_inference_test")).get(); + + assertThat(client().search(new SearchRequest().indices("index_for_inference_test") + .source(new SearchSourceBuilder() + .size(0) + .trackTotalHits(true) + .query(QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery("regression_value"))))).get().getHits().getTotalHits().value, + equalTo(20L)); + + assertThat(client().search(new SearchRequest().indices("index_for_inference_test") + .source(new SearchSourceBuilder() + .size(0) + .trackTotalHits(true) + .query(QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery("result_class"))))).get().getHits().getTotalHits().value, + equalTo(20L)); + + } + + public void testSimulate() { + String source = "{\n" + + " \"pipeline\": {\n" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class\",\n" + + " \"inference_config\": {\"classification\":{}},\n" + + " \"model_id\": \"test_classification\",\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class_prob\",\n" + + " \"inference_config\": {\"classification\": {\"num_top_classes\":2}},\n" + + " \"model_id\": \"test_classification\",\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"regression_value\",\n" + + " \"model_id\": \"test_regression\",\n" + + " \"inference_config\": {\"regression\":{}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " },\n" + + " \"docs\": [\n" + + " {\"_source\": {\n" + + " \"col1\": \"female\",\n" + + " \"col2\": \"M\",\n" + + " \"col3\": \"none\",\n" + + " \"col4\": 10\n" + + " }}]\n" + + "}"; + + SimulatePipelineResponse response = client().admin().cluster() + .prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get(); + SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0); + assertThat(baseResult.getIngestDocument().getFieldValue("regression_value", Double.class), equalTo(1.0)); + assertThat(baseResult.getIngestDocument().getFieldValue("result_class", String.class), equalTo("second")); + assertThat(baseResult.getIngestDocument().getFieldValue("result_class_prob", List.class).size(), equalTo(2)); + + String sourceWithMissingModel = "{\n" + + " \"pipeline\": {\n" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class\",\n" + + " \"model_id\": \"test_classification_missing\",\n" + + " \"inference_config\": {\"classification\":{}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " },\n" + + " \"docs\": [\n" + + " {\"_source\": {\n" + + " \"col1\": \"female\",\n" + + " \"col2\": \"M\",\n" + + " \"col3\": \"none\",\n" + + " \"col4\": 10\n" + + " }}]\n" + + "}"; + + response = client().admin().cluster() + .prepareSimulatePipeline(new BytesArray(sourceWithMissingModel.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get(); + + assertThat(((SimulateDocumentBaseResult) response.getResults().get(0)).getFailure().getMessage(), + containsString("Could not find trained model [test_classification_missing]")); + } + + private Map generateSourceDoc() { + return new HashMap<>(){{ + put("col1", randomFrom("female", "male")); + put("col2", randomFrom("S", "M", "L", "XL")); + put("col3", randomFrom("true", "false", "none", "other")); + put("col4", randomIntBetween(0, 10)); + }}; + } + + private static final String REGRESSION_MODEL = "{" + + " \"model_id\": \"test_regression\",\n" + + " \"model_version\": 0,\n" + + " \"definition\": {\n" + + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + + " \"preprocessors\": [\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field\": \"col1\",\n" + + " \"hot_map\": {\n" + + " \"male\": \"col1_male\",\n" + + " \"female\": \"col1_female\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"target_mean_encoding\": {\n" + + " \"field\": \"col2\",\n" + + " \"feature_name\": \"col2_encoded\",\n" + + " \"target_map\": {\n" + + " \"S\": 5.0,\n" + + " \"M\": 10.0,\n" + + " \"L\": 20\n" + + " },\n" + + " \"default_value\": 5.0\n" + + " }\n" + + " },\n" + + " {\n" + + " \"frequency_encoding\": {\n" + + " \"field\": \"col3\",\n" + + " \"feature_name\": \"col3_encoded\",\n" + + " \"frequency_map\": {\n" + + " \"none\": 0.75,\n" + + " \"true\": 0.10,\n" + + " \"false\": 0.15\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"trained_model\": {\n" + + " \"ensemble\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"aggregate_output\": {\n" + + " \"weighted_sum\": {\n" + + " \"weights\": [\n" + + " 0.5,\n" + + " 0.5\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"target_type\": \"regression\",\n" + + " \"trained_models\": [\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"description\": \"test model for regression\",\n" + + " \"version\": \"8.0.0\",\n" + + " \"created_by\": \"ml_test\",\n" + + " \"model_type\": \"local\",\n" + + " \"created_time\": 0" + + "}"; + + private static final String CLASSIFICATION_MODEL = "" + + "{\n" + + " \"model_id\": \"test_classification\",\n" + + " \"model_version\": 0,\n" + + " \"definition\":{\n" + + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + + " \"preprocessors\": [\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field\": \"col1\",\n" + + " \"hot_map\": {\n" + + " \"male\": \"col1_male\",\n" + + " \"female\": \"col1_female\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"target_mean_encoding\": {\n" + + " \"field\": \"col2\",\n" + + " \"feature_name\": \"col2_encoded\",\n" + + " \"target_map\": {\n" + + " \"S\": 5.0,\n" + + " \"M\": 10.0,\n" + + " \"L\": 20\n" + + " },\n" + + " \"default_value\": 5.0\n" + + " }\n" + + " },\n" + + " {\n" + + " \"frequency_encoding\": {\n" + + " \"field\": \"col3\",\n" + + " \"feature_name\": \"col3_encoded\",\n" + + " \"frequency_map\": {\n" + + " \"none\": 0.75,\n" + + " \"true\": 0.10,\n" + + " \"false\": 0.15\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"trained_model\": {\n" + + " \"ensemble\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"aggregate_output\": {\n" + + " \"weighted_mode\": {\n" + + " \"weights\": [\n" + + " 0.5,\n" + + " 0.5\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"target_type\": \"classification\",\n" + + " \"classification_labels\": [\"first\", \"second\"],\n" + + " \"trained_models\": [\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"description\": \"test model for classification\",\n" + + " \"version\": \"8.0.0\",\n" + + " \"created_by\": \"benwtrent\",\n" + + " \"model_type\": \"local\",\n" + + " \"created_time\": 0\n" + + "}"; + + private static final String CLASSIFICATION_PIPELINE = "{" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class\",\n" + + " \"model_id\": \"test_classification\",\n" + + " \"inference_config\": {\"classification\": {}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }]}\n"; + + private static final String REGRESSION_PIPELINE = "{" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"regression_value\",\n" + + " \"model_id\": \"test_regression\",\n" + + " \"inference_config\": {\"regression\": {}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }]}\n"; + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 500a71b3a9416..72e903d79baf5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -208,6 +208,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex; @@ -313,7 +314,7 @@ import static java.util.Collections.emptyList; import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME; -public class MachineLearning extends Plugin implements ActionPlugin, IngestPlugin, AnalysisPlugin, PersistentTaskPlugin { +public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin { public static final String NAME = "ml"; public static final String BASE_PATH = "/_ml/"; public static final String PRE_V7_BASE_PATH = "/_xpack/ml/"; @@ -341,7 +342,8 @@ protected Setting roleSetting() { public Map getProcessors(Processor.Parameters parameters) { InferenceProcessor.Factory inferenceFactory = new InferenceProcessor.Factory(parameters.client, parameters.ingestService.getClusterService(), - this.settings); + this.settings, + parameters.ingestService); parameters.ingestService.addIngestClusterStateListener(inferenceFactory); return Collections.singletonMap(InferenceProcessor.TYPE, inferenceFactory); } @@ -435,7 +437,9 @@ public List> getSettings() { AutodetectBuilder.MAX_ANOMALY_RECORDS_SETTING_DYNAMIC, MAX_OPEN_JOBS_PER_NODE, MIN_DISK_SPACE_OFF_HEAP, - MlConfigMigrationEligibilityCheck.ENABLE_CONFIG_MIGRATION); + MlConfigMigrationEligibilityCheck.ENABLE_CONFIG_MIGRATION, + InferenceProcessor.MAX_INFERENCE_PROCESSORS + ); } public Settings additionalSettings() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index b8cccc0d45e23..40ca9ba253594 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -5,32 +5,86 @@ */ package org.elasticsearch.xpack.ml.inference.ingest; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.Client; -import org.elasticsearch.ingest.AbstractProcessor; -import org.elasticsearch.ingest.IngestDocument; - -import java.util.function.BiConsumer; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.ingest.AbstractProcessor; import org.elasticsearch.ingest.ConfigurationUtils; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.IngestService; +import org.elasticsearch.ingest.Pipeline; +import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.ingest.Processor; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; import java.util.Map; +import java.util.function.BiConsumer; import java.util.function.Consumer; +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + public class InferenceProcessor extends AbstractProcessor { + // How many total inference processors are allowed to be used in the cluster. + public static final Setting MAX_INFERENCE_PROCESSORS = Setting.intSetting("xpack.ml.max_inference_processors", + 50, + 1, + Setting.Property.Dynamic, + Setting.Property.NodeScope); + public static final String TYPE = "inference"; public static final String MODEL_ID = "model_id"; + public static final String INFERENCE_CONFIG = "inference_config"; + public static final String TARGET_FIELD = "target_field"; + public static final String FIELD_MAPPINGS = "field_mappings"; + public static final String MODEL_INFO_FIELD = "model_info_field"; private final Client client; private final String modelId; - public InferenceProcessor(Client client, String tag, String modelId) { + private final String targetField; + private final String modelInfoField; + private final Map modelInfo; + private final InferenceConfig inferenceConfig; + private final Map fieldMapping; + + public InferenceProcessor(Client client, + String tag, + String targetField, + String modelId, + InferenceConfig inferenceConfig, + Map fieldMapping, + String modelInfoField) { super(tag); this.client = client; + this.targetField = targetField; + this.modelInfoField = modelInfoField; this.modelId = modelId; + this.inferenceConfig = inferenceConfig; + this.fieldMapping = fieldMapping; + this.modelInfo = new HashMap<>(); + this.modelInfo.put("model_id", modelId); } public String getModelId() { @@ -39,8 +93,44 @@ public String getModelId() { @Override public void execute(IngestDocument ingestDocument, BiConsumer handler) { - //TODO actually work - handler.accept(ingestDocument, null); + executeAsyncWithOrigin(client, + ML_ORIGIN, + InferModelAction.INSTANCE, + this.buildRequest(ingestDocument), + ActionListener.wrap( + r -> { + try { + mutateDocument(r, ingestDocument); + handler.accept(ingestDocument, null); + } catch(ElasticsearchException ex) { + handler.accept(ingestDocument, ex); + } + }, + e -> handler.accept(ingestDocument, e) + )); + } + + InferModelAction.Request buildRequest(IngestDocument ingestDocument) { + Map fields = new HashMap<>(ingestDocument.getSourceAndMetadata()); + if (fieldMapping != null) { + fieldMapping.forEach((src, dest) -> { + Object srcValue = fields.remove(src); + if (srcValue != null) { + fields.put(dest, srcValue); + } + }); + } + return new InferModelAction.Request(modelId, fields, inferenceConfig); + } + + void mutateDocument(InferModelAction.Response response, IngestDocument ingestDocument) { + if (response.getInferenceResults().isEmpty()) { + throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR); + } + response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField); + if (modelInfoField != null) { + ingestDocument.setFieldValue(modelInfoField, modelInfo); + } } @Override @@ -53,26 +143,123 @@ public String getType() { return TYPE; } - public static class Factory implements Processor.Factory, Consumer { + public static final class Factory implements Processor.Factory, Consumer { + + private static final Logger logger = LogManager.getLogger(Factory.class); private final Client client; - private final ClusterService clusterService; + private final IngestService ingestService; + private volatile int currentInferenceProcessors; + private volatile int maxIngestProcessors; + private volatile Version minNodeVersion = Version.CURRENT; - public Factory(Client client, ClusterService clusterService, Settings settings) { + public Factory(Client client, ClusterService clusterService, Settings settings, IngestService ingestService) { this.client = client; - this.clusterService = clusterService; + this.maxIngestProcessors = MAX_INFERENCE_PROCESSORS.get(settings); + this.ingestService = ingestService; + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_INFERENCE_PROCESSORS, this::setMaxIngestProcessors); + } + + @Override + public void accept(ClusterState state) { + minNodeVersion = state.nodes().getMinNodeVersion(); + MetaData metaData = state.getMetaData(); + if (metaData == null) { + currentInferenceProcessors = 0; + return; + } + IngestMetadata ingestMetadata = metaData.custom(IngestMetadata.TYPE); + if (ingestMetadata == null) { + currentInferenceProcessors = 0; + return; + } + + int count = 0; + for (PipelineConfiguration configuration : ingestMetadata.getPipelines().values()) { + try { + Pipeline pipeline = Pipeline.create(configuration.getId(), + configuration.getConfigAsMap(), + ingestService.getProcessorFactories(), + ingestService.getScriptService()); + count += pipeline.getProcessors().stream().filter(processor -> processor instanceof InferenceProcessor).count(); + } catch (Exception ex) { + logger.warn(new ParameterizedMessage("failure parsing pipeline config [{}]", configuration.getId()), ex); + } + } + currentInferenceProcessors = count; + } + + // Used for testing + int numInferenceProcessors() { + return currentInferenceProcessors; } @Override - public Processor create(Map processorFactories, String tag, Map config) + public InferenceProcessor create(Map processorFactories, String tag, Map config) throws Exception { + + if (this.maxIngestProcessors <= currentInferenceProcessors) { + throw new ElasticsearchStatusException("Max number of inference processors reached, total inference processors [{}]. " + + "Adjust the setting [{}]: [{}] if a greater number is desired.", + RestStatus.CONFLICT, + currentInferenceProcessors, + MAX_INFERENCE_PROCESSORS.getKey(), + maxIngestProcessors); + } + String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID); - return new InferenceProcessor(client, tag, modelId); + String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD); + Map fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS); + InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG)); + String modelInfoField = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_INFO_FIELD, "_model_info"); + if (modelInfoField != null && tag != null) { + modelInfoField += "." + tag; + } + return new InferenceProcessor(client, tag, targetField, modelId, inferenceConfig, fieldMapping, modelInfoField); } - @Override - public void accept(ClusterState clusterState) { + // Package private for testing + void setMaxIngestProcessors(int maxIngestProcessors) { + logger.trace("updating setting maxIngestProcessors from [{}] to [{}]", this.maxIngestProcessors, maxIngestProcessors); + this.maxIngestProcessors = maxIngestProcessors; + } + + InferenceConfig inferenceConfigFromMap(Map inferenceConfig) throws IOException { + ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); + + if (inferenceConfig.size() != 1) { + throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.", + INFERENCE_CONFIG); + } + Object value = inferenceConfig.values().iterator().next(); + + if ((value instanceof Map) == false) { + throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.", + INFERENCE_CONFIG); + } + @SuppressWarnings("unchecked") + Map valueMap = (Map)value; + + if (inferenceConfig.containsKey(ClassificationConfig.NAME)) { + checkSupportedVersion(new ClassificationConfig(0)); + return ClassificationConfig.fromMap(valueMap); + } else if (inferenceConfig.containsKey(RegressionConfig.NAME)) { + checkSupportedVersion(new RegressionConfig()); + return RegressionConfig.fromMap(valueMap); + } else { + throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}", + inferenceConfig.keySet(), + Arrays.asList(ClassificationConfig.NAME, RegressionConfig.NAME)); + } + } + void checkSupportedVersion(InferenceConfig config) { + if (config.getMinimalSupportedVersion().after(minNodeVersion)) { + throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION, + config.getName(), + config.getMinimalSupportedVersion(), + minNodeVersion)); + } } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index e6a5baf42ba12..da294f1e1580b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -52,20 +52,21 @@ public void getModel(String modelId, ActionListener modelActionListener) if (cachedModel != null) { if (cachedModel.isSuccess()) { modelActionListener.onResponse(cachedModel.getModel()); + logger.trace("[{}] loaded from cache", modelId); return; } } if (loadModelIfNecessary(modelId, modelActionListener) == false) { // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called // by a simulated pipeline - logger.debug("[{}] not actively loading, eager loading without cache", modelId); + logger.trace("[{}] not actively loading, eager loading without cache", modelId); provider.getTrainedModel(modelId, ActionListener.wrap( trainedModelConfig -> modelActionListener.onResponse(new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition())), modelActionListener::onFailure )); } else { - logger.debug("[{}] is currently loading, added new listener to queue", modelId); + logger.trace("[{}] is loading or loaded, added new listener to queue", modelId); } } @@ -87,7 +88,7 @@ private boolean loadModelIfNecessary(String modelId, ActionListener model if (loadingListeners.computeIfPresent( modelId, (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) { - logger.debug("[{}] attempting to load and cache", modelId); + logger.trace("[{}] attempting to load and cache", modelId); loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener)); loadModel(modelId); } @@ -156,6 +157,8 @@ public void clusterChanged(ClusterChangedEvent event) { // The listeners still waiting for a model and we are canceling the load? List>>> drainWithFailure = new ArrayList<>(); synchronized (loadingListeners) { + HashSet loadedModelBeforeClusterState = logger.isTraceEnabled() ? new HashSet<>(loadedModels.keySet()) : null; + HashSet loadingModelBeforeClusterState = logger.isTraceEnabled() ? new HashSet<>(loadingListeners.keySet()) : null; // If we had models still loading here but are no longer referenced // we should remove them from loadingListeners and alert the listeners for (String modelId : loadingListeners.keySet()) { @@ -180,6 +183,17 @@ public void clusterChanged(ClusterChangedEvent event) { for (String modelId : allReferencedModelKeys) { loadingListeners.put(modelId, new ArrayDeque<>()); } + if (loadedModelBeforeClusterState != null && loadingModelBeforeClusterState != null) { + if (loadingListeners.keySet().equals(loadingModelBeforeClusterState) == false) { + logger.trace("cluster state event changed loading models: before {} after {}", loadingModelBeforeClusterState, + loadingListeners.keySet()); + } + if (loadedModels.keySet().equals(loadedModelBeforeClusterState) == false) { + logger.trace("cluster state event changed loaded models: before {} after {}", loadedModelBeforeClusterState, + loadedModels.keySet()); + } + } + } for (Tuple>> modelAndListeners : drainWithFailure) { final String msg = new ParameterizedMessage( @@ -223,7 +237,6 @@ private static Set getReferencedModelKeys(IngestMetadata ingestMetadata) Object modelId = ((Map)processorConfig).get(InferenceProcessor.MODEL_ID); if (modelId != null) { assert modelId instanceof String; - // TODO also read model version allReferencedModelKeys.add(modelId.toString()); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java new file mode 100644 index 0000000000000..322b5cfb4ec2c --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -0,0 +1,266 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.ingest; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.Version; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.IngestService; +import org.elasticsearch.ingest.PipelineConfiguration; +import org.elasticsearch.ingest.Processor; +import org.elasticsearch.plugins.IngestPlugin; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.junit.Before; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class InferenceProcessorFactoryTests extends ESTestCase { + + private static final IngestPlugin SKINNY_PLUGIN = new IngestPlugin() { + @Override + public Map getProcessors(Processor.Parameters parameters) { + return Collections.singletonMap(InferenceProcessor.TYPE, + new InferenceProcessor.Factory(parameters.client, + parameters.ingestService.getClusterService(), + Settings.EMPTY, + parameters.ingestService)); + } + }; + private Client client; + private ClusterService clusterService; + private IngestService ingestService; + + @Before + public void setUpVariables() { + ThreadPool tp = mock(ThreadPool.class); + client = mock(Client.class); + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, + Collections.singleton(InferenceProcessor.MAX_INFERENCE_PROCESSORS)); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + ingestService = new IngestService(clusterService, tp, null, null, + null, Collections.singletonList(SKINNY_PLUGIN), client); + } + + public void testNumInferenceProcessors() throws Exception { + MetaData metaData = null; + + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY, ingestService); + processorFactory.accept(buildClusterState(metaData)); + + assertThat(processorFactory.numInferenceProcessors(), equalTo(0)); + metaData = MetaData.builder().build(); + + processorFactory.accept(buildClusterState(metaData)); + assertThat(processorFactory.numInferenceProcessors(), equalTo(0)); + + processorFactory.accept(buildClusterStateWithModelReferences("model1", "model2", "model3")); + assertThat(processorFactory.numInferenceProcessors(), equalTo(3)); + } + + public void testCreateProcessorWithTooManyExisting() throws Exception { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.builder().put(InferenceProcessor.MAX_INFERENCE_PROCESSORS.getKey(), 1).build(), + ingestService); + + processorFactory.accept(buildClusterStateWithModelReferences("model1")); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", Collections.emptyMap())); + + assertThat(ex.getMessage(), equalTo("Max number of inference processors reached, total inference processors [1]. " + + "Adjust the setting [xpack.ml.max_inference_processors]: [1] if a greater number is desired.")); + } + + public void testCreateProcessorWithInvalidInferenceConfig() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService); + + Map config = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("unknown_type", Collections.emptyMap())); + }}; + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config)); + assertThat(ex.getMessage(), + equalTo("unrecognized inference configuration type [unknown_type]. Supported types [classification, regression]")); + + Map config2 = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("regression", "boom")); + }}; + ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config2)); + assertThat(ex.getMessage(), + equalTo("inference_config must be an object with one inference type mapped to an object.")); + + Map config3 = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.emptyMap()); + }}; + ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config3)); + assertThat(ex.getMessage(), + equalTo("inference_config must be an object with one inference type mapped to an object.")); + } + + public void testCreateProcessorWithTooOldMinNodeVersion() throws IOException { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService); + processorFactory.accept(builderClusterStateWithModelReferences(Version.V_7_5_0, "model1")); + + Map regression = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap())); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression); + fail("Should not have successfully created"); + } catch (ElasticsearchException ex) { + assertThat(ex.getMessage(), + equalTo("Configuration [regression] requires minimum node version [8.0.0] (current minimum node version [7.5.0]")); + } catch (Exception ex) { + fail(ex.getMessage()); + } + + Map classification = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME, + Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1))); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", classification); + fail("Should not have successfully created"); + } catch (ElasticsearchException ex) { + assertThat(ex.getMessage(), + equalTo("Configuration [classification] requires minimum node version [8.0.0] (current minimum node version [7.5.0]")); + } catch (Exception ex) { + fail(ex.getMessage()); + } + } + + public void testCreateProcessor() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService); + + Map regression = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap())); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression); + } catch (Exception ex) { + fail(ex.getMessage()); + } + + Map classification = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME, + Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1))); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", classification); + } catch (Exception ex) { + fail(ex.getMessage()); + } + } + + private static ClusterState buildClusterState(MetaData metaData) { + return ClusterState.builder(new ClusterName("_name")).metaData(metaData).build(); + } + + private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException { + return builderClusterStateWithModelReferences(Version.CURRENT, modelId); + } + + private static ClusterState builderClusterStateWithModelReferences(Version minNodeVersion, String... modelId) throws IOException { + Map configurations = new HashMap<>(modelId.length); + for (String id : modelId) { + configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id)); + } + IngestMetadata ingestMetadata = new IngestMetadata(configurations); + + return ClusterState.builder(new ClusterName("_name")) + .metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) + .nodes(DiscoveryNodes.builder() + .add(new DiscoveryNode("min_node", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + minNodeVersion)) + .add(new DiscoveryNode("current_node", + new TransportAddress(InetAddress.getLoopbackAddress(), 9302), + Version.CURRENT)) + .localNodeId("_node_id") + .masterNodeId("_node_id")) + .build(); + } + + private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException { + try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors", + Collections.singletonList( + Collections.singletonMap(InferenceProcessor.TYPE, + new HashMap<>() {{ + put(InferenceProcessor.MODEL_ID, modelId); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap())); + put(InferenceProcessor.TARGET_FIELD, "new_field"); + put(InferenceProcessor.FIELD_MAPPINGS, Collections.singletonMap("source", "dest")); + }}))))) { + return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON); + } + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java new file mode 100644 index 0000000000000..4f55768407339 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -0,0 +1,191 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.ingest; + +import org.elasticsearch.client.Client; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.core.Is.is; +import static org.mockito.Mockito.mock; + +public class InferenceProcessorTests extends ESTestCase { + + private Client client; + + @Before + public void setUpVariables() { + client = mock(Client.class); + } + + public void testMutateDocumentWithClassification() { + String targetField = "classification_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "classification_model", + new ClassificationConfig(0), + Collections.emptyMap(), + "_ml_model.my_processor"); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", null))); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue(targetField, String.class), equalTo("foo")); + assertThat(document.getFieldValue("_ml_model", Map.class), + equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model")))); + } + + @SuppressWarnings("unchecked") + public void testMutateDocumentClassificationTopNClasses() { + String targetField = "classification_value_probabilities"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "classification_model", + new ClassificationConfig(2), + Collections.emptyMap(), + "_ml_model.my_processor"); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + List classes = new ArrayList<>(2); + classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6)); + classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4)); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes))); + inferenceProcessor.mutateDocument(response, document); + + assertThat((List>)document.getFieldValue(targetField, List.class), + contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new))); + assertThat(document.getFieldValue("_ml_model", Map.class), + equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model")))); + } + + public void testMutateDocumentRegression() { + String targetField = "regression_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "regression_model", + new RegressionConfig(), + Collections.emptyMap(), + "_ml_model.my_processor"); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new RegressionInferenceResults(0.7))); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); + assertThat(document.getFieldValue("_ml_model", Map.class), + equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "regression_model")))); + } + + public void testMutateDocumentNoModelMetaData() { + String targetField = "regression_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "regression_model", + new RegressionConfig(), + Collections.emptyMap(), + null); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new RegressionInferenceResults(0.7))); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); + assertThat(document.hasField("_ml_model"), is(false)); + } + + public void testGenerateRequestWithEmptyMapping() { + String modelId = "model"; + Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10); + + InferenceProcessor processor = new InferenceProcessor(client, + "my_processor", + "my_field", + modelId, + new ClassificationConfig(topNClasses), + Collections.emptyMap(), + null); + + Map source = new HashMap<>(){{ + put("value1", 1); + put("value2", 4); + put("categorical", "foo"); + }}; + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(source)); + } + + public void testGenerateWithMapping() { + String modelId = "model"; + Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10); + + Map fieldMapping = new HashMap<>(3) {{ + put("value1", "new_value1"); + put("value2", "new_value2"); + put("categorical", "new_categorical"); + }}; + + InferenceProcessor processor = new InferenceProcessor(client, + "my_processor", + "my_field", + modelId, + new ClassificationConfig(topNClasses), + fieldMapping, + null); + + Map source = new HashMap<>(3){{ + put("value1", 1); + put("categorical", "foo"); + put("un_touched", "bar"); + }}; + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + Map expectedMap = new HashMap<>(2) {{ + put("new_value1", 1); + put("new_categorical", "foo"); + put("un_touched", "bar"); + }}; + assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap)); + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index a18b29487eac5..a8b199a7a3b59 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -97,7 +97,10 @@ "processors": [ { "inference" : { - "model_id" : "used-regression-model" + "model_id" : "used-regression-model", + "inference_config": {"regression": {}}, + "target_field": "regression_field", + "field_mappings": {} } } ]