diff --git a/client_resolver_test.go b/client_resolver_test.go index 738b078..e2652f1 100644 --- a/client_resolver_test.go +++ b/client_resolver_test.go @@ -8,14 +8,16 @@ import ( "time" mqtt "github.com/eclipse/paho.mqtt.golang" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) func TestClient_newClient(t *testing.T) { tests := []struct { - name string - addrs []TCPAddress - newClientFunc func(*mqtt.ClientOptions) mqtt.Client + name string + addrs []TCPAddress + newClientFunc func(*mqtt.ClientOptions) mqtt.Client + onConnLostAssert func(*testing.T, error) }{ { name: "success_attempt_1", @@ -92,6 +94,9 @@ func TestClient_newClient(t *testing.T) { Port: 8888, }, }, + onConnLostAssert: func(t *testing.T, err error) { + assert.EqualError(t, err, "some error") + }, newClientFunc: func(o *mqtt.ClientOptions) mqtt.Client { if o.Servers[0].String() != "tcp://localhost:1883" { panic(o.Servers) @@ -123,15 +128,22 @@ func TestClient_newClient(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &Client{ - options: defaultClientOptions(), + opts := defaultClientOptions() + if tt.onConnLostAssert != nil { + opts.onConnectionLostHandler = func(err error) { + tt.onConnLostAssert(t, err) + } } + + c := &Client{options: opts} newClientFunc.Store(tt.newClientFunc) got := c.newClient(tt.addrs, 0) got.(*mockClient).AssertExpectations(t) }) } + + newClientFunc.Store(mqtt.NewClient) } func TestClient_watchAddressUpdates(t *testing.T) {