From 8682007bacfe41619e34fd8bfe6cfe125c393cff Mon Sep 17 00:00:00 2001 From: Traian Schiau Date: Wed, 17 May 2023 07:25:47 +0300 Subject: [PATCH] [batch/job] Partial admission --- pkg/controller/jobs/job/job_controller.go | 54 +++- .../jobs/job/job_controller_test.go | 235 ++++++++++++++++++ pkg/controller/jobs/job/job_webhook.go | 38 ++- pkg/controller/jobs/job/job_webhook_test.go | 63 +++++ pkg/util/testingjobs/job/wrappers.go | 5 + .../controller/job/job_controller_test.go | 63 +++++ 6 files changed, 448 insertions(+), 10 deletions(-) diff --git a/pkg/controller/jobs/job/job_controller.go b/pkg/controller/jobs/job/job_controller.go index e18bc17019..1493307f2c 100644 --- a/pkg/controller/jobs/job/job_controller.go +++ b/pkg/controller/jobs/job/job_controller.go @@ -18,6 +18,7 @@ package job import ( "context" + "strconv" batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" @@ -48,6 +49,10 @@ var ( FrameworkName = "batch/job" ) +const ( + JobMinParallelismAnnotation = "kueue.x-k8s.io/job-min-parallelism" +) + func init() { utilruntime.Must(jobframework.RegisterIntegration(FrameworkName, jobframework.IntegrationCallbacks{ SetupIndexes: SetupIndexes, @@ -178,33 +183,46 @@ func (j *Job) PodSets() []kueue.PodSet { Name: kueue.DefaultPodSetName, Template: *j.Spec.Template.DeepCopy(), Count: j.podsCount(), + MinCount: j.minPodsCount(), }, } } -func (j *Job) RunWithPodSetsInfo(nodeSelectors []jobframework.PodSetInfo) { +func (j *Job) RunWithPodSetsInfo(podSetsInfo []jobframework.PodSetInfo) { j.Spec.Suspend = pointer.Bool(false) - if len(nodeSelectors) == 0 { + if len(podSetsInfo) == 0 { return } if j.Spec.Template.Spec.NodeSelector == nil { - j.Spec.Template.Spec.NodeSelector = nodeSelectors[0].NodeSelector + j.Spec.Template.Spec.NodeSelector = podSetsInfo[0].NodeSelector } else { - for k, v := range nodeSelectors[0].NodeSelector { + for k, v := range podSetsInfo[0].NodeSelector { j.Spec.Template.Spec.NodeSelector[k] = v } } + if j.minPodsCount() != nil { + j.Spec.Parallelism = pointer.Int32(podSetsInfo[0].Count) + } } -func (j *Job) RestorePodSetsInfo(nodeSelectors []jobframework.PodSetInfo) { - if len(nodeSelectors) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, nodeSelectors[0].NodeSelector) { +func (j *Job) RestorePodSetsInfo(podSetsInfo []jobframework.PodSetInfo) { + if len(podSetsInfo) == 0 { + return + } + + // if the job accepts partial admission + if j.minPodsCount() != nil { + j.Spec.Parallelism = pointer.Int32(podSetsInfo[0].Count) + } + + if equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, podSetsInfo[0].NodeSelector) { return } j.Spec.Template.Spec.NodeSelector = map[string]string{} - for k, v := range nodeSelectors[0].NodeSelector { + for k, v := range podSetsInfo[0].NodeSelector { j.Spec.Template.Spec.NodeSelector[k] = v } } @@ -239,7 +257,18 @@ func (j *Job) EquivalentToWorkload(wl kueue.Workload) bool { return false } - if *j.Spec.Parallelism != wl.Spec.PodSets[0].Count { + ps0 := &wl.Spec.PodSets[0] + + // if the job accepts partial admission + if mpc := j.minPodsCount(); mpc != nil { + if pointer.Int32Deref(ps0.MinCount, -1) != *mpc { + return false + } + + if j.IsSuspended() && j.podsCount() != ps0.Count { + return false + } + } else if j.podsCount() != ps0.Count { return false } @@ -271,6 +300,15 @@ func (j *Job) podsCount() int32 { return podsCount } +func (j *Job) minPodsCount() *int32 { + if strVal, found := j.GetAnnotations()[JobMinParallelismAnnotation]; found { + if iVal, err := strconv.Atoi(strVal); err == nil { + return pointer.Int32(int32(iVal)) + } + } + return nil +} + // SetupWithManager sets up the controller with the Manager. It indexes workloads // based on the owning jobs. func (r *JobReconciler) SetupWithManager(mgr ctrl.Manager) error { diff --git a/pkg/controller/jobs/job/job_controller_test.go b/pkg/controller/jobs/job/job_controller_test.go index 735c1b0d02..4ddba4d937 100644 --- a/pkg/controller/jobs/job/job_controller_test.go +++ b/pkg/controller/jobs/job/job_controller_test.go @@ -19,9 +19,14 @@ package job import ( "testing" + "github.com/google/go-cmp/cmp" batchv1 "k8s.io/api/batch/v1" + kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" + "sigs.k8s.io/kueue/pkg/controller/jobframework" "sigs.k8s.io/kueue/pkg/util/pointer" + utiltesting "sigs.k8s.io/kueue/pkg/util/testing" + utiltestingjob "sigs.k8s.io/kueue/pkg/util/testingjobs/job" ) func TestPodsReady(t *testing.T) { @@ -148,3 +153,233 @@ func TestPodsReady(t *testing.T) { }) } } + +func TestPodSetsInfo(t *testing.T) { + testcases := map[string]struct { + job *batchv1.Job + runInfo, restoreInfo []jobframework.PodSetInfo + wantUnsuspended *batchv1.Job + }{ + "append": { + job: utiltestingjob.MakeJob("job", "ns"). + Parallelism(1). + NodeSelector("orig-key", "orig-val"). + Obj(), + runInfo: []jobframework.PodSetInfo{ + { + NodeSelector: map[string]string{ + "new-key": "new-val", + }, + }, + }, + wantUnsuspended: utiltestingjob.MakeJob("job", "ns"). + Parallelism(1). + NodeSelector("orig-key", "orig-val"). + NodeSelector("new-key", "new-val"). + Suspend(false). + Obj(), + restoreInfo: []jobframework.PodSetInfo{ + { + NodeSelector: map[string]string{ + "orig-key": "orig-val", + }, + }, + }, + }, + "update": { + job: utiltestingjob.MakeJob("job", "ns"). + Parallelism(1). + NodeSelector("orig-key", "orig-val"). + Obj(), + runInfo: []jobframework.PodSetInfo{ + { + NodeSelector: map[string]string{ + "orig-key": "new-val", + }, + }, + }, + wantUnsuspended: utiltestingjob.MakeJob("job", "ns"). + Parallelism(1). + NodeSelector("orig-key", "new-val"). + Suspend(false). + Obj(), + restoreInfo: []jobframework.PodSetInfo{ + { + NodeSelector: map[string]string{ + "orig-key": "orig-val", + }, + }, + }, + }, + "parallelism": { + job: utiltestingjob.MakeJob("job", "ns"). + Parallelism(5). + SetAnnotation(JobMinParallelismAnnotation, "2"). + Obj(), + runInfo: []jobframework.PodSetInfo{ + { + Count: 2, + }, + }, + wantUnsuspended: utiltestingjob.MakeJob("job", "ns"). + Parallelism(2). + SetAnnotation(JobMinParallelismAnnotation, "2"). + Suspend(false). + Obj(), + restoreInfo: []jobframework.PodSetInfo{ + { + Count: 5, + }, + }, + }, + } + for name, tc := range testcases { + t.Run(name, func(t *testing.T) { + origSpec := *tc.job.Spec.DeepCopy() + job := Job{tc.job} + + job.RunWithPodSetsInfo(tc.runInfo) + + if diff := cmp.Diff(job.Spec, tc.wantUnsuspended.Spec); diff != "" { + t.Errorf("node selectors mismatch (-want +got):\n%s", diff) + } + job.RestorePodSetsInfo(tc.restoreInfo) + job.Suspend() + if diff := cmp.Diff(job.Spec, origSpec); diff != "" { + t.Errorf("node selectors mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestEquivalentToWorkload(t *testing.T) { + baseJob := &Job{utiltestingjob.MakeJob("job", "ns"). + Parallelism(2). + Obj()} + baseJobPartialAdmission := &Job{utiltestingjob.MakeJob("job", "ns"). + SetAnnotation(JobMinParallelismAnnotation, "2"). + Parallelism(2). + Obj()} + podSets := []kueue.PodSet{ + *utiltesting.MakePodSet("main", 2). + Containers(*baseJob.Spec.Template.Spec.Containers[0].DeepCopy()). + Obj(), + } + basWorkload := utiltesting.MakeWorkload("wl", "ns").PodSets(podSets...).Obj() + basWorkloadPartialAdmission := basWorkload.DeepCopy() + basWorkloadPartialAdmission.Spec.PodSets[0].MinCount = pointer.Int32(2) + cases := map[string]struct { + wl *kueue.Workload + job *Job + wantResult bool + }{ + "wrong podsets number": { + job: baseJob, + wl: func() *kueue.Workload { + wl := basWorkload.DeepCopy() + wl.Spec.PodSets = append(wl.Spec.PodSets, wl.Spec.PodSets...) + return wl + }(), + }, + "different pods count": { + job: baseJob, + wl: func() *kueue.Workload { + wl := basWorkload.DeepCopy() + wl.Spec.PodSets[0].Count = 3 + return wl + }(), + }, + "different container": { + job: baseJob, + wl: func() *kueue.Workload { + wl := basWorkload.DeepCopy() + wl.Spec.PodSets[0].Template.Spec.Containers[0].Image = "other-image" + return wl + }(), + }, + "equivalent": { + job: baseJob, + wl: basWorkload.DeepCopy(), + wantResult: true, + }, + "partial admission bad count (suspended)": { + job: baseJobPartialAdmission, + wl: func() *kueue.Workload { + wl := basWorkloadPartialAdmission.DeepCopy() + wl.Spec.PodSets[0].Count = 3 + return wl + }(), + }, + "partial admission different count (unsuspended)": { + job: func() *Job { + j := &Job{baseJobPartialAdmission.DeepCopy()} + j.Spec.Suspend = pointer.Bool(false) + return j + }(), + wl: func() *kueue.Workload { + wl := basWorkloadPartialAdmission.DeepCopy() + wl.Spec.PodSets[0].Count = 3 + return wl + }(), + wantResult: true, + }, + "partial admission bad minCount": { + job: baseJobPartialAdmission, + wl: func() *kueue.Workload { + wl := basWorkloadPartialAdmission.DeepCopy() + wl.Spec.PodSets[0].MinCount = pointer.Int32(3) + return wl + }(), + }, + "equivalent partial admission": { + job: baseJobPartialAdmission, + wl: basWorkloadPartialAdmission.DeepCopy(), + wantResult: true, + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + if tc.job.EquivalentToWorkload(*tc.wl) != tc.wantResult { + t.Fatalf("Unexpected result, wanted: %v", tc.wantResult) + } + }) + } +} + +func TestPodSets(t *testing.T) { + podTemplate := utiltestingjob.MakeJob("job", "ns").Spec.Template.DeepCopy() + cases := map[string]struct { + job *batchv1.Job + wantPodSets []kueue.PodSet + }{ + "no partial admission": { + job: utiltestingjob.MakeJob("job", "ns").Parallelism(3).Obj(), + wantPodSets: []kueue.PodSet{ + { + Name: "main", + Template: *podTemplate.DeepCopy(), + Count: 3, + }, + }, + }, + "partial admission": { + job: utiltestingjob.MakeJob("job", "ns").Parallelism(3).SetAnnotation(JobMinParallelismAnnotation, "2").Obj(), + wantPodSets: []kueue.PodSet{ + { + Name: "main", + Template: *podTemplate.DeepCopy(), + Count: 3, + MinCount: pointer.Int32(2), + }, + }, + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + gotPodSets := (&Job{tc.job}).PodSets() + if diff := cmp.Diff(tc.wantPodSets, gotPodSets); diff != "" { + t.Errorf("node selectors mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/controller/jobs/job/job_webhook.go b/pkg/controller/jobs/job/job_webhook.go index 5d63896e38..5552c6c53f 100644 --- a/pkg/controller/jobs/job/job_webhook.go +++ b/pkg/controller/jobs/job/job_webhook.go @@ -18,18 +18,25 @@ package job import ( "context" + "fmt" + "strconv" batchv1 "k8s.io/api/batch/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/klog/v2" + "k8s.io/utils/pointer" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/webhook" "sigs.k8s.io/kueue/pkg/controller/jobframework" ) +var ( + minPodsCountAnnotationsPath = field.NewPath("metadata", "annotations").Key(JobMinParallelismAnnotation) +) + type JobWebhook struct { manageJobsWithoutQueueName bool } @@ -87,10 +94,26 @@ func (w *JobWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) err return validateCreate(&Job{job}).ToAggregate() } -func validateCreate(job jobframework.GenericJob) field.ErrorList { +func validateCreate(job *Job) field.ErrorList { var allErrs field.ErrorList allErrs = append(allErrs, jobframework.ValidateAnnotationAsCRDName(job, jobframework.ParentWorkloadAnnotation)...) allErrs = append(allErrs, jobframework.ValidateCreateForQueueName(job)...) + allErrs = append(allErrs, validatePartialAdmissionCreate(job)...) + return allErrs +} + +func validatePartialAdmissionCreate(job *Job) field.ErrorList { + var allErrs field.ErrorList + if strVal, found := job.Annotations[JobMinParallelismAnnotation]; found { + v, err := strconv.Atoi(strVal) + if err != nil { + allErrs = append(allErrs, field.Invalid(minPodsCountAnnotationsPath, job.Annotations[JobMinParallelismAnnotation], err.Error())) + } else { + if int32(v) >= job.podsCount() || v <= 0 { + allErrs = append(allErrs, field.Invalid(minPodsCountAnnotationsPath, v, fmt.Sprintf("should be between 0 and %d", job.podsCount()-1))) + } + } + } return allErrs } @@ -103,12 +126,23 @@ func (w *JobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime. return validateUpdate(&Job{oldJob}, &Job{newJob}).ToAggregate() } -func validateUpdate(oldJob, newJob jobframework.GenericJob) field.ErrorList { +func validateUpdate(oldJob, newJob *Job) field.ErrorList { allErrs := validateCreate(newJob) allErrs = append(allErrs, jobframework.ValidateUpdateForParentWorkload(oldJob, newJob)...) allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalPodSetsInfo(oldJob, newJob)...) allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldJob, newJob)...) allErrs = append(allErrs, jobframework.ValidateUpdateForQueueName(oldJob, newJob)...) + allErrs = append(allErrs, validatePartialAdmissionUpdate(oldJob, newJob)...) + return allErrs +} + +func validatePartialAdmissionUpdate(oldJob, newJob *Job) field.ErrorList { + var allErrs field.ErrorList + if _, found := oldJob.Annotations[JobMinParallelismAnnotation]; found { + if !oldJob.IsSuspended() && pointer.Int32Deref(oldJob.Spec.Parallelism, 1) != pointer.Int32Deref(newJob.Spec.Parallelism, 1) { + allErrs = append(allErrs, field.Forbidden(field.NewPath("spec", "parallelism"), "cannot change when partial admission is enabled and the job is not suspended")) + } + } return allErrs } diff --git a/pkg/controller/jobs/job/job_webhook_test.go b/pkg/controller/jobs/job/job_webhook_test.go index 8fcaa155da..303e61aa9c 100644 --- a/pkg/controller/jobs/job/job_webhook_test.go +++ b/pkg/controller/jobs/job/job_webhook_test.go @@ -83,6 +83,37 @@ func TestValidateCreate(t *testing.T) { field.Invalid(queueNameLabelPath, "queue name", invalidRFC1123Message), }, }, + { + name: "invalid partial admission annotation (format)", + job: testingutil.MakeJob("job", "default"). + Parallelism(4). + Completions(6). + SetAnnotation(JobMinParallelismAnnotation, "NaN"). + Obj(), + wantErr: field.ErrorList{ + field.Invalid(minPodsCountAnnotationsPath, "NaN", "strconv.Atoi: parsing \"NaN\": invalid syntax"), + }, + }, + { + name: "invalid partial admission annotation (badValue)", + job: testingutil.MakeJob("job", "default"). + Parallelism(4). + Completions(6). + SetAnnotation(JobMinParallelismAnnotation, "5"). + Obj(), + wantErr: field.ErrorList{ + field.Invalid(minPodsCountAnnotationsPath, 5, "should be between 0 and 3"), + }, + }, + { + name: "valid partial admission annotation", + job: testingutil.MakeJob("job", "default"). + Parallelism(4). + Completions(6). + SetAnnotation(JobMinParallelismAnnotation, "3"). + Obj(), + wantErr: nil, + }, } for _, tc := range testcases { @@ -199,6 +230,38 @@ func TestValidateUpdate(t *testing.T) { field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"), }, }, + { + name: "immutable parallelism while unsuspended with partial admission enabled", + oldJob: testingutil.MakeJob("job", "default"). + Suspend(false). + Parallelism(4). + Completions(6). + SetAnnotation(JobMinParallelismAnnotation, "3"). + Obj(), + newJob: testingutil.MakeJob("job", "default"). + Suspend(false). + Parallelism(5). + Completions(6). + SetAnnotation(JobMinParallelismAnnotation, "3"). + Obj(), + wantErr: field.ErrorList{ + field.Forbidden(field.NewPath("spec", "parallelism"), "cannot change when partial admission is enabled and the job is not suspended"), + }, + }, + { + name: "mutable parallelism while suspended with partial admission enabled", + oldJob: testingutil.MakeJob("job", "default"). + Parallelism(4). + Completions(6). + SetAnnotation(JobMinParallelismAnnotation, "3"). + Obj(), + newJob: testingutil.MakeJob("job", "default"). + Parallelism(5). + Completions(6). + SetAnnotation(JobMinParallelismAnnotation, "3"). + Obj(), + wantErr: nil, + }, } for _, tc := range testcases { diff --git a/pkg/util/testingjobs/job/wrappers.go b/pkg/util/testingjobs/job/wrappers.go index 61bf27f707..490d612ca2 100644 --- a/pkg/util/testingjobs/job/wrappers.go +++ b/pkg/util/testingjobs/job/wrappers.go @@ -114,6 +114,11 @@ func (j *JobWrapper) OriginalNodeSelectorsAnnotation(content string) *JobWrapper return j } +func (j *JobWrapper) SetAnnotation(key, content string) *JobWrapper { + j.Annotations[key] = content + return j +} + // Toleration adds a toleration to the job. func (j *JobWrapper) Toleration(t corev1.Toleration) *JobWrapper { j.Spec.Template.Spec.Tolerations = append(j.Spec.Template.Spec.Tolerations, t) diff --git a/test/integration/controller/job/job_controller_test.go b/test/integration/controller/job/job_controller_test.go index b4fcfe3c0d..7bb6a481e6 100644 --- a/test/integration/controller/job/job_controller_test.go +++ b/test/integration/controller/job/job_controller_test.go @@ -1027,4 +1027,67 @@ var _ = ginkgo.Describe("Job controller interacting with scheduler", func() { }) }) + ginkgo.It("Should schedule jobs when partial admission is enabled", func() { + prodLocalQ = testing.MakeLocalQueue("prod-queue", ns.Name).ClusterQueue(prodClusterQ.Name).Obj() + job1 := testingjob.MakeJob("job1", ns.Name). + Queue(prodLocalQ.Name). + Parallelism(5). + Completions(6). + Request(corev1.ResourceCPU, "2"). + Obj() + jobKey := types.NamespacedName{Name: job1.Name, Namespace: job1.Namespace} + wlKey := types.NamespacedName{Name: workloadjob.GetWorkloadNameForJob(job1.Name), Namespace: job1.Namespace} + + ginkgo.By("creating localQueues") + gomega.Expect(k8sClient.Create(ctx, prodLocalQ)).Should(gomega.Succeed()) + + ginkgo.By("creating the job") + gomega.Expect(k8sClient.Create(ctx, job1)).Should(gomega.Succeed()) + + createdJob := &batchv1.Job{} + ginkgo.By("the job should stay suspended", func() { + gomega.Consistently(func() *bool { + gomega.Expect(k8sClient.Get(ctx, jobKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.Suspend + }, util.ConsistentDuration, util.Interval).Should(gomega.Equal(pointer.Bool(true))) + }) + + ginkgo.By("enable partial admission", func() { + gomega.Expect(k8sClient.Get(ctx, jobKey, createdJob)).Should(gomega.Succeed()) + if createdJob.Annotations == nil { + createdJob.Annotations = map[string]string{ + workloadjob.JobMinParallelismAnnotation: "1", + } + } else { + createdJob.Annotations[workloadjob.JobMinParallelismAnnotation] = "1" + } + + gomega.Expect(k8sClient.Update(ctx, createdJob)).Should(gomega.Succeed()) + }) + + wl := &kueue.Workload{} + ginkgo.By("the job should be unsuspended with a lower parallelism", func() { + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, jobKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(pointer.Bool(false))) + gomega.Expect(*createdJob.Spec.Parallelism).To(gomega.BeEquivalentTo(2)) + + gomega.Expect(k8sClient.Get(ctx, wlKey, wl)).To(gomega.Succeed()) + gomega.Expect(wl.Spec.PodSets[0].MinCount).ToNot(gomega.BeNil()) + gomega.Expect(*wl.Spec.PodSets[0].MinCount).To(gomega.BeEquivalentTo(1)) + }) + + ginkgo.By("changing the min parallelism the job should be suspended and its parallelism restored", func() { + gomega.Expect(k8sClient.Get(ctx, jobKey, createdJob)).Should(gomega.Succeed()) + createdJob.Annotations[workloadjob.JobMinParallelismAnnotation] = "4" + gomega.Expect(k8sClient.Update(ctx, createdJob)).Should(gomega.Succeed()) + + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, jobKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(pointer.Bool(true))) + gomega.Expect(*createdJob.Spec.Parallelism).To(gomega.BeEquivalentTo(5)) + }) + }) })