Skip to content

Commit

Permalink
refactor: use ptr.Deref instead of GetDefaultIfNil & remove ContainsO…
Browse files Browse the repository at this point in the history
…bjectKey
  • Loading branch information
KevFan committed Aug 30, 2023
1 parent b2bd120 commit 6d5c580
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 93 deletions.
3 changes: 2 additions & 1 deletion controllers/ratelimitpolicy_cluster_envoy_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/go-logr/logr"
limitadorv1alpha1 "github.com/kuadrant/limitador-operator/api/v1alpha1"
"golang.org/x/exp/slices"
istioapinetworkingv1alpha3 "istio.io/api/networking/v1alpha3"
istioclientnetworkingv1alpha3 "istio.io/client-go/pkg/apis/networking/v1alpha3"
"k8s.io/apimachinery/pkg/api/errors"
Expand Down Expand Up @@ -50,7 +51,7 @@ func (r *RateLimitPolicyReconciler) reconcileRateLimitingClusterEnvoyFilter(ctx
rlpRefs := gw.PolicyRefs()
rlpKey := client.ObjectKeyFromObject(rlp)
// Add the RLP key to the reference list. Only if it does not exist (it should not)
if !common.ContainsObjectKey(rlpRefs, rlpKey) {
if !slices.Contains(rlpRefs, rlpKey) {
rlpRefs = append(gw.PolicyRefs(), rlpKey)
}
ef, err := r.gatewayRateLimitingClusterEnvoyFilter(ctx, gw.Gateway, rlpRefs)
Expand Down
3 changes: 2 additions & 1 deletion controllers/ratelimitpolicy_wasm_plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"

"github.com/go-logr/logr"
"golang.org/x/exp/slices"
istioextensionsv1alpha1 "istio.io/api/extensions/v1alpha1"
istioclientgoextensionv1alpha1 "istio.io/client-go/pkg/apis/extensions/v1alpha1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -57,7 +58,7 @@ func (r *RateLimitPolicyReconciler) reconcileWASMPluginConf(ctx context.Context,
rlpRefs := gw.PolicyRefs()
rlpKey := client.ObjectKeyFromObject(rlp)
// Add the RLP key to the reference list. Only if it does not exist (it should not)
if !common.ContainsObjectKey(rlpRefs, rlpKey) {
if !slices.Contains(rlpRefs, rlpKey) {
rlpRefs = append(gw.PolicyRefs(), rlpKey)
}
wp, err := r.gatewayWASMPlugin(ctx, gw, rlpRefs)
Expand Down
9 changes: 0 additions & 9 deletions pkg/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package common

import (
"fmt"
"reflect"
"strings"

"golang.org/x/exp/slices"
Expand Down Expand Up @@ -46,14 +45,6 @@ type KuadrantPolicy interface {
GetRulesHostnames() []string
}

// GetDefaultIfNil returns the value of a pointer argument, or a default value if the pointer is nil.
func GetDefaultIfNil[T any](val *T, def T) T {
if reflect.ValueOf(val).IsNil() {
return def
}
return *val
}

// GetEmptySliceIfNil returns a provided slice, or an empty slice of the same type if the input slice is nil.
func GetEmptySliceIfNil[T any](val []T) []T {
if val == nil {
Expand Down
23 changes: 0 additions & 23 deletions pkg/common/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,29 +73,6 @@ func TestFind(t *testing.T) {
}
}

func TestGetDefaultIfNil(t *testing.T) {
t.Run("when value is non-nil pointer type and default value is provided then return value", func(t *testing.T) {
val := "test"
def := "default"

result := GetDefaultIfNil(&val, def)

if result != val {
t.Errorf("Expected %v, but got %v", val, result)
}
})
t.Run("when value is nil pointer type and default value is provided then return default value", func(t *testing.T) {
var val *string
def := "default"

result := GetDefaultIfNil(val, def)

if result != def {
t.Errorf("Expected %v, but got %v", def, result)
}
})
}

func TestGetEmptySliceIfNil(t *testing.T) {
t.Run("when a non-nil slice is provided then return same slice", func(t *testing.T) {
value := []int{1, 2, 3}
Expand Down
16 changes: 9 additions & 7 deletions pkg/common/gatewayapi_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"reflect"
"strings"

"golang.org/x/exp/slices"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta"
"k8s.io/apimachinery/pkg/types"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
gatewayapiv1alpha2 "sigs.k8s.io/gateway-api/apis/v1alpha2"
gatewayapiv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1"
Expand Down Expand Up @@ -204,12 +206,12 @@ func HTTPMethodToString(method *gatewayapiv1beta1.HTTPMethod) string {

func GetKuadrantNamespaceFromPolicyTargetRef(ctx context.Context, cli client.Client, policy KuadrantPolicy) (string, error) {
targetRef := policy.GetTargetRef()
gwNamespacedName := types.NamespacedName{Namespace: string(GetDefaultIfNil(targetRef.Namespace, policy.GetWrappedNamespace())), Name: string(targetRef.Name)}
gwNamespacedName := types.NamespacedName{Namespace: string(ptr.Deref(targetRef.Namespace, policy.GetWrappedNamespace())), Name: string(targetRef.Name)}

Check warning on line 209 in pkg/common/gatewayapi_utils.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/gatewayapi_utils.go#L209

Added line #L209 was not covered by tests
if IsTargetRefHTTPRoute(targetRef) {
route := &gatewayapiv1beta1.HTTPRoute{}
if err := cli.Get(
ctx,
types.NamespacedName{Namespace: string(GetDefaultIfNil(targetRef.Namespace, policy.GetWrappedNamespace())), Name: string(targetRef.Name)},
types.NamespacedName{Namespace: string(ptr.Deref(targetRef.Namespace, policy.GetWrappedNamespace())), Name: string(targetRef.Name)},

Check warning on line 214 in pkg/common/gatewayapi_utils.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/gatewayapi_utils.go#L214

Added line #L214 was not covered by tests
route,
); err != nil {
return "", err
Expand Down Expand Up @@ -316,7 +318,7 @@ func GatewaysMissingPolicyRef(gwList *gatewayapiv1beta1.GatewayList, policyKey c
for i := range gwList.Items {
gateway := gwList.Items[i]
gw := GatewayWrapper{&gateway, config}
if ContainsObjectKey(policyGwKeys, client.ObjectKeyFromObject(&gateway)) && !gw.ContainsPolicy(policyKey) {
if slices.Contains(policyGwKeys, client.ObjectKeyFromObject(&gateway)) && !gw.ContainsPolicy(policyKey) {
gateways = append(gateways, gw)
}
}
Expand All @@ -329,7 +331,7 @@ func GatewaysWithValidPolicyRef(gwList *gatewayapiv1beta1.GatewayList, policyKey
for i := range gwList.Items {
gateway := gwList.Items[i]
gw := GatewayWrapper{&gateway, config}
if ContainsObjectKey(policyGwKeys, client.ObjectKeyFromObject(&gateway)) && gw.ContainsPolicy(policyKey) {
if slices.Contains(policyGwKeys, client.ObjectKeyFromObject(&gateway)) && gw.ContainsPolicy(policyKey) {
gateways = append(gateways, gw)
}
}
Expand All @@ -342,7 +344,7 @@ func GatewaysWithInvalidPolicyRef(gwList *gatewayapiv1beta1.GatewayList, policyK
for i := range gwList.Items {
gateway := gwList.Items[i]
gw := GatewayWrapper{&gateway, config}
if !ContainsObjectKey(policyGwKeys, client.ObjectKeyFromObject(&gateway)) && gw.ContainsPolicy(policyKey) {
if !slices.Contains(policyGwKeys, client.ObjectKeyFromObject(&gateway)) && gw.ContainsPolicy(policyKey) {
gateways = append(gateways, gw)
}
}
Expand Down Expand Up @@ -403,7 +405,7 @@ func (g GatewayWrapper) ContainsPolicy(policyKey client.ObjectKey) bool {
return false
}

return ContainsObjectKey(refs, policyKey)
return slices.Contains(refs, policyKey)
}

// AddPolicy tries to add a policy to the existing ref list.
Expand Down Expand Up @@ -434,7 +436,7 @@ func (g GatewayWrapper) AddPolicy(policyKey client.ObjectKey) bool {
return false
}

if ContainsObjectKey(refs, policyKey) {
if slices.Contains(refs, policyKey) {
return false
}

Expand Down
10 changes: 0 additions & 10 deletions pkg/common/k8s_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,6 @@ func ObjectKeyListDifference(a, b []client.ObjectKey) []client.ObjectKey {
return result
}

// ContainsObjectKey tells whether a contains x
func ContainsObjectKey(a []client.ObjectKey, x client.ObjectKey) bool {
for _, n := range a {
if x == n {
return true
}
}
return false
}

// FindObjectKey returns the smallest index i at which x == a[i],
// or len(a) if there is no such index.
func FindObjectKey(a []client.ObjectKey, x client.ObjectKey) int {
Expand Down
42 changes: 0 additions & 42 deletions pkg/common/k8s_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -731,48 +731,6 @@ func TestGetServicePortNumber(t *testing.T) {
}
}

func TestContainsObjectKey(t *testing.T) {
key1 := client.ObjectKey{Namespace: "ns1", Name: "obj1"}
key2 := client.ObjectKey{Namespace: "ns2", Name: "obj2"}
key3 := client.ObjectKey{Namespace: "ns3", Name: "obj3"}
key4 := client.ObjectKey{Namespace: "ns4", Name: "obj4"}

testCases := []struct {
name string
list []client.ObjectKey
key client.ObjectKey
expected bool
}{
{
name: "when list contains key then return true",
list: []client.ObjectKey{key1, key2, key3},
key: key2,
expected: true,
},
{
name: "when list does not contain key then return false",
list: []client.ObjectKey{key1, key2, key3},
key: key4,
expected: false,
},
{
name: "when list is empty then return false",
list: []client.ObjectKey{},
key: key4,
expected: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := ContainsObjectKey(tc.list, tc.key)
if result != tc.expected {
t.Errorf("unexpected result: got %t, want %t", result, tc.expected)
}
})
}
}

func TestFindObjectKey(t *testing.T) {
key1 := client.ObjectKey{Namespace: "ns1", Name: "obj1"}
key2 := client.ObjectKey{Namespace: "ns2", Name: "obj2"}
Expand Down

0 comments on commit 6d5c580

Please sign in to comment.