Skip to content

Commit

Permalink
Merge branch 'branch-0.3.20' into 'branch-0.3.20'
Browse files Browse the repository at this point in the history
Backport: Make job fail when partial tasks' pre-dependent tasks finished and exceeds the waiting timeout (tony-framework#621)

See merge request !72
  • Loading branch information
zhangjunfan committed Dec 1, 2021
2 parents 0e7b85d + 5dd9270 commit e7431e9
Show file tree
Hide file tree
Showing 9 changed files with 317 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -331,4 +331,22 @@ public static String getContainerDockerMountKey() {

public static final String TB_GPUS = TB_JOB_PREFIX + "gpus";
public static final int DEFAULT_TB_GPUS = 0;

/**
* Introduce the group dependency waiting time(sec), like as follows:
* tony.application.group.A = worker,chief
*
* tony.application.dependency.evaluator.timeout.after.A = 3600
*/
public static final String GROUP_REGEX = TONY_APPLICATION_PREFIX + "group\\.([A-Za-z]+)$";
public static final String GROUP_DEPEND_TIMEOUT_REGEX =
TONY_APPLICATION_PREFIX + "dependency\\.([A-Za-z]+)\\.timeout\\.after\\.([A-Za-z]+)$";

public static String getGroupKey(String groupName) {
return String.format(TONY_APPLICATION_PREFIX + "group.%s", groupName);
}

public static String getGroupDependentKey(String grp, String dependentGrp) {
return String.format(TONY_APPLICATION_PREFIX + "dependency.%s.timeout.after.%s", grp, dependentGrp);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
package com.linkedin.tony.runtime;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
Expand All @@ -34,6 +38,7 @@
import com.linkedin.tony.TaskExecutor;
import com.linkedin.tony.TonyConfigurationKeys;
import com.linkedin.tony.tensorflow.TonySession;
import com.linkedin.tony.util.Utils;

import static com.linkedin.tony.Constants.SIDECAR_TB_ROLE_NAME;

Expand All @@ -51,6 +56,12 @@ public abstract class MLGenericRuntime implements FrameworkRuntime {

// ===================For AM =======================

// Group dependencies policy.
Map<String, List<String>> grpWithMembersIndex;
Map<String, List<String>> taskInGrpsIndex;
// todo: Need to support single group dependent multiple other groups
Map<String, Pair<String, Long>> taskWithDependentGrpsIndex;

@Override
public String constructClusterSpec(String taskId) throws IOException {
assert session != null;
Expand Down Expand Up @@ -116,9 +127,127 @@ public boolean isHealthy(Configuration tonyConf) {
session.setFinalStatus(FinalApplicationStatus.FAILED, "Container allocation timeout.");
return false;
}

/**
* Checking the task roles completion timeout when its' pre-dependency tasks finished
*
* For example, tensorflow estimator training job will include some roles of ps/worker/evaluator/chief.
* Actually, due to the bug of tensorflow or misusing the estimator api, sometimes evaluator will hang.
* So if we use the configuration as follows, when evaluator is still running after timeout and
* chief/workers are finished, the mechanism of dependency group timeout will make job failed.
*
* Dependency group timeout configuration as follows:
*
* tony.application.group.A = worker,chief
* tony.application.dependency.evaluator.timeout.after.A = 3600
*
*/
String errorMsg = null;
try {
errorMsg = groupDependencyTimeout(tonyConf);
} catch (Exception e) {
log.error("Failed to check dependency timeout.", e);
}
if (errorMsg != null) {
session.setFinalStatus(FinalApplicationStatus.FAILED, errorMsg);
return false;
}
return true;
}

@VisibleForTesting
protected String groupDependencyTimeout(Configuration tonyConf) {
if (taskWithDependentGrpsIndex == null) {
taskWithDependentGrpsIndex = Utils.getJobTypeDependentGrps(tonyConf);
log.info("Task types dependent grp: " + taskWithDependentGrpsIndex);
}
// groupDependencies is map, key: waiting role, value: pre-dependent groups and waiting timeout
if (taskWithDependentGrpsIndex == null || taskWithDependentGrpsIndex.isEmpty()) {
return null;
}

// groupMembers is map, key: groupName, value: its members in this group
if (grpWithMembersIndex == null) {
grpWithMembersIndex = Utils.getAllGroupJobTypes(tonyConf);
log.info("Group members: " + grpWithMembersIndex);
}

// memberInGroups is map. key: jobtype name, value: in which groups
if (taskInGrpsIndex == null) {
taskInGrpsIndex = getMemberInGroups(grpWithMembersIndex);
}

Map<String, TonySession.TonyTask[]> allTasks = session.getTonyTasks();
List<TonySession.TonyTask> runningTasks = session.getRunningTasks();

// Get the running jobs' type, like the tf roles of ps/worker/chief/evaluator
Set<String> runningJobTypes = runningTasks.stream()
.map(TonySession.TonyTask::getJobName)
.filter(jobname -> taskWithDependentGrpsIndex.containsKey(jobname))
.collect(Collectors.toSet());

for (String runningTaskType : runningJobTypes) {
Pair<String, Long> dependentGroupPair = taskWithDependentGrpsIndex.get(runningTaskType);
String dependentGroupName = dependentGroupPair.getKey();
long timeout = dependentGroupPair.getValue() * 1000;

if (!grpWithMembersIndex.containsKey(dependentGroupName)) {
continue;
}

boolean allDependentTaskFinished = true;
long latestEndTimeInAllDependentTasks = 0L;
for (String dependentsGroupJobtype : grpWithMembersIndex.get(dependentGroupName)) {

if (Utils.existRunningTasksWithJobtype(runningTasks, dependentsGroupJobtype)) {
allDependentTaskFinished = false;
break;
}

// Find out the latest finished task in this task type, if the specified timeout exceed,
// make the job fail.
latestEndTimeInAllDependentTasks = Math.max(
Arrays.stream(allTasks.get(dependentsGroupJobtype))
.mapToLong(x -> x.getEndTime())
.max().getAsLong(),
latestEndTimeInAllDependentTasks
);
}

if (!allDependentTaskFinished) {
continue;
}

log.info("Running job type: " + runningTaskType + ", all dependent task finished: " + allDependentTaskFinished);

if (System.currentTimeMillis() - latestEndTimeInAllDependentTasks > timeout) {
return String.format("Jobtype: %s runs exceeded timeout(%s sec) because it's "
+ "dependent group: %s (task set: [%s]) has been finished",
runningTaskType, dependentGroupPair.getValue(), dependentGroupName,
StringUtils.join(grpWithMembersIndex.get(dependentGroupName), ","));
}
}

return null;
}

private Map<String, List<String>> getMemberInGroups(Map<String, List<String>> groupMembers) {
/**
* key: job type name
* value: the list of groups
*/
Map<String, List<String>> memberInGroups = new HashMap<>();
for (Map.Entry<String, List<String>> entry : groupMembers.entrySet()) {
String group = entry.getKey();
List<String> members = entry.getValue();
for (String member : members) {
memberInGroups.putIfAbsent(member, new ArrayList<>());
memberInGroups.get(member).add(group);
}
}
return memberInGroups;
}

private boolean containerAllocationTimeout(Configuration tonyConf) {
String distributedModeVal = tonyConf.get(TonyConfigurationKeys.APPLICATION_DISTRIBUTED_MODE,
TonyConfigurationKeys.DEFAULT_APPLICATION_DISTRIBUTED_MODE);
Expand Down Expand Up @@ -237,4 +366,4 @@ private boolean enableSidecarTB(Configuration tonyConf) {
}

protected abstract void buildTaskEnv(TaskExecutor executor) throws Exception;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -198,9 +199,13 @@ public int getNotCompletedTrackedTasks() {

public int getNotCompletedTrackedTasks(String jobType) {
return (int) jobTasks.entrySet().stream().filter(entry -> Utils.isJobTypeMonitored(entry.getKey(), tonyConf))
.filter(entry -> entry.getKey().equals(jobType))
.flatMap(entry -> Arrays.stream(entry.getValue()))
.filter(task -> task == null || !task.isCompleted()).count();
.filter(entry -> entry.getKey().equals(jobType))
.flatMap(entry -> Arrays.stream(entry.getValue()))
.filter(task -> task == null || !task.isCompleted()).count();
}

public List<TonyTask> getRunningTasks() {
return jobTasks.values().stream().flatMap(x -> Arrays.stream(x)).filter(task -> task != null && !task.isCompleted()).collect(Collectors.toList());
}

public int getNumFailedTasks() {
Expand Down Expand Up @@ -265,6 +270,7 @@ public void onTaskCompleted(String jobName, String jobIndex, int exitCode, Strin
TonyTask task = getTask(jobName, jobIndex);
Preconditions.checkNotNull(task);
task.setExitStatus(exitCode);
task.setEndTime(System.currentTimeMillis());
// If the chief worker failed[chief or worker 0], short circuit and stop the training. Note that even though other
// worker failures will also fail the job but we don't short circuit the training because the training can still
// continue, while if chief worker is dead, TensorFlow training will hang.
Expand Down Expand Up @@ -354,7 +360,7 @@ public synchronized void setFinalStatus(FinalApplicationStatus status, String me
}
}

private TonyTask getTask(String jobName, String taskIndex) {
public TonyTask getTask(String jobName, String taskIndex) {
for (Map.Entry<String, TonyTask[]> entry : jobTasks.entrySet()) {
TonyTask[] tasks = entry.getValue();
for (TonyTask task : tasks) {
Expand All @@ -368,6 +374,15 @@ private TonyTask getTask(String jobName, String taskIndex) {
return null;
}

@VisibleForTesting
public void addTask(TonyTask tonyTask) {
String jobName = tonyTask.getJobName();
TonyTask[] tasks = jobTasks.getOrDefault(jobName, new TonyTask[]{});
List<TonyTask> newTasks = new ArrayList<>(Arrays.asList(tasks));
newTasks.add(tonyTask);
jobTasks.put(jobName, newTasks.toArray(new TonyTask[newTasks.size()]));
}

/**
* Returns true if the job is "chief" or if there is no "chief" job and ("worker", "0") is passed in.
*/
Expand Down Expand Up @@ -424,6 +439,7 @@ public class TonyTask implements Comparable<TonyTask> {
private int port = -1;
private TaskInfo taskInfo;
private final long startTime;
private long endTime;

/**
* The container the task is running in. Set once a container has been allocated for the task.
Expand Down Expand Up @@ -486,7 +502,7 @@ synchronized int getExitStatus() {
return exitStatus;
}

synchronized void setExitStatus(int status) {
public synchronized void setExitStatus(int status) {
// Only set exit status if it hasn't been set yet
if (exitStatus == -1) {
this.exitStatus = status;
Expand Down Expand Up @@ -516,6 +532,12 @@ public void setTaskInfo(Container container) {
taskInfo = new TaskInfo(jobName, taskIndex, Utils.constructContainerUrl(container));
}

// just for test case
@VisibleForTesting
public void setTaskInfo() {
taskInfo = new TaskInfo(jobName, taskIndex, "");
}

TonyTask(String jobName, String taskIndex, int sessionId, long startTime) {
this.jobName = jobName;
this.taskIndex = taskIndex;
Expand All @@ -528,6 +550,14 @@ public void addContainer(Container container) {
containerIdMap.put(container.getId(), this);
}

public long getEndTime() {
return endTime;
}

public void setEndTime(long endTime) {
this.endTime = endTime;
}

/**
* Combination of jobName and taskIndex.
* @return Id
Expand Down
Loading

0 comments on commit e7431e9

Please sign in to comment.