Skip to content

Commit

Permalink
[ML][Inference] adds lazy model loader and inference
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent committed Oct 1, 2019
1 parent e2b9c1b commit ac1d0ab
Show file tree
Hide file tree
Showing 16 changed files with 1,265 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.elasticsearch.common.regex.Regex;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.env.Environment;
import org.elasticsearch.gateway.GatewayService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class Parameters {
*/
public final Client client;

public Parameters(Environment env, ScriptService scriptService, AnalysisRegistry analysisRegistry, ThreadContext threadContext,
public Parameters(Environment env, ScriptService scriptService, AnalysisRegistry analysisRegistry, ThreadContext threadContext,
LongSupplier relativeTimeSupplier, BiFunction<Long, Runnable, Scheduler.ScheduledCancellable> scheduler,
IngestService ingestService, Client client) {
this.env = env;
Expand Down
8 changes: 4 additions & 4 deletions server/src/main/java/org/elasticsearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,7 @@ protected Node(
clusterService.addLocalNodeMasterListener(
new ConsistentSettingsService(settings, clusterService, settingsModule.getConsistentSettings())
.newHashPublisher());
final IngestService ingestService = new IngestService(clusterService, threadPool, this.environment,
scriptModule.getScriptService(), analysisModule.getAnalysisRegistry(),
pluginsService.filterPlugins(IngestPlugin.class), client);

final ClusterInfoService clusterInfoService = newClusterInfoService(settings, clusterService, threadPool, client);
final UsageService usageService = new UsageService();

Expand Down Expand Up @@ -405,7 +403,9 @@ protected Node(
ClusterModule.getNamedXWriteables().stream())
.flatMap(Function.identity()).collect(toList()));
final MetaStateService metaStateService = new MetaStateService(nodeEnvironment, xContentRegistry);

final IngestService ingestService = new IngestService(clusterService, threadPool, this.environment,
scriptModule.getScriptService(), analysisModule.getAnalysisRegistry(),
pluginsService.filterPlugins(IngestPlugin.class), client);
// collect engine factory providers from server and from plugins
final Collection<EnginePlugin> enginePlugins = pluginsService.filterPlugins(EnginePlugin.class);
final Collection<Function<IndexSettings, Optional<EngineFactory>>> engineFactoryProviders =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public List<String> getFeatureNames() {

@Override
public double infer(Map<String, Object> fields) {
List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
List<Double> features = featureNames.stream().map(f -> ((Number)fields.get(f)).doubleValue()).collect(Collectors.toList());
return infer(features);
}

Expand All @@ -128,7 +128,7 @@ public List<Double> classificationProbability(Map<String, Object> fields) {
throw new UnsupportedOperationException(
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
}
List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
List<Double> features = featureNames.stream().map(f -> ((Number)fields.get(f)).doubleValue()).collect(Collectors.toList());
return classificationProbability(features);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.elasticsearch.persistent.PersistentTasksExecutor;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.AnalysisPlugin;
import org.elasticsearch.plugins.IngestPlugin;
import org.elasticsearch.plugins.PersistentTaskPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestController;
Expand Down Expand Up @@ -161,6 +162,7 @@
import org.elasticsearch.xpack.ml.action.TransportGetModelSnapshotsAction;
import org.elasticsearch.xpack.ml.action.TransportGetOverallBucketsAction;
import org.elasticsearch.xpack.ml.action.TransportGetRecordsAction;
import org.elasticsearch.xpack.ml.action.TransportInferModelAction;
import org.elasticsearch.xpack.ml.action.TransportIsolateDatafeedAction;
import org.elasticsearch.xpack.ml.action.TransportKillProcessAction;
import org.elasticsearch.xpack.ml.action.TransportMlInfoAction;
Expand Down Expand Up @@ -200,7 +202,10 @@
import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.inference.action.InferModelAction;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier;
Expand Down Expand Up @@ -299,7 +304,7 @@
import static java.util.Collections.emptyList;
import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME;

public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlugin, PersistentTaskPlugin {
public class MachineLearning extends Plugin implements ActionPlugin, IngestPlugin, AnalysisPlugin, PersistentTaskPlugin {
public static final String NAME = "ml";
public static final String BASE_PATH = "/_ml/";
public static final String PRE_V7_BASE_PATH = "/_xpack/ml/";
Expand Down Expand Up @@ -495,6 +500,8 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
notifier,
xContentRegistry);

final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry);
final ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider);
// special holder for @link(MachineLearningFeatureSetUsage) which needs access to job manager if ML is enabled
JobManagerHolder jobManagerHolder = new JobManagerHolder(jobManager);

Expand Down Expand Up @@ -607,7 +614,9 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
analyticsProcessManager,
memoryEstimationProcessManager,
dataFrameAnalyticsConfigProvider,
nativeStorageProvider
nativeStorageProvider,
modelLoadingService,
trainedModelProvider
);
}

Expand Down Expand Up @@ -762,6 +771,7 @@ public List<RestHandler> getRestHandlers(Settings settings, RestController restC
new ActionHandler<>(StopDataFrameAnalyticsAction.INSTANCE, TransportStopDataFrameAnalyticsAction.class),
new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class),
new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class),
new ActionHandler<>(InferModelAction.INSTANCE, TransportInferModelAction.class),
usageAction,
infoAction);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.action;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.ml.inference.action.InferModelAction;
import org.elasticsearch.xpack.ml.inference.loadingservice.Model;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;

import java.util.List;

public class TransportInferModelAction extends HandledTransportAction<InferModelAction.Request, InferModelAction.Response> {

private final ModelLoadingService modelLoadingService;
private final Client client;

@Inject
public TransportInferModelAction(String actionName,
TransportService transportService,
ActionFilters actionFilters,
ModelLoadingService modelLoadingService,
Client client) {
super(actionName, transportService, actionFilters, InferModelAction.Request::new);
this.modelLoadingService = modelLoadingService;
this.client = client;
}

@Override
protected void doExecute(Task task, InferModelAction.Request request, ActionListener<InferModelAction.Response> listener) {

ActionListener<List<Object>> inferenceCompleteListener = ActionListener.wrap(
inferenceResponse -> listener.onResponse(new InferModelAction.Response(inferenceResponse)),
listener::onFailure
);

ActionListener<Model> getModelListener = ActionListener.wrap(
model -> {
TypedChainTaskExecutor<Object> typedChainTaskExecutor =
new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME),
// run through all tasks
r -> true,
// Always fail immediately and return an error
ex -> true);
if (request.getTopClasses() != null) {
request.getObjectsToInfer().forEach(stringObjectMap ->
typedChainTaskExecutor.add(chainedTask -> model.confidence(stringObjectMap, request.getTopClasses(), chainedTask))
);
} else {
request.getObjectsToInfer().forEach(stringObjectMap ->
typedChainTaskExecutor.add(chainedTask -> model.infer(stringObjectMap, chainedTask))
);
}
typedChainTaskExecutor.execute(inferenceCompleteListener);
},
listener::onFailure
);

this.modelLoadingService.getModelAndCache(request.getModelId(), request.getModelVersion(), getModelListener);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.action;

import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestBuilder;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.ElasticsearchClient;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class InferModelAction extends ActionType<InferModelAction.Response> {

public static final InferModelAction INSTANCE = new InferModelAction();
public static final String NAME = "cluster:admin/xpack/ml/infer";

private InferModelAction() {
super(NAME, Response::new);
}

public static class Request extends ActionRequest {

private final String modelId;
private final long modelVersion;
private final List<Map<String, Object>> objectsToInfer;
private final boolean cacheModel;
private final Integer topClasses;

public Request(String modelId, long modelVersion) {
this(modelId, modelVersion, Collections.emptyList(), null);
}

public Request(String modelId, long modelVersion, List<Map<String, Object>> objectsToInfer, Integer topClasses) {
this.modelId = modelId;
this.modelVersion = modelVersion;
this.objectsToInfer = objectsToInfer == null ? Collections.emptyList() :
Collections.unmodifiableList(new ArrayList<>(objectsToInfer));
this.cacheModel = true;
this.topClasses = topClasses;
}

public Request(String modelId, long modelVersion, Map<String, Object> objectToInfer, Integer topClasses) {
this(modelId,
modelVersion,
objectToInfer == null ? Collections.emptyList() : Collections.singletonList(objectToInfer),
topClasses);
}

public Request(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.modelVersion = in.readVLong();
this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap));
this.topClasses = in.readOptionalInt();
this.cacheModel = in.readBoolean();
}

public String getModelId() {
return modelId;
}

public long getModelVersion() {
return modelVersion;
}

public List<Map<String, Object>> getObjectsToInfer() {
return objectsToInfer;
}

public boolean isCacheModel() {
return cacheModel;
}

public Integer getTopClasses() {
return topClasses;
}

@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
out.writeVLong(modelVersion);
out.writeCollection(objectsToInfer, StreamOutput::writeMap);
out.writeOptionalInt(topClasses);
out.writeBoolean(cacheModel);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferModelAction.Request that = (InferModelAction.Request) o;
return Objects.equals(modelId, that.modelId)
&& Objects.equals(modelVersion, that.modelVersion)
&& Objects.equals(topClasses, that.topClasses)
&& Objects.equals(cacheModel, that.cacheModel)
&& Objects.equals(objectsToInfer, that.objectsToInfer);
}

@Override
public int hashCode() {
return Objects.hash(modelId, modelVersion, objectsToInfer, topClasses, cacheModel);
}

}

public static class RequestBuilder extends ActionRequestBuilder<Request, Response> {
public RequestBuilder(ElasticsearchClient client, Request request) {
super(client, INSTANCE, request);
}
}

public static class Response extends ActionResponse {

// TODO come up with a better union type object
private final List<Object> inferenceResponse;

public Response(List<Object> inferenceResponse) {
super();
this.inferenceResponse = Collections.unmodifiableList(inferenceResponse);
}

public Response(StreamInput in) throws IOException {
super(in);
this.inferenceResponse = Collections.unmodifiableList(in.readList(StreamInput::readGenericValue));
}

public List<Object> getInferenceResponse() {
return inferenceResponse;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(inferenceResponse, StreamOutput::writeGenericValue);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferModelAction.Response that = (InferModelAction.Response) o;
return Objects.equals(inferenceResponse, that.inferenceResponse);
}

@Override
public int hashCode() {
return Objects.hash(inferenceResponse);
}

}
}
Loading

0 comments on commit ac1d0ab

Please sign in to comment.