Skip to content

Commit

Permalink
[batch/job] Partial admission
Browse files Browse the repository at this point in the history
  • Loading branch information
trasc committed May 17, 2023
1 parent eaecfbb commit 7297247
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 17 deletions.
7 changes: 4 additions & 3 deletions pkg/controller/jobframework/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ func (r *JobReconciler) stopJob(ctx context.Context, job GenericJob, object clie
}

log.V(3).Info("restore podSetsInfo from annotation")
info, err := getPodSetsInfoFromObjectAnnotation(object, wl.Spec.PodSets)
info, err := getPodSetsInfoFromObjectAnnotation(object, job)
if err != nil {
log.V(3).Error(err, "Unable to get original podSetsInfo")
} else {
Expand Down Expand Up @@ -527,7 +527,7 @@ func cloneNodeSelector(src map[string]string) map[string]string {

// getPodSetsInfoFromObjectAnnotation tries to retrieve a podSetsInfo slice from the
// object's annotations fails if it's not found or is unable to unmarshal
func getPodSetsInfoFromObjectAnnotation(obj client.Object, spec []kueue.PodSet) ([]PodSetInfo, error) {
func getPodSetsInfoFromObjectAnnotation(obj client.Object, job GenericJob) ([]PodSetInfo, error) {
hasCounts := true
str, found := obj.GetAnnotations()[OriginalPodSetsInfoAnnotation]
if !found {
Expand All @@ -544,7 +544,8 @@ func getPodSetsInfoFromObjectAnnotation(obj client.Object, spec []kueue.PodSet)
}

if !hasCounts {
psMap := utilslice.ToRefMap(spec, func(ps *kueue.PodSet) string { return ps.Name })
podSets := job.PodSets()
psMap := utilslice.ToRefMap(podSets, func(ps *kueue.PodSet) string { return ps.Name })
for i := range ret {
info := &ret[i]
ps, found := psMap[info.Name]
Expand Down
52 changes: 41 additions & 11 deletions pkg/controller/jobs/job/job_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package job

import (
"context"
"strconv"

batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -48,6 +49,10 @@ var (
FrameworkName = "batch/job"
)

const (
JobMinParallelismAnnotation = "kueue.x-k8s.io/job-min-parallelism"
)

func init() {
utilruntime.Must(jobframework.RegisterIntegration(FrameworkName, jobframework.IntegrationCallbacks{
SetupIndexes: SetupIndexes,
Expand Down Expand Up @@ -173,36 +178,47 @@ func (j *Job) ReclaimablePods() []kueue.ReclaimablePod {
func (j *Job) PodSets() []kueue.PodSet {
return []kueue.PodSet{
{
Name: kueue.DefaultPodSetName,
Template: *j.Spec.Template.DeepCopy(),
Count: j.podsCount(),
Name: kueue.DefaultPodSetName,
Template: *j.Spec.Template.DeepCopy(),
Count: j.podsCount(),
MinimumCount: j.minPodsCount(),
},
}
}

func (j *Job) RunWithPodSetsInfo(nodeSelectors []jobframework.PodSetInfo) {
func (j *Job) RunWithPodSetsInfo(podSetsInfo []jobframework.PodSetInfo) {
j.Spec.Suspend = pointer.Bool(false)
if len(nodeSelectors) == 0 {
if len(podSetsInfo) == 0 {
return
}

if j.Spec.Template.Spec.NodeSelector == nil {
j.Spec.Template.Spec.NodeSelector = nodeSelectors[0].NodeSelector
j.Spec.Template.Spec.NodeSelector = podSetsInfo[0].NodeSelector
} else {
for k, v := range nodeSelectors[0].NodeSelector {
for k, v := range podSetsInfo[0].NodeSelector {
j.Spec.Template.Spec.NodeSelector[k] = v
}
}
j.Spec.Parallelism = pointer.Int32(podSetsInfo[0].Count)
}

func (j *Job) RestorePodSetsInfo(nodeSelectors []jobframework.PodSetInfo) {
if len(nodeSelectors) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, nodeSelectors[0].NodeSelector) {
func (j *Job) RestorePodSetsInfo(podSetsInfo []jobframework.PodSetInfo) {
if len(podSetsInfo) == 0 {
return
}

// if partial admission is enabled
if j.minPodsCount() != nil {
j.Spec.Parallelism = pointer.Int32(podSetsInfo[0].Count)
}

if equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, podSetsInfo[0].NodeSelector) {
return
}

j.Spec.Template.Spec.NodeSelector = map[string]string{}

for k, v := range nodeSelectors[0].NodeSelector {
for k, v := range podSetsInfo[0].NodeSelector {
j.Spec.Template.Spec.NodeSelector[k] = v
}
}
Expand Down Expand Up @@ -237,7 +253,12 @@ func (j *Job) EquivalentToWorkload(wl kueue.Workload) bool {
return false
}

if *j.Spec.Parallelism != wl.Spec.PodSets[0].Count {
ps0 := &wl.Spec.PodSets[0]
if mpc := j.minPodsCount(); mpc != nil {
if ps0.MinimumCount == nil || *ps0.MinimumCount != *mpc || ps0.Count < *mpc || ps0.Count > j.podsCount() {
return false
}
} else if j.podsCount() != ps0.Count {
return false
}

Expand Down Expand Up @@ -269,6 +290,15 @@ func (j *Job) podsCount() int32 {
return podsCount
}

func (j *Job) minPodsCount() *int32 {
if strVal, found := j.GetAnnotations()[JobMinParallelismAnnotation]; found {
if iVal, err := strconv.Atoi(strVal); err == nil {
return pointer.Int32(int32(iVal))
}
}
return nil
}

// SetupWithManager sets up the controller with the Manager. It indexes workloads
// based on the owning jobs.
func (r *JobReconciler) SetupWithManager(mgr ctrl.Manager) error {
Expand Down
42 changes: 40 additions & 2 deletions pkg/controller/jobs/job/job_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,25 @@ package job

import (
"context"
"fmt"
"strconv"

batchv1 "k8s.io/api/batch/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/klog/v2"
"k8s.io/utils/pointer"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/webhook"

"sigs.k8s.io/kueue/pkg/controller/jobframework"
)

var (
minPodsCountAnnotationsPath = field.NewPath("metadata", "annotations").Key(JobMinParallelismAnnotation)
)

type JobWebhook struct {
manageJobsWithoutQueueName bool
}
Expand Down Expand Up @@ -87,10 +94,30 @@ func (w *JobWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) err
return validateCreate(&Job{job}).ToAggregate()
}

func validateCreate(job jobframework.GenericJob) field.ErrorList {
func validateCreate(job *Job) field.ErrorList {
var allErrs field.ErrorList
allErrs = append(allErrs, jobframework.ValidateAnnotationAsCRDName(job, jobframework.ParentWorkloadAnnotation)...)
allErrs = append(allErrs, jobframework.ValidateCreateForQueueName(job)...)
allErrs = append(allErrs, validatePartialAdmissionCreate(job)...)
return allErrs
}

func validatePartialAdmissionCreate(job *Job) field.ErrorList {
var allErrs field.ErrorList
if strVal, found := job.Annotations[JobMinParallelismAnnotation]; found {
v, err := strconv.Atoi(strVal)
if err != nil {
allErrs = append(allErrs, field.Invalid(minPodsCountAnnotationsPath, job.Annotations[JobMinParallelismAnnotation], err.Error()))
} else {
if int32(v) >= job.podsCount() || v <= 0 {
allErrs = append(allErrs, field.Invalid(minPodsCountAnnotationsPath, v, fmt.Sprintf("should be between 0 and %d", job.podsCount()-1)))
}
}
// the completions should be explicitly provided
if job.Spec.Completions == nil {
allErrs = append(allErrs, field.Invalid(field.NewPath("spec", "completions"), nil, "should be explicitly provided when partial admission is enabled"))
}
}
return allErrs
}

Expand All @@ -103,12 +130,23 @@ func (w *JobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.
return validateUpdate(&Job{oldJob}, &Job{newJob}).ToAggregate()
}

func validateUpdate(oldJob, newJob jobframework.GenericJob) field.ErrorList {
func validateUpdate(oldJob, newJob *Job) field.ErrorList {
allErrs := validateCreate(newJob)
allErrs = append(allErrs, jobframework.ValidateUpdateForParentWorkload(oldJob, newJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalPodSetsInfo(oldJob, newJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForOriginalNodeSelectors(oldJob, newJob)...)
allErrs = append(allErrs, jobframework.ValidateUpdateForQueueName(oldJob, newJob)...)
allErrs = append(allErrs, validatePartialAdmissionUpdate(oldJob, newJob)...)
return allErrs
}

func validatePartialAdmissionUpdate(oldJob, newJob *Job) field.ErrorList {
var allErrs field.ErrorList
if _, found := oldJob.Annotations[JobMinParallelismAnnotation]; found {
if !oldJob.IsSuspended() && pointer.Int32Deref(oldJob.Spec.Parallelism, 1) != pointer.Int32Deref(newJob.Spec.Parallelism, 1) {
allErrs = append(allErrs, field.Forbidden(field.NewPath("spec", "parallelism"), "cannot change when partial admission is enabled and the job is not suspended"))
}
}
return allErrs
}

Expand Down
83 changes: 83 additions & 0 deletions pkg/controller/jobs/job/job_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,57 @@ func TestValidateCreate(t *testing.T) {
field.Invalid(queueNameLabelPath, "queue name", invalidRFC1123Message),
},
},
{
name: "invalid partial admission annotation (format)",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(6).
SetAnnotation(JobMinParallelismAnnotation, "NaN").
Obj(),
wantErr: field.ErrorList{
field.Invalid(minPodsCountAnnotationsPath, "NaN", "strconv.Atoi: parsing \"NaN\": invalid syntax"),
},
},
{
name: "invalid partial admission annotation (badValue)",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(6).
SetAnnotation(JobMinParallelismAnnotation, "5").
Obj(),
wantErr: field.ErrorList{
field.Invalid(minPodsCountAnnotationsPath, 5, "should be between 0 and 3"),
},
},
{
name: "partial admission, nil completions",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
SetAnnotation(JobMinParallelismAnnotation, "3").
Obj(),
wantErr: field.ErrorList{
field.Invalid(field.NewPath("spec", "completions"), nil, "should be explicitly provided when partial admission is enabled"),
},
},
{
name: "partial admission, changed ",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
SetAnnotation(JobMinParallelismAnnotation, "3").
Obj(),
wantErr: field.ErrorList{
field.Invalid(field.NewPath("spec", "completions"), nil, "should be explicitly provided when partial admission is enabled"),
},
},
{
name: "partial admission annotation valid",
job: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(6).
SetAnnotation(JobMinParallelismAnnotation, "3").
Obj(),
wantErr: nil,
},
}

for _, tc := range testcases {
Expand Down Expand Up @@ -199,6 +250,38 @@ func TestValidateUpdate(t *testing.T) {
field.Forbidden(originalNodeSelectorsKeyPath, "this annotation is immutable while the job is not changing its suspended state"),
},
},
{
name: "immutable parallelism while unsuspended with partial admission enabled",
oldJob: testingutil.MakeJob("job", "default").
Suspend(false).
Parallelism(4).
Completions(6).
SetAnnotation(JobMinParallelismAnnotation, "3").
Obj(),
newJob: testingutil.MakeJob("job", "default").
Suspend(false).
Parallelism(5).
Completions(6).
SetAnnotation(JobMinParallelismAnnotation, "3").
Obj(),
wantErr: field.ErrorList{
field.Forbidden(field.NewPath("spec", "parallelism"), "cannot change when partial admission is enabled and the job is not suspended"),
},
},
{
name: "mutable parallelism while suspended with partial admission enabled",
oldJob: testingutil.MakeJob("job", "default").
Parallelism(4).
Completions(6).
SetAnnotation(JobMinParallelismAnnotation, "3").
Obj(),
newJob: testingutil.MakeJob("job", "default").
Parallelism(5).
Completions(6).
SetAnnotation(JobMinParallelismAnnotation, "3").
Obj(),
wantErr: nil,
},
}

for _, tc := range testcases {
Expand Down
2 changes: 1 addition & 1 deletion pkg/scheduler/flavorassigner/podSetReducer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
utiltesting "sigs.k8s.io/kueue/pkg/util/testing"
)

func TestPodSetReducer(t *testing.T) {
func TestSearch(t *testing.T) {
cases := map[string]struct {
podSets []kueue.PodSet
countLimit int32
Expand Down
5 changes: 5 additions & 0 deletions pkg/util/testingjobs/job/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ func (j *JobWrapper) OriginalNodeSelectorsAnnotation(content string) *JobWrapper
return j
}

func (j *JobWrapper) SetAnnotation(key, content string) *JobWrapper {
j.Annotations[key] = content
return j
}

// Toleration adds a toleration to the job.
func (j *JobWrapper) Toleration(t corev1.Toleration) *JobWrapper {
j.Spec.Template.Spec.Tolerations = append(j.Spec.Template.Spec.Tolerations, t)
Expand Down
Loading

0 comments on commit 7297247

Please sign in to comment.