Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dispatch ML task to ML node first #346

Merged
merged 1 commit into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@
import org.opensearch.action.ActionResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.node.DiscoveryNodeRole;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.IndexScopedSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.settings.SettingsFilter;
import org.opensearch.common.xcontent.NamedXContentRegistry;
Expand Down Expand Up @@ -115,14 +113,7 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin {
private ClusterService clusterService;
private ThreadPool threadPool;

public static final Setting<Boolean> IS_ML_NODE_SETTING = Setting.boolSetting("node.ml", false, Setting.Property.NodeScope);
Copy link
Collaborator

@jackiehanyang jackiehanyang Jun 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curiosity question: what is the purpose of IS_ML_NODE_SETTING before and why we don't need it now? I saw this part of logic was moved to TestHelper class, what's the reason for that?

Copy link
Collaborator Author

@ylwu-amzn ylwu-amzn Jun 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That part was from prototype when we tried to support ML node. But actually it's not being used in our formal release as OpenSearch core doesn't support ML role and we have to postpone that. Now OpenSearch plan to support ML role with dynamic role feature in 2.1. We can add this back but we don't need this prototype/experiment code any more. Just move it to test part.


public static final DiscoveryNodeRole ML_ROLE = new DiscoveryNodeRole("ml", "l") {
@Override
public Setting<Boolean> legacySetting() {
return IS_ML_NODE_SETTING;
}
};
public static final String ML_ROLE_NAME = "ml";

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.task;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
Expand All @@ -23,6 +24,7 @@
import org.opensearch.ml.action.stats.MLStatsNodesAction;
import org.opensearch.ml.action.stats.MLStatsNodesRequest;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.utils.MLNodeUtils;

import com.google.common.collect.ImmutableSet;

Expand All @@ -49,9 +51,7 @@ public MLTaskDispatcher(ClusterService clusterService, Client client) {
* @param listener Action listener
*/
public void dispatchTask(ActionListener<DiscoveryNode> listener) {
// todo: add ML node type setting check
// DiscoveryNode[] mlNodes = getEligibleMLNodes();
DiscoveryNode[] mlNodes = getEligibleDataNodes();
DiscoveryNode[] mlNodes = getEligibleNodes();
MLStatsNodesRequest MLStatsNodesRequest = new MLStatsNodesRequest(mlNodes);
MLStatsNodesRequest
.addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE));
Expand Down Expand Up @@ -107,14 +107,32 @@ public void dispatchTask(ActionListener<DiscoveryNode> listener) {
}));
}

private DiscoveryNode[] getEligibleDataNodes() {
/**
* Get eligible node to run ML task. If there are nodes with ml role, will return all these
* ml nodes; otherwise return all data nodes.
*
* @return array of discovery node
*/
protected DiscoveryNode[] getEligibleNodes() {
Comment on lines +110 to +116
Copy link
Collaborator

@Zhangxunmt Zhangxunmt Jun 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it only preferable to run ML tasks in ml node? I assume ml-common can run in data node as well. Also is there any logic in the ClusterState.nodes() to evaluate if any ml node is overloaded, etc? I just wonder, in the future, if we want to add more priority based strategy here to prioritize ML node, but still use data node if ML node is heavy loaded, etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it only preferable to run ML tasks in ml node? I assume ml-common can run in data node as well.

Check the comment "If there are nodes with ml role, will return all these ml nodes; otherwise return all data nodes."

Also is there any logic in the ClusterState.nodes() to evaluate if any ml node is overloaded, etc?

Yes, we check JVM heap usage and how many ML task running on a node. If exceeds limit, will not dispatch new ML task to that node.

I just wonder, in the future, if we want to add more priority based strategy here to prioritize ML node, but still use data node if ML node is heavy loaded, etc.

I think we'd better ask user to scale the cluster by adding more ML node or switch to more powerful node type if ML node is heavy/over loaded. But this is not the one way door, we can always tune the code if cx really needs to run model on data nodes if ML node overloaded.

ClusterState state = this.clusterService.state();
final List<DiscoveryNode> eligibleMLNodes = new ArrayList<>();
final List<DiscoveryNode> eligibleDataNodes = new ArrayList<>();
for (DiscoveryNode node : state.nodes()) {
if (MLNodeUtils.isMLNode(node)) {
eligibleMLNodes.add(node);
}
if (node.isDataNode()) {
eligibleDataNodes.add(node);
}
}
return eligibleDataNodes.toArray(new DiscoveryNode[0]);
if (eligibleMLNodes.size() > 0) {
DiscoveryNode[] mlNodes = eligibleMLNodes.toArray(new DiscoveryNode[0]);
log.debug("Find {} dedicated ML nodes: {}", eligibleMLNodes.size(), Arrays.toString(mlNodes));
return mlNodes;
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: This "else" should be redundant.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's code style preference. People have different preference for "No-else-after-return" or not, check https://stackoverflow.com/questions/46875442/unnecessary-else-after-return-no-else-return.

For me, I feel the code is more readable to keep else to make the returns have same indentation.

DiscoveryNode[] dataNodes = eligibleDataNodes.toArray(new DiscoveryNode[0]);
log.debug("Find no dedicated ML nodes. But have {} data nodes: {}", eligibleDataNodes.size(), Arrays.toString(dataNodes));
return dataNodes;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@

package org.opensearch.ml.utils;

import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_ROLE_NAME;

import java.io.IOException;

import lombok.experimental.UtilityClass;

import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.xcontent.*;
import org.opensearch.ml.plugin.MachineLearningPlugin;

@UtilityClass
public class MLNodeUtils {
public boolean isMLNode(DiscoveryNode node) {
return node.getRoles().stream().anyMatch(role -> role.roleName().equalsIgnoreCase(MachineLearningPlugin.ML_ROLE.roleName()));
return node.getRoles().stream().anyMatch(role -> role.roleName().equalsIgnoreCase(ML_ROLE_NAME));
}

public static XContentParser createXContentParserFromRegistry(NamedXContentRegistry xContentRegistry, BytesReference bytesReference)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

import static org.mockito.Mockito.*;
import static org.opensearch.ml.common.breaker.MemoryCircuitBreaker.DEFAULT_JVM_HEAP_USAGE_THRESHOLD;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_ROLE_NAME;
import static org.opensearch.ml.utils.TestHelper.ML_ROLE;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

Expand All @@ -35,6 +36,8 @@
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.test.OpenSearchTestCase;

import com.google.common.collect.ImmutableSet;

public class MLTaskDispatcherTests extends OpenSearchTestCase {

@Mock
Expand All @@ -48,8 +51,9 @@ public class MLTaskDispatcherTests extends OpenSearchTestCase {

MLTaskDispatcher taskDispatcher;
ClusterState testState;
DiscoveryNode node1;
DiscoveryNode node2;
DiscoveryNode dataNode1;
DiscoveryNode dataNode2;
DiscoveryNode mlNode;
MLStatsNodesResponse mlStatsNodesResponse;
String clusterName = "test cluster";

Expand All @@ -59,11 +63,12 @@ public void setup() {

taskDispatcher = spy(new MLTaskDispatcher(clusterService, client));

Set<DiscoveryNodeRole> roleSet = new HashSet<>();
roleSet.add(DiscoveryNodeRole.DATA_ROLE);
node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), new HashMap<>(), roleSet, Version.CURRENT);
node2 = new DiscoveryNode("node2", buildNewFakeTransportAddress(), new HashMap<>(), roleSet, Version.CURRENT);
DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build();
Set<DiscoveryNodeRole> dataRoleSet = ImmutableSet.of(DiscoveryNodeRole.DATA_ROLE);
dataNode1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), new HashMap<>(), dataRoleSet, Version.CURRENT);
dataNode2 = new DiscoveryNode("node2", buildNewFakeTransportAddress(), new HashMap<>(), dataRoleSet, Version.CURRENT);
Set<DiscoveryNodeRole> mlRoleSet = ImmutableSet.of(ML_ROLE);
mlNode = new DiscoveryNode("mlNode", buildNewFakeTransportAddress(), new HashMap<>(), mlRoleSet, Version.CURRENT);
DiscoveryNodes nodes = DiscoveryNodes.builder().add(dataNode1).add(dataNode2).build();
testState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, null, 0, false);
when(clusterService.state()).thenReturn(testState);

Expand Down Expand Up @@ -111,12 +116,34 @@ public void testDispatchTask_TaskCountExceedLimit() {
assertEquals(errorMessage, argumentCaptor.getValue().getMessage());
}

public void testGetEligibleNodes_DataNodeOnly() {
DiscoveryNode[] eligibleNodes = taskDispatcher.getEligibleNodes();
assertEquals(2, eligibleNodes.length);
for (DiscoveryNode node : eligibleNodes) {
assertTrue(node.isDataNode());
}
}

public void testGetEligibleNodes_MlAndDataNodes() {
DiscoveryNodes nodes = DiscoveryNodes.builder().add(dataNode1).add(dataNode2).add(mlNode).build();
testState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, null, 0, false);
when(clusterService.state()).thenReturn(testState);

DiscoveryNode[] eligibleNodes = taskDispatcher.getEligibleNodes();
assertEquals(1, eligibleNodes.length);
for (DiscoveryNode node : eligibleNodes) {
assertFalse(node.isDataNode());
DiscoveryNodeRole[] discoveryNodeRoles = node.getRoles().toArray(new DiscoveryNodeRole[0]);
assertEquals(ML_ROLE_NAME, discoveryNodeRoles[0].roleName());
}
}

private MLStatsNodesResponse getMlStatsNodesResponse() {
Map<MLNodeLevelStat, Object> nodeStats = new HashMap<>();
nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l);
nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 5l);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats);
return new MLStatsNodesResponse(
new ClusterName(clusterName),
Arrays.asList(mlStatsNodeResponse1, mlStatsNodeResponse2),
Expand All @@ -127,8 +154,8 @@ private MLStatsNodesResponse getMlStatsNodesResponse() {
private MLStatsNodesResponse getNodesResponse_NoTaskCounts() {
Map<MLNodeLevelStat, Object> nodeStats = new HashMap<>();
nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats);
return new MLStatsNodesResponse(
new ClusterName(clusterName),
Arrays.asList(mlStatsNodeResponse1, mlStatsNodeResponse2),
Expand All @@ -140,8 +167,8 @@ private MLStatsNodesResponse getNodesResponse_MemoryExceedLimits() {
Map<MLNodeLevelStat, Object> nodeStats = new HashMap<>();
nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 90l);
nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 5l);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats);
return new MLStatsNodesResponse(
new ClusterName(clusterName),
Arrays.asList(mlStatsNodeResponse1, mlStatsNodeResponse2),
Expand All @@ -153,8 +180,8 @@ private MLStatsNodesResponse getNodesResponse_TaskCountExceedLimits() {
Map<MLNodeLevelStat, Object> nodeStats = new HashMap<>();
nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l);
nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 15l);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats);
return new MLStatsNodesResponse(
new ClusterName(clusterName),
Arrays.asList(mlStatsNodeResponse1, mlStatsNodeResponse2),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.utils;

import static java.util.Collections.emptyMap;
import static org.opensearch.ml.utils.TestHelper.ML_ROLE;

import java.io.IOException;
import java.util.HashSet;
Expand All @@ -22,7 +23,6 @@
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.plugin.MachineLearningPlugin;
import org.opensearch.test.OpenSearchTestCase;

public class MLNodeUtilsTests extends OpenSearchTestCase {
Expand All @@ -34,7 +34,7 @@ public void testIsMLNode() {
DiscoveryNode normalNode = new DiscoveryNode("Normal node", buildNewFakeTransportAddress(), emptyMap(), roleSet, Version.CURRENT);
Assert.assertFalse(MLNodeUtils.isMLNode(normalNode));

roleSet.add(MachineLearningPlugin.ML_ROLE);
roleSet.add(ML_ROLE);
DiscoveryNode mlNode = new DiscoveryNode("ML node", buildNewFakeTransportAddress(), emptyMap(), roleSet, Version.CURRENT);
Assert.assertTrue(MLNodeUtils.isMLNode(mlNode));
}
Expand Down
12 changes: 12 additions & 0 deletions plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
import org.opensearch.client.Response;
import org.opensearch.client.RestClient;
import org.opensearch.client.WarningsHandler;
import org.opensearch.cluster.node.DiscoveryNodeRole;
import org.opensearch.common.bytes.BytesArray;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.NamedXContentRegistry;
Expand All @@ -56,6 +58,16 @@
import com.google.common.collect.ImmutableMap;

public class TestHelper {

public static final Setting<Boolean> IS_ML_NODE_SETTING = Setting.boolSetting("node.ml", false, Setting.Property.NodeScope);

public static final DiscoveryNodeRole ML_ROLE = new DiscoveryNodeRole("ml", "ml") {
@Override
public Setting<Boolean> legacySetting() {
return IS_ML_NODE_SETTING;
}
};

public static XContentParser parser(String xc) throws IOException {
return parser(xc, true);
}
Expand Down