Skip to content

Commit

Permalink
[batch/job] Partial admission
Browse files Browse the repository at this point in the history
  • Loading branch information
trasc committed May 31, 2023
1 parent 8481a0a commit 8682007
Show file tree
Hide file tree
Showing 6 changed files with 448 additions and 10 deletions.
54 changes: 46 additions & 8 deletions pkg/controller/jobs/job/job_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package job

import (
"context"
"strconv"

batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -48,6 +49,10 @@ var (
FrameworkName = "batch/job"
)

const (
JobMinParallelismAnnotation = "kueue.x-k8s.io/job-min-parallelism"
)

func init() {
utilruntime.Must(jobframework.RegisterIntegration(FrameworkName, jobframework.IntegrationCallbacks{
SetupIndexes: SetupIndexes,
Expand Down Expand Up @@ -178,33 +183,46 @@ func (j *Job) PodSets() []kueue.PodSet {
Name: kueue.DefaultPodSetName,
Template: *j.Spec.Template.DeepCopy(),
Count: j.podsCount(),
MinCount: j.minPodsCount(),
},
}
}

func (j *Job) RunWithPodSetsInfo(nodeSelectors []jobframework.PodSetInfo) {
func (j *Job) RunWithPodSetsInfo(podSetsInfo []jobframework.PodSetInfo) {
j.Spec.Suspend = pointer.Bool(false)
if len(nodeSelectors) == 0 {
if len(podSetsInfo) == 0 {
return
}

if j.Spec.Template.Spec.NodeSelector == nil {
j.Spec.Template.Spec.NodeSelector = nodeSelectors[0].NodeSelector
j.Spec.Template.Spec.NodeSelector = podSetsInfo[0].NodeSelector
} else {
for k, v := range nodeSelectors[0].NodeSelector {
for k, v := range podSetsInfo[0].NodeSelector {
j.Spec.Template.Spec.NodeSelector[k] = v
}
}
if j.minPodsCount() != nil {
j.Spec.Parallelism = pointer.Int32(podSetsInfo[0].Count)
}
}

func (j *Job) RestorePodSetsInfo(nodeSelectors []jobframework.PodSetInfo) {
if len(nodeSelectors) == 0 || equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, nodeSelectors[0].NodeSelector) {
func (j *Job) RestorePodSetsInfo(podSetsInfo []jobframework.PodSetInfo) {
if len(podSetsInfo) == 0 {
return
}

// if the job accepts partial admission
if j.minPodsCount() != nil {
j.Spec.Parallelism = pointer.Int32(podSetsInfo[0].Count)
}

if equality.Semantic.DeepEqual(j.Spec.Template.Spec.NodeSelector, podSetsInfo[0].NodeSelector) {
return
}

j.Spec.Template.Spec.NodeSelector = map[string]string{}

for k, v := range nodeSelectors[0].NodeSelector {
for k, v := range podSetsInfo[0].NodeSelector {
j.Spec.Template.Spec.NodeSelector[k] = v
}
}
Expand Down Expand Up @@ -239,7 +257,18 @@ func (j *Job) EquivalentToWorkload(wl kueue.Workload) bool {
return false
}

if *j.Spec.Parallelism != wl.Spec.PodSets[0].Count {
ps0 := &wl.Spec.PodSets[0]

// if the job accepts partial admission
if mpc := j.minPodsCount(); mpc != nil {
if pointer.Int32Deref(ps0.MinCount, -1) != *mpc {
return false
}

if j.IsSuspended() && j.podsCount() != ps0.Count {
return false
}
} else if j.podsCount() != ps0.Count {
return false
}

Expand Down Expand Up @@ -271,6 +300,15 @@ func (j *Job) podsCount() int32 {
return podsCount
}

func (j *Job) minPodsCount() *int32 {
if strVal, found := j.GetAnnotations()[JobMinParallelismAnnotation]; found {
if iVal, err := strconv.Atoi(strVal); err == nil {
return pointer.Int32(int32(iVal))
}
}
return nil
}

// SetupWithManager sets up the controller with the Manager. It indexes workloads
// based on the owning jobs.
func (r *JobReconciler) SetupWithManager(mgr ctrl.Manager) error {
Expand Down
235 changes: 235 additions & 0 deletions pkg/controller/jobs/job/job_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@ package job
import (
"testing"

"github.com/google/go-cmp/cmp"
batchv1 "k8s.io/api/batch/v1"

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
"sigs.k8s.io/kueue/pkg/util/pointer"
utiltesting "sigs.k8s.io/kueue/pkg/util/testing"
utiltestingjob "sigs.k8s.io/kueue/pkg/util/testingjobs/job"
)

func TestPodsReady(t *testing.T) {
Expand Down Expand Up @@ -148,3 +153,233 @@ func TestPodsReady(t *testing.T) {
})
}
}

func TestPodSetsInfo(t *testing.T) {
testcases := map[string]struct {
job *batchv1.Job
runInfo, restoreInfo []jobframework.PodSetInfo
wantUnsuspended *batchv1.Job
}{
"append": {
job: utiltestingjob.MakeJob("job", "ns").
Parallelism(1).
NodeSelector("orig-key", "orig-val").
Obj(),
runInfo: []jobframework.PodSetInfo{
{
NodeSelector: map[string]string{
"new-key": "new-val",
},
},
},
wantUnsuspended: utiltestingjob.MakeJob("job", "ns").
Parallelism(1).
NodeSelector("orig-key", "orig-val").
NodeSelector("new-key", "new-val").
Suspend(false).
Obj(),
restoreInfo: []jobframework.PodSetInfo{
{
NodeSelector: map[string]string{
"orig-key": "orig-val",
},
},
},
},
"update": {
job: utiltestingjob.MakeJob("job", "ns").
Parallelism(1).
NodeSelector("orig-key", "orig-val").
Obj(),
runInfo: []jobframework.PodSetInfo{
{
NodeSelector: map[string]string{
"orig-key": "new-val",
},
},
},
wantUnsuspended: utiltestingjob.MakeJob("job", "ns").
Parallelism(1).
NodeSelector("orig-key", "new-val").
Suspend(false).
Obj(),
restoreInfo: []jobframework.PodSetInfo{
{
NodeSelector: map[string]string{
"orig-key": "orig-val",
},
},
},
},
"parallelism": {
job: utiltestingjob.MakeJob("job", "ns").
Parallelism(5).
SetAnnotation(JobMinParallelismAnnotation, "2").
Obj(),
runInfo: []jobframework.PodSetInfo{
{
Count: 2,
},
},
wantUnsuspended: utiltestingjob.MakeJob("job", "ns").
Parallelism(2).
SetAnnotation(JobMinParallelismAnnotation, "2").
Suspend(false).
Obj(),
restoreInfo: []jobframework.PodSetInfo{
{
Count: 5,
},
},
},
}
for name, tc := range testcases {
t.Run(name, func(t *testing.T) {
origSpec := *tc.job.Spec.DeepCopy()
job := Job{tc.job}

job.RunWithPodSetsInfo(tc.runInfo)

if diff := cmp.Diff(job.Spec, tc.wantUnsuspended.Spec); diff != "" {
t.Errorf("node selectors mismatch (-want +got):\n%s", diff)
}
job.RestorePodSetsInfo(tc.restoreInfo)
job.Suspend()
if diff := cmp.Diff(job.Spec, origSpec); diff != "" {
t.Errorf("node selectors mismatch (-want +got):\n%s", diff)
}
})
}
}

func TestEquivalentToWorkload(t *testing.T) {
baseJob := &Job{utiltestingjob.MakeJob("job", "ns").
Parallelism(2).
Obj()}
baseJobPartialAdmission := &Job{utiltestingjob.MakeJob("job", "ns").
SetAnnotation(JobMinParallelismAnnotation, "2").
Parallelism(2).
Obj()}
podSets := []kueue.PodSet{
*utiltesting.MakePodSet("main", 2).
Containers(*baseJob.Spec.Template.Spec.Containers[0].DeepCopy()).
Obj(),
}
basWorkload := utiltesting.MakeWorkload("wl", "ns").PodSets(podSets...).Obj()
basWorkloadPartialAdmission := basWorkload.DeepCopy()
basWorkloadPartialAdmission.Spec.PodSets[0].MinCount = pointer.Int32(2)
cases := map[string]struct {
wl *kueue.Workload
job *Job
wantResult bool
}{
"wrong podsets number": {
job: baseJob,
wl: func() *kueue.Workload {
wl := basWorkload.DeepCopy()
wl.Spec.PodSets = append(wl.Spec.PodSets, wl.Spec.PodSets...)
return wl
}(),
},
"different pods count": {
job: baseJob,
wl: func() *kueue.Workload {
wl := basWorkload.DeepCopy()
wl.Spec.PodSets[0].Count = 3
return wl
}(),
},
"different container": {
job: baseJob,
wl: func() *kueue.Workload {
wl := basWorkload.DeepCopy()
wl.Spec.PodSets[0].Template.Spec.Containers[0].Image = "other-image"
return wl
}(),
},
"equivalent": {
job: baseJob,
wl: basWorkload.DeepCopy(),
wantResult: true,
},
"partial admission bad count (suspended)": {
job: baseJobPartialAdmission,
wl: func() *kueue.Workload {
wl := basWorkloadPartialAdmission.DeepCopy()
wl.Spec.PodSets[0].Count = 3
return wl
}(),
},
"partial admission different count (unsuspended)": {
job: func() *Job {
j := &Job{baseJobPartialAdmission.DeepCopy()}
j.Spec.Suspend = pointer.Bool(false)
return j
}(),
wl: func() *kueue.Workload {
wl := basWorkloadPartialAdmission.DeepCopy()
wl.Spec.PodSets[0].Count = 3
return wl
}(),
wantResult: true,
},
"partial admission bad minCount": {
job: baseJobPartialAdmission,
wl: func() *kueue.Workload {
wl := basWorkloadPartialAdmission.DeepCopy()
wl.Spec.PodSets[0].MinCount = pointer.Int32(3)
return wl
}(),
},
"equivalent partial admission": {
job: baseJobPartialAdmission,
wl: basWorkloadPartialAdmission.DeepCopy(),
wantResult: true,
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
if tc.job.EquivalentToWorkload(*tc.wl) != tc.wantResult {
t.Fatalf("Unexpected result, wanted: %v", tc.wantResult)
}
})
}
}

func TestPodSets(t *testing.T) {
podTemplate := utiltestingjob.MakeJob("job", "ns").Spec.Template.DeepCopy()
cases := map[string]struct {
job *batchv1.Job
wantPodSets []kueue.PodSet
}{
"no partial admission": {
job: utiltestingjob.MakeJob("job", "ns").Parallelism(3).Obj(),
wantPodSets: []kueue.PodSet{
{
Name: "main",
Template: *podTemplate.DeepCopy(),
Count: 3,
},
},
},
"partial admission": {
job: utiltestingjob.MakeJob("job", "ns").Parallelism(3).SetAnnotation(JobMinParallelismAnnotation, "2").Obj(),
wantPodSets: []kueue.PodSet{
{
Name: "main",
Template: *podTemplate.DeepCopy(),
Count: 3,
MinCount: pointer.Int32(2),
},
},
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
gotPodSets := (&Job{tc.job}).PodSets()
if diff := cmp.Diff(tc.wantPodSets, gotPodSets); diff != "" {
t.Errorf("node selectors mismatch (-want +got):\n%s", diff)
}
})
}
}
Loading

0 comments on commit 8682007

Please sign in to comment.