Skip to content

Commit

Permalink
feat: add pod labels with proper validation (#9364)
Browse files Browse the repository at this point in the history
  • Loading branch information
kkunapuli authored May 16, 2024
1 parent 0a59c63 commit 2c9b9b9
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 8 deletions.
83 changes: 82 additions & 1 deletion master/internal/rm/kubernetesrm/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"math"
"path"
"reflect"
"regexp"
"strconv"
"strings"

Expand All @@ -30,6 +31,7 @@ import (
schedulingV1 "k8s.io/api/scheduling/v1"
"k8s.io/apimachinery/pkg/api/resource"
metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/validation"
)

const (
Expand Down Expand Up @@ -349,6 +351,54 @@ func (p *pod) createPriorityClass(name string, priority int32) error {
return err
}

const (
maxChars int = 63
fmtAlphaNumeric string = "A-Za-z0-9"
fmtAllowedChars string = fmtAlphaNumeric + `\.\-_`
defaultPodLabelValue string = "invalid_value"
)

var (
regDisallowedSpecialChars = regexp.MustCompile("[^" + fmtAllowedChars + "]")
regLeadingNonAlphaNumeric = regexp.MustCompile("^[^" + fmtAlphaNumeric + "]+")
regTrailingNonAlphaNumeric = regexp.MustCompile("[^" + fmtAlphaNumeric + "]+$")
)

func validatePodLabelValue(value string) (string, error) {
errs := validation.IsValidLabelValue(value)
if len(errs) == 0 {
return value, nil
}

// Label value is not valid; attempt to fix it.
// 0. Convert dis-allowed special characters to underscore.
fixedValue := regDisallowedSpecialChars.ReplaceAllString(value, "_")

// 1. Strip leading non-alphanumeric characters.
fixedValue = regLeadingNonAlphaNumeric.ReplaceAllString(fixedValue, "")

// 2. Truncate to 63 characters.
if len(fixedValue) > maxChars {
fixedValue = fixedValue[:maxChars]
}

// 3. Strip ending non-alphanumeric characters.
fixedValue = regTrailingNonAlphaNumeric.ReplaceAllString(fixedValue, "")

log.Debugf(
"conform to Kubernetes pod label value standards: reformatting %s to %s",
value, fixedValue,
)

// Final validation check, return error if still not valid for safety.
errs = validation.IsValidLabelValue(fixedValue)
if len(errs) != 0 {
return "", errors.New("pod label value is not valid")
}

return fixedValue, nil
}

func (p *pod) configurePodSpec(
volumes []k8sV1.Volume,
determinedInitContainers k8sV1.Container,
Expand All @@ -368,14 +418,45 @@ func (p *pod) configurePodSpec(
if podSpec.ObjectMeta.Labels == nil {
podSpec.ObjectMeta.Labels = make(map[string]string)
}
if p.submissionInfo.taskSpec.Owner != nil {
// Owner label will disappear if Owner is somehow nil.
labelValue, err := validatePodLabelValue(p.submissionInfo.taskSpec.Owner.Username)
if err != nil {
labelValue = defaultPodLabelValue
log.Warnf("unable to reformat username=%s to Kubernetes standards; using %s",
p.submissionInfo.taskSpec.Owner.Username, labelValue)
}
podSpec.ObjectMeta.Labels[userLabel] = labelValue
}

labelValue, err := validatePodLabelValue(p.submissionInfo.taskSpec.Workspace)
if err != nil {
labelValue = defaultPodLabelValue
log.Warnf("unable to reformat workspace=%s to Kubernetes standards; using %s",
p.submissionInfo.taskSpec.Workspace, labelValue)
}
podSpec.ObjectMeta.Labels[workspaceLabel] = labelValue

labelValue, err = validatePodLabelValue(p.req.ResourcePool)
if err != nil {
labelValue = defaultPodLabelValue
log.Warnf("unable to reformat resource_pool=%s to Kubernetes standards; using %s",
p.req.ResourcePool, labelValue)
}
podSpec.ObjectMeta.Labels[resourcePoolLabel] = labelValue

podSpec.ObjectMeta.Labels[taskTypeLabel] = string(p.submissionInfo.taskSpec.TaskType)
podSpec.ObjectMeta.Labels[taskIDLabel] = p.submissionInfo.taskSpec.TaskID
podSpec.ObjectMeta.Labels[containerIDLabel] = p.submissionInfo.taskSpec.ContainerID
podSpec.ObjectMeta.Labels[determinedLabel] = p.submissionInfo.taskSpec.AllocationID

// If map is not populated, labels will be missing and observability will be impacted.
for k, v := range p.submissionInfo.taskSpec.ExtraPodLabels {
podSpec.ObjectMeta.Labels[labelPrefix+k] = v
labelValue, err := validatePodLabelValue(v)
if err != nil {
labelValue = defaultPodLabelValue
}
podSpec.ObjectMeta.Labels[labelPrefix+k] = labelValue
}

p.modifyPodSpec(podSpec, scheduler)
Expand Down
46 changes: 39 additions & 7 deletions master/internal/rm/kubernetesrm/spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,37 @@ func TestAllPrintableCharactersInEnv(t *testing.T) {
require.Contains(t, actual, k8sV1.EnvVar{Name: "func", Value: "f(x)=x"})
}

func TestValidatePodLabelValues(t *testing.T) {
tests := []struct {
name string
input string
output string
}{
{"valid all alpha", "simpleCharacters", "simpleCharacters"},
{"valid all alphanumeric", "simple4Characters", "simple4Characters"},
{"valid contains non-alphanumeric", "simple-Characters.With_Other", "simple-Characters.With_Other"},
{"invalid chars", "letters contain *@ other chars -=%", "letters_contain____other_chars"},
{"invalid leading chars", "-%4-simpleCharacters0", "4-simpleCharacters0"},
{"invalid trailing chars", "simple-Characters4%-.#", "simple-Characters4"},
{
"invalid too many chars", "simpleCharactersGoesOnForWayTooLong36384042444648505254565860-_AndThenSome",
"simpleCharactersGoesOnForWayTooLong36384042444648505254565860",
},
{"invalid email-style input", "name@domain.com", "name_domain.com"},
{"invalid chars only", "-.*$%#$...", ""},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
testOutput, err := validatePodLabelValue(tt.input)
require.NoError(t, err)
require.Equal(t, tt.output, testOutput, tt.name+" failed")
})
}
}

func TestDeterminedLabels(t *testing.T) {
// fill out task spec
// Fill out task spec.
taskSpec := tasks.TaskSpec{
Owner: createUser(),
Workspace: "test-workspace",
Expand All @@ -244,12 +273,15 @@ func TestDeterminedLabels(t *testing.T) {
},
}

// define expectations
// Define expectations.
expectedLabels := map[string]string{
determinedLabel: taskSpec.AllocationID,
taskTypeLabel: string(taskSpec.TaskType),
taskIDLabel: taskSpec.TaskID,
containerIDLabel: taskSpec.ContainerID,
determinedLabel: taskSpec.AllocationID,
userLabel: taskSpec.Owner.Username,
workspaceLabel: taskSpec.Workspace,
resourcePoolLabel: p.req.ResourcePool,
taskTypeLabel: string(taskSpec.TaskType),
taskIDLabel: taskSpec.TaskID,
containerIDLabel: taskSpec.ContainerID,
}
for k, v := range taskSpec.ExtraPodLabels {
expectedLabels[labelPrefix+k] = v
Expand All @@ -258,7 +290,7 @@ func TestDeterminedLabels(t *testing.T) {
spec := p.configurePodSpec(make([]k8sV1.Volume, 1), k8sV1.Container{},
k8sV1.Container{}, make([]k8sV1.Container, 1), &k8sV1.Pod{}, "scheduler")

// confirm pod spec has required labels
// Confirm pod spec has required labels.
require.NotNil(t, spec)
require.Equal(t, expectedLabels, spec.ObjectMeta.Labels)
}

0 comments on commit 2c9b9b9

Please sign in to comment.