From 475e98e125be3db300023f5f376e19e1bef64fa8 Mon Sep 17 00:00:00 2001 From: David Roberts Date: Tue, 19 Mar 2024 15:30:57 +0000 Subject: [PATCH] Apply code review suggestions --- .../TransportGetTrainedModelsStatsAction.java | 123 ++++++++---------- 1 file changed, 55 insertions(+), 68 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java index dc5976a9b6db8..c2394859a5a0e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java @@ -18,6 +18,7 @@ import org.elasticsearch.action.search.TransportSearchAction; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.service.ClusterService; @@ -27,7 +28,6 @@ import org.elasticsearch.common.metrics.CounterMetric; import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.common.util.concurrent.ListenableFuture; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.query.QueryBuilder; @@ -39,6 +39,7 @@ import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.transport.Transports; import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; @@ -122,47 +123,36 @@ protected void doExecute( GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder(); - ListenableFuture> modelSizeStatsListener = new ListenableFuture<>(); - modelSizeStatsListener.addListener(listener.delegateFailureAndWrap((l, modelSizeStatsByModelId) -> { - responseBuilder.setModelSizeStatsByModelId(modelSizeStatsByModelId); - l.onResponse( - responseBuilder.build(modelToDeployments(responseBuilder.getExpandedModelIdsWithAliases().keySet(), assignmentMetadata)) - ); - })); - - ListenableFuture deploymentStatsListener = new ListenableFuture<>(); - deploymentStatsListener.addListener(listener.delegateFailureAndWrap((delegate, deploymentStats) -> executor.execute(() -> { - // deployment stats for each matching deployment - // not necessarily for all models - responseBuilder.setDeploymentStatsByDeploymentId( - deploymentStats.getStats() - .results() - .stream() - .collect(Collectors.toMap(AssignmentStats::getDeploymentId, Function.identity())) - ); + SubscribableListener.>>>newForked(l -> { + // When the request resource is a deployment find the + // model used in that deployment for the model stats + String idExpression = addModelsUsedInMatchingDeployments(request.getResourceId(), assignmentMetadata); + logger.debug("Expanded models/deployment Ids request [{}]", idExpression); - int numberOfAllocations = deploymentStats.getStats().results().stream().mapToInt(AssignmentStats::getNumberOfAllocations).sum(); - modelSizeStats( - responseBuilder.getExpandedModelIdsWithAliases(), + // the request id may contain deployment ids + // It is not an error if these don't match a model id but + // they need to be included in case the deployment id is also + // a model id. Hence, the `matchedDeploymentIds` parameter + trainedModelProvider.expandIds( + idExpression, request.isAllowNoResources(), + request.getPageParams(), + Collections.emptySet(), + modelAliasMetadata, parentTaskId, - modelSizeStatsListener, - numberOfAllocations + matchedDeploymentIds, + l ); - }))); - - ListenableFuture> inferenceStatsListener = new ListenableFuture<>(); - // inference stats are per model and are only - // persisted for boosted tree models - inferenceStatsListener.addListener(listener.delegateFailureAndWrap((l, inferenceStats) -> executor.execute(() -> { - responseBuilder.setInferenceStatsByModelId( - inferenceStats.stream().collect(Collectors.toMap(InferenceStats::getModelId, Function.identity())) + }).andThen((l, tuple) -> { + responseBuilder.setExpandedModelIdsWithAliases(tuple.v2()).setTotalModelCount(tuple.v1()); + executeAsyncWithOrigin( + client, + ML_ORIGIN, + TransportNodesStatsAction.TYPE, + nodeStatsRequest(clusterService.state(), parentTaskId), + l ); - getDeploymentStats(client, request.getResourceId(), parentTaskId, assignmentMetadata, deploymentStatsListener); - }))); - - ListenableFuture nodesStatsListener = new ListenableFuture<>(); - nodesStatsListener.addListener(listener.delegateFailureAndWrap((delegate, nodesStatsResponse) -> executor.execute(() -> { + }).>andThen(executor, null, (l, nodesStatsResponse) -> { // find all pipelines whether using the model id, // alias or deployment id. Set allPossiblePipelineReferences = responseBuilder.getExpandedModelIdsWithAliases() @@ -182,46 +172,43 @@ protected void doExecute( trainedModelProvider.getInferenceStats( responseBuilder.getExpandedModelIdsWithAliases().keySet().toArray(new String[0]), parentTaskId, - inferenceStatsListener + l ); - }))); - - ListenableFuture>>> idsListener = new ListenableFuture<>(); - idsListener.addListener(listener.delegateFailureAndWrap((delegate, tuple) -> { - responseBuilder.setExpandedModelIdsWithAliases(tuple.v2()).setTotalModelCount(tuple.v1()); - executeAsyncWithOrigin( - client, - ML_ORIGIN, - TransportNodesStatsAction.TYPE, - nodeStatsRequest(clusterService.state(), parentTaskId), - nodesStatsListener + }).andThen(executor, null, (l, inferenceStats) -> { + // inference stats are per model and are only + // persisted for boosted tree models + responseBuilder.setInferenceStatsByModelId( + inferenceStats.stream().collect(Collectors.toMap(InferenceStats::getModelId, Function.identity())) + ); + getDeploymentStats(client, request.getResourceId(), parentTaskId, assignmentMetadata, l); + }).>andThen(executor, null, (l, deploymentStats) -> { + // deployment stats for each matching deployment + // not necessarily for all models + responseBuilder.setDeploymentStatsByDeploymentId( + deploymentStats.getStats() + .results() + .stream() + .collect(Collectors.toMap(AssignmentStats::getDeploymentId, Function.identity())) ); - })); - - executor.execute(() -> { - // When the request resource is a deployment find the - // model used in that deployment for the model stats - String idExpression = addModelsUsedInMatchingDeployments(request.getResourceId(), assignmentMetadata); - logger.debug("Expanded models/deployment Ids request [{}]", idExpression); - // the request id may contain deployment ids - // It is not an error if these don't match a model id but - // they need to be included in case the deployment id is also - // a model id. Hence, the `matchedDeploymentIds` parameter - trainedModelProvider.expandIds( - idExpression, + int numberOfAllocations = deploymentStats.getStats().results().stream().mapToInt(AssignmentStats::getNumberOfAllocations).sum(); + modelSizeStats( + responseBuilder.getExpandedModelIdsWithAliases(), request.isAllowNoResources(), - request.getPageParams(), - Collections.emptySet(), - modelAliasMetadata, parentTaskId, - matchedDeploymentIds, - idsListener + l, + numberOfAllocations ); - }); + }).andThen((l, modelSizeStatsByModelId) -> { + responseBuilder.setModelSizeStatsByModelId(modelSizeStatsByModelId); + l.onResponse( + responseBuilder.build(modelToDeployments(responseBuilder.getExpandedModelIdsWithAliases().keySet(), assignmentMetadata)) + ); + }).addListener(listener, executor, null); } static String addModelsUsedInMatchingDeployments(String idExpression, TrainedModelAssignmentMetadata assignmentMetadata) { + assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures"); if (Strings.isAllOrWildcard(idExpression)) { return idExpression; } else {