Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cancel search on shard failure when partial results disallowed #63520

Merged
merged 3 commits into from
Oct 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollAction;
import org.elasticsearch.action.search.SearchTask;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
Expand All @@ -42,17 +45,22 @@
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.lookup.LeafFieldsLookup;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.transport.TransportService;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import static org.elasticsearch.index.query.QueryBuilders.scriptQuery;
Expand Down Expand Up @@ -273,13 +281,83 @@ public void testCancelMultiSearch() throws Exception {
}
}

public void testCancelFailedSearchWhenPartialResultDisallowed() throws Exception {
final List<ScriptedBlockPlugin> plugins = initBlockFactory();
int numberOfShards = between(2, 5);
AtomicBoolean failed = new AtomicBoolean();
CountDownLatch queryLatch = new CountDownLatch(1);
CountDownLatch cancelledLatch = new CountDownLatch(1);
for (ScriptedBlockPlugin plugin : plugins) {
plugin.disableBlock();
plugin.setBeforeExecution(() -> {
try {
queryLatch.await(); // block the query until we get a search task
} catch (InterruptedException e) {
throw new AssertionError(e);
}
if (failed.compareAndSet(false, true)) {
throw new IllegalStateException("simulated");
}
try {
cancelledLatch.await(); // block the query until the search is cancelled
} catch (InterruptedException e) {
throw new AssertionError(e);
}
});
}
createIndex("test", Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numberOfShards)
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
.build());
indexTestData();
Thread searchThread = new Thread(() -> {
expectThrows(Exception.class, () -> {
client().prepareSearch("test")
.setSearchType(SearchType.QUERY_THEN_FETCH)
.setQuery(scriptQuery(new Script(ScriptType.INLINE, "mockscript", SCRIPT_NAME, Collections.emptyMap())))
.setAllowPartialSearchResults(false).setSize(1000).get();
});
});
searchThread.start();
try {
assertBusy(() -> assertThat(getSearchTasks(), hasSize(1)));
queryLatch.countDown();
assertBusy(() -> {
final List<SearchTask> searchTasks = getSearchTasks();
assertThat(searchTasks, hasSize(1));
assertTrue(searchTasks.get(0).isCancelled());
}, 30, TimeUnit.SECONDS);
} finally {
for (ScriptedBlockPlugin plugin : plugins) {
plugin.setBeforeExecution(() -> {});
}
cancelledLatch.countDown();
searchThread.join();
}
}

List<SearchTask> getSearchTasks() {
List<SearchTask> tasks = new ArrayList<>();
for (String nodeName : internalCluster().getNodeNames()) {
TransportService transportService = internalCluster().getInstance(TransportService.class, nodeName);
for (Task task : transportService.getTaskManager().getCancellableTasks().values()) {
if (task.getAction().equals(SearchAction.NAME)) {
tasks.add((SearchTask) task);
}
}
}
return tasks;
}

public static class ScriptedBlockPlugin extends MockScriptPlugin {
static final String SCRIPT_NAME = "search_block";

private final AtomicInteger hits = new AtomicInteger();

private final AtomicBoolean shouldBlock = new AtomicBoolean(true);

private final AtomicReference<Runnable> beforeExecution = new AtomicReference<>();

public void reset() {
hits.set(0);
}
Expand All @@ -292,9 +370,17 @@ public void enableBlock() {
shouldBlock.set(true);
}

public void setBeforeExecution(Runnable runnable) {
beforeExecution.set(runnable);
}

@Override
public Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
return Collections.singletonMap(SCRIPT_NAME, params -> {
final Runnable runnable = beforeExecution.get();
if (runnable != null) {
runnable.run();
}
LeafFieldsLookup fieldsLookup = (LeafFieldsLookup) params.get("_fields");
LogManager.getLogger(SearchCancellationIT.class).info("Blocking on the document {}", fieldsLookup.get("_id"));
hits.incrementAndGet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
private final int maxConcurrentRequestsPerNode;
private final Map<String, PendingExecutions> pendingExecutionsPerNode = new ConcurrentHashMap<>();
private final boolean throttleConcurrentRequests;
private final AtomicBoolean requestCancelled = new AtomicBoolean();

private final List<Releasable> releasables = new ArrayList<>();

Expand Down Expand Up @@ -393,6 +394,15 @@ private void onShardFailure(final int shardIndex, @Nullable SearchShardTarget sh
logger.debug(() -> new ParameterizedMessage("{}: Failed to execute [{}] lastShard [{}]",
shard != null ? shard : shardIt.shardId(), request, lastShard), e);
if (lastShard) {
if (request.allowPartialSearchResults() == false) {
if (requestCancelled.compareAndSet(false, true)) {
try {
searchTransportService.cancelSearchTask(task, "partial results are not allowed and at least one shard has failed");
} catch (Exception cancelFailure) {
logger.debug("Failed to cancel search request", cancelFailure);
}
}
}
onShardGroupFailure(shardIndex, shard, e);
}
final int totalOps = this.totalOps.incrementAndGet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.IndicesRequest;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.StreamInput;
Expand All @@ -46,6 +50,7 @@
import org.elasticsearch.search.query.QuerySearchRequest;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.search.query.ScrollQuerySearchResult;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.RemoteClusterService;
import org.elasticsearch.transport.Transport;
Expand Down Expand Up @@ -81,12 +86,14 @@ public class SearchTransportService {
public static final String QUERY_CAN_MATCH_NAME = "indices:data/read/search[can_match]";

private final TransportService transportService;
private final NodeClient client;
private final BiFunction<Transport.Connection, SearchActionListener, ActionListener> responseWrapper;
private final Map<String, Long> clientConnections = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();

public SearchTransportService(TransportService transportService,
public SearchTransportService(TransportService transportService, NodeClient client,
BiFunction<Transport.Connection, SearchActionListener, ActionListener> responseWrapper) {
this.transportService = transportService;
this.client = client;
this.responseWrapper = responseWrapper;
}

Expand Down Expand Up @@ -423,4 +430,12 @@ private boolean assertNodePresent() {
return true;
}
}

public void cancelSearchTask(SearchTask task, String reason) {
CancelTasksRequest req = new CancelTasksRequest()
.setTaskId(new TaskId(client.getLocalNodeId(), task.getId()))
.setReason("Fatal failure during search: " + reason);
// force the origin to execute the cancellation as a system user
new OriginSettingClient(client, GetTaskAction.TASKS_ORIGIN).admin().cluster().cancelTasks(req, ActionListener.wrap(() -> {}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,13 @@

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsRequest;
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
Expand Down Expand Up @@ -69,7 +66,6 @@
import org.elasticsearch.search.profile.ProfileShardResult;
import org.elasticsearch.search.profile.SearchProfileShardResults;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.RemoteClusterAware;
import org.elasticsearch.transport.RemoteClusterService;
Expand Down Expand Up @@ -97,7 +93,6 @@
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import static org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN;
import static org.elasticsearch.action.search.SearchType.DFS_QUERY_THEN_FETCH;
import static org.elasticsearch.action.search.SearchType.QUERY_THEN_FETCH;
import static org.elasticsearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort;
Expand All @@ -108,7 +103,6 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
public static final Setting<Long> SHARD_COUNT_LIMIT_SETTING = Setting.longSetting(
"action.search.shard_count.limit", Long.MAX_VALUE, 1L, Property.Dynamic, Property.NodeScope);

private final NodeClient client;
private final ThreadPool threadPool;
private final ClusterService clusterService;
private final SearchTransportService searchTransportService;
Expand All @@ -120,8 +114,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
private final CircuitBreaker circuitBreaker;

@Inject
public TransportSearchAction(NodeClient client,
ThreadPool threadPool,
public TransportSearchAction(ThreadPool threadPool,
CircuitBreakerService circuitBreakerService,
TransportService transportService,
SearchService searchService,
Expand All @@ -132,7 +125,6 @@ public TransportSearchAction(NodeClient client,
IndexNameExpressionResolver indexNameExpressionResolver,
NamedWriteableRegistry namedWriteableRegistry) {
super(SearchAction.NAME, transportService, actionFilters, (Writeable.Reader<SearchRequest>) SearchRequest::new);
this.client = client;
this.threadPool = threadPool;
this.circuitBreaker = circuitBreakerService.getBreaker(CircuitBreaker.REQUEST);
this.searchPhaseController = searchPhaseController;
Expand Down Expand Up @@ -801,7 +793,8 @@ public void run() {
}, clusters);
} else {
final QueryPhaseResultConsumer queryResultConsumer = searchPhaseController.newSearchPhaseResults(executor,
circuitBreaker, task.getProgressListener(), searchRequest, shardIterators.size(), exc -> cancelTask(task, exc));
circuitBreaker, task.getProgressListener(), searchRequest, shardIterators.size(),
exc -> searchTransportService.cancelSearchTask(task, "failed to merge result [" + exc.getMessage() + "]"));
AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction;
switch (searchRequest.searchType()) {
case DFS_QUERY_THEN_FETCH:
Expand All @@ -821,15 +814,6 @@ public void run() {
}
}

private void cancelTask(SearchTask task, Exception exc) {
String errorMsg = exc.getMessage() != null ? exc.getMessage() : "";
CancelTasksRequest req = new CancelTasksRequest()
.setTaskId(new TaskId(client.getLocalNodeId(), task.getId()))
.setReason("Fatal failure during search: " + errorMsg);
// force the origin to execute the cancellation as a system user
new OriginSettingClient(client, TASKS_ORIGIN).admin().cluster().cancelTasks(req, ActionListener.wrap(() -> {}));
}

private static void failIfOverShardCountLimit(ClusterService clusterService, int shardCount) {
final long shardCountLimit = clusterService.getClusterSettings().get(SHARD_COUNT_LIMIT_SETTING);
if (shardCount > shardCountLimit) {
Expand Down
2 changes: 1 addition & 1 deletion server/src/main/java/org/elasticsearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ protected Node(final Environment initialEnvironment,
networkModule.getTransportInterceptor(), localNodeFactory, settingsModule.getClusterSettings(), taskHeaders);
final GatewayMetaState gatewayMetaState = new GatewayMetaState();
final ResponseCollectorService responseCollectorService = new ResponseCollectorService(clusterService);
final SearchTransportService searchTransportService = new SearchTransportService(transportService,
final SearchTransportService searchTransportService = new SearchTransportService(transportService, client,
SearchExecutionStatsCollector.makeWrapper(responseCollectorService));
final HttpServerTransport httpServerTransport = newHttpTransport(networkModule);
final IndexingPressure indexingLimits = new IndexingPressure(settings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public void testFilterShards() throws InterruptedException {
final boolean shard1 = randomBoolean();
final boolean shard2 = randomBoolean();

SearchTransportService searchTransportService = new SearchTransportService(null, null) {
SearchTransportService searchTransportService = new SearchTransportService(null, null, null) {
@Override
public void sendCanMatch(Transport.Connection connection, ShardSearchRequest request, SearchTask task,
ActionListener<SearchService.CanMatchResponse> listener) {
Expand Down Expand Up @@ -129,7 +129,7 @@ public void testFilterWithFailure() throws InterruptedException {
lookup.put("node1", new SearchAsyncActionTests.MockConnection(primaryNode));
lookup.put("node2", new SearchAsyncActionTests.MockConnection(replicaNode));
final boolean shard1 = randomBoolean();
SearchTransportService searchTransportService = new SearchTransportService(null, null) {
SearchTransportService searchTransportService = new SearchTransportService(null, null, null) {
@Override
public void sendCanMatch(Transport.Connection connection, ShardSearchRequest request, SearchTask task,
ActionListener<SearchService.CanMatchResponse> listener) {
Expand Down Expand Up @@ -195,7 +195,7 @@ public void testLotsOfShards() throws InterruptedException {


final SearchTransportService searchTransportService =
new SearchTransportService(null, null) {
new SearchTransportService(null, null, null) {
@Override
public void sendCanMatch(
Transport.Connection connection,
Expand All @@ -213,7 +213,7 @@ public void sendCanMatch(
final ExecutorService executor = Executors.newFixedThreadPool(randomIntBetween(1, Runtime.getRuntime().availableProcessors()));
final SearchRequest searchRequest = new SearchRequest();
searchRequest.allowPartialSearchResults(true);
SearchTransportService transportService = new SearchTransportService(null, null);
SearchTransportService transportService = new SearchTransportService(null, null, null);
ActionListener<SearchResponse> responseListener = ActionListener.wrap(response -> {},
(e) -> { throw new AssertionError("unexpected", e);});
Map<String, AliasFilter> aliasFilters = Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY));
Expand Down Expand Up @@ -296,7 +296,7 @@ public void testSortShards() throws InterruptedException {
List<MinAndMax<?>> minAndMaxes = new ArrayList<>();
Set<ShardId> shardToSkip = new HashSet<>();

SearchTransportService searchTransportService = new SearchTransportService(null, null) {
SearchTransportService searchTransportService = new SearchTransportService(null, null, null) {
@Override
public void sendCanMatch(Transport.Connection connection, ShardSearchRequest request, SearchTask task,
ActionListener<SearchService.CanMatchResponse> listener) {
Expand Down Expand Up @@ -369,7 +369,7 @@ public void testInvalidSortShards() throws InterruptedException {
List<ShardId> shardIds = new ArrayList<>();
Set<ShardId> shardToSkip = new HashSet<>();

SearchTransportService searchTransportService = new SearchTransportService(null, null) {
SearchTransportService searchTransportService = new SearchTransportService(null, null, null) {
@Override
public void sendCanMatch(Transport.Connection connection, ShardSearchRequest request, SearchTask task,
ActionListener<SearchService.CanMatchResponse> listener) {
Expand Down
Loading