Skip to content

Commit

Permalink
[ML] Use perAllocation and perDeployment memory usage in the model as…
Browse files Browse the repository at this point in the history
…signment planner (#98874)

Building upon #98139, this PR extends the model assignment planning algorithms and the linear solver to use the extended memory fields. It also adds unit tests to verify the new behavior.

I needed to adjust the old unit tests since we use the estimateMemoryUsage routine, which would compute 2*memoryBytes + 240 MB as the memory requirement. Previously, in the unit tests, we were simply using memoryBytes field value.
  • Loading branch information
valeriy42 committed Nov 6, 2023
1 parent f3a4813 commit aa2f6e7
Show file tree
Hide file tree
Showing 20 changed files with 2,076 additions and 483 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/98874.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 98874
summary: Estimate the memory required to deploy trained models more accurately
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.cluster.SimpleDiffable;
import org.elasticsearch.common.Randomness;
Expand Down Expand Up @@ -96,6 +97,10 @@ public final class TrainedModelAssignment implements SimpleDiffable<TrainedModel
private final Instant startTime;
private final int maxAssignedAllocations;

public static boolean useNewMemoryFields(TransportVersion minClusterVersion) {
return minClusterVersion.onOrAfter(TransportVersions.V_8_500_064);
}

public static TrainedModelAssignment fromXContent(XContentParser parser) throws IOException {
return PARSER.apply(parser, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
Expand Down Expand Up @@ -296,29 +298,23 @@ private void modelSizeStats(
for (TrainedModelConfig model : models) {
if (model.getModelType() == TrainedModelType.PYTORCH) {
long totalDefinitionLength = pytorchTotalDefinitionLengthsByModelId.getOrDefault(model.getModelId(), 0L);
// We ensure that in the mixed cluster state trained model stats uses the same values for memory estimation
// as the rebalancer.
boolean useNewMemoryFields = TrainedModelAssignment.useNewMemoryFields(
TransportVersionUtils.getMinTransportVersion(clusterService.state())
);
long estimatedMemoryUsageBytes = totalDefinitionLength > 0L
? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
model.getModelId(),
totalDefinitionLength,
model.getPerDeploymentMemoryBytes(),
model.getPerAllocationMemoryBytes(),
useNewMemoryFields ? model.getPerDeploymentMemoryBytes() : 0,
useNewMemoryFields ? model.getPerAllocationMemoryBytes() : 0,
numberOfAllocations
)
: 0L;
modelSizeStatsByModelId.put(
model.getModelId(),
new TrainedModelSizeStats(
totalDefinitionLength,
totalDefinitionLength > 0L
? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
model.getModelId(),
totalDefinitionLength,
model.getPerDeploymentMemoryBytes(),
model.getPerAllocationMemoryBytes(),
numberOfAllocations
)
: 0L
)
new TrainedModelSizeStats(totalDefinitionLength, estimatedMemoryUsageBytes)
);
} else {
modelSizeStatsByModelId.put(model.getModelId(), new TrainedModelSizeStats(model.getModelSize(), 0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AllocationReducer;
Expand Down Expand Up @@ -76,6 +77,8 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene
private static final TransportVersion RENAME_ALLOCATION_TO_ASSIGNMENT_TRANSPORT_VERSION = TransportVersions.V_8_3_0;
public static final TransportVersion DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION = TransportVersions.V_8_4_0;

private static final TransportVersion NEW_ALLOCATION_MEMORY_VERSION = TransportVersions.V_8_500_064;

private final ClusterService clusterService;
private final ThreadPool threadPool;
private final NodeLoadDetector nodeLoadDetector;
Expand Down Expand Up @@ -644,12 +647,14 @@ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments(
Map<DiscoveryNode, NodeLoad> nodeLoads = detectNodeLoads(nodes, currentState);
TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.fromState(currentState);

boolean useNewMemoryFields = TrainedModelAssignment.useNewMemoryFields(TransportVersionUtils.getMinTransportVersion(currentState));
TrainedModelAssignmentRebalancer rebalancer = new TrainedModelAssignmentRebalancer(
currentMetadata,
nodeLoads,
nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(currentState),
modelToAdd,
allocatedProcessorsScale
allocatedProcessorsScale,
useNewMemoryFields
);

Set<String> shuttingDownNodeIds = currentState.metadata().nodeShutdowns().getAllNodeIds();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,22 @@ class TrainedModelAssignmentRebalancer {
private final Optional<StartTrainedModelDeploymentAction.TaskParams> deploymentToAdd;
private final int allocatedProcessorsScale;

private final boolean useNewMemoryFields;

TrainedModelAssignmentRebalancer(
TrainedModelAssignmentMetadata currentMetadata,
Map<DiscoveryNode, NodeLoad> nodeLoads,
Map<List<String>, Collection<DiscoveryNode>> mlNodesByZone,
Optional<StartTrainedModelDeploymentAction.TaskParams> deploymentToAdd,
int allocatedProcessorsScale
int allocatedProcessorsScale,
boolean useNewMemoryFields
) {
this.currentMetadata = Objects.requireNonNull(currentMetadata);
this.nodeLoads = Objects.requireNonNull(nodeLoads);
this.mlNodesByZone = Objects.requireNonNull(mlNodesByZone);
this.deploymentToAdd = Objects.requireNonNull(deploymentToAdd);
this.allocatedProcessorsScale = allocatedProcessorsScale;
this.useNewMemoryFields = useNewMemoryFields;
}

TrainedModelAssignmentMetadata.Builder rebalance() {
Expand Down Expand Up @@ -138,9 +142,11 @@ private static void copyAssignments(
AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id());
dest.assignModelToNode(m, originalNode, assignment.getValue());
if (m.currentAllocationsByNodeId().containsKey(originalNode.id())) {
// TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
// As the node has all its available memory we need to manually account memory of models with
// current allocations.
dest.accountMemory(m, originalNode);
long requiredMemory = m.estimateMemoryUsageBytes(m.currentAllocationsByNodeId().get(originalNode.id()));
dest.accountMemory(m, originalNode, requiredMemory);
}
}
}
Expand Down Expand Up @@ -168,11 +174,14 @@ private AssignmentPlan computePlanForNormalPriorityModels(
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getTargetAllocations()));
return new AssignmentPlan.Deployment(
assignment.getDeploymentId(),
assignment.getTaskParams().estimateMemoryUsageBytes(),
assignment.getTaskParams().getModelBytes(),
assignment.getTaskParams().getNumberOfAllocations(),
assignment.getTaskParams().getThreadsPerAllocation(),
currentAssignments,
assignment.getMaxAssignedAllocations()
assignment.getMaxAssignedAllocations(),
// in the mixed cluster state use old memory fields to avoid unstable assignment plans
useNewMemoryFields ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0,
useNewMemoryFields ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0
);
})
.forEach(planDeployments::add);
Expand All @@ -181,11 +190,14 @@ private AssignmentPlan computePlanForNormalPriorityModels(
planDeployments.add(
new AssignmentPlan.Deployment(
taskParams.getDeploymentId(),
taskParams.estimateMemoryUsageBytes(),
taskParams.getModelBytes(),
taskParams.getNumberOfAllocations(),
taskParams.getThreadsPerAllocation(),
Map.of(),
0
0,
// in the mixed cluster state use old memory fields to avoid unstable assignment plans
useNewMemoryFields ? taskParams.getPerDeploymentMemoryBytes() : 0,
useNewMemoryFields ? taskParams.getPerAllocationMemoryBytes() : 0
)
);
}
Expand Down Expand Up @@ -217,12 +229,14 @@ private AssignmentPlan computePlanForLowPriorityModels(Set<String> assignableNod
.map(
assignment -> new AssignmentPlan.Deployment(
assignment.getDeploymentId(),
assignment.getTaskParams().estimateMemoryUsageBytes(),
assignment.getTaskParams().getModelBytes(),
assignment.getTaskParams().getNumberOfAllocations(),
assignment.getTaskParams().getThreadsPerAllocation(),
findFittingAssignments(assignment, assignableNodeIds, remainingNodeMemory),
assignment.getMaxAssignedAllocations(),
Priority.LOW
Priority.LOW,
(useNewMemoryFields == false) ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0,
(useNewMemoryFields == false) ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0
)
)
.forEach(planDeployments::add);
Expand All @@ -231,12 +245,14 @@ private AssignmentPlan computePlanForLowPriorityModels(Set<String> assignableNod
planDeployments.add(
new AssignmentPlan.Deployment(
taskParams.getDeploymentId(),
taskParams.estimateMemoryUsageBytes(),
taskParams.getModelBytes(),
taskParams.getNumberOfAllocations(),
taskParams.getThreadsPerAllocation(),
Map.of(),
0,
Priority.LOW
Priority.LOW,
(useNewMemoryFields == false) ? taskParams.getPerDeploymentMemoryBytes() : 0,
(useNewMemoryFields == false) ? taskParams.getPerAllocationMemoryBytes() : 0
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ private Node modifyNodePreservingAllocations(Node n) {
int coresUsed = 0;
for (Deployment m : deployments) {
if (m.currentAllocationsByNodeId().containsKey(n.id())) {
bytesUsed += m.memoryBytes();
int allocations = m.currentAllocationsByNodeId().get(n.id());
bytesUsed += m.estimateMemoryUsageBytes(allocations);
coresUsed += calculateUsedCores(n, m);
}
}
Expand All @@ -58,7 +59,9 @@ Deployment modifyModelPreservingPreviousAssignments(Deployment m) {
m.allocations() - calculatePreservedAllocations(m),
m.threadsPerAllocation(),
calculateAllocationsPerNodeToPreserve(m),
m.maxAssignedAllocations()
m.maxAssignedAllocations(),
m.perDeploymentMemoryBytes(),
m.perAllocationMemoryBytes()
);
}

Expand All @@ -67,28 +70,37 @@ AssignmentPlan mergePreservedAllocations(AssignmentPlan assignmentPlan) {
// they will not match the models/nodes members we have in this class.
// Therefore, we build a lookup table based on the ids so we can merge the plan
// with its preserved allocations.
final Map<Tuple<String, String>, Integer> assignmentsByModelNodeIdPair = new HashMap<>();
final Map<Tuple<String, String>, Integer> plannedAssignmentsByModelNodeIdPair = new HashMap<>();
for (Deployment m : assignmentPlan.models()) {
Map<Node, Integer> assignments = assignmentPlan.assignments(m).orElse(Map.of());
for (Map.Entry<Node, Integer> nodeAssignment : assignments.entrySet()) {
assignmentsByModelNodeIdPair.put(Tuple.tuple(m.id(), nodeAssignment.getKey().id()), nodeAssignment.getValue());
plannedAssignmentsByModelNodeIdPair.put(Tuple.tuple(m.id(), nodeAssignment.getKey().id()), nodeAssignment.getValue());
}
}

AssignmentPlan.Builder mergedPlanBuilder = AssignmentPlan.builder(nodes, deployments);
for (Deployment m : deployments) {
for (Node n : nodes) {
int allocations = assignmentsByModelNodeIdPair.getOrDefault(Tuple.tuple(m.id(), n.id()), 0);
if (m.currentAllocationsByNodeId().containsKey(n.id())) {
if (mergedPlanBuilder.getRemainingMemory(n) >= m.memoryBytes()) {
allocations += addPreservedAllocations(n, m);
// As the node has all its available memory we need to manually account memory of models with
// current allocations.
mergedPlanBuilder.accountMemory(m, n);
for (Node n : nodes) {
// TODO (#101612) Should the first loop happen in the builder constructor?
for (Deployment deploymentAllocationsToPreserve : deployments) {

// if the model m is already allocated on the node n and I want to preserve this allocation
int preservedAllocations = addPreservedAllocations(n, deploymentAllocationsToPreserve);
if (preservedAllocations > 0) {
long requiredMemory = deploymentAllocationsToPreserve.estimateMemoryUsageBytes(preservedAllocations);
if (mergedPlanBuilder.canAssign(deploymentAllocationsToPreserve, n, preservedAllocations, requiredMemory)) {
mergedPlanBuilder.assignModelToNode(deploymentAllocationsToPreserve, n, preservedAllocations, requiredMemory);
}
}
if (allocations > 0) {
mergedPlanBuilder.assignModelToNode(m, n, allocations);
}
for (Deployment deploymentNewAllocations : deployments) {
int newAllocations = plannedAssignmentsByModelNodeIdPair.getOrDefault(
Tuple.tuple(deploymentNewAllocations.id(), n.id()),
0
);

long requiredMemory = mergedPlanBuilder.getDeploymentMemoryRequirement(deploymentNewAllocations, n, newAllocations);
if (newAllocations > 0 && mergedPlanBuilder.canAssign(deploymentNewAllocations, n, newAllocations, requiredMemory)) {
mergedPlanBuilder.assignModelToNode(deploymentNewAllocations, n, newAllocations);
}
}
}
Expand Down
Loading

0 comments on commit aa2f6e7

Please sign in to comment.