Skip to content

Commit

Permalink
[ML] Reimplement established model memory (#35263)
Browse files Browse the repository at this point in the history
This is the 6.6/6.7 implementation of a master node service to
keep track of the native process memory requirement of each ML
job with an associated native process.

The new ML memory tracker service works when the whole cluster
is upgraded to at least version 6.6.  For mixed version clusters
the old mechanism of established model memory stored on the job
in cluster state is used.  This means that the old (and complex)
code to keep established model memory up to date on the job object
cannot yet be removed.  When this change is forward ported to 7.0
the old way of keeping established model memory updated will be
removed.
  • Loading branch information
droberts195 committed Nov 13, 2018
1 parent 3a0a5e7 commit 2b6cd7a
Show file tree
Hide file tree
Showing 13 changed files with 853 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,28 @@ public class MlMetadata implements XPackPlugin.XPackMetaDataCustom {
public static final String TYPE = "ml";
private static final ParseField JOBS_FIELD = new ParseField("jobs");
private static final ParseField DATAFEEDS_FIELD = new ParseField("datafeeds");
private static final ParseField LAST_MEMORY_REFRESH_VERSION_FIELD = new ParseField("last_memory_refresh_version");

public static final MlMetadata EMPTY_METADATA = new MlMetadata(Collections.emptySortedMap(), Collections.emptySortedMap());
public static final MlMetadata EMPTY_METADATA = new MlMetadata(Collections.emptySortedMap(), Collections.emptySortedMap(), null);
// This parser follows the pattern that metadata is parsed leniently (to allow for enhancements)
public static final ObjectParser<Builder, Void> LENIENT_PARSER = new ObjectParser<>("ml_metadata", true, Builder::new);

static {
LENIENT_PARSER.declareObjectArray(Builder::putJobs, (p, c) -> Job.LENIENT_PARSER.apply(p, c).build(), JOBS_FIELD);
LENIENT_PARSER.declareObjectArray(Builder::putDatafeeds,
(p, c) -> DatafeedConfig.LENIENT_PARSER.apply(p, c).build(), DATAFEEDS_FIELD);
LENIENT_PARSER.declareLong(Builder::setLastMemoryRefreshVersion, LAST_MEMORY_REFRESH_VERSION_FIELD);
}

private final SortedMap<String, Job> jobs;
private final SortedMap<String, DatafeedConfig> datafeeds;
private final Long lastMemoryRefreshVersion;
private final GroupOrJobLookup groupOrJobLookup;

private MlMetadata(SortedMap<String, Job> jobs, SortedMap<String, DatafeedConfig> datafeeds) {
private MlMetadata(SortedMap<String, Job> jobs, SortedMap<String, DatafeedConfig> datafeeds, Long lastMemoryRefreshVersion) {
this.jobs = Collections.unmodifiableSortedMap(jobs);
this.datafeeds = Collections.unmodifiableSortedMap(datafeeds);
this.lastMemoryRefreshVersion = lastMemoryRefreshVersion;
this.groupOrJobLookup = new GroupOrJobLookup(jobs.values());
}

Expand Down Expand Up @@ -112,6 +116,10 @@ public Set<String> expandDatafeedIds(String expression, boolean allowNoDatafeeds
.expand(expression, allowNoDatafeeds);
}

public Long getLastMemoryRefreshVersion() {
return lastMemoryRefreshVersion;
}

@Override
public Version getMinimalSupportedVersion() {
return Version.V_5_4_0;
Expand Down Expand Up @@ -145,14 +153,21 @@ public MlMetadata(StreamInput in) throws IOException {
datafeeds.put(in.readString(), new DatafeedConfig(in));
}
this.datafeeds = datafeeds;

if (in.getVersion().onOrAfter(Version.V_6_6_0)) {
lastMemoryRefreshVersion = in.readOptionalLong();
} else {
lastMemoryRefreshVersion = null;
}
this.groupOrJobLookup = new GroupOrJobLookup(jobs.values());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
writeMap(jobs, out);
writeMap(datafeeds, out);
if (out.getVersion().onOrAfter(Version.V_6_6_0)) {
out.writeOptionalLong(lastMemoryRefreshVersion);
}
}

private static <T extends Writeable> void writeMap(Map<String, T> map, StreamOutput out) throws IOException {
Expand All @@ -169,6 +184,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
new DelegatingMapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"), params);
mapValuesToXContent(JOBS_FIELD, jobs, builder, extendedParams);
mapValuesToXContent(DATAFEEDS_FIELD, datafeeds, builder, extendedParams);
if (lastMemoryRefreshVersion != null) {
builder.field(LAST_MEMORY_REFRESH_VERSION_FIELD.getPreferredName(), lastMemoryRefreshVersion);
}
return builder;
}

Expand All @@ -185,30 +203,46 @@ public static class MlMetadataDiff implements NamedDiff<MetaData.Custom> {

final Diff<Map<String, Job>> jobs;
final Diff<Map<String, DatafeedConfig>> datafeeds;
final Long lastMemoryRefreshVersion;

MlMetadataDiff(MlMetadata before, MlMetadata after) {
this.jobs = DiffableUtils.diff(before.jobs, after.jobs, DiffableUtils.getStringKeySerializer());
this.datafeeds = DiffableUtils.diff(before.datafeeds, after.datafeeds, DiffableUtils.getStringKeySerializer());
this.lastMemoryRefreshVersion = after.lastMemoryRefreshVersion;
}

public MlMetadataDiff(StreamInput in) throws IOException {
this.jobs = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(), Job::new,
MlMetadataDiff::readJobDiffFrom);
this.datafeeds = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(), DatafeedConfig::new,
MlMetadataDiff::readSchedulerDiffFrom);
MlMetadataDiff::readDatafeedDiffFrom);
if (in.getVersion().onOrAfter(Version.V_6_6_0)) {
lastMemoryRefreshVersion = in.readOptionalLong();
} else {
lastMemoryRefreshVersion = null;
}
}

/**
* Merge the diff with the ML metadata.
* @param part The current ML metadata.
* @return The new ML metadata.
*/
@Override
public MetaData.Custom apply(MetaData.Custom part) {
TreeMap<String, Job> newJobs = new TreeMap<>(jobs.apply(((MlMetadata) part).jobs));
TreeMap<String, DatafeedConfig> newDatafeeds = new TreeMap<>(datafeeds.apply(((MlMetadata) part).datafeeds));
return new MlMetadata(newJobs, newDatafeeds);
// lastMemoryRefreshVersion always comes from the diff - no need to merge with the old value
return new MlMetadata(newJobs, newDatafeeds, lastMemoryRefreshVersion);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
jobs.writeTo(out);
datafeeds.writeTo(out);
if (out.getVersion().onOrAfter(Version.V_6_6_0)) {
out.writeOptionalLong(lastMemoryRefreshVersion);
}
}

@Override
Expand All @@ -220,7 +254,7 @@ static Diff<Job> readJobDiffFrom(StreamInput in) throws IOException {
return AbstractDiffable.readDiffFrom(Job::new, in);
}

static Diff<DatafeedConfig> readSchedulerDiffFrom(StreamInput in) throws IOException {
static Diff<DatafeedConfig> readDatafeedDiffFrom(StreamInput in) throws IOException {
return AbstractDiffable.readDiffFrom(DatafeedConfig::new, in);
}
}
Expand All @@ -233,7 +267,8 @@ public boolean equals(Object o) {
return false;
MlMetadata that = (MlMetadata) o;
return Objects.equals(jobs, that.jobs) &&
Objects.equals(datafeeds, that.datafeeds);
Objects.equals(datafeeds, that.datafeeds) &&
Objects.equals(lastMemoryRefreshVersion, that.lastMemoryRefreshVersion);
}

@Override
Expand All @@ -243,13 +278,14 @@ public final String toString() {

@Override
public int hashCode() {
return Objects.hash(jobs, datafeeds);
return Objects.hash(jobs, datafeeds, lastMemoryRefreshVersion);
}

public static class Builder {

private TreeMap<String, Job> jobs;
private TreeMap<String, DatafeedConfig> datafeeds;
private Long lastMemoryRefreshVersion;

public Builder() {
jobs = new TreeMap<>();
Expand All @@ -263,6 +299,7 @@ public Builder(@Nullable MlMetadata previous) {
} else {
jobs = new TreeMap<>(previous.jobs);
datafeeds = new TreeMap<>(previous.datafeeds);
lastMemoryRefreshVersion = previous.lastMemoryRefreshVersion;
}
}

Expand Down Expand Up @@ -382,8 +419,13 @@ private Builder putDatafeeds(Collection<DatafeedConfig> datafeeds) {
return this;
}

public Builder setLastMemoryRefreshVersion(Long lastMemoryRefreshVersion) {
this.lastMemoryRefreshVersion = lastMemoryRefreshVersion;
return this;
}

public MlMetadata build() {
return new MlMetadata(jobs, datafeeds);
return new MlMetadata(jobs, datafeeds, lastMemoryRefreshVersion);
}

public void markJobAsDeleting(String jobId, PersistentTasksCustomMetaData tasks, boolean allowDeleteOpenJob) {
Expand Down Expand Up @@ -420,8 +462,6 @@ void checkJobHasNoDatafeed(String jobId) {
}
}



public static MlMetadata getMlMetadata(ClusterState state) {
MlMetadata mlMetadata = (state == null) ? null : state.getMetaData().custom(TYPE);
if (mlMetadata == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ private static ObjectParser<Builder, Void> createParser(boolean ignoreUnknownFie
private final Date createTime;
private final Date finishedTime;
private final Date lastDataTime;
// TODO: Remove in 7.0
private final Long establishedModelMemory;
private final AnalysisConfig analysisConfig;
private final AnalysisLimits analysisLimits;
Expand Down Expand Up @@ -439,6 +440,7 @@ public Collection<String> allInputFields() {
* program code and stack.
* @return an estimate of the memory requirement of this job, in bytes
*/
// TODO: remove this method in 7.0
public long estimateMemoryFootprint() {
if (establishedModelMemory != null && establishedModelMemory > 0) {
return establishedModelMemory + PROCESS_MEMORY_OVERHEAD.getBytes();
Expand Down Expand Up @@ -658,6 +660,7 @@ public static class Builder implements Writeable, ToXContentObject {
private Date createTime;
private Date finishedTime;
private Date lastDataTime;
// TODO: remove in 7.0
private Long establishedModelMemory;
private ModelPlotConfig modelPlotConfig;
private Long renormalizationWindowDays;
Expand Down Expand Up @@ -1102,10 +1105,6 @@ private void validateGroups() {
public Job build(Date createTime) {
setCreateTime(createTime);
setJobVersion(Version.CURRENT);
// TODO: Maybe we _could_ accept a value for this supplied at create time - it would
// mean cloned jobs that hadn't been edited much would start with an accurate expected size.
// But on the other hand it would mean jobs that were cloned and then completely changed
// would start with a size that was completely wrong.
setEstablishedModelMemory(null);
return build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ public void testEstimateMemoryFootprint_GivenNoLimitAndNotEstablished() {
builder.setEstablishedModelMemory(0L);
}
assertEquals(ByteSizeUnit.MB.toBytes(AnalysisLimits.PRE_6_1_DEFAULT_MODEL_MEMORY_LIMIT_MB)
+ Job.PROCESS_MEMORY_OVERHEAD.getBytes(), builder.build().estimateMemoryFootprint());
+ Job.PROCESS_MEMORY_OVERHEAD.getBytes(), builder.build().estimateMemoryFootprint());
}

public void testEarliestValidTimestamp_GivenEmptyDataCounts() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerFactory;
import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory;
import org.elasticsearch.xpack.ml.notifications.Auditor;
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.NativeControllerHolder;
import org.elasticsearch.xpack.ml.rest.RestDeleteExpiredDataAction;
Expand Down Expand Up @@ -278,6 +279,7 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu

private final SetOnce<AutodetectProcessManager> autodetectProcessManager = new SetOnce<>();
private final SetOnce<DatafeedManager> datafeedManager = new SetOnce<>();
private final SetOnce<MlMemoryTracker> memoryTracker = new SetOnce<>();

public MachineLearning(Settings settings, Path configPath) {
this.settings = settings;
Expand Down Expand Up @@ -420,6 +422,8 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
this.datafeedManager.set(datafeedManager);
MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(environment, clusterService, datafeedManager,
autodetectProcessManager);
MlMemoryTracker memoryTracker = new MlMemoryTracker(clusterService, threadPool, jobManager, jobResultsProvider);
this.memoryTracker.set(memoryTracker);

// This object's constructor attaches to the license state, so there's no need to retain another reference to it
new InvalidLicenseEnforcer(getLicenseState(), threadPool, datafeedManager, autodetectProcessManager);
Expand All @@ -438,7 +442,8 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
jobDataCountsPersister,
datafeedManager,
auditor,
new MlAssignmentNotifier(auditor, clusterService)
new MlAssignmentNotifier(auditor, clusterService),
memoryTracker
);
}

Expand All @@ -449,7 +454,8 @@ public List<PersistentTasksExecutor<?>> getPersistentTasksExecutor(ClusterServic
}

return Arrays.asList(
new TransportOpenJobAction.OpenJobPersistentTasksExecutor(settings, clusterService, autodetectProcessManager.get()),
new TransportOpenJobAction.OpenJobPersistentTasksExecutor(settings, clusterService, autodetectProcessManager.get(),
memoryTracker.get()),
new TransportStartDatafeedAction.StartDatafeedPersistentTasksExecutor(settings, datafeedManager.get())
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import org.elasticsearch.xpack.ml.job.persistence.JobDataDeleter;
import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider;
import org.elasticsearch.xpack.ml.notifications.Auditor;
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
import org.elasticsearch.xpack.ml.utils.MlIndicesUtils;

import java.util.ArrayList;
Expand All @@ -94,6 +95,7 @@ public class TransportDeleteJobAction extends TransportMasterNodeAction<DeleteJo
private final JobResultsProvider jobResultsProvider;
private final JobConfigProvider jobConfigProvider;
private final DatafeedConfigProvider datafeedConfigProvider;
private final MlMemoryTracker memoryTracker;

/**
* A map of task listeners by job_id.
Expand All @@ -108,7 +110,8 @@ public TransportDeleteJobAction(Settings settings, TransportService transportSer
ThreadPool threadPool, ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver, PersistentTasksService persistentTasksService,
Client client, Auditor auditor, JobResultsProvider jobResultsProvider,
JobConfigProvider jobConfigProvider, DatafeedConfigProvider datafeedConfigProvider) {
JobConfigProvider jobConfigProvider, DatafeedConfigProvider datafeedConfigProvider,
MlMemoryTracker memoryTracker) {
super(settings, DeleteJobAction.NAME, transportService, clusterService, threadPool, actionFilters,
indexNameExpressionResolver, DeleteJobAction.Request::new);
this.client = client;
Expand All @@ -117,6 +120,7 @@ public TransportDeleteJobAction(Settings settings, TransportService transportSer
this.jobResultsProvider = jobResultsProvider;
this.jobConfigProvider = jobConfigProvider;
this.datafeedConfigProvider = datafeedConfigProvider;
this.memoryTracker = memoryTracker;
this.listenersByJobId = new HashMap<>();
}

Expand Down Expand Up @@ -211,6 +215,9 @@ private void normalDeleteJob(ParentTaskAssigningClient parentTaskClient, DeleteJ
ActionListener<AcknowledgedResponse> listener) {
String jobId = request.getJobId();

// We clean up the memory tracker on delete rather than close as close is not a master node action
memoryTracker.removeJob(jobId);

// Step 4. When the job has been removed from the cluster state, return a response
// -------
CheckedConsumer<Boolean, Exception> apiResponseHandler = jobDeleted -> {
Expand Down
Loading

0 comments on commit 2b6cd7a

Please sign in to comment.