Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML][Inference] Adding read/del trained models #47882

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ToXContentFragment;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;

public class DeleteTrainedModelAction extends ActionType<AcknowledgedResponse> {

public static final DeleteTrainedModelAction INSTANCE = new DeleteTrainedModelAction();
public static final String NAME = "cluster:admin/xpack/ml/inference/delete";

private DeleteTrainedModelAction() {
super(NAME, AcknowledgedResponse::new);
}

public static class Request extends AcknowledgedRequest<Request> implements ToXContentFragment {

private String id;

public Request(StreamInput in) throws IOException {
super(in);
id = in.readString();
}

public Request() {}

public Request(String id) {
this.id = ExceptionsHelper.requireNonNull(id, TrainedModelConfig.MODEL_ID);
}

public String getId() {
return id;
}

@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), id);
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
DeleteTrainedModelAction.Request request = (DeleteTrainedModelAction.Request) o;
return Objects.equals(id, request.id);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(id);
}

@Override
public int hashCode() {
return Objects.hash(id);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.action.ActionRequestBuilder;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.ElasticsearchClient;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;

import java.io.IOException;

public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Response> {

public static final GetTrainedModelsAction INSTANCE = new GetTrainedModelsAction();
public static final String NAME = "cluster:monitor/xpack/ml/inference/get";

private GetTrainedModelsAction() {
super(NAME, Response::new);
}

public static class Request extends AbstractGetResourcesRequest {

public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");

public Request() {
setAllowNoResources(true);
}

public Request(String id) {
setResourceId(id);
setAllowNoResources(true);
}

public Request(StreamInput in) throws IOException {
super(in);
}

@Override
public String getResourceIdField() {
return TrainedModelConfig.MODEL_ID.getPreferredName();
}

}

public static class Response extends AbstractGetResourcesResponse<TrainedModelConfig> {

public static final ParseField RESULTS_FIELD = new ParseField("trained_model_configs");

public Response(StreamInput in) throws IOException {
super(in);
}

public Response(QueryPage<TrainedModelConfig> analytics) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename analytics to trainedModels or similar?

super(analytics);
}

@Override
protected Reader<TrainedModelConfig> getReader() {
return TrainedModelConfig::new;
}
}

public static class RequestBuilder extends ActionRequestBuilder<Request, Response> {

public RequestBuilder(ElasticsearchClient client) {
super(client, INSTANCE, new Request());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.notifications;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage;
import org.elasticsearch.xpack.core.common.notifications.Level;
import org.elasticsearch.xpack.core.ml.job.config.Job;

import java.util.Date;


public class InferenceAuditMessage extends AbstractAuditMessage {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a unit test, similar to AnomalyDetectionAuditMessageTests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will, but it seems like overkill to me as this class essentially does nothing.


//TODO this should be MODEL_ID...
private static final ParseField JOB_ID = Job.ID;
public static final ConstructingObjectParser<InferenceAuditMessage, Void> PARSER =
createParser("ml_inference_audit_message", InferenceAuditMessage::new, JOB_ID);

public InferenceAuditMessage(String resourceId, String message, Level level, Date timestamp, String nodeName) {
super(resourceId, message, level, timestamp, nodeName);
}

@Override
public final String getJobType() {
return "inference";
}

@Override
protected String getResourceField() {
return JOB_ID.getPreferredName();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ public static ResourceAlreadyExistsException dataFrameAnalyticsAlreadyExists(Str
return new ResourceAlreadyExistsException("A data frame analytics with id [{}] already exists", id);
}

public static ResourceNotFoundException missingTrainedModel(String modelId) {
return new ResourceNotFoundException("No known trained model with model_id [{}]", modelId);
}

public static ElasticsearchException serverError(String msg) {
return new ElasticsearchException(msg);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction.Request;

public class DeleteTrainedModelsRequestTests extends AbstractWireSerializingTestCase<Request> {

@Override
protected Request createTestInstance() {
return new Request(randomAlphaOfLengthBetween(1, 20));
}

@Override
protected Writeable.Reader<Request> instanceReader() {
return Request::new;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request;

public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCase<Request> {

@Override
protected Request createTestInstance() {
Request request = new Request(randomAlphaOfLength(20));
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
return request;
}

@Override
protected Writeable.Reader<Request> instanceReader() {
return Request::new;
}
}
Loading