diff --git a/docs/changelog/98874.yaml b/docs/changelog/98874.yaml new file mode 100644 index 0000000000000..e3eb7b5acc63f --- /dev/null +++ b/docs/changelog/98874.yaml @@ -0,0 +1,5 @@ +pr: 98874 +summary: Estimate the memory required to deploy trained models more accurately +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java index f69be31939b32..d27d325a5c596 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java @@ -9,6 +9,7 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.cluster.SimpleDiffable; import org.elasticsearch.common.Randomness; @@ -96,6 +97,10 @@ public final class TrainedModelAssignment implements SimpleDiffable 0L ? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( model.getModelId(), totalDefinitionLength, - model.getPerDeploymentMemoryBytes(), - model.getPerAllocationMemoryBytes(), + useNewMemoryFields ? model.getPerDeploymentMemoryBytes() : 0, + useNewMemoryFields ? model.getPerAllocationMemoryBytes() : 0, numberOfAllocations ) : 0L; modelSizeStatsByModelId.put( model.getModelId(), - new TrainedModelSizeStats( - totalDefinitionLength, - totalDefinitionLength > 0L - ? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( - model.getModelId(), - totalDefinitionLength, - model.getPerDeploymentMemoryBytes(), - model.getPerAllocationMemoryBytes(), - numberOfAllocations - ) - : 0L - ) + new TrainedModelSizeStats(totalDefinitionLength, estimatedMemoryUsageBytes) ); } else { modelSizeStatsByModelId.put(model.getModelId(), new TrainedModelSizeStats(model.getModelSize(), 0)); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java index 2caf338d2a3c7..fe4462d6556ee 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java @@ -47,6 +47,7 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; +import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper; import org.elasticsearch.xpack.ml.inference.assignment.planning.AllocationReducer; @@ -76,6 +77,8 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene private static final TransportVersion RENAME_ALLOCATION_TO_ASSIGNMENT_TRANSPORT_VERSION = TransportVersions.V_8_3_0; public static final TransportVersion DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION = TransportVersions.V_8_4_0; + private static final TransportVersion NEW_ALLOCATION_MEMORY_VERSION = TransportVersions.V_8_500_064; + private final ClusterService clusterService; private final ThreadPool threadPool; private final NodeLoadDetector nodeLoadDetector; @@ -644,12 +647,14 @@ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments( Map nodeLoads = detectNodeLoads(nodes, currentState); TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.fromState(currentState); + boolean useNewMemoryFields = TrainedModelAssignment.useNewMemoryFields(TransportVersionUtils.getMinTransportVersion(currentState)); TrainedModelAssignmentRebalancer rebalancer = new TrainedModelAssignmentRebalancer( currentMetadata, nodeLoads, nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(currentState), modelToAdd, - allocatedProcessorsScale + allocatedProcessorsScale, + useNewMemoryFields ); Set shuttingDownNodeIds = currentState.metadata().nodeShutdowns().getAllNodeIds(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java index e1241dc8a93c3..6e6b447fcea3d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java @@ -52,18 +52,22 @@ class TrainedModelAssignmentRebalancer { private final Optional deploymentToAdd; private final int allocatedProcessorsScale; + private final boolean useNewMemoryFields; + TrainedModelAssignmentRebalancer( TrainedModelAssignmentMetadata currentMetadata, Map nodeLoads, Map, Collection> mlNodesByZone, Optional deploymentToAdd, - int allocatedProcessorsScale + int allocatedProcessorsScale, + boolean useNewMemoryFields ) { this.currentMetadata = Objects.requireNonNull(currentMetadata); this.nodeLoads = Objects.requireNonNull(nodeLoads); this.mlNodesByZone = Objects.requireNonNull(mlNodesByZone); this.deploymentToAdd = Objects.requireNonNull(deploymentToAdd); this.allocatedProcessorsScale = allocatedProcessorsScale; + this.useNewMemoryFields = useNewMemoryFields; } TrainedModelAssignmentMetadata.Builder rebalance() { @@ -138,9 +142,11 @@ private static void copyAssignments( AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id()); dest.assignModelToNode(m, originalNode, assignment.getValue()); if (m.currentAllocationsByNodeId().containsKey(originalNode.id())) { + // TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder // As the node has all its available memory we need to manually account memory of models with // current allocations. - dest.accountMemory(m, originalNode); + long requiredMemory = m.estimateMemoryUsageBytes(m.currentAllocationsByNodeId().get(originalNode.id())); + dest.accountMemory(m, originalNode, requiredMemory); } } } @@ -168,11 +174,14 @@ private AssignmentPlan computePlanForNormalPriorityModels( .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getTargetAllocations())); return new AssignmentPlan.Deployment( assignment.getDeploymentId(), - assignment.getTaskParams().estimateMemoryUsageBytes(), + assignment.getTaskParams().getModelBytes(), assignment.getTaskParams().getNumberOfAllocations(), assignment.getTaskParams().getThreadsPerAllocation(), currentAssignments, - assignment.getMaxAssignedAllocations() + assignment.getMaxAssignedAllocations(), + // in the mixed cluster state use old memory fields to avoid unstable assignment plans + useNewMemoryFields ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0, + useNewMemoryFields ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0 ); }) .forEach(planDeployments::add); @@ -181,11 +190,14 @@ private AssignmentPlan computePlanForNormalPriorityModels( planDeployments.add( new AssignmentPlan.Deployment( taskParams.getDeploymentId(), - taskParams.estimateMemoryUsageBytes(), + taskParams.getModelBytes(), taskParams.getNumberOfAllocations(), taskParams.getThreadsPerAllocation(), Map.of(), - 0 + 0, + // in the mixed cluster state use old memory fields to avoid unstable assignment plans + useNewMemoryFields ? taskParams.getPerDeploymentMemoryBytes() : 0, + useNewMemoryFields ? taskParams.getPerAllocationMemoryBytes() : 0 ) ); } @@ -217,12 +229,14 @@ private AssignmentPlan computePlanForLowPriorityModels(Set assignableNod .map( assignment -> new AssignmentPlan.Deployment( assignment.getDeploymentId(), - assignment.getTaskParams().estimateMemoryUsageBytes(), + assignment.getTaskParams().getModelBytes(), assignment.getTaskParams().getNumberOfAllocations(), assignment.getTaskParams().getThreadsPerAllocation(), findFittingAssignments(assignment, assignableNodeIds, remainingNodeMemory), assignment.getMaxAssignedAllocations(), - Priority.LOW + Priority.LOW, + (useNewMemoryFields == false) ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0, + (useNewMemoryFields == false) ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0 ) ) .forEach(planDeployments::add); @@ -231,12 +245,14 @@ private AssignmentPlan computePlanForLowPriorityModels(Set assignableNod planDeployments.add( new AssignmentPlan.Deployment( taskParams.getDeploymentId(), - taskParams.estimateMemoryUsageBytes(), + taskParams.getModelBytes(), taskParams.getNumberOfAllocations(), taskParams.getThreadsPerAllocation(), Map.of(), 0, - Priority.LOW + Priority.LOW, + (useNewMemoryFields == false) ? taskParams.getPerDeploymentMemoryBytes() : 0, + (useNewMemoryFields == false) ? taskParams.getPerAllocationMemoryBytes() : 0 ) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java index 4843cc43d1187..026b433a8c2d4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java @@ -35,7 +35,8 @@ private Node modifyNodePreservingAllocations(Node n) { int coresUsed = 0; for (Deployment m : deployments) { if (m.currentAllocationsByNodeId().containsKey(n.id())) { - bytesUsed += m.memoryBytes(); + int allocations = m.currentAllocationsByNodeId().get(n.id()); + bytesUsed += m.estimateMemoryUsageBytes(allocations); coresUsed += calculateUsedCores(n, m); } } @@ -58,7 +59,9 @@ Deployment modifyModelPreservingPreviousAssignments(Deployment m) { m.allocations() - calculatePreservedAllocations(m), m.threadsPerAllocation(), calculateAllocationsPerNodeToPreserve(m), - m.maxAssignedAllocations() + m.maxAssignedAllocations(), + m.perDeploymentMemoryBytes(), + m.perAllocationMemoryBytes() ); } @@ -67,28 +70,37 @@ AssignmentPlan mergePreservedAllocations(AssignmentPlan assignmentPlan) { // they will not match the models/nodes members we have in this class. // Therefore, we build a lookup table based on the ids so we can merge the plan // with its preserved allocations. - final Map, Integer> assignmentsByModelNodeIdPair = new HashMap<>(); + final Map, Integer> plannedAssignmentsByModelNodeIdPair = new HashMap<>(); for (Deployment m : assignmentPlan.models()) { Map assignments = assignmentPlan.assignments(m).orElse(Map.of()); for (Map.Entry nodeAssignment : assignments.entrySet()) { - assignmentsByModelNodeIdPair.put(Tuple.tuple(m.id(), nodeAssignment.getKey().id()), nodeAssignment.getValue()); + plannedAssignmentsByModelNodeIdPair.put(Tuple.tuple(m.id(), nodeAssignment.getKey().id()), nodeAssignment.getValue()); } } AssignmentPlan.Builder mergedPlanBuilder = AssignmentPlan.builder(nodes, deployments); - for (Deployment m : deployments) { - for (Node n : nodes) { - int allocations = assignmentsByModelNodeIdPair.getOrDefault(Tuple.tuple(m.id(), n.id()), 0); - if (m.currentAllocationsByNodeId().containsKey(n.id())) { - if (mergedPlanBuilder.getRemainingMemory(n) >= m.memoryBytes()) { - allocations += addPreservedAllocations(n, m); - // As the node has all its available memory we need to manually account memory of models with - // current allocations. - mergedPlanBuilder.accountMemory(m, n); + for (Node n : nodes) { + // TODO (#101612) Should the first loop happen in the builder constructor? + for (Deployment deploymentAllocationsToPreserve : deployments) { + + // if the model m is already allocated on the node n and I want to preserve this allocation + int preservedAllocations = addPreservedAllocations(n, deploymentAllocationsToPreserve); + if (preservedAllocations > 0) { + long requiredMemory = deploymentAllocationsToPreserve.estimateMemoryUsageBytes(preservedAllocations); + if (mergedPlanBuilder.canAssign(deploymentAllocationsToPreserve, n, preservedAllocations, requiredMemory)) { + mergedPlanBuilder.assignModelToNode(deploymentAllocationsToPreserve, n, preservedAllocations, requiredMemory); } } - if (allocations > 0) { - mergedPlanBuilder.assignModelToNode(m, n, allocations); + } + for (Deployment deploymentNewAllocations : deployments) { + int newAllocations = plannedAssignmentsByModelNodeIdPair.getOrDefault( + Tuple.tuple(deploymentNewAllocations.id(), n.id()), + 0 + ); + + long requiredMemory = mergedPlanBuilder.getDeploymentMemoryRequirement(deploymentNewAllocations, n, newAllocations); + if (newAllocations > 0 && mergedPlanBuilder.canAssign(deploymentNewAllocations, n, newAllocations, requiredMemory)) { + mergedPlanBuilder.assignModelToNode(deploymentNewAllocations, n, newAllocations); } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java index 72a83d7579463..1dce7f0bb46ba 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.Maps; import org.elasticsearch.core.Tuple; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; import java.util.ArrayList; @@ -36,18 +37,32 @@ public record Deployment( int threadsPerAllocation, Map currentAllocationsByNodeId, int maxAssignedAllocations, - Priority priority + Priority priority, + long perDeploymentMemoryBytes, + long perAllocationMemoryBytes ) { public Deployment( String id, - long memoryBytes, + long modelBytes, int allocations, int threadsPerAllocation, Map currentAllocationsByNodeId, - int maxAssignedAllocations + int maxAssignedAllocations, + long perDeploymentMemoryBytes, + long perAllocationMemoryBytes ) { - this(id, memoryBytes, allocations, threadsPerAllocation, currentAllocationsByNodeId, maxAssignedAllocations, Priority.NORMAL); + this( + id, + modelBytes, + allocations, + threadsPerAllocation, + currentAllocationsByNodeId, + maxAssignedAllocations, + Priority.NORMAL, + perDeploymentMemoryBytes, + perAllocationMemoryBytes + ); } int getCurrentAssignedAllocations() { @@ -58,6 +73,60 @@ boolean hasEverBeenAllocated() { return maxAssignedAllocations > 0; } + public long estimateMemoryUsageBytes(int allocations) { + return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( + id, + memoryBytes, + perDeploymentMemoryBytes, + perAllocationMemoryBytes, + allocations + ); + } + + long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) { + return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( + id, + memoryBytes, + perDeploymentMemoryBytes, + perAllocationMemoryBytes, + allocationsNew + ) - StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( + id, + memoryBytes, + perDeploymentMemoryBytes, + perAllocationMemoryBytes, + allocationsOld + ); + + } + + long minimumMemoryRequiredBytes() { + return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( + id, + memoryBytes, + perDeploymentMemoryBytes, + perAllocationMemoryBytes, + 1 + ); + } + + int findOptimalAllocations(int maxAllocations, long availableMemoryBytes) { + if (perDeploymentMemoryBytes > 0 && perAllocationMemoryBytes > 0) { + return (int) Math.max( + Math.min(maxAllocations, Math.floorDiv(availableMemoryBytes - estimateMemoryUsageBytes(0), perAllocationMemoryBytes)), + 0 + ); + } + return maxAllocations; + } + + int findExcessAllocations(int maxAllocations, long availableMemoryBytes) { + if (perDeploymentMemoryBytes > 0 && perAllocationMemoryBytes > 0) { + return (int) Math.min(maxAllocations, Math.floorDiv(availableMemoryBytes, perAllocationMemoryBytes)); + } + return maxAllocations; + } + @Override public String toString() { return id @@ -71,6 +140,8 @@ public String toString() { + currentAllocationsByNodeId + ") (max_assigned_allocations = " + maxAssignedAllocations + + ") (memory_usage = " + + ByteSizeValue.ofBytes(estimateMemoryUsageBytes(allocations)) + ")"; } }; @@ -304,19 +375,42 @@ int getRemainingAllocations(Deployment m) { } boolean canAssign(Deployment deployment, Node node, int allocations) { - return (isAlreadyAssigned(deployment, node) - || (deployment.memoryBytes() <= remainingNodeMemory.get(node)) - && (deployment.priority == Priority.LOW - || allocations * deployment.threadsPerAllocation() <= remainingNodeCores.get(node))); + long requiredMemory = getDeploymentMemoryRequirement(deployment, node, allocations); + return canAssign(deployment, node, allocations, requiredMemory); + } + + boolean canAssign(Deployment deployment, Node node, int allocations, long requiredMemory) { + return (requiredMemory <= remainingNodeMemory.get(node)) + && (deployment.priority == Priority.LOW || allocations * deployment.threadsPerAllocation() <= remainingNodeCores.get(node)); + } + + public long getDeploymentMemoryRequirement(Deployment deployment, Node node, int newAllocations) { + int assignedAllocations = getAssignedAllocations(deployment, node); + + if (assignedAllocations > 0) { + return deployment.estimateAdditionalMemoryUsageBytes(assignedAllocations, assignedAllocations + newAllocations); + } + return deployment.estimateMemoryUsageBytes(newAllocations); } public Builder assignModelToNode(Deployment deployment, Node node, int allocations) { + return assignModelToNode(deployment, node, allocations, getDeploymentMemoryRequirement(deployment, node, allocations)); + } + + public Builder assignModelToNode(Deployment deployment, Node node, int allocations, long requiredMemory) { if (allocations <= 0) { return this; } - if (isAlreadyAssigned(deployment, node) == false && deployment.memoryBytes() > remainingNodeMemory.get(node)) { + if (/*isAlreadyAssigned(deployment, node) == false + &&*/ requiredMemory > remainingNodeMemory.get(node)) { throw new IllegalArgumentException( - "not enough memory on node [" + node.id() + "] to assign model [" + deployment.id() + "]" + "not enough memory on node [" + + node.id() + + "] to assign [" + + allocations + + "] allocations to deployment [" + + deployment.id() + + "]" ); } if (deployment.priority == Priority.NORMAL && allocations * deployment.threadsPerAllocation() > remainingNodeCores.get(node)) { @@ -333,9 +427,9 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio ); } - long additionalModelMemory = isAlreadyAssigned(deployment, node) ? 0 : deployment.memoryBytes; assignments.get(deployment).compute(node, (n, remAllocations) -> remAllocations + allocations); - remainingNodeMemory.compute(node, (n, remMemory) -> remMemory - additionalModelMemory); + accountMemory(deployment, node, requiredMemory); + if (deployment.priority == Priority.NORMAL) { remainingNodeCores.compute(node, (n, remCores) -> remCores - allocations * deployment.threadsPerAllocation()); } @@ -347,9 +441,26 @@ private boolean isAlreadyAssigned(Deployment deployment, Node node) { return deployment.currentAllocationsByNodeId().containsKey(node.id()) || assignments.get(deployment).get(node) > 0; } + private int getAssignedAllocations(Deployment deployment, Node node) { + int currentAllocations = getCurrentAllocations(deployment, node); + int assignmentAllocations = assignments.get(deployment).get(node); + return currentAllocations + assignmentAllocations; + } + + private static int getCurrentAllocations(Deployment m, Node n) { + return m.currentAllocationsByNodeId.containsKey(n.id()) ? m.currentAllocationsByNodeId.get(n.id()) : 0; + } + public void accountMemory(Deployment m, Node n) { - remainingNodeMemory.computeIfPresent(n, (k, v) -> v - m.memoryBytes()); - if (remainingNodeMemory.get(n) < 0) { + // TODO (#101612) remove or refactor unused method + long requiredMemory = getDeploymentMemoryRequirement(m, n, getCurrentAllocations(m, n)); + accountMemory(m, n, requiredMemory); + } + + public void accountMemory(Deployment m, Node n, long requiredMemory) { + // TODO (#101612) computation of required memory should be done internally + remainingNodeMemory.computeIfPresent(n, (k, v) -> v - requiredMemory); + if (remainingNodeMemory.containsKey(n) && remainingNodeMemory.get(n) < 0) { throw new IllegalArgumentException("not enough memory on node [" + n.id() + "] to assign model [" + m.id() + "]"); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java index 73b713cced32a..b1c017b1a784c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java @@ -115,8 +115,11 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat m.memoryBytes(), 1, m.threadsPerAllocation(), - m.currentAllocationsByNodeId(), - m.maxAssignedAllocations() + // don't rely on the current allocation + new HashMap<>(), + m.maxAssignedAllocations(), + m.perDeploymentMemoryBytes(), + m.perAllocationMemoryBytes() ) ) .toList(); @@ -145,7 +148,9 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat m.allocations(), m.threadsPerAllocation(), currentAllocationsByNodeId, - m.maxAssignedAllocations() + m.maxAssignedAllocations(), + m.perDeploymentMemoryBytes(), + m.perAllocationMemoryBytes() ); }).toList(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java index 90c5a2257d94d..bd97680e285cc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java @@ -68,6 +68,8 @@ class LinearProgrammingPlanSolver { private final Map normalizedMemoryPerNode; private final Map coresPerNode; private final Map normalizedMemoryPerModel; + private final Map normalizedMemoryPerAllocation; + private final Map normalizedMinimumDeploymentMemoryRequired; private final int maxNodeCores; private final long maxModelMemoryBytes; @@ -84,12 +86,17 @@ class LinearProgrammingPlanSolver { .filter(m -> m.threadsPerAllocation() <= maxNodeCores) .toList(); - maxModelMemoryBytes = this.deployments.stream().map(AssignmentPlan.Deployment::memoryBytes).max(Long::compareTo).orElse(1L); + // We use the maximum memory to deploy a model with one allocation as the normalization factor. + maxModelMemoryBytes = this.deployments.stream().map(m -> m.minimumMemoryRequiredBytes()).max(Long::compareTo).orElse(1L); normalizedMemoryPerNode = this.nodes.stream() .collect(Collectors.toMap(Function.identity(), n -> n.availableMemoryBytes() / (double) maxModelMemoryBytes)); coresPerNode = this.nodes.stream().collect(Collectors.toMap(Function.identity(), Node::cores)); normalizedMemoryPerModel = this.deployments.stream() - .collect(Collectors.toMap(Function.identity(), m -> m.memoryBytes() / (double) maxModelMemoryBytes)); + .collect(Collectors.toMap(Function.identity(), m -> m.estimateMemoryUsageBytes(0) / (double) maxModelMemoryBytes)); + normalizedMemoryPerAllocation = this.deployments.stream() + .collect(Collectors.toMap(Function.identity(), m -> m.perAllocationMemoryBytes() / (double) maxModelMemoryBytes)); + normalizedMinimumDeploymentMemoryRequired = this.deployments.stream() + .collect(Collectors.toMap(Function.identity(), m -> m.minimumMemoryRequiredBytes() / (double) maxModelMemoryBytes)); } AssignmentPlan solvePlan(boolean useBinPackingOnly) { @@ -133,8 +140,8 @@ private double weightForAllocationVar( Node n, Map, Double> weights ) { - return (1 + weights.get(Tuple.tuple(m, n)) - (m.memoryBytes() > n.availableMemoryBytes() ? 10 : 0)) - L1 * normalizedMemoryPerModel - .get(m) / maxNodeCores; + return (1 + weights.get(Tuple.tuple(m, n)) - (m.minimumMemoryRequiredBytes() > n.availableMemoryBytes() ? 10 : 0)) - L1 + * normalizedMemoryPerModel.get(m) / maxNodeCores; } private Tuple, Double>, AssignmentPlan> calculateWeightsAndBinPackingPlan() { @@ -156,9 +163,9 @@ private Tuple, Double>, AssignmentPlan> calculateWei .sorted(Comparator.comparingDouble(n -> descendingSizeAnyFitsNodeOrder(n, m, assignmentPlan))) .toList(); for (Node n : orderedNodes) { - int allocations = Math.min( - assignmentPlan.getRemainingCores(n) / m.threadsPerAllocation(), - assignmentPlan.getRemainingAllocations(m) + int allocations = m.findOptimalAllocations( + Math.min(assignmentPlan.getRemainingCores(n) / m.threadsPerAllocation(), assignmentPlan.getRemainingAllocations(m)), + assignmentPlan.getRemainingMemory(n) ); if (allocations > 0 && assignmentPlan.canAssign(m, n, allocations)) { assignmentPlan.assignModelToNode(m, n, allocations); @@ -185,7 +192,8 @@ private Tuple, Double>, AssignmentPlan> calculateWei } private double descendingSizeAnyFitsModelOrder(AssignmentPlan.Deployment m) { - return (m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -normalizedMemoryPerModel.get(m) * m.threadsPerAllocation(); + return (m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -normalizedMinimumDeploymentMemoryRequired.get(m) * m + .threadsPerAllocation(); } private double descendingSizeAnyFitsNodeOrder(Node n, AssignmentPlan.Deployment m, AssignmentPlan.Builder assignmentPlan) { @@ -307,7 +315,10 @@ private boolean solveLinearProgram( List modelMemories = new ArrayList<>(); deployments.stream().filter(m -> m.currentAllocationsByNodeId().containsKey(n.id()) == false).forEach(m -> { allocations.add(allocationVars.get(Tuple.tuple(m, n))); - modelMemories.add(normalizedMemoryPerModel.get(m) * m.threadsPerAllocation() / (double) coresPerNode.get(n)); + modelMemories.add( + (normalizedMemoryPerModel.get(m) / (double) coresPerNode.get(n) + normalizedMemoryPerAllocation.get(m)) * m + .threadsPerAllocation() + ); }); model.addExpression("used_memory_on_node_" + n.id() + "_not_more_than_available") .upper(normalizedMemoryPerNode.get(n)) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java index f10ece8f5a593..72109941ad477 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocations.java @@ -37,6 +37,6 @@ protected int calculatePreservedAllocations(Deployment m) { @Override protected int addPreservedAllocations(Node n, Deployment m) { - return m.currentAllocationsByNodeId().get(n.id()); + return m.currentAllocationsByNodeId().containsKey(n.id()) ? m.currentAllocationsByNodeId().get(n.id()) : 0; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java index 324e1a8d69a53..43b8860803596 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocation.java @@ -37,6 +37,6 @@ protected int calculatePreservedAllocations(AssignmentPlan.Deployment m) { @Override protected int addPreservedAllocations(Node n, AssignmentPlan.Deployment m) { - return 1; + return m.currentAllocationsByNodeId().containsKey(n.id()) ? 1 : 0; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java index dafc07099f850..8bdc99998a0c2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.java @@ -135,8 +135,9 @@ private void assignUnderSubscribedNodes(Collection nodeSelection) { for (AssignmentPlan.Deployment m : deployments) { Tuple assignment = Tuple.tuple(m, n); if (assignments.get(assignment) > 0) { - totalModelMemory += m.memoryBytes(); - maxTotalThreads += (int) Math.ceil(allocations.get(assignment)) * m.threadsPerAllocation(); + int roundedAllocations = (int) Math.ceil(allocations.get(assignment)); + totalModelMemory += m.estimateMemoryUsageBytes(roundedAllocations); + maxTotalThreads += roundedAllocations * m.threadsPerAllocation(); assignedDeployments.add(m); } } @@ -199,9 +200,12 @@ private void assignExcessCores(Node n) { if (resourceTracker.remainingNodeCores.get(n) <= 0) { break; } - int extraAllocations = Math.min( - resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation(), - resourceTracker.remainingModelAllocations.get(m) + int extraAllocations = m.findExcessAllocations( + Math.min( + resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation(), + resourceTracker.remainingModelAllocations.get(m) + ), + resourceTracker.remainingNodeMemory.get(n) ); allocations.compute(Tuple.tuple(m, n), (k, v) -> v + extraAllocations); resourceTracker.assign(m, n, extraAllocations); @@ -211,7 +215,7 @@ private void assignExcessCores(Node n) { } private static double remainingModelOrder(AssignmentPlan.Deployment m) { - return (m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -m.memoryBytes(); + return (m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -m.minimumMemoryRequiredBytes(); } private boolean hasSoftAssignments(Node n) { @@ -275,15 +279,17 @@ private void doRandomizedRounding(List> s int roundedAllocations = random.nextDouble() < roundUpProbability ? (int) Math.ceil(allocations.get(assignment)) : (int) Math.floor(allocations.get(assignment)); - - if (m.memoryBytes() > resourceTracker.remainingNodeMemory.get(n) + if (m.estimateMemoryUsageBytes(roundedAllocations) > resourceTracker.remainingNodeMemory.get(n) || m.threadsPerAllocation() > resourceTracker.remainingNodeCores.get(n) || roundedAllocations == 0 || random.nextDouble() > assignments.get(assignment)) { unassign(assignment); assignUnderSubscribedNodes(Set.of(n)); } else { - roundedAllocations = Math.min(roundedAllocations, resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation()); + roundedAllocations = m.findOptimalAllocations( + Math.min(roundedAllocations, resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation()), + resourceTracker.remainingNodeMemory.get(n) + ); assignModelToNode(m, n, roundedAllocations); unassignOversizedModels(n); assignExcessCores(n); @@ -294,7 +300,8 @@ private void doRandomizedRounding(List> s private void unassignOversizedModels(Node n) { for (AssignmentPlan.Deployment m : deployments) { Tuple assignment = Tuple.tuple(m, n); - if (assignments.get(assignment) < 1.0 && m.memoryBytes() > resourceTracker.remainingNodeMemory.get(n)) { + int roundedAllocations = (int) Math.ceil(allocations.get(assignment)); + if (assignments.get(assignment) < 1.0 && m.minimumMemoryRequiredBytes() > resourceTracker.remainingNodeMemory.get(n)) { unassign(assignment); } } @@ -303,7 +310,11 @@ private void unassignOversizedModels(Node n) { private AssignmentPlan toPlan() { AssignmentPlan.Builder builder = AssignmentPlan.builder(nodes, deployments); for (Map.Entry, Integer> assignment : tryAssigningRemainingCores().entrySet()) { - builder.assignModelToNode(assignment.getKey().v1(), assignment.getKey().v2(), assignment.getValue()); + // TODO (#101612) The model should be assigned to the node only when it is possible. This means, that canAssign should be + // integrated into the assignModelToNode. + if (builder.canAssign(assignment.getKey().v1(), assignment.getKey().v2(), assignment.getValue())) { + builder.assignModelToNode(assignment.getKey().v1(), assignment.getKey().v2(), assignment.getValue()); + } } return builder.build(); } @@ -338,7 +349,7 @@ private Map, Integer> tryAssigningRemaini .toList()) { for (Node n : nodes.stream() .filter( - n -> resourceTracker.remainingNodeMemory.get(n) >= m.memoryBytes() + n -> resourceTracker.remainingNodeMemory.get(n) >= m.minimumMemoryRequiredBytes() && resourceTracker.remainingNodeCores.get(n) >= m.threadsPerAllocation() && resultAllocations.get(Tuple.tuple(m, n)) == 0 ) @@ -354,10 +365,15 @@ private Map, Integer> tryAssigningRemaini ) ) .toList()) { - int assigningAllocations = Math.min( resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation(), - resourceTracker.remainingModelAllocations.get(m) + Math.min( + resourceTracker.remainingModelAllocations.get(m), + m.findOptimalAllocations( + resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation(), + resourceTracker.remainingModelAllocations.get(m) + ) + ) ); resourceTracker.assign(m, n, assigningAllocations); resultAllocations.put(Tuple.tuple(m, n), assigningAllocations); @@ -427,7 +443,7 @@ private static class ResourceTracker { void assign(AssignmentPlan.Deployment m, Node n, int allocations) { if (assignments.contains(Tuple.tuple(m, n)) == false) { assignments.add(Tuple.tuple(m, n)); - remainingNodeMemory.compute(n, (k, v) -> v - m.memoryBytes()); + remainingNodeMemory.compute(n, (k, v) -> v - m.estimateMemoryUsageBytes(allocations)); } remainingNodeCores.compute(n, (k, v) -> v - allocations * m.threadsPerAllocation()); remainingModelAllocations.compute(m, (k, v) -> v - allocations); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java index 9870aa93bf6ce..8c9499ca9e00c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java @@ -126,10 +126,12 @@ private AssignmentPlan computeZonePlan( modelIdToTargetAllocations.get(m.id()), m.threadsPerAllocation(), m.currentAllocationsByNodeId(), - // Only force assigning at least once previously assigned models that have not had any allocation yet (tryAssigningPreviouslyAssignedModels && modelIdToRemainingAllocations.get(m.id()) == m.allocations()) ? m.maxAssignedAllocations() - : 0 + : 0, + // Only force assigning at least once previously assigned models that have not had any allocation yet + m.perDeploymentMemoryBytes(), + m.perAllocationMemoryBytes() ) ) .toList(); @@ -151,7 +153,9 @@ private AssignmentPlan computePlanAcrossAllNodes(List plans) { m.allocations(), m.threadsPerAllocation(), allocationsByNodeIdByModelId.get(m.id()), - m.maxAssignedAllocations() + m.maxAssignedAllocations(), + m.perDeploymentMemoryBytes(), + m.perAllocationMemoryBytes() ) ) .toList(); @@ -180,9 +184,13 @@ private AssignmentPlan swapOriginalModelsInPlan( Node originalNode = originalNodeById.get(assignment.getKey().id()); planBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue()); if (originalDeployment.currentAllocationsByNodeId().containsKey(originalNode.id())) { + // TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder // As the node has all its available memory we need to manually account memory of models with // current allocations. - planBuilder.accountMemory(m, originalNode); + long requiredMemory = originalDeployment.estimateMemoryUsageBytes( + originalDeployment.currentAllocationsByNodeId().get(originalNode.id()) + ); + planBuilder.accountMemory(m, originalNode, requiredMemory); } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java index 8ccf8839cfc08..334fdfbb8b922 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java @@ -44,7 +44,8 @@ public void testRebalance_GivenNoAssignments() { Map.of(), Map.of(), Optional.empty(), - 1 + 1, + false ).rebalance().build(); assertThat(result.allAssignments().isEmpty(), is(true)); } @@ -78,7 +79,8 @@ public void testRebalance_GivenAllAssignmentsAreSatisfied_ShouldMakeNoChanges() nodeLoads, Map.of(), Optional.empty(), - 1 + 1, + false ).rebalance().build(); assertThat(currentMetadata, equalTo(result)); @@ -116,7 +118,8 @@ public void testRebalance_GivenAllAssignmentsAreSatisfied_GivenOutdatedRoutingEn nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(2))); @@ -140,7 +143,7 @@ public void testRebalance_GivenModelToAddAlreadyExists() { .build(); expectThrows( ResourceAlreadyExistsException.class, - () -> new TrainedModelAssignmentRebalancer(currentMetadata, Map.of(), Map.of(), Optional.of(taskParams), 1).rebalance() + () -> new TrainedModelAssignmentRebalancer(currentMetadata, Map.of(), Map.of(), Optional.of(taskParams), 1, false).rebalance() ); } @@ -154,7 +157,8 @@ public void testRebalance_GivenFirstModelToAdd_NoMLNodes() throws Exception { Map.of(), Map.of(), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -181,7 +185,8 @@ public void testRebalance_GivenFirstModelToAdd_NotEnoughProcessors() throws Exce nodeLoads, Map.of(List.of(), List.of(node)), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -217,7 +222,8 @@ public void testRebalance_GivenFirstModelToAdd_NotEnoughMemory() throws Exceptio nodeLoads, Map.of(), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -253,7 +259,8 @@ public void testRebalance_GivenFirstModelToAdd_ErrorDetectingNodeLoad() throws E nodeLoads, Map.of(), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -289,7 +296,8 @@ public void testRebalance_GivenProblemsOnMultipleNodes() throws Exception { nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -322,7 +330,8 @@ public void testRebalance_GivenFirstModelToAdd_FitsFully() throws Exception { nodeLoads, Map.of(List.of(), List.of(node1)), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -361,7 +370,8 @@ public void testRebalance_GivenModelToAdd_AndPreviousAssignments_AndTwoNodes_All nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(2))); @@ -425,7 +435,8 @@ public void testRebalance_GivenPreviousAssignments_AndNewNode() throws Exception nodeLoads, Map.of(List.of(), List.of(node1, node2, node3)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(2))); @@ -489,7 +500,8 @@ public void testRebalance_GivenPreviousAssignments_AndRemovedNode_AndRemainingNo nodeLoads, Map.of(List.of(), List.of(node1)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(2))); @@ -559,7 +571,8 @@ public void testRebalance_GivenPreviousAssignments_AndRemovedNode_AndRemainingNo nodeLoads, Map.of(List.of(), List.of(node1)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(2))); @@ -608,7 +621,8 @@ public void testRebalance_GivenFailedAssignment_RestartsAssignment() throws Exce nodeLoads, Map.of(List.of(), List.of(node1)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); assertThat(result.allAssignments(), is(aMapWithSize(1))); @@ -642,7 +656,8 @@ public void testRebalance_GivenLowPriorityModelToAdd_OnlyModel_NotEnoughMemory() nodeLoads, Map.of(), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(deploymentId); @@ -658,8 +673,8 @@ public void testRebalance_GivenLowPriorityModelToAdd_OnlyModel_NotEnoughMemory() public void testRebalance_GivenLowPriorityModelToAdd_NotEnoughMemoryNorProcessors() throws Exception { long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes(); - DiscoveryNode node1 = buildNode("node-1", nodeMemoryBytes, 1); - DiscoveryNode node2 = buildNode("node-2", nodeMemoryBytes, 1); + DiscoveryNode node1 = buildNode("node-1", nodeMemoryBytes, 8); + DiscoveryNode node2 = buildNode("node-2", nodeMemoryBytes, 8); Map nodeLoads = new HashMap<>(); nodeLoads.put(node1, NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build()); @@ -688,7 +703,8 @@ public void testRebalance_GivenLowPriorityModelToAdd_NotEnoughMemoryNorProcessor nodeLoads, Map.of(List.of("zone-1"), List.of(node1), List.of("zone-2"), List.of(node2)), Optional.of(taskParams1), - 1 + 1, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(deployment1); @@ -727,7 +743,8 @@ public void testRebalance_GivenMixedPriorityModels_NotEnoughMemoryForLowPriority nodeLoads, Map.of(List.of(), List.of(node1)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); { @@ -780,7 +797,8 @@ public void testRebalance_GivenMixedPriorityModels_TwoZones_EachNodeCanHoldOneMo nodeLoads, Map.of(List.of("zone-1"), List.of(node1), List.of("zone-2"), List.of(node2)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); List assignedNodes = new ArrayList<>(); @@ -834,7 +852,8 @@ public void testRebalance_GivenModelUsingAllCpu_FittingLowPriorityModelCanStart( nodeLoads, Map.of(List.of(), List.of(node1)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); { @@ -884,7 +903,8 @@ public void testRebalance_GivenMultipleLowPriorityModels_AndMultipleNodes() thro nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.empty(), - 1 + 1, + false ).rebalance().build(); { @@ -934,7 +954,8 @@ public void testRebalance_GivenNormalPriorityModelToLoad_EvictsLowPriorityModel( nodeLoads, Map.of(List.of(), List.of(node1)), Optional.of(taskParams2), - 1 + 1, + false ).rebalance().build(); { @@ -986,7 +1007,8 @@ public void testRebalance_GivenNormalPriorityModelToLoad_AndLowPriorityModelCanS nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.of(taskParams2), - 1 + 1, + false ).rebalance().build(); { @@ -1038,7 +1060,8 @@ public void testRebalance_GivenNormalPriorityModelToLoad_AndLowPriorityModelMust nodeLoads, Map.of(List.of(), List.of(node1, node2)), Optional.of(taskParams2), - 1 + 1, + false ).rebalance().build(); { @@ -1084,7 +1107,8 @@ public void testRebalance_GivenFirstModelToAdd_GivenScalingProcessorSetting() { nodeLoads, Map.of(List.of(), List.of(node)), Optional.of(taskParams), - 2 + 2, + false ).rebalance().build(); TrainedModelAssignment assignment = result.getDeploymentAssignment(modelId); @@ -1106,7 +1130,8 @@ public void testRebalance_GivenFirstModelToAdd_GivenScalingProcessorSetting() { nodeLoads, Map.of(List.of(), List.of(node)), Optional.of(taskParams), - 1 + 1, + false ).rebalance().build(); assignment = result.getDeploymentAssignment(modelId); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java index 3ecdd5000ba35..cbbb38f1d1ddd 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.assignment.planning; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Deployment; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Node; @@ -24,109 +25,248 @@ public class AssignmentPlanTests extends ESTestCase { public void testBuilderCtor_GivenDuplicateNode() { Node n = new Node("n_1", 100, 4); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, 0, 0); expectThrows(IllegalArgumentException.class, () -> AssignmentPlan.builder(List.of(n, n), List.of(m))); } public void testBuilderCtor_GivenDuplicateModel() { Node n = new Node("n_1", 100, 4); - Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0); + Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, 0, 0); expectThrows(IllegalArgumentException.class, () -> AssignmentPlan.builder(List.of(n), List.of(m, m))); } public void testAssignModelToNode_GivenNoPreviousAssignment() { - Node n = new Node("n_1", 100, 4); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(350).getBytes(), 4); - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + { // old memory format + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(40).getBytes(), 1, 2, Map.of(), 0, 0, 0); - assertThat(builder.getRemainingCores(n), equalTo(4)); - assertThat(builder.getRemainingMemory(n), equalTo(100L)); - assertThat(builder.getRemainingAllocations(m), equalTo(1)); - assertThat(builder.getRemainingThreads(m), equalTo(2)); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - builder.assignModelToNode(m, n, 1); + assertThat(builder.getRemainingCores(n), equalTo(4)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(350).getBytes())); + assertThat(builder.getRemainingAllocations(m), equalTo(1)); + assertThat(builder.getRemainingThreads(m), equalTo(2)); - assertThat(builder.getRemainingCores(n), equalTo(2)); - assertThat(builder.getRemainingMemory(n), equalTo(60L)); - assertThat(builder.getRemainingAllocations(m), equalTo(0)); - assertThat(builder.getRemainingThreads(m), equalTo(0)); + builder.assignModelToNode(m, n, 1); - AssignmentPlan plan = builder.build(); + assertThat(builder.getRemainingCores(n), equalTo(2)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(30).getBytes())); + assertThat(builder.getRemainingAllocations(m), equalTo(0)); + assertThat(builder.getRemainingThreads(m), equalTo(0)); - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + AssignmentPlan plan = builder.build(); + + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(true)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + } + { // new memory format + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(20).getBytes(), + 1, + 2, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(30).getBytes() + ); + + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + + assertThat(builder.getRemainingCores(n), equalTo(4)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(350).getBytes())); + assertThat(builder.getRemainingAllocations(m), equalTo(1)); + assertThat(builder.getRemainingThreads(m), equalTo(2)); + + builder.assignModelToNode(m, n, 1); + + assertThat(builder.getRemainingCores(n), equalTo(2)); + assertThat(builder.getRemainingMemory(n), equalTo(0L)); + assertThat(builder.getRemainingAllocations(m), equalTo(0)); + assertThat(builder.getRemainingThreads(m), equalTo(0)); + + AssignmentPlan plan = builder.build(); + + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(true)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + } } public void testAssignModelToNode_GivenNewPlanSatisfiesCurrentAssignment() { - Node n = new Node("n_1", 100, 4); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 2, 2, Map.of("n_1", 1), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(350).getBytes(), 4); + { // old memory format + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 2, + Map.of("n_1", 1), + 0, + 0, + 0 + ); - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - builder.assignModelToNode(m, n, 1); + builder.assignModelToNode(m, n, 1); - assertThat(builder.getRemainingCores(n), equalTo(2)); - assertThat(builder.getRemainingMemory(n), equalTo(100L)); - assertThat(builder.getRemainingAllocations(m), equalTo(1)); - assertThat(builder.getRemainingThreads(m), equalTo(2)); + assertThat(builder.getRemainingCores(n), equalTo(2)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(350).getBytes())); + assertThat(builder.getRemainingAllocations(m), equalTo(1)); + assertThat(builder.getRemainingThreads(m), equalTo(2)); - AssignmentPlan plan = builder.build(); + AssignmentPlan plan = builder.build(); - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(true)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + } + { // new memory format + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(25).getBytes(), + 2, + 2, + Map.of("n_1", 1), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(25).getBytes() + ); + + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + + builder.assignModelToNode(m, n, 1); + + assertThat(builder.getRemainingCores(n), equalTo(2)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(325).getBytes())); + assertThat(builder.getRemainingAllocations(m), equalTo(1)); + assertThat(builder.getRemainingThreads(m), equalTo(2)); + + AssignmentPlan plan = builder.build(); + + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(true)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + + } } public void testAssignModelToNode_GivenNewPlanDoesNotSatisfyCurrentAssignment() { - Node n = new Node("n_1", 100, 4); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 2, 2, Map.of("n_1", 2), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 4); + { + // old memory format + Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 0, 0, 0); - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - builder.assignModelToNode(m, n, 1); + builder.assignModelToNode(m, n, 1); - assertThat(builder.getRemainingCores(n), equalTo(2)); - assertThat(builder.getRemainingMemory(n), equalTo(100L)); - assertThat(builder.getRemainingAllocations(m), equalTo(1)); - assertThat(builder.getRemainingThreads(m), equalTo(2)); + assertThat(builder.getRemainingCores(n), equalTo(2)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(300).getBytes())); + assertThat(builder.getRemainingAllocations(m), equalTo(1)); + assertThat(builder.getRemainingThreads(m), equalTo(2)); - AssignmentPlan plan = builder.build(); + AssignmentPlan plan = builder.build(); - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(false)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(false)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + } + { + // new memory format + Deployment m = new Deployment( + "m_1", + ByteSizeValue.ofMb(25).getBytes(), + 2, + 2, + Map.of("n_1", 2), + 0, + ByteSizeValue.ofMb(250).getBytes(), + ByteSizeValue.ofMb(25).getBytes() + ); + + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + + builder.assignModelToNode(m, n, 1); + + assertThat(builder.getRemainingCores(n), equalTo(2)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(275).getBytes())); + assertThat(builder.getRemainingAllocations(m), equalTo(1)); + assertThat(builder.getRemainingThreads(m), equalTo(2)); + + AssignmentPlan plan = builder.build(); + + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(false)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + } } public void testAssignModelToNode_GivenPreviouslyUnassignedModelDoesNotFit() { - Node n = new Node("n_1", 100, 4); - Deployment m = new AssignmentPlan.Deployment("m_1", 101, 2, 2, Map.of(), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(340 - 1).getBytes(), 4); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 2, Map.of(), 0, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 1)); - assertThat(e.getMessage(), equalTo("not enough memory on node [n_1] to assign model [m_1]")); + assertThat(e.getMessage(), equalTo("not enough memory on node [n_1] to assign [1] allocations to deployment [m_1]")); } public void testAssignModelToNode_GivenPreviouslyAssignedModelDoesNotFit() { - Node n = new Node("n_1", 100, 4); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 101, 2, 2, Map.of("n_1", 1), 0); + { // old memory format + Node n = new Node("n_1", ByteSizeValue.ofMb(340 - 1).getBytes(), 4); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(50).getBytes(), + 2, + 2, + Map.of("n_1", 1), + 0, + 0, + 0 + ); - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - builder.assignModelToNode(m, n, 2); - AssignmentPlan plan = builder.build(); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - assertThat(plan.models(), contains(m)); - assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 2))); + builder.assignModelToNode(m, n, 2); + AssignmentPlan plan = builder.build(); + + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(true)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 2))); + } + { // new memory format + Node n = new Node("n_1", ByteSizeValue.ofMb(340 - 1).getBytes(), 4); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 2, + Map.of("n_1", 1), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(5).getBytes() + ); + + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + + builder.assignModelToNode(m, n, 2); + AssignmentPlan plan = builder.build(); + + assertThat(plan.models(), contains(m)); + assertThat(plan.satisfiesCurrentAssignments(), is(true)); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 2))); + } } public void testAssignModelToNode_GivenNotEnoughCores_AndSingleThreadPerAllocation() { - Node n = new Node("n_1", 100, 4); - Deployment m = new AssignmentPlan.Deployment("m_1", 100, 5, 1, Map.of(), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 4); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 5, 1, Map.of(), 0, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 5)); @@ -138,8 +278,8 @@ public void testAssignModelToNode_GivenNotEnoughCores_AndSingleThreadPerAllocati } public void testAssignModelToNode_GivenNotEnoughCores_AndMultipleThreadsPerAllocation() { - Node n = new Node("n_1", 100, 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of(), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 5); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 3, 2, Map.of(), 0, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 3)); @@ -151,13 +291,22 @@ public void testAssignModelToNode_GivenNotEnoughCores_AndMultipleThreadsPerAlloc } public void testAssignModelToNode_GivenSameModelAssignedTwice() { - Node n = new Node("n_1", 100, 8); - Deployment m = new AssignmentPlan.Deployment("m_1", 60, 4, 2, Map.of(), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8); + Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(50).getBytes(), + 4, + 2, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); assertThat(builder.getRemainingCores(n), equalTo(8)); - assertThat(builder.getRemainingMemory(n), equalTo(100L)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(1000).getBytes())); assertThat(builder.getRemainingAllocations(m), equalTo(4)); assertThat(builder.getRemainingThreads(m), equalTo(8)); assertThat(builder.canAssign(m, n, 1), is(true)); @@ -165,7 +314,7 @@ public void testAssignModelToNode_GivenSameModelAssignedTwice() { builder.assignModelToNode(m, n, 1); assertThat(builder.getRemainingCores(n), equalTo(6)); - assertThat(builder.getRemainingMemory(n), equalTo(40L)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(600).getBytes())); assertThat(builder.getRemainingAllocations(m), equalTo(3)); assertThat(builder.getRemainingThreads(m), equalTo(6)); assertThat(builder.canAssign(m, n, 2), is(true)); @@ -173,7 +322,7 @@ public void testAssignModelToNode_GivenSameModelAssignedTwice() { builder.assignModelToNode(m, n, 2); assertThat(builder.getRemainingCores(n), equalTo(2)); - assertThat(builder.getRemainingMemory(n), equalTo(40L)); + assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(500).getBytes())); assertThat(builder.getRemainingAllocations(m), equalTo(1)); assertThat(builder.getRemainingThreads(m), equalTo(2)); @@ -186,7 +335,7 @@ public void testAssignModelToNode_GivenSameModelAssignedTwice() { public void testCanAssign_GivenPreviouslyUnassignedModelDoesNotFit() { Node n = new Node("n_1", 100, 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 101, 1, 1, Map.of(), 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 101, 1, 1, Map.of(), 0, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -194,17 +343,33 @@ public void testCanAssign_GivenPreviouslyUnassignedModelDoesNotFit() { } public void testCanAssign_GivenPreviouslyAssignedModelDoesNotFit() { - Node n = new Node("n_1", 100, 5); - Deployment m = new AssignmentPlan.Deployment("m_1", 101, 1, 1, Map.of("n_1", 1), 0); - - AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); - - assertThat(builder.canAssign(m, n, 1), is(true)); + Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); + { + // old memory format + Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(31).getBytes(), 1, 1, Map.of("n_1", 1), 0, 0, 0); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + assertThat(builder.canAssign(m, n, 1), is(true)); + } + { + // new memory format + Deployment m = new Deployment( + "m_1", + ByteSizeValue.ofMb(25).getBytes(), + 1, + 1, + Map.of("n_1", 1), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); + assertThat(builder.canAssign(m, n, 1), is(true)); + } } public void testCanAssign_GivenEnoughMemory() { - Node n = new Node("n_1", 100, 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of(), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 5); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 3, 2, Map.of(), 0, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -216,16 +381,25 @@ public void testCanAssign_GivenEnoughMemory() { public void testCompareTo_GivenDifferenceInPreviousAssignments() { AssignmentPlan planSatisfyingPreviousAssignments; AssignmentPlan planNotSatisfyingPreviousAssignments; - Node n = new Node("n_1", 100, 5); + Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); { - Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of("n_1", 2), 0); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 2), 0, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planSatisfyingPreviousAssignments = builder.build(); } { - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of("n_1", 3), 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 3, + 2, + Map.of("n_1", 3), + 0, + 0, + 0 + ); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planNotSatisfyingPreviousAssignments = builder.build(); @@ -238,8 +412,17 @@ public void testCompareTo_GivenDifferenceInPreviousAssignments() { public void testCompareTo_GivenDifferenceInAllocations() { AssignmentPlan planWithMoreAllocations; AssignmentPlan planWithFewerAllocations; - Node n = new Node("n_1", 100, 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of("n_1", 1), 0); + Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 3, + 2, + Map.of("n_1", 1), + 0, + 0, + 0 + ); { AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -259,16 +442,25 @@ public void testCompareTo_GivenDifferenceInAllocations() { public void testCompareTo_GivenDifferenceInMemory() { AssignmentPlan planUsingMoreMemory; AssignmentPlan planUsingLessMemory; - Node n = new Node("n_1", 100, 5); + Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); { - Deployment m = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of("n_1", 1), 0); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 1), 0, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planUsingMoreMemory = builder.build(); } { - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 99, 3, 2, Map.of("n_1", 1), 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(29).getBytes(), + 3, + 2, + Map.of("n_1", 1), + 0, + 0, + 0 + ); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planUsingLessMemory = builder.build(); @@ -279,26 +471,96 @@ public void testCompareTo_GivenDifferenceInMemory() { } public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 50, 1, 2, Map.of(), 0); - AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment("m_2", 30, 2, 1, Map.of(), 0); - AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment("m_3", 20, 4, 1, Map.of(), 0); - AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) - .assignModelToNode(deployment1, node1, 1) - .assignModelToNode(deployment2, node2, 2) - .assignModelToNode(deployment3, node1, 2) - .assignModelToNode(deployment3, node2, 2) - .build(); - assertThat(plan.satisfiesAllModels(), is(true)); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + { + // old memory format + AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(50).getBytes(), + 1, + 2, + Map.of(), + 0, + 0, + 0 + ); + AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment( + "m_2", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 1, + Map.of(), + 0, + 0, + 0 + ); + AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment( + "m_3", + ByteSizeValue.ofMb(20).getBytes(), + 4, + 1, + Map.of(), + 0, + 0, + 0 + ); + AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) + .assignModelToNode(deployment1, node1, 1) + .assignModelToNode(deployment2, node2, 2) + .assignModelToNode(deployment3, node1, 2) + .assignModelToNode(deployment3, node2, 2) + .build(); + assertThat(plan.satisfiesAllModels(), is(true)); + } + { + // new memory format + AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(50).getBytes(), + 1, + 2, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment( + "m_2", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment( + "m_3", + ByteSizeValue.ofMb(20).getBytes(), + 4, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) + .assignModelToNode(deployment1, node1, 1) + .assignModelToNode(deployment2, node2, 2) + .assignModelToNode(deployment3, node1, 2) + .assignModelToNode(deployment3, node2, 2) + .build(); + assertThat(plan.satisfiesAllModels(), is(true)); + } } public void testSatisfiesAllModels_GivenOneModelHasOneAllocationLess() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 50, 1, 2, Map.of(), 0); - AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment("m_2", 30, 2, 1, Map.of(), 0); - Deployment deployment3 = new Deployment("m_3", 20, 4, 1, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, 0, 0); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) .assignModelToNode(deployment1, node1, 1) .assignModelToNode(deployment2, node2, 2) @@ -309,11 +571,11 @@ public void testSatisfiesAllModels_GivenOneModelHasOneAllocationLess() { } public void testArePreviouslyAssignedModelsAssigned_GivenTrue() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 50, 1, 2, Map.of(), 3); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", 30, 2, 1, Map.of(), 4); - AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment("m_3", 20, 4, 1, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, 0, 0); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) .assignModelToNode(deployment1, node1, 1) .assignModelToNode(deployment2, node2, 1) @@ -322,10 +584,10 @@ public void testArePreviouslyAssignedModelsAssigned_GivenTrue() { } public void testArePreviouslyAssignedModelsAssigned_GivenFalse() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment1 = new Deployment("m_1", 50, 1, 2, Map.of(), 3); - AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment("m_2", 30, 2, 1, Map.of(), 4); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, 0, 0); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) .assignModelToNode(deployment1, node1, 1) .build(); @@ -333,12 +595,39 @@ public void testArePreviouslyAssignedModelsAssigned_GivenFalse() { } public void testCountPreviouslyAssignedThatAreStillAssigned() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 50, 1, 2, Map.of(), 3); - AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment("m_2", 30, 2, 1, Map.of(), 4); - AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment("m_3", 20, 4, 1, Map.of(), 1); - AssignmentPlan.Deployment deployment4 = new AssignmentPlan.Deployment("m_4", 20, 4, 1, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, 0, 0); + AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment( + "m_2", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 1, + Map.of(), + 4, + 0, + 0 + ); + AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment( + "m_3", + ByteSizeValue.ofMb(20).getBytes(), + 4, + 1, + Map.of(), + 1, + 0, + 0 + ); + AssignmentPlan.Deployment deployment4 = new AssignmentPlan.Deployment( + "m_4", + ByteSizeValue.ofMb(20).getBytes(), + 4, + 1, + Map.of(), + 0, + 0, + 0 + ); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3, deployment4)) .assignModelToNode(deployment1, node1, 1) .assignModelToNode(deployment2, node2, 1) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java index 82a291a8d9fb2..6a72ccf4c4445 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java @@ -33,50 +33,144 @@ public class AssignmentPlannerTests extends ESTestCase { + private static long scaleNodeSize(long nodeMemory) { + // 240 Mb is the size in StartTrainedModelDeploymentAction.MEMORY_OVERHEAD + return ByteSizeValue.ofMb(240 + 2 * nodeMemory).getBytes(); + } + public void testModelThatDoesNotFitInMemory() { - List nodes = List.of(new Node("n_1", 100, 4)); - Deployment deployment = new AssignmentPlan.Deployment("m_1", 101, 4, 1, Map.of(), 0); - AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); - assertThat(plan.assignments(deployment).isEmpty(), is(true)); + { // Without perDeploymentMemory and perAllocationMemory specified + List nodes = List.of(new Node("n_1", scaleNodeSize(50), 4)); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(51).getBytes(), 4, 1, Map.of(), 0, 0, 0); + AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); + assertThat(plan.assignments(deployment).isEmpty(), is(true)); + } + { // With perDeploymentMemory and perAllocationMemory specified + List nodes = List.of(new Node("n_1", scaleNodeSize(55), 4)); + Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(50).getBytes(), + 4, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(250).getBytes(), + ByteSizeValue.ofMb(51).getBytes() + ); + AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); + assertThat(plan.assignments(deployment).isEmpty(), is(true)); + } } public void testModelWithThreadsPerAllocationNotFittingOnAnyNode() { - List nodes = List.of(new Node("n_1", 100, 4), new Node("n_2", 100, 5)); - Deployment deployment = new AssignmentPlan.Deployment("m_1", 1, 1, 6, Map.of(), 0); + List nodes = List.of(new Node("n_1", scaleNodeSize(100), 4), new Node("n_2", scaleNodeSize(100), 5)); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(1).getBytes(), 1, 6, Map.of(), 0, 0, 0); AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); assertThat(plan.assignments(deployment).isEmpty(), is(true)); } public void testSingleModelThatFitsFullyOnSingleNode() { { - Node node = new Node("n_1", 100, 4); - Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 1, Map.of(), 0); + Node node = new Node("n_1", scaleNodeSize(100), 4); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, 0, 0); + AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); + assertModelFullyAssignedToNode(plan, deployment, node); + } + { + Node node = new Node("n_1", scaleNodeSize(1000), 8); + Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(1000).getBytes(), 8, 1, Map.of(), 0, 0, 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertModelFullyAssignedToNode(plan, deployment, node); } { - Node node = new Node("n_1", 1000, 8); - Deployment deployment = new Deployment("m_1", 1000, 8, 1, Map.of(), 0); + Node node = new Node("n_1", scaleNodeSize(10000), 16); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(10000).getBytes(), + 1, + 16, + Map.of(), + 0, + 0, + 0 + ); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertModelFullyAssignedToNode(plan, deployment, node); } { - Node node = new Node("n_1", 10000, 16); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 10000, 1, 16, Map.of(), 0); + Node node = new Node("n_1", scaleNodeSize(100), 4); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, 0, 0); + AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); + assertModelFullyAssignedToNode(plan, deployment, node); + } + } + + public void testSingleModelThatFitsFullyOnSingleNode_NewMemoryFields() { + { + Node node = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 4); + Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 1, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(100).getBytes() + ); + AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); + assertModelFullyAssignedToNode(plan, deployment, node); + } + { + Node node = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8); + Deployment deployment = new Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 8, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(100).getBytes(), + ByteSizeValue.ofMb(100).getBytes() + ); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertModelFullyAssignedToNode(plan, deployment, node); } } public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFullyAssignedOnOneNode() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment = new Deployment("m_1", 100, 4, 1, Map.of(), 0); + Node node1 = new Node("n_1", scaleNodeSize(100), 4); + Node node2 = new Node("n_2", scaleNodeSize(100), 4); + AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 4, 1, Map.of(), 0, 0, 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); Map assignments = plan.assignments(deployment).get(); - if (assignments.get(node1) > 0) { + if (assignments.get(node1) != null) { + assertThat(assignments.get(node1), equalTo(4)); + } else { + assertThat(assignments.get(node2), equalTo(4)); + } + } + + public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFullyAssignedOnOneNode_NewMemoryFields() { + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + AssignmentPlan.Deployment deployment = new Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 4, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(150).getBytes() + ); + + AssignmentPlan plan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); + + Map assignments = plan.assignments(deployment).get(); + if (assignments.get(node1) != null) { assertThat(assignments.get(node1), equalTo(4)); } else { assertThat(assignments.get(node2), equalTo(4)); @@ -84,10 +178,53 @@ public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFully } public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerAllocation() { - AssignmentPlan.Deployment deployment = new Deployment("m_1", 30, 10, 1, Map.of(), 0); + AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 1, Map.of(), 0, 0, 0); + // Single node + { + Node node = new Node("n_1", scaleNodeSize(100), 4); + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); + assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node), equalTo(4)); + } + // Two nodes + { + Node node1 = new Node("n_1", scaleNodeSize(100), 4); + Node node2 = new Node("n_2", scaleNodeSize(100), 2); + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); + assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1), equalTo(4)); + assertThat(assignments.get(node2), equalTo(2)); + } + // Three nodes + { + Node node1 = new Node("n_1", scaleNodeSize(100), 4); + Node node2 = new Node("n_2", scaleNodeSize(100), 2); + Node node3 = new Node("n_3", scaleNodeSize(100), 3); + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); + assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1), equalTo(4)); + assertThat(assignments.get(node2), equalTo(2)); + assertThat(assignments.get(node3), equalTo(3)); + } + } + + public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerAllocation_NewMemoryFields() { + AssignmentPlan.Deployment deployment = new Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 10, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(100).getBytes() + ); // Single node { - Node node = new Node("n_1", 100, 4); + Node node = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -95,8 +232,8 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA } // Two nodes { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 2); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(600).getBytes(), 2); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -105,9 +242,9 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA } // Three nodes { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 2); - Node node3 = new Node("n_3", 100, 3); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(600).getBytes(), 2); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(700).getBytes(), 3); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -118,14 +255,105 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA } public void testMultipleModelsAndNodesWithSingleSolution() { - Node node1 = new Node("n_1", 100, 7); - Node node2 = new Node("n_2", 100, 7); - Node node3 = new Node("n_3", 100, 2); - Node node4 = new Node("n_4", 100, 2); - Deployment deployment1 = new Deployment("m_1", 50, 2, 4, Map.of(), 0); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", 50, 2, 3, Map.of(), 0); - Deployment deployment3 = new AssignmentPlan.Deployment("m_3", 50, 1, 2, Map.of(), 0); - AssignmentPlan.Deployment deployment4 = new AssignmentPlan.Deployment("m_4", 50, 2, 1, Map.of(), 0); + Node node1 = new Node("n_1", 2 * scaleNodeSize(50), 7); + Node node2 = new Node("n_2", 2 * scaleNodeSize(50), 7); + Node node3 = new Node("n_3", 2 * scaleNodeSize(50), 2); + Node node4 = new Node("n_4", 2 * scaleNodeSize(50), 2); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 4, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 2, 3, Map.of(), 0, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, 0, 0); + Deployment deployment4 = new Deployment("m_4", ByteSizeValue.ofMb(50).getBytes(), 2, 1, Map.of(), 0, 0, 0); + + AssignmentPlan plan = new AssignmentPlanner( + List.of(node1, node2, node3, node4), + List.of(deployment1, deployment2, deployment3, deployment4) + ).computePlan(); + + { + assertThat(plan.assignments(deployment1).isPresent(), is(true)); + Map assignments = plan.assignments(deployment1).get(); + assertThat(assignments.get(node1), equalTo(1)); + assertThat(assignments.get(node2), equalTo(1)); + assertThat(assignments.get(node3), is(nullValue())); + assertThat(assignments.get(node4), is(nullValue())); + } + { + assertThat(plan.assignments(deployment2).isPresent(), is(true)); + Map assignments = plan.assignments(deployment2).get(); + assertThat(assignments.get(node1), equalTo(1)); + assertThat(assignments.get(node2), equalTo(1)); + assertThat(assignments.get(node3), is(nullValue())); + assertThat(assignments.get(node4), is(nullValue())); + } + { + assertThat(plan.assignments(deployment3).isPresent(), is(true)); + Map assignments = plan.assignments(deployment3).get(); + assertThat(assignments.get(node1), is(nullValue())); + assertThat(assignments.get(node2), is(nullValue())); + // Will either be on node 3 or 4 + Node assignedNode = assignments.get(node3) != null ? node3 : node4; + Node otherNode = assignedNode.equals(node3) ? node4 : node3; + assertThat(assignments.get(assignedNode), equalTo(1)); + assertThat(assignments.get(otherNode), is(nullValue())); + } + { + assertThat(plan.assignments(deployment4).isPresent(), is(true)); + Map assignments = plan.assignments(deployment4).get(); + assertThat(assignments.get(node1), is(nullValue())); + assertThat(assignments.get(node2), is(nullValue())); + // Will either be on node 3 or 4 + Node assignedNode = assignments.get(node3) != null ? node3 : node4; + Node otherNode = assignedNode.equals(node3) ? node4 : node3; + assertThat(assignments.get(assignedNode), equalTo(2)); + assertThat(assignments.get(otherNode), is(nullValue())); + } + } + + public void testMultipleModelsAndNodesWithSingleSolution_NewMemoryFields() { + Node node1 = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 7); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(800).getBytes(), 7); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(900).getBytes(), 2); + Node node4 = new Node("n_4", ByteSizeValue.ofMb(900).getBytes(), 2); + Deployment deployment1 = new Deployment( + "m_1", + ByteSizeValue.ofMb(50).getBytes(), + 2, + 4, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ); + Deployment deployment2 = new Deployment( + "m_2", + ByteSizeValue.ofMb(50).getBytes(), + 2, + 3, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ); + Deployment deployment3 = new Deployment( + "m_3", + ByteSizeValue.ofMb(50).getBytes(), + 1, + 2, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ); + Deployment deployment4 = new Deployment( + "m_4", + ByteSizeValue.ofMb(50).getBytes(), + 2, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ); AssignmentPlan plan = new AssignmentPlanner( List.of(node1, node2, node3, node4), @@ -173,10 +401,53 @@ public void testMultipleModelsAndNodesWithSingleSolution() { } public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerAllocation() { - Deployment deployment = new AssignmentPlan.Deployment("m_1", 30, 10, 3, Map.of(), 0); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 3, Map.of(), 0, 0, 0); + // Single node + { + Node node = new Node("n_1", scaleNodeSize(100), 4); + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); + assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node), equalTo(1)); + } + // Two nodes + { + Node node1 = new Node("n_1", scaleNodeSize(100), 4); + Node node2 = new Node("n_2", scaleNodeSize(100), 8); + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); + assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1), equalTo(1)); + assertThat(assignments.get(node2), equalTo(2)); + } + // Three nodes + { + Node node1 = new Node("n_1", scaleNodeSize(100), 4); + Node node2 = new Node("n_2", scaleNodeSize(100), 7); + Node node3 = new Node("n_3", scaleNodeSize(100), 15); + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); + assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1), equalTo(1)); + assertThat(assignments.get(node2), equalTo(2)); + assertThat(assignments.get(node3), equalTo(5)); + } + } + + public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerAllocation_NewMemoryFields() { + Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(50).getBytes(), + 10, + 3, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ); // Single node { - Node node = new Node("n_1", 100, 4); + Node node = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -184,8 +455,8 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA } // Two nodes { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 8); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(800).getBytes(), 8); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -194,9 +465,9 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA } // Three nodes { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 7); - Node node3 = new Node("n_3", 100, 15); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(800).getBytes(), 7); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(800).getBytes(), 15); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); Map assignments = assignmentPlan.assignments(deployment).get(); @@ -207,8 +478,17 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA } public void testModelWithPreviousAssignmentAndNoMoreCoresAvailable() { - Node node = new Node("n_1", 100, 4); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 30, 4, 1, Map.of("n_1", 4), 0); + Node node = new Node("n_1", scaleNodeSize(100), 4); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 4, + 1, + Map.of("n_1", 4), + 0, + 0, + 0 + ); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertThat(plan.assignments(deployment).isPresent(), is(true)); @@ -217,26 +497,117 @@ public void testModelWithPreviousAssignmentAndNoMoreCoresAvailable() { public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation() { List nodes = List.of( - new Node("n_1", ByteSizeValue.ofGb(6).getBytes(), 8), - new Node("n_2", ByteSizeValue.ofGb(6).getBytes(), 8), - new Node("n_3", ByteSizeValue.ofGb(6).getBytes(), 8), - new Node("n_4", ByteSizeValue.ofGb(6).getBytes(), 8), - new Node("n_5", ByteSizeValue.ofGb(16).getBytes(), 16), - new Node("n_6", ByteSizeValue.ofGb(8).getBytes(), 16) + new Node("n_1", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_2", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_3", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_4", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_5", ByteSizeValue.ofGb(64).getBytes(), 16), + new Node("n_6", ByteSizeValue.ofGb(32).getBytes(), 16) ); List deployments = List.of( - new Deployment("m_1", ByteSizeValue.ofGb(4).getBytes(), 10, 1, Map.of("n_1", 5), 0), - new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of("n_3", 2), 0), - new AssignmentPlan.Deployment("m_3", ByteSizeValue.ofGb(3).getBytes(), 3, 1, Map.of(), 0), - new Deployment("m_4", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of("n_3", 2), 0), - new Deployment("m_5", ByteSizeValue.ofGb(6).getBytes(), 2, 1, Map.of(), 0), - new Deployment("m_6", ByteSizeValue.ofGb(1).getBytes(), 12, 1, Map.of(), 0), - new AssignmentPlan.Deployment("m_7", ByteSizeValue.ofGb(1).getBytes() / 2, 12, 1, Map.of("n_2", 6), 0), - new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0), - new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0), - new AssignmentPlan.Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0), - new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0), - new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0) + new Deployment("m_1", ByteSizeValue.ofGb(4).getBytes(), 10, 1, Map.of("n_1", 5), 0, 0, 0), + new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of("n_3", 2), 0, 0, 0), + new AssignmentPlan.Deployment("m_3", ByteSizeValue.ofGb(3).getBytes(), 3, 1, Map.of(), 0, 0, 0), + new Deployment("m_4", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of("n_3", 2), 0, 0, 0), + new Deployment("m_5", ByteSizeValue.ofGb(6).getBytes(), 2, 1, Map.of(), 0, 0, 0), + new Deployment("m_6", ByteSizeValue.ofGb(1).getBytes(), 12, 1, Map.of(), 0, 0, 0), + new AssignmentPlan.Deployment("m_7", ByteSizeValue.ofGb(1).getBytes() / 2, 12, 1, Map.of("n_2", 6), 0, 0, 0), + new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, 0, 0), + new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, 0, 0), + new AssignmentPlan.Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, 0, 0), + new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, 0, 0), + new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, 0, 0) + ); + + AssignmentPlan assignmentPlan = new AssignmentPlanner(nodes, deployments).computePlan(); + + int usedCores = 0; + for (AssignmentPlan.Deployment m : deployments) { + Map assignments = assignmentPlan.assignments(m).orElse(Map.of()); + usedCores += assignments.values().stream().mapToInt(Integer::intValue).sum(); + } + assertThat(usedCores, equalTo(64)); + + assertPreviousAssignmentsAreSatisfied(deployments, assignmentPlan); + } + + public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation_NewMemoryFields() { + List nodes = List.of( + new Node("n_1", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_2", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_3", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_4", ByteSizeValue.ofGb(18).getBytes(), 8), + new Node("n_5", ByteSizeValue.ofGb(64).getBytes(), 16), + new Node("n_6", ByteSizeValue.ofGb(32).getBytes(), 16) + ); + // Use mix of old and new memory fields + List deployments = List.of( + new Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 10, + 1, + Map.of("n_1", 5), + 0, + ByteSizeValue.ofMb(400).getBytes(), + ByteSizeValue.ofMb(100).getBytes() + ), + new Deployment("m_2", ByteSizeValue.ofMb(100).getBytes(), 3, 1, Map.of("n_3", 2), 0, 0, 0), + new Deployment( + "m_3", + ByteSizeValue.ofMb(50).getBytes(), + 3, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ), + new Deployment( + "m_4", + ByteSizeValue.ofMb(50).getBytes(), + 4, + 1, + Map.of("n_3", 2), + 0, + ByteSizeValue.ofMb(400).getBytes(), + ByteSizeValue.ofMb(100).getBytes() + ), + new Deployment( + "m_5", + ByteSizeValue.ofMb(500).getBytes(), + 2, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(800).getBytes(), + ByteSizeValue.ofMb(100).getBytes() + ), + new Deployment( + "m_6", + ByteSizeValue.ofMb(50).getBytes(), + 12, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(50).getBytes(), + ByteSizeValue.ofMb(20).getBytes() + ), + new Deployment( + "m_7", + ByteSizeValue.ofMb(50).getBytes(), + 12, + 1, + Map.of("n_2", 6), + 0, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ), + new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, 0, 0), + new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, 0, 0), + new Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, 0, 0), + new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, 0, 0), + new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, 0, 0) ); AssignmentPlan assignmentPlan = new AssignmentPlanner(nodes, deployments).computePlan(); @@ -297,6 +668,9 @@ public void testRandomBenchmark() { StopWatch stopWatch = new StopWatch(); stopWatch.start(); AssignmentPlan assignmentPlan = solver.computePlan(); + for (Node node : nodes) { + assertThat(assignmentPlan.getRemainingNodeMemory(node.id()), greaterThanOrEqualTo(0L)); + } stopWatch.stop(); Quality quality = computeQuality(nodes, deployments, assignmentPlan); @@ -336,7 +710,16 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode .stream() .collect(Collectors.toMap(e -> e.getKey().id(), Map.Entry::getValue)); previousModelsPlusNew.add( - new AssignmentPlan.Deployment(m.id(), m.memoryBytes(), m.allocations(), m.threadsPerAllocation(), previousAssignments, 0) + new AssignmentPlan.Deployment( + m.id(), + m.memoryBytes(), + m.allocations(), + m.threadsPerAllocation(), + previousAssignments, + 0, + 0, + 0 + ) ); } previousModelsPlusNew.add(randomModel("new")); @@ -347,18 +730,20 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode } public void testGivenLargerModelWithPreviousAssignmentsAndSmallerModelWithoutAssignments() { - Node node1 = new Node("n_1", ByteSizeValue.ofGb(2).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofGb(2).getBytes(), 2); - Node node3 = new Node("n_3", ByteSizeValue.ofGb(2).getBytes(), 2); + Node node1 = new Node("n_1", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2); + Node node2 = new Node("n_2", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2); + Node node3 = new Node("n_3", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2); Deployment deployment1 = new AssignmentPlan.Deployment( "m_1", ByteSizeValue.ofMb(1200).getBytes(), 3, 1, Map.of("n_1", 2, "n_2", 1), + 0, + 0, 0 ); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 2, 1, Map.of(), 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 2, 1, Map.of(), 0, 0, 0); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment1, deployment2)) .computePlan(); assertThat(assignmentPlan.getRemainingNodeMemory("n_1"), greaterThanOrEqualTo(0L)); @@ -381,15 +766,17 @@ public void testGivenLargerModelWithPreviousAssignmentsAndSmallerModelWithoutAss } public void testModelWithoutCurrentAllocationsGetsAssignedIfAllocatedPreviously() { - Node node1 = new Node("n_1", ByteSizeValue.ofGb(4).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofGb(4).getBytes(), 2); + Node node1 = new Node("n_1", ByteSizeValue.ofGb(6).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofGb(6).getBytes(), 2); AssignmentPlan.Deployment deployment1 = new Deployment( "m_1", ByteSizeValue.ofMb(1200).getBytes(), 3, 1, Map.of("n_1", 2, "n_2", 1), - 3 + 3, + 0, + 0 ); AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment( "m_2", @@ -397,35 +784,84 @@ public void testModelWithoutCurrentAllocationsGetsAssignedIfAllocatedPreviously( 1, 2, Map.of(), - 1 + 1, + 0, + 0 ); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1, deployment2)).computePlan(); Map> indexedBasedPlan = convertToIdIndexed(assignmentPlan); assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2")); - assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); - assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1))); + if (indexedBasedPlan.get("m_2").containsKey("n_1")) { + assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_2", 2))); + assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_1", 1))); + } else { + assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); + assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1))); + } assertThat(assignmentPlan.getRemainingNodeMemory("n_1"), greaterThanOrEqualTo(0L)); assertThat(assignmentPlan.getRemainingNodeMemory("n_2"), greaterThanOrEqualTo(0L)); } public void testGivenPreviouslyAssignedModels_CannotAllBeAllocated() { - Node node1 = new Node("n_1", ByteSizeValue.ofGb(2).getBytes(), 2); - AssignmentPlan.Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(1200).getBytes(), 1, 1, Map.of(), 1); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 1, 1, Map.of(), 1); + Node node1 = new Node("n_1", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2); + AssignmentPlan.Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(1200).getBytes(), 1, 1, Map.of(), 1, 0, 0); + AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 1, 1, Map.of(), 1, 0, 0); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1), List.of(deployment1, deployment2)).computePlan(); assertThat(assignmentPlan.countPreviouslyAssignedModelsThatAreStillAssigned(), equalTo(1L)); } + public void testGivenClusterResize_AllocationShouldNotExceedMemoryConstraints() { + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1840).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, 0, 0); + + // First only start m_1 + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1)).computePlan(); + + Map> indexedBasedPlan = convertToIdIndexed(assignmentPlan); + assertThat(indexedBasedPlan.keySet(), hasItems("m_1")); + assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); + + // Then start m_2 + assignmentPlan = new AssignmentPlanner( + List.of(node1, node2), + Stream.concat(createModelsFromPlan(assignmentPlan).stream(), Stream.of(deployment2)).toList() + ).computePlan(); + + indexedBasedPlan = convertToIdIndexed(assignmentPlan); + assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2")); + assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); + assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1))); + + // Then start m_3 + assignmentPlan = new AssignmentPlanner( + List.of(node1, node2), + Stream.concat(createModelsFromPlan(assignmentPlan).stream(), Stream.of(deployment3)).toList() + ).computePlan(); + + indexedBasedPlan = convertToIdIndexed(assignmentPlan); + assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2", "m_3")); + assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); + assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1))); + assertThat(indexedBasedPlan.get("m_3"), equalTo(Map.of("n_2", 1))); + + // First, one node goes away. + assignmentPlan = new AssignmentPlanner(List.of(node1), createModelsFromPlan(assignmentPlan)).computePlan(); + assertThat(assignmentPlan.getRemainingNodeMemory("n_1"), greaterThanOrEqualTo(0L)); + } + public void testGivenClusterResize_ShouldAllocateEachModelAtLeastOnce() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1200).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1200).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(2600).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(2600).getBytes(), 2); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, 0, 0); // First only start m_1 AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1)).computePlan(); @@ -458,8 +894,8 @@ public void testGivenClusterResize_ShouldAllocateEachModelAtLeastOnce() { assertThat(indexedBasedPlan.get("m_3"), equalTo(Map.of("n_2", 1))); // Now the cluster starts getting resized. - Node node3 = new Node("n_3", ByteSizeValue.ofMb(2400).getBytes(), 2); - Node node4 = new Node("n_4", ByteSizeValue.ofMb(2400).getBytes(), 2); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(2600).getBytes(), 2); + Node node4 = new Node("n_4", ByteSizeValue.ofMb(2600).getBytes(), 2); // First, one node goes away. assignmentPlan = new AssignmentPlanner(List.of(node1), createModelsFromPlan(assignmentPlan)).computePlan(); @@ -492,11 +928,65 @@ public void testGivenClusterResize_ShouldAllocateEachModelAtLeastOnce() { public void testGivenClusterResize_ShouldRemoveAllocatedModels() { // Ensure that plan is removing previously allocated models if not enough memory is available - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1200).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1200).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 1, 1, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1840).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 1, 1, Map.of(), 0, 0, 0); + + // Create a plan where all deployments are assigned at least once + AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) + .computePlan(); + Map> indexedBasedPlan = convertToIdIndexed(assignmentPlan); + assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2", "m_3")); + assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); + assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1))); + assertThat(indexedBasedPlan.get("m_3"), equalTo(Map.of("n_2", 1))); + assertThat(assignmentPlan.getRemainingNodeMemory(node1.id()), greaterThanOrEqualTo(0L)); + assertThat(assignmentPlan.getRemainingNodeMemory(node2.id()), greaterThanOrEqualTo(0L)); + + // Now the cluster starts getting resized. Ensure that resources are not over-allocated. + assignmentPlan = new AssignmentPlanner(List.of(node1), createModelsFromPlan(assignmentPlan)).computePlan(); + assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2))); + assertThat(assignmentPlan.getRemainingNodeMemory(node1.id()), greaterThanOrEqualTo(0L)); + assertThat(assignmentPlan.getRemainingNodeCores(node1.id()), greaterThanOrEqualTo(0)); + + } + + public void testGivenClusterResize_ShouldRemoveAllocatedModels_NewMemoryFields() { + // Ensure that plan is removing previously allocated models if not enough memory is available + Node node1 = new Node("n_1", ByteSizeValue.ofMb(700).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 2); + Deployment deployment1 = new Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 2, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(400).getBytes(), + ByteSizeValue.ofMb(100).getBytes() + ); + Deployment deployment2 = new Deployment( + "m_2", + ByteSizeValue.ofMb(100).getBytes(), + 1, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(400).getBytes(), + ByteSizeValue.ofMb(150).getBytes() + ); + Deployment deployment3 = new Deployment( + "m_3", + ByteSizeValue.ofMb(50).getBytes(), + 1, + 1, + Map.of(), + 0, + ByteSizeValue.ofMb(250).getBytes(), + ByteSizeValue.ofMb(50).getBytes() + ); // Create a plan where all deployments are assigned at least once AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) @@ -536,7 +1026,9 @@ public static List createModelsFromPlan(AssignmentPlan plan) { m.allocations(), m.threadsPerAllocation(), currentAllocations, - Math.max(m.maxAssignedAllocations(), totalAllocations) + Math.max(m.maxAssignedAllocations(), totalAllocations), + 0, + 0 ) ); } @@ -579,7 +1071,7 @@ public static List randomNodes(int scale, String nodeIdPrefix) { for (int i = 0; i < 1 + 3 * scale; i++) { int cores = randomIntBetween(2, 32); long memBytesPerCore = randomFrom(memBytesPerCoreValues); - nodes.add(new Node(nodeIdPrefix + "n_" + i, cores * memBytesPerCore, cores)); + nodes.add(new Node(nodeIdPrefix + "n_" + i, scaleNodeSize(ByteSizeValue.ofBytes(cores * memBytesPerCore).getMb()), cores)); } return nodes; } @@ -594,14 +1086,30 @@ public static List randomModels(int scale, double load) { public static Deployment randomModel(String idSuffix) { int allocations = randomIntBetween(1, 32); - return new Deployment( - "m_" + idSuffix, - randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(10).getBytes()), - randomIntBetween(1, 32), - randomIntBetween(1, 4), - Map.of(), - 0 - ); + // randomly choose between old and new memory fields format + if (randomBoolean()) { + return new Deployment( + "m_" + idSuffix, + randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(10).getBytes()), + randomIntBetween(1, 32), + randomIntBetween(1, 4), + Map.of(), + 0, + 0, + 0 + ); + } else { + return new Deployment( + "m_" + idSuffix, + randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(1).getBytes()), + randomIntBetween(1, 32), + randomIntBetween(1, 4), + Map.of(), + 0, + randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(1).getBytes()), + randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(1).getBytes()) + ); + } } public static void assertPreviousAssignmentsAreSatisfied(List deployments, AssignmentPlan assignmentPlan) { @@ -628,7 +1136,7 @@ private void runTooManyNodesAndModels(int nodesSize, int modelsSize) { } List deployments = new ArrayList<>(); for (int i = 0; i < modelsSize; i++) { - deployments.add(new Deployment("m_" + i, ByteSizeValue.ofMb(200).getBytes(), 2, 1, Map.of(), 0)); + deployments.add(new Deployment("m_" + i, ByteSizeValue.ofMb(200).getBytes(), 2, 1, Map.of(), 0, 0, 0)); } // Check plan is computed without OOM exception diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java index 4a9b01e535d88..c45ce36394109 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.assignment.planning; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Deployment; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Node; @@ -14,7 +15,6 @@ import java.util.List; import java.util.Map; -import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -22,77 +22,179 @@ public class PreserveAllAllocationsTests extends ESTestCase { public void testGivenNoPreviousAssignments() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - Deployment deployment1 = new Deployment("m_1", 30, 2, 1, Map.of(), 0); - Deployment deployment2 = new Deployment("m_2", 30, 2, 4, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, 0, 0); PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( List.of(node1, node2), List.of(deployment1, deployment2) ); - - List nodesPreservingAllocations = preserveAllAllocations.nodesPreservingAllocations(); - assertThat(nodesPreservingAllocations, contains(node1, node2)); - - List modelsPreservingAllocations = preserveAllAllocations.modelsPreservingAllocations(); - assertThat(modelsPreservingAllocations, contains(deployment1, deployment2)); } public void testGivenPreviousAssignments() { - Node node1 = new Node("n_1", 100, 8); - Node node2 = new Node("n_2", 100, 8); - Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 30, 2, 1, Map.of("n_1", 1), 1); - Deployment deployment2 = new Deployment("m_2", 50, 6, 4, Map.of("n_1", 1, "n_2", 2), 3); - PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( - List.of(node1, node2), - List.of(deployment1, deployment2) - ); - - List nodesPreservingAllocations = preserveAllAllocations.nodesPreservingAllocations(); - assertThat(nodesPreservingAllocations, hasSize(2)); - - assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); - assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(20L)); - assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); - - assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); - assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(50L)); - assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(0)); - - List modelsPreservingAllocations = preserveAllAllocations.modelsPreservingAllocations(); - assertThat(modelsPreservingAllocations, hasSize(2)); - - assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); - assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(30L)); - assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); - - assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); - assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(50L)); - assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(3)); - assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); - assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 0))); - - AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) - .assignModelToNode(deployment1, node1, 2) - .build(); - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); - assertThat(plan.assignments(deployment2).isEmpty(), is(true)); - - plan = preserveAllAllocations.mergePreservedAllocations(plan); - - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); - assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(20L)); - assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); - assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(50L)); - assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); + { + // old memory format + Node node1 = new Node("n_1", ByteSizeValue.ofMb(640).getBytes(), 8); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(640).getBytes(), 8); + Deployment deployment1 = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 1, + Map.of("n_1", 1), + 1, + 0, + 0 + ); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 6, 4, Map.of("n_1", 1, "n_2", 2), 3, 0, 0); + PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( + List.of(node1, node2), + List.of(deployment1, deployment2) + ); + + List nodesPreservingAllocations = preserveAllAllocations.nodesPreservingAllocations(); + assertThat(nodesPreservingAllocations, hasSize(2)); + + assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); + // 640 - [(2*30 + 240) + (2*50 + 240)] = 0: deployments use 640 MB on the node 1 + assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(0L)); + // 8 - (1*1+1*4) = 3 : deployments use 5 cores on the node + assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); + + assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); + // 640 - (50*2+240) = 300 : deployments use 340MB on the node + assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + // 8 - (2*4) = 0 : preserving all allocation2 of deployment 2 should use 8 cores on the node + assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(0)); + + List modelsPreservingAllocations = preserveAllAllocations.modelsPreservingAllocations(); + assertThat(modelsPreservingAllocations, hasSize(2)); + + assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); + assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(ByteSizeValue.ofMb(30).getBytes())); + assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); + + assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); + assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(ByteSizeValue.ofMb(50).getBytes())); + assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(3)); + assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); + assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 0))); + + // Now we have a plan with 2 deployments assigned to 2 nodes. + // Note that deployment 1 has already 1 allocation on node 1, and it gets 2 more. It's more than 2 allocations defined during + // initialization of deployment1, but we don't care at this point. + AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) + .assignModelToNode(deployment1, node1, 2) + .build(); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); + assertThat(plan.assignments(deployment2).isEmpty(), is(true)); + + plan = preserveAllAllocations.mergePreservedAllocations(plan); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); + + // Node 1 already had deployments 1 and 2 assigned to it so adding more allocation doesn't change memory usage. + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(0L)); + // 8 - ((1*1+1*4) + 2*1) = 1 : deployments use 7 cores on the node + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); + // Nothing changed for Node 2 + assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(ByteSizeValue.ofMb(300).getBytes())); + // Nothing changed for Node 2 + assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); + } + { + // new memory format + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 8); + Deployment deployment1 = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 1, + Map.of("n_1", 1), + 1, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + Deployment deployment2 = new Deployment( + "m_2", + ByteSizeValue.ofMb(50).getBytes(), + 6, + 4, + Map.of("n_1", 1, "n_2", 2), + 3, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( + List.of(node1, node2), + List.of(deployment1, deployment2) + ); + + List nodesPreservingAllocations = preserveAllAllocations.nodesPreservingAllocations(); + assertThat(nodesPreservingAllocations, hasSize(2)); + + assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); + // 1000 - [(30 + 300+10) + (50 + 300 + 10)] = 300: deployments use 700 MB on the node 1 + assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + // 8 - (1*1+1*4) = 3 : deployments use 5 cores on the node + assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); + + assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); + // 1000 - (50 + 300 + 2*10) = 630 : deployments use 370MB on the node + assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(630).getBytes())); + // 8 - (2*4) = 0 : preserving all allocation2 of deployment 2 should use 8 cores on the node + assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(0)); + + List modelsPreservingAllocations = preserveAllAllocations.modelsPreservingAllocations(); + assertThat(modelsPreservingAllocations, hasSize(2)); + + assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); + assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(ByteSizeValue.ofMb(30).getBytes())); + assertThat(modelsPreservingAllocations.get(0).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + assertThat(modelsPreservingAllocations.get(0).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(10).getBytes())); + assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); + + assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); + assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(ByteSizeValue.ofMb(50).getBytes())); + assertThat(modelsPreservingAllocations.get(1).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + assertThat(modelsPreservingAllocations.get(1).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(10).getBytes())); + assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(3)); + assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); + assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 0))); + + // Now we have a plan with 2 deployments assigned to 2 nodes. + // Note that deployment 1 has already 1 allocation on node 1, and it gets 2 more. It's more than 2 allocations defined during + // initialization of deployment1, but we don't care at this point. + AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) + .assignModelToNode(deployment1, node1, 2) + .build(); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); + assertThat(plan.assignments(deployment2).isEmpty(), is(true)); + + plan = preserveAllAllocations.mergePreservedAllocations(plan); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); + + // 1000 - ((30 + 300 + 3*10) + (50 + 300 + 10)) = 280 : deployments use 720 MB on the node 1 + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(280).getBytes())); + // 8 - ((1*1+1*4) + 2*1) = 1 : deployments use 7 cores on the node + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); + // Nothing changed for Node 2 + assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(ByteSizeValue.ofMb(630).getBytes())); + // Nothing changed for Node 2 + assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); + } } public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments() { - Node node = new Node("n_1", 100, 4); - AssignmentPlan.Deployment deployment = new Deployment("m_1", 30, 2, 2, Map.of("n_1", 2), 2); + Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4); + Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, 0, 0); PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations(List.of(node), List.of(deployment)); AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); @@ -101,7 +203,7 @@ public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments plan = preserveAllAllocations.mergePreservedAllocations(plan); assertThat(plan.assignments(deployment).isPresent(), is(true)); assertThat(plan.assignments(deployment).get(), equalTo(Map.of(node, 2))); - assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(70L)); + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(100).getBytes())); assertThat(plan.getRemainingNodeCores("n_1"), equalTo(0)); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java index d8c3b09422e92..f646bf5cb2e9d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.assignment.planning; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Deployment; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Node; @@ -22,10 +23,10 @@ public class PreserveOneAllocationTests extends ESTestCase { public void testGivenNoPreviousAssignments() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 30, 2, 1, Map.of(), 0); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", 30, 2, 4, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); + Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, 0, 0); PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node1, node2), List.of(deployment1, deployment2)); List nodesPreservingAllocations = preserveOneAllocation.nodesPreservingAllocations(); @@ -36,67 +37,204 @@ public void testGivenNoPreviousAssignments() { } public void testGivenPreviousAssignments() { - Node node1 = new Node("n_1", 100, 8); - Node node2 = new Node("n_2", 100, 8); - AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 30, 2, 1, Map.of("n_1", 1), 1); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", 50, 6, 4, Map.of("n_1", 1, "n_2", 2), 3); - PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node1, node2), List.of(deployment1, deployment2)); - - List nodesPreservingAllocations = preserveOneAllocation.nodesPreservingAllocations(); - assertThat(nodesPreservingAllocations, hasSize(2)); - - assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); - assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(20L)); - assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); - - assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); - assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(50L)); - assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(4)); - - List modelsPreservingAllocations = preserveOneAllocation.modelsPreservingAllocations(); - assertThat(modelsPreservingAllocations, hasSize(2)); - - assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); - assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(30L)); - assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); - assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); - - assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); - assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(50L)); - assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(4)); - assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); - assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 1))); - - AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) - .assignModelToNode(deployment1, node1, 2) - .assignModelToNode(deployment2, node2, 1) - .build(); - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node2, 1))); - - plan = preserveOneAllocation.mergePreservedAllocations(plan); - - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); - assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(20L)); - assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); - assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(50L)); - assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); + { + // old memory format + Node node1 = new Node("n_1", ByteSizeValue.ofMb(640).getBytes(), 8); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(640).getBytes(), 8); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of("n_1", 1), 1, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 6, 4, Map.of("n_1", 1, "n_2", 2), 3, 0, 0); + PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation( + List.of(node1, node2), + List.of(deployment1, deployment2) + ); + + List nodesPreservingAllocations = preserveOneAllocation.nodesPreservingAllocations(); + assertThat(nodesPreservingAllocations, hasSize(2)); + + assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); + // 640 - [(30*2+240)+(50*2+240)] = 0 : deployments use all memory on the node + assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(0L)); + // 8 - (1*1+1*4) = 3 : deployments use 5 cores on the node + assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); + + assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); + // 640 - (50*2+240) = 300 : deployments use 340MB on the node + assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + // 8 - (1*4) = 4 : preserving 1 allocation of deployment 2 should use 4 cores on the node + assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(4)); + + List modelsPreservingAllocations = preserveOneAllocation.modelsPreservingAllocations(); + assertThat(modelsPreservingAllocations, hasSize(2)); + + assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); + assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(ByteSizeValue.ofMb(30).getBytes())); + assertThat(modelsPreservingAllocations.get(0).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(0).getBytes())); + assertThat(modelsPreservingAllocations.get(0).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(0).getBytes())); + assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); + + assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); + assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(ByteSizeValue.ofMb(50).getBytes())); + assertThat(modelsPreservingAllocations.get(1).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(0).getBytes())); + assertThat(modelsPreservingAllocations.get(1).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(0).getBytes())); + assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(4)); + assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); + assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 1))); + + // Now we have a plan with 2 deployments assigned to 2 nodes. + // Note that deployment 1 has already 1 allocation on node 1, and it gets 2 more. It's more than 2 allocations defined during + // initialization of deployment1, but we don't care at this point. + AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) + .assignModelToNode(deployment1, node1, 2) + .assignModelToNode(deployment2, node2, 1) + .build(); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node2, 1))); + + plan = preserveOneAllocation.mergePreservedAllocations(plan); + + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); + // Node 1 already had deployments 1 and 2 assigned to it so adding more allocation doesn't change memory usage. + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(0L)); + // 8 - ((1*1+1*4) + 2*1) = 1 : deployments use 7 cores on the node + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); + // Node 2 already had deployment 2 assigned to it so adding more allocation doesn't change memory usage. + assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(ByteSizeValue.ofMb(300).getBytes())); + // 8 - [(1*4) + (1*4)] = 4 : deployment 2 should use all cores on the node + assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); + } + { + // new memory format + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 8); + Deployment deployment1 = new Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 1, + Map.of("n_1", 1), + 1, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + Deployment deployment2 = new Deployment( + "m_2", + ByteSizeValue.ofMb(50).getBytes(), + 6, + 4, + Map.of("n_1", 1, "n_2", 2), + 3, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation( + List.of(node1, node2), + List.of(deployment1, deployment2) + ); + + List nodesPreservingAllocations = preserveOneAllocation.nodesPreservingAllocations(); + assertThat(nodesPreservingAllocations, hasSize(2)); + + assertThat(nodesPreservingAllocations.get(0).id(), equalTo("n_1")); + // 1000 - [(30+300+10)+(50 + 300 +10)] = 300 : deployments use 700 memory on the node + assertThat(nodesPreservingAllocations.get(0).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + // 8 - (1*1+1*4) = 3 : deployments use 5 cores on the node + assertThat(nodesPreservingAllocations.get(0).cores(), equalTo(3)); + + assertThat(nodesPreservingAllocations.get(1).id(), equalTo("n_2")); + // 1000 - (50 +300 + 2*10) = 630 : deployments use 340MB on the node + assertThat(nodesPreservingAllocations.get(1).availableMemoryBytes(), equalTo(ByteSizeValue.ofMb(630).getBytes())); + // 8 - (1*4) = 0 : preserving 1 allocation of deployment 2 should use 4 cores on the node + assertThat(nodesPreservingAllocations.get(1).cores(), equalTo(4)); + + List modelsPreservingAllocations = preserveOneAllocation.modelsPreservingAllocations(); + assertThat(modelsPreservingAllocations, hasSize(2)); + + assertThat(modelsPreservingAllocations.get(0).id(), equalTo("m_1")); + assertThat(modelsPreservingAllocations.get(0).memoryBytes(), equalTo(ByteSizeValue.ofMb(30).getBytes())); + assertThat(modelsPreservingAllocations.get(0).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + assertThat(modelsPreservingAllocations.get(0).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(10).getBytes())); + assertThat(modelsPreservingAllocations.get(0).allocations(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).threadsPerAllocation(), equalTo(1)); + assertThat(modelsPreservingAllocations.get(0).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0))); + + assertThat(modelsPreservingAllocations.get(1).id(), equalTo("m_2")); + assertThat(modelsPreservingAllocations.get(1).memoryBytes(), equalTo(ByteSizeValue.ofMb(50).getBytes())); + assertThat(modelsPreservingAllocations.get(1).perDeploymentMemoryBytes(), equalTo(ByteSizeValue.ofMb(300).getBytes())); + assertThat(modelsPreservingAllocations.get(1).perAllocationMemoryBytes(), equalTo(ByteSizeValue.ofMb(10).getBytes())); + assertThat(modelsPreservingAllocations.get(1).allocations(), equalTo(4)); + assertThat(modelsPreservingAllocations.get(1).threadsPerAllocation(), equalTo(4)); + assertThat(modelsPreservingAllocations.get(1).currentAllocationsByNodeId(), equalTo(Map.of("n_1", 0, "n_2", 1))); + + // Now we have a plan with 2 deployments assigned to 2 nodes. + // Note that deployment 1 has already 1 allocation on node 1, and it gets 2 more. It's more than 2 allocations defined during + // initialization of deployment1, but we don't care at this point. + AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) + .assignModelToNode(deployment1, node1, 2) + .assignModelToNode(deployment2, node2, 1) + .build(); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node2, 1))); + + plan = preserveOneAllocation.mergePreservedAllocations(plan); + + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); + // 1000 - [(30+300+3*10) + (50+300+10)] = 280 : deployments use 720MB on the node + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(280).getBytes())); + // 8 - ((1*1+1*4) + 2*1) = 1 : deployments use 7 cores on the node + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(1)); + // 1000 - (50 + 300 + 2*10) = 630 : deployments use 370MB on the node + assertThat(plan.getRemainingNodeMemory("n_2"), equalTo(ByteSizeValue.ofMb(630).getBytes())); + // 8 - [(1*4) + (1*4)] = 4 : deployment 2 should use all cores on the node + assertThat(plan.getRemainingNodeCores("n_2"), equalTo(0)); + + } } public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments() { - Node node = new Node("n_1", 100, 4); - AssignmentPlan.Deployment deployment = new Deployment("m_1", 30, 2, 2, Map.of("n_1", 2), 2); - PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node), List.of(deployment)); - - AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); - assertThat(plan.assignments(deployment).isEmpty(), is(true)); - - plan = preserveOneAllocation.mergePreservedAllocations(plan); - assertThat(plan.assignments(deployment).isPresent(), is(true)); - assertThat(plan.assignments(deployment).get(), equalTo(Map.of(node, 1))); - assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(70L)); - assertThat(plan.getRemainingNodeCores("n_1"), equalTo(2)); + { + // old memory format + Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4); + Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, 0, 0); + PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node), List.of(deployment)); + + AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); + assertThat(plan.assignments(deployment).isEmpty(), is(true)); + + plan = preserveOneAllocation.mergePreservedAllocations(plan); + assertThat(plan.assignments(deployment).isPresent(), is(true)); + assertThat(plan.assignments(deployment).get(), equalTo(Map.of(node, 1))); + // 400 - (30*2 + 240) = 100 : deployments use 300MB on the node + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(100).getBytes())); + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(2)); + } + { + // new memory format + Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4); + Deployment deployment = new Deployment( + "m_1", + ByteSizeValue.ofMb(30).getBytes(), + 2, + 2, + Map.of("n_1", 2), + 2, + ByteSizeValue.ofMb(300).getBytes(), + ByteSizeValue.ofMb(10).getBytes() + ); + PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node), List.of(deployment)); + + AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); + assertThat(plan.assignments(deployment).isEmpty(), is(true)); + + plan = preserveOneAllocation.mergePreservedAllocations(plan); + assertThat(plan.assignments(deployment).isPresent(), is(true)); + assertThat(plan.assignments(deployment).get(), equalTo(Map.of(node, 1))); + // 400 - (30 + 300 + 10) = 60 : deployments use 340MB on the node + assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(60).getBytes())); + assertThat(plan.getRemainingNodeCores("n_1"), equalTo(2)); + } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java index 7ceb8bbb86869..651e4764cb894 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java @@ -36,7 +36,7 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase { public void testGivenOneModel_OneNode_OneZone_DoesNotFit() { Node node = new Node("n_1", 100, 1); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 2, Map.of(), 0); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 2, Map.of(), 0, 0, 0); AssignmentPlan plan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node)), List.of(deployment)).computePlan(); @@ -44,8 +44,17 @@ public void testGivenOneModel_OneNode_OneZone_DoesNotFit() { } public void testGivenOneModel_OneNode_OneZone_FullyFits() { - Node node = new Node("n_1", 100, 4); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 2, 2, Map.of(), 0); + Node node = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 2, + 2, + Map.of(), + 0, + 0, + 0 + ); AssignmentPlan plan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node)), List.of(deployment)).computePlan(); @@ -53,8 +62,17 @@ public void testGivenOneModel_OneNode_OneZone_FullyFits() { } public void testGivenOneModel_OneNode_OneZone_PartiallyFits() { - Node node = new Node("n_1", 100, 5); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 3, 2, Map.of(), 0); + Node node = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 5); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 3, + 2, + Map.of(), + 0, + 0, + 0 + ); AssignmentPlan plan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node)), List.of(deployment)).computePlan(); @@ -64,9 +82,18 @@ public void testGivenOneModel_OneNode_OneZone_PartiallyFits() { } public void testGivenOneModelWithSingleAllocation_OneNode_TwoZones() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 2, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 1, + 2, + Map.of(), + 0, + 0, + 0 + ); AssignmentPlan plan = new ZoneAwareAssignmentPlanner( Map.of(List.of("z1"), List.of(node1), List.of("z2"), List.of(node2)), @@ -82,9 +109,18 @@ public void testGivenOneModelWithSingleAllocation_OneNode_TwoZones() { } public void testGivenOneModel_OneNodePerZone_TwoZones_FullyFits() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 2, 2, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 2, + 2, + Map.of(), + 0, + 0, + 0 + ); AssignmentPlan plan = new ZoneAwareAssignmentPlanner( Map.of(List.of("z_1"), List.of(node1), List.of("z_2"), List.of(node2)), @@ -99,9 +135,18 @@ public void testGivenOneModel_OneNodePerZone_TwoZones_FullyFits() { } public void testGivenOneModel_OneNodePerZone_TwoZones_PartiallyFits() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 3, 3, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 3, + 3, + Map.of(), + 0, + 0, + 0 + ); AssignmentPlan plan = new ZoneAwareAssignmentPlanner( Map.of(List.of("z_1"), List.of(node1), List.of("z_2"), List.of(node2)), @@ -117,15 +162,15 @@ public void testGivenOneModel_OneNodePerZone_TwoZones_PartiallyFits() { } public void testGivenThreeModels_TwoNodesPerZone_ThreeZones_FullyFit() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - Node node3 = new Node("n_3", 100, 4); - Node node4 = new Node("n_4", 100, 4); - Node node5 = new Node("n_5", 100, 4); - Node node6 = new Node("n_6", 100, 4); - AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", 25, 4, 1, Map.of(), 0); - Deployment deployment2 = new AssignmentPlan.Deployment("m_2", 25, 6, 2, Map.of(), 0); - AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment("m_3", 25, 2, 3, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node4 = new Node("n_4", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node5 = new Node("n_5", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node6 = new Node("n_6", ByteSizeValue.ofMb(1000).getBytes(), 4); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 6, 2, Map.of(), 0, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(30).getBytes(), 2, 3, Map.of(), 0, 0, 0); Map, List> nodesByZone = Map.of( List.of("z_1"), @@ -168,11 +213,11 @@ public void testGivenThreeModels_TwoNodesPerZone_ThreeZones_FullyFit() { } public void testGivenTwoModelsWithSingleAllocation_OneNode_ThreeZones() { - Node node1 = new Node("n_1", 100, 4); - Node node2 = new Node("n_2", 100, 4); - Node node3 = new Node("n_3", 100, 4); - AssignmentPlan.Deployment deployment1 = new Deployment("m_1", 25, 1, 1, Map.of(), 0); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", 25, 1, 1, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(1000).getBytes(), 4); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, 0, 0); AssignmentPlan plan = new ZoneAwareAssignmentPlanner( Map.of(List.of("z1"), List.of(node1), List.of("z2"), List.of(node2), List.of("z3"), List.of(node3)), @@ -203,7 +248,16 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode .stream() .collect(Collectors.toMap(e -> e.getKey().id(), Map.Entry::getValue)); previousModelsPlusNew.add( - new AssignmentPlan.Deployment(m.id(), m.memoryBytes(), m.allocations(), m.threadsPerAllocation(), previousAssignments, 0) + new AssignmentPlan.Deployment( + m.id(), + m.memoryBytes(), + m.allocations(), + m.threadsPerAllocation(), + previousAssignments, + 0, + 0, + 0 + ) ); } previousModelsPlusNew.add(randomModel("new")); @@ -214,11 +268,11 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode } public void testGivenClusterResize_GivenOneZone_ShouldAllocateEachModelAtLeastOnce() { - Node node1 = new Node("n_1", ByteSizeValue.ofMb(1200).getBytes(), 2); - Node node2 = new Node("n_2", ByteSizeValue.ofMb(1200).getBytes(), 2); - AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0); - AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0); - AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0); + Node node1 = new Node("n_1", ByteSizeValue.ofMb(2580).getBytes(), 2); + Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, 0, 0); // First only start m_1 AssignmentPlan assignmentPlan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node1, node2)), List.of(deployment1)) @@ -252,8 +306,8 @@ public void testGivenClusterResize_GivenOneZone_ShouldAllocateEachModelAtLeastOn assertThat(indexedBasedPlan.get("m_3"), equalTo(Map.of("n_2", 1))); // Now the cluster starts getting resized. - Node node3 = new Node("n_3", ByteSizeValue.ofMb(2400).getBytes(), 2); - Node node4 = new Node("n_4", ByteSizeValue.ofMb(2400).getBytes(), 2); + Node node3 = new Node("n_3", ByteSizeValue.ofMb(5160).getBytes(), 2); + Node node4 = new Node("n_4", ByteSizeValue.ofMb(5160).getBytes(), 2); // First, one node goes away. assignmentPlan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node1)), createModelsFromPlan(assignmentPlan)) diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MlAssignmentPlannerUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MlAssignmentPlannerUpgradeIT.java new file mode 100644 index 0000000000000..fc78bf36c72fb --- /dev/null +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MlAssignmentPlannerUpgradeIT.java @@ -0,0 +1,289 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.upgrades; + +import org.elasticsearch.Version; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.Strings; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static org.elasticsearch.client.WarningsHandler.PERMISSIVE; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class MlAssignmentPlannerUpgradeIT extends AbstractUpgradeTestCase { + + private Logger logger = LogManager.getLogger(MlAssignmentPlannerUpgradeIT.class); + + // See PyTorchModelIT for how this model was created + static final String BASE_64_ENCODED_MODEL = + "UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwp" + + "TdXBlclNpbXBsZQpxACmBfShYCAAAAHRyYWluaW5ncQGIdWJxAi5QSwcIXOpBBDQAAAA0AAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAA" + + "AAAAAdAEEAc2ltcGxlbW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQj0AWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaW" + + "lpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWnWOMWvDMBCF9/yKI5MMrnHTQsHgjt2aJdlCEIp9SgWSTpykFvfXV1htaYds0nfv473Jqhjh" + + "kAPywbhgUbzSnC02wwZAyqBYOUzIUUoY4XRe6SVr/Q8lVsYbf4UBLkS2kBk1aOIPxbOIaPVQtEQ8vUnZ/WlrSxTA+JCTNHMc4Ig+Ele" + + "s+Jod+iR3N/jDDf74wxu4e/5+DmtE9mUyhdgFNq7bZ3ekehbruC6aTxS/c1rom6Z698WrEfIYxcn4JGTftLA7tzCnJeD41IJVC+U07k" + + "umUHw3E47Vqh+xnULeFisYLx064mV8UTZibWFMmX0p23wBUEsHCE0EGH3yAAAAlwEAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJ" + + "wA5AHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCNQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpa" + + "WlpaWlpaWlpaWlpaWlpaWlpaWlpaWrWST0+DMBiHW6bOod/BGS94kKpo2Mwyox5x3pbgiXSAFtdR/nQu3IwHiZ9oX88CaeGu9tL0efq" + + "+v8P7fmiGA1wgTgoIcECZQqe6vmYD6G4hAJOcB1E8NazTm+ELyzY4C3Q0z8MsRwF+j4JlQUPEEo5wjH0WB9hCNFqgpOCExZY5QnnEw7" + + "ME+0v8GuaIs8wnKI7RigVrKkBzm0lh2OdjkeHllG28f066vK6SfEypF60S+vuYt4gjj2fYr/uPrSvRv356TepfJ9iWJRN0OaELQSZN3" + + "FRPNbcP1PTSntMr0x0HzLZQjPYIEo3UaFeiISRKH0Mil+BE/dyT1m7tCBLwVO1MX4DK3bbuTlXuy8r71j5Aoho66udAoseOnrdVzx28" + + "UFW6ROuO/lT6QKKyo79VU54emj9QSwcInsUTEDMBAAAFAwAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAZAAYAc2ltcGxlbW9kZWw" + + "vY29uc3RhbnRzLnBrbEZCAgBaWoACKS5QSwcIbS8JVwQAAAAEAAAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAATADsAc2ltcGxlbW" + + "9kZWwvdmVyc2lvbkZCNwBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaMwpQSwcI0" + + "Z5nVQIAAAACAAAAUEsBAgAAAAAICAAAAAAAAFzqQQQ0AAAANAAAABQAAAAAAAAAAAAAAAAAAAAAAHNpbXBsZW1vZGVsL2RhdGEucGts" + + "UEsBAgAAFAAICAgAAAAAAE0EGH3yAAAAlwEAAB0AAAAAAAAAAAAAAAAAhAAAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5UEs" + + "BAgAAFAAICAgAAAAAAJ7FExAzAQAABQMAACcAAAAAAAAAAAAAAAAAAgIAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYn" + + "VnX3BrbFBLAQIAAAAACAgAAAAAAABtLwlXBAAAAAQAAAAZAAAAAAAAAAAAAAAAAMMDAABzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsU" + + "EsBAgAAAAAICAAAAAAAANGeZ1UCAAAAAgAAABMAAAAAAAAAAAAAAAAAFAQAAHNpbXBsZW1vZGVsL3ZlcnNpb25QSwYGLAAAAAAAAAAe" + + "Ay0AAAAAAAAAAAAFAAAAAAAAAAUAAAAAAAAAagEAAAAAAACSBAAAAAAAAFBLBgcAAAAA/AUAAAAAAAABAAAAUEsFBgAAAAAFAAUAagE" + + "AAJIEAAAAAA=="; + static final long RAW_MODEL_SIZE; // size of the model before base64 encoding + static { + RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length; + } + + public void testMlAssignmentPlannerUpgrade() throws Exception { + assumeTrue("NLP model deployments added in 8.0", isOriginalClusterVersionAtLeast(Version.V_8_0_0)); + + logger.info("Starting testMlAssignmentPlannerUpgrade, model size {}", RAW_MODEL_SIZE); + + switch (CLUSTER_TYPE) { + case OLD -> { + // setup deployments using old and new memory format + setupDeployments(); + + waitForDeploymentStarted("old_memory_format"); + waitForDeploymentStarted("new_memory_format"); + + // assert correct memory format is used + assertOldMemoryFormat("old_memory_format"); + if (isOriginalClusterVersionAtLeast(Version.V_8_11_0)) { + assertNewMemoryFormat("new_memory_format"); + } else { + assertOldMemoryFormat("new_memory_format"); + } + } + case MIXED -> { + ensureHealth(".ml-inference-*,.ml-config*", (request -> { + request.addParameter("wait_for_status", "yellow"); + request.addParameter("timeout", "70s"); + })); + waitForDeploymentStarted("old_memory_format"); + waitForDeploymentStarted("new_memory_format"); + + // assert correct memory format is used + assertOldMemoryFormat("old_memory_format"); + if (isOriginalClusterVersionAtLeast(Version.V_8_11_0)) { + assertNewMemoryFormat("new_memory_format"); + } else { + assertOldMemoryFormat("new_memory_format"); + } + + } + case UPGRADED -> { + ensureHealth(".ml-inference-*,.ml-config*", (request -> { + request.addParameter("wait_for_status", "yellow"); + request.addParameter("timeout", "70s"); + })); + waitForDeploymentStarted("old_memory_format"); + waitForDeploymentStarted("new_memory_format"); + + // assert correct memory format is used + assertOldMemoryFormat("old_memory_format"); + assertNewMemoryFormat("new_memory_format"); + + cleanupDeployments(); + } + } + } + + @SuppressWarnings("unchecked") + private void waitForDeploymentStarted(String modelId) throws Exception { + assertBusy(() -> { + var response = getTrainedModelStats(modelId); + Map map = entityAsMap(response); + List> stats = (List>) map.get("trained_model_stats"); + assertThat(stats, hasSize(1)); + var stat = stats.get(0); + assertThat(stat.toString(), XContentMapValues.extractValue("deployment_stats.state", stat), equalTo("started")); + }, 30, TimeUnit.SECONDS); + } + + @SuppressWarnings("unchecked") + private void assertOldMemoryFormat(String modelId) throws Exception { + // There was a change in the MEMORY_OVERHEAD value in 8.3.0, see #86416 + long memoryOverheadMb = Version.fromString(UPGRADE_FROM_VERSION).onOrAfter(Version.V_8_2_1) ? 240 : 270; + var response = getTrainedModelStats(modelId); + Map map = entityAsMap(response); + List> stats = (List>) map.get("trained_model_stats"); + assertThat(stats, hasSize(1)); + var stat = stats.get(0); + Long expectedMemoryUsage = ByteSizeValue.ofMb(memoryOverheadMb).getBytes() + RAW_MODEL_SIZE * 2; + Integer actualMemoryUsage = (Integer) XContentMapValues.extractValue("model_size_stats.required_native_memory_bytes", stat); + assertThat( + Strings.format("Memory usage mismatch for the model %s in cluster state %s", modelId, CLUSTER_TYPE.toString()), + actualMemoryUsage, + equalTo(expectedMemoryUsage.intValue()) + ); + } + + @SuppressWarnings("unchecked") + private void assertNewMemoryFormat(String modelId) throws Exception { + var response = getTrainedModelStats(modelId); + Map map = entityAsMap(response); + List> stats = (List>) map.get("trained_model_stats"); + assertThat(stats, hasSize(1)); + var stat = stats.get(0); + Long expectedMemoryUsage = ByteSizeValue.ofMb(300).getBytes() + RAW_MODEL_SIZE + ByteSizeValue.ofMb(10).getBytes(); + Integer actualMemoryUsage = (Integer) XContentMapValues.extractValue("model_size_stats.required_native_memory_bytes", stat); + assertThat(stat.toString(), actualMemoryUsage.toString(), equalTo(expectedMemoryUsage.toString())); + } + + private Response getTrainedModelStats(String modelId) throws IOException { + Request request = new Request("GET", "/_ml/trained_models/" + modelId + "/_stats"); + request.setOptions(request.getOptions().toBuilder().setWarningsHandler(PERMISSIVE).build()); + var response = client().performRequest(request); + assertOK(response); + return response; + } + + private Response infer(String input, String modelId) throws IOException { + Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_infer"); + request.setJsonEntity(Strings.format(""" + { "docs": [{"input":"%s"}] } + """, input)); + request.setOptions(request.getOptions().toBuilder().setWarningsHandler(PERMISSIVE).build()); + var response = client().performRequest(request); + assertOK(response); + return response; + } + + private void putModelDefinition(String modelId) throws IOException { + Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0"); + request.setJsonEntity(Strings.format(""" + {"total_definition_length":%s,"definition": "%s","total_parts": 1}""", RAW_MODEL_SIZE, BASE_64_ENCODED_MODEL)); + client().performRequest(request); + } + + private void putVocabulary(List vocabulary, String modelId) throws IOException { + List vocabularyWithPad = new ArrayList<>(); + vocabularyWithPad.add("[PAD]"); + vocabularyWithPad.add("[UNK]"); + vocabularyWithPad.addAll(vocabulary); + String quotedWords = vocabularyWithPad.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(",")); + + Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/vocabulary"); + request.setJsonEntity(Strings.format(""" + { "vocabulary": [%s] } + """, quotedWords)); + client().performRequest(request); + } + + private void setupDeployments() throws Exception { + createTrainedModel("old_memory_format", 0, 0); + putModelDefinition("old_memory_format"); + putVocabulary(List.of("these", "are", "my", "words"), "old_memory_format"); + startDeployment("old_memory_format"); + + createTrainedModel("new_memory_format", ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes()); + putModelDefinition("new_memory_format"); + putVocabulary(List.of("these", "are", "my", "words"), "new_memory_format"); + startDeployment("new_memory_format"); + } + + private void cleanupDeployments() throws IOException { + stopDeployment("old_memory_format"); + deleteTrainedModel("old_memory_format"); + stopDeployment("new_memory_format"); + deleteTrainedModel("new_memory_format"); + } + + private void createTrainedModel(String modelId, long perDeploymentMemoryBytes, long perAllocationMemoryBytes) throws IOException { + Request request = new Request("PUT", "/_ml/trained_models/" + modelId); + if (perAllocationMemoryBytes > 0 && perDeploymentMemoryBytes > 0) { + request.setJsonEntity(Strings.format(""" + { + "description": "simple model for testing", + "model_type": "pytorch", + "inference_config": { + "pass_through": { + "tokenization": { + "bert": { + "with_special_tokens": false + } + } + } + }, + "metadata": { + "per_deployment_memory_bytes": %s, + "per_allocation_memory_bytes": %s + } + }""", perDeploymentMemoryBytes, perAllocationMemoryBytes)); + } else { + request.setJsonEntity(""" + { + "description": "simple model for testing", + "model_type": "pytorch", + "inference_config": { + "pass_through": { + "tokenization": { + "bert": { + "with_special_tokens": false + } + } + } + } + }"""); + } + client().performRequest(request); + } + + private void deleteTrainedModel(String modelId) throws IOException { + Request request = new Request("DELETE", "_ml/trained_models/" + modelId); + client().performRequest(request); + } + + private Response startDeployment(String modelId) throws IOException { + return startDeployment(modelId, "started"); + } + + private Response startDeployment(String modelId, String waitForState) throws IOException { + Request request = new Request( + "POST", + "/_ml/trained_models/" + + modelId + + "/deployment/_start?timeout=40s&wait_for=" + + waitForState + + "&inference_threads=1&model_threads=1" + ); + request.setOptions(request.getOptions().toBuilder().setWarningsHandler(PERMISSIVE).build()); + var response = client().performRequest(request); + assertOK(response); + return response; + } + + private void stopDeployment(String modelId) throws IOException { + String endpoint = "/_ml/trained_models/" + modelId + "/deployment/_stop"; + Request request = new Request("POST", endpoint); + client().performRequest(request); + } +}