From 2b6cd7aace6f80adaecb76145826748c67f8e124 Mon Sep 17 00:00:00 2001 From: David Roberts Date: Tue, 13 Nov 2018 13:31:39 +0000 Subject: [PATCH] [ML] Reimplement established model memory (#35263) This is the 6.6/6.7 implementation of a master node service to keep track of the native process memory requirement of each ML job with an associated native process. The new ML memory tracker service works when the whole cluster is upgraded to at least version 6.6. For mixed version clusters the old mechanism of established model memory stored on the job in cluster state is used. This means that the old (and complex) code to keep established model memory up to date on the job object cannot yet be removed. When this change is forward ported to 7.0 the old way of keeping established model memory updated will be removed. --- .../xpack/core/ml/MlMetadata.java | 62 +++- .../xpack/core/ml/job/config/Job.java | 7 +- .../xpack/core/ml/job/config/JobTests.java | 2 +- .../xpack/ml/MachineLearning.java | 10 +- .../ml/action/TransportDeleteJobAction.java | 9 +- .../ml/action/TransportOpenJobAction.java | 157 +++++++-- .../output/AutoDetectResultProcessor.java | 1 + .../xpack/ml/process/MlMemoryTracker.java | 331 ++++++++++++++++++ .../xpack/ml/MlMetadataTests.java | 13 +- .../action/TransportOpenJobActionTests.java | 40 ++- .../integration/MlDistributedFailureIT.java | 89 ++++- .../xpack/ml/integration/TooManyJobsIT.java | 2 - .../ml/process/MlMemoryTrackerTests.java | 195 +++++++++++ 13 files changed, 853 insertions(+), 65 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java index 81762def4cc35..febed3d97efbb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java @@ -57,8 +57,9 @@ public class MlMetadata implements XPackPlugin.XPackMetaDataCustom { public static final String TYPE = "ml"; private static final ParseField JOBS_FIELD = new ParseField("jobs"); private static final ParseField DATAFEEDS_FIELD = new ParseField("datafeeds"); + private static final ParseField LAST_MEMORY_REFRESH_VERSION_FIELD = new ParseField("last_memory_refresh_version"); - public static final MlMetadata EMPTY_METADATA = new MlMetadata(Collections.emptySortedMap(), Collections.emptySortedMap()); + public static final MlMetadata EMPTY_METADATA = new MlMetadata(Collections.emptySortedMap(), Collections.emptySortedMap(), null); // This parser follows the pattern that metadata is parsed leniently (to allow for enhancements) public static final ObjectParser LENIENT_PARSER = new ObjectParser<>("ml_metadata", true, Builder::new); @@ -66,15 +67,18 @@ public class MlMetadata implements XPackPlugin.XPackMetaDataCustom { LENIENT_PARSER.declareObjectArray(Builder::putJobs, (p, c) -> Job.LENIENT_PARSER.apply(p, c).build(), JOBS_FIELD); LENIENT_PARSER.declareObjectArray(Builder::putDatafeeds, (p, c) -> DatafeedConfig.LENIENT_PARSER.apply(p, c).build(), DATAFEEDS_FIELD); + LENIENT_PARSER.declareLong(Builder::setLastMemoryRefreshVersion, LAST_MEMORY_REFRESH_VERSION_FIELD); } private final SortedMap jobs; private final SortedMap datafeeds; + private final Long lastMemoryRefreshVersion; private final GroupOrJobLookup groupOrJobLookup; - private MlMetadata(SortedMap jobs, SortedMap datafeeds) { + private MlMetadata(SortedMap jobs, SortedMap datafeeds, Long lastMemoryRefreshVersion) { this.jobs = Collections.unmodifiableSortedMap(jobs); this.datafeeds = Collections.unmodifiableSortedMap(datafeeds); + this.lastMemoryRefreshVersion = lastMemoryRefreshVersion; this.groupOrJobLookup = new GroupOrJobLookup(jobs.values()); } @@ -112,6 +116,10 @@ public Set expandDatafeedIds(String expression, boolean allowNoDatafeeds .expand(expression, allowNoDatafeeds); } + public Long getLastMemoryRefreshVersion() { + return lastMemoryRefreshVersion; + } + @Override public Version getMinimalSupportedVersion() { return Version.V_5_4_0; @@ -145,7 +153,11 @@ public MlMetadata(StreamInput in) throws IOException { datafeeds.put(in.readString(), new DatafeedConfig(in)); } this.datafeeds = datafeeds; - + if (in.getVersion().onOrAfter(Version.V_6_6_0)) { + lastMemoryRefreshVersion = in.readOptionalLong(); + } else { + lastMemoryRefreshVersion = null; + } this.groupOrJobLookup = new GroupOrJobLookup(jobs.values()); } @@ -153,6 +165,9 @@ public MlMetadata(StreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { writeMap(jobs, out); writeMap(datafeeds, out); + if (out.getVersion().onOrAfter(Version.V_6_6_0)) { + out.writeOptionalLong(lastMemoryRefreshVersion); + } } private static void writeMap(Map map, StreamOutput out) throws IOException { @@ -169,6 +184,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws new DelegatingMapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"), params); mapValuesToXContent(JOBS_FIELD, jobs, builder, extendedParams); mapValuesToXContent(DATAFEEDS_FIELD, datafeeds, builder, extendedParams); + if (lastMemoryRefreshVersion != null) { + builder.field(LAST_MEMORY_REFRESH_VERSION_FIELD.getPreferredName(), lastMemoryRefreshVersion); + } return builder; } @@ -185,30 +203,46 @@ public static class MlMetadataDiff implements NamedDiff { final Diff> jobs; final Diff> datafeeds; + final Long lastMemoryRefreshVersion; MlMetadataDiff(MlMetadata before, MlMetadata after) { this.jobs = DiffableUtils.diff(before.jobs, after.jobs, DiffableUtils.getStringKeySerializer()); this.datafeeds = DiffableUtils.diff(before.datafeeds, after.datafeeds, DiffableUtils.getStringKeySerializer()); + this.lastMemoryRefreshVersion = after.lastMemoryRefreshVersion; } public MlMetadataDiff(StreamInput in) throws IOException { this.jobs = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(), Job::new, MlMetadataDiff::readJobDiffFrom); this.datafeeds = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(), DatafeedConfig::new, - MlMetadataDiff::readSchedulerDiffFrom); + MlMetadataDiff::readDatafeedDiffFrom); + if (in.getVersion().onOrAfter(Version.V_6_6_0)) { + lastMemoryRefreshVersion = in.readOptionalLong(); + } else { + lastMemoryRefreshVersion = null; + } } + /** + * Merge the diff with the ML metadata. + * @param part The current ML metadata. + * @return The new ML metadata. + */ @Override public MetaData.Custom apply(MetaData.Custom part) { TreeMap newJobs = new TreeMap<>(jobs.apply(((MlMetadata) part).jobs)); TreeMap newDatafeeds = new TreeMap<>(datafeeds.apply(((MlMetadata) part).datafeeds)); - return new MlMetadata(newJobs, newDatafeeds); + // lastMemoryRefreshVersion always comes from the diff - no need to merge with the old value + return new MlMetadata(newJobs, newDatafeeds, lastMemoryRefreshVersion); } @Override public void writeTo(StreamOutput out) throws IOException { jobs.writeTo(out); datafeeds.writeTo(out); + if (out.getVersion().onOrAfter(Version.V_6_6_0)) { + out.writeOptionalLong(lastMemoryRefreshVersion); + } } @Override @@ -220,7 +254,7 @@ static Diff readJobDiffFrom(StreamInput in) throws IOException { return AbstractDiffable.readDiffFrom(Job::new, in); } - static Diff readSchedulerDiffFrom(StreamInput in) throws IOException { + static Diff readDatafeedDiffFrom(StreamInput in) throws IOException { return AbstractDiffable.readDiffFrom(DatafeedConfig::new, in); } } @@ -233,7 +267,8 @@ public boolean equals(Object o) { return false; MlMetadata that = (MlMetadata) o; return Objects.equals(jobs, that.jobs) && - Objects.equals(datafeeds, that.datafeeds); + Objects.equals(datafeeds, that.datafeeds) && + Objects.equals(lastMemoryRefreshVersion, that.lastMemoryRefreshVersion); } @Override @@ -243,13 +278,14 @@ public final String toString() { @Override public int hashCode() { - return Objects.hash(jobs, datafeeds); + return Objects.hash(jobs, datafeeds, lastMemoryRefreshVersion); } public static class Builder { private TreeMap jobs; private TreeMap datafeeds; + private Long lastMemoryRefreshVersion; public Builder() { jobs = new TreeMap<>(); @@ -263,6 +299,7 @@ public Builder(@Nullable MlMetadata previous) { } else { jobs = new TreeMap<>(previous.jobs); datafeeds = new TreeMap<>(previous.datafeeds); + lastMemoryRefreshVersion = previous.lastMemoryRefreshVersion; } } @@ -382,8 +419,13 @@ private Builder putDatafeeds(Collection datafeeds) { return this; } + public Builder setLastMemoryRefreshVersion(Long lastMemoryRefreshVersion) { + this.lastMemoryRefreshVersion = lastMemoryRefreshVersion; + return this; + } + public MlMetadata build() { - return new MlMetadata(jobs, datafeeds); + return new MlMetadata(jobs, datafeeds, lastMemoryRefreshVersion); } public void markJobAsDeleting(String jobId, PersistentTasksCustomMetaData tasks, boolean allowDeleteOpenJob) { @@ -420,8 +462,6 @@ void checkJobHasNoDatafeed(String jobId) { } } - - public static MlMetadata getMlMetadata(ClusterState state) { MlMetadata mlMetadata = (state == null) ? null : state.getMetaData().custom(TYPE); if (mlMetadata == null) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/Job.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/Job.java index ffe24dff8ced0..fd7fa70bded43 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/Job.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/Job.java @@ -142,6 +142,7 @@ private static ObjectParser createParser(boolean ignoreUnknownFie private final Date createTime; private final Date finishedTime; private final Date lastDataTime; + // TODO: Remove in 7.0 private final Long establishedModelMemory; private final AnalysisConfig analysisConfig; private final AnalysisLimits analysisLimits; @@ -439,6 +440,7 @@ public Collection allInputFields() { * program code and stack. * @return an estimate of the memory requirement of this job, in bytes */ + // TODO: remove this method in 7.0 public long estimateMemoryFootprint() { if (establishedModelMemory != null && establishedModelMemory > 0) { return establishedModelMemory + PROCESS_MEMORY_OVERHEAD.getBytes(); @@ -658,6 +660,7 @@ public static class Builder implements Writeable, ToXContentObject { private Date createTime; private Date finishedTime; private Date lastDataTime; + // TODO: remove in 7.0 private Long establishedModelMemory; private ModelPlotConfig modelPlotConfig; private Long renormalizationWindowDays; @@ -1102,10 +1105,6 @@ private void validateGroups() { public Job build(Date createTime) { setCreateTime(createTime); setJobVersion(Version.CURRENT); - // TODO: Maybe we _could_ accept a value for this supplied at create time - it would - // mean cloned jobs that hadn't been edited much would start with an accurate expected size. - // But on the other hand it would mean jobs that were cloned and then completely changed - // would start with a size that was completely wrong. setEstablishedModelMemory(null); return build(); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/JobTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/JobTests.java index 62340ba6cf63c..0fae85f6d6b5c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/JobTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/JobTests.java @@ -561,7 +561,7 @@ public void testEstimateMemoryFootprint_GivenNoLimitAndNotEstablished() { builder.setEstablishedModelMemory(0L); } assertEquals(ByteSizeUnit.MB.toBytes(AnalysisLimits.PRE_6_1_DEFAULT_MODEL_MEMORY_LIMIT_MB) - + Job.PROCESS_MEMORY_OVERHEAD.getBytes(), builder.build().estimateMemoryFootprint()); + + Job.PROCESS_MEMORY_OVERHEAD.getBytes(), builder.build().estimateMemoryFootprint()); } public void testEarliestValidTimestamp_GivenEmptyDataCounts() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 77f898dcb8773..9dbf5cc9f6d16 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -181,6 +181,7 @@ import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerFactory; import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory; import org.elasticsearch.xpack.ml.notifications.Auditor; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import org.elasticsearch.xpack.ml.process.NativeController; import org.elasticsearch.xpack.ml.process.NativeControllerHolder; import org.elasticsearch.xpack.ml.rest.RestDeleteExpiredDataAction; @@ -278,6 +279,7 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu private final SetOnce autodetectProcessManager = new SetOnce<>(); private final SetOnce datafeedManager = new SetOnce<>(); + private final SetOnce memoryTracker = new SetOnce<>(); public MachineLearning(Settings settings, Path configPath) { this.settings = settings; @@ -420,6 +422,8 @@ public Collection createComponents(Client client, ClusterService cluster this.datafeedManager.set(datafeedManager); MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(environment, clusterService, datafeedManager, autodetectProcessManager); + MlMemoryTracker memoryTracker = new MlMemoryTracker(clusterService, threadPool, jobManager, jobResultsProvider); + this.memoryTracker.set(memoryTracker); // This object's constructor attaches to the license state, so there's no need to retain another reference to it new InvalidLicenseEnforcer(getLicenseState(), threadPool, datafeedManager, autodetectProcessManager); @@ -438,7 +442,8 @@ public Collection createComponents(Client client, ClusterService cluster jobDataCountsPersister, datafeedManager, auditor, - new MlAssignmentNotifier(auditor, clusterService) + new MlAssignmentNotifier(auditor, clusterService), + memoryTracker ); } @@ -449,7 +454,8 @@ public List> getPersistentTasksExecutor(ClusterServic } return Arrays.asList( - new TransportOpenJobAction.OpenJobPersistentTasksExecutor(settings, clusterService, autodetectProcessManager.get()), + new TransportOpenJobAction.OpenJobPersistentTasksExecutor(settings, clusterService, autodetectProcessManager.get(), + memoryTracker.get()), new TransportStartDatafeedAction.StartDatafeedPersistentTasksExecutor(settings, datafeedManager.get()) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteJobAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteJobAction.java index 761c21b63f165..10e6b8093f763 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteJobAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteJobAction.java @@ -69,6 +69,7 @@ import org.elasticsearch.xpack.ml.job.persistence.JobDataDeleter; import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider; import org.elasticsearch.xpack.ml.notifications.Auditor; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import org.elasticsearch.xpack.ml.utils.MlIndicesUtils; import java.util.ArrayList; @@ -94,6 +95,7 @@ public class TransportDeleteJobAction extends TransportMasterNodeAction(); } @@ -211,6 +215,9 @@ private void normalDeleteJob(ParentTaskAssigningClient parentTaskClient, DeleteJ ActionListener listener) { String jobId = request.getJobId(); + // We clean up the memory tracker on delete rather than close as close is not a master node action + memoryTracker.removeJob(jobId); + // Step 4. When the job has been removed from the cluster state, return a response // ------- CheckedConsumer apiResponseHandler = jobDeleted -> { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java index 5b4be78596a64..ce820db0babec 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java @@ -69,6 +69,7 @@ import org.elasticsearch.xpack.ml.job.persistence.JobConfigProvider; import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider; import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import java.io.IOException; import java.util.ArrayList; @@ -95,20 +96,23 @@ To ensure that a subsequent close job call will see that same task status (and s */ public class TransportOpenJobAction extends TransportMasterNodeAction { + private static final PersistentTasksCustomMetaData.Assignment AWAITING_LAZY_ASSIGNMENT = + new PersistentTasksCustomMetaData.Assignment(null, "persistent task is awaiting node assignment."); + private final XPackLicenseState licenseState; private final PersistentTasksService persistentTasksService; private final Client client; private final JobResultsProvider jobResultsProvider; private final JobConfigProvider jobConfigProvider; - private static final PersistentTasksCustomMetaData.Assignment AWAITING_LAZY_ASSIGNMENT = - new PersistentTasksCustomMetaData.Assignment(null, "persistent task is awaiting node assignment."); + private final MlMemoryTracker memoryTracker; @Inject public TransportOpenJobAction(Settings settings, TransportService transportService, ThreadPool threadPool, XPackLicenseState licenseState, ClusterService clusterService, PersistentTasksService persistentTasksService, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, Client client, - JobResultsProvider jobResultsProvider, JobConfigProvider jobConfigProvider) { + JobResultsProvider jobResultsProvider, JobConfigProvider jobConfigProvider, + MlMemoryTracker memoryTracker) { super(settings, OpenJobAction.NAME, transportService, clusterService, threadPool, actionFilters, indexNameExpressionResolver, OpenJobAction.Request::new); this.licenseState = licenseState; @@ -116,6 +120,7 @@ public TransportOpenJobAction(Settings settings, TransportService transportServi this.client = client; this.jobResultsProvider = jobResultsProvider; this.jobConfigProvider = jobConfigProvider; + this.memoryTracker = memoryTracker; } /** @@ -144,6 +149,7 @@ static PersistentTasksCustomMetaData.Assignment selectLeastLoadedMlNode(String j int maxConcurrentJobAllocations, int fallbackMaxNumberOfOpenJobs, int maxMachineMemoryPercent, + MlMemoryTracker memoryTracker, Logger logger) { String resultsIndexName = job != null ? job.getResultsIndexName() : null; List unavailableIndices = verifyIndicesPrimaryShardsAreActive(resultsIndexName, clusterState); @@ -154,10 +160,38 @@ static PersistentTasksCustomMetaData.Assignment selectLeastLoadedMlNode(String j return new PersistentTasksCustomMetaData.Assignment(null, reason); } + // Try to allocate jobs according to memory usage, but if that's not possible (maybe due to a mixed version cluster or maybe + // because of some weird OS problem) then fall back to the old mechanism of only considering numbers of assigned jobs + boolean allocateByMemory = true; + + if (memoryTracker.isRecentlyRefreshed() == false) { + + boolean scheduledRefresh = memoryTracker.asyncRefresh(ActionListener.wrap( + acknowledged -> { + if (acknowledged) { + logger.trace("Job memory requirement refresh request completed successfully"); + } else { + logger.warn("Job memory requirement refresh request completed but did not set time in cluster state"); + } + }, + e -> logger.error("Failed to refresh job memory requirements", e) + )); + if (scheduledRefresh) { + String reason = "Not opening job [" + jobId + "] because job memory requirements are stale - refresh requested"; + logger.debug(reason); + return new PersistentTasksCustomMetaData.Assignment(null, reason); + } else { + allocateByMemory = false; + logger.warn("Falling back to allocating job [{}] by job counts because a memory requirement refresh could not be scheduled", + jobId); + } + } + List reasons = new LinkedList<>(); long maxAvailableCount = Long.MIN_VALUE; + long maxAvailableMemory = Long.MIN_VALUE; DiscoveryNode minLoadedNodeByCount = null; - + DiscoveryNode minLoadedNodeByMemory = null; PersistentTasksCustomMetaData persistentTasks = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); for (DiscoveryNode node : clusterState.getNodes()) { Map nodeAttributes = node.getAttributes(); @@ -197,10 +231,9 @@ static PersistentTasksCustomMetaData.Assignment selectLeastLoadedMlNode(String j } } - long numberOfAssignedJobs = 0; int numberOfAllocatingJobs = 0; - + long assignedJobMemory = 0; if (persistentTasks != null) { // find all the job tasks assigned to this node Collection> assignedTasks = persistentTasks.findTasks( @@ -231,6 +264,15 @@ static PersistentTasksCustomMetaData.Assignment selectLeastLoadedMlNode(String j if (jobState.isAnyOf(JobState.CLOSED, JobState.FAILED) == false) { // Don't count CLOSED or FAILED jobs, as they don't consume native memory ++numberOfAssignedJobs; + OpenJobAction.JobParams params = (OpenJobAction.JobParams) assignedTask.getParams(); + Long jobMemoryRequirement = memoryTracker.getJobMemoryRequirement(params.getJobId()); + if (jobMemoryRequirement == null) { + allocateByMemory = false; + logger.debug("Falling back to allocating job [{}] by job counts because " + + "the memory requirement for job [{}] was not available", jobId, params.getJobId()); + } else { + assignedJobMemory += jobMemoryRequirement; + } } } } @@ -271,10 +313,62 @@ static PersistentTasksCustomMetaData.Assignment selectLeastLoadedMlNode(String j maxAvailableCount = availableCount; minLoadedNodeByCount = node; } + + String machineMemoryStr = nodeAttributes.get(MachineLearning.MACHINE_MEMORY_NODE_ATTR); + long machineMemory = -1; + // TODO: remove leniency and reject the node if the attribute is null in 7.0 + if (machineMemoryStr != null) { + try { + machineMemory = Long.parseLong(machineMemoryStr); + } catch (NumberFormatException e) { + String reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndMlAttributes(node) + "], because " + + MachineLearning.MACHINE_MEMORY_NODE_ATTR + " attribute [" + machineMemoryStr + "] is not a long"; + logger.trace(reason); + reasons.add(reason); + continue; + } + } + + if (allocateByMemory) { + if (machineMemory > 0) { + long maxMlMemory = machineMemory * maxMachineMemoryPercent / 100; + Long estimatedMemoryFootprint = memoryTracker.getJobMemoryRequirement(jobId); + if (estimatedMemoryFootprint != null) { + long availableMemory = maxMlMemory - assignedJobMemory; + if (estimatedMemoryFootprint > availableMemory) { + String reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndMlAttributes(node) + + "], because this node has insufficient available memory. Available memory for ML [" + maxMlMemory + + "], memory required by existing jobs [" + assignedJobMemory + + "], estimated memory required for this job [" + estimatedMemoryFootprint + "]"; + logger.trace(reason); + reasons.add(reason); + continue; + } + + if (maxAvailableMemory < availableMemory) { + maxAvailableMemory = availableMemory; + minLoadedNodeByMemory = node; + } + } else { + // If we cannot get the job memory requirement, + // fall back to simply allocating by job count + allocateByMemory = false; + logger.debug("Falling back to allocating job [{}] by job counts because its memory requirement was not available", + jobId); + } + } else { + // If we cannot get the available memory on any machine in + // the cluster, fall back to simply allocating by job count + allocateByMemory = false; + logger.debug("Falling back to allocating job [{}] by job counts because machine memory was not available for node [{}]", + jobId, nodeNameAndMlAttributes(node)); + } + } } - if (minLoadedNodeByCount != null) { - logger.debug("selected node [{}] for job [{}]", minLoadedNodeByCount, jobId); - return new PersistentTasksCustomMetaData.Assignment(minLoadedNodeByCount.getId(), ""); + DiscoveryNode minLoadedNode = allocateByMemory ? minLoadedNodeByMemory : minLoadedNodeByCount; + if (minLoadedNode != null) { + logger.debug("selected node [{}] for job [{}]", minLoadedNode, jobId); + return new PersistentTasksCustomMetaData.Assignment(minLoadedNode.getId(), ""); } else { String explanation = String.join("|", reasons); logger.debug("no node selected for job [{}], reasons [{}]", jobId, explanation); @@ -415,6 +509,11 @@ protected void masterOperation(OpenJobAction.Request request, ClusterState state OpenJobAction.JobParams jobParams = request.getJobParams(); if (licenseState.isMachineLearningAllowed()) { + // If the whole cluster supports the ML memory tracker then we don't need + // to worry about updating established model memory on the job objects + // TODO: remove in 7.0 as it will always be true + boolean clusterSupportsMlMemoryTracker = state.getNodes().getMinNodeVersion().onOrAfter(Version.V_6_6_0); + // Clear job finished time once the job is started and respond ActionListener clearJobFinishTime = ActionListener.wrap( response -> { @@ -446,15 +545,19 @@ public void onFailure(Exception e) { }; // Start job task - ActionListener jobUpateListener = ActionListener.wrap( - response -> { - persistentTasksService.sendStartRequest(MlTasks.jobTaskId(jobParams.getJobId()), - MlTasks.JOB_TASK_NAME, jobParams, waitForJobToStart); - }, - listener::onFailure + ActionListener memoryRequirementRefreshListener = ActionListener.wrap( + mem -> persistentTasksService.sendStartRequest(MlTasks.jobTaskId(jobParams.getJobId()), MlTasks.JOB_TASK_NAME, jobParams, + waitForJobToStart), + listener::onFailure + ); + + // Tell the job tracker to refresh the memory requirement for this job and all other jobs that have persistent tasks + ActionListener jobUpdateListener = ActionListener.wrap( + response -> memoryTracker.refreshJobMemoryAndAllOthers(jobParams.getJobId(), memoryRequirementRefreshListener), + listener::onFailure ); - // Update established model memory for pre-6.1 jobs that haven't had it set + // Update established model memory for pre-6.1 jobs that haven't had it set (TODO: remove in 7.0) // and increase the model memory limit for 6.1 - 6.3 jobs ActionListener missingMappingsListener = ActionListener.wrap( response -> { @@ -462,8 +565,9 @@ public void onFailure(Exception e) { if (job != null) { Version jobVersion = job.getJobVersion(); Long jobEstablishedModelMemory = job.getEstablishedModelMemory(); - if ((jobVersion == null || jobVersion.before(Version.V_6_1_0)) + if (clusterSupportsMlMemoryTracker == false && (jobVersion == null || jobVersion.before(Version.V_6_1_0)) && (jobEstablishedModelMemory == null || jobEstablishedModelMemory == 0)) { + // TODO: remove in 7.0 - established model memory no longer needs to be set on the job object // Set the established memory usage for pre 6.1 jobs jobResultsProvider.getEstablishedMemoryUsage(job.getId(), null, null, establishedModelMemory -> { if (establishedModelMemory != null && establishedModelMemory > 0) { @@ -472,9 +576,9 @@ public void onFailure(Exception e) { UpdateJobAction.Request updateRequest = UpdateJobAction.Request.internal(job.getId(), update); executeAsyncWithOrigin(client, ML_ORIGIN, UpdateJobAction.INSTANCE, updateRequest, - jobUpateListener); + jobUpdateListener); } else { - jobUpateListener.onResponse(null); + jobUpdateListener.onResponse(null); } }, listener::onFailure); } else if (jobVersion != null && @@ -491,16 +595,16 @@ public void onFailure(Exception e) { .setAnalysisLimits(limits).build(); UpdateJobAction.Request updateRequest = UpdateJobAction.Request.internal(job.getId(), update); executeAsyncWithOrigin(client, ML_ORIGIN, UpdateJobAction.INSTANCE, updateRequest, - jobUpateListener); + jobUpdateListener); } else { - jobUpateListener.onResponse(null); + jobUpdateListener.onResponse(null); } } else { - jobUpateListener.onResponse(null); + jobUpdateListener.onResponse(null); } } else { - jobUpateListener.onResponse(null); + jobUpdateListener.onResponse(null); } }, listener::onFailure ); @@ -644,6 +748,7 @@ private void addDocMappingIfMissing(String alias, CheckedSupplier { private final AutodetectProcessManager autodetectProcessManager; + private final MlMemoryTracker memoryTracker; /** * The maximum number of open jobs can be different on each node. However, nodes on older versions @@ -657,9 +762,10 @@ public static class OpenJobPersistentTasksExecutor extends PersistentTasksExecut private volatile int maxLazyMLNodes; public OpenJobPersistentTasksExecutor(Settings settings, ClusterService clusterService, - AutodetectProcessManager autodetectProcessManager) { + AutodetectProcessManager autodetectProcessManager, MlMemoryTracker memoryTracker) { super(MlTasks.JOB_TASK_NAME, MachineLearning.UTILITY_THREAD_POOL_NAME); this.autodetectProcessManager = autodetectProcessManager; + this.memoryTracker = memoryTracker; this.fallbackMaxNumberOfOpenJobs = AutodetectProcessManager.MAX_OPEN_JOBS_PER_NODE.get(settings); this.maxConcurrentJobAllocations = MachineLearning.CONCURRENT_JOB_ALLOCATIONS.get(settings); this.maxMachineMemoryPercent = MachineLearning.MAX_MACHINE_MEMORY_PERCENT.get(settings); @@ -679,10 +785,11 @@ public PersistentTasksCustomMetaData.Assignment getAssignment(OpenJobAction.JobP maxConcurrentJobAllocations, fallbackMaxNumberOfOpenJobs, maxMachineMemoryPercent, + memoryTracker, logger); if (assignment.getExecutorNode() == null) { int numMlNodes = 0; - for(DiscoveryNode node : clusterState.getNodes()) { + for (DiscoveryNode node : clusterState.getNodes()) { if (Boolean.valueOf(node.getAttributes().get(MachineLearning.ML_ENABLED_NODE_ATTR))) { numMlNodes++; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutoDetectResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutoDetectResultProcessor.java index 9ae247cca73b6..727732540e72b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutoDetectResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutoDetectResultProcessor.java @@ -111,6 +111,7 @@ public class AutoDetectResultProcessor { * New model size stats are read as the process is running */ private volatile ModelSizeStats latestModelSizeStats; + // TODO: remove in 7.0, along with all established model memory functionality in this class private volatile Date latestDateForEstablishedModelMemoryCalc; private volatile long latestEstablishedModelMemory; private volatile boolean haveNewLatestModelSizeStats; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java new file mode 100644 index 0000000000000..2ea92d98391eb --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java @@ -0,0 +1,331 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.cluster.AckedClusterStateUpdateTask; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.LocalNodeMasterListener; +import org.elasticsearch.cluster.ack.AckedRequest; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.MlMetadata; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.OpenJobAction; +import org.elasticsearch.xpack.core.ml.job.config.Job; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.job.JobManager; +import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; + +/** + * This class keeps track of the memory requirement of ML jobs. + * It only functions on the master node - for this reason it should only be used by master node actions. + * The memory requirement for ML jobs can be updated in 3 ways: + * 1. For all open ML jobs (via {@link #asyncRefresh}) + * 2. For all open ML jobs, plus one named ML job that is not open (via {@link #refreshJobMemoryAndAllOthers}) + * 3. For one named ML job (via {@link #refreshJobMemory}) + * In all cases a listener informs the caller when the requested updates are complete. + */ +public class MlMemoryTracker implements LocalNodeMasterListener { + + private static final AckedRequest ACKED_REQUEST = new AckedRequest() { + @Override + public TimeValue ackTimeout() { + return AcknowledgedRequest.DEFAULT_ACK_TIMEOUT; + } + + @Override + public TimeValue masterNodeTimeout() { + return AcknowledgedRequest.DEFAULT_ACK_TIMEOUT; + } + }; + + private static final Duration RECENT_UPDATE_THRESHOLD = Duration.ofMinutes(1); + + private final Logger logger = LogManager.getLogger(MlMemoryTracker.class); + private final ConcurrentHashMap memoryRequirementByJob = new ConcurrentHashMap<>(); + private final List> fullRefreshCompletionListeners = new ArrayList<>(); + + private final ThreadPool threadPool; + private final ClusterService clusterService; + private final JobManager jobManager; + private final JobResultsProvider jobResultsProvider; + private volatile boolean isMaster; + private volatile Instant lastUpdateTime; + + public MlMemoryTracker(ClusterService clusterService, ThreadPool threadPool, JobManager jobManager, + JobResultsProvider jobResultsProvider) { + this.threadPool = threadPool; + this.clusterService = clusterService; + this.jobManager = jobManager; + this.jobResultsProvider = jobResultsProvider; + clusterService.addLocalNodeMasterListener(this); + } + + @Override + public void onMaster() { + isMaster = true; + logger.trace("ML memory tracker on master"); + } + + @Override + public void offMaster() { + isMaster = false; + logger.trace("ML memory tracker off master"); + memoryRequirementByJob.clear(); + lastUpdateTime = null; + } + + @Override + public String executorName() { + return MachineLearning.UTILITY_THREAD_POOL_NAME; + } + + /** + * Is the information in this object sufficiently up to date + * for valid allocation decisions to be made using it? + */ + public boolean isRecentlyRefreshed() { + Instant localLastUpdateTime = lastUpdateTime; + return localLastUpdateTime != null && localLastUpdateTime.plus(RECENT_UPDATE_THRESHOLD).isAfter(Instant.now()); + } + + /** + * Get the memory requirement for a job. + * This method only works on the master node. + * @param jobId The job ID. + * @return The memory requirement of the job specified by {@code jobId}, + * or null if it cannot be calculated. + */ + public Long getJobMemoryRequirement(String jobId) { + + if (isMaster == false) { + return null; + } + + Long memoryRequirement = memoryRequirementByJob.get(jobId); + if (memoryRequirement != null) { + return memoryRequirement; + } + + // Fallback for mixed version 6.6+/pre-6.6 cluster - TODO: remove in 7.0 + Job job = MlMetadata.getMlMetadata(clusterService.state()).getJobs().get(jobId); + if (job != null) { + return job.estimateMemoryFootprint(); + } + + return null; + } + + /** + * Remove any memory requirement that is stored for the specified job. + * It doesn't matter if this method is called for a job that doesn't have + * a stored memory requirement. + */ + public void removeJob(String jobId) { + memoryRequirementByJob.remove(jobId); + } + + /** + * Uses a separate thread to refresh the memory requirement for every ML job that has + * a corresponding persistent task. This method only works on the master node. + * @param listener Will be called when the async refresh completes or fails. The + * boolean value indicates whether the cluster state was updated + * with the refresh completion time. (If it was then this will in + * cause the persistent tasks framework to check if any persistent + * tasks are awaiting allocation.) + * @return true if the async refresh is scheduled, and false + * if this is not possible for some reason. + */ + public boolean asyncRefresh(ActionListener listener) { + + if (isMaster) { + try { + ActionListener mlMetaUpdateListener = ActionListener.wrap( + aVoid -> recordUpdateTimeInClusterState(listener), + listener::onFailure + ); + threadPool.executor(executorName()).execute( + () -> refresh(clusterService.state().getMetaData().custom(PersistentTasksCustomMetaData.TYPE), mlMetaUpdateListener)); + return true; + } catch (EsRejectedExecutionException e) { + logger.debug("Couldn't schedule ML memory update - node might be shutting down", e); + } + } + + return false; + } + + /** + * This refreshes the memory requirement for every ML job that has a corresponding + * persistent task and, in addition, one job that doesn't have a persistent task. + * This method only works on the master node. + * @param jobId The job ID of the job whose memory requirement is to be refreshed + * despite not having a corresponding persistent task. + * @param listener Receives the memory requirement of the job specified by {@code jobId}, + * or null if it cannot be calculated. + */ + public void refreshJobMemoryAndAllOthers(String jobId, ActionListener listener) { + + if (isMaster == false) { + listener.onResponse(null); + return; + } + + PersistentTasksCustomMetaData persistentTasks = clusterService.state().getMetaData().custom(PersistentTasksCustomMetaData.TYPE); + refresh(persistentTasks, ActionListener.wrap(aVoid -> refreshJobMemory(jobId, listener), listener::onFailure)); + } + + /** + * This refreshes the memory requirement for every ML job that has a corresponding persistent task. + * It does NOT remove entries for jobs that no longer have a persistent task, because that would + * lead to a race where a job was opened part way through the refresh. (Instead, entries are removed + * when jobs are deleted.) + */ + void refresh(PersistentTasksCustomMetaData persistentTasks, ActionListener onCompletion) { + + synchronized (fullRefreshCompletionListeners) { + fullRefreshCompletionListeners.add(onCompletion); + if (fullRefreshCompletionListeners.size() > 1) { + // A refresh is already in progress, so don't do another + return; + } + } + + ActionListener refreshComplete = ActionListener.wrap(aVoid -> { + lastUpdateTime = Instant.now(); + synchronized (fullRefreshCompletionListeners) { + assert fullRefreshCompletionListeners.isEmpty() == false; + for (ActionListener listener : fullRefreshCompletionListeners) { + listener.onResponse(null); + } + fullRefreshCompletionListeners.clear(); + } + }, onCompletion::onFailure); + + // persistentTasks will be null if there's never been a persistent task created in this cluster + if (persistentTasks == null) { + refreshComplete.onResponse(null); + } else { + List> mlJobTasks = persistentTasks.tasks().stream() + .filter(task -> MlTasks.JOB_TASK_NAME.equals(task.getTaskName())).collect(Collectors.toList()); + iterateMlJobTasks(mlJobTasks.iterator(), refreshComplete); + } + } + + private void recordUpdateTimeInClusterState(ActionListener listener) { + + clusterService.submitStateUpdateTask("ml-memory-last-update-time", + new AckedClusterStateUpdateTask(ACKED_REQUEST, listener) { + @Override + protected Boolean newResponse(boolean acknowledged) { + return acknowledged; + } + + @Override + public ClusterState execute(ClusterState currentState) { + MlMetadata currentMlMetadata = MlMetadata.getMlMetadata(currentState); + MlMetadata.Builder builder = new MlMetadata.Builder(currentMlMetadata); + builder.setLastMemoryRefreshVersion(currentState.getVersion() + 1); + MlMetadata newMlMetadata = builder.build(); + if (newMlMetadata.equals(currentMlMetadata)) { + // Return same reference if nothing has changed + return currentState; + } else { + ClusterState.Builder newState = ClusterState.builder(currentState); + newState.metaData(MetaData.builder(currentState.getMetaData()).putCustom(MlMetadata.TYPE, newMlMetadata).build()); + return newState.build(); + } + } + }); + } + + private void iterateMlJobTasks(Iterator> iterator, + ActionListener refreshComplete) { + if (iterator.hasNext()) { + OpenJobAction.JobParams jobParams = (OpenJobAction.JobParams) iterator.next().getParams(); + refreshJobMemory(jobParams.getJobId(), + ActionListener.wrap(mem -> iterateMlJobTasks(iterator, refreshComplete), refreshComplete::onFailure)); + } else { + refreshComplete.onResponse(null); + } + } + + /** + * Refresh the memory requirement for a single job. + * This method only works on the master node. + * @param jobId The ID of the job to refresh the memory requirement for. + * @param listener Receives the job's memory requirement, or null + * if it cannot be calculated. + */ + public void refreshJobMemory(String jobId, ActionListener listener) { + if (isMaster == false) { + listener.onResponse(null); + return; + } + + try { + jobResultsProvider.getEstablishedMemoryUsage(jobId, null, null, + establishedModelMemoryBytes -> { + if (establishedModelMemoryBytes <= 0L) { + setJobMemoryToLimit(jobId, listener); + } else { + Long memoryRequirementBytes = establishedModelMemoryBytes + Job.PROCESS_MEMORY_OVERHEAD.getBytes(); + memoryRequirementByJob.put(jobId, memoryRequirementBytes); + listener.onResponse(memoryRequirementBytes); + } + }, + e -> { + logger.error("[" + jobId + "] failed to calculate job established model memory requirement", e); + setJobMemoryToLimit(jobId, listener); + } + ); + } catch (Exception e) { + logger.error("[" + jobId + "] failed to calculate job established model memory requirement", e); + setJobMemoryToLimit(jobId, listener); + } + } + + private void setJobMemoryToLimit(String jobId, ActionListener listener) { + jobManager.getJob(jobId, ActionListener.wrap(job -> { + Long memoryLimitMb = job.getAnalysisLimits().getModelMemoryLimit(); + if (memoryLimitMb != null) { + Long memoryRequirementBytes = ByteSizeUnit.MB.toBytes(memoryLimitMb) + Job.PROCESS_MEMORY_OVERHEAD.getBytes(); + memoryRequirementByJob.put(jobId, memoryRequirementBytes); + listener.onResponse(memoryRequirementBytes); + } else { + memoryRequirementByJob.remove(jobId); + listener.onResponse(null); + } + }, e -> { + if (e instanceof ResourceNotFoundException) { + // TODO: does this also happen if the .ml-config index exists but is unavailable? + logger.trace("[{}] job deleted during ML memory update", jobId); + } else { + logger.error("[" + jobId + "] failed to get job during ML memory update", e); + } + memoryRequirementByJob.remove(jobId); + listener.onResponse(null); + })); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetadataTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetadataTests.java index c7ca2ff805eba..eb58221bf5f35 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetadataTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetadataTests.java @@ -69,6 +69,9 @@ protected MlMetadata createTestInstance() { builder.putJob(job, false); } } + if (randomBoolean()) { + builder.setLastMemoryRefreshVersion(randomNonNegativeLong()); + } return builder.build(); } @@ -438,8 +441,9 @@ protected MlMetadata mutateInstance(MlMetadata instance) { for (Map.Entry entry : datafeeds.entrySet()) { metadataBuilder.putDatafeed(entry.getValue(), Collections.emptyMap()); } + metadataBuilder.setLastMemoryRefreshVersion(instance.getLastMemoryRefreshVersion()); - switch (between(0, 1)) { + switch (between(0, 2)) { case 0: metadataBuilder.putJob(JobTests.createRandomizedJob(), true); break; @@ -459,6 +463,13 @@ protected MlMetadata mutateInstance(MlMetadata instance) { metadataBuilder.putJob(randomJob, false); metadataBuilder.putDatafeed(datafeedConfig, Collections.emptyMap()); break; + case 2: + if (instance.getLastMemoryRefreshVersion() == null) { + metadataBuilder.setLastMemoryRefreshVersion(randomNonNegativeLong()); + } else { + metadataBuilder.setLastMemoryRefreshVersion(null); + } + break; default: throw new AssertionError("Illegal randomisation branch"); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java index 4a98b380b0929..393fc492f5d63 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java @@ -50,7 +50,9 @@ import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; import org.elasticsearch.xpack.core.ml.notifications.AuditorField; import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; +import org.junit.Before; import java.io.IOException; import java.net.InetAddress; @@ -71,6 +73,14 @@ public class TransportOpenJobActionTests extends ESTestCase { + private MlMemoryTracker memoryTracker; + + @Before + public void setup() { + memoryTracker = mock(MlMemoryTracker.class); + when(memoryTracker.isRecentlyRefreshed()).thenReturn(true); + } + public void testValidate_jobMissing() { expectThrows(ResourceNotFoundException.class, () -> TransportOpenJobAction.validate("job_id2", null)); } @@ -125,7 +135,7 @@ public void testSelectLeastLoadedMlNode_byCount() { jobBuilder.setJobVersion(Version.CURRENT); Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id4", jobBuilder.build(), - cs.build(), 2, 10, 30, logger); + cs.build(), 2, 10, 30, memoryTracker, logger); assertEquals("", result.getExplanation()); assertEquals("_node_id3", result.getExecutorNode()); } @@ -161,7 +171,7 @@ public void testSelectLeastLoadedMlNode_maxCapacity() { Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id0", new ByteSizeValue(150, ByteSizeUnit.MB)).build(new Date()); Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id0", job, cs.build(), 2, - maxRunningJobsPerNode, 30, logger); + maxRunningJobsPerNode, 30, memoryTracker, logger); assertNull(result.getExecutorNode()); assertTrue(result.getExplanation().contains("because this node is full. Number of opened jobs [" + maxRunningJobsPerNode + "], xpack.ml.max_open_jobs [" + maxRunningJobsPerNode + "]")); @@ -187,7 +197,7 @@ public void testSelectLeastLoadedMlNode_noMlNodes() { Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id2", new ByteSizeValue(2, ByteSizeUnit.MB)).build(new Date()); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id2", job, cs.build(), 2, 10, 30, logger); + Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id2", job, cs.build(), 2, 10, 30, memoryTracker, logger); assertTrue(result.getExplanation().contains("because this node isn't a ml node")); assertNull(result.getExecutorNode()); } @@ -221,7 +231,7 @@ public void testSelectLeastLoadedMlNode_maxConcurrentOpeningJobs() { Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id6", new ByteSizeValue(2, ByteSizeUnit.MB)).build(new Date()); ClusterState cs = csBuilder.build(); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id6", job, cs, 2, 10, 30, logger); + Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id6", job, cs, 2, 10, 30, memoryTracker, logger); assertEquals("_node_id3", result.getExecutorNode()); tasksBuilder = PersistentTasksCustomMetaData.builder(tasks); @@ -231,7 +241,7 @@ public void testSelectLeastLoadedMlNode_maxConcurrentOpeningJobs() { csBuilder = ClusterState.builder(cs); csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); cs = csBuilder.build(); - result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, logger); + result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, memoryTracker, logger); assertNull("no node selected, because OPENING state", result.getExecutorNode()); assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); @@ -242,7 +252,7 @@ public void testSelectLeastLoadedMlNode_maxConcurrentOpeningJobs() { csBuilder = ClusterState.builder(cs); csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); cs = csBuilder.build(); - result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, logger); + result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, memoryTracker, logger); assertNull("no node selected, because stale task", result.getExecutorNode()); assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); @@ -253,7 +263,7 @@ public void testSelectLeastLoadedMlNode_maxConcurrentOpeningJobs() { csBuilder = ClusterState.builder(cs); csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); cs = csBuilder.build(); - result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, logger); + result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, memoryTracker, logger); assertNull("no node selected, because null state", result.getExecutorNode()); assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); } @@ -291,7 +301,7 @@ public void testSelectLeastLoadedMlNode_concurrentOpeningJobsAndStaleFailedJob() Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id7", new ByteSizeValue(2, ByteSizeUnit.MB)).build(new Date()); // Allocation won't be possible if the stale failed job is treated as opening - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, logger); + Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, memoryTracker, logger); assertEquals("_node_id1", result.getExecutorNode()); tasksBuilder = PersistentTasksCustomMetaData.builder(tasks); @@ -301,7 +311,7 @@ public void testSelectLeastLoadedMlNode_concurrentOpeningJobsAndStaleFailedJob() csBuilder = ClusterState.builder(cs); csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); cs = csBuilder.build(); - result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id8", job, cs, 2, 10, 30, logger); + result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id8", job, cs, 2, 10, 30, memoryTracker, logger); assertNull("no node selected, because OPENING state", result.getExecutorNode()); assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); } @@ -332,7 +342,8 @@ public void testSelectLeastLoadedMlNode_noCompatibleJobTypeNodes() { cs.nodes(nodes); metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); cs.metaData(metaData); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("incompatible_type_job", job, cs.build(), 2, 10, 30, logger); + Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("incompatible_type_job", job, cs.build(), 2, 10, 30, + memoryTracker, logger); assertThat(result.getExplanation(), containsString("because this node does not support jobs of type [incompatible_type]")); assertNull(result.getExecutorNode()); } @@ -359,7 +370,8 @@ public void testSelectLeastLoadedMlNode_noNodesPriorTo_V_5_5() { Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id7", new ByteSizeValue(2, ByteSizeUnit.MB)).build(new Date()); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("incompatible_type_job", job, cs.build(), 2, 10, 30, logger); + Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("incompatible_type_job", job, cs.build(), 2, 10, 30, + memoryTracker, logger); assertThat(result.getExplanation(), containsString("because this node does not support machine learning jobs")); assertNull(result.getExecutorNode()); } @@ -385,7 +397,8 @@ public void testSelectLeastLoadedMlNode_jobWithRulesButNoNodeMeetsRequiredVersio cs.metaData(metaData); Job job = jobWithRules("job_with_rules"); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_with_rules", job, cs.build(), 2, 10, 30, logger); + Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_with_rules", job, cs.build(), 2, 10, 30, memoryTracker, + logger); assertThat(result.getExplanation(), containsString( "because jobs using custom_rules require a node of version [6.4.0] or higher")); assertNull(result.getExecutorNode()); @@ -412,7 +425,8 @@ public void testSelectLeastLoadedMlNode_jobWithRulesAndNodeMeetsRequiredVersion( cs.metaData(metaData); Job job = jobWithRules("job_with_rules"); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_with_rules", job, cs.build(), 2, 10, 30, logger); + Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_with_rules", job, cs.build(), 2, 10, 30, memoryTracker, + logger); assertNotNull(result.getExecutorNode()); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/MlDistributedFailureIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/MlDistributedFailureIT.java index 2e14289da705e..5e4d8fd06030c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/MlDistributedFailureIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/MlDistributedFailureIT.java @@ -8,10 +8,13 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.CheckedRunnable; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.DeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentHelper; @@ -31,21 +34,32 @@ import org.elasticsearch.xpack.core.ml.action.PutJobAction; import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; +import org.elasticsearch.xpack.core.ml.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.job.config.JobState; import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts; +import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import static org.elasticsearch.persistent.PersistentTasksClusterService.needsReassignment; public class MlDistributedFailureIT extends BaseMlIntegTestCase { + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings.builder().put(super.nodeSettings(nodeOrdinal)) + .put(MachineLearning.CONCURRENT_JOB_ALLOCATIONS.getKey(), 4) + .build(); + } + public void testFailOver() throws Exception { internalCluster().ensureAtLeastNumDataNodes(3); ensureStableClusterOnAllNodes(3); @@ -58,8 +72,6 @@ public void testFailOver() throws Exception { }); } - @TestLogging("org.elasticsearch.xpack.ml.action:DEBUG,org.elasticsearch.xpack.persistent:TRACE," + - "org.elasticsearch.xpack.ml.datafeed:TRACE") public void testLoseDedicatedMasterNode() throws Exception { internalCluster().ensureAtMostNumDataNodes(0); logger.info("Starting dedicated master node..."); @@ -136,12 +148,12 @@ public void testCloseUnassignedJobAndDatafeed() throws Exception { // Job state is opened but the job is not assigned to a node (because we just killed the only ML node) GetJobsStatsAction.Request jobStatsRequest = new GetJobsStatsAction.Request(jobId); GetJobsStatsAction.Response jobStatsResponse = client().execute(GetJobsStatsAction.INSTANCE, jobStatsRequest).actionGet(); - assertEquals(jobStatsResponse.getResponse().results().get(0).getState(), JobState.OPENED); + assertEquals(JobState.OPENED, jobStatsResponse.getResponse().results().get(0).getState()); GetDatafeedsStatsAction.Request datafeedStatsRequest = new GetDatafeedsStatsAction.Request(datafeedId); GetDatafeedsStatsAction.Response datafeedStatsResponse = client().execute(GetDatafeedsStatsAction.INSTANCE, datafeedStatsRequest).actionGet(); - assertEquals(datafeedStatsResponse.getResponse().results().get(0).getDatafeedState(), DatafeedState.STARTED); + assertEquals(DatafeedState.STARTED, datafeedStatsResponse.getResponse().results().get(0).getDatafeedState()); // Can't normal stop an unassigned datafeed StopDatafeedAction.Request stopDatafeedRequest = new StopDatafeedAction.Request(datafeedId); @@ -170,6 +182,73 @@ public void testCloseUnassignedJobAndDatafeed() throws Exception { assertTrue(closeJobResponse.isClosed()); } + @TestLogging("org.elasticsearch.xpack.ml.action:TRACE,org.elasticsearch.xpack.ml.process:TRACE") + public void testJobRelocationIsMemoryAware() throws Exception { + + internalCluster().ensureAtLeastNumDataNodes(1); + ensureStableClusterOnAllNodes(1); + + // Open 4 small jobs. Since there is only 1 node in the cluster they'll have to go on that node. + + setupJobWithoutDatafeed("small1", new ByteSizeValue(2, ByteSizeUnit.MB)); + setupJobWithoutDatafeed("small2", new ByteSizeValue(2, ByteSizeUnit.MB)); + setupJobWithoutDatafeed("small3", new ByteSizeValue(2, ByteSizeUnit.MB)); + setupJobWithoutDatafeed("small4", new ByteSizeValue(2, ByteSizeUnit.MB)); + + // Expand the cluster to 3 nodes. The 4 small jobs will stay on the + // same node because we don't rebalance jobs that are happily running. + + internalCluster().ensureAtLeastNumDataNodes(3); + ensureStableClusterOnAllNodes(3); + + // Open a big job. This should go on a different node to the 4 small ones. + + setupJobWithoutDatafeed("big1", new ByteSizeValue(500, ByteSizeUnit.MB)); + + // Stop the current master node - this should be the one with the 4 small jobs on. + + internalCluster().stopCurrentMasterNode(); + ensureStableClusterOnAllNodes(2); + + // If memory requirements are used to reallocate the 4 small jobs (as we expect) then they should + // all reallocate to the same node, that being the one that doesn't have the big job on. If job counts + // are used to reallocate the small jobs then this implies the fallback allocation mechanism has been + // used in a situation we don't want it to be used in, and at least one of the small jobs will be on + // the same node as the big job. (This all relies on xpack.ml.node_concurrent_job_allocations being set + // to at least 4, which we do in the nodeSettings() method.) + + assertBusy(() -> { + GetJobsStatsAction.Response statsResponse = + client().execute(GetJobsStatsAction.INSTANCE, new GetJobsStatsAction.Request(MetaData.ALL)).actionGet(); + QueryPage jobStats = statsResponse.getResponse(); + assertNotNull(jobStats); + List smallJobNodes = jobStats.results().stream().filter(s -> s.getJobId().startsWith("small") && s.getNode() != null) + .map(s -> s.getNode().getName()).collect(Collectors.toList()); + List bigJobNodes = jobStats.results().stream().filter(s -> s.getJobId().startsWith("big") && s.getNode() != null) + .map(s -> s.getNode().getName()).collect(Collectors.toList()); + logger.info("small job nodes: " + smallJobNodes + ", big job nodes: " + bigJobNodes); + assertEquals(5, jobStats.count()); + assertEquals(4, smallJobNodes.size()); + assertEquals(1, bigJobNodes.size()); + assertEquals(1L, smallJobNodes.stream().distinct().count()); + assertEquals(1L, bigJobNodes.stream().distinct().count()); + assertNotEquals(smallJobNodes, bigJobNodes); + }); + } + + private void setupJobWithoutDatafeed(String jobId, ByteSizeValue modelMemoryLimit) throws Exception { + Job.Builder job = createFareQuoteJob(jobId, modelMemoryLimit); + PutJobAction.Request putJobRequest = new PutJobAction.Request(job); + client().execute(PutJobAction.INSTANCE, putJobRequest).actionGet(); + + client().execute(OpenJobAction.INSTANCE, new OpenJobAction.Request(job.getId())).actionGet(); + assertBusy(() -> { + GetJobsStatsAction.Response statsResponse = + client().execute(GetJobsStatsAction.INSTANCE, new GetJobsStatsAction.Request(job.getId())).actionGet(); + assertEquals(JobState.OPENED, statsResponse.getResponse().results().get(0).getState()); + }); + } + private void setupJobAndDatafeed(String jobId, String datafeedId) throws Exception { Job.Builder job = createScheduledJob(jobId); PutJobAction.Request putJobRequest = new PutJobAction.Request(job); @@ -183,7 +262,7 @@ private void setupJobAndDatafeed(String jobId, String datafeedId) throws Excepti assertBusy(() -> { GetJobsStatsAction.Response statsResponse = client().execute(GetJobsStatsAction.INSTANCE, new GetJobsStatsAction.Request(job.getId())).actionGet(); - assertEquals(statsResponse.getResponse().results().get(0).getState(), JobState.OPENED); + assertEquals(JobState.OPENED, statsResponse.getResponse().results().get(0).getState()); }); StartDatafeedAction.Request startDatafeedRequest = new StartDatafeedAction.Request(config.getId(), 0L); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TooManyJobsIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TooManyJobsIT.java index 87aa3c5b926e3..c4150d633a8f0 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TooManyJobsIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TooManyJobsIT.java @@ -123,12 +123,10 @@ public void testLazyNodeValidation() throws Exception { }); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/34084") public void testSingleNode() throws Exception { verifyMaxNumberOfJobsLimit(1, randomIntBetween(1, 100)); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/34084") public void testMultipleNodes() throws Exception { verifyMaxNumberOfJobsLimit(3, randomIntBetween(1, 100)); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java new file mode 100644 index 0000000000000..cbba7ffa04972 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java @@ -0,0 +1,195 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.process; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.AckedClusterStateUpdateTask; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.MlMetadata; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.OpenJobAction; +import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; +import org.elasticsearch.xpack.core.ml.job.config.Job; +import org.elasticsearch.xpack.ml.job.JobManager; +import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider; +import org.junit.Before; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class MlMemoryTrackerTests extends ESTestCase { + + private ClusterService clusterService; + private ThreadPool threadPool; + private JobManager jobManager; + private JobResultsProvider jobResultsProvider; + private MlMemoryTracker memoryTracker; + + @Before + public void setup() { + + clusterService = mock(ClusterService.class); + threadPool = mock(ThreadPool.class); + ExecutorService executorService = mock(ExecutorService.class); + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + Runnable r = (Runnable) invocation.getArguments()[0]; + r.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + when(threadPool.executor(anyString())).thenReturn(executorService); + jobManager = mock(JobManager.class); + jobResultsProvider = mock(JobResultsProvider.class); + memoryTracker = new MlMemoryTracker(clusterService, threadPool, jobManager, jobResultsProvider); + } + + public void testRefreshAll() { + + boolean isMaster = randomBoolean(); + if (isMaster) { + memoryTracker.onMaster(); + } else { + memoryTracker.offMaster(); + } + + int numMlJobTasks = randomIntBetween(2, 5); + Map> tasks = new HashMap<>(); + for (int i = 1; i <= numMlJobTasks; ++i) { + String jobId = "job" + i; + PersistentTasksCustomMetaData.PersistentTask task = makeTestTask(jobId); + tasks.put(task.getId(), task); + } + PersistentTasksCustomMetaData persistentTasks = new PersistentTasksCustomMetaData(numMlJobTasks, tasks); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + Consumer listener = (Consumer) invocation.getArguments()[3]; + listener.accept(randomLongBetween(1000, 1000000)); + return null; + }).when(jobResultsProvider).getEstablishedMemoryUsage(anyString(), any(), any(), any(Consumer.class), any()); + + memoryTracker.refresh(persistentTasks, ActionListener.wrap(aVoid -> {}, ESTestCase::assertNull)); + + if (isMaster) { + for (int i = 1; i <= numMlJobTasks; ++i) { + String jobId = "job" + i; + verify(jobResultsProvider, times(1)).getEstablishedMemoryUsage(eq(jobId), any(), any(), any(), any()); + } + } else { + verify(jobResultsProvider, never()).getEstablishedMemoryUsage(anyString(), any(), any(), any(), any()); + } + } + + public void testRefreshOne() { + + boolean isMaster = randomBoolean(); + if (isMaster) { + memoryTracker.onMaster(); + } else { + memoryTracker.offMaster(); + } + + String jobId = "job"; + boolean haveEstablishedModelMemory = randomBoolean(); + + long modelBytes = 1024 * 1024; + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + Consumer listener = (Consumer) invocation.getArguments()[3]; + listener.accept(haveEstablishedModelMemory ? modelBytes : 0L); + return null; + }).when(jobResultsProvider).getEstablishedMemoryUsage(eq(jobId), any(), any(), any(Consumer.class), any()); + + long modelMemoryLimitMb = 2; + Job job = mock(Job.class); + when(job.getAnalysisLimits()).thenReturn(new AnalysisLimits(modelMemoryLimitMb, 4L)); + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(job); + return null; + }).when(jobManager).getJob(eq(jobId), any(ActionListener.class)); + + AtomicReference refreshedMemoryRequirement = new AtomicReference<>(); + memoryTracker.refreshJobMemory(jobId, ActionListener.wrap(refreshedMemoryRequirement::set, ESTestCase::assertNull)); + + if (isMaster) { + if (haveEstablishedModelMemory) { + assertEquals(Long.valueOf(modelBytes + Job.PROCESS_MEMORY_OVERHEAD.getBytes()), + memoryTracker.getJobMemoryRequirement(jobId)); + } else { + assertEquals(Long.valueOf(ByteSizeUnit.MB.toBytes(modelMemoryLimitMb) + Job.PROCESS_MEMORY_OVERHEAD.getBytes()), + memoryTracker.getJobMemoryRequirement(jobId)); + } + } else { + assertNull(memoryTracker.getJobMemoryRequirement(jobId)); + } + + assertEquals(memoryTracker.getJobMemoryRequirement(jobId), refreshedMemoryRequirement.get()); + + memoryTracker.removeJob(jobId); + assertNull(memoryTracker.getJobMemoryRequirement(jobId)); + } + + @SuppressWarnings("unchecked") + public void testRecordUpdateTimeInClusterState() { + + boolean isMaster = randomBoolean(); + if (isMaster) { + memoryTracker.onMaster(); + } else { + memoryTracker.offMaster(); + } + + when(clusterService.state()).thenReturn(ClusterState.EMPTY_STATE); + + AtomicReference updateVersion = new AtomicReference<>(); + + doAnswer(invocation -> { + AckedClusterStateUpdateTask task = (AckedClusterStateUpdateTask) invocation.getArguments()[1]; + ClusterState currentClusterState = ClusterState.EMPTY_STATE; + ClusterState newClusterState = task.execute(currentClusterState); + assertThat(currentClusterState, not(equalTo(newClusterState))); + MlMetadata newMlMetadata = MlMetadata.getMlMetadata(newClusterState); + updateVersion.set(newMlMetadata.getLastMemoryRefreshVersion()); + task.onAllNodesAcked(null); + return null; + }).when(clusterService).submitStateUpdateTask(anyString(), any(AckedClusterStateUpdateTask.class)); + + memoryTracker.asyncRefresh(ActionListener.wrap(ESTestCase::assertTrue, ESTestCase::assertNull)); + + if (isMaster) { + assertNotNull(updateVersion.get()); + } else { + assertNull(updateVersion.get()); + } + } + + private PersistentTasksCustomMetaData.PersistentTask makeTestTask(String jobId) { + return new PersistentTasksCustomMetaData.PersistentTask<>("job-" + jobId, MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams(jobId), + 0, PersistentTasksCustomMetaData.INITIAL_ASSIGNMENT); + } +}