diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java index 005f0d180cdc1..b86cfced5524f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java @@ -5,17 +5,20 @@ */ package org.elasticsearch.xpack.core.ml.action; -import org.elasticsearch.action.ActionRequestBuilder; import org.elasticsearch.action.ActionType; -import org.elasticsearch.client.ElasticsearchClient; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest; import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse; import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + public class GetTrainedModelsAction extends ActionType { @@ -28,19 +31,20 @@ private GetTrainedModelsAction() { public static class Request extends AbstractGetResourcesRequest { + public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition"); public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); - public Request() { - setAllowNoResources(true); - } + private final boolean includeModelDefinition; - public Request(String id) { + public Request(String id, boolean includeModelDefinition) { setResourceId(id); setAllowNoResources(true); + this.includeModelDefinition = includeModelDefinition; } public Request(StreamInput in) throws IOException { super(in); + this.includeModelDefinition = in.readBoolean(); } @Override @@ -48,6 +52,32 @@ public String getResourceIdField() { return TrainedModelConfig.MODEL_ID.getPreferredName(); } + public boolean isIncludeModelDefinition() { + return includeModelDefinition; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeBoolean(includeModelDefinition); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), includeModelDefinition); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + Request other = (Request) obj; + return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition; + } } public static class Response extends AbstractGetResourcesResponse { @@ -66,12 +96,33 @@ public Response(QueryPage trainedModels) { protected Reader getReader() { return TrainedModelConfig::new; } - } - public static class RequestBuilder extends ActionRequestBuilder { + public static Builder builder() { + return new Builder(); + } + + public static class Builder { - public RequestBuilder(ElasticsearchClient client) { - super(client, INSTANCE, new Request()); + private long totalCount; + private List configs = Collections.emptyList(); + + private Builder() { + } + + public Builder setTotalCount(long totalCount) { + this.totalCount = totalCount; + return this; + } + + public Builder setModels(List configs) { + this.configs = configs; + return this; + } + + public Response build() { + return new Response(new QueryPage<>(configs, totalCount, RESULTS_FIELD)); + } } } + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 61cd542a9fc2d..00fce4e58e826 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -83,9 +83,13 @@ public final class Messages { public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists"; public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]"; public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]"; + public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}"; public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION = "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]"; public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]"; + public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]"; + public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED = + "Getting model definition is not supported when getting more than one model"; public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_CREATED = "Job created"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java index 0abc0318e215e..85345467df169 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java @@ -14,7 +14,7 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas @Override protected Request createTestInstance() { - Request request = new Request(randomAlphaOfLength(20)); + Request request = new Request(randomAlphaOfLength(20), randomBoolean()); request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100))); return request; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java index 20213ba99d62d..5fdadac712d0d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; @@ -12,6 +13,8 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentParseException; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.search.SearchModule; @@ -22,7 +25,6 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; -import org.junit.Before; import java.io.IOException; import java.util.ArrayList; @@ -33,27 +35,21 @@ import java.util.stream.Stream; import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; public class TrainedModelDefinitionTests extends AbstractSerializingTestCase { - private boolean lenient; - - @Before - public void chooseStrictOrLenient() { - lenient = randomBoolean(); - } - @Override protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException { - return TrainedModelDefinition.fromXContent(parser, lenient).build(); + return TrainedModelDefinition.fromXContent(parser, true).build(); } @Override protected boolean supportsUnknownFields() { - return lenient; + return true; } @Override @@ -63,7 +59,7 @@ protected Predicate getRandomFieldsExcludeFilter() { @Override protected ToXContent.Params getToXContentParams() { - return lenient ? ToXContent.EMPTY_PARAMS : new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")); + return new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")); } @Override @@ -286,9 +282,27 @@ public void testTreeSchemaDeserialization() throws IOException { assertThat(definition.getTrainedModel().getClass(), equalTo(Tree.class)); } + public void testStrictParser() throws IOException { + TrainedModelDefinition.Builder builder = createRandomBuilder("asdf"); + BytesReference reference = XContentHelper.toXContent(builder.build(), + XContentType.JSON, + new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")), + false); + + XContentParser parser = XContentHelper.createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + reference, + XContentType.JSON); + + XContentParseException exception = expectThrows(XContentParseException.class, + () -> TrainedModelDefinition.fromXContent(parser, false)); + + assertThat(exception.getMessage(), containsString("[trained_model_definition] unknown field [doc_type]")); + } + @Override protected TrainedModelDefinition createTestInstance() { - return createRandomBuilder(null).build(); + return createRandomBuilder(randomAlphaOfLength(10)).build(); } @Override diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index 9c059bced93d3..1982cec7eca0c 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests; @@ -63,6 +64,11 @@ public void testGetTrainedModels() throws IOException { model1.setJsonEntity(buildRegressionModel(modelId)); assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201)); + Request modelDefinition1 = new Request("PUT", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId)); + modelDefinition1.setJsonEntity(buildRegressionModelDefinition(modelId)); + assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201)); + Request model2 = new Request("PUT", InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId2); model2.setJsonEntity(buildRegressionModel(modelId2)); @@ -85,8 +91,26 @@ public void testGetTrainedModels() throws IOException { response = EntityUtils.toString(getModel.getEntity()); assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); + assertThat(response, not(containsString("\"definition\""))); assertThat(response, containsString("\"count\":2")); + getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/test_regression_model?human=true&include_model_definition=true")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, containsString("\"heap_memory_estimation_bytes\"")); + assertThat(response, containsString("\"heap_memory_estimation\"")); + assertThat(response, containsString("\"definition\"")); + assertThat(response, containsString("\"count\":1")); + + ResponseException responseException = expectThrows(ResponseException.class, () -> + client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/test_regression*?human=true&include_model_definition=true"))); + assertThat(EntityUtils.toString(responseException.getResponse().getEntity()), + containsString(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED)); + getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference/test_regression_model,test_regression_model-2")); assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); @@ -131,6 +155,11 @@ public void testDeleteTrainedModels() throws IOException { model1.setJsonEntity(buildRegressionModel(modelId)); assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201)); + Request modelDefinition1 = new Request("PUT", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId)); + modelDefinition1.setJsonEntity(buildRegressionModelDefinition(modelId)); + assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201)); + adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh")); Response delModel = client().performRequest(new Request("DELETE", @@ -141,6 +170,18 @@ public void testDeleteTrainedModels() throws IOException { ResponseException responseException = expectThrows(ResponseException.class, () -> client().performRequest(new Request("DELETE", MachineLearning.BASE_PATH + "inference/" + modelId))); assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404)); + + responseException = expectThrows(ResponseException.class, + () -> client().performRequest( + new Request("GET", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId)))); + assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404)); + + responseException = expectThrows(ResponseException.class, + () -> client().performRequest( + new Request("GET", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId))); + assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404)); } private static String buildRegressionModel(String modelId) throws IOException { @@ -149,9 +190,6 @@ private static String buildRegressionModel(String modelId) throws IOException { .setModelId(modelId) .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3"))) .setCreatedBy("ml_test") - .setDefinition(new TrainedModelDefinition.Builder() - .setPreProcessors(Collections.emptyList()) - .setTrainedModel(LocalModelTests.buildRegression())) .setVersion(Version.CURRENT) .setCreateTime(Instant.now()) .build() @@ -160,6 +198,18 @@ private static String buildRegressionModel(String modelId) throws IOException { } } + private static String buildRegressionModelDefinition(String modelId) throws IOException { + try(XContentBuilder builder = XContentFactory.jsonBuilder()) { + new TrainedModelDefinition.Builder() + .setPreProcessors(Collections.emptyList()) + .setTrainedModel(LocalModelTests.buildRegression()) + .setModelId(modelId) + .build() + .toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"))); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON); + } + } + @After public void clearMlState() throws Exception { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java index ee95ddbd9670d..15629579368f3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java @@ -5,86 +5,72 @@ */ package org.elasticsearch.xpack.ml.action; -import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.client.Client; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.ParseField; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.action.AbstractTransportGetResourcesAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Response; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; -import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import java.util.Collections; +import java.util.Set; -public class TransportGetTrainedModelsAction extends AbstractTransportGetResourcesAction { +public class TransportGetTrainedModelsAction extends HandledTransportAction { + + private final TrainedModelProvider provider; @Inject - public TransportGetTrainedModelsAction(TransportService transportService, ActionFilters actionFilters, Client client, - NamedXContentRegistry xContentRegistry) { - super(GetTrainedModelsAction.NAME, transportService, actionFilters, GetTrainedModelsAction.Request::new, client, - xContentRegistry); + public TransportGetTrainedModelsAction(TransportService transportService, + ActionFilters actionFilters, + TrainedModelProvider trainedModelProvider) { + super(GetTrainedModelsAction.NAME, transportService, actionFilters, GetTrainedModelsAction.Request::new); + this.provider = trainedModelProvider; } @Override - protected ParseField getResultsField() { - return GetTrainedModelsAction.Response.RESULTS_FIELD; - } + protected void doExecute(Task task, Request request, ActionListener listener) { - @Override - protected String[] getIndices() { - return new String[] { InferenceIndexConstants.INDEX_PATTERN }; - } + Response.Builder responseBuilder = Response.builder(); - @Override - protected TrainedModelConfig parse(XContentParser parser) { - return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build(); - } + ActionListener>> idExpansionListener = ActionListener.wrap( + totalAndIds -> { + responseBuilder.setTotalCount(totalAndIds.v1()); - @Override - protected ResourceNotFoundException notFoundException(String resourceId) { - return ExceptionsHelper.missingTrainedModel(resourceId); - } + if (totalAndIds.v2().isEmpty()) { + listener.onResponse(responseBuilder.build()); + return; + } - @Override - protected void doExecute(Task task, GetTrainedModelsAction.Request request, - ActionListener listener) { - searchResources(request, ActionListener.wrap( - queryPage -> listener.onResponse(new GetTrainedModelsAction.Response(queryPage)), - listener::onFailure - )); - } + if (request.isIncludeModelDefinition() && totalAndIds.v2().size() > 1) { + listener.onFailure( + ExceptionsHelper.badRequestException(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED) + ); + return; + } - @Override - protected String executionOrigin() { - return ML_ORIGIN; - } - - @Override - protected String extractIdFromResource(TrainedModelConfig config) { - return config.getModelId(); - } + if (request.isIncludeModelDefinition()) { + provider.getTrainedModel(totalAndIds.v2().iterator().next(), true, ActionListener.wrap( + config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()), + listener::onFailure + )); + } else { + provider.getTrainedModels(totalAndIds.v2(), request.isAllowNoResources(), ActionListener.wrap( + configs -> listener.onResponse(responseBuilder.setModels(configs).build()), + listener::onFailure + )); + } + }, + listener::onFailure + ); - @Override - protected SearchSourceBuilder customSearchOptions(SearchSourceBuilder searchSourceBuilder) { - return searchSourceBuilder.sort("_index", SortOrder.DESC); + provider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idExpansionListener); } - @Nullable - protected QueryBuilder additionalQuery() { - return QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME); - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java index aabd760be4097..a15579b62de6a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java @@ -11,37 +11,23 @@ import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction; import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest; import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse; -import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.metrics.CounterMetric; -import org.elasticsearch.common.regex.Regex; -import org.elasticsearch.index.query.BoolQueryBuilder; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.ingest.IngestStats; import org.elasticsearch.ingest.Pipeline; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import java.util.ArrayList; import java.util.HashMap; @@ -63,17 +49,20 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction private final Client client; private final ClusterService clusterService; private final IngestService ingestService; + private final TrainedModelProvider trainedModelProvider; @Inject public TransportGetTrainedModelsStatsAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, IngestService ingestService, + TrainedModelProvider trainedModelProvider, Client client) { super(GetTrainedModelsStatsAction.NAME, transportService, actionFilters, GetTrainedModelsStatsAction.Request::new); this.client = client; this.clusterService = clusterService; this.ingestService = ingestService; + this.trainedModelProvider = trainedModelProvider; } @Override @@ -105,7 +94,7 @@ protected void doExecute(Task task, listener::onFailure ); - expandIds(request, idsListener); + trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idsListener); } static Map inferenceIngestStatsByPipelineId(NodesStatsResponse response, @@ -124,91 +113,6 @@ static Map inferenceIngestStatsByPipelineId(NodesStatsRespo return ingestStatsMap; } - - private void expandIds(GetTrainedModelsStatsAction.Request request, ActionListener>> idsListener) { - String[] tokens = Strings.tokenizeToStringArray(request.getResourceId(), ","); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() - .sort(SortBuilders.fieldSort(request.getResourceIdField()) - // If there are no resources, there might be no mapping for the id field. - // This makes sure we don't get an error if that happens. - .unmappedType("long")) - .query(buildQuery(tokens, request.getResourceIdField())); - if (request.getPageParams() != null) { - sourceBuilder.from(request.getPageParams().getFrom()) - .size(request.getPageParams().getSize()); - } - sourceBuilder.trackTotalHits(true) - // we only care about the item id's, there is no need to load large model definitions. - .fetchSource(TrainedModelConfig.MODEL_ID.getPreferredName(), null); - - IndicesOptions indicesOptions = SearchRequest.DEFAULT_INDICES_OPTIONS; - SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN) - .indicesOptions(IndicesOptions.fromOptions(true, - indicesOptions.allowNoIndices(), - indicesOptions.expandWildcardsOpen(), - indicesOptions.expandWildcardsClosed(), - indicesOptions)) - .source(sourceBuilder); - - executeAsyncWithOrigin(client.threadPool().getThreadContext(), - ML_ORIGIN, - searchRequest, - ActionListener.wrap( - response -> { - Set foundResourceIds = new LinkedHashSet<>(); - long totalHitCount = response.getHits().getTotalHits().value; - for (SearchHit hit : response.getHits().getHits()) { - Map docSource = hit.getSourceAsMap(); - if (docSource == null) { - continue; - } - Object idValue = docSource.get(TrainedModelConfig.MODEL_ID.getPreferredName()); - if (idValue instanceof String) { - foundResourceIds.add(idValue.toString()); - } - } - ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, request.isAllowNoResources()); - requiredMatches.filterMatchedIds(foundResourceIds); - if (requiredMatches.hasUnmatchedIds()) { - idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString())); - } else { - idsListener.onResponse(Tuple.tuple(totalHitCount, foundResourceIds)); - } - }, - idsListener::onFailure - ), - client::search); - - } - - private QueryBuilder buildQuery(String[] tokens, String resourceIdField) { - BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() - .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME)); - - if (Strings.isAllOrWildcard(tokens)) { - return boolQuery; - } - // If the resourceId is not _all or *, we should see if it is a comma delimited string with wild-cards - // e.g. id1,id2*,id3 - BoolQueryBuilder shouldQueries = new BoolQueryBuilder(); - List terms = new ArrayList<>(); - for (String token : tokens) { - if (Regex.isSimpleMatchPattern(token)) { - shouldQueries.should(QueryBuilders.wildcardQuery(resourceIdField, token)); - } else { - terms.add(token); - } - } - if (terms.isEmpty() == false) { - shouldQueries.should(QueryBuilders.termsQuery(resourceIdField, terms)); - } - - if (shouldQueries.should().isEmpty() == false) { - boolQuery.filter(shouldQueries); - } - return boolQuery; - } - static String[] ingestNodes(final ClusterState clusterState) { String[] ingestNodes = new String[clusterState.nodes().getIngestNodes().size()]; Iterator nodeIterator = clusterState.nodes().getIngestNodes().keysIt(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 63e496060a788..d47a4eb886079 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -20,10 +20,19 @@ import org.elasticsearch.action.search.MultiSearchAction; import org.elasticsearch.action.search.MultiSearchRequestBuilder; import org.elasticsearch.action.search.MultiSearchResponse; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.Client; import org.elasticsearch.common.CheckedBiFunction; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.regex.Regex; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; @@ -34,12 +43,18 @@ import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.reindex.DeleteByQueryAction; import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher; +import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; @@ -49,10 +64,17 @@ import java.io.IOException; import java.io.InputStream; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; +import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_FAILED_TO_DESERIALIZE; public class TrainedModelProvider { @@ -191,6 +213,56 @@ public void getTrainedModel(final String modelId, final boolean includeDefinitio multiSearchResponseActionListener); } + /** + * Gets all the provided trained config model objects + * + * NOTE: + * This does no expansion on the ids. + * It assumes that there are fewer than 10k. + */ + public void getTrainedModels(Set modelIds, boolean allowNoResources, final ActionListener> listener) { + QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0]))); + + SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) + .addSort(TrainedModelConfig.MODEL_ID.getPreferredName(), SortOrder.ASC) + .addSort("_index", SortOrder.DESC) + .setQuery(queryBuilder) + .request(); + + ActionListener configSearchHandler = ActionListener.wrap( + searchResponse -> { + Set observedIds = new HashSet<>(searchResponse.getHits().getHits().length, 1.0f); + List configs = new ArrayList<>(searchResponse.getHits().getHits().length); + for(SearchHit searchHit : searchResponse.getHits().getHits()) { + try { + if (observedIds.contains(searchHit.getId()) == false) { + configs.add( + parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()).build() + ); + observedIds.add(searchHit.getId()); + } + } catch (IOException ex) { + listener.onFailure( + ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ex, searchHit.getId())); + return; + } + } + // We previously expanded the IDs. + // If the config has gone missing between then and now we should throw if allowNoResources is false + // Otherwise, treat it as if it was never expanded to begin with. + Set missingConfigs = Sets.difference(modelIds, observedIds); + if (missingConfigs.isEmpty() == false && allowNoResources == false) { + listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs)); + return; + } + listener.onResponse(configs); + }, + listener::onFailure + ); + + executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configSearchHandler); + } + public void deleteTrainedModel(String modelId, ActionListener listener) { DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); @@ -216,6 +288,92 @@ public void deleteTrainedModel(String modelId, ActionListener listener) })); } + public void expandIds(String idExpression, + boolean allowNoResources, + @Nullable PageParams pageParams, + ActionListener>> idsListener) { + String[] tokens = Strings.tokenizeToStringArray(idExpression, ","); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() + .sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName()) + // If there are no resources, there might be no mapping for the id field. + // This makes sure we don't get an error if that happens. + .unmappedType("long")) + .query(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName())); + if (pageParams != null) { + sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize()); + } + sourceBuilder.trackTotalHits(true) + // we only care about the item id's + .fetchSource(TrainedModelConfig.MODEL_ID.getPreferredName(), null); + + IndicesOptions indicesOptions = SearchRequest.DEFAULT_INDICES_OPTIONS; + SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN) + .indicesOptions(IndicesOptions.fromOptions(true, + indicesOptions.allowNoIndices(), + indicesOptions.expandWildcardsOpen(), + indicesOptions.expandWildcardsClosed(), + indicesOptions)) + .source(sourceBuilder); + + executeAsyncWithOrigin(client.threadPool().getThreadContext(), + ML_ORIGIN, + searchRequest, + ActionListener.wrap( + response -> { + Set foundResourceIds = new LinkedHashSet<>(); + long totalHitCount = response.getHits().getTotalHits().value; + for (SearchHit hit : response.getHits().getHits()) { + Map docSource = hit.getSourceAsMap(); + if (docSource == null) { + continue; + } + Object idValue = docSource.get(TrainedModelConfig.MODEL_ID.getPreferredName()); + if (idValue instanceof String) { + foundResourceIds.add(idValue.toString()); + } + } + ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources); + requiredMatches.filterMatchedIds(foundResourceIds); + if (requiredMatches.hasUnmatchedIds()) { + idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString())); + } else { + idsListener.onResponse(Tuple.tuple(totalHitCount, foundResourceIds)); + } + }, + idsListener::onFailure + ), + client::search); + + } + + private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME)); + + if (Strings.isAllOrWildcard(tokens)) { + return boolQuery; + } + // If the resourceId is not _all or *, we should see if it is a comma delimited string with wild-cards + // e.g. id1,id2*,id3 + BoolQueryBuilder shouldQueries = new BoolQueryBuilder(); + List terms = new ArrayList<>(); + for (String token : tokens) { + if (Regex.isSimpleMatchPattern(token)) { + shouldQueries.should(QueryBuilders.wildcardQuery(resourceIdField, token)); + } else { + terms.add(token); + } + } + if (terms.isEmpty() == false) { + shouldQueries.should(QueryBuilders.termsQuery(resourceIdField, terms)); + } + + if (shouldQueries.should().isEmpty() == false) { + boolQuery.filter(shouldQueries); + } + return boolQuery; + } + private static T handleSearchItem(MultiSearchResponse.Item item, String resourceId, CheckedBiFunction parseLeniently) throws Exception { @@ -228,23 +386,23 @@ private static T handleSearchItem(MultiSearchResponse.Item item, return parseLeniently.apply(item.getResponse().getHits().getHits()[0].getSourceRef(), resourceId); } - private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws Exception { + private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws IOException { try (InputStream stream = source.streamInput(); XContentParser parser = XContentFactory.xContent(XContentType.JSON) .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { return TrainedModelConfig.fromXContent(parser, true); - } catch (Exception e) { + } catch (IOException e) { logger.error(new ParameterizedMessage("[{}] failed to parse model", modelId), e); throw e; } } - private TrainedModelDefinition parseModelDefinitionDocLenientlyFromSource(BytesReference source, String modelId) throws Exception { + private TrainedModelDefinition parseModelDefinitionDocLenientlyFromSource(BytesReference source, String modelId) throws IOException { try (InputStream stream = source.streamInput(); XContentParser parser = XContentFactory.xContent(XContentType.JSON) .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { return TrainedModelDefinition.fromXContent(parser, true).build(); - } catch (Exception e) { + } catch (IOException e) { logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), e); throw e; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java index 40ddd05827043..578b75fbc0793 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -41,7 +41,11 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient if (Strings.isNullOrEmpty(modelId)) { modelId = MetaData.ALL; } - GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId); + boolean includeModelDefinition = restRequest.paramAsBoolean( + GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(), + false + ); + GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition); if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 2362635142d0d..272628f4c12f8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -147,6 +147,14 @@ public void testMaxCachedLimitReached() throws Exception { modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); + // Should have been loaded from the cluster change event + // Verify that we have at least loaded all three so that evictions occur in the following loop + assertBusy(() -> { + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any()); + }); + String[] modelIds = new String[]{model1, model2, model3}; for(int i = 0; i < 10; i++) { // Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 40c52f55de144..099baf949b684 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -65,7 +65,8 @@ public void testInferModels() throws Exception { .setInput(new TrainedModelInput(Arrays.asList("field1", "field2"))) .setDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) - .setTrainedModel(buildClassification(true))) + .setTrainedModel(buildClassification(true)) + .setModelId(modelId1)) .setVersion(Version.CURRENT) .setCreateTime(Instant.now()) .build(); @@ -73,7 +74,8 @@ public void testInferModels() throws Exception { .setInput(new TrainedModelInput(Arrays.asList("field1", "field2"))) .setDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) - .setTrainedModel(buildRegression())) + .setTrainedModel(buildRegression()) + .setModelId(modelId2)) .setVersion(Version.CURRENT) .setCreateTime(Instant.now()) .build(); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json index 481f8b25975bb..22d16a6c36941 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json @@ -33,6 +33,12 @@ "description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)", "default":true }, + "include_model_definition":{ + "type":"boolean", + "required":false, + "description":"Should the full model definition be included in the results. These definitions can be large", + "default":false + }, "from":{ "type":"int", "description":"skips a number of trained models", diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml index efc6b784dbeac..6062f6519067f 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml @@ -207,10 +207,12 @@ setup: failed: 0 processors: - inference: - count: 0 - time_in_millis: 0 - current: 0 - failed: 0 + type: inference + stats: + count: 0 + time_in_millis: 0 + current: 0 + failed: 0 - match: trained_model_stats.0.ingest.pipelines.regression-model-pipeline-1: @@ -220,7 +222,9 @@ setup: failed: 0 processors: - inference: - count: 0 - time_in_millis: 0 - current: 0 - failed: 0 + type: inference + stats: + count: 0 + time_in_millis: 0 + current: 0 + failed: 0