Skip to content

Commit

Permalink
[jobframework] Add support for partial admission
Browse files Browse the repository at this point in the history
  • Loading branch information
trasc committed May 31, 2023
1 parent f034ed4 commit 8481a0a
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 56 deletions.
8 changes: 8 additions & 0 deletions pkg/controller/jobframework/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,13 @@ const (
// node selectors are recorded upon a workload admission. This information
// will be used to restore them when the job is suspended.
// The content is a json marshaled slice of selectors.
//
// DEPRECATED: Use OriginalPodSetsInfoAnnotation.
OriginalNodeSelectorsAnnotation = "kueue.x-k8s.io/original-node-selectors"

// OriginalPodSetsInfoAnnotation is the annotation in which the original
// node selectors and podSet counts are recorded upon a workload admission.
// This information will be used to restore them when the job is suspended.
// The content is a json marshaled slice of PodSetInfo.
OriginalPodSetsInfoAnnotation = "kueue.x-k8s.io/original-pod-sets-info"
)
8 changes: 4 additions & 4 deletions pkg/controller/jobframework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ type GenericJob interface {
// ResetStatus will reset the job status to the original state.
// If true, status is modified, if not, status is as it was.
ResetStatus() bool
// RunWithNodeAffinity will inject the node affinity extracting from workload to job and unsuspend the job.
RunWithNodeAffinity(nodeSelectors []PodSetNodeSelector)
// RestoreNodeAffinity will restore the original node affinity of job.
RestoreNodeAffinity(nodeSelectors []PodSetNodeSelector)
// RunWithPodSetsInfo will inject the node affinity and podSet counts extracting from workload to job and unsuspend it.
RunWithPodSetsInfo(nodeSelectors []PodSetInfo)
// RestorePodSetsInfo will restore the original node affinity and podSet counts of the job.
RestorePodSetsInfo(nodeSelectors []PodSetInfo)
// Finished means whether the job is completed/failed or not,
// condition represents the workload finished condition.
Finished() (condition metav1.Condition, finished bool)
Expand Down
84 changes: 54 additions & 30 deletions pkg/controller/jobframework/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package jobframework
import (
"context"
"encoding/json"
"errors"
"fmt"

corev1 "k8s.io/api/core/v1"
Expand All @@ -32,11 +33,13 @@ import (
kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/constants"
utilpriority "sigs.k8s.io/kueue/pkg/util/priority"
"sigs.k8s.io/kueue/pkg/util/slices"
"sigs.k8s.io/kueue/pkg/workload"
)

var (
errNodeSelectorsNotFound = fmt.Errorf("annotation %s not found", OriginalNodeSelectorsAnnotation)
errPodSetsInfoNotFound = fmt.Errorf("annotation %s or %s not found", OriginalNodeSelectorsAnnotation, OriginalPodSetsInfoAnnotation)
errUnknownPodSetName = errors.New("unknown podSet name")
)

// JobReconciler reconciles a GenericJob object
Expand Down Expand Up @@ -332,17 +335,17 @@ func (r *JobReconciler) equivalentToWorkload(job GenericJob, object client.Objec

// startJob will unsuspend the job, and also inject the node affinity.
func (r *JobReconciler) startJob(ctx context.Context, job GenericJob, object client.Object, wl *kueue.Workload) error {
//get the original selectors and store them in the job object
originalSelectors := r.getNodeSelectorsFromPodSets(wl)
if err := setNodeSelectorsInAnnotation(object, originalSelectors); err != nil {
return fmt.Errorf("startJob, record original node selectors: %w", err)
//get the original podSetsInfo and store them in the job object
originalPodSetsInfo := r.getPodSetsInfoFromSpec(wl)
if err := setNodeSelectorsInAnnotation(object, originalPodSetsInfo); err != nil {
return fmt.Errorf("startJob, record original podSetsInfo: %w", err)
}

nodeSelectors, err := r.getNodeSelectorsFromAdmission(ctx, wl)
info, err := r.getPodSetsInfoFromAdmission(ctx, wl)
if err != nil {
return err
}
job.RunWithNodeAffinity(nodeSelectors)
job.RunWithPodSetsInfo(info)

if err := r.client.Update(ctx, object); err != nil {
return err
Expand Down Expand Up @@ -372,12 +375,12 @@ func (r *JobReconciler) stopJob(ctx context.Context, job GenericJob, object clie
}
}

log.V(3).Info("restore node selectors from annotation")
selectors, err := getNodeSelectorsFromObjectAnnotation(object)
log.V(3).Info("restore podSetsInfo from annotation")
info, err := getPodSetsInfoFromObjectAnnotation(object, job)
if err != nil {
log.V(3).Error(err, "Unable to get original node selectors")
log.V(3).Error(err, "Unable to get original podSetsInfo")
} else {
job.RestoreNodeAffinity(selectors)
job.RestorePodSetsInfo(info)
return r.client.Update(ctx, object)
}

Expand Down Expand Up @@ -412,24 +415,26 @@ func (r *JobReconciler) constructWorkload(ctx context.Context, job GenericJob, o
return wl, nil
}

type PodSetNodeSelector struct {
type PodSetInfo struct {
Name string `json:"name"`
NodeSelector map[string]string `json:"nodeSelector"`
Count int32 `json:"count"`
}

// getNodeSelectorsFromAdmission will extract node selectors from admitted workloads.
func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *kueue.Workload) ([]PodSetNodeSelector, error) {
// getPodSetsInfoFromAdmission will extract podSetsInfo and podSets count from admitted workloads.
func (r *JobReconciler) getPodSetsInfoFromAdmission(ctx context.Context, w *kueue.Workload) ([]PodSetInfo, error) {
if len(w.Status.Admission.PodSetAssignments) == 0 {
return nil, nil
}

nodeSelectors := make([]PodSetNodeSelector, len(w.Status.Admission.PodSetAssignments))
nodeSelectors := make([]PodSetInfo, len(w.Status.Admission.PodSetAssignments))

for i, podSetFlavor := range w.Status.Admission.PodSetAssignments {
processedFlvs := sets.NewString()
nodeSelector := PodSetNodeSelector{
nodeSelector := PodSetInfo{
Name: podSetFlavor.Name,
NodeSelector: make(map[string]string),
Count: podSetFlavor.Count,
}
for _, flvRef := range podSetFlavor.Flavors {
flvName := string(flvRef)
Expand All @@ -452,18 +457,19 @@ func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *ku
return nodeSelectors, nil
}

// getNodeSelectorsFromPodSets will extract node selectors from a workload's podSets.
func (r *JobReconciler) getNodeSelectorsFromPodSets(w *kueue.Workload) []PodSetNodeSelector {
// getPodSetsInfoFromSpec will extract podSetsInfo and podSet's counts from a workload's spec.
func (r *JobReconciler) getPodSetsInfoFromSpec(w *kueue.Workload) []PodSetInfo {
podSets := w.Spec.PodSets
if len(podSets) == 0 {
return nil
}
ret := make([]PodSetNodeSelector, len(podSets))
ret := make([]PodSetInfo, len(podSets))
for psi := range podSets {
ps := &podSets[psi]
ret[psi] = PodSetNodeSelector{
ret[psi] = PodSetInfo{
Name: ps.Name,
NodeSelector: cloneNodeSelector(ps.Template.Spec.NodeSelector),
Count: ps.Count,
}
}
return ret
Expand Down Expand Up @@ -519,34 +525,52 @@ func cloneNodeSelector(src map[string]string) map[string]string {
return ret
}

// getNodeSelectorsFromObjectAnnotation tries to retrieve a node selectors slice from the
// getPodSetsInfoFromObjectAnnotation tries to retrieve a podSetsInfo slice from the
// object's annotations fails if it's not found or is unable to unmarshal
func getNodeSelectorsFromObjectAnnotation(obj client.Object) ([]PodSetNodeSelector, error) {
str, found := obj.GetAnnotations()[OriginalNodeSelectorsAnnotation]
func getPodSetsInfoFromObjectAnnotation(obj client.Object, job GenericJob) ([]PodSetInfo, error) {
hasCounts := true
str, found := obj.GetAnnotations()[OriginalPodSetsInfoAnnotation]
if !found {
return nil, errNodeSelectorsNotFound
hasCounts = false
str, found = obj.GetAnnotations()[OriginalNodeSelectorsAnnotation]
if !found {
return nil, errPodSetsInfoNotFound
}
}
// unmarshal
ret := []PodSetNodeSelector{}
ret := []PodSetInfo{}
if err := json.Unmarshal([]byte(str), &ret); err != nil {
return nil, err
}

if !hasCounts {
podSets := job.PodSets()
psMap := slices.ToRefMap(podSets, func(ps *kueue.PodSet) string { return ps.Name })
for i := range ret {
info := &ret[i]
ps, found := psMap[info.Name]
if !found {
return nil, fmt.Errorf("%w: %s", errUnknownPodSetName, info.Name)
}
info.Count = ps.Count
}
}
return ret, nil
}

// setNodeSelectorsInAnnotation - sets an annotation containing the provided node selectors into
// setNodeSelectorsInAnnotation - sets an annotation containing the provided podSetsInfo into
// a job object, even if very unlikely it could return an error related to json.marshaling
func setNodeSelectorsInAnnotation(obj client.Object, nodeSelectors []PodSetNodeSelector) error {
nodeSelectorsBytes, err := json.Marshal(nodeSelectors)
func setNodeSelectorsInAnnotation(obj client.Object, info []PodSetInfo) error {
nodeSelectorsBytes, err := json.Marshal(info)
if err != nil {
return err
}

annotations := obj.GetAnnotations()
if annotations == nil {
annotations = map[string]string{OriginalNodeSelectorsAnnotation: string(nodeSelectorsBytes)}
annotations = map[string]string{OriginalPodSetsInfoAnnotation: string(nodeSelectorsBytes)}
} else {
annotations[OriginalNodeSelectorsAnnotation] = string(nodeSelectorsBytes)
annotations[OriginalPodSetsInfoAnnotation] = string(nodeSelectorsBytes)
}
obj.SetAnnotations(annotations)
return nil
Expand Down
19 changes: 18 additions & 1 deletion pkg/controller/jobframework/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ var (
queueNameLabelPath = labelsPath.Key(QueueLabel)

originalNodeSelectorsWorkloadKeyPath = annotationsPath.Key(OriginalNodeSelectorsAnnotation)
originalPodSetsInfosWorkloadKeyPath = annotationsPath.Key(OriginalPodSetsInfoAnnotation)
)

func ValidateCreateForQueueName(job GenericJob) field.ErrorList {
Expand Down Expand Up @@ -83,10 +84,26 @@ func ValidateUpdateForOriginalNodeSelectors(oldJob, newJob GenericJob) field.Err
allErrs = append(allErrs, field.Forbidden(originalNodeSelectorsWorkloadKeyPath, "this annotation is immutable while the job is not changing its suspended state"))
}
} else if av, found := newJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation]; found {
out := []PodSetNodeSelector{}
out := []PodSetInfo{}
if err := json.Unmarshal([]byte(av), &out); err != nil {
allErrs = append(allErrs, field.Invalid(originalNodeSelectorsWorkloadKeyPath, av, err.Error()))
}
}
return allErrs
}

func ValidateUpdateForOriginalPodSetsInfo(oldJob, newJob GenericJob) field.ErrorList {
var allErrs field.ErrorList
if oldJob.IsSuspended() == newJob.IsSuspended() {
if errList := apivalidation.ValidateImmutableField(oldJob.Object().GetAnnotations()[OriginalPodSetsInfoAnnotation],
newJob.Object().GetAnnotations()[OriginalPodSetsInfoAnnotation], originalPodSetsInfosWorkloadKeyPath); len(errList) > 0 {
allErrs = append(allErrs, field.Forbidden(originalPodSetsInfosWorkloadKeyPath, "this annotation is immutable while the job is not changing its suspended state"))
}
} else if av, found := newJob.Object().GetAnnotations()[OriginalPodSetsInfoAnnotation]; found {
out := []PodSetInfo{}
if err := json.Unmarshal([]byte(av), &out); err != nil {
allErrs = append(allErrs, field.Invalid(originalPodSetsInfosWorkloadKeyPath, av, err.Error()))
}
}
return allErrs
}
4 changes: 2 additions & 2 deletions pkg/controller/jobs/job/job_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func (j *Job) PodSets() []kueue.PodSet {
}
}

func (j *Job) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
func (j *Job) RunWithPodSetsInfo(nodeSelectors []jobframework.PodSetInfo) {
j.Spec.Suspend = pointer.Bool(false)
if len(nodeSelectors) == 0 {
return
Expand All @@ -197,7 +197,7 @@ func (j *Job) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelecto
}
}

func (j *Job) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
func (j *Job) RestorePodSetsInfo(nodeSelectors []jobframework.PodSetInfo) {
if len(nodeSelectors) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, nodeSelectors[0].NodeSelector) {
return
}
Expand Down
1 change: 1 addition & 0 deletions pkg/controller/jobs/job/job_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ func (w *JobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.
func validateUpdate(oldJob, newJob jobframework.GenericJob) field.ErrorList {
allErrs := validateCreate(newJob)
allErrs = append(allErrs, jobframework.ValidateUpdateForParentWorkload(oldJob, newJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalPodSetsInfo(oldJob, newJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldJob, newJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForQueueName(oldJob, newJob)...)
return allErrs
Expand Down
12 changes: 6 additions & 6 deletions pkg/controller/jobs/mpijob/mpijob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,17 @@ func (j *MPIJob) PodSets() []kueue.PodSet {
return podSets
}

func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
func (j *MPIJob) RunWithPodSetsInfo(podSetInfos []jobframework.PodSetInfo) {
j.Spec.RunPolicy.Suspend = pointer.Bool(false)
if len(nodeSelectors) == 0 {
if len(podSetInfos) == 0 {
return
}
// The node selectors are provided in the same order as the generated list of
// podSets, use the same ordering logic to restore them.
orderedReplicaTypes := orderedReplicaTypes(&j.Spec)
for index := range nodeSelectors {
for index := range podSetInfos {
replicaType := orderedReplicaTypes[index]
nodeSelector := nodeSelectors[index]
nodeSelector := podSetInfos[index]
if len(nodeSelector.NodeSelector) != 0 {
if j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector == nil {
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = nodeSelector.NodeSelector
Expand All @@ -146,9 +146,9 @@ func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSele
}
}

func (j *MPIJob) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
func (j *MPIJob) RestorePodSetsInfo(podSetInfos []jobframework.PodSetInfo) {
orderedReplicaTypes := orderedReplicaTypes(&j.Spec)
for index, nodeSelector := range nodeSelectors {
for index, nodeSelector := range podSetInfos {
replicaType := orderedReplicaTypes[index]
if !equality.Semantic.DeepEqual(j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector, nodeSelector) {
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = map[string]string{}
Expand Down
1 change: 1 addition & 0 deletions pkg/controller/jobs/mpijob/mpijob_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ func (w *MPIJobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runti
log := ctrl.LoggerFrom(ctx).WithName("job-webhook")
log.Info("Validating update", "job", klog.KObj(newJob))
allErrs := jobframework.ValidateUpdateForQueueName(oldGenJob, newGenJob)
allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalPodSetsInfo(oldGenJob, newGenJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldGenJob, newGenJob)...)
return allErrs.ToAggregate()
}
Expand Down
20 changes: 10 additions & 10 deletions pkg/controller/jobs/rayjob/rayjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,20 +137,20 @@ func applySelectors(dst, src map[string]string) map[string]string {
return dst
}

func (j *RayJob) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
func (j *RayJob) RunWithPodSetsInfo(podSetInfos []jobframework.PodSetInfo) {
j.Spec.Suspend = false
if len(nodeSelectors) != len(j.Spec.RayClusterSpec.WorkerGroupSpecs)+1 {
if len(podSetInfos) != len(j.Spec.RayClusterSpec.WorkerGroupSpecs)+1 {
return
}

// head
headPodSpec := &j.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec
headPodSpec.NodeSelector = applySelectors(headPodSpec.NodeSelector, nodeSelectors[0].NodeSelector)
headPodSpec.NodeSelector = applySelectors(headPodSpec.NodeSelector, podSetInfos[0].NodeSelector)

// workers
for index := range j.Spec.RayClusterSpec.WorkerGroupSpecs {
workerPodSpec := &j.Spec.RayClusterSpec.WorkerGroupSpecs[index].Template.Spec
workerPodSpec.NodeSelector = applySelectors(workerPodSpec.NodeSelector, nodeSelectors[index+1].NodeSelector)
workerPodSpec.NodeSelector = applySelectors(workerPodSpec.NodeSelector, podSetInfos[index+1].NodeSelector)
}
}

Expand All @@ -162,22 +162,22 @@ func cloneSelectors(src map[string]string) map[string]string {
return dst
}

func (j *RayJob) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
if len(nodeSelectors) != len(j.Spec.RayClusterSpec.WorkerGroupSpecs)+1 {
func (j *RayJob) RestorePodSetsInfo(podSetInfos []jobframework.PodSetInfo) {
if len(podSetInfos) != len(j.Spec.RayClusterSpec.WorkerGroupSpecs)+1 {
return
}

// head
headPodSpec := &j.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec
if !equality.Semantic.DeepEqual(headPodSpec.NodeSelector, nodeSelectors[0].NodeSelector) {
headPodSpec.NodeSelector = cloneSelectors(nodeSelectors[0].NodeSelector)
if !equality.Semantic.DeepEqual(headPodSpec.NodeSelector, podSetInfos[0].NodeSelector) {
headPodSpec.NodeSelector = cloneSelectors(podSetInfos[0].NodeSelector)
}

// workers
for index := range j.Spec.RayClusterSpec.WorkerGroupSpecs {
workerPodSpec := &j.Spec.RayClusterSpec.WorkerGroupSpecs[index].Template.Spec
if !equality.Semantic.DeepEqual(workerPodSpec.NodeSelector, nodeSelectors[index+1].NodeSelector) {
workerPodSpec.NodeSelector = cloneSelectors(nodeSelectors[index+1].NodeSelector)
if !equality.Semantic.DeepEqual(workerPodSpec.NodeSelector, podSetInfos[index+1].NodeSelector) {
workerPodSpec.NodeSelector = cloneSelectors(podSetInfos[index+1].NodeSelector)
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/controller/jobs/rayjob/rayjob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,8 @@ func TestNodeSelectors(t *testing.T) {
}).
Obj())

// run with Affinity should append or update the node selectors
job.RunWithNodeAffinity([]jobframework.PodSetNodeSelector{
// RunWithPodSetsInfo should append or update the node selectors
job.RunWithPodSetsInfo([]jobframework.PodSetInfo{
{
NodeSelector: map[string]string{
"newKey": "newValue",
Expand Down Expand Up @@ -341,7 +341,7 @@ func TestNodeSelectors(t *testing.T) {
}

// restore should replace node selectors
job.RestoreNodeAffinity([]jobframework.PodSetNodeSelector{
job.RestorePodSetsInfo([]jobframework.PodSetInfo{
{
NodeSelector: map[string]string{
// clean it all
Expand Down

0 comments on commit 8481a0a

Please sign in to comment.