Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Reimplement established model memory #35263

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;

import java.io.IOException;
import java.time.Instant;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
Expand All @@ -57,24 +58,28 @@ public class MlMetadata implements XPackPlugin.XPackMetaDataCustom {
public static final String TYPE = "ml";
private static final ParseField JOBS_FIELD = new ParseField("jobs");
private static final ParseField DATAFEEDS_FIELD = new ParseField("datafeeds");
private static final ParseField LAST_MEMORY_REFRESH_TIME_FIELD = new ParseField("last_memory_refresh_time");

public static final MlMetadata EMPTY_METADATA = new MlMetadata(Collections.emptySortedMap(), Collections.emptySortedMap());
public static final MlMetadata EMPTY_METADATA = new MlMetadata(Collections.emptySortedMap(), Collections.emptySortedMap(), null);
// This parser follows the pattern that metadata is parsed leniently (to allow for enhancements)
public static final ObjectParser<Builder, Void> LENIENT_PARSER = new ObjectParser<>("ml_metadata", true, Builder::new);

static {
LENIENT_PARSER.declareObjectArray(Builder::putJobs, (p, c) -> Job.LENIENT_PARSER.apply(p, c).build(), JOBS_FIELD);
LENIENT_PARSER.declareObjectArray(Builder::putDatafeeds,
(p, c) -> DatafeedConfig.LENIENT_PARSER.apply(p, c).build(), DATAFEEDS_FIELD);
LENIENT_PARSER.declareLong(Builder::setLastMemoryRefreshTimeMs, LAST_MEMORY_REFRESH_TIME_FIELD);
}

private final SortedMap<String, Job> jobs;
private final SortedMap<String, DatafeedConfig> datafeeds;
private final Instant lastMemoryRefreshTime;
private final GroupOrJobLookup groupOrJobLookup;

private MlMetadata(SortedMap<String, Job> jobs, SortedMap<String, DatafeedConfig> datafeeds) {
private MlMetadata(SortedMap<String, Job> jobs, SortedMap<String, DatafeedConfig> datafeeds, Instant lastMemoryRefreshTime) {
this.jobs = Collections.unmodifiableSortedMap(jobs);
this.datafeeds = Collections.unmodifiableSortedMap(datafeeds);
this.lastMemoryRefreshTime = lastMemoryRefreshTime;
this.groupOrJobLookup = new GroupOrJobLookup(jobs.values());
}

Expand Down Expand Up @@ -112,6 +117,10 @@ public Set<String> expandDatafeedIds(String expression, boolean allowNoDatafeeds
.expand(expression, allowNoDatafeeds);
}

public Instant getLastMemoryRefreshTime() {
return lastMemoryRefreshTime;
}

@Override
public Version getMinimalSupportedVersion() {
return Version.V_5_4_0;
Expand Down Expand Up @@ -145,14 +154,27 @@ public MlMetadata(StreamInput in) throws IOException {
datafeeds.put(in.readString(), new DatafeedConfig(in));
}
this.datafeeds = datafeeds;

if (in.getVersion().onOrAfter(Version.V_6_6_0)) {
lastMemoryRefreshTime = in.readBoolean() ? Instant.ofEpochSecond(in.readVLong(), in.readVInt()) : null;
} else {
lastMemoryRefreshTime = null;
}
this.groupOrJobLookup = new GroupOrJobLookup(jobs.values());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
writeMap(jobs, out);
writeMap(datafeeds, out);
if (out.getVersion().onOrAfter(Version.V_6_6_0)) {
if (lastMemoryRefreshTime == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeVLong(lastMemoryRefreshTime.getEpochSecond());
out.writeVInt(lastMemoryRefreshTime.getNano());
}
}
}

private static <T extends Writeable> void writeMap(Map<String, T> map, StreamOutput out) throws IOException {
Expand All @@ -169,6 +191,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
new DelegatingMapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"), params);
mapValuesToXContent(JOBS_FIELD, jobs, builder, extendedParams);
mapValuesToXContent(DATAFEEDS_FIELD, datafeeds, builder, extendedParams);
if (lastMemoryRefreshTime != null) {
// We lose precision lower than milliseconds here - OK as millisecond precision is adequate for this use case
builder.timeField(LAST_MEMORY_REFRESH_TIME_FIELD.getPreferredName(),
LAST_MEMORY_REFRESH_TIME_FIELD.getPreferredName() + "_string", lastMemoryRefreshTime.toEpochMilli());
}
return builder;
}

Expand All @@ -185,30 +212,47 @@ public static class MlMetadataDiff implements NamedDiff<MetaData.Custom> {

final Diff<Map<String, Job>> jobs;
final Diff<Map<String, DatafeedConfig>> datafeeds;
final Instant lastMemoryRefreshTime;

MlMetadataDiff(MlMetadata before, MlMetadata after) {
this.jobs = DiffableUtils.diff(before.jobs, after.jobs, DiffableUtils.getStringKeySerializer());
this.datafeeds = DiffableUtils.diff(before.datafeeds, after.datafeeds, DiffableUtils.getStringKeySerializer());
this.lastMemoryRefreshTime = after.lastMemoryRefreshTime;
}

public MlMetadataDiff(StreamInput in) throws IOException {
this.jobs = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(), Job::new,
MlMetadataDiff::readJobDiffFrom);
this.datafeeds = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(), DatafeedConfig::new,
MlMetadataDiff::readSchedulerDiffFrom);
if (in.getVersion().onOrAfter(Version.V_6_6_0)) {
lastMemoryRefreshTime = in.readBoolean() ? Instant.ofEpochSecond(in.readVLong(), in.readVInt()) : null;
} else {
lastMemoryRefreshTime = null;
}
}

@Override
public MetaData.Custom apply(MetaData.Custom part) {
TreeMap<String, Job> newJobs = new TreeMap<>(jobs.apply(((MlMetadata) part).jobs));
TreeMap<String, DatafeedConfig> newDatafeeds = new TreeMap<>(datafeeds.apply(((MlMetadata) part).datafeeds));
return new MlMetadata(newJobs, newDatafeeds);
Instant lastMemoryRefreshTime = ((MlMetadata) part).lastMemoryRefreshTime;
return new MlMetadata(newJobs, newDatafeeds, lastMemoryRefreshTime);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
jobs.writeTo(out);
datafeeds.writeTo(out);
if (out.getVersion().onOrAfter(Version.V_6_6_0)) {
if (lastMemoryRefreshTime == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeVLong(lastMemoryRefreshTime.getEpochSecond());
out.writeVInt(lastMemoryRefreshTime.getNano());
}
}
}

@Override
Expand All @@ -233,7 +277,8 @@ public boolean equals(Object o) {
return false;
MlMetadata that = (MlMetadata) o;
return Objects.equals(jobs, that.jobs) &&
Objects.equals(datafeeds, that.datafeeds);
Objects.equals(datafeeds, that.datafeeds) &&
Objects.equals(lastMemoryRefreshTime, that.lastMemoryRefreshTime);
}

@Override
Expand All @@ -243,13 +288,14 @@ public final String toString() {

@Override
public int hashCode() {
return Objects.hash(jobs, datafeeds);
return Objects.hash(jobs, datafeeds, lastMemoryRefreshTime);
}

public static class Builder {

private TreeMap<String, Job> jobs;
private TreeMap<String, DatafeedConfig> datafeeds;
private Instant lastMemoryRefreshTime;

public Builder() {
jobs = new TreeMap<>();
Expand All @@ -263,6 +309,7 @@ public Builder(@Nullable MlMetadata previous) {
} else {
jobs = new TreeMap<>(previous.jobs);
datafeeds = new TreeMap<>(previous.datafeeds);
lastMemoryRefreshTime = previous.lastMemoryRefreshTime;
}
}

Expand Down Expand Up @@ -382,8 +429,18 @@ private Builder putDatafeeds(Collection<DatafeedConfig> datafeeds) {
return this;
}

Builder setLastMemoryRefreshTimeMs(long lastMemoryRefreshTimeMs) {
lastMemoryRefreshTime = Instant.ofEpochMilli(lastMemoryRefreshTimeMs);
return this;
}

public Builder setLastMemoryRefreshTime(Instant lastMemoryRefreshTime) {
this.lastMemoryRefreshTime = lastMemoryRefreshTime;
return this;
}

public MlMetadata build() {
return new MlMetadata(jobs, datafeeds);
return new MlMetadata(jobs, datafeeds, lastMemoryRefreshTime);
}

public void markJobAsDeleting(String jobId, PersistentTasksCustomMetaData tasks, boolean allowDeleteOpenJob) {
Expand Down Expand Up @@ -420,8 +477,6 @@ void checkJobHasNoDatafeed(String jobId) {
}
}



public static MlMetadata getMlMetadata(ClusterState state) {
MlMetadata mlMetadata = (state == null) ? null : state.getMetaData().custom(TYPE);
if (mlMetadata == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@ static PersistentTasksCustomMetaData.Assignment selectLeastLoadedMlNode(String j
if (memoryTracker.isRecentlyRefreshed() == false) {

boolean scheduledRefresh = memoryTracker.asyncRefresh(ActionListener.wrap(
aVoid -> {
// TODO: find a way to get the persistent task framework to do another reassignment check BLOCKER!
// Persistent task allocation reacts to custom metadata changes, so one way would be to retain the
// MlMetadata as a single number that we increment when we want to kick persistent tasks.
// A less sneaky way would be to introduce an internal action specifically for the purpose of
// asking persistent tasks to re-check whether unallocated tasks can be allocated.
acknowledged -> {
if (acknowledged) {
logger.trace("Job memory requirement refresh request completed successfully");
} else {
logger.warn("Job memory requirement refresh request completed but did not set time in cluster state");
}
},
e -> logger.error("Failed to refresh job memory requirements", e)
));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,15 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.cluster.AckedClusterStateUpdateTask;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.LocalNodeMasterListener;
import org.elasticsearch.cluster.ack.AckedRequest;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
import org.elasticsearch.threadpool.ThreadPool;
Expand All @@ -23,6 +29,8 @@
import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider;

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
Expand All @@ -32,15 +40,27 @@
/**
* This class keeps track of the memory requirement of ML jobs.
* It only functions on the master node - for this reason it should only be used by master node actions.
* The memory requirement for open ML jobs is updated at the following times:
* 1. When a master node is elected the memory requirement for all non-closed ML jobs is updated
* 2. The memory requirement for all non-closed ML jobs is updated periodically thereafter - every 30 seconds by default
* 3. When a job is opened the memory requirement for that single job is updated
* As a result of this every open job should have a value for its memory requirement that is no more than 30 seconds out-of-date.
* The memory requirement for ML jobs can be updated in 3 ways:
* 1. For all open ML jobs (via {@link #asyncRefresh})
* 2. For all open ML jobs, plus one named ML job that is not open (via {@link #refreshJobMemoryAndAllOthers})
* 3. For one named ML job (via {@link #refreshJobMemory})
* In all cases a listener informs the caller when the requested updates are complete.
*/
public class MlMemoryTracker implements LocalNodeMasterListener {

private static final Long RECENT_UPDATE_THRESHOLD_NS = 30_000_000_000L; // 30 seconds
private static final AckedRequest ACKED_REQUEST = new AckedRequest() {
@Override
public TimeValue ackTimeout() {
return AcknowledgedRequest.DEFAULT_ACK_TIMEOUT;
}

@Override
public TimeValue masterNodeTimeout() {
return AcknowledgedRequest.DEFAULT_ACK_TIMEOUT;
}
};

private static final Duration RECENT_UPDATE_THRESHOLD = Duration.ofMinutes(1);

private final Logger logger = LogManager.getLogger(MlMemoryTracker.class);
private final ConcurrentHashMap<String, Long> memoryRequirementByJob = new ConcurrentHashMap<>();
Expand All @@ -51,7 +71,7 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
private final JobManager jobManager;
private final JobResultsProvider jobResultsProvider;
private volatile boolean isMaster;
private volatile Long lastUpdateNanotime;
private volatile Instant lastUpdateTime;

public MlMemoryTracker(ClusterService clusterService, ThreadPool threadPool, JobManager jobManager,
JobResultsProvider jobResultsProvider) {
Expand All @@ -72,7 +92,7 @@ public void onMaster() {
public void offMaster() {
isMaster = false;
memoryRequirementByJob.clear();
lastUpdateNanotime = null;
lastUpdateTime = null;
}

@Override
Expand All @@ -85,8 +105,8 @@ public String executorName() {
* for valid allocation decisions to be made using it?
*/
public boolean isRecentlyRefreshed() {
Long localLastUpdateNanotime = lastUpdateNanotime;
return localLastUpdateNanotime != null && System.nanoTime() - localLastUpdateNanotime < RECENT_UPDATE_THRESHOLD_NS;
Instant localLastUpdateTime = lastUpdateTime;
return localLastUpdateTime != null && localLastUpdateTime.plus(RECENT_UPDATE_THRESHOLD).isAfter(Instant.now());
}

/**
Expand Down Expand Up @@ -122,16 +142,24 @@ public void removeJob(String jobId) {
/**
* Uses a separate thread to refresh the memory requirement for every ML job that has
* a corresponding persistent task. This method only works on the master node.
* @param listener Will be called when the async refresh completes or fails.
* @param listener Will be called when the async refresh completes or fails. The
* boolean value indicates whether the cluster state was updated
* with the refresh completion time. (If it was then this will in
* cause the persistent tasks framework to check if any persistent
* tasks are awaiting allocation.)
* @return <code>true</code> if the async refresh is scheduled, and <code>false</code>
* if this is not possible for some reason.
*/
public boolean asyncRefresh(ActionListener<Void> listener) {
public boolean asyncRefresh(ActionListener<Boolean> listener) {

if (isMaster) {
try {
ActionListener<Void> mlMetaUpdateListener = ActionListener.wrap(
aVoid -> recordUpdateTimeInClusterState(listener),
listener::onFailure
);
threadPool.executor(executorName()).execute(
() -> refresh(clusterService.state().getMetaData().custom(PersistentTasksCustomMetaData.TYPE), listener));
() -> refresh(clusterService.state().getMetaData().custom(PersistentTasksCustomMetaData.TYPE), mlMetaUpdateListener));
return true;
} catch (EsRejectedExecutionException e) {
logger.debug("Couldn't schedule ML memory update - node might be shutting down", e);
Expand Down Expand Up @@ -175,7 +203,7 @@ void refresh(PersistentTasksCustomMetaData persistentTasks, ActionListener<Void>
}

ActionListener<Void> refreshComplete = ActionListener.wrap(aVoid -> {
lastUpdateNanotime = System.nanoTime();
lastUpdateTime = Instant.now();
synchronized (fullRefreshCompletionListeners) {
assert fullRefreshCompletionListeners.isEmpty() == false;
for (ActionListener<Void> listener : fullRefreshCompletionListeners) {
Expand All @@ -195,6 +223,33 @@ void refresh(PersistentTasksCustomMetaData persistentTasks, ActionListener<Void>
}
}

void recordUpdateTimeInClusterState(ActionListener<Boolean> listener) {

clusterService.submitStateUpdateTask("ml-memory-last-update-time",
new AckedClusterStateUpdateTask<Boolean>(ACKED_REQUEST, listener) {
@Override
protected Boolean newResponse(boolean acknowledged) {
return acknowledged;
}

@Override
public ClusterState execute(ClusterState currentState) {
MlMetadata currentMlMetadata = MlMetadata.getMlMetadata(currentState);
MlMetadata.Builder builder = new MlMetadata.Builder(currentMlMetadata);
MlMetadata newMlMetadata = builder.build();
builder.setLastMemoryRefreshTime(lastUpdateTime);
droberts195 marked this conversation as resolved.
Show resolved Hide resolved
if (newMlMetadata.equals(currentMlMetadata)) {
// Return same reference if nothing has changed
return currentState;
} else {
ClusterState.Builder newState = ClusterState.builder(currentState);
newState.metaData(MetaData.builder(currentState.getMetaData()).putCustom(MlMetadata.TYPE, newMlMetadata).build());
return newState.build();
}
}
});
}

private void iterateMlJobTasks(Iterator<PersistentTasksCustomMetaData.PersistentTask<?>> iterator,
ActionListener<Void> refreshComplete) {
if (iterator.hasNext()) {
Expand Down
Loading