diff --git a/tony-core/src/main/java/com/linkedin/tony/TonyConfigurationKeys.java b/tony-core/src/main/java/com/linkedin/tony/TonyConfigurationKeys.java index 44967594..b276383d 100644 --- a/tony-core/src/main/java/com/linkedin/tony/TonyConfigurationKeys.java +++ b/tony-core/src/main/java/com/linkedin/tony/TonyConfigurationKeys.java @@ -331,4 +331,22 @@ public static String getContainerDockerMountKey() { public static final String TB_GPUS = TB_JOB_PREFIX + "gpus"; public static final int DEFAULT_TB_GPUS = 0; + + /** + * Introduce the group dependency waiting time(sec), like as follows: + * tony.application.group.A = worker,chief + * + * tony.application.dependency.evaluator.timeout.after.A = 3600 + */ + public static final String GROUP_REGEX = TONY_APPLICATION_PREFIX + "group\\.([A-Za-z]+)$"; + public static final String GROUP_DEPEND_TIMEOUT_REGEX = + TONY_APPLICATION_PREFIX + "dependency\\.([A-Za-z]+)\\.timeout\\.after\\.([A-Za-z]+)$"; + + public static String getGroupKey(String groupName) { + return String.format(TONY_APPLICATION_PREFIX + "group.%s", groupName); + } + + public static String getGroupDependentKey(String grp, String dependentGrp) { + return String.format(TONY_APPLICATION_PREFIX + "dependency.%s.timeout.after.%s", grp, dependentGrp); + } } diff --git a/tony-core/src/main/java/com/linkedin/tony/runtime/MLGenericRuntime.java b/tony-core/src/main/java/com/linkedin/tony/runtime/MLGenericRuntime.java index e92fb063..4db32ad8 100644 --- a/tony-core/src/main/java/com/linkedin/tony/runtime/MLGenericRuntime.java +++ b/tony-core/src/main/java/com/linkedin/tony/runtime/MLGenericRuntime.java @@ -16,13 +16,17 @@ package com.linkedin.tony.runtime; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang.StringUtils; +import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; @@ -34,6 +38,7 @@ import com.linkedin.tony.TaskExecutor; import com.linkedin.tony.TonyConfigurationKeys; import com.linkedin.tony.tensorflow.TonySession; +import com.linkedin.tony.util.Utils; import static com.linkedin.tony.Constants.SIDECAR_TB_ROLE_NAME; @@ -51,6 +56,12 @@ public abstract class MLGenericRuntime implements FrameworkRuntime { // ===================For AM ======================= + // Group dependencies policy. + Map> grpWithMembersIndex; + Map> taskInGrpsIndex; + // todo: Need to support single group dependent multiple other groups + Map> taskWithDependentGrpsIndex; + @Override public String constructClusterSpec(String taskId) throws IOException { assert session != null; @@ -116,9 +127,127 @@ public boolean isHealthy(Configuration tonyConf) { session.setFinalStatus(FinalApplicationStatus.FAILED, "Container allocation timeout."); return false; } + + /** + * Checking the task roles completion timeout when its' pre-dependency tasks finished + * + * For example, tensorflow estimator training job will include some roles of ps/worker/evaluator/chief. + * Actually, due to the bug of tensorflow or misusing the estimator api, sometimes evaluator will hang. + * So if we use the configuration as follows, when evaluator is still running after timeout and + * chief/workers are finished, the mechanism of dependency group timeout will make job failed. + * + * Dependency group timeout configuration as follows: + * + * tony.application.group.A = worker,chief + * tony.application.dependency.evaluator.timeout.after.A = 3600 + * + */ + String errorMsg = null; + try { + errorMsg = groupDependencyTimeout(tonyConf); + } catch (Exception e) { + log.error("Failed to check dependency timeout.", e); + } + if (errorMsg != null) { + session.setFinalStatus(FinalApplicationStatus.FAILED, errorMsg); + return false; + } return true; } + @VisibleForTesting + protected String groupDependencyTimeout(Configuration tonyConf) { + if (taskWithDependentGrpsIndex == null) { + taskWithDependentGrpsIndex = Utils.getJobTypeDependentGrps(tonyConf); + log.info("Task types dependent grp: " + taskWithDependentGrpsIndex); + } + // groupDependencies is map, key: waiting role, value: pre-dependent groups and waiting timeout + if (taskWithDependentGrpsIndex == null || taskWithDependentGrpsIndex.isEmpty()) { + return null; + } + + // groupMembers is map, key: groupName, value: its members in this group + if (grpWithMembersIndex == null) { + grpWithMembersIndex = Utils.getAllGroupJobTypes(tonyConf); + log.info("Group members: " + grpWithMembersIndex); + } + + // memberInGroups is map. key: jobtype name, value: in which groups + if (taskInGrpsIndex == null) { + taskInGrpsIndex = getMemberInGroups(grpWithMembersIndex); + } + + Map allTasks = session.getTonyTasks(); + List runningTasks = session.getRunningTasks(); + + // Get the running jobs' type, like the tf roles of ps/worker/chief/evaluator + Set runningJobTypes = runningTasks.stream() + .map(TonySession.TonyTask::getJobName) + .filter(jobname -> taskWithDependentGrpsIndex.containsKey(jobname)) + .collect(Collectors.toSet()); + + for (String runningTaskType : runningJobTypes) { + Pair dependentGroupPair = taskWithDependentGrpsIndex.get(runningTaskType); + String dependentGroupName = dependentGroupPair.getKey(); + long timeout = dependentGroupPair.getValue() * 1000; + + if (!grpWithMembersIndex.containsKey(dependentGroupName)) { + continue; + } + + boolean allDependentTaskFinished = true; + long latestEndTimeInAllDependentTasks = 0L; + for (String dependentsGroupJobtype : grpWithMembersIndex.get(dependentGroupName)) { + + if (Utils.existRunningTasksWithJobtype(runningTasks, dependentsGroupJobtype)) { + allDependentTaskFinished = false; + break; + } + + // Find out the latest finished task in this task type, if the specified timeout exceed, + // make the job fail. + latestEndTimeInAllDependentTasks = Math.max( + Arrays.stream(allTasks.get(dependentsGroupJobtype)) + .mapToLong(x -> x.getEndTime()) + .max().getAsLong(), + latestEndTimeInAllDependentTasks + ); + } + + if (!allDependentTaskFinished) { + continue; + } + + log.info("Running job type: " + runningTaskType + ", all dependent task finished: " + allDependentTaskFinished); + + if (System.currentTimeMillis() - latestEndTimeInAllDependentTasks > timeout) { + return String.format("Jobtype: %s runs exceeded timeout(%s sec) because it's " + + "dependent group: %s (task set: [%s]) has been finished", + runningTaskType, dependentGroupPair.getValue(), dependentGroupName, + StringUtils.join(grpWithMembersIndex.get(dependentGroupName), ",")); + } + } + + return null; + } + + private Map> getMemberInGroups(Map> groupMembers) { + /** + * key: job type name + * value: the list of groups + */ + Map> memberInGroups = new HashMap<>(); + for (Map.Entry> entry : groupMembers.entrySet()) { + String group = entry.getKey(); + List members = entry.getValue(); + for (String member : members) { + memberInGroups.putIfAbsent(member, new ArrayList<>()); + memberInGroups.get(member).add(group); + } + } + return memberInGroups; + } + private boolean containerAllocationTimeout(Configuration tonyConf) { String distributedModeVal = tonyConf.get(TonyConfigurationKeys.APPLICATION_DISTRIBUTED_MODE, TonyConfigurationKeys.DEFAULT_APPLICATION_DISTRIBUTED_MODE); @@ -237,4 +366,4 @@ private boolean enableSidecarTB(Configuration tonyConf) { } protected abstract void buildTaskEnv(TaskExecutor executor) throws Exception; -} \ No newline at end of file +} diff --git a/tony-core/src/main/java/com/linkedin/tony/tensorflow/TonySession.java b/tony-core/src/main/java/com/linkedin/tony/tensorflow/TonySession.java index 016818f9..134e460e 100644 --- a/tony-core/src/main/java/com/linkedin/tony/tensorflow/TonySession.java +++ b/tony-core/src/main/java/com/linkedin/tony/tensorflow/TonySession.java @@ -20,6 +20,7 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -198,9 +199,13 @@ public int getNotCompletedTrackedTasks() { public int getNotCompletedTrackedTasks(String jobType) { return (int) jobTasks.entrySet().stream().filter(entry -> Utils.isJobTypeMonitored(entry.getKey(), tonyConf)) - .filter(entry -> entry.getKey().equals(jobType)) - .flatMap(entry -> Arrays.stream(entry.getValue())) - .filter(task -> task == null || !task.isCompleted()).count(); + .filter(entry -> entry.getKey().equals(jobType)) + .flatMap(entry -> Arrays.stream(entry.getValue())) + .filter(task -> task == null || !task.isCompleted()).count(); + } + + public List getRunningTasks() { + return jobTasks.values().stream().flatMap(x -> Arrays.stream(x)).filter(task -> task != null && !task.isCompleted()).collect(Collectors.toList()); } public int getNumFailedTasks() { @@ -265,6 +270,7 @@ public void onTaskCompleted(String jobName, String jobIndex, int exitCode, Strin TonyTask task = getTask(jobName, jobIndex); Preconditions.checkNotNull(task); task.setExitStatus(exitCode); + task.setEndTime(System.currentTimeMillis()); // If the chief worker failed[chief or worker 0], short circuit and stop the training. Note that even though other // worker failures will also fail the job but we don't short circuit the training because the training can still // continue, while if chief worker is dead, TensorFlow training will hang. @@ -354,7 +360,7 @@ public synchronized void setFinalStatus(FinalApplicationStatus status, String me } } - private TonyTask getTask(String jobName, String taskIndex) { + public TonyTask getTask(String jobName, String taskIndex) { for (Map.Entry entry : jobTasks.entrySet()) { TonyTask[] tasks = entry.getValue(); for (TonyTask task : tasks) { @@ -368,6 +374,15 @@ private TonyTask getTask(String jobName, String taskIndex) { return null; } + @VisibleForTesting + public void addTask(TonyTask tonyTask) { + String jobName = tonyTask.getJobName(); + TonyTask[] tasks = jobTasks.getOrDefault(jobName, new TonyTask[]{}); + List newTasks = new ArrayList<>(Arrays.asList(tasks)); + newTasks.add(tonyTask); + jobTasks.put(jobName, newTasks.toArray(new TonyTask[newTasks.size()])); + } + /** * Returns true if the job is "chief" or if there is no "chief" job and ("worker", "0") is passed in. */ @@ -424,6 +439,7 @@ public class TonyTask implements Comparable { private int port = -1; private TaskInfo taskInfo; private final long startTime; + private long endTime; /** * The container the task is running in. Set once a container has been allocated for the task. @@ -486,7 +502,7 @@ synchronized int getExitStatus() { return exitStatus; } - synchronized void setExitStatus(int status) { + public synchronized void setExitStatus(int status) { // Only set exit status if it hasn't been set yet if (exitStatus == -1) { this.exitStatus = status; @@ -516,6 +532,12 @@ public void setTaskInfo(Container container) { taskInfo = new TaskInfo(jobName, taskIndex, Utils.constructContainerUrl(container)); } + // just for test case + @VisibleForTesting + public void setTaskInfo() { + taskInfo = new TaskInfo(jobName, taskIndex, ""); + } + TonyTask(String jobName, String taskIndex, int sessionId, long startTime) { this.jobName = jobName; this.taskIndex = taskIndex; @@ -528,6 +550,14 @@ public void addContainer(Container container) { containerIdMap.put(container.getId(), this); } + public long getEndTime() { + return endTime; + } + + public void setEndTime(long endTime) { + this.endTime = endTime; + } + /** * Combination of jobName and taskIndex. * @return Id diff --git a/tony-core/src/main/java/com/linkedin/tony/util/Utils.java b/tony-core/src/main/java/com/linkedin/tony/util/Utils.java index 767f1873..466803fe 100644 --- a/tony-core/src/main/java/com/linkedin/tony/util/Utils.java +++ b/tony-core/src/main/java/com/linkedin/tony/util/Utils.java @@ -35,6 +35,8 @@ import org.apache.commons.cli.ParseException; import org.apache.commons.io.IOUtils; import org.apache.commons.lang.StringUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.Triple; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; @@ -65,6 +67,7 @@ import com.linkedin.tony.horovod.HorovodClusterSpec; import com.linkedin.tony.rpc.TaskInfo; import com.linkedin.tony.tensorflow.TensorFlowContainerRequest; +import com.linkedin.tony.tensorflow.TonySession; import net.lingala.zip4j.core.ZipFile; import net.lingala.zip4j.exception.ZipException; @@ -394,14 +397,52 @@ public static int getNumTotalTasks(Configuration conf) { .sum(); } + public static Map> getJobTypeDependentGrps(Configuration tonyConf) { + return tonyConf.getValByRegex(TonyConfigurationKeys.GROUP_DEPEND_TIMEOUT_REGEX).keySet().stream() + .map(Utils::getDependentGrps) + .map(pair -> Utils.getDependentTimeout(tonyConf, pair)) + .collect(Collectors.toMap(Triple::getLeft, x -> Pair.of(x.getMiddle(), x.getRight()), (oldV, newV) -> newV)); + } + + private static Triple getDependentTimeout(Configuration tonyConf, Pair pair) { + String grp = pair.getKey(); + String dependentGrp = pair.getValue(); + long timeout = tonyConf.getLong(TonyConfigurationKeys.getGroupDependentKey(grp, dependentGrp), 0L); + return Triple.of(grp, dependentGrp, timeout); + } + + private static Pair getDependentGrps(String confKey) { + Pattern instancePattern = Pattern.compile(TonyConfigurationKeys.GROUP_DEPEND_TIMEOUT_REGEX); + Matcher instanceMatcher = instancePattern.matcher(confKey); + if (instanceMatcher.matches()) { + return Pair.of(instanceMatcher.group(1), instanceMatcher.group(2)); + } + return null; + } + + public static Map> getAllGroupJobTypes(Configuration conf) { + return conf.getValByRegex(TonyConfigurationKeys.GROUP_REGEX).keySet().stream() + .map(Utils::getGroupName) + .map(groupName -> Utils.getGroupMembers(conf, groupName)) + .collect(Collectors.toMap(Pair::getLeft, Pair::getRight)); + } + + private static Pair> getGroupMembers(Configuration conf, String groupName) { + return Pair.of(groupName, Arrays.asList(conf.getStrings(TonyConfigurationKeys.getGroupKey(groupName)))); + } + /** - * Extracts TensorFlow job name from configuration key of the form "tony.*.instances". + * Extracts group name from configuration key of the form "tony.application.group.*". * @param confKey Name of the configuration key - * @return TensorFlow job name + * @return group name */ - public static String getTaskType(String confKey) { - Pattern instancePattern = Pattern.compile(TonyConfigurationKeys.INSTANCES_REGEX); - Matcher instanceMatcher = instancePattern.matcher(confKey); + private static String getGroupName(String confKey) { + return getRegexKey(confKey, TonyConfigurationKeys.GROUP_REGEX); + } + + private static String getRegexKey(String conf, String regex) { + Pattern instancePattern = Pattern.compile(regex); + Matcher instanceMatcher = instancePattern.matcher(conf); if (instanceMatcher.matches()) { return instanceMatcher.group(1); } else { @@ -409,6 +450,15 @@ public static String getTaskType(String confKey) { } } + /** + * Extracts TensorFlow job name from configuration key of the form "tony.*.instances". + * @param confKey Name of the configuration key + * @return TensorFlow job name + */ + public static String getTaskType(String confKey) { + return getRegexKey(confKey, TonyConfigurationKeys.INSTANCES_REGEX); + } + public static boolean isArchive(String path) { File f = new File(path); int fileSignature = 0; @@ -717,5 +767,9 @@ public static HorovodClusterSpec parseClusterSpecForHorovod(String clusterSpec) return spec; } + public static boolean existRunningTasksWithJobtype(List runningTasks, String jobtype) { + return runningTasks.stream().anyMatch(x -> x.getJobName().equals(jobtype)); + } + private Utils() { } } diff --git a/tony-core/src/main/resources/tony-default.xml b/tony-core/src/main/resources/tony-default.xml index c28a6276..11a423e2 100644 --- a/tony-core/src/main/resources/tony-default.xml +++ b/tony-core/src/main/resources/tony-default.xml @@ -389,4 +389,14 @@ tony.horovod.driver.mode.debug false + + + tony.application.group.TFESTIMATOR + chief + + + + tony.application.dependency.worker.timeout.after.TFESTIMATOR + 7200 + diff --git a/tony-core/src/test/java/com/linkedin/tony/TestTonyE2E.java b/tony-core/src/test/java/com/linkedin/tony/TestTonyE2E.java index 8b4b5952..ef8597bd 100644 --- a/tony-core/src/test/java/com/linkedin/tony/TestTonyE2E.java +++ b/tony-core/src/test/java/com/linkedin/tony/TestTonyE2E.java @@ -553,6 +553,30 @@ public void testAttachedTensorboardShouldPass() throws ParseException, IOExcepti client.removeListener(handler); } + /** + * When chief finished, the worker will finished after 10s + */ + @Test + public void testGroupDependencyTimeoutShouldPass() throws ParseException, IOException { + client.init(new String[]{ + "--src_dir", "tony-core/src/test/resources/scripts", + "--hdfs_classpath", libPath, + "--container_env", Constants.SKIP_HADOOP_PATH + "=true", + "--python_venv", "tony-core/src/test/resources/test.zip", + "--executes", "python exit_0.py", + "--conf", "tony.chief.instances=1", + "--conf", "tony.worker.instances=2", + "--conf", "tony.worker.command=python forever_not_exit.py", + "--conf", "tony.application.framework=tensorflow", + "--conf", "tony.application.group.A=chief", + "--conf", "tony.application.dependency.worker.timeout.after.A=10", + }); + client.addListener(handler); + int exitCode = client.start(); + Assert.assertEquals(exitCode, -1); + client.removeListener(handler); + } + /** * Since we are switching from passing arguments to ApplicationMaster & TaskExecutor * to passing tony configuration file. It is critical to make sure all fields in diff --git a/tony-core/src/test/java/com/linkedin/tony/runtime/TestMLGenericRuntime.java b/tony-core/src/test/java/com/linkedin/tony/runtime/TestMLGenericRuntime.java deleted file mode 100644 index 9819878d..00000000 --- a/tony-core/src/test/java/com/linkedin/tony/runtime/TestMLGenericRuntime.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2021 LinkedIn Corp. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not - * use this file except in compliance with the License. You may obtain a copy of - * the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations under - * the License. - */ -package com.linkedin.tony.runtime; - -import org.apache.hadoop.conf.Configuration; -import org.testng.Assert; -import org.testng.annotations.BeforeTest; -import org.testng.annotations.Test; - -import com.linkedin.tony.TaskExecutor; - -import static com.linkedin.tony.TonyConfigurationKeys.TENSORBOARD_LOG_DIR; - -public class TestMLGenericRuntime { - private MLGenericRuntime runtime; - - static class TestRuntime extends MLGenericRuntime { - @Override - protected void buildTaskEnv(TaskExecutor executor) throws Exception { - return; - } - } - - @BeforeTest - public void before() { - runtime = new TestRuntime(); - } - - /** - * Test MLGenericRuntime when in task executor. - */ - @Test - public void testNeedReserveTBPort() { - TaskExecutor taskExecutor = new TaskExecutor(); - taskExecutor.setJobName("chief"); - - runtime.initTaskExecutorResource(taskExecutor); - - taskExecutor.setChief(true); - Assert.assertTrue(runtime.needReserveTBPort()); - - Configuration conf1 = new Configuration(); - conf1.set(TENSORBOARD_LOG_DIR, "/tmp"); - taskExecutor.setTonyConf(conf1); - Assert.assertFalse(runtime.needReserveTBPort()); - - taskExecutor.setChief(false); - Assert.assertFalse(runtime.needReserveTBPort()); - - taskExecutor.setJobName("tensorboard"); - Assert.assertTrue(runtime.needReserveTBPort()); - } -} diff --git a/tony-core/src/test/java/com/linkedin/tony/util/TestUtils.java b/tony-core/src/test/java/com/linkedin/tony/util/TestUtils.java index 4ac7716e..9e56d596 100644 --- a/tony-core/src/test/java/com/linkedin/tony/util/TestUtils.java +++ b/tony-core/src/test/java/com/linkedin/tony/util/TestUtils.java @@ -18,7 +18,10 @@ import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; + +import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.CommonConfigurationKeys; import org.apache.hadoop.yarn.api.records.Container; @@ -271,4 +274,34 @@ public void testGetContainerEnvForDocker() { assertEquals(Utils.getContainerEnvForDocker(conf, "tony.worker.gpus"), new HashMap<>()); } + + @Test + public void testGetAllGroupJobTypes() { + Configuration conf = new Configuration(); + conf.addResource("tony-default.xml"); + conf.set("tony.application.group.A", "worker,chief"); + conf.set("tony.application.group.B", "evaluator"); + + Map> groupIndex = Utils.getAllGroupJobTypes(conf); + assertTrue(groupIndex.containsKey("A")); + assertTrue(groupIndex.containsKey("B")); + assertEquals(groupIndex.get("A"), Arrays.asList("worker", "chief")); + assertEquals(groupIndex.get("B"), Arrays.asList("evaluator")); + } + + @Test + public void testGetGroupDependencies() { + Configuration conf = new Configuration(); + conf.addResource("tony-default.xml"); + conf.set("tony.application.dependency.A.timeout.after.B", "3600"); + conf.set("tony.application.dependency.B.timeout.after.C", "3600"); + + Map> dependenciesIndex = Utils.getJobTypeDependentGrps(conf); + assertTrue(dependenciesIndex.containsKey("A")); + assertTrue(dependenciesIndex.containsKey("B")); + assertEquals(dependenciesIndex.get("A").getKey(), "B"); + assertEquals(dependenciesIndex.get("A").getValue(), Long.valueOf("3600")); + assertEquals(dependenciesIndex.get("B").getKey(), "C"); + assertEquals(dependenciesIndex.get("B").getValue(), Long.valueOf("3600")); + } } diff --git a/tony-core/src/test/resources/scripts/forever_not_exit.py b/tony-core/src/test/resources/scripts/forever_not_exit.py new file mode 100644 index 00000000..787688d2 --- /dev/null +++ b/tony-core/src/test/resources/scripts/forever_not_exit.py @@ -0,0 +1,8 @@ +# +# Copyright 2021 LinkedIn Corporation. All rights reserved. Licensed under the BSD-2 Clause license. +# See LICENSE in the project root for license information. +# +import time + +while True: + time.sleep(1) \ No newline at end of file