From 8481a0aad8b3ea42f2b28777c3c9a2a3d8408d16 Mon Sep 17 00:00:00 2001 From: Traian Schiau Date: Mon, 15 May 2023 15:45:01 +0300 Subject: [PATCH] [jobframework] Add support for partial admission --- pkg/controller/jobframework/constants.go | 8 ++ pkg/controller/jobframework/interface.go | 8 +- pkg/controller/jobframework/reconciler.go | 84 ++++++++++++------- pkg/controller/jobframework/validation.go | 19 ++++- pkg/controller/jobs/job/job_controller.go | 4 +- pkg/controller/jobs/job/job_webhook.go | 1 + .../jobs/mpijob/mpijob_controller.go | 12 +-- pkg/controller/jobs/mpijob/mpijob_webhook.go | 1 + .../jobs/rayjob/rayjob_controller.go | 20 ++--- .../jobs/rayjob/rayjob_controller_test.go | 6 +- 10 files changed, 107 insertions(+), 56 deletions(-) diff --git a/pkg/controller/jobframework/constants.go b/pkg/controller/jobframework/constants.go index 8fbbec5a1d..775f2ee2c1 100644 --- a/pkg/controller/jobframework/constants.go +++ b/pkg/controller/jobframework/constants.go @@ -37,5 +37,13 @@ const ( // node selectors are recorded upon a workload admission. This information // will be used to restore them when the job is suspended. // The content is a json marshaled slice of selectors. + // + // DEPRECATED: Use OriginalPodSetsInfoAnnotation. OriginalNodeSelectorsAnnotation = "kueue.x-k8s.io/original-node-selectors" + + // OriginalPodSetsInfoAnnotation is the annotation in which the original + // node selectors and podSet counts are recorded upon a workload admission. + // This information will be used to restore them when the job is suspended. + // The content is a json marshaled slice of PodSetInfo. + OriginalPodSetsInfoAnnotation = "kueue.x-k8s.io/original-pod-sets-info" ) diff --git a/pkg/controller/jobframework/interface.go b/pkg/controller/jobframework/interface.go index 2f6dfba821..7ecb67501a 100644 --- a/pkg/controller/jobframework/interface.go +++ b/pkg/controller/jobframework/interface.go @@ -31,10 +31,10 @@ type GenericJob interface { // ResetStatus will reset the job status to the original state. // If true, status is modified, if not, status is as it was. ResetStatus() bool - // RunWithNodeAffinity will inject the node affinity extracting from workload to job and unsuspend the job. - RunWithNodeAffinity(nodeSelectors []PodSetNodeSelector) - // RestoreNodeAffinity will restore the original node affinity of job. - RestoreNodeAffinity(nodeSelectors []PodSetNodeSelector) + // RunWithPodSetsInfo will inject the node affinity and podSet counts extracting from workload to job and unsuspend it. + RunWithPodSetsInfo(nodeSelectors []PodSetInfo) + // RestorePodSetsInfo will restore the original node affinity and podSet counts of the job. + RestorePodSetsInfo(nodeSelectors []PodSetInfo) // Finished means whether the job is completed/failed or not, // condition represents the workload finished condition. Finished() (condition metav1.Condition, finished bool) diff --git a/pkg/controller/jobframework/reconciler.go b/pkg/controller/jobframework/reconciler.go index 777a16ec86..6631f98ecb 100644 --- a/pkg/controller/jobframework/reconciler.go +++ b/pkg/controller/jobframework/reconciler.go @@ -16,6 +16,7 @@ package jobframework import ( "context" "encoding/json" + "errors" "fmt" corev1 "k8s.io/api/core/v1" @@ -32,11 +33,13 @@ import ( kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" "sigs.k8s.io/kueue/pkg/constants" utilpriority "sigs.k8s.io/kueue/pkg/util/priority" + "sigs.k8s.io/kueue/pkg/util/slices" "sigs.k8s.io/kueue/pkg/workload" ) var ( - errNodeSelectorsNotFound = fmt.Errorf("annotation %s not found", OriginalNodeSelectorsAnnotation) + errPodSetsInfoNotFound = fmt.Errorf("annotation %s or %s not found", OriginalNodeSelectorsAnnotation, OriginalPodSetsInfoAnnotation) + errUnknownPodSetName = errors.New("unknown podSet name") ) // JobReconciler reconciles a GenericJob object @@ -332,17 +335,17 @@ func (r *JobReconciler) equivalentToWorkload(job GenericJob, object client.Objec // startJob will unsuspend the job, and also inject the node affinity. func (r *JobReconciler) startJob(ctx context.Context, job GenericJob, object client.Object, wl *kueue.Workload) error { - //get the original selectors and store them in the job object - originalSelectors := r.getNodeSelectorsFromPodSets(wl) - if err := setNodeSelectorsInAnnotation(object, originalSelectors); err != nil { - return fmt.Errorf("startJob, record original node selectors: %w", err) + //get the original podSetsInfo and store them in the job object + originalPodSetsInfo := r.getPodSetsInfoFromSpec(wl) + if err := setNodeSelectorsInAnnotation(object, originalPodSetsInfo); err != nil { + return fmt.Errorf("startJob, record original podSetsInfo: %w", err) } - nodeSelectors, err := r.getNodeSelectorsFromAdmission(ctx, wl) + info, err := r.getPodSetsInfoFromAdmission(ctx, wl) if err != nil { return err } - job.RunWithNodeAffinity(nodeSelectors) + job.RunWithPodSetsInfo(info) if err := r.client.Update(ctx, object); err != nil { return err @@ -372,12 +375,12 @@ func (r *JobReconciler) stopJob(ctx context.Context, job GenericJob, object clie } } - log.V(3).Info("restore node selectors from annotation") - selectors, err := getNodeSelectorsFromObjectAnnotation(object) + log.V(3).Info("restore podSetsInfo from annotation") + info, err := getPodSetsInfoFromObjectAnnotation(object, job) if err != nil { - log.V(3).Error(err, "Unable to get original node selectors") + log.V(3).Error(err, "Unable to get original podSetsInfo") } else { - job.RestoreNodeAffinity(selectors) + job.RestorePodSetsInfo(info) return r.client.Update(ctx, object) } @@ -412,24 +415,26 @@ func (r *JobReconciler) constructWorkload(ctx context.Context, job GenericJob, o return wl, nil } -type PodSetNodeSelector struct { +type PodSetInfo struct { Name string `json:"name"` NodeSelector map[string]string `json:"nodeSelector"` + Count int32 `json:"count"` } -// getNodeSelectorsFromAdmission will extract node selectors from admitted workloads. -func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *kueue.Workload) ([]PodSetNodeSelector, error) { +// getPodSetsInfoFromAdmission will extract podSetsInfo and podSets count from admitted workloads. +func (r *JobReconciler) getPodSetsInfoFromAdmission(ctx context.Context, w *kueue.Workload) ([]PodSetInfo, error) { if len(w.Status.Admission.PodSetAssignments) == 0 { return nil, nil } - nodeSelectors := make([]PodSetNodeSelector, len(w.Status.Admission.PodSetAssignments)) + nodeSelectors := make([]PodSetInfo, len(w.Status.Admission.PodSetAssignments)) for i, podSetFlavor := range w.Status.Admission.PodSetAssignments { processedFlvs := sets.NewString() - nodeSelector := PodSetNodeSelector{ + nodeSelector := PodSetInfo{ Name: podSetFlavor.Name, NodeSelector: make(map[string]string), + Count: podSetFlavor.Count, } for _, flvRef := range podSetFlavor.Flavors { flvName := string(flvRef) @@ -452,18 +457,19 @@ func (r *JobReconciler) getNodeSelectorsFromAdmission(ctx context.Context, w *ku return nodeSelectors, nil } -// getNodeSelectorsFromPodSets will extract node selectors from a workload's podSets. -func (r *JobReconciler) getNodeSelectorsFromPodSets(w *kueue.Workload) []PodSetNodeSelector { +// getPodSetsInfoFromSpec will extract podSetsInfo and podSet's counts from a workload's spec. +func (r *JobReconciler) getPodSetsInfoFromSpec(w *kueue.Workload) []PodSetInfo { podSets := w.Spec.PodSets if len(podSets) == 0 { return nil } - ret := make([]PodSetNodeSelector, len(podSets)) + ret := make([]PodSetInfo, len(podSets)) for psi := range podSets { ps := &podSets[psi] - ret[psi] = PodSetNodeSelector{ + ret[psi] = PodSetInfo{ Name: ps.Name, NodeSelector: cloneNodeSelector(ps.Template.Spec.NodeSelector), + Count: ps.Count, } } return ret @@ -519,34 +525,52 @@ func cloneNodeSelector(src map[string]string) map[string]string { return ret } -// getNodeSelectorsFromObjectAnnotation tries to retrieve a node selectors slice from the +// getPodSetsInfoFromObjectAnnotation tries to retrieve a podSetsInfo slice from the // object's annotations fails if it's not found or is unable to unmarshal -func getNodeSelectorsFromObjectAnnotation(obj client.Object) ([]PodSetNodeSelector, error) { - str, found := obj.GetAnnotations()[OriginalNodeSelectorsAnnotation] +func getPodSetsInfoFromObjectAnnotation(obj client.Object, job GenericJob) ([]PodSetInfo, error) { + hasCounts := true + str, found := obj.GetAnnotations()[OriginalPodSetsInfoAnnotation] if !found { - return nil, errNodeSelectorsNotFound + hasCounts = false + str, found = obj.GetAnnotations()[OriginalNodeSelectorsAnnotation] + if !found { + return nil, errPodSetsInfoNotFound + } } // unmarshal - ret := []PodSetNodeSelector{} + ret := []PodSetInfo{} if err := json.Unmarshal([]byte(str), &ret); err != nil { return nil, err } + + if !hasCounts { + podSets := job.PodSets() + psMap := slices.ToRefMap(podSets, func(ps *kueue.PodSet) string { return ps.Name }) + for i := range ret { + info := &ret[i] + ps, found := psMap[info.Name] + if !found { + return nil, fmt.Errorf("%w: %s", errUnknownPodSetName, info.Name) + } + info.Count = ps.Count + } + } return ret, nil } -// setNodeSelectorsInAnnotation - sets an annotation containing the provided node selectors into +// setNodeSelectorsInAnnotation - sets an annotation containing the provided podSetsInfo into // a job object, even if very unlikely it could return an error related to json.marshaling -func setNodeSelectorsInAnnotation(obj client.Object, nodeSelectors []PodSetNodeSelector) error { - nodeSelectorsBytes, err := json.Marshal(nodeSelectors) +func setNodeSelectorsInAnnotation(obj client.Object, info []PodSetInfo) error { + nodeSelectorsBytes, err := json.Marshal(info) if err != nil { return err } annotations := obj.GetAnnotations() if annotations == nil { - annotations = map[string]string{OriginalNodeSelectorsAnnotation: string(nodeSelectorsBytes)} + annotations = map[string]string{OriginalPodSetsInfoAnnotation: string(nodeSelectorsBytes)} } else { - annotations[OriginalNodeSelectorsAnnotation] = string(nodeSelectorsBytes) + annotations[OriginalPodSetsInfoAnnotation] = string(nodeSelectorsBytes) } obj.SetAnnotations(annotations) return nil diff --git a/pkg/controller/jobframework/validation.go b/pkg/controller/jobframework/validation.go index 4a745777c9..74b763f31c 100644 --- a/pkg/controller/jobframework/validation.go +++ b/pkg/controller/jobframework/validation.go @@ -29,6 +29,7 @@ var ( queueNameLabelPath = labelsPath.Key(QueueLabel) originalNodeSelectorsWorkloadKeyPath = annotationsPath.Key(OriginalNodeSelectorsAnnotation) + originalPodSetsInfosWorkloadKeyPath = annotationsPath.Key(OriginalPodSetsInfoAnnotation) ) func ValidateCreateForQueueName(job GenericJob) field.ErrorList { @@ -83,10 +84,26 @@ func ValidateUpdateForOriginalNodeSelectors(oldJob, newJob GenericJob) field.Err allErrs = append(allErrs, field.Forbidden(originalNodeSelectorsWorkloadKeyPath, "this annotation is immutable while the job is not changing its suspended state")) } } else if av, found := newJob.Object().GetAnnotations()[OriginalNodeSelectorsAnnotation]; found { - out := []PodSetNodeSelector{} + out := []PodSetInfo{} if err := json.Unmarshal([]byte(av), &out); err != nil { allErrs = append(allErrs, field.Invalid(originalNodeSelectorsWorkloadKeyPath, av, err.Error())) } } return allErrs } + +func ValidateUpdateForOriginalPodSetsInfo(oldJob, newJob GenericJob) field.ErrorList { + var allErrs field.ErrorList + if oldJob.IsSuspended() == newJob.IsSuspended() { + if errList := apivalidation.ValidateImmutableField(oldJob.Object().GetAnnotations()[OriginalPodSetsInfoAnnotation], + newJob.Object().GetAnnotations()[OriginalPodSetsInfoAnnotation], originalPodSetsInfosWorkloadKeyPath); len(errList) > 0 { + allErrs = append(allErrs, field.Forbidden(originalPodSetsInfosWorkloadKeyPath, "this annotation is immutable while the job is not changing its suspended state")) + } + } else if av, found := newJob.Object().GetAnnotations()[OriginalPodSetsInfoAnnotation]; found { + out := []PodSetInfo{} + if err := json.Unmarshal([]byte(av), &out); err != nil { + allErrs = append(allErrs, field.Invalid(originalPodSetsInfosWorkloadKeyPath, av, err.Error())) + } + } + return allErrs +} diff --git a/pkg/controller/jobs/job/job_controller.go b/pkg/controller/jobs/job/job_controller.go index d8d81b1e54..e18bc17019 100644 --- a/pkg/controller/jobs/job/job_controller.go +++ b/pkg/controller/jobs/job/job_controller.go @@ -182,7 +182,7 @@ func (j *Job) PodSets() []kueue.PodSet { } } -func (j *Job) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) { +func (j *Job) RunWithPodSetsInfo(nodeSelectors []jobframework.PodSetInfo) { j.Spec.Suspend = pointer.Bool(false) if len(nodeSelectors) == 0 { return @@ -197,7 +197,7 @@ func (j *Job) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelecto } } -func (j *Job) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) { +func (j *Job) RestorePodSetsInfo(nodeSelectors []jobframework.PodSetInfo) { if len(nodeSelectors) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, nodeSelectors[0].NodeSelector) { return } diff --git a/pkg/controller/jobs/job/job_webhook.go b/pkg/controller/jobs/job/job_webhook.go index f3d4d30c68..5d63896e38 100644 --- a/pkg/controller/jobs/job/job_webhook.go +++ b/pkg/controller/jobs/job/job_webhook.go @@ -106,6 +106,7 @@ func (w *JobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime. func validateUpdate(oldJob, newJob jobframework.GenericJob) field.ErrorList { allErrs := validateCreate(newJob) allErrs = append(allErrs, jobframework.ValidateUpdateForParentWorkload(oldJob, newJob)...) + allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalPodSetsInfo(oldJob, newJob)...) allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldJob, newJob)...) allErrs = append(allErrs, jobframework.ValidateUpdateForQueueName(oldJob, newJob)...) return allErrs diff --git a/pkg/controller/jobs/mpijob/mpijob_controller.go b/pkg/controller/jobs/mpijob/mpijob_controller.go index 8dfe422819..633cc2902d 100644 --- a/pkg/controller/jobs/mpijob/mpijob_controller.go +++ b/pkg/controller/jobs/mpijob/mpijob_controller.go @@ -123,17 +123,17 @@ func (j *MPIJob) PodSets() []kueue.PodSet { return podSets } -func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) { +func (j *MPIJob) RunWithPodSetsInfo(podSetInfos []jobframework.PodSetInfo) { j.Spec.RunPolicy.Suspend = pointer.Bool(false) - if len(nodeSelectors) == 0 { + if len(podSetInfos) == 0 { return } // The node selectors are provided in the same order as the generated list of // podSets, use the same ordering logic to restore them. orderedReplicaTypes := orderedReplicaTypes(&j.Spec) - for index := range nodeSelectors { + for index := range podSetInfos { replicaType := orderedReplicaTypes[index] - nodeSelector := nodeSelectors[index] + nodeSelector := podSetInfos[index] if len(nodeSelector.NodeSelector) != 0 { if j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector == nil { j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = nodeSelector.NodeSelector @@ -146,9 +146,9 @@ func (j *MPIJob) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSele } } -func (j *MPIJob) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) { +func (j *MPIJob) RestorePodSetsInfo(podSetInfos []jobframework.PodSetInfo) { orderedReplicaTypes := orderedReplicaTypes(&j.Spec) - for index, nodeSelector := range nodeSelectors { + for index, nodeSelector := range podSetInfos { replicaType := orderedReplicaTypes[index] if !equality.Semantic.DeepEqual(j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector, nodeSelector) { j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = map[string]string{} diff --git a/pkg/controller/jobs/mpijob/mpijob_webhook.go b/pkg/controller/jobs/mpijob/mpijob_webhook.go index b66227157b..b14851c71a 100644 --- a/pkg/controller/jobs/mpijob/mpijob_webhook.go +++ b/pkg/controller/jobs/mpijob/mpijob_webhook.go @@ -88,6 +88,7 @@ func (w *MPIJobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runti log := ctrl.LoggerFrom(ctx).WithName("job-webhook") log.Info("Validating update", "job", klog.KObj(newJob)) allErrs := jobframework.ValidateUpdateForQueueName(oldGenJob, newGenJob) + allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalPodSetsInfo(oldGenJob, newGenJob)...) allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldGenJob, newGenJob)...) return allErrs.ToAggregate() } diff --git a/pkg/controller/jobs/rayjob/rayjob_controller.go b/pkg/controller/jobs/rayjob/rayjob_controller.go index d1c5f0a9c8..9a04e97c3d 100644 --- a/pkg/controller/jobs/rayjob/rayjob_controller.go +++ b/pkg/controller/jobs/rayjob/rayjob_controller.go @@ -137,20 +137,20 @@ func applySelectors(dst, src map[string]string) map[string]string { return dst } -func (j *RayJob) RunWithNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) { +func (j *RayJob) RunWithPodSetsInfo(podSetInfos []jobframework.PodSetInfo) { j.Spec.Suspend = false - if len(nodeSelectors) != len(j.Spec.RayClusterSpec.WorkerGroupSpecs)+1 { + if len(podSetInfos) != len(j.Spec.RayClusterSpec.WorkerGroupSpecs)+1 { return } // head headPodSpec := &j.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec - headPodSpec.NodeSelector = applySelectors(headPodSpec.NodeSelector, nodeSelectors[0].NodeSelector) + headPodSpec.NodeSelector = applySelectors(headPodSpec.NodeSelector, podSetInfos[0].NodeSelector) // workers for index := range j.Spec.RayClusterSpec.WorkerGroupSpecs { workerPodSpec := &j.Spec.RayClusterSpec.WorkerGroupSpecs[index].Template.Spec - workerPodSpec.NodeSelector = applySelectors(workerPodSpec.NodeSelector, nodeSelectors[index+1].NodeSelector) + workerPodSpec.NodeSelector = applySelectors(workerPodSpec.NodeSelector, podSetInfos[index+1].NodeSelector) } } @@ -162,22 +162,22 @@ func cloneSelectors(src map[string]string) map[string]string { return dst } -func (j *RayJob) RestoreNodeAffinity(nodeSelectors []jobframework.PodSetNodeSelector) { - if len(nodeSelectors) != len(j.Spec.RayClusterSpec.WorkerGroupSpecs)+1 { +func (j *RayJob) RestorePodSetsInfo(podSetInfos []jobframework.PodSetInfo) { + if len(podSetInfos) != len(j.Spec.RayClusterSpec.WorkerGroupSpecs)+1 { return } // head headPodSpec := &j.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec - if !equality.Semantic.DeepEqual(headPodSpec.NodeSelector, nodeSelectors[0].NodeSelector) { - headPodSpec.NodeSelector = cloneSelectors(nodeSelectors[0].NodeSelector) + if !equality.Semantic.DeepEqual(headPodSpec.NodeSelector, podSetInfos[0].NodeSelector) { + headPodSpec.NodeSelector = cloneSelectors(podSetInfos[0].NodeSelector) } // workers for index := range j.Spec.RayClusterSpec.WorkerGroupSpecs { workerPodSpec := &j.Spec.RayClusterSpec.WorkerGroupSpecs[index].Template.Spec - if !equality.Semantic.DeepEqual(workerPodSpec.NodeSelector, nodeSelectors[index+1].NodeSelector) { - workerPodSpec.NodeSelector = cloneSelectors(nodeSelectors[index+1].NodeSelector) + if !equality.Semantic.DeepEqual(workerPodSpec.NodeSelector, podSetInfos[index+1].NodeSelector) { + workerPodSpec.NodeSelector = cloneSelectors(podSetInfos[index+1].NodeSelector) } } } diff --git a/pkg/controller/jobs/rayjob/rayjob_controller_test.go b/pkg/controller/jobs/rayjob/rayjob_controller_test.go index 43f53461cb..12070426ea 100644 --- a/pkg/controller/jobs/rayjob/rayjob_controller_test.go +++ b/pkg/controller/jobs/rayjob/rayjob_controller_test.go @@ -297,8 +297,8 @@ func TestNodeSelectors(t *testing.T) { }). Obj()) - // run with Affinity should append or update the node selectors - job.RunWithNodeAffinity([]jobframework.PodSetNodeSelector{ + // RunWithPodSetsInfo should append or update the node selectors + job.RunWithPodSetsInfo([]jobframework.PodSetInfo{ { NodeSelector: map[string]string{ "newKey": "newValue", @@ -341,7 +341,7 @@ func TestNodeSelectors(t *testing.T) { } // restore should replace node selectors - job.RestoreNodeAffinity([]jobframework.PodSetNodeSelector{ + job.RestorePodSetsInfo([]jobframework.PodSetInfo{ { NodeSelector: map[string]string{ // clean it all