From c7b5b53bb1b20ea14d4b578a5639d193b9d736d1 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 10 Oct 2019 08:34:07 -0400 Subject: [PATCH 1/5] [ML][Inference] Adding ingest processor --- .../core/ml/action/InferModelAction.java | 2 +- .../trainedmodel/ClassificationConfig.java | 19 + .../trainedmodel/InferenceConfig.java | 2 + .../trainedmodel/RegressionConfig.java | 16 + .../ensemble/NullInferenceConfig.java | 6 + .../xpack/core/ml/job/messages/Messages.java | 2 + .../ClassificationConfigTests.java | 20 + .../trainedmodel/RegressionConfigTests.java | 16 + .../ml/integration/InferenceIngestIT.java | 531 ++++++++++++++++++ .../xpack/ml/MachineLearning.java | 19 +- .../inference/ingest/InferenceProcessor.java | 220 +++++++- .../loadingservice/ModelLoadingService.java | 20 +- .../InferenceProcessorFactoryTests.java | 266 +++++++++ .../ingest/InferenceProcessorTests.java | 188 +++++++ 14 files changed, 1318 insertions(+), 9 deletions(-) create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java index 67e3a75283d67..0093c86f78542 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -27,7 +27,7 @@ public class InferModelAction extends ActionType { public static final InferModelAction INSTANCE = new InferModelAction(); - public static final String NAME = "cluster:admin/xpack/ml/infer"; + public static final String NAME = "cluster:admin/xpack/ml/inference/infer"; private InferModelAction() { super(NAME, Response::new); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java index 5aa0403d94753..4c9fc4d89e93b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java @@ -5,12 +5,16 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; public class ClassificationConfig implements InferenceConfig { @@ -18,11 +22,21 @@ public class ClassificationConfig implements InferenceConfig { public static final String NAME = "classification"; public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + private static final Version MIN_SUPPORTED_VERSION = Version.V_8_0_0; public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0); private final int numTopClasses; + public static ClassificationConfig fromMap(Map map) { + Map options = new HashMap<>(map); + Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName()); + if (options.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet()); + } + return new ClassificationConfig(numTopClasses); + } + public ClassificationConfig(Integer numTopClasses) { this.numTopClasses = numTopClasses == null ? 0 : numTopClasses; } @@ -78,4 +92,9 @@ public boolean isTargetTypeSupported(TargetType targetType) { return TargetType.CLASSIFICATION.equals(targetType); } + @Override + public Version getMinimalSupportedVersion() { + return MIN_SUPPORTED_VERSION; + } + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java index 6129d71d5ff95..d423f5b0eb6ed 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; @@ -13,4 +14,5 @@ public interface InferenceConfig extends NamedXContentObject, NamedWriteable { boolean isTargetTypeSupported(TargetType targetType); + Version getMinimalSupportedVersion(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java index bb7f772f86ba4..58bd7bbd3d558 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java @@ -5,16 +5,27 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.util.Map; import java.util.Objects; public class RegressionConfig implements InferenceConfig { public static final String NAME = "regression"; + private static final Version MIN_SUPPORTED_VERSION = Version.V_8_0_0; + + public static RegressionConfig fromMap(Map map) { + if (map.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet()); + } + return new RegressionConfig(); + } public RegressionConfig() { } @@ -61,4 +72,9 @@ public boolean isTargetTypeSupported(TargetType targetType) { return TargetType.REGRESSION.equals(targetType); } + @Override + public Version getMinimalSupportedVersion() { + return MIN_SUPPORTED_VERSION; + } + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java index 7628d0beec25f..42757d889818e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; @@ -26,6 +27,11 @@ public boolean isTargetTypeSupported(TargetType targetType) { return true; } + @Override + public Version getMinimalSupportedVersion() { + return Version.CURRENT; + } + @Override public String getWriteableName() { return "null"; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index c302d04186a0d..d9edf670d763f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -85,6 +85,8 @@ public final class Messages { public static final String INFERENCE_FAILED_TO_SERIALIZE_MODEL = "Failed to serialize the trained model [{0}] with version [{1}] for storage"; public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}] with version [{1}]"; + public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION = + "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]"; public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_CREATED = "Job created"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java index 4df3263215f63..808aaf960f4e1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java @@ -5,15 +5,35 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import java.util.Collections; + +import static org.hamcrest.Matchers.equalTo; + public class ClassificationConfigTests extends AbstractWireSerializingTestCase { public static ClassificationConfig randomClassificationConfig() { return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10)); } + public void testFromMap() { + ClassificationConfig expected = new ClassificationConfig(0); + assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected)); + + expected = new ClassificationConfig(3); + assertThat(ClassificationConfig.fromMap(Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3)), + equalTo(expected)); + } + + public void testFromMapWithUnknownField() { + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> ClassificationConfig.fromMap(Collections.singletonMap("some_key", 1))); + assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + } + @Override protected ClassificationConfig createTestInstance() { return randomClassificationConfig(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java index 57efcdd15009a..bdb0e6d03201f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java @@ -5,15 +5,31 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import java.util.Collections; + +import static org.hamcrest.Matchers.equalTo; + public class RegressionConfigTests extends AbstractWireSerializingTestCase { public static RegressionConfig randomRegressionConfig() { return new RegressionConfig(); } + public void testFromMap() { + RegressionConfig expected = new RegressionConfig(); + assertThat(RegressionConfig.fromMap(Collections.emptyMap()), equalTo(expected)); + } + + public void testFromMapWithUnknownField() { + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> RegressionConfig.fromMap(Collections.singletonMap("some_key", 1))); + assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + } + @Override protected RegressionConfig createTestInstance() { return randomRegressionConfig(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java new file mode 100644 index 0000000000000..31fd1c6f2ef36 --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -0,0 +1,531 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; +import org.elasticsearch.action.ingest.SimulateDocumentBaseResult; +import org.elasticsearch.action.ingest.SimulatePipelineResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.junit.Before; + +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { + + @Before + public void createBothModels() { + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, + "_doc", + TrainedModelConfig.documentId("test_classification", 0)) + .setSource(CLASSIFICATION_MODEL, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get().status(), equalTo(RestStatus.CREATED)); + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, + "_doc", + TrainedModelConfig.documentId("test_regression", 0)) + .setSource(REGRESSION_MODEL, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get().status(), equalTo(RestStatus.CREATED)); + } + + public void testPipelineCreationAndDeletion() throws Exception { + + for (int i = 0; i < 10; i++) { + assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline", + new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + client().prepareIndex("index_for_inference_test", "_doc") + .setSource(new HashMap<>(){{ + put("col1", randomFrom("female", "male")); + put("col2", randomFrom("S", "M", "L", "XL")); + put("col3", randomFrom("true", "false", "none", "other")); + put("col4", randomIntBetween(0, 10)); + }}) + .setPipeline("simple_classification_pipeline") + .get(); + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(), + is(true)); + + assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline", + new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + client().prepareIndex("index_for_inference_test", "_doc") + .setSource(new HashMap<>(){{ + put("col1", randomFrom("female", "male")); + put("col2", randomFrom("S", "M", "L", "XL")); + put("col3", randomFrom("true", "false", "none", "other")); + put("col4", randomIntBetween(0, 10)); + }}) + .setPipeline("simple_regression_pipeline") + .get(); + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(), + is(true)); + } + + assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline", + new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline", + new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + for (int i = 0; i < 10; i++) { + client().prepareIndex("index_for_inference_test", "_doc") + .setSource(new HashMap<>(){{ + put("col1", randomFrom("female", "male")); + put("col2", randomFrom("S", "M", "L", "XL")); + put("col3", randomFrom("true", "false", "none", "other")); + put("col4", randomIntBetween(0, 10)); + }}) + .setPipeline("simple_classification_pipeline") + .get(); + + client().prepareIndex("index_for_inference_test", "_doc") + .setSource(new HashMap<>(){{ + put("col1", randomFrom("female", "male")); + put("col2", randomFrom("S", "M", "L", "XL")); + put("col3", randomFrom("true", "false", "none", "other")); + put("col4", randomIntBetween(0, 10)); + }}) + .setPipeline("simple_regression_pipeline") + .get(); + } + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(), + is(true)); + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(), + is(true)); + + client().admin().indices().refresh(new RefreshRequest("index_for_inference_test")).get(); + + assertThat(client().search(new SearchRequest().indices("index_for_inference_test") + .source(new SearchSourceBuilder() + .size(0) + .trackTotalHits(true) + .query(QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery("regression_value"))))).get().getHits().getTotalHits().value, + equalTo(20L)); + + assertThat(client().search(new SearchRequest().indices("index_for_inference_test") + .source(new SearchSourceBuilder() + .size(0) + .trackTotalHits(true) + .query(QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery("result_class"))))).get().getHits().getTotalHits().value, + equalTo(20L)); + + } + + public void testSimulate() { + String source = "{\n" + + " \"pipeline\": {\n" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class\",\n" + + " \"inference_config\": {\"classification\":{}},\n" + + " \"model_id\": \"test_classification\",\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class_prob\",\n" + + " \"inference_config\": {\"classification\": {\"num_top_classes\":2}},\n" + + " \"model_id\": \"test_classification\",\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"regression_value\",\n" + + " \"model_id\": \"test_regression\",\n" + + " \"inference_config\": {\"regression\":{}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " },\n" + + " \"docs\": [\n" + + " {\"_source\": {\n" + + " \"col1\": \"female\",\n" + + " \"col2\": \"M\",\n" + + " \"col3\": \"none\",\n" + + " \"col4\": 10\n" + + " }}]\n" + + "}"; + + SimulatePipelineResponse response = client().admin().cluster() + .prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get(); + SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0); + assertThat(baseResult.getIngestDocument().getFieldValue("regression_value", Double.class), equalTo(1.0)); + assertThat(baseResult.getIngestDocument().getFieldValue("result_class", String.class), equalTo("second")); + assertThat(baseResult.getIngestDocument().getFieldValue("result_class_prob", List.class).size(), equalTo(2)); + + String sourceWithMissingModel = "{\n" + + " \"pipeline\": {\n" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class\",\n" + + " \"model_id\": \"test_classification_missing\",\n" + + " \"inference_config\": {\"classification\":{}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " },\n" + + " \"docs\": [\n" + + " {\"_source\": {\n" + + " \"col1\": \"female\",\n" + + " \"col2\": \"M\",\n" + + " \"col3\": \"none\",\n" + + " \"col4\": 10\n" + + " }}]\n" + + "}"; + + response = client().admin().cluster() + .prepareSimulatePipeline(new BytesArray(sourceWithMissingModel.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get(); + + assertThat(((SimulateDocumentBaseResult) response.getResults().get(0)).getFailure().getMessage(), + containsString("Could not find trained model [test_classification_missing] with version [0]")); + } + + private static final String REGRESSION_MODEL = "{" + + " \"model_id\": \"test_regression\",\n" + + " \"model_version\": 0,\n" + + " \"definition\": {\n" + + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + + " \"preprocessors\": [\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field\": \"col1\",\n" + + " \"hot_map\": {\n" + + " \"male\": \"col1_male\",\n" + + " \"female\": \"col1_female\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"target_mean_encoding\": {\n" + + " \"field\": \"col2\",\n" + + " \"feature_name\": \"col2_encoded\",\n" + + " \"target_map\": {\n" + + " \"S\": 5.0,\n" + + " \"M\": 10.0,\n" + + " \"L\": 20\n" + + " },\n" + + " \"default_value\": 5.0\n" + + " }\n" + + " },\n" + + " {\n" + + " \"frequency_encoding\": {\n" + + " \"field\": \"col3\",\n" + + " \"feature_name\": \"col3_encoded\",\n" + + " \"frequency_map\": {\n" + + " \"none\": 0.75,\n" + + " \"true\": 0.10,\n" + + " \"false\": 0.15\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"trained_model\": {\n" + + " \"ensemble\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"aggregate_output\": {\n" + + " \"weighted_sum\": {\n" + + " \"weights\": [\n" + + " 0.5,\n" + + " 0.5\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"target_type\": \"regression\",\n" + + " \"trained_models\": [\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"description\": \"test model for regression\",\n" + + " \"version\": \"8.0.0\",\n" + + " \"created_by\": \"ml_test\",\n" + + " \"model_type\": \"local\",\n" + + " \"created_time\": 0" + + "}"; + + private static final String CLASSIFICATION_MODEL = "" + + "{\n" + + " \"model_id\": \"test_classification\",\n" + + " \"model_version\": 0,\n" + + " \"definition\":{\n" + + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + + " \"preprocessors\": [\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field\": \"col1\",\n" + + " \"hot_map\": {\n" + + " \"male\": \"col1_male\",\n" + + " \"female\": \"col1_female\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"target_mean_encoding\": {\n" + + " \"field\": \"col2\",\n" + + " \"feature_name\": \"col2_encoded\",\n" + + " \"target_map\": {\n" + + " \"S\": 5.0,\n" + + " \"M\": 10.0,\n" + + " \"L\": 20\n" + + " },\n" + + " \"default_value\": 5.0\n" + + " }\n" + + " },\n" + + " {\n" + + " \"frequency_encoding\": {\n" + + " \"field\": \"col3\",\n" + + " \"feature_name\": \"col3_encoded\",\n" + + " \"frequency_map\": {\n" + + " \"none\": 0.75,\n" + + " \"true\": 0.10,\n" + + " \"false\": 0.15\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"trained_model\": {\n" + + " \"ensemble\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"aggregate_output\": {\n" + + " \"weighted_mode\": {\n" + + " \"weights\": [\n" + + " 0.5,\n" + + " 0.5\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"target_type\": \"classification\",\n" + + " \"classification_labels\": [\"first\", \"second\"],\n" + + " \"trained_models\": [\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"description\": \"test model for classification\",\n" + + " \"version\": \"8.0.0\",\n" + + " \"created_by\": \"benwtrent\",\n" + + " \"model_type\": \"local\",\n" + + " \"created_time\": 0\n" + + "}"; + + private static final String CLASSIFICATION_PIPELINE = "{" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class\",\n" + + " \"model_id\": \"test_classification\",\n" + + " \"inference_config\": {\"classification\": {}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }]}\n"; + + private static final String REGRESSION_PIPELINE = "{" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"regression_value\",\n" + + " \"model_id\": \"test_regression\",\n" + + " \"inference_config\": {\"regression\": {}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }]}\n"; + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 0b09b0736bfbb..0ba71fff62266 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -42,12 +42,14 @@ import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.TokenizerFactory; import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider; +import org.elasticsearch.ingest.Processor; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.monitor.os.OsProbe; import org.elasticsearch.monitor.os.OsStats; import org.elasticsearch.persistent.PersistentTasksExecutor; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.AnalysisPlugin; +import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.RestController; @@ -202,6 +204,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -303,7 +306,7 @@ import static java.util.Collections.emptyList; import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME; -public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlugin, PersistentTaskPlugin { +public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin { public static final String NAME = "ml"; public static final String BASE_PATH = "/_ml/"; public static final String PRE_V7_BASE_PATH = "/_xpack/ml/"; @@ -327,6 +330,16 @@ protected Setting roleSetting() { }; + @Override + public Map getProcessors(Processor.Parameters parameters) { + InferenceProcessor.Factory inferenceFactory = new InferenceProcessor.Factory(parameters.client, + parameters.ingestService.getClusterService(), + this.settings, + parameters.ingestService); + parameters.ingestService.addIngestClusterStateListener(inferenceFactory); + return Collections.singletonMap(InferenceProcessor.TYPE, inferenceFactory); + } + @Override public Set getRoles() { return Collections.singleton(ML_ROLE); @@ -416,7 +429,9 @@ public List> getSettings() { AutodetectBuilder.MAX_ANOMALY_RECORDS_SETTING_DYNAMIC, MAX_OPEN_JOBS_PER_NODE, MIN_DISK_SPACE_OFF_HEAP, - MlConfigMigrationEligibilityCheck.ENABLE_CONFIG_MIGRATION); + MlConfigMigrationEligibilityCheck.ENABLE_CONFIG_MIGRATION, + InferenceProcessor.MAX_INFERENCE_PROCESSORS + ); } public Settings additionalSettings() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 59f6c62a7f55e..bcb293c377748 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -5,28 +5,126 @@ */ package org.elasticsearch.xpack.ml.inference.ingest; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.ingest.AbstractProcessor; +import org.elasticsearch.ingest.ConfigurationUtils; import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.IngestService; +import org.elasticsearch.ingest.Pipeline; +import org.elasticsearch.ingest.PipelineConfiguration; +import org.elasticsearch.ingest.Processor; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + public class InferenceProcessor extends AbstractProcessor { + // How many total inference processors are allowed to be used in the cluster. + public static final Setting MAX_INFERENCE_PROCESSORS = Setting.intSetting("xpack.ml.max_inference_processors", + 50, + 1, + Setting.Property.Dynamic, + Setting.Property.NodeScope); + public static final String TYPE = "inference"; public static final String MODEL_ID = "model_id"; + public static final String INFERENCE_CONFIG = "inference_config"; + public static final String TARGET_FIELD = "target_field"; + public static final String FIELD_MAPPINGS = "field_mappings"; + public static final String MODEL_INFO_FIELD = "model_info_field"; private final Client client; - public InferenceProcessor(Client client, String tag) { + private final String targetField; + private final String modelInfoField; + private final Map modelInfo; + private final String modelId; + private final InferenceConfig inferenceConfig; + private final Map fieldMapping; + + public InferenceProcessor(Client client, + String tag, + String targetField, + String modelId, + InferenceConfig inferenceConfig, + Map fieldMapping, + String modelInfoField) { super(tag); this.client = client; + this.targetField = targetField; + this.modelInfoField = modelInfoField; + this.modelId = modelId; + this.inferenceConfig = inferenceConfig; + this.fieldMapping = fieldMapping; + this.modelInfo = new HashMap<>(); + this.modelInfo.put("model_id", modelId); } @Override public void execute(IngestDocument ingestDocument, BiConsumer handler) { - //TODO actually work - handler.accept(ingestDocument, null); + executeAsyncWithOrigin(client, + ML_ORIGIN, + InferModelAction.INSTANCE, + this.buildRequest(ingestDocument), + ActionListener.wrap( + r -> { + try { + mutateDocument(r, ingestDocument); + handler.accept(ingestDocument, null); + } catch(ElasticsearchException ex) { + handler.accept(ingestDocument, ex); + } + }, + e -> handler.accept(ingestDocument, e) + )); + } + + InferModelAction.Request buildRequest(IngestDocument ingestDocument) { + Map fields = new HashMap<>(ingestDocument.getSourceAndMetadata()); + if (fieldMapping != null) { + fieldMapping.forEach((src, dest) -> { + Object srcValue = fields.remove(src); + if (srcValue != null) { + fields.put(dest, srcValue); + } + }); + } + return new InferModelAction.Request(modelId, 0, fields, inferenceConfig); + } + + void mutateDocument(InferModelAction.Response response, IngestDocument ingestDocument) { + response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField); + if (modelInfoField != null) { + ingestDocument.setFieldValue(modelInfoField, modelInfo); + } } @Override @@ -38,4 +136,120 @@ public IngestDocument execute(IngestDocument ingestDocument) { public String getType() { return TYPE; } + + public static final class Factory implements Processor.Factory, Consumer { + + private static final Logger logger = LogManager.getLogger(Factory.class); + + private final Client client; + private final IngestService ingestService; + private volatile int currentInferenceProcessors; + private volatile int maxIngestProcessors; + private volatile Version minNodeVersion = Version.CURRENT; + public Factory(Client client, ClusterService clusterService, Settings settings, IngestService ingestService) { + this.client = client; + this.maxIngestProcessors = MAX_INFERENCE_PROCESSORS.get(settings); + this.ingestService = ingestService; + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_INFERENCE_PROCESSORS, this::setMaxIngestProcessors); + } + + @Override + public void accept(ClusterState state) { + minNodeVersion = state.nodes().getMinNodeVersion(); + MetaData metaData = state.getMetaData(); + if (metaData == null) { + currentInferenceProcessors = 0; + return; + } + IngestMetadata ingestMetadata = metaData.custom(IngestMetadata.TYPE); + if (ingestMetadata == null) { + currentInferenceProcessors = 0; + return; + } + + int count = 0; + for (PipelineConfiguration configuration : ingestMetadata.getPipelines().values()) { + try { + Pipeline pipeline = Pipeline.create(configuration.getId(), + configuration.getConfigAsMap(), + ingestService.getProcessorFactories(), + ingestService.getScriptService()); + count += pipeline.getProcessors().stream().filter(processor -> processor instanceof InferenceProcessor).count(); + } catch (Exception ex) { + logger.warn(new ParameterizedMessage("failure parsing pipeline config [{}]", configuration.getId()), ex); + } + } + currentInferenceProcessors = count; + } + + // Used for testing + int numInferenceProcessors() { + return currentInferenceProcessors; + } + + @Override + public InferenceProcessor create(Map processorFactories, String tag, Map config) + throws Exception { + + if (this.maxIngestProcessors <= currentInferenceProcessors) { + throw new ElasticsearchStatusException("Max number of inference processors reached, total inference processors [{}]. " + + "Adjust the setting [{}]: [{}] if a greater number is desired.", + RestStatus.CONFLICT, + currentInferenceProcessors, + MAX_INFERENCE_PROCESSORS.getKey(), + maxIngestProcessors); + } + + String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID); + String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD); + Map fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS); + InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG)); + String modelInfoField = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_INFO_FIELD, "_model_info"); + return new InferenceProcessor(client, tag, targetField, modelId, inferenceConfig, fieldMapping, modelInfoField); + } + + // Package private for testing + void setMaxIngestProcessors(int maxIngestProcessors) { + logger.debug("updating setting maxIngestProcessors from [{}] to [{}]", this.maxIngestProcessors, maxIngestProcessors); + this.maxIngestProcessors = maxIngestProcessors; + } + + InferenceConfig inferenceConfigFromMap(Map inferenceConfig) throws IOException { + ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); + + if (inferenceConfig.keySet().size() != 1) { + throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.", + INFERENCE_CONFIG); + } + Object value = inferenceConfig.values().iterator().next(); + + if ((value instanceof Map) == false) { + throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.", + INFERENCE_CONFIG); + } + @SuppressWarnings("unchecked") + Map valueMap = (Map)value; + + if (inferenceConfig.containsKey(ClassificationConfig.NAME)) { + checkSupportedVersion(new ClassificationConfig(0)); + return ClassificationConfig.fromMap(valueMap); + } else if (inferenceConfig.containsKey(RegressionConfig.NAME)) { + checkSupportedVersion(new RegressionConfig()); + return RegressionConfig.fromMap(valueMap); + } else { + throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}", + inferenceConfig.keySet(), + Arrays.asList(ClassificationConfig.NAME, RegressionConfig.NAME)); + } + } + + void checkSupportedVersion(InferenceConfig config) { + if (config.getMinimalSupportedVersion().after(minNodeVersion)) { + throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION, + config.getName(), + config.getMinimalSupportedVersion(), + minNodeVersion)); + } + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index b4fc552ba5f93..adb4453605f9c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -53,20 +53,21 @@ public void getModel(String modelId, long modelVersion, ActionListener mo if (cachedModel != null) { if (cachedModel.isSuccess()) { modelActionListener.onResponse(cachedModel.getModel()); + logger.trace("[{}] version [{}] loaded from cache", modelId, modelVersion); return; } } if (loadModelIfNecessary(key, modelId, modelVersion, modelActionListener) == false) { // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called // by a simulated pipeline - logger.debug("[{}] version [{}] not actively loading, eager loading without cache", modelId, modelVersion); + logger.trace("[{}] version [{}] not actively loading, eager loading without cache", modelId, modelVersion); provider.getTrainedModel(modelId, modelVersion, ActionListener.wrap( trainedModelConfig -> modelActionListener.onResponse(new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition())), modelActionListener::onFailure )); } else { - logger.debug("[{}] version [{}] is currently loading, added new listener to queue", modelId, modelVersion); + logger.trace("[{}] version [{}] is loading or loaded, added new listener to queue", modelId, modelVersion); } } @@ -88,7 +89,7 @@ private boolean loadModelIfNecessary(String key, String modelId, long modelVersi if (loadingListeners.computeIfPresent( key, (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) { - logger.debug("[{}] version [{}] attempting to load and cache", modelId, modelVersion); + logger.trace("[{}] version [{}] attempting to load and cache", modelId, modelVersion); loadingListeners.put(key, addFluently(new ArrayDeque<>(), modelActionListener)); loadModel(key, modelId, modelVersion); } @@ -157,6 +158,8 @@ public void clusterChanged(ClusterChangedEvent event) { // The listeners still waiting for a model and we are canceling the load? List>>> drainWithFailure = new ArrayList<>(); synchronized (loadingListeners) { + HashSet loadedModelBeforeClusterState = logger.isTraceEnabled() ? new HashSet<>(loadedModels.keySet()) : null; + HashSet loadingModelBeforeClusterState = logger.isTraceEnabled() ? new HashSet<>(loadingListeners.keySet()) : null; // If we had models still loading here but are no longer referenced // we should remove them from loadingListeners and alert the listeners for (String modelKey : loadingListeners.keySet()) { @@ -181,6 +184,17 @@ public void clusterChanged(ClusterChangedEvent event) { for (String modelId : allReferencedModelKeys) { loadingListeners.put(modelId, new ArrayDeque<>()); } + if (loadedModelBeforeClusterState != null && loadingModelBeforeClusterState != null) { + if (loadingListeners.keySet().equals(loadingModelBeforeClusterState) == false) { + logger.trace("cluster state event changed loading models: before {} after {}", loadingModelBeforeClusterState, + loadingListeners.keySet()); + } + if (loadedModels.keySet().equals(loadedModelBeforeClusterState) == false) { + logger.trace("cluster state event changed loaded models: before {} after {}", loadedModelBeforeClusterState, + loadedModels.keySet()); + } + } + } for (Tuple>> modelAndListeners : drainWithFailure) { final String msg = new ParameterizedMessage( diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java new file mode 100644 index 0000000000000..6b43d7a6f0d9d --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -0,0 +1,266 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.ingest; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.Version; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.IngestService; +import org.elasticsearch.ingest.PipelineConfiguration; +import org.elasticsearch.ingest.Processor; +import org.elasticsearch.plugins.IngestPlugin; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.junit.Before; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class InferenceProcessorFactoryTests extends ESTestCase { + + private static final IngestPlugin SKINNY_PLUGIN = new IngestPlugin() { + @Override + public Map getProcessors(Processor.Parameters parameters) { + return Collections.singletonMap(InferenceProcessor.TYPE, + new InferenceProcessor.Factory(parameters.client, + parameters.ingestService.getClusterService(), + Settings.EMPTY, + parameters.ingestService)); + } + }; + private Client client; + private ClusterService clusterService; + private IngestService ingestService; + + @Before + public void setUpVariables() { + ThreadPool tp = mock(ThreadPool.class); + client = mock(Client.class); + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, + Collections.singleton(InferenceProcessor.MAX_INFERENCE_PROCESSORS)); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + ingestService = new IngestService(clusterService, tp, null, null, + null, Collections.singletonList(SKINNY_PLUGIN), client); + } + + public void testNumInferenceProcessors() throws Exception { + MetaData metaData = null; + + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY, ingestService); + processorFactory.accept(buildState(metaData)); + + assertThat(processorFactory.numInferenceProcessors(), equalTo(0)); + metaData = MetaData.builder().build(); + + processorFactory.accept(buildState(metaData)); + assertThat(processorFactory.numInferenceProcessors(), equalTo(0)); + + processorFactory.accept(buildClusterStateWithModelReferences("model1", "model2", "model3")); + assertThat(processorFactory.numInferenceProcessors(), equalTo(3)); + } + + public void testCreateProcessorWithTooManyExisting() throws Exception { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.builder().put(InferenceProcessor.MAX_INFERENCE_PROCESSORS.getKey(), 1).build(), + ingestService); + + processorFactory.accept(buildClusterStateWithModelReferences("model1")); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", Collections.emptyMap())); + + assertThat(ex.getMessage(), equalTo("Max number of inference processors reached, total inference processors [1]. " + + "Adjust the setting [xpack.ml.max_inference_processors]: [1] if a greater number is desired.")); + } + + public void testCreateProcessorWithInvalidInferenceConfig() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService); + + Map config = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("unknown_type", Collections.emptyMap())); + }}; + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config)); + assertThat(ex.getMessage(), + equalTo("unrecognized inference configuration type [unknown_type]. Supported types [classification, regression]")); + + Map config2 = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("regression", "boom")); + }}; + ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config2)); + assertThat(ex.getMessage(), + equalTo("inference_config must be an object with one inference type mapped to an object.")); + + Map config3 = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.emptyMap()); + }}; + ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config3)); + assertThat(ex.getMessage(), + equalTo("inference_config must be an object with one inference type mapped to an object.")); + } + + public void testCreateProcessorWithTooOldMinNodeVersion() throws IOException { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService); + processorFactory.accept(builderClusterStateWithModelReferences(Version.V_7_5_0, "model1")); + + Map regression = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap())); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression); + fail("Should not have successfully created"); + } catch (ElasticsearchException ex) { + assertThat(ex.getMessage(), + equalTo("Configuration [regression] requires minimum node version [8.0.0] (current minimum node version [7.5.0]")); + } catch (Exception ex) { + fail(ex.getMessage()); + } + + Map classification = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME, + Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1))); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", classification); + fail("Should not have successfully created"); + } catch (ElasticsearchException ex) { + assertThat(ex.getMessage(), + equalTo("Configuration [classification] requires minimum node version [8.0.0] (current minimum node version [7.5.0]")); + } catch (Exception ex) { + fail(ex.getMessage()); + } + } + + public void testCreateProcessor() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService); + + Map regression = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap())); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression); + } catch (Exception ex) { + fail(ex.getMessage()); + } + + Map classification = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME, + Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1))); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", classification); + } catch (Exception ex) { + fail(ex.getMessage()); + } + } + + private static ClusterState buildState(MetaData metaData) { + return ClusterState.builder(new ClusterName("_name")).metaData(metaData).build(); + } + + private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException { + return builderClusterStateWithModelReferences(Version.CURRENT, modelId); + } + + private static ClusterState builderClusterStateWithModelReferences(Version minNodeVersion, String... modelId) throws IOException { + Map configurations = new HashMap<>(modelId.length); + for (String id : modelId) { + configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id)); + } + IngestMetadata ingestMetadata = new IngestMetadata(configurations); + + return ClusterState.builder(new ClusterName("_name")) + .metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) + .nodes(DiscoveryNodes.builder() + .add(new DiscoveryNode("min_node", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + minNodeVersion)) + .add(new DiscoveryNode("current_node", + new TransportAddress(InetAddress.getLoopbackAddress(), 9302), + Version.CURRENT)) + .localNodeId("_node_id") + .masterNodeId("_node_id")) + .build(); + } + + private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException { + try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors", + Collections.singletonList( + Collections.singletonMap(InferenceProcessor.TYPE, + new HashMap<>() {{ + put(InferenceProcessor.MODEL_ID, modelId); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap())); + put(InferenceProcessor.TARGET_FIELD, "new_field"); + put(InferenceProcessor.FIELD_MAPPINGS, Collections.singletonMap("source", "dest")); + }}))))) { + return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON); + } + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java new file mode 100644 index 0000000000000..32a00cd24cb73 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -0,0 +1,188 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.ingest; + +import org.elasticsearch.client.Client; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.core.Is.is; +import static org.mockito.Mockito.mock; + +public class InferenceProcessorTests extends ESTestCase { + + private Client client; + + @Before + public void setUpVariables() { + client = mock(Client.class); + } + + public void testMutateDocumentWithClassification() { + String targetField = "classification_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "classification_model", + new ClassificationConfig(0), + Collections.emptyMap(), + "_ml_model"); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", null))); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue(targetField, String.class), equalTo("foo")); + assertThat(document.getFieldValue("_ml_model", Map.class), equalTo(Collections.singletonMap("model_id", "classification_model"))); + } + + @SuppressWarnings("unchecked") + public void testMutateDocumentClassificationTopNClasses() { + String targetField = "classification_value_probabilities"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "classification_model", + new ClassificationConfig(2), + Collections.emptyMap(), + "_ml_model"); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + List classes = new ArrayList<>(2); + classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6)); + classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4)); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes))); + inferenceProcessor.mutateDocument(response, document); + + assertThat((List>)document.getFieldValue(targetField, List.class), + contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new))); + assertThat(document.getFieldValue("_ml_model", Map.class), equalTo(Collections.singletonMap("model_id", "classification_model"))); + } + + public void testMutateDocumentRegression() { + String targetField = "regression_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "regression_model", + new RegressionConfig(), + Collections.emptyMap(), + "_ml_model"); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new RegressionInferenceResults(0.7))); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); + assertThat(document.getFieldValue("_ml_model", Map.class), equalTo(Collections.singletonMap("model_id", "regression_model"))); + } + + public void testMutateDocumentNoModelMetaData() { + String targetField = "regression_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "regression_model", + new RegressionConfig(), + Collections.emptyMap(), + null); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new RegressionInferenceResults(0.7))); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); + assertThat(document.hasField("_ml_model"), is(false)); + } + + public void testGenerateRequestWithEmptyMapping() { + String modelId = "model"; + Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10); + + InferenceProcessor processor = new InferenceProcessor(client, + "my_processor", + "my_field", + modelId, + new ClassificationConfig(topNClasses), + Collections.emptyMap(), + null); + + Map source = new HashMap<>(){{ + put("value1", 1); + put("value2", 4); + put("categorical", "foo"); + }}; + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(source)); + } + + public void testGenerateWithMapping() { + String modelId = "model"; + Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10); + + Map fieldMapping = new HashMap<>(3) {{ + put("value1", "new_value1"); + put("value2", "new_value2"); + put("categorical", "new_categorical"); + }}; + + InferenceProcessor processor = new InferenceProcessor(client, + "my_processor", + "my_field", + modelId, + new ClassificationConfig(topNClasses), + fieldMapping, + null); + + Map source = new HashMap<>(3){{ + put("value1", 1); + put("categorical", "foo"); + put("un_touched", "bar"); + }}; + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + Map expectedMap = new HashMap<>(2) {{ + put("new_value1", 1); + put("new_categorical", "foo"); + put("un_touched", "bar"); + }}; + assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap)); + } +} From 13661cf4f381130a0cc2454b0b47ebbc7252a24d Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 10 Oct 2019 09:27:37 -0400 Subject: [PATCH 2/5] optionally including tag in model metadata injection in processor --- .../ml/inference/ingest/InferenceProcessor.java | 5 ++++- .../inference/ingest/InferenceProcessorTests.java | 15 +++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index bcb293c377748..906d8b64e5daa 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -205,12 +205,15 @@ public InferenceProcessor create(Map processorFactori Map fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS); InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG)); String modelInfoField = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_INFO_FIELD, "_model_info"); + if (modelInfoField != null && tag != null) { + modelInfoField += "." + tag; + } return new InferenceProcessor(client, tag, targetField, modelId, inferenceConfig, fieldMapping, modelInfoField); } // Package private for testing void setMaxIngestProcessors(int maxIngestProcessors) { - logger.debug("updating setting maxIngestProcessors from [{}] to [{}]", this.maxIngestProcessors, maxIngestProcessors); + logger.trace("updating setting maxIngestProcessors from [{}] to [{}]", this.maxIngestProcessors, maxIngestProcessors); this.maxIngestProcessors = maxIngestProcessors; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index 32a00cd24cb73..4f55768407339 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -43,7 +43,7 @@ public void testMutateDocumentWithClassification() { "classification_model", new ClassificationConfig(0), Collections.emptyMap(), - "_ml_model"); + "_ml_model.my_processor"); Map source = new HashMap<>(); Map ingestMetadata = new HashMap<>(); @@ -54,7 +54,8 @@ public void testMutateDocumentWithClassification() { inferenceProcessor.mutateDocument(response, document); assertThat(document.getFieldValue(targetField, String.class), equalTo("foo")); - assertThat(document.getFieldValue("_ml_model", Map.class), equalTo(Collections.singletonMap("model_id", "classification_model"))); + assertThat(document.getFieldValue("_ml_model", Map.class), + equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model")))); } @SuppressWarnings("unchecked") @@ -66,7 +67,7 @@ public void testMutateDocumentClassificationTopNClasses() { "classification_model", new ClassificationConfig(2), Collections.emptyMap(), - "_ml_model"); + "_ml_model.my_processor"); Map source = new HashMap<>(); Map ingestMetadata = new HashMap<>(); @@ -82,7 +83,8 @@ public void testMutateDocumentClassificationTopNClasses() { assertThat((List>)document.getFieldValue(targetField, List.class), contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new))); - assertThat(document.getFieldValue("_ml_model", Map.class), equalTo(Collections.singletonMap("model_id", "classification_model"))); + assertThat(document.getFieldValue("_ml_model", Map.class), + equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model")))); } public void testMutateDocumentRegression() { @@ -93,7 +95,7 @@ public void testMutateDocumentRegression() { "regression_model", new RegressionConfig(), Collections.emptyMap(), - "_ml_model"); + "_ml_model.my_processor"); Map source = new HashMap<>(); Map ingestMetadata = new HashMap<>(); @@ -104,7 +106,8 @@ public void testMutateDocumentRegression() { inferenceProcessor.mutateDocument(response, document); assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); - assertThat(document.getFieldValue("_ml_model", Map.class), equalTo(Collections.singletonMap("model_id", "regression_model"))); + assertThat(document.getFieldValue("_ml_model", Map.class), + equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "regression_model")))); } public void testMutateDocumentNoModelMetaData() { From 62ec0ccc6e2f8a484c8c2d6e5830a88d98da2365 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 16 Oct 2019 15:24:12 -0400 Subject: [PATCH 3/5] fixing test --- .../elasticsearch/xpack/ml/integration/InferenceIngestIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index f5131ea5cd55b..5a2a7af362f64 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -232,7 +232,7 @@ public void testSimulate() { XContentType.JSON).get(); assertThat(((SimulateDocumentBaseResult) response.getResults().get(0)).getFailure().getMessage(), - containsString("Could not find trained model [test_classification_missing] with version [0]")); + containsString("Could not find trained model [test_classification_missing]")); } private static final String REGRESSION_MODEL = "{" + From 67ca80533da08854ff50df0a9ba4dd93c581f3b0 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 21 Oct 2019 07:20:25 -0400 Subject: [PATCH 4/5] addressing PR comments --- .../ml/integration/InferenceIngestIT.java | 24 +++++++++---------- .../inference/ingest/InferenceProcessor.java | 6 ++++- .../InferenceProcessorFactoryTests.java | 6 ++--- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index 5a2a7af362f64..852b3fcea0f0e 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -92,22 +93,12 @@ public void testPipelineCreationAndDeletion() throws Exception { for (int i = 0; i < 10; i++) { client().prepareIndex("index_for_inference_test", "_doc") - .setSource(new HashMap<>(){{ - put("col1", randomFrom("female", "male")); - put("col2", randomFrom("S", "M", "L", "XL")); - put("col3", randomFrom("true", "false", "none", "other")); - put("col4", randomIntBetween(0, 10)); - }}) + .setSource(generateSourceDoc()) .setPipeline("simple_classification_pipeline") .get(); client().prepareIndex("index_for_inference_test", "_doc") - .setSource(new HashMap<>(){{ - put("col1", randomFrom("female", "male")); - put("col2", randomFrom("S", "M", "L", "XL")); - put("col3", randomFrom("true", "false", "none", "other")); - put("col4", randomIntBetween(0, 10)); - }}) + .setSource(generateSourceDoc()) .setPipeline("simple_regression_pipeline") .get(); } @@ -235,6 +226,15 @@ public void testSimulate() { containsString("Could not find trained model [test_classification_missing]")); } + private Map generateSourceDoc() { + return new HashMap<>(){{ + put("col1", randomFrom("female", "male")); + put("col2", randomFrom("S", "M", "L", "XL")); + put("col3", randomFrom("true", "false", "none", "other")); + put("col4", randomIntBetween(0, 10)); + }}; + } + private static final String REGRESSION_MODEL = "{" + " \"model_id\": \"test_regression\",\n" + " \"model_version\": 0,\n" + diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 5951894ca0df7..40ca9ba253594 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -124,6 +124,9 @@ InferModelAction.Request buildRequest(IngestDocument ingestDocument) { } void mutateDocument(InferModelAction.Response response, IngestDocument ingestDocument) { + if (response.getInferenceResults().isEmpty()) { + throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR); + } response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField); if (modelInfoField != null) { ingestDocument.setFieldValue(modelInfoField, modelInfo); @@ -149,6 +152,7 @@ public static final class Factory implements Processor.Factory, Consumer inferenceConfig) throws IOException { ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); - if (inferenceConfig.keySet().size() != 1) { + if (inferenceConfig.size() != 1) { throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.", INFERENCE_CONFIG); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java index 6b43d7a6f0d9d..322b5cfb4ec2c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -75,12 +75,12 @@ public void testNumInferenceProcessors() throws Exception { MetaData metaData = null; InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY, ingestService); - processorFactory.accept(buildState(metaData)); + processorFactory.accept(buildClusterState(metaData)); assertThat(processorFactory.numInferenceProcessors(), equalTo(0)); metaData = MetaData.builder().build(); - processorFactory.accept(buildState(metaData)); + processorFactory.accept(buildClusterState(metaData)); assertThat(processorFactory.numInferenceProcessors(), equalTo(0)); processorFactory.accept(buildClusterStateWithModelReferences("model1", "model2", "model3")); @@ -220,7 +220,7 @@ public void testCreateProcessor() { } } - private static ClusterState buildState(MetaData metaData) { + private static ClusterState buildClusterState(MetaData metaData) { return ClusterState.builder(new ClusterName("_name")).metaData(metaData).build(); } From 40957e825e7c15efb400f28778104643c465cbcd Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 21 Oct 2019 07:56:34 -0400 Subject: [PATCH 5/5] adding comment --- .../xpack/core/ml/inference/trainedmodel/InferenceConfig.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java index d423f5b0eb6ed..5d1dc7983ff3c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java @@ -14,5 +14,8 @@ public interface InferenceConfig extends NamedXContentObject, NamedWriteable { boolean isTargetTypeSupported(TargetType targetType); + /** + * All nodes in the cluster must be at least this version + */ Version getMinimalSupportedVersion(); }