Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make job fail when partial tasks' pre-dependent tasks finished and exceeds the waiting timeout #621

Merged
merged 4 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -334,4 +334,23 @@ public static String getContainerDockerMountKey() {

public static final String TB_GPUS = TB_JOB_PREFIX + "gpus";
public static final int DEFAULT_TB_GPUS = 0;

/**
* Introduce the group dependency waiting time(sec), like as follows:
* tony.application.group.a = worker,chief
* tony.application.group.b = evaluator
*
* tony.application.dependency.b.timeout.after.a = 3600
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does b have to be a group? I think worker.timeout.after.a also works from conversations. Can you clarify here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mistake for code comment, i will fix it.
And tony.application.dependency.evaluator.timeout.after.a is best practise.

*/
public static final String GROUP_REGEX = TONY_APPLICATION_PREFIX + "group\\.([A-Za-z]+)$";
public static final String GROUP_DEPEND_TIMEOUT_REGEX =
TONY_APPLICATION_PREFIX + "dependency\\.([A-Za-z]+)\\.timeout\\.after\\.([A-Za-z]+)$";
Comment on lines +344 to +346
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also add some unit tests for these regex..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has been tested in Utils.getAllGroupJobTypes and Utils.getJobTypeDependentGrps


public static String getGroupKey(String groupName) {
return String.format(TONY_APPLICATION_PREFIX + "group.%s", groupName);
}

public static String getGroupDependentKey(String grp, String dependentGrp) {
return String.format(TONY_APPLICATION_PREFIX + "dependency.%s.timeout.after.%s", grp, dependentGrp);
}
}
35 changes: 33 additions & 2 deletions tony-core/src/main/java/com/linkedin/tony/TonySession.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
Expand Down Expand Up @@ -196,6 +198,10 @@ public int getNumFailedTasks() {
return (int) jobTasks.values().stream().flatMap(arr -> Arrays.stream(arr)).filter(task -> task != null && task.isFailed()).count();
}

public List<TonyTask> getRunningTasks() {
return jobTasks.values().stream().flatMap(x -> Arrays.stream(x)).filter(task -> task != null && !task.isCompleted()).collect(Collectors.toList());
}

/** Number of expected tasks that have been scheduled at current time **/
public int getNumExpectedTasks() {
return numExpectedTasks;
Expand Down Expand Up @@ -262,6 +268,7 @@ public void onTaskCompleted(String jobName, String jobIndex, int exitCode, Strin
TonyTask task = getTask(jobName, jobIndex);
Preconditions.checkNotNull(task);
task.setExitStatus(exitCode);
task.setEndTime(System.currentTimeMillis());
// If the chief worker failed[chief or worker 0], short circuit and stop the training. Note that even though other
// worker failures will also fail the job but we don't short circuit the training because the training can still
// continue, while if chief worker is dead, TensorFlow training will hang.
Expand Down Expand Up @@ -361,7 +368,7 @@ public synchronized void setFinalStatus(FinalApplicationStatus status, String me
}
}

private TonyTask getTask(String jobName, String taskIndex) {
public TonyTask getTask(String jobName, String taskIndex) {
for (Map.Entry<String, TonyTask[]> entry : jobTasks.entrySet()) {
TonyTask[] tasks = entry.getValue();
for (TonyTask task : tasks) {
Expand All @@ -377,6 +384,15 @@ private TonyTask getTask(String jobName, String taskIndex) {
return null;
}

@VisibleForTesting
public void addTask(TonyTask tonyTask) {
String jobName = tonyTask.getJobName();
TonyTask[] tasks = jobTasks.getOrDefault(jobName, new TonyTask[]{});
List<TonyTask> newTasks = new ArrayList<>(Arrays.asList(tasks));
newTasks.add(tonyTask);
jobTasks.put(jobName, newTasks.toArray(new TonyTask[newTasks.size()]));
}

/**
* Returns true if the job is "chief" or if there is no "chief" job and ("worker", "0") is passed in.
*/
Expand Down Expand Up @@ -441,6 +457,7 @@ public class TonyTask implements Comparable<TonyTask> {
private int port = -1;
private TaskInfo taskInfo;
private final long startTime;
private long endTime;

/**
* The container the task is running in. Set once a container has been allocated for the task.
Expand Down Expand Up @@ -503,7 +520,7 @@ synchronized int getExitStatus() {
return exitStatus;
}

synchronized void setExitStatus(int status) {
public synchronized void setExitStatus(int status) {
// Only set exit status if it hasn't been set yet
if (exitStatus == -1) {
this.exitStatus = status;
Expand Down Expand Up @@ -533,6 +550,12 @@ public void setTaskInfo(Container container) {
taskInfo = new TaskInfo(jobName, taskIndex, Utils.constructContainerUrl(container));
}

// just for test case
@VisibleForTesting
public void setTaskInfo() {
taskInfo = new TaskInfo(jobName, taskIndex, "");
}

TonyTask(String jobName, String taskIndex, int sessionId, long startTime) {
this.jobName = jobName;
this.taskIndex = taskIndex;
Expand All @@ -545,6 +568,14 @@ public void addContainer(Container container) {
containerIdMap.put(container.getId(), this);
}

public long getEndTime() {
return endTime;
}

public void setEndTime(long endTime) {
this.endTime = endTime;
}

/**
* Combination of jobName and taskIndex.
* @return Id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
package com.linkedin.tony.runtime;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;

Expand All @@ -33,6 +37,7 @@
import com.linkedin.tony.TaskExecutor;
import com.linkedin.tony.TonyConfigurationKeys;
import com.linkedin.tony.TonySession;
import com.linkedin.tony.util.Utils;

import static com.linkedin.tony.Constants.SIDECAR_TB_ROLE_NAME;

Expand All @@ -55,6 +60,12 @@ class AM implements Framework.ApplicationMasterAdapter {
private long lastRegisterWorkerTime = System.currentTimeMillis();
private long runtimeInitialTime = System.currentTimeMillis();

// Group dependencies policy.
Map<String, List<String>> grpWithMembersIndex;
Map<String, List<String>> taskInGrpsIndex;
// todo: Need to support single group dependent multiple other groups
Map<String, Pair<String, Long>> taskWithDependentGrpsIndex;

@Override
public String constructClusterSpec(String taskId) throws IOException {
assert session != null;
Expand Down Expand Up @@ -120,9 +131,118 @@ public boolean isHealthy(Configuration tonyConf) {
session.setFinalStatus(FinalApplicationStatus.FAILED, "Container allocation timeout.");
return false;
}

/**
* Checking the task roles completion timeout when its' pre-dependency tasks finished
*
* For example, tensorflow estimator training job will include some roles of ps/worker/evaluator/chief.
* Actually, due to the bug of tensorflow or misusing the estimator api, sometimes evaluator will hang.
* So if we use the configuration as follows, when evaluator is still running after timeout and
* chief/workers are finished, the mechanism of dependency group timeout will make job failed.
*
* Dependency group timeout configuration as follows:
*
* tony.application.group.A = worker,chief
* tony.application.dependency.evaluator.timeout.after.A = 3600
*
*/
String errorMsg = groupDependencyTimeout(tonyConf);
if (errorMsg != null) {
session.setFinalStatus(FinalApplicationStatus.FAILED, errorMsg);
return false;
}
return true;
}

@VisibleForTesting
protected String groupDependencyTimeout(Configuration tonyConf) {
if (taskWithDependentGrpsIndex == null) {
taskWithDependentGrpsIndex = Utils.getJobTypeDependentGrps(tonyConf);
}
// groupDependencies is map, key: waiting role, value: pre-dependent groups and waiting timeout
if (taskWithDependentGrpsIndex == null || taskWithDependentGrpsIndex.isEmpty()) {
return null;
}

// groupMembers is map, key: groupName, value: its members in this group
if (grpWithMembersIndex == null) {
grpWithMembersIndex = Utils.getAllGroupJobTypes(tonyConf);
}

// memberInGroups is map. key: jobtype name, value: in which groups
if (taskInGrpsIndex == null) {
taskInGrpsIndex = getMemberInGroups(grpWithMembersIndex);
}

Map<String, TonySession.TonyTask[]> allTasks = session.getTonyTasks();
List<TonySession.TonyTask> runningTasks = session.getRunningTasks();

// Get the running jobs' type, like the tf roles of ps/worker/chief/evaluator
Set<String> runningJobTypes = runningTasks.stream()
.map(TonySession.TonyTask::getJobName)
.filter(jobname -> taskWithDependentGrpsIndex.containsKey(jobname))
.collect(Collectors.toSet());

for (String runningTaskType : runningJobTypes) {
Pair<String, Long> dependentGroupPair = taskWithDependentGrpsIndex.get(runningTaskType);
String dependentGroupName = dependentGroupPair.getKey();
long timeout = dependentGroupPair.getValue() * 1000;

if (!grpWithMembersIndex.containsKey(dependentGroupName)) {
continue;
}

boolean allDependentTaskFinished = true;
long latestEndTimeInAllDependentTasks = 0L;
for (String dependentsGroupJobtype : grpWithMembersIndex.get(dependentGroupName)) {

if (Utils.existRunningTasksWithJobtype(runningTasks, dependentsGroupJobtype)) {
allDependentTaskFinished = false;
break;
}

// Find out the latest finished task in this task type, if the specified timeout exceed,
// make the job fail.
latestEndTimeInAllDependentTasks = Math.max(
Arrays.stream(allTasks.get(dependentsGroupJobtype))
.mapToLong(x -> x.getEndTime())
.max().getAsLong(),
latestEndTimeInAllDependentTasks
);
}

if (!allDependentTaskFinished) {
continue;
}

if (System.currentTimeMillis() - latestEndTimeInAllDependentTasks > timeout) {
return String.format("Jobtype: %s runs exceeded timeout because it's "
+ "dependent group: %s (task set: [%s]) has been finished.",
runningTaskType, dependentGroupName,
StringUtils.join(grpWithMembersIndex.get(dependentGroupName), ","));
}
}

return null;
}

private Map<String, List<String>> getMemberInGroups(Map<String, List<String>> groupMembers) {
/**
* key: job type name
* value: the list of groups
*/
Map<String, List<String>> memberInGroups = new HashMap<>();
for (Map.Entry<String, List<String>> entry : groupMembers.entrySet()) {
String group = entry.getKey();
List<String> members = entry.getValue();
for (String member : members) {
memberInGroups.putIfAbsent(member, new ArrayList<>());
memberInGroups.get(member).add(group);
}
}
return memberInGroups;
}

private boolean containerAllocationTimeout(Configuration tonyConf) {
String distributedModeVal = tonyConf.get(TonyConfigurationKeys.APPLICATION_DISTRIBUTED_MODE,
TonyConfigurationKeys.DEFAULT_APPLICATION_DISTRIBUTED_MODE);
Expand Down
64 changes: 59 additions & 5 deletions tony-core/src/main/java/com/linkedin/tony/util/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import org.apache.commons.cli.ParseException;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.lang3.tuple.Triple;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
Expand Down Expand Up @@ -68,6 +70,7 @@
import com.linkedin.tony.LocalizableResource;
import com.linkedin.tony.TonyConfig;
import com.linkedin.tony.TonyConfigurationKeys;
import com.linkedin.tony.TonySession;
import com.linkedin.tony.horovod.HorovodClusterSpec;
import com.linkedin.tony.rpc.TaskInfo;
import com.linkedin.tony.models.JobContainerRequest;
Expand Down Expand Up @@ -459,21 +462,68 @@ public static int getNumTotalTasks(Configuration conf) {
.sum();
}

public static Map<String, Pair<String, Long>> getJobTypeDependentGrps(Configuration tonyConf) {
return tonyConf.getValByRegex(TonyConfigurationKeys.GROUP_DEPEND_TIMEOUT_REGEX).keySet().stream()
.map(Utils::getDependentGrps)
.map(pair -> Utils.getDependentTimeout(tonyConf, pair))
.collect(Collectors.toMap(Triple::getLeft, x -> Pair.of(x.getMiddle(), x.getRight()), (oldV, newV) -> newV));
}

private static Triple<String, String, Long> getDependentTimeout(Configuration tonyConf, Pair<String, String> pair) {
String grp = pair.getKey();
String dependentGrp = pair.getValue();
long timeout = tonyConf.getLong(TonyConfigurationKeys.getGroupDependentKey(grp, dependentGrp), 0L);
return Triple.of(grp, dependentGrp, timeout);
}

private static Pair<String, String> getDependentGrps(String confKey) {
Pattern instancePattern = Pattern.compile(TonyConfigurationKeys.GROUP_DEPEND_TIMEOUT_REGEX);
Matcher instanceMatcher = instancePattern.matcher(confKey);
if (instanceMatcher.matches()) {
return Pair.of(instanceMatcher.group(1), instanceMatcher.group(2));
}
return null;
}

public static Map<String, List<String>> getAllGroupJobTypes(Configuration conf) {
return conf.getValByRegex(TonyConfigurationKeys.GROUP_REGEX).keySet().stream()
.map(Utils::getGroupName)
.map(groupName -> Utils.getGroupMembers(conf, groupName))
.collect(Collectors.toMap(Pair::getLeft, Pair::getRight));
}

private static Pair<String, List<String>> getGroupMembers(Configuration conf, String groupName) {
return Pair.of(groupName, Arrays.asList(conf.getStrings(TonyConfigurationKeys.getGroupKey(groupName))));
}

/**
* Extracts TensorFlow job name from configuration key of the form "tony.*.instances".
* Extracts group name from configuration key of the form "tony.application.group.*".
* @param confKey Name of the configuration key
* @return TensorFlow job name
* @return group name
*/
public static String getTaskType(String confKey) {
Pattern instancePattern = Pattern.compile(TonyConfigurationKeys.INSTANCES_REGEX);
Matcher instanceMatcher = instancePattern.matcher(confKey);
private static String getGroupName(String confKey) {
return getRegexKey(confKey, TonyConfigurationKeys.GROUP_REGEX);
}

private static String getRegexKey(String conf, String regex) {
Pattern instancePattern = Pattern.compile(regex);
Matcher instanceMatcher = instancePattern.matcher(conf);
if (instanceMatcher.matches()) {
return instanceMatcher.group(1);
} else {
return null;
}
}

/**
* Extracts TensorFlow job name from configuration key of the form "tony.*.instances".
* @param confKey Name of the configuration key
* @return TensorFlow job name
*/
public static String getTaskType(String confKey) {
return getRegexKey(confKey, TonyConfigurationKeys.INSTANCES_REGEX);
}

public static boolean isArchive(String path) {
File f = new File(path);
int fileSignature = 0;
Expand Down Expand Up @@ -796,5 +846,9 @@ public static HorovodClusterSpec parseClusterSpecForHorovod(String clusterSpec)
return spec;
}

public static boolean existRunningTasksWithJobtype(List<TonySession.TonyTask> runningTasks, String jobtype) {
return runningTasks.stream().anyMatch(x -> x.getJobName().equals(jobtype));
}

private Utils() { }
}
Loading