Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML][Inference] add new flag for optionally including model definition #48718

Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
*/
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.action.ActionRequestBuilder;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.ElasticsearchClient;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;


public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Response> {

Expand All @@ -28,26 +31,53 @@ private GetTrainedModelsAction() {

public static class Request extends AbstractGetResourcesRequest {

public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition");
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");

public Request() {
setAllowNoResources(true);
}
private final boolean includeModelDefinition;

public Request(String id) {
public Request(String id, boolean includeModelDefinition) {
setResourceId(id);
setAllowNoResources(true);
this.includeModelDefinition = includeModelDefinition;
}

public Request(StreamInput in) throws IOException {
super(in);
this.includeModelDefinition = in.readBoolean();
}

@Override
public String getResourceIdField() {
return TrainedModelConfig.MODEL_ID.getPreferredName();
}

public boolean isIncludeModelDefinition() {
return includeModelDefinition;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeBoolean(includeModelDefinition);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), includeModelDefinition);
}

@Override
public boolean equals(Object obj) {
if (obj == this) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
Request other = (Request) obj;
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition;
}
}

public static class Response extends AbstractGetResourcesResponse<TrainedModelConfig> {
Expand All @@ -66,12 +96,33 @@ public Response(QueryPage<TrainedModelConfig> trainedModels) {
protected Reader<TrainedModelConfig> getReader() {
return TrainedModelConfig::new;
}
}

public static class RequestBuilder extends ActionRequestBuilder<Request, Response> {
public static Builder builder() {
return new Builder();
}

public static class Builder {

public RequestBuilder(ElasticsearchClient client) {
super(client, INSTANCE, new Request());
private long totalCount;
private List<TrainedModelConfig> configs = Collections.emptyList();

private Builder() {
}

public Builder setTotalCount(long totalCount) {
this.totalCount = totalCount;
return this;
}

public Builder setModels(List<TrainedModelConfig> configs) {
this.configs = configs;
return this;
}

public Response build() {
return new Response(new QueryPage<>(configs, totalCount, RESULTS_FIELD));
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,13 @@ public final class Messages {
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}";
public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]";
public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED =
"Getting model definition is not supported when getting more than one model";

public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
public static final String JOB_AUDIT_CREATED = "Job created";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas

@Override
protected Request createTestInstance() {
Request request = new Request(randomAlphaOfLength(20));
Request request = new Request(randomAlphaOfLength(20), randomBoolean());
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
return request;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
*/
package org.elasticsearch.xpack.core.ml.inference;

import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.search.SearchModule;
Expand All @@ -22,7 +25,6 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -33,27 +35,21 @@
import java.util.stream.Stream;

import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;


public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<TrainedModelDefinition> {

private boolean lenient;

@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}

@Override
protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException {
return TrainedModelDefinition.fromXContent(parser, lenient).build();
return TrainedModelDefinition.fromXContent(parser, true).build();
}

@Override
protected boolean supportsUnknownFields() {
return lenient;
return true;
}

@Override
Expand All @@ -63,7 +59,7 @@ protected Predicate<String> getRandomFieldsExcludeFilter() {

@Override
protected ToXContent.Params getToXContentParams() {
return lenient ? ToXContent.EMPTY_PARAMS : new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"));
return new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"));
}

@Override
Expand Down Expand Up @@ -286,9 +282,27 @@ public void testTreeSchemaDeserialization() throws IOException {
assertThat(definition.getTrainedModel().getClass(), equalTo(Tree.class));
}

public void testStrictParser() throws IOException {
TrainedModelDefinition.Builder builder = createRandomBuilder("asdf");
BytesReference reference = XContentHelper.toXContent(builder.build(),
XContentType.JSON,
new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")),
false);

XContentParser parser = XContentHelper.createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
reference,
XContentType.JSON);

XContentParseException exception = expectThrows(XContentParseException.class,
() -> TrainedModelDefinition.fromXContent(parser, false));

assertThat(exception.getMessage(), containsString("[trained_model_definition] unknown field [doc_type]"));
}

@Override
protected TrainedModelDefinition createTestInstance() {
return createRandomBuilder(null).build();
return createRandomBuilder(randomAlphaOfLength(10)).build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests;
Expand Down Expand Up @@ -63,6 +64,11 @@ public void testGetTrainedModels() throws IOException {
model1.setJsonEntity(buildRegressionModel(modelId));
assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));

Request modelDefinition1 = new Request("PUT",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId));
modelDefinition1.setJsonEntity(buildRegressionModelDefinition(modelId));
assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));

Request model2 = new Request("PUT",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId2);
model2.setJsonEntity(buildRegressionModel(modelId2));
Expand All @@ -85,8 +91,26 @@ public void testGetTrainedModels() throws IOException {
response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
assertThat(response, containsString("\"model_id\":\"test_regression_model-2\""));
assertThat(response, not(containsString("\"definition\"")));
assertThat(response, containsString("\"count\":2"));

getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/test_regression_model?human=true&include_model_definition=true"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));

response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
assertThat(response, containsString("\"heap_memory_estimation_bytes\""));
assertThat(response, containsString("\"heap_memory_estimation\""));
assertThat(response, containsString("\"definition\""));
assertThat(response, containsString("\"count\":1"));

ResponseException responseException = expectThrows(ResponseException.class, () ->
client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/test_regression*?human=true&include_model_definition=true")));
assertThat(EntityUtils.toString(responseException.getResponse().getEntity()),
containsString(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED));

getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/test_regression_model,test_regression_model-2"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
Expand Down Expand Up @@ -131,6 +155,11 @@ public void testDeleteTrainedModels() throws IOException {
model1.setJsonEntity(buildRegressionModel(modelId));
assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));

Request modelDefinition1 = new Request("PUT",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId));
modelDefinition1.setJsonEntity(buildRegressionModelDefinition(modelId));
assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));

adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));

Response delModel = client().performRequest(new Request("DELETE",
Expand All @@ -141,6 +170,18 @@ public void testDeleteTrainedModels() throws IOException {
ResponseException responseException = expectThrows(ResponseException.class,
() -> client().performRequest(new Request("DELETE", MachineLearning.BASE_PATH + "inference/" + modelId)));
assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404));

responseException = expectThrows(ResponseException.class,
() -> client().performRequest(
new Request("GET",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId))));
assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404));

responseException = expectThrows(ResponseException.class,
() -> client().performRequest(
new Request("GET",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId)));
assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404));
}

private static String buildRegressionModel(String modelId) throws IOException {
Expand All @@ -149,9 +190,6 @@ private static String buildRegressionModel(String modelId) throws IOException {
.setModelId(modelId)
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3")))
.setCreatedBy("ml_test")
.setDefinition(new TrainedModelDefinition.Builder()
.setPreProcessors(Collections.emptyList())
.setTrainedModel(LocalModelTests.buildRegression()))
.setVersion(Version.CURRENT)
.setCreateTime(Instant.now())
.build()
Expand All @@ -160,6 +198,18 @@ private static String buildRegressionModel(String modelId) throws IOException {
}
}

private static String buildRegressionModelDefinition(String modelId) throws IOException {
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
new TrainedModelDefinition.Builder()
.setPreProcessors(Collections.emptyList())
.setTrainedModel(LocalModelTests.buildRegression())
.setModelId(modelId)
.build()
.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
}
}


@After
public void clearMlState() throws Exception {
Expand Down
Loading