Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML][Inference] add inference processors and trained models to usage #47869

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,26 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
public static final String CREATED_BY = "created_by";
public static final String NODE_COUNT = "node_count";
public static final String DATA_FRAME_ANALYTICS_JOBS_FIELD = "data_frame_analytics_jobs";
public static final String INFERENCE_FIELD = "inference";

private final Map<String, Object> jobsUsage;
private final Map<String, Object> datafeedsUsage;
private final Map<String, Object> analyticsUsage;
private final Map<String, Object> inferenceUsage;
private final int nodeCount;

public MachineLearningFeatureSetUsage(boolean available,
boolean enabled,
Map<String, Object> jobsUsage,
Map<String, Object> datafeedsUsage,
Map<String, Object> analyticsUsage,
Map<String, Object> inferenceUsage,
int nodeCount) {
super(XPackField.MACHINE_LEARNING, available, enabled);
this.jobsUsage = Objects.requireNonNull(jobsUsage);
this.datafeedsUsage = Objects.requireNonNull(datafeedsUsage);
this.analyticsUsage = Objects.requireNonNull(analyticsUsage);
this.inferenceUsage = Objects.requireNonNull(inferenceUsage);
this.nodeCount = nodeCount;
}

Expand All @@ -57,6 +61,11 @@ public MachineLearningFeatureSetUsage(StreamInput in) throws IOException {
} else {
this.analyticsUsage = Collections.emptyMap();
}
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
this.inferenceUsage = in.readMap();
} else {
this.inferenceUsage = Collections.emptyMap();
}
this.nodeCount = in.readInt();
}

Expand All @@ -68,6 +77,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_7_4_0)) {
out.writeMap(analyticsUsage);
}
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
out.writeMap(inferenceUsage);
}
out.writeInt(nodeCount);
}

Expand All @@ -77,6 +89,7 @@ protected void innerXContent(XContentBuilder builder, Params params) throws IOEx
builder.field(JOBS_FIELD, jobsUsage);
builder.field(DATAFEEDS_FIELD, datafeedsUsage);
builder.field(DATA_FRAME_ANALYTICS_JOBS_FIELD, analyticsUsage);
builder.field(INFERENCE_FIELD, inferenceUsage);
if (nodeCount >= 0) {
builder.field(NODE_COUNT, nodeCount);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

import org.apache.lucene.util.Counter;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
Expand All @@ -16,11 +22,13 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.env.Environment;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.protocol.xpack.XPackUsageRequest;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureResponse;
Expand All @@ -32,19 +40,24 @@
import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeStats;
import org.elasticsearch.xpack.core.ml.stats.ForecastStats;
import org.elasticsearch.xpack.core.ml.stats.StatsAccumulator;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.job.JobManagerHolder;

import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransportAction {
Expand Down Expand Up @@ -72,28 +85,56 @@ protected void masterOperation(Task task, XPackUsageRequest request, ClusterStat
ActionListener<XPackUsageFeatureResponse> listener) {
if (enabled == false) {
MachineLearningFeatureSetUsage usage = new MachineLearningFeatureSetUsage(licenseState.isMachineLearningAllowed(), enabled,
Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), 0);
Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), 0);
listener.onResponse(new XPackUsageFeatureResponse(usage));
return;
}

Map<String, Object> jobsUsage = new LinkedHashMap<>();
Map<String, Object> datafeedsUsage = new LinkedHashMap<>();
Map<String, Object> analyticsUsage = new LinkedHashMap<>();
Map<String, Object> inferenceUsage = new LinkedHashMap<>();
int nodeCount = mlNodeCount(state);

// Step 3. Extract usage from data frame analytics stats and return usage response
ActionListener<GetDataFrameAnalyticsStatsAction.Response> dataframeAnalyticsListener = ActionListener.wrap(
// Step 5. extract trained model config count and then return results
ActionListener<SearchResponse> trainedModelConfigCountListener = ActionListener.wrap(
response -> {
addDataFrameAnalyticsUsage(response, analyticsUsage);
addTrainedModelStats(response, inferenceUsage);
MachineLearningFeatureSetUsage usage = new MachineLearningFeatureSetUsage(licenseState.isMachineLearningAllowed(),
enabled, jobsUsage, datafeedsUsage, analyticsUsage, nodeCount);
enabled, jobsUsage, datafeedsUsage, analyticsUsage, inferenceUsage, nodeCount);
listener.onResponse(new XPackUsageFeatureResponse(usage));
},
listener::onFailure
);

// Step 2. Extract usage from datafeeds stats and return usage response
// Step 4. Extract usage from ingest statistics and gather trained model config count
ActionListener<NodesStatsResponse> nodesStatsListener = ActionListener.wrap(
response -> {
addInferenceIngestUsage(response, inferenceUsage);
SearchRequestBuilder requestBuilder = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
.setSize(0)
.setTrackTotalHits(true);
ClientHelper.executeAsyncWithOrigin(client.threadPool().getThreadContext(),
ClientHelper.ML_ORIGIN,
requestBuilder.request(),
trainedModelConfigCountListener,
client::search);
},
listener::onFailure
);

// Step 3. Extract usage from data frame analytics stats and then request ingest node stats
ActionListener<GetDataFrameAnalyticsStatsAction.Response> dataframeAnalyticsListener = ActionListener.wrap(
response -> {
addDataFrameAnalyticsUsage(response, analyticsUsage);
String[] ingestNodes = ingestNodes(state);
NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear().ingest(true);
client.execute(NodesStatsAction.INSTANCE, nodesStatsRequest, nodesStatsListener);
},
listener::onFailure
);

// Step 2. Extract usage from datafeeds stats and then request stats for data frame analytics
ActionListener<GetDatafeedsStatsAction.Response> datafeedStatsListener =
ActionListener.wrap(response -> {
addDatafeedsUsage(response, datafeedsUsage);
Expand Down Expand Up @@ -227,6 +268,66 @@ private void addDataFrameAnalyticsUsage(GetDataFrameAnalyticsStatsAction.Respons
}
}

private static void initializeStats(Map<String, Long> emptyStatsMap) {
emptyStatsMap.put("sum", 0L);
emptyStatsMap.put("min", 0L);
emptyStatsMap.put("max", 0L);
}

private static void updateStats(Map<String, Long> statsMap, Long value) {
statsMap.compute("sum", (k, v) -> v + value);
statsMap.compute("min", (k, v) -> Math.min(v, value));
statsMap.compute("max", (k, v) -> Math.max(v, value));
}

//TODO separate out ours and users models possibly regression vs classification
private void addTrainedModelStats(SearchResponse response, Map<String, Object> inferenceUsage) {
inferenceUsage.put("trained_models",
Collections.singletonMap(MachineLearningFeatureSetUsage.ALL, createCountUsageEntry(response.getHits().getTotalHits().value)));
}

//TODO separate out ours and users models possibly regression vs classification
private void addInferenceIngestUsage(NodesStatsResponse response, Map<String, Object> inferenceUsage) {
Set<String> pipelines = new HashSet<>();
Map<String, Long> docCountStats = new HashMap<>(3);
Map<String, Long> timeStats = new HashMap<>(3);
Map<String, Long> failureStats = new HashMap<>(3);
initializeStats(docCountStats);
initializeStats(timeStats);
initializeStats(failureStats);

response.getNodes()
.stream()
.map(NodeStats::getIngestStats)
.map(IngestStats::getProcessorStats)
.forEach(map ->
map.forEach((pipelineId, processors) -> {
boolean containsInference = false;
for(IngestStats.ProcessorStat stats : processors) {
if (stats.getName().equals(InferenceProcessor.TYPE)) {
containsInference = true;
long ingestCount = stats.getStats().getIngestCount();
long ingestTime = stats.getStats().getIngestTimeInMillis();
long failureCount = stats.getStats().getIngestFailedCount();
updateStats(docCountStats, ingestCount);
updateStats(timeStats, ingestTime);
updateStats(failureStats, failureCount);
}
}
if (containsInference) {
pipelines.add(pipelineId);
}
})
);

Map<String, Object> ingestUsage = new HashMap<>(6);
ingestUsage.put("pipelines", createCountUsageEntry(pipelines.size()));
ingestUsage.put("num_docs_processed", docCountStats);
ingestUsage.put("time_ms", timeStats);
ingestUsage.put("num_failures", failureStats);
inferenceUsage.put("ingest_processors", Collections.singletonMap(MachineLearningFeatureSetUsage.ALL, ingestUsage));
}

private static int mlNodeCount(final ClusterState clusterState) {
int mlNodeCount = 0;
for (DiscoveryNode node : clusterState.getNodes()) {
Expand All @@ -236,4 +337,14 @@ private static int mlNodeCount(final ClusterState clusterState) {
}
return mlNodeCount;
}

private static String[] ingestNodes(final ClusterState clusterState) {
String[] ingestNodes = new String[clusterState.nodes().getIngestNodes().size()];
Iterator<String> nodeIterator = clusterState.nodes().getIngestNodes().keysIt();
int i = 0;
while(nodeIterator.hasNext()) {
ingestNodes[i++] = nodeIterator.next();
}
return ingestNodes;
}
}
Loading