Skip to content

Commit

Permalink
add CredentialFetcher to allow updating credentials on each newOption…
Browse files Browse the repository at this point in the history
…s creation
  • Loading branch information
ajatprabha committed Jul 20, 2023
1 parent fd45a82 commit 6da2bfe
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 4 deletions.
21 changes: 19 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ func toClientOptions(c *Client, o *clientOptions) *mqtt.ClientOptions {
opts.SetClientID(o.clientID)
}

setCredentials(o, opts)

opts.AddBroker(formatAddressWithProtocol(o)).
SetUsername(o.username).
SetPassword(o.password).
SetTLSConfig(o.tlsConfig).
SetAutoReconnect(o.autoReconnect).
SetCleanSession(o.cleanSession).
Expand All @@ -184,6 +184,23 @@ func toClientOptions(c *Client, o *clientOptions) *mqtt.ClientOptions {
return opts
}

func setCredentials(o *clientOptions, opts *mqtt.ClientOptions) {
if o.credentialFetcher != nil {
ctx, cancel := context.WithTimeout(context.Background(), o.credentialFetchTimeout)
defer cancel()

if c, err := o.credentialFetcher.Credentials(ctx); err == nil {
opts.SetUsername(c.Username)
opts.SetPassword(c.Password)

return
}
}

opts.SetUsername(o.username)
opts.SetPassword(o.password)
}

func formatAddressWithProtocol(opts *clientOptions) string {
if opts.tlsConfig != nil {
return fmt.Sprintf("tls://%s", opts.brokerAddress)
Expand Down
21 changes: 21 additions & 0 deletions client_credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package courier

import "context"

// Credential is a <username,password> pair.
type Credential struct {
Username string
Password string
}

// CredentialFetcher is an interface that allows to fetch credentials for a client.
type CredentialFetcher interface {
Credentials(context.Context) (*Credential, error)
}

// WithCredentialFetcher sets the specified CredentialFetcher.
func WithCredentialFetcher(fetcher CredentialFetcher) ClientOption {
return optionFunc(func(o *clientOptions) {
o.credentialFetcher = fetcher
})
}
26 changes: 26 additions & 0 deletions client_credentials_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package courier

import (
"context"
"testing"

"github.com/stretchr/testify/mock"
)

type mockCredentialFetcher struct {
mock.Mock
}

func newMockCredentialFetcher(t *testing.T) *mockCredentialFetcher {
m := &mockCredentialFetcher{}
m.Test(t)
return m
}

func (m *mockCredentialFetcher) Credentials(ctx context.Context) (*Credential, error) {
args := m.Called(ctx)
if c := args.Get(0); c != nil {
return c.(*Credential), args.Error(1)
}
return nil, args.Error(1)
}
7 changes: 5 additions & 2 deletions client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,16 @@ func WithExponentialStartOptions(options ...StartOption) ClientOption {
type clientOptions struct {
username, clientID, password,
brokerAddress string
resolver Resolver
resolver Resolver
credentialFetcher CredentialFetcher

tlsConfig *tls.Config

autoReconnect, maintainOrder, cleanSession bool

connectTimeout, writeTimeout, keepAlive,
maxReconnectInterval, gracefulShutdownPeriod time.Duration
maxReconnectInterval, gracefulShutdownPeriod,
credentialFetchTimeout time.Duration

startOptions *startOptions

Expand All @@ -230,6 +232,7 @@ func defaultClientOptions() *clientOptions {
maxReconnectInterval: 5 * time.Minute,
gracefulShutdownPeriod: 30 * time.Second,
keepAlive: 60 * time.Second,
credentialFetchTimeout: 10 * time.Second,
newEncoder: DefaultEncoderFunc,
newDecoder: DefaultDecoderFunc,
store: inMemoryPersistence,
Expand Down
7 changes: 7 additions & 0 deletions client_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ func TestClientOptionSuite(t *testing.T) {
func (s *ClientOptionSuite) Test_apply() {
store := NewMemoryStore()
r := resolver{}
mc := newMockCredentialFetcher(s.T())

tests := []struct {
name string
Expand Down Expand Up @@ -101,6 +102,11 @@ func (s *ClientOptionSuite) Test_apply() {
option: WithResolver(r),
want: &clientOptions{resolver: r},
},
{
name: "WithCredentialFetcher",
option: WithCredentialFetcher(mc),
want: &clientOptions{credentialFetcher: mc},
},
{
name: "WithExponentialStartOptions",
option: WithExponentialStartOptions(WithMaxInterval(time.Second)),
Expand Down Expand Up @@ -192,6 +198,7 @@ func (s *ClientOptionSuite) Test_defaultOptions() {
maxReconnectInterval: 5 * time.Minute,
gracefulShutdownPeriod: 30 * time.Second,
keepAlive: 60 * time.Second,
credentialFetchTimeout: 10 * time.Second,
newEncoder: DefaultEncoderFunc,
newDecoder: DefaultDecoderFunc,
store: inMemoryPersistence,
Expand Down
31 changes: 31 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,37 @@ func TestNewClientWithResolverOption(t *testing.T) {
mr.AssertExpectations(t)
}

func TestNewClientWithCredentialFetcher(t *testing.T) {
mcf := newMockCredentialFetcher(t)
mcf.On("Credentials", mock.Anything).Return(&Credential{
Username: "username",
Password: "password",
}, nil)

newClientFunc.Store(func(opts *mqtt.ClientOptions) mqtt.Client {
assert.Equal(t, "username", opts.Username)
assert.Equal(t, "password", opts.Password)
return mqtt.NewClient(opts)
})
defer func() {
newClientFunc.Store(mqtt.NewClient)
}()

c, err := NewClient(append(defOpts, WithCredentialFetcher(mcf))...)

assert.NoError(t, c.Start())
mcf.AssertExpectations(t)

assert.Eventually(t, func() bool {
return c.IsConnected()
}, 10*time.Second, 250*time.Millisecond)

c.Stop()

assert.NoError(t, err)
mcf.AssertExpectations(t)
}

func TestNewClientWithExponentialStartOptions(t *testing.T) {
c, err := NewClient(append(defOpts, WithExponentialStartOptions(WithMaxInterval(10*time.Second)))...)
assert.NoError(t, err)
Expand Down
35 changes: 35 additions & 0 deletions docs/docs/sdk/SDK.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Package courier contains the client that can be used to interact with the courie
- [func WithCleanSession\(cleanSession bool\) ClientOption](#WithCleanSession)
- [func WithClientID\(clientID string\) ClientOption](#WithClientID)
- [func WithConnectTimeout\(duration time.Duration\) ClientOption](#WithConnectTimeout)
- [func WithCredentialFetcher\(fetcher CredentialFetcher\) ClientOption](#WithCredentialFetcher)
- [func WithCustomDecoder\(decoderFunc DecoderFunc\) ClientOption](#WithCustomDecoder)
- [func WithCustomEncoder\(encoderFunc EncoderFunc\) ClientOption](#WithCustomEncoder)
- [func WithExponentialStartOptions\(options ...StartOption\) ClientOption](#WithExponentialStartOptions)
Expand All @@ -52,6 +53,8 @@ Package courier contains the client that can be used to interact with the courie
- [func WithUsername\(username string\) ClientOption](#WithUsername)
- [func WithWriteTimeout\(duration time.Duration\) ClientOption](#WithWriteTimeout)
- [type ConnectionInformer](#ConnectionInformer)
- [type Credential](#Credential)
- [type CredentialFetcher](#CredentialFetcher)
- [type Decoder](#Decoder)
- [func DefaultDecoderFunc\(\_ context.Context, r io.Reader\) Decoder](#DefaultDecoderFunc)
- [type DecoderFunc](#DecoderFunc)
Expand Down Expand Up @@ -366,6 +369,15 @@ func WithConnectTimeout(duration time.Duration) ClientOption

WithConnectTimeout limits how long the client will wait when trying to open a connection to an MQTT server before timing out. A duration of 0 never times out. Default 15 seconds.

<a name="WithCredentialFetcher"></a>
### func [WithCredentialFetcher](https://github.com/gojek/courier-go/blob/main/client_credentials.go#L17)

```go
func WithCredentialFetcher(fetcher CredentialFetcher) ClientOption
```

WithCredentialFetcher sets the specified CredentialFetcher.

<a name="WithCustomDecoder"></a>
### func [WithCustomDecoder](https://github.com/gojek/courier-go/blob/main/client_options.go#L176)

Expand Down Expand Up @@ -542,6 +554,29 @@ type ConnectionInformer interface {
}
```

<a name="Credential"></a>
## type [Credential](https://github.com/gojek/courier-go/blob/main/client_credentials.go#L6-L9)

Credential is a \<username,password\> pair.

```go
type Credential struct {
Username string
Password string
}
```

<a name="CredentialFetcher"></a>
## type [CredentialFetcher](https://github.com/gojek/courier-go/blob/main/client_credentials.go#L12-L14)

CredentialFetcher is an interface that allows to fetch credentials for a client.

```go
type CredentialFetcher interface {
Credentials(context.Context) (*Credential, error)
}
```

<a name="Decoder"></a>
## type [Decoder](https://github.com/gojek/courier-go/blob/main/decoder.go#L16-L19)

Expand Down

0 comments on commit 6da2bfe

Please sign in to comment.