Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML][Inference] adds lazy model loader and inference #47410

Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction;
import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction;
import org.elasticsearch.xpack.core.ml.action.GetRecordsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.KillProcessAction;
import org.elasticsearch.xpack.core.ml.action.MlInfoAction;
Expand Down Expand Up @@ -323,6 +324,7 @@ public List<ActionType<? extends ActionResponse>> getClientActions() {
StartDataFrameAnalyticsAction.INSTANCE,
EvaluateDataFrameAction.INSTANCE,
EstimateMemoryUsageAction.INSTANCE,
InferModelAction.INSTANCE,
// security
ClearRealmCacheAction.INSTANCE,
ClearRolesCacheAction.INSTANCE,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class InferModelAction extends ActionType<InferModelAction.Response> {

public static final InferModelAction INSTANCE = new InferModelAction();
public static final String NAME = "cluster:admin/xpack/ml/infer";

private InferModelAction() {
super(NAME, Response::new);
}

public static class Request extends ActionRequest {

private final String modelId;
private final long modelVersion;
private final List<Map<String, Object>> objectsToInfer;
private final int topClasses;

public Request(String modelId, long modelVersion) {
this(modelId, modelVersion, Collections.emptyList(), null);
}

public Request(String modelId, long modelVersion, List<Map<String, Object>> objectsToInfer, Integer topClasses) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
this.modelVersion = modelVersion;
this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer"));
this.topClasses = topClasses == null ? 0 : topClasses;
}

public Request(String modelId, long modelVersion, Map<String, Object> objectToInfer, Integer topClasses) {
this(modelId,
modelVersion,
Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")),
topClasses);
}

public Request(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.modelVersion = in.readVLong();
this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap));
this.topClasses = in.readInt();
}

public String getModelId() {
return modelId;
}

public long getModelVersion() {
return modelVersion;
}

public List<Map<String, Object>> getObjectsToInfer() {
return objectsToInfer;
}

public int getTopClasses() {
return topClasses;
}

@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
out.writeVLong(modelVersion);
out.writeCollection(objectsToInfer, StreamOutput::writeMap);
out.writeInt(topClasses);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferModelAction.Request that = (InferModelAction.Request) o;
return Objects.equals(modelId, that.modelId)
&& Objects.equals(modelVersion, that.modelVersion)
&& Objects.equals(topClasses, that.topClasses)
&& Objects.equals(objectsToInfer, that.objectsToInfer);
}

@Override
public int hashCode() {
return Objects.hash(modelId, modelVersion, objectsToInfer, topClasses);
}

}

public static class Response extends ActionResponse {

private final List<InferenceResults<?>> inferenceResults;
private final String resultsType;

public Response(List<InferenceResults<?>> inferenceResponse, String resultsType) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we call inferenceResponse -> inferenceResults? Also for the member variable.

super();
this.resultsType = ExceptionsHelper.requireNonNull(resultsType, "resultsType");
this.inferenceResults = inferenceResponse == null ?
Collections.emptyList() :
Collections.unmodifiableList(inferenceResponse);
}

public Response(StreamInput in) throws IOException {
super(in);
this.resultsType = in.readString();
if(resultsType.equals(ClassificationInferenceResults.RESULT_TYPE)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: space after if

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This polymorphism via if here makes me think that InferenceResults could be a NamedWriteable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does, but since InferenceResults<T> supports a generic, I am not sure it is possible...

I could maybe make SingleValueInferenceResults the named writable, but then we would still have a conditional here to choose between single value, multi-value, etc.

this.inferenceResults = Collections.unmodifiableList(in.readList(ClassificationInferenceResults::new));
} else if (this.resultsType.equals(RegressionInferenceResults.RESULT_TYPE)) {
this.inferenceResults = Collections.unmodifiableList(in.readList(RegressionInferenceResults::new));
} else {
throw new IOException("Unrecognized result type [" + resultsType + "]");
}
}

public List<InferenceResults<?>> getInferenceResults() {
return inferenceResults;
}

public String getResultsType() {
return resultsType;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(resultsType);
out.writeCollection(inferenceResults);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferModelAction.Response that = (InferModelAction.Response) o;
return Objects.equals(resultsType, that.resultsType) && Objects.equals(inferenceResults, that.inferenceResults);
}

@Override
public int hashCode() {
return Objects.hash(resultsType, inferenceResults);
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

public class ClassificationInferenceResults extends SingleValueInferenceResults {

public static final String RESULT_TYPE = "classification";
public static final ParseField CLASSIFICATION_LABEL = new ParseField("classification_label");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the predicted class? If yes should we call it predicted_class instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this field is just the optional string label for the ordinal value returned from the trained model.

Might be null as it is the specific string label for the numeric value returned via the model. If the user already transformed their classes into ordinal numerics, there is no label, just the numeric value.

public static final ParseField TOP_CLASSES = new ParseField("top_classes");

private final String classificationLabel;
private final List<TopClassEntry> topClasses;

public ClassificationInferenceResults(double value, String classificationLabel, List<TopClassEntry> topClasses) {
super(value);
dimitris-athanasiou marked this conversation as resolved.
Show resolved Hide resolved
this.classificationLabel = classificationLabel;
dimitris-athanasiou marked this conversation as resolved.
Show resolved Hide resolved
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
}

public ClassificationInferenceResults(StreamInput in) throws IOException {
super(in);
this.classificationLabel = in.readOptionalString();
this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new));
}

public String getClassificationLabel() {
return classificationLabel;
}

public List<TopClassEntry> getTopClasses() {
return topClasses;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalString(classificationLabel);
out.writeCollection(topClasses);
}

@Override
XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
if (classificationLabel != null) {
builder.field(CLASSIFICATION_LABEL.getPreferredName(), classificationLabel);
}
if (topClasses.isEmpty() == false) {
builder.field(TOP_CLASSES.getPreferredName(), topClasses);
}
return builder;
}

@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
ClassificationInferenceResults that = (ClassificationInferenceResults) object;
return Objects.equals(value(), that.value()) &&
Objects.equals(classificationLabel, that.classificationLabel) &&
Objects.equals(topClasses, that.topClasses);
}

@Override
public int hashCode() {
return Objects.hash(value(), classificationLabel, topClasses);
}

@Override
public String resultType() {
return RESULT_TYPE;
}

@Override
public String valueAsString() {
return classificationLabel == null ? super.valueAsString() : classificationLabel;
}

public static class TopClassEntry implements ToXContentObject, Writeable {

public final ParseField CLASSIFICATION = new ParseField("classification");
public final ParseField PROBABILITY = new ParseField("probability");

private final String classification;
private final double probability;

public TopClassEntry(String classification, Double probability) {
this.classification = ExceptionsHelper.requireNonNull(classification, CLASSIFICATION);
this.probability = ExceptionsHelper.requireNonNull(probability, PROBABILITY);
}

public TopClassEntry(StreamInput in) throws IOException {
this.classification = in.readString();
this.probability = in.readDouble();
}

public String getClassification() {
return classification;
}

public double getProbability() {
return probability;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(classification);
out.writeDouble(probability);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASSIFICATION.getPreferredName(), classification);
builder.field(PROBABILITY.getPreferredName(), probability);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
TopClassEntry that = (TopClassEntry) object;
return Objects.equals(classification, that.classification) &&
Objects.equals(probability, that.probability);
}

@Override
public int hashCode() {
return Objects.hash(classification, probability);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;

public interface InferenceResults<T> extends ToXContentObject, Writeable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like us to consider an alternative idea here.

Right now this needs to be generic because of the T value() method. The paradigm is that we call value() to get the result. But how are we going to use this result?

I think eventually the result is an object we append on the object-to-infer, right?

Could we thus have:

Map<String, Object> result()?

It could be we can find a better type than Map<String, Object>. However, that would mean that each implementation of the results could be returning an object flexibly. Hard to discuss all this in text so I'm sure it warrants a nice design discussion!

Copy link
Member Author

@benwtrent benwtrent Oct 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think eventually the result is an object we append on the object-to-infer, right?

No, the result could be used any number of ways. We don't really append it to the mapped fields of the inference fields, we will supply it to the caller either via an API call, or through the ingest processor (which will have a target_field parameter to tell us where to put the result).

Could we thus have:
Map<String, Object> result()?

I would rather not pass around Map<String, Object>. If we were going down that path, what is the point of having an object defined at all?

I honestly think having a generic T covers our uses cases.

Regression: Always returns a single numeric
Classification: Could be numeric or string (depending on if we have field mapped values)
Future: Covers the more exotic cases of List, Map, etc. without sacrificing type-safety.


String resultType();

T value();

String valueAsString();

}
Loading