Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support task resource tracking in OpenSearch #3982

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,9 @@ public void onTaskUnregistered(Task task) {}

@Override
public void waitForTaskCompletion(Task task) {}

@Override
public void taskExecutionStarted(Task task, Boolean closeableInvoked) {}
});
}
// Need to run the task in a separate thread because node client's .execute() is blocked by our task listener
Expand Down Expand Up @@ -651,6 +654,9 @@ public void waitForTaskCompletion(Task task) {
waitForWaitingToStart.countDown();
}

@Override
public void taskExecutionStarted(Task task, Boolean closeableInvoked) {}

@Override
public void onTaskRegistered(Task task) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskInfo;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

Expand All @@ -65,8 +66,15 @@ public static long waitForCompletionTimeout(TimeValue timeout) {

private static final TimeValue DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT = timeValueSeconds(30);

private final TaskResourceTrackingService taskResourceTrackingService;

@Inject
public TransportListTasksAction(ClusterService clusterService, TransportService transportService, ActionFilters actionFilters) {
public TransportListTasksAction(
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters,
TaskResourceTrackingService taskResourceTrackingService
) {
super(
ListTasksAction.NAME,
clusterService,
Expand All @@ -77,6 +85,7 @@ public TransportListTasksAction(ClusterService clusterService, TransportService
TaskInfo::new,
ThreadPool.Names.MANAGEMENT
);
this.taskResourceTrackingService = taskResourceTrackingService;
}

@Override
Expand Down Expand Up @@ -106,6 +115,8 @@ protected void processTasks(ListTasksRequest request, Consumer<Task> operation)
}
taskManager.waitForTaskCompletion(task, timeoutNanos);
});
} else {
operation = operation.andThen(taskResourceTrackingService::refreshResourceStats);
}
super.processTasks(request, operation);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ public SearchShardTask(long id, String type, String action, String description,
super(id, type, action, description, parentTaskId, headers);
}

@Override
public boolean supportsResourceTracking() {
return true;
}

@Override
public boolean shouldCancelChildrenOnCancellation() {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ public final String getDescription() {
return descriptionSupplier.get();
}

@Override
public boolean supportsResourceTracking() {
return true;
}

/**
* Attach a {@link SearchProgressListener} to this task.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.action.ActionResponse;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskCancelledException;
import org.opensearch.tasks.TaskId;
Expand Down Expand Up @@ -93,31 +94,39 @@ public final Task execute(Request request, ActionListener<Response> listener) {
*/
final Releasable unregisterChildNode = registerChildNode(request.getParentTask());
final Task task;

try {
task = taskManager.register("transport", actionName, request);
} catch (TaskCancelledException e) {
unregisterChildNode.close();
throw e;
}
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(response);

ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task);
try {
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(response);
}
}
}

@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(e);
@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(e);
}
}
}
});
});
} finally {
storedContext.close();
}

return task;
}

Expand All @@ -134,25 +143,30 @@ public final Task execute(Request request, TaskListener<Response> listener) {
unregisterChildNode.close();
throw e;
}
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(task, response);
ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task);
try {
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(task, response);
}
}
}

@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(task, e);
@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(task, e);
}
}
}
});
});
} finally {
storedContext.close();
}
return task;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
import org.opensearch.script.ScriptMetadata;
import org.opensearch.snapshots.SnapshotsInfoService;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.tasks.TaskResultsService;

import java.util.ArrayList;
Expand Down Expand Up @@ -396,6 +397,7 @@ protected void configure() {
bind(NodeMappingRefreshAction.class).asEagerSingleton();
bind(MappingUpdatedAction.class).asEagerSingleton();
bind(TaskResultsService.class).asEagerSingleton();
bind(TaskResourceTrackingService.class).asEagerSingleton();
bind(AllocationDeciders.class).toInstance(allocationDeciders);
bind(ShardsAllocator.class).toInstance(shardsAllocator);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.opensearch.index.ShardIndexingPressureMemoryManager;
import org.opensearch.index.ShardIndexingPressureSettings;
import org.opensearch.index.ShardIndexingPressureStore;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.watcher.ResourceWatcherService;
import org.opensearch.action.admin.cluster.configuration.TransportAddVotingConfigExclusionsAction;
import org.opensearch.action.admin.indices.close.TransportCloseIndexAction;
Expand Down Expand Up @@ -575,7 +576,8 @@ public void apply(Settings value, Settings current, Settings previous) {
ShardIndexingPressureMemoryManager.THROUGHPUT_DEGRADATION_LIMITS,
ShardIndexingPressureMemoryManager.SUCCESSFUL_REQUEST_ELAPSED_TIMEOUT,
ShardIndexingPressureMemoryManager.MAX_OUTSTANDING_REQUESTS,
IndexingPressure.MAX_INDEXING_BYTES
IndexingPressure.MAX_INDEXING_BYTES,
TaskResourceTrackingService.TASK_RESOURCE_TRACKING_ENABLED
)
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.node.Node;
import org.opensearch.threadpool.RunnableTaskExecutionListener;
import org.opensearch.threadpool.TaskAwareRunnable;

import java.util.List;
import java.util.Optional;
Expand All @@ -55,6 +57,7 @@
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

/**
Expand Down Expand Up @@ -182,21 +185,32 @@ public static OpenSearchThreadPoolExecutor newResizable(
int size,
int queueCapacity,
ThreadFactory threadFactory,
ThreadContext contextHolder
ThreadContext contextHolder,
AtomicReference<RunnableTaskExecutionListener> runnableTaskListener
) {

if (queueCapacity <= 0) {
throw new IllegalArgumentException("queue capacity for [" + name + "] executor must be positive, got: " + queueCapacity);
}

Function<Runnable, WrappedRunnable> runnableWrapper;
if (runnableTaskListener != null) {
runnableWrapper = (runnable) -> {
TaskAwareRunnable taskAwareRunnable = new TaskAwareRunnable(contextHolder, runnable, runnableTaskListener);
return new TimedRunnable(taskAwareRunnable);
};
} else {
runnableWrapper = TimedRunnable::new;
}

return new QueueResizableOpenSearchThreadPoolExecutor(
name,
size,
size,
0,
TimeUnit.MILLISECONDS,
new ResizableBlockingQueue<>(ConcurrentCollections.<Runnable>newBlockingQueue(), queueCapacity),
TimedRunnable::new,
runnableWrapper,
threadFactory,
new OpenSearchAbortPolicy(),
contextHolder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@

import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_COUNT;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_SIZE;
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;

/**
* A ThreadContext is a map of string headers and a transient map of keyed objects that are associated with
Expand Down Expand Up @@ -135,16 +136,23 @@ public StoredContext stashContext() {
* This is needed so the DeprecationLogger in another thread can see the value of X-Opaque-ID provided by a user.
* Otherwise when context is stash, it should be empty.
*/

ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT;

if (context.requestHeaders.containsKey(Task.X_OPAQUE_ID)) {
ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT.putHeaders(
threadContextStruct = threadContextStruct.putHeaders(
MapBuilder.<String, String>newMapBuilder()
.put(Task.X_OPAQUE_ID, context.requestHeaders.get(Task.X_OPAQUE_ID))
.immutableMap()
);
threadLocal.set(threadContextStruct);
} else {
threadLocal.set(DEFAULT_CONTEXT);
}

if (context.transientHeaders.containsKey(TASK_ID)) {
threadContextStruct = threadContextStruct.putTransient(TASK_ID, context.transientHeaders.get(TASK_ID));
}

threadLocal.set(threadContextStruct);

return () -> {
// If the node and thus the threadLocal get closed while this task
// is still executing, we don't want this runnable to fail with an
Expand Down
14 changes: 12 additions & 2 deletions server/src/main/java/org/opensearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.cluster.routing.allocation.AwarenessReplicaBalance;
import org.opensearch.index.IndexingPressureService;
import org.opensearch.index.store.RemoteDirectoryFactory;
import org.opensearch.indices.replication.SegmentReplicationSourceFactory;
import org.opensearch.indices.replication.SegmentReplicationTargetService;
import org.opensearch.indices.replication.SegmentReplicationSourceService;
import org.opensearch.index.store.RemoteDirectoryFactory;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.threadpool.RunnableTaskExecutionListener;
import org.opensearch.watcher.ResourceWatcherService;
import org.opensearch.Assertions;
import org.opensearch.Build;
Expand Down Expand Up @@ -219,6 +221,7 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -338,6 +341,7 @@ public static class DiscoverySettings {
private final LocalNodeFactory localNodeFactory;
private final NodeService nodeService;
final NamedWriteableRegistry namedWriteableRegistry;
private final AtomicReference<RunnableTaskExecutionListener> runnableTaskListener;

public Node(Environment environment) {
this(environment, Collections.emptyList(), true);
Expand Down Expand Up @@ -447,7 +451,8 @@ protected Node(

final List<ExecutorBuilder<?>> executorBuilders = pluginsService.getExecutorBuilders(settings);

final ThreadPool threadPool = new ThreadPool(settings, executorBuilders.toArray(new ExecutorBuilder[0]));
runnableTaskListener = new AtomicReference<>();
final ThreadPool threadPool = new ThreadPool(settings, runnableTaskListener, executorBuilders.toArray(new ExecutorBuilder[0]));
resourcesToClose.add(() -> ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS));
final ResourceWatcherService resourceWatcherService = new ResourceWatcherService(settings, threadPool);
resourcesToClose.add(resourceWatcherService);
Expand Down Expand Up @@ -1095,6 +1100,11 @@ public Node start() throws NodeValidationException {
TransportService transportService = injector.getInstance(TransportService.class);
transportService.getTaskManager().setTaskResultsService(injector.getInstance(TaskResultsService.class));
transportService.getTaskManager().setTaskCancellationService(new TaskCancellationService(transportService));

TaskResourceTrackingService taskResourceTrackingService = injector.getInstance(TaskResourceTrackingService.class);
transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService);
runnableTaskListener.set(taskResourceTrackingService);

transportService.start();
assert localNodeFactory.getNode() != null;
assert transportService.getLocalNode().equals(localNodeFactory.getNode())
Expand Down
Loading