Skip to content

Commit

Permalink
Make AnalyticsProcessManager class more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek committed Nov 19, 2019
1 parent 68870ac commit cf672f2
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
Expand Down Expand Up @@ -239,7 +240,6 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
"Finished analysis");
}

@AwaitsFix(bugUrl="https://github.com/elastic/elasticsearch/issues/49095")
public void testStopAndRestart() throws Exception {
initialize("regression_stop_and_restart");

Expand Down Expand Up @@ -270,8 +270,12 @@ public void testStopAndRestart() throws Exception {
// Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
assertBusy(() -> {
DataFrameAnalyticsState state = getAnalyticsStats(jobId).getState();
assertThat(state, is(anyOf(equalTo(DataFrameAnalyticsState.REINDEXING), equalTo(DataFrameAnalyticsState.ANALYZING),
equalTo(DataFrameAnalyticsState.STOPPED))));
assertThat(
state,
is(anyOf(
equalTo(DataFrameAnalyticsState.REINDEXING),
equalTo(DataFrameAnalyticsState.ANALYZING),
equalTo(DataFrameAnalyticsState.STOPPED))));
});
stopAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
Expand All @@ -287,7 +291,7 @@ public void testStopAndRestart() throws Exception {
}
}

waitUntilAnalyticsIsStopped(jobId);
waitUntilAnalyticsIsStopped(jobId, TimeValue.timeValueMinutes(1));

SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable;
Expand Down Expand Up @@ -54,7 +53,8 @@ public class AnalyticsProcessManager {
private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class);

private final Client client;
private final ThreadPool threadPool;
private final ExecutorService executorServiceForJob;
private final ExecutorService executorServiceForProcess;
private final AnalyticsProcessFactory<AnalyticsResult> processFactory;
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
private final DataFrameAnalyticsAuditor auditor;
Expand All @@ -65,40 +65,59 @@ public AnalyticsProcessManager(Client client,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
DataFrameAnalyticsAuditor auditor,
TrainedModelProvider trainedModelProvider) {
this(
client,
threadPool.generic(),
threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME),
analyticsProcessFactory,
auditor,
trainedModelProvider);
}

// Visible for testing
public AnalyticsProcessManager(Client client,
ExecutorService executorServiceForJob,
ExecutorService executorServiceForProcess,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
DataFrameAnalyticsAuditor auditor,
TrainedModelProvider trainedModelProvider) {
this.client = Objects.requireNonNull(client);
this.threadPool = Objects.requireNonNull(threadPool);
this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob);
this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess);
this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
this.auditor = Objects.requireNonNull(auditor);
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
}

public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory,
Consumer<Exception> finishHandler) {
threadPool.generic().execute(() -> {
if (task.isStopping()) {
// The task was requested to stop before we created the process context
finishHandler.accept(null);
return;
executorServiceForJob.execute(() -> {
ProcessContext processContext = new ProcessContext(config.getId());
synchronized (this) {
if (task.isStopping()) {
// The task was requested to stop before we created the process context
finishHandler.accept(null);
return;
}
if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) {
finishHandler.accept(
ExceptionsHelper.serverError("[" + config.getId() + "] Could not create process as one already exists"));
return;
}
}

// First we refresh the dest index to ensure data is searchable
// Refresh the dest index to ensure data is searchable
refreshDest(config);

ProcessContext processContext = new ProcessContext(config.getId());
if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) {
finishHandler.accept(ExceptionsHelper.serverError("[" + processContext.id
+ "] Could not create process as one already exists"));
return;
}

// Fetch existing model state (if any)
BytesReference state = getModelState(config);

if (processContext.startProcess(dataExtractorFactory, config, task, state)) {
ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME);
executorService.execute(() -> processResults(processContext));
executorService.execute(() -> processData(task, config, processContext.dataExtractor,
executorServiceForProcess.execute(() -> processResults(processContext));
executorServiceForProcess.execute(() -> processData(task, config, processContext.dataExtractor,
processContext.process, processContext.resultProcessor, finishHandler, state));
} else {
processContextByAllocation.remove(task.getAllocationId());
finishHandler.accept(null);
}
});
Expand All @@ -111,8 +130,6 @@ private BytesReference getModelState(DataFrameAnalyticsConfig config) {
}

try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
SearchRequest searchRequest = new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern());
searchRequest.source().size(1).query(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId())));
SearchResponse searchResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setSize(1)
.setQuery(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId())))
Expand Down Expand Up @@ -246,9 +263,8 @@ private void restoreState(DataFrameAnalyticsConfig config, @Nullable BytesRefere

private AnalyticsProcess<AnalyticsResult> createProcess(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config,
AnalyticsProcessConfig analyticsProcessConfig, @Nullable BytesReference state) {
ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME);
AnalyticsProcess<AnalyticsResult> process = processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state,
executorService, onProcessCrash(task));
AnalyticsProcess<AnalyticsResult> process =
processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state, executorServiceForProcess, onProcessCrash(task));
if (process.isProcessAlive() == false) {
throw ExceptionsHelper.serverError("Failed to start data frame analytics process");
}
Expand Down Expand Up @@ -285,17 +301,22 @@ private void closeProcess(DataFrameAnalyticsTask task) {
}
}

public void stop(DataFrameAnalyticsTask task) {
public synchronized void stop(DataFrameAnalyticsTask task) {
ProcessContext processContext = processContextByAllocation.get(task.getAllocationId());
if (processContext != null) {
LOGGER.debug("[{}] Stopping process", task.getParams().getId() );
LOGGER.debug("[{}] Stopping process", task.getParams().getId());
processContext.stop();
} else {
LOGGER.debug("[{}] No process context to stop", task.getParams().getId() );
LOGGER.debug("[{}] No process context to stop", task.getParams().getId());
task.markAsCompleted();
}
}

// Visible for testing
int getProcessContextCount() {
return processContextByAllocation.size();
}

class ProcessContext {

private final String id;
Expand All @@ -309,31 +330,26 @@ class ProcessContext {
this.id = Objects.requireNonNull(id);
}

public String getId() {
return id;
}

public boolean isProcessKilled() {
return processKilled;
synchronized String getFailureReason() {
return failureReason;
}

private synchronized void setFailureReason(String failureReason) {
synchronized void setFailureReason(String failureReason) {
// Only set the new reason if there isn't one already as we want to keep the first reason
if (failureReason != null) {
if (this.failureReason == null && failureReason != null) {
this.failureReason = failureReason;
}
}

private String getFailureReason() {
return failureReason;
}

public synchronized void stop() {
synchronized void stop() {
LOGGER.debug("[{}] Stopping process", id);
processKilled = true;
if (dataExtractor != null) {
dataExtractor.cancel();
}
if (resultProcessor != null) {
resultProcessor.cancel();
}
if (process != null) {
try {
process.kill();
Expand All @@ -346,8 +362,8 @@ public synchronized void stop() {
/**
* @return {@code true} if the process was started or {@code false} if it was not because it was stopped in the meantime
*/
private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, DataFrameAnalyticsConfig config,
DataFrameAnalyticsTask task, @Nullable BytesReference state) {
synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, DataFrameAnalyticsConfig config,
DataFrameAnalyticsTask task, @Nullable BytesReference state) {
if (processKilled) {
// The job was stopped before we started the process so no need to start it
return false;
Expand All @@ -365,8 +381,8 @@ private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtr
process = createProcess(task, config, analyticsProcessConfig, state);
DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client,
dataExtractorFactory.newExtractor(true));
resultProcessor = new AnalyticsResultProcessor(config, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker(),
trainedModelProvider, auditor, dataExtractor.getFieldNames());
resultProcessor = new AnalyticsResultProcessor(
config, dataFrameRowsJoiner, task.getProgressTracker(), trainedModelProvider, auditor, dataExtractor.getFieldNames());
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,26 @@
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

public class AnalyticsResultProcessor {

private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class);

private final DataFrameAnalyticsConfig analytics;
private final DataFrameRowsJoiner dataFrameRowsJoiner;
private final Supplier<Boolean> isProcessKilled;
private final ProgressTracker progressTracker;
private final TrainedModelProvider trainedModelProvider;
private final DataFrameAnalyticsAuditor auditor;
private final List<String> fieldNames;
private final CountDownLatch completionLatch = new CountDownLatch(1);
private volatile String failure;
private volatile boolean isCancelled;

public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
Supplier<Boolean> isProcessKilled, ProgressTracker progressTracker,
TrainedModelProvider trainedModelProvider, DataFrameAnalyticsAuditor auditor,
List<String> fieldNames) {
ProgressTracker progressTracker, TrainedModelProvider trainedModelProvider,
DataFrameAnalyticsAuditor auditor, List<String> fieldNames) {
this.analytics = Objects.requireNonNull(analytics);
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
this.isProcessKilled = Objects.requireNonNull(isProcessKilled);
this.progressTracker = Objects.requireNonNull(progressTracker);
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
this.auditor = Objects.requireNonNull(auditor);
Expand All @@ -74,6 +71,10 @@ public void awaitForCompletion() {
}
}

public void cancel() {
isCancelled = true;
}

public void process(AnalyticsProcess<AnalyticsResult> process) {
long totalRows = process.getConfig().rows();
long processedRows = 0;
Expand All @@ -82,20 +83,23 @@ public void process(AnalyticsProcess<AnalyticsResult> process) {
try (DataFrameRowsJoiner resultsJoiner = dataFrameRowsJoiner) {
Iterator<AnalyticsResult> iterator = process.readAnalyticsResults();
while (iterator.hasNext()) {
if (isCancelled) {
break;
}
AnalyticsResult result = iterator.next();
processResult(result, resultsJoiner);
if (result.getRowResults() != null) {
processedRows++;
progressTracker.writingResultsPercent.set(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows));
}
}
if (isProcessKilled.get() == false) {
if (isCancelled == false) {
// This means we completed successfully so we need to set the progress to 100.
// This is because due to skipped rows, it is possible the processed rows will not reach the total rows.
progressTracker.writingResultsPercent.set(100);
}
} catch (Exception e) {
if (isProcessKilled.get()) {
if (isCancelled) {
// No need to log error as it's due to stopping
} else {
LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", analytics.getId()), e);
Expand Down
Loading

0 comments on commit cf672f2

Please sign in to comment.