Skip to content

Commit

Permalink
During concurrent slice searches in IndexSearcher stop other tasks if…
Browse files Browse the repository at this point in the history
… one throws an Exception.

Since TaskExecutor now waits for all concurrent tasks to finish, even if one throws an
Exception and when an exception is thrown, any remaining unscheduled tasks are cancelled,
the next step is to notify currently running tasks to exit early. This is done via a
a new QueryTimeout implementation, ExceptionBasedQueryTimeout, which holds a volatile
boolean of whether any other sibling task threw an exception. If the boolean is true,
then the shouldExit method returns true, so that the in progress task exits early.

Closes #12278
  • Loading branch information
quux00 committed Nov 3, 2023
1 parent d6836d3 commit dd0bb0a
Show file tree
Hide file tree
Showing 2 changed files with 286 additions and 3 deletions.
70 changes: 67 additions & 3 deletions lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -674,14 +674,24 @@ private <C extends Collector, T> T search(
"CollectorManager does not always produce collectors with the same score mode");
}
}
if (collectors.size() > 1) {
addExceptionBasedQueryTimeout(queryTimeout);
}
final List<Callable<C>> listTasks = new ArrayList<>(leafSlices.length);
for (int i = 0; i < leafSlices.length; ++i) {
final LeafReaderContext[] leaves = leafSlices[i].leaves;
final C collector = collectors.get(i);
listTasks.add(
() -> {
search(Arrays.asList(leaves), weight, collector);
return collector;
try {
search(Arrays.asList(leaves), weight, collector);
return collector;
} catch (Exception e) {
if (queryTimeout instanceof ExceptionBasedQueryTimeout eqt) {
eqt.notifyExceptionThrown();
}
throw e;
}
});
}
List<C> results = taskExecutor.invokeAll(listTasks);
Expand Down Expand Up @@ -725,7 +735,7 @@ protected void search(List<LeafReaderContext> leaves, Weight weight, Collector c
BulkScorer scorer = weight.bulkScorer(ctx);
if (scorer != null) {
if (queryTimeout != null) {
scorer = new TimeLimitingBulkScorer(scorer, queryTimeout);
scorer = createTimeLimitingBulkScorer(scorer, queryTimeout);
}
try {
scorer.score(leafCollector, ctx.reader().getLiveDocs());
Expand Down Expand Up @@ -954,6 +964,27 @@ public TaskExecutor getTaskExecutor() {
return taskExecutor;
}

/**
* Package private so that it can be overridden for testing.
*
* @lucene.internal
*/
void addExceptionBasedQueryTimeout(QueryTimeout delegate) {
setTimeout(new ExceptionBasedQueryTimeout(delegate));
}

/**
* Created to allow tests to override this method with a test-based BulkScorer.
*
* @param scorer to pass to the {@link TimeLimitingBulkScorer} constructor
* @param queryTimeout to pass to the {@link TimeLimitingBulkScorer} constructor
* @return {@link TimeLimitingBulkScorer}
* @lucene.internal
*/
BulkScorer createTimeLimitingBulkScorer(BulkScorer scorer, QueryTimeout queryTimeout) {
return new TimeLimitingBulkScorer(scorer, queryTimeout);
}

/**
* Thrown when an attempt is made to add more than {@link #getMaxClauseCount()} clauses. This
* typically happens if a PrefixQuery, FuzzyQuery, WildcardQuery, or TermRangeQuery is expanded to
Expand Down Expand Up @@ -1029,4 +1060,37 @@ public LeafSlice[] get() {
return leafSlices;
}
}

/**
* A QueryTimeout implementation that returns true from {@code shouldExit} if it has been notified
* that an Exception was thrown. It also wraps any existing QueryTimeout that has been set on the
* IndexSearcher and checks its underlying shouldExit result.
*
* <p>Use this when doing concurrent searches of slices to cause other slice tasks to abort early
* if any of their siblings throw an Exception, so the entire search can fail fast.
*/
static class ExceptionBasedQueryTimeout implements QueryTimeout {

private final QueryTimeout delegate;
private volatile boolean exceptionThrown;

public ExceptionBasedQueryTimeout(QueryTimeout delegate) {
if (delegate == null) {
// create a no-op timeout to avoid 'if' checks in shouldExit
this.delegate = () -> false;
} else {
this.delegate = delegate;
}
this.exceptionThrown = false;
}

public void notifyExceptionThrown() {
exceptionThrown = true;
}

@Override
public boolean shouldExit() {
return exceptionThrown || delegate.shouldExit();
}
}
}
219 changes: 219 additions & 0 deletions lucene/core/src/test/org/apache/lucene/search/TestIndexSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,17 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MultiReader;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.NamedThreadFactory;
import org.hamcrest.Matchers;

public class TestIndexSearcher extends LuceneTestCase {
Directory dir;
Expand Down Expand Up @@ -293,4 +296,220 @@ public void testNullExecutorNonNullTaskExecutor() {
IndexSearcher indexSearcher = new IndexSearcher(reader);
assertNotNull(indexSearcher.getTaskExecutor());
}

/*
* The goal of this test is to ensure that when multiple concurrent slices are
* being searched and one of concurrent tasks throws an Exception that the other
* tasks become aware of it (via the ExceptionBasedQueryTimeout in IndexSearcher)
* and exit immediately rather than completing their search actions.
*
* To test this:
* - a concurrent Executor is used to ensure concurrent tasks are running
* - the MatchAllOrThrowExceptionQuery is used to ensure that one of the search
* tasks throws an Exception
* - a testing ExceptionBasedTimeoutWrapper is used to track the number of times
* and early exit happens
* - a CountDownLatch is used to synchronize the task that is going to throw an Exception
* with another task that is in the ExceptionBasedQueryTimeout.shouldExit method,
* ensuring the Exception is thrown while at least one other task is still running
* - a second CountDownLatch is used to synchronize between the
* ExceptionBasedQueryTimeout.notifyExceptionThrown call (coming from the task thread
* where the exception is thrown) and the ExceptionBasedQueryTimeout.shouldExit method
* to ensure that at least one task has shouldExit return true (for an early exit)
* - an atomic earlyExitCounter tracks how many tasks exited early due to
* TimeLimitingBulkScorer.TimeExceededException in the TimeLimitingBulkScorer
*/
public void testMultipleSegmentsWithExceptionCausesEarlyTerminationOfRunningTasks() {
// skip this test when only one leaf, since one leaf means one task
// and the TimeLimitingBulkScorer will NOT be added in IndexSearcher
if (reader.leaves().size() <= 1) {
return;
}
List<LeafReaderContext> leaves = reader.leaves();
// tracks how many tasks exited early due to Exception being thrown in another task
AtomicInteger earlyExitCounter = new AtomicInteger(0);
// task that throws an Exception waits on this latch to ensure at least one task is checking the
// QueryTimeout.shouldExit method before it throws the Exception
CountDownLatch shouldExitLatch = new CountDownLatch(1);
// latch used by the ExceptionBasedQueryTimeoutWrapper to ensure that the shouldExit method
// of at least one task waits until the ExceptionBasedQueryTimeout.notifyExceptionThrown method
// is called before checking getting the shouldExit method
CountDownLatch exceptionThrownLatch = new CountDownLatch(1);
ExecutorService executorService =
Executors.newFixedThreadPool(7, new NamedThreadFactory("concurrentSlicesTest"));
try {
IndexSearcher searcher =
new IndexSearcher(reader, executorService) {
@Override
protected LeafSlice[] slices(List<LeafReaderContext> leaves) {
return slices(leaves, 1, 1);
}

@Override
BulkScorer createTimeLimitingBulkScorer(BulkScorer scorer, QueryTimeout queryTimeout) {
return new TimeLimitingBulkScorerWrapper(earlyExitCounter, scorer, queryTimeout);
}

@Override
void addExceptionBasedQueryTimeout(QueryTimeout delegate) {
setTimeout(
new ExceptionBasedQueryTimeoutWrapper(
shouldExitLatch, exceptionThrownLatch, delegate));
}
};

MatchAllOrThrowExceptionQuery query = new MatchAllOrThrowExceptionQuery(shouldExitLatch);
RuntimeException exc = expectThrows(RuntimeException.class, () -> searcher.search(query, 10));
assertThat(
exc.getMessage(), Matchers.containsString("MatchAllOrThrowExceptionQuery Exception"));
assertThat(earlyExitCounter.get(), Matchers.greaterThan(0));

} finally {
executorService.shutdown();
}
}

private static class ExceptionBasedQueryTimeoutWrapper
extends IndexSearcher.ExceptionBasedQueryTimeout {
private final CountDownLatch shouldExitLatch;
private final CountDownLatch exceptionThrownLatch;

public ExceptionBasedQueryTimeoutWrapper(
CountDownLatch shouldExitLatch,
CountDownLatch exceptionThrownLatch,
QueryTimeout delegate) {
super(delegate);
this.shouldExitLatch = shouldExitLatch;
this.exceptionThrownLatch = exceptionThrownLatch;
}

// called on the thread of the task that threw an Exception
@Override
public void notifyExceptionThrown() {
super.notifyExceptionThrown();
exceptionThrownLatch.countDown();
}

@Override
public boolean shouldExit() {
// notifies other tasks that this task is in the shouldExit method
shouldExitLatch.countDown();
try {
// wait until at least one task has been notified that an exception was thrown
exceptionThrownLatch.await(10, TimeUnit.SECONDS);
} catch (InterruptedException e) {
throw new RuntimeException("Unexpected timeout of latch await");
}
return super.shouldExit();
}
}

private static class MatchAllOrThrowExceptionQuery extends Query {

private final CountDownLatch shouldExitLatch;
private final AtomicInteger numExceptionsToThrow;
private final Query delegate;

/**
* Throws an Exception out of the {@code scorer} method the first time it is called. Otherwise,
* it delegates all calls to the MatchAllDocsQuery.
*/
public MatchAllOrThrowExceptionQuery(CountDownLatch shouldExitLatch) {
this.numExceptionsToThrow = new AtomicInteger(1);
this.shouldExitLatch = shouldExitLatch;
this.delegate = new MatchAllDocsQuery();
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
Weight matchAllWeight = delegate.createWeight(searcher, scoreMode, boost);

return new Weight(delegate) {
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return matchAllWeight.isCacheable(ctx);
}

@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
return matchAllWeight.explain(context, doc);
}

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
// only throw exception when the counter hits 1 (so only one exception gets thrown)
if (numExceptionsToThrow.getAndDecrement() == 1) {
// wait until at least one other task has
try {
// wait until we know another task is in the QueryTimeout.shouldExit method
// before throwing an Exception to ensure at least one shouldExit method
// returns true causing another task to exit early
shouldExitLatch.await(10, TimeUnit.SECONDS);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
throw new RuntimeException("MatchAllOrThrowExceptionQuery Exception");
}
return matchAllWeight.scorer(context);
}
};
}

@Override
public void visit(QueryVisitor visitor) {
delegate.visit(visitor);
}

@Override
public String toString(String field) {
return "MatchAllOrThrowExceptionQuery";
}

@Override
public int hashCode() {
return delegate.hashCode();
}

@Override
public boolean equals(Object other) {
return other == this;
}
}

static class TimeLimitingBulkScorerWrapper extends BulkScorer {
private final TimeLimitingBulkScorer delegate;
private final AtomicInteger earlyExitCounter;

/**
* Wraps a {@link TimeLimitingBulkScorer}, recording the counts of how many times {@link
* TimeLimitingBulkScorer.TimeExceededException} is thrown from the {@code score} method.
*
* @param earlyExitCounter counter to increment when {@link
* TimeLimitingBulkScorer.TimeExceededException} is caught
* @param scorer to pass to {@link TimeLimitingBulkScorer} constructor
* @param queryTimeout to pass to {@link TimeLimitingBulkScorer} constructor
*/
public TimeLimitingBulkScorerWrapper(
AtomicInteger earlyExitCounter, BulkScorer scorer, QueryTimeout queryTimeout) {
this.earlyExitCounter = earlyExitCounter;
this.delegate = new TimeLimitingBulkScorer(scorer, queryTimeout);
}

@Override
public int score(LeafCollector collector, Bits acceptDocs, int min, int max)
throws IOException {
try {
return delegate.score(collector, acceptDocs, min, max);
} catch (TimeLimitingBulkScorer.TimeExceededException tee) {
earlyExitCounter.incrementAndGet();
throw tee;
}
}

@Override
public long cost() {
return delegate.cost();
}
}
}

0 comments on commit dd0bb0a

Please sign in to comment.