Skip to content

Commit

Permalink
Add ptls.Dialer to provide some common configuration for tls.Dial ope…
Browse files Browse the repository at this point in the history
…rations
  • Loading branch information
joshuatcasey committed Sep 11, 2024
1 parent 62d1715 commit 99b6bad
Show file tree
Hide file tree
Showing 8 changed files with 300 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package webhookcachefiller

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net/url"
Expand Down Expand Up @@ -77,6 +76,7 @@ func New(
withInformer pinnipedcontroller.WithInformerOptionFunc,
clock clock.Clock,
log plog.Logger,
dialer ptls.Dialer,
) controllerlib.Controller {
return controllerlib.New(
controllerlib.Config{
Expand All @@ -90,6 +90,7 @@ func New(
configMapInformer: configMapInformer,
clock: clock,
log: log.WithName(controllerName),
dialer: dialer,
},
},
withInformer(
Expand Down Expand Up @@ -125,6 +126,7 @@ type webhookCacheFillerController struct {
client conciergeclientset.Interface
clock clock.Clock
log plog.Logger
dialer ptls.Dialer
}

// Sync implements controllerlib.Syncer.
Expand Down Expand Up @@ -428,11 +430,11 @@ func (c *webhookCacheFillerController) validateConnection(
return conditions, nil
}

conn, err := tls.Dial("tcp", endpointHostPort.Endpoint(), ptls.Default(certPool))
err := c.dialer.IsReachableAndTLSValidationSucceeds(endpointHostPort.Endpoint(), certPool, logger)

if err != nil {
errText := "cannot dial server"
msg := fmt.Sprintf("%s: %s", errText, err.Error())
msg := fmt.Sprintf("%s: %s", errText, err)
conditions = append(conditions, &metav1.Condition{
Type: typeWebhookConnectionValid,
Status: metav1.ConditionFalse,
Expand All @@ -442,13 +444,6 @@ func (c *webhookCacheFillerController) validateConnection(
return conditions, fmt.Errorf("%s: %w", errText, err)
}

// this error should never be significant
err = conn.Close()
if err != nil {
// no unit test for this failure
logger.Error("error closing dialer", err)
}

conditions = append(conditions, successfulWebhookConnectionValidCondition())
return conditions, nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1934,7 +1934,8 @@ func TestController(t *testing.T) {
kubeInformers.Core().V1().ConfigMaps(),
controllerlib.WithInformer,
frozenClock,
logger)
logger,
ptls.NewDialer())

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -2177,7 +2178,8 @@ func TestControllerFilterSecret(t *testing.T) {
configMapInformer,
observableInformers.WithInformer,
frozenClock,
logger)
logger,
ptls.NewDialer())

unrelated := &corev1.Secret{}
filter := observableInformers.GetFilterForInformer(secretInformer)
Expand Down Expand Up @@ -2238,7 +2240,8 @@ func TestControllerFilterConfigMap(t *testing.T) {
configMapInformer,
observableInformers.WithInformer,
frozenClock,
logger)
logger,
ptls.NewDialer())

unrelated := &corev1.ConfigMap{}
filter := observableInformers.GetFilterForInformer(configMapInformer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package githubupstreamwatcher

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
Expand Down Expand Up @@ -117,7 +116,7 @@ type gitHubWatcherController struct {
secretInformer corev1informers.SecretInformer
configMapInformer corev1informers.ConfigMapInformer
clock clock.Clock
dialFunc func(network, addr string, config *tls.Config) (*tls.Conn, error)
dialer ptls.Dialer
validatedCache GitHubValidatedAPICacheI
}

Expand All @@ -132,7 +131,7 @@ func New(
log plog.Logger,
withInformer pinnipedcontroller.WithInformerOptionFunc,
clock clock.Clock,
dialFunc func(network, addr string, config *tls.Config) (*tls.Conn, error),
dialer ptls.Dialer,
validatedCache *cache.Expiring,
) controllerlib.Controller {
c := gitHubWatcherController{
Expand All @@ -144,7 +143,7 @@ func New(
secretInformer: secretInformer,
configMapInformer: configMapInformer,
clock: clock,
dialFunc: dialFunc,
dialer: dialer,
validatedCache: NewGitHubValidatedAPICache(validatedCache),
}

Expand Down Expand Up @@ -471,7 +470,7 @@ func (c *gitHubWatcherController) validateGitHubConnection(
apiAddress := apiHostPort.Endpoint()

if !c.validatedCache.IsValid(apiAddress, caBundle.Hash()) {
conn, tlsDialErr := c.dialFunc("tcp", apiAddress, ptls.Default(caBundle.CertPool()))
tlsDialErr := c.dialer.IsReachableAndTLSValidationSucceeds(apiAddress, caBundle.CertPool(), c.log)
if tlsDialErr != nil {
return &metav1.Condition{
Type: GitHubConnectionValid,
Expand All @@ -481,8 +480,6 @@ func (c *gitHubWatcherController) validateGitHubConnection(
apiAddress, *specifiedHost, buildDialErrorMessage(tlsDialErr)),
}, nil, tlsDialErr
}
// Any error should be ignored. We have performed a successful Dial, so no need to requeue this Sync.
_ = conn.Close()
}

c.validatedCache.MarkAsValidated(apiAddress, caBundle.Hash())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package githubupstreamwatcher
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
Expand Down Expand Up @@ -41,6 +40,7 @@ import (
"go.pinniped.dev/internal/controller/supervisorconfig/upstreamwatchers"
"go.pinniped.dev/internal/controller/tlsconfigutil"
"go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/crypto/ptls"
"go.pinniped.dev/internal/federationdomain/dynamicupstreamprovider"
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
"go.pinniped.dev/internal/net/phttp"
Expand All @@ -61,6 +61,32 @@ var (
githubIDPKind = idpv1alpha1.SchemeGroupVersion.WithKind("GitHubIdentityProvider")
)

type fakeGithubDialer struct {
t *testing.T
realAddress string
realCertPool *x509.CertPool
}

func (f fakeGithubDialer) IsReachableAndTLSValidationSucceeds(address string, _ *x509.CertPool, logger ptls.ErrorOnlyLogger) error {
require.Equal(f.t, "api.github.com:443", address)

return ptls.NewDialer().IsReachableAndTLSValidationSucceeds(f.realAddress, f.realCertPool, logger)
}

var _ ptls.Dialer = (*fakeGithubDialer)(nil)

type allowNoDials struct {
t *testing.T
}

func (f allowNoDials) IsReachableAndTLSValidationSucceeds(_ string, _ *x509.CertPool, _ ptls.ErrorOnlyLogger) error {
f.t.Errorf("this test should not perform dial")
f.t.FailNow()
return nil
}

var _ ptls.Dialer = (*allowNoDials)(nil)

func TestController(t *testing.T) {
require.Equal(t, 6, countExpectedConditions)

Expand Down Expand Up @@ -406,7 +432,7 @@ func TestController(t *testing.T) {
name string
githubIdentityProviders []runtime.Object
secretsAndConfigMaps []runtime.Object
mockDialer func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error)
mockDialer func(*testing.T) ptls.Dialer
preexistingValidatedCache []GitHubValidatedAPICacheKey
wantErr string
wantLogs []string
Expand Down Expand Up @@ -555,15 +581,13 @@ func TestController(t *testing.T) {
return githubIDP
}(),
},
mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) {
mockDialer: func(t *testing.T) ptls.Dialer {
t.Helper()

return func(network, addr string, config *tls.Config) (*tls.Conn, error) {
require.Equal(t, "api.github.com:443", addr)
// don't actually dial github.com to avoid making external network calls in unit test
configClone := config.Clone()
configClone.RootCAs = goodServerCertPool
return tls.Dial(network, goodServerDomain, configClone)
return &fakeGithubDialer{
t: t,
realAddress: goodServerDomain,
realCertPool: goodServerCertPool,
}
},
wantResultingCache: []*upstreamgithub.ProviderConfig{
Expand Down Expand Up @@ -638,15 +662,13 @@ func TestController(t *testing.T) {
return githubIDP
}(),
},
mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) {
mockDialer: func(t *testing.T) ptls.Dialer {
t.Helper()

return func(network, addr string, config *tls.Config) (*tls.Conn, error) {
require.Equal(t, "api.github.com:443", addr)
// don't actually dial github.com to avoid making external network calls in unit test
configClone := config.Clone()
configClone.RootCAs = goodServerCertPool
return tls.Dial(network, goodServerDomain, configClone)
return &fakeGithubDialer{
t: t,
realAddress: goodServerDomain,
realCertPool: goodServerCertPool,
}
},
wantResultingCache: []*upstreamgithub.ProviderConfig{
Expand Down Expand Up @@ -721,15 +743,13 @@ func TestController(t *testing.T) {
return githubIDP
}(),
},
mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) {
mockDialer: func(t *testing.T) ptls.Dialer {
t.Helper()

return func(network, addr string, config *tls.Config) (*tls.Conn, error) {
require.Equal(t, "api.github.com:443", addr)
// don't actually dial github.com to avoid making external network calls in unit test
configClone := config.Clone()
configClone.RootCAs = goodServerCertPool
return tls.Dial(network, goodServerDomain, configClone)
return &fakeGithubDialer{
t: t,
realAddress: goodServerDomain,
realCertPool: goodServerCertPool,
}
},
wantResultingCache: []*upstreamgithub.ProviderConfig{
Expand Down Expand Up @@ -804,15 +824,13 @@ func TestController(t *testing.T) {
return githubIDP
}(),
},
mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) {
mockDialer: func(t *testing.T) ptls.Dialer {
t.Helper()

return func(network, addr string, config *tls.Config) (*tls.Conn, error) {
require.Equal(t, "api.github.com:443", addr)
// don't actually dial github.com to avoid making external network calls in unit test
configClone := config.Clone()
configClone.RootCAs = goodServerCertPool
return tls.Dial(network, goodServerDomain, configClone)
return &fakeGithubDialer{
t: t,
realAddress: goodServerDomain,
realCertPool: goodServerCertPool,
}
},
wantResultingCache: []*upstreamgithub.ProviderConfig{
Expand Down Expand Up @@ -887,15 +905,13 @@ func TestController(t *testing.T) {
return githubIDP
}(),
},
mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) {
mockDialer: func(t *testing.T) ptls.Dialer {
t.Helper()

return func(network, addr string, config *tls.Config) (*tls.Conn, error) {
require.Equal(t, "api.github.com:443", addr)
// don't actually dial github.com to avoid making external network calls in unit test
configClone := config.Clone()
configClone.RootCAs = goodServerCertPool
return tls.Dial(network, goodServerDomain, configClone)
return &fakeGithubDialer{
t: t,
realAddress: goodServerDomain,
realCertPool: goodServerCertPool,
}
},
wantResultingCache: []*upstreamgithub.ProviderConfig{
Expand Down Expand Up @@ -1379,14 +1395,10 @@ func TestController(t *testing.T) {
name: "happy path with previously validated address/CA Bundle does not validate again",
secretsAndConfigMaps: []runtime.Object{goodClientCredentialsSecret},
githubIdentityProviders: []runtime.Object{validFilledOutIDP},
mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) {
mockDialer: func(t *testing.T) ptls.Dialer {
t.Helper()

return func(network, addr string, config *tls.Config) (*tls.Conn, error) {
t.Errorf("this test should not perform dial")
t.FailNow()
return nil, nil
}
return &allowNoDials{t: t}
},
preexistingValidatedCache: []GitHubValidatedAPICacheKey{
{
Expand Down Expand Up @@ -2479,7 +2491,7 @@ func TestController(t *testing.T) {

gitHubIdentityProviderInformer := supervisorInformers.IDP().V1alpha1().GitHubIdentityProviders()

dialer := tls.Dial
var dialer ptls.Dialer = ptls.NewDialer()
if tt.mockDialer != nil {
dialer = tt.mockDialer(t)
}
Expand Down Expand Up @@ -2882,7 +2894,7 @@ func TestController_OnlyWantActions(t *testing.T) {
logger,
controllerlib.WithInformer,
frozenClockForLastTransitionTime,
tls.Dial,
ptls.NewDialer(),
cache.NewExpiring(),
)

Expand Down Expand Up @@ -3006,7 +3018,7 @@ func TestGitHubUpstreamWatcherControllerFilterSecret(t *testing.T) {
logger,
observableInformers.WithInformer,
clock.RealClock{},
tls.Dial,
ptls.NewDialer(),
cache.NewExpiring(),
)

Expand Down Expand Up @@ -3063,7 +3075,7 @@ func TestGitHubUpstreamWatcherControllerFilterConfigMaps(t *testing.T) {
logger,
observableInformers.WithInformer,
clock.RealClock{},
tls.Dial,
ptls.NewDialer(),
cache.NewExpiring(),
)

Expand Down Expand Up @@ -3120,7 +3132,7 @@ func TestGitHubUpstreamWatcherControllerFilterGitHubIDP(t *testing.T) {
logger,
observableInformers.WithInformer,
clock.RealClock{},
tls.Dial,
ptls.NewDialer(),
cache.NewExpiring(),
)

Expand Down
2 changes: 2 additions & 0 deletions internal/controllermanager/prepare_controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"go.pinniped.dev/internal/controller/serviceaccounttokencleanup"
"go.pinniped.dev/internal/controllerinit"
"go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/crypto/ptls"
"go.pinniped.dev/internal/deploymentref"
"go.pinniped.dev/internal/downward"
"go.pinniped.dev/internal/dynamiccert"
Expand Down Expand Up @@ -244,6 +245,7 @@ func PrepareControllers(c *Config) (controllerinit.RunnerBuilder, error) { //nol
controllerlib.WithInformer,
clock.RealClock{},
plog.New(),
ptls.NewDialer(),

Check warning on line 248 in internal/controllermanager/prepare_controllers.go

View check run for this annotation

Codecov / codecov/patch

internal/controllermanager/prepare_controllers.go#L248

Added line #L248 was not covered by tests
),
singletonWorker,
).
Expand Down
Loading

0 comments on commit 99b6bad

Please sign in to comment.