-
Notifications
You must be signed in to change notification settings - Fork 24.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML][Inference] adds lazy model loader and inference
- Loading branch information
Showing
16 changed files
with
1,265 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
.../plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License; | ||
* you may not use this file except in compliance with the Elastic License. | ||
*/ | ||
package org.elasticsearch.xpack.ml.action; | ||
|
||
import org.elasticsearch.action.ActionListener; | ||
import org.elasticsearch.action.support.ActionFilters; | ||
import org.elasticsearch.action.support.HandledTransportAction; | ||
import org.elasticsearch.client.Client; | ||
import org.elasticsearch.common.inject.Inject; | ||
import org.elasticsearch.tasks.Task; | ||
import org.elasticsearch.threadpool.ThreadPool; | ||
import org.elasticsearch.transport.TransportService; | ||
import org.elasticsearch.xpack.ml.inference.action.InferModelAction; | ||
import org.elasticsearch.xpack.ml.inference.loadingservice.Model; | ||
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; | ||
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor; | ||
|
||
import java.util.List; | ||
|
||
public class TransportInferModelAction extends HandledTransportAction<InferModelAction.Request, InferModelAction.Response> { | ||
|
||
private final ModelLoadingService modelLoadingService; | ||
private final Client client; | ||
|
||
@Inject | ||
public TransportInferModelAction(String actionName, | ||
TransportService transportService, | ||
ActionFilters actionFilters, | ||
ModelLoadingService modelLoadingService, | ||
Client client) { | ||
super(actionName, transportService, actionFilters, InferModelAction.Request::new); | ||
this.modelLoadingService = modelLoadingService; | ||
this.client = client; | ||
} | ||
|
||
@Override | ||
protected void doExecute(Task task, InferModelAction.Request request, ActionListener<InferModelAction.Response> listener) { | ||
|
||
ActionListener<List<Object>> inferenceCompleteListener = ActionListener.wrap( | ||
inferenceResponse -> listener.onResponse(new InferModelAction.Response(inferenceResponse)), | ||
listener::onFailure | ||
); | ||
|
||
ActionListener<Model> getModelListener = ActionListener.wrap( | ||
model -> { | ||
TypedChainTaskExecutor<Object> typedChainTaskExecutor = | ||
new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME), | ||
// run through all tasks | ||
r -> true, | ||
// Always fail immediately and return an error | ||
ex -> true); | ||
if (request.getTopClasses() != null) { | ||
request.getObjectsToInfer().forEach(stringObjectMap -> | ||
typedChainTaskExecutor.add(chainedTask -> model.confidence(stringObjectMap, request.getTopClasses(), chainedTask)) | ||
); | ||
} else { | ||
request.getObjectsToInfer().forEach(stringObjectMap -> | ||
typedChainTaskExecutor.add(chainedTask -> model.infer(stringObjectMap, chainedTask)) | ||
); | ||
} | ||
typedChainTaskExecutor.execute(inferenceCompleteListener); | ||
}, | ||
listener::onFailure | ||
); | ||
|
||
this.modelLoadingService.getModelAndCache(request.getModelId(), request.getModelVersion(), getModelListener); | ||
} | ||
} |
168 changes: 168 additions & 0 deletions
168
...plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License; | ||
* you may not use this file except in compliance with the Elastic License. | ||
*/ | ||
package org.elasticsearch.xpack.ml.inference.action; | ||
|
||
import org.elasticsearch.action.ActionRequest; | ||
import org.elasticsearch.action.ActionRequestBuilder; | ||
import org.elasticsearch.action.ActionRequestValidationException; | ||
import org.elasticsearch.action.ActionResponse; | ||
import org.elasticsearch.action.ActionType; | ||
import org.elasticsearch.client.ElasticsearchClient; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
|
||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
|
||
public class InferModelAction extends ActionType<InferModelAction.Response> { | ||
|
||
public static final InferModelAction INSTANCE = new InferModelAction(); | ||
public static final String NAME = "cluster:admin/xpack/ml/infer"; | ||
|
||
private InferModelAction() { | ||
super(NAME, Response::new); | ||
} | ||
|
||
public static class Request extends ActionRequest { | ||
|
||
private final String modelId; | ||
private final long modelVersion; | ||
private final List<Map<String, Object>> objectsToInfer; | ||
private final boolean cacheModel; | ||
private final Integer topClasses; | ||
|
||
public Request(String modelId, long modelVersion) { | ||
this(modelId, modelVersion, Collections.emptyList(), null); | ||
} | ||
|
||
public Request(String modelId, long modelVersion, List<Map<String, Object>> objectsToInfer, Integer topClasses) { | ||
this.modelId = modelId; | ||
this.modelVersion = modelVersion; | ||
this.objectsToInfer = objectsToInfer == null ? Collections.emptyList() : | ||
Collections.unmodifiableList(new ArrayList<>(objectsToInfer)); | ||
this.cacheModel = true; | ||
this.topClasses = topClasses; | ||
} | ||
|
||
public Request(String modelId, long modelVersion, Map<String, Object> objectToInfer, Integer topClasses) { | ||
this(modelId, | ||
modelVersion, | ||
objectToInfer == null ? Collections.emptyList() : Collections.singletonList(objectToInfer), | ||
topClasses); | ||
} | ||
|
||
public Request(StreamInput in) throws IOException { | ||
super(in); | ||
this.modelId = in.readString(); | ||
this.modelVersion = in.readVLong(); | ||
this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap)); | ||
this.topClasses = in.readOptionalInt(); | ||
this.cacheModel = in.readBoolean(); | ||
} | ||
|
||
public String getModelId() { | ||
return modelId; | ||
} | ||
|
||
public long getModelVersion() { | ||
return modelVersion; | ||
} | ||
|
||
public List<Map<String, Object>> getObjectsToInfer() { | ||
return objectsToInfer; | ||
} | ||
|
||
public boolean isCacheModel() { | ||
return cacheModel; | ||
} | ||
|
||
public Integer getTopClasses() { | ||
return topClasses; | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(modelId); | ||
out.writeVLong(modelVersion); | ||
out.writeCollection(objectsToInfer, StreamOutput::writeMap); | ||
out.writeOptionalInt(topClasses); | ||
out.writeBoolean(cacheModel); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
InferModelAction.Request that = (InferModelAction.Request) o; | ||
return Objects.equals(modelId, that.modelId) | ||
&& Objects.equals(modelVersion, that.modelVersion) | ||
&& Objects.equals(topClasses, that.topClasses) | ||
&& Objects.equals(cacheModel, that.cacheModel) | ||
&& Objects.equals(objectsToInfer, that.objectsToInfer); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(modelId, modelVersion, objectsToInfer, topClasses, cacheModel); | ||
} | ||
|
||
} | ||
|
||
public static class RequestBuilder extends ActionRequestBuilder<Request, Response> { | ||
public RequestBuilder(ElasticsearchClient client, Request request) { | ||
super(client, INSTANCE, request); | ||
} | ||
} | ||
|
||
public static class Response extends ActionResponse { | ||
|
||
// TODO come up with a better union type object | ||
private final List<Object> inferenceResponse; | ||
|
||
public Response(List<Object> inferenceResponse) { | ||
super(); | ||
this.inferenceResponse = Collections.unmodifiableList(inferenceResponse); | ||
} | ||
|
||
public Response(StreamInput in) throws IOException { | ||
super(in); | ||
this.inferenceResponse = Collections.unmodifiableList(in.readList(StreamInput::readGenericValue)); | ||
} | ||
|
||
public List<Object> getInferenceResponse() { | ||
return inferenceResponse; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeCollection(inferenceResponse, StreamOutput::writeGenericValue); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
InferModelAction.Response that = (InferModelAction.Response) o; | ||
return Objects.equals(inferenceResponse, that.inferenceResponse); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(inferenceResponse); | ||
} | ||
|
||
} | ||
} |
Oops, something went wrong.