Skip to content

Commit

Permalink
[jobframework] Add support for partial admission
Browse files Browse the repository at this point in the history
  • Loading branch information
trasc committed May 15, 2023
1 parent 2ee58d7 commit 77f462e
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 30 deletions.
8 changes: 8 additions & 0 deletions pkg/controller/jobframework/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,13 @@ const (
// node selectors are recorded upon a workload admission. This information
// will be used to restore them when the job is suspended.
// The content is a json marshaled slice of selectors.
//
// DEPRECATED: Use OriginalPodSetsInfoAnnotation.
OriginalNodeSelectorsAnnotation = "kueue.x-k8s.io/original-node-selectors"

// OriginalPodSetsInfoAnnotation is the annotation in which the original
// node selectors and podSet counts are recorded upon a workload admission.
// This information will be used to restore them when the job is suspended.
// The content is a json marshaled slice of PodSetInfo.
OriginalPodSetsInfoAnnotation = "kueue.x-k8s.io/original-pod-sets-info"
)
8 changes: 4 additions & 4 deletions pkg/controller/jobframework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ type GenericJob interface {
// ResetStatus will reset the job status to the original state.
// If true, status is modified, if not, status is as it was.
ResetStatus() bool
// RunWithNodeAffinity will inject the node affinity extracting from workload to job and unsuspend the job.
RunWithNodeAffinity(nodeSelectors []PodSetNodeSelector)
// RestoreNodeAffinity will restore the original node affinity of job.
RestoreNodeAffinity(nodeSelectors []PodSetNodeSelector)
// RunWithPodSetsInfo will inject the node affinity extracting from workload to job and unsuspend the job.
RunWithPodSetsInfo(nodeSelectors []PodSetInfo)
// RestorePodSetsInfo will restore the original node affinity of job.
RestorePodSetsInfo(nodeSelectors []PodSetInfo)
// Finished means whether the job is completed/failed or not,
// condition represents the workload finished condition.
Finished() (condition metav1.Condition, finished bool)
Expand Down
66 changes: 45 additions & 21 deletions pkg/controller/jobframework/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package jobframework
import (
"context"
"encoding/json"
"errors"
"fmt"

corev1 "k8s.io/api/core/v1"
Expand All @@ -32,11 +33,13 @@ import (
kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/constants"
utilpriority "sigs.k8s.io/kueue/pkg/util/priority"
utilslice "sigs.k8s.io/kueue/pkg/util/slice"
"sigs.k8s.io/kueue/pkg/workload"
)

var (
errNodeSelectorsNotFound = fmt.Errorf("annotation %s not found", OriginalNodeSelectorsAnnotation)
errPodSetsInfoNotFound = fmt.Errorf("annotation %s or %s not found", OriginalNodeSelectorsAnnotation, OriginalPodSetsInfoAnnotation)
errUnknownPodSetName = errors.New("unknown podSet name")
)

// JobReconciler reconciles a GenericJob object
Expand Down Expand Up @@ -333,16 +336,16 @@ func (r *JobReconciler) equivalentToWorkload(job GenericJob, object client.Objec
// startJob will unsuspend the job, and also inject the node affinity.
func (r *JobReconciler) startJob(ctx context.Context, job GenericJob, object client.Object, wl *kueue.Workload) error {
//get the original selectors and store them in the job object
originalSelectors := r.getNodeSelectorsFromPodSets(wl)
originalSelectors := r.getPodSetsInfoFromSpec(wl)
if err := setNodeSelectorsInAnnotation(object, originalSelectors); err != nil {
return fmt.Errorf("startJob, record original node selectors: %w", err)
}

nodeSelectors, err := r.getNodeSelectorsFromAdmission(ctx, wl)
nodeSelectors, err := r.getPodSetsInfoFromAdmission(ctx, wl)
if err != nil {
return err
}
job.RunWithNodeAffinity(nodeSelectors)
job.RunWithPodSetsInfo(nodeSelectors)

if err := r.client.Update(ctx, object); err != nil {
return err
Expand Down Expand Up @@ -373,11 +376,11 @@ func (r *JobReconciler) stopJob(ctx context.Context, job GenericJob, object clie
}

log.V(3).Info("restore node selectors from annotation")
selectors, err := getNodeSelectorsFromObjectAnnotation(object)
selectors, err := getPodSetsInfoFromObjectAnnotation(object)
if err != nil {
log.V(3).Error(err, "Unable to get original node selectors")
} else {
job.RestoreNodeAffinity(selectors)
job.RestorePodSetsInfo(selectors)
return r.client.Update(ctx, object)
}

Expand Down Expand Up @@ -412,24 +415,26 @@ func (r *JobReconciler) constructWorkload(ctx context.Context, job GenericJob, o
return wl, nil
}

type PodSetNodeSelector struct {
type PodSetInfo struct {
Name string `json:"name"`
NodeSelector map[string]string `json:"nodeSelector"`
Count int32 `json:"count"`
}

// getNodeSelectorsFromAdmission will extract node selectors from admitted workloads.
func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *kueue.Workload) ([]PodSetNodeSelector, error) {
// getPodSetsInfoFromAdmission will extract node selectors and podSets count from admitted workloads.
func (r *JobReconciler) getPodSetsInfoFromAdmission(ctx context.Context, w *kueue.Workload) ([]PodSetInfo, error) {
if len(w.Status.Admission.PodSetAssignments) == 0 {
return nil, nil
}

nodeSelectors := make([]PodSetNodeSelector, len(w.Status.Admission.PodSetAssignments))
nodeSelectors := make([]PodSetInfo, len(w.Status.Admission.PodSetAssignments))

for i, podSetFlavor := range w.Status.Admission.PodSetAssignments {
processedFlvs := sets.NewString()
nodeSelector := PodSetNodeSelector{
nodeSelector := PodSetInfo{
Name: podSetFlavor.Name,
NodeSelector: make(map[string]string),
Count: podSetFlavor.Count,
}
for _, flvRef := range podSetFlavor.Flavors {
flvName := string(flvRef)
Expand All @@ -452,18 +457,19 @@ func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *ku
return nodeSelectors, nil
}

// getNodeSelectorsFromPodSets will extract node selectors from a workload's podSets.
func (r *JobReconciler) getNodeSelectorsFromPodSets(w *kueue.Workload) []PodSetNodeSelector {
// getPodSetsInfoFromSpec will extract node selectors and podSet's counts from a workload's spec.
func (r *JobReconciler) getPodSetsInfoFromSpec(w *kueue.Workload) []PodSetInfo {
podSets := w.Spec.PodSets
if len(podSets) == 0 {
return nil
}
ret := make([]PodSetNodeSelector, len(podSets))
ret := make([]PodSetInfo, len(podSets))
for psi := range podSets {
ps := &podSets[psi]
ret[psi] = PodSetNodeSelector{
ret[psi] = PodSetInfo{
Name: ps.Name,
NodeSelector: cloneNodeSelector(ps.Template.Spec.NodeSelector),
Count: ps.Count,
}
}
return ret
Expand Down Expand Up @@ -519,24 +525,42 @@ func cloneNodeSelector(src map[string]string) map[string]string {
return ret
}

// getNodeSelectorsFromObjectAnnotation tries to retrieve a node selectors slice from the
// getPodSetsInfoFromObjectAnnotation tries to retrieve a node selectors slice from the
// object's annotations fails if it's not found or is unable to unmarshal
func getNodeSelectorsFromObjectAnnotation(obj client.Object) ([]PodSetNodeSelector, error) {
str, found := obj.GetAnnotations()[OriginalNodeSelectorsAnnotation]
func getPodSetsInfoFromObjectAnnotation(obj client.Object, spec []kueue.PodSet) ([]PodSetInfo, error) {
hasCounts := true
str, found := obj.GetAnnotations()[OriginalPodSetsInfoAnnotation]
if !found {
return nil, errNodeSelectorsNotFound
hasCounts = false
} else {
str, found = obj.GetAnnotations()[OriginalNodeSelectorsAnnotation]
if !found {
return nil, errPodSetsInfoNotFound
}
}
// unmarshal
ret := []PodSetNodeSelector{}
ret := []PodSetInfo{}
if err := json.Unmarshal([]byte(str), &ret); err != nil {
return nil, err
}

if !hasCounts {
psMap := utilslice.ToRefMap(spec, func(ps *kueue.PodSet) string { return ps.Name })
for i := range ret {
info := &ret[i]
ps, found := psMap[info.Name]
if !found {
return nil, fmt.Errorf("%w: %s", errUnknownPodSetName, info.Name)
}
info.Count = ps.Count
}
}
return ret, nil
}

// setNodeSelectorsInAnnotation - sets an annotation containing the provided node selectors into
// a job object, even if very unlikely it could return an error related to json.marshaling
func setNodeSelectorsInAnnotation(obj client.Object, nodeSelectors []PodSetNodeSelector) error {
func setNodeSelectorsInAnnotation(obj client.Object, nodeSelectors []PodSetInfo) error {
nodeSelectorsBytes, err := json.Marshal(nodeSelectors)
if err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/jobframework/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func ValidateUpdateForOriginalNodeSelectors(oldJob, newJob GenericJob) field.Err
allErrs = append(allErrs, field.Forbidden(originalNodeSelectorsWorkloadKeyPath, "this annotation is immutable while the job is not changing its suspended state"))
}
} else if av, found := newJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation]; found {
out := []PodSetNodeSelector{}
out := []PodSetInfo{}
if err := json.Unmarshal([]byte(av), &out); err != nil {
allErrs = append(allErrs, field.Invalid(originalNodeSelectorsWorkloadKeyPath, av, err.Error()))
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/controller/jobs/job/job_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (j *Job) PodSets() []kueue.PodSet {
}
}

func (j *Job) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
func (j *Job) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetInfo) {
j.Spec.Suspend = pointer.Bool(false)
if len(nodeSelectors) == 0 {
return
Expand All @@ -195,7 +195,7 @@ func (j *Job) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelecto
}
}

func (j *Job) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
func (j *Job) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetInfo) {
if len(nodeSelectors) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, nodeSelectors[0].NodeSelector) {
return
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/controller/jobs/mpijob/mpijob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (j *MPIJob) PodSets() []kueue.PodSet {
return podSets
}

func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetInfo) {
j.Spec.RunPolicy.Suspend = pointer.Bool(false)
if len(nodeSelectors) == 0 {
return
Expand All @@ -144,7 +144,7 @@ func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSele
}
}

func (j *MPIJob) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) {
func (j *MPIJob) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetInfo) {
orderedReplicaTypes := orderedReplicaTypes(&j.Spec)
for index, nodeSelector := range nodeSelectors {
replicaType := orderedReplicaTypes[index]
Expand Down

0 comments on commit 77f462e

Please sign in to comment.