Skip to content

Commit

Permalink
[ML] ML Model Inference Ingest Processor (elastic#49052)
Browse files Browse the repository at this point in the history
* [ML][Inference] adds lazy model loader and inference (elastic#47410)

This adds a couple of things:

- A model loader service that is accessible via transport calls. This service will load in models and cache them. They will stay loaded until a processor no longer references them
- A Model class and its first sub-class LocalModel. Used to cache model information and run inference.
- Transport action and handler for requests to infer against a local model
Related Feature PRs:
* [ML][Inference] Adjust inference configuration option API (elastic#47812)

* [ML][Inference] adds logistic_regression output aggregator (elastic#48075)

* [ML][Inference] Adding read/del trained models (elastic#47882)

* [ML][Inference] Adding inference ingest processor (elastic#47859)

* [ML][Inference] fixing classification inference for ensemble (elastic#48463)

* [ML][Inference] Adding model memory estimations (elastic#48323)

* [ML][Inference] adding more options to inference processor (elastic#48545)

* [ML][Inference] handle string values better in feature extraction (elastic#48584)

* [ML][Inference] Adding _stats endpoint for inference (elastic#48492)

* [ML][Inference] add inference processors and trained models to usage (elastic#47869)

* [ML][Inference] add new flag for optionally including model definition (elastic#48718)

* [ML][Inference] adding license checks (elastic#49056)

* [ML][Inference] Adding memory and compute estimates to inference (elastic#48955)
  • Loading branch information
benwtrent committed Nov 18, 2019
1 parent 5f9965e commit bf04d3b
Show file tree
Hide file tree
Showing 97 changed files with 7,855 additions and 362 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand All @@ -47,6 +48,8 @@ public class TrainedModelConfig implements ToXContentObject {
public static final ParseField TAGS = new ParseField("tags");
public static final ParseField METADATA = new ParseField("metadata");
public static final ParseField INPUT = new ParseField("input");
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");

public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
Expand All @@ -66,6 +69,8 @@ public class TrainedModelConfig implements ToXContentObject {
PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT);
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
}

public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
Expand All @@ -81,6 +86,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
private final List<String> tags;
private final Map<String, Object> metadata;
private final TrainedModelInput input;
private final Long estimatedHeapMemory;
private final Long estimatedOperations;

TrainedModelConfig(String modelId,
String createdBy,
Expand All @@ -90,7 +97,9 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
TrainedModelDefinition definition,
List<String> tags,
Map<String, Object> metadata,
TrainedModelInput input) {
TrainedModelInput input,
Long estimatedHeapMemory,
Long estimatedOperations) {
this.modelId = modelId;
this.createdBy = createdBy;
this.version = version;
Expand All @@ -100,6 +109,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
this.tags = tags == null ? null : Collections.unmodifiableList(tags);
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
this.input = input;
this.estimatedHeapMemory = estimatedHeapMemory;
this.estimatedOperations = estimatedOperations;
}

public String getModelId() {
Expand Down Expand Up @@ -138,6 +149,18 @@ public TrainedModelInput getInput() {
return input;
}

public ByteSizeValue getEstimatedHeapMemory() {
return estimatedHeapMemory == null ? null : new ByteSizeValue(estimatedHeapMemory);
}

public Long getEstimatedHeapMemoryBytes() {
return estimatedHeapMemory;
}

public Long getEstimatedOperations() {
return estimatedOperations;
}

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -172,6 +195,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (input != null) {
builder.field(INPUT.getPreferredName(), input);
}
if (estimatedHeapMemory != null) {
builder.field(ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), estimatedHeapMemory);
}
if (estimatedOperations != null) {
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
}
builder.endObject();
return builder;
}
Expand All @@ -194,6 +223,8 @@ public boolean equals(Object o) {
Objects.equals(definition, that.definition) &&
Objects.equals(tags, that.tags) &&
Objects.equals(input, that.input) &&
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
Objects.equals(estimatedOperations, that.estimatedOperations) &&
Objects.equals(metadata, that.metadata);
}

Expand All @@ -206,6 +237,8 @@ public int hashCode() {
definition,
description,
tags,
estimatedHeapMemory,
estimatedOperations,
metadata,
input);
}
Expand All @@ -222,6 +255,8 @@ public static class Builder {
private List<String> tags;
private TrainedModelDefinition definition;
private TrainedModelInput input;
private Long estimatedHeapMemory;
private Long estimatedOperations;

public Builder setModelId(String modelId) {
this.modelId = modelId;
Expand Down Expand Up @@ -277,6 +312,16 @@ public Builder setInput(TrainedModelInput input) {
return this;
}

public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
this.estimatedHeapMemory = estimatedHeapMemory;
return this;
}

public Builder setEstimatedOperations(Long estimatedOperations) {
this.estimatedOperations = estimatedOperations;
return this;
}

public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
Expand All @@ -287,7 +332,9 @@ public TrainedModelConfig build() {
definition,
tags,
metadata,
input);
input,
estimatedHeapMemory,
estimatedOperations);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ protected TrainedModelConfig createTestInstance() {
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
randomBoolean() ? null : TrainedModelInputTests.createRandomInput());
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
randomBoolean() ? null : randomNonNegativeLong(),
randomBoolean() ? null : randomNonNegativeLong());

}

@Override
Expand Down
62 changes: 62 additions & 0 deletions server/src/main/java/org/elasticsearch/ingest/IngestStats.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

public class IngestStats implements Writeable, ToXContentFragment {
Expand Down Expand Up @@ -150,6 +151,21 @@ public Map<String, List<ProcessorStat>> getProcessorStats() {
return processorStats;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
IngestStats that = (IngestStats) o;
return Objects.equals(totalStats, that.totalStats)
&& Objects.equals(pipelineStats, that.pipelineStats)
&& Objects.equals(processorStats, that.processorStats);
}

@Override
public int hashCode() {
return Objects.hash(totalStats, pipelineStats, processorStats);
}

public static class Stats implements Writeable, ToXContentFragment {

private final long ingestCount;
Expand Down Expand Up @@ -218,6 +234,22 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field("failed", ingestFailedCount);
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
IngestStats.Stats that = (IngestStats.Stats) o;
return Objects.equals(ingestCount, that.ingestCount)
&& Objects.equals(ingestTimeInMillis, that.ingestTimeInMillis)
&& Objects.equals(ingestFailedCount, that.ingestFailedCount)
&& Objects.equals(ingestCurrent, that.ingestCurrent);
}

@Override
public int hashCode() {
return Objects.hash(ingestCount, ingestTimeInMillis, ingestFailedCount, ingestCurrent);
}
}

/**
Expand Down Expand Up @@ -270,6 +302,20 @@ public String getPipelineId() {
public Stats getStats() {
return stats;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
IngestStats.PipelineStat that = (IngestStats.PipelineStat) o;
return Objects.equals(pipelineId, that.pipelineId)
&& Objects.equals(stats, that.stats);
}

@Override
public int hashCode() {
return Objects.hash(pipelineId, stats);
}
}

/**
Expand Down Expand Up @@ -297,5 +343,21 @@ public String getType() {
public Stats getStats() {
return stats;
}


@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
IngestStats.ProcessorStat that = (IngestStats.ProcessorStat) o;
return Objects.equals(name, that.name)
&& Objects.equals(type, that.type)
&& Objects.equals(stats, that.stats);
}

@Override
public int hashCode() {
return Objects.hash(name, type, stats);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction;
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction;
Expand All @@ -109,6 +110,9 @@
import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction;
import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction;
import org.elasticsearch.xpack.core.ml.action.GetRecordsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.KillProcessAction;
import org.elasticsearch.xpack.core.ml.action.MlInfoAction;
Expand Down Expand Up @@ -153,6 +157,19 @@
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
Expand Down Expand Up @@ -371,6 +388,10 @@ public List<ActionType<? extends ActionResponse>> getClientActions() {
StopDataFrameAnalyticsAction.INSTANCE,
EvaluateDataFrameAction.INSTANCE,
EstimateMemoryUsageAction.INSTANCE,
InferModelAction.INSTANCE,
GetTrainedModelsAction.INSTANCE,
DeleteTrainedModelAction.INSTANCE,
GetTrainedModelsStatsAction.INSTANCE,
// security
ClearRealmCacheAction.INSTANCE,
ClearRolesCacheAction.INSTANCE,
Expand Down Expand Up @@ -519,6 +540,16 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
new NamedWriteableRegistry.Entry(OutputAggregator.class,
LogisticRegression.NAME.getPreferredName(),
LogisticRegression::new),
// ML - Inference Results
new NamedWriteableRegistry.Entry(InferenceResults.class,
ClassificationInferenceResults.NAME,
ClassificationInferenceResults::new),
new NamedWriteableRegistry.Entry(InferenceResults.class,
RegressionInferenceResults.NAME,
RegressionInferenceResults::new),
// ML - Inference Configuration
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new),
new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new),

// monitoring
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,26 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
public static final String CREATED_BY = "created_by";
public static final String NODE_COUNT = "node_count";
public static final String DATA_FRAME_ANALYTICS_JOBS_FIELD = "data_frame_analytics_jobs";
public static final String INFERENCE_FIELD = "inference";

private final Map<String, Object> jobsUsage;
private final Map<String, Object> datafeedsUsage;
private final Map<String, Object> analyticsUsage;
private final Map<String, Object> inferenceUsage;
private final int nodeCount;

public MachineLearningFeatureSetUsage(boolean available,
boolean enabled,
Map<String, Object> jobsUsage,
Map<String, Object> datafeedsUsage,
Map<String, Object> analyticsUsage,
Map<String, Object> inferenceUsage,
int nodeCount) {
super(XPackField.MACHINE_LEARNING, available, enabled);
this.jobsUsage = Objects.requireNonNull(jobsUsage);
this.datafeedsUsage = Objects.requireNonNull(datafeedsUsage);
this.analyticsUsage = Objects.requireNonNull(analyticsUsage);
this.inferenceUsage = Objects.requireNonNull(inferenceUsage);
this.nodeCount = nodeCount;
}

Expand All @@ -57,12 +61,17 @@ public MachineLearningFeatureSetUsage(StreamInput in) throws IOException {
} else {
this.analyticsUsage = Collections.emptyMap();
}
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
this.inferenceUsage = in.readMap();
} else {
this.inferenceUsage = Collections.emptyMap();
}
if (in.getVersion().onOrAfter(Version.V_6_5_0)) {
this.nodeCount = in.readInt();
} else {
this.nodeCount = -1;
}
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
Expand All @@ -72,17 +81,21 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_7_4_0)) {
out.writeMap(analyticsUsage);
}
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
out.writeMap(inferenceUsage);
}
if (out.getVersion().onOrAfter(Version.V_6_5_0)) {
out.writeInt(nodeCount);
}
}
}

@Override
protected void innerXContent(XContentBuilder builder, Params params) throws IOException {
super.innerXContent(builder, params);
builder.field(JOBS_FIELD, jobsUsage);
builder.field(DATAFEEDS_FIELD, datafeedsUsage);
builder.field(DATA_FRAME_ANALYTICS_JOBS_FIELD, analyticsUsage);
builder.field(INFERENCE_FIELD, inferenceUsage);
if (nodeCount >= 0) {
builder.field(NODE_COUNT, nodeCount);
}
Expand Down
Loading

0 comments on commit bf04d3b

Please sign in to comment.