diff --git a/.golangci.yml b/.golangci.yml index d7a698e6..9a96986b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,6 +1,7 @@ linters: enable: - bodyclose + - errorlint - goconst - godot - gofmt @@ -10,5 +11,4 @@ linters: disable: # Temporarily disabling so it can be addressed in a dedicated PR. - errcheck - - errorlint - goerr113 \ No newline at end of file diff --git a/client_test.go b/client_test.go index 5fd8422c..62153b23 100644 --- a/client_test.go +++ b/client_test.go @@ -3,6 +3,7 @@ package kafka import ( "bytes" "context" + "errors" "io" "math/rand" "net" @@ -262,7 +263,7 @@ func TestClientProduceAndConsume(t *testing.T) { for { r, err := res.Records.ReadRecord() if err != nil { - if err != io.EOF { + if !errors.Is(err, io.EOF) { t.Fatal(err) } break diff --git a/compress/snappy/go-xerial-snappy/snappy_test.go b/compress/snappy/go-xerial-snappy/snappy_test.go index 02b02226..12ae72ff 100644 --- a/compress/snappy/go-xerial-snappy/snappy_test.go +++ b/compress/snappy/go-xerial-snappy/snappy_test.go @@ -2,6 +2,7 @@ package snappy import ( "bytes" + "errors" "testing" ) @@ -92,7 +93,7 @@ func TestSnappyDecodeMalformedTruncatedHeader(t *testing.T) { for i := 0; i < len(xerialHeader); i++ { buf := make([]byte, i) copy(buf, xerialHeader[:i]) - if _, err := Decode(buf); err != ErrMalformed { + if _, err := Decode(buf); !errors.Is(err, ErrMalformed) { t.Errorf("expected ErrMalformed got %v", err) } } @@ -104,7 +105,7 @@ func TestSnappyDecodeMalformedTruncatedSize(t *testing.T) { for _, size := range sizes { buf := make([]byte, size) copy(buf, xerialHeader) - if _, err := Decode(buf); err != ErrMalformed { + if _, err := Decode(buf); !errors.Is(err, ErrMalformed) { t.Errorf("expected ErrMalformed got %v", err) } } @@ -116,7 +117,7 @@ func TestSnappyDecodeMalformedBNoData(t *testing.T) { copy(buf, xerialHeader) // indicate that there's one byte of data to be read buf[len(buf)-1] = 1 - if _, err := Decode(buf); err != ErrMalformed { + if _, err := Decode(buf); !errors.Is(err, ErrMalformed) { t.Errorf("expected ErrMalformed got %v", err) } } @@ -128,7 +129,7 @@ func TestSnappyMasterDecodeFailed(t *testing.T) { buf[len(buf)-2] = 1 // A payload which will not decode buf[len(buf)-1] = 1 - if _, err := Decode(buf); err == ErrMalformed || err == nil { + if _, err := Decode(buf); errors.Is(err, ErrMalformed) || err == nil { t.Errorf("unexpected err: %v", err) } } diff --git a/conn.go b/conn.go index 73cd5e46..1e2d0388 100644 --- a/conn.go +++ b/conn.go @@ -1231,11 +1231,10 @@ func (c *Conn) writeRequest(apiKey apiKey, apiVersion apiVersion, correlationID func (c *Conn) readResponse(size int, res interface{}) error { size, err := read(&c.rbuf, size, res) - switch err.(type) { - case Error: - var e error - if size, e = discardN(&c.rbuf, size, size); e != nil { - err = e + if err != nil { + var kafkaError Error + if errors.As(err, &kafkaError) { + size, err = discardN(&c.rbuf, size, size) } } return expectZeroSize(size, err) @@ -1294,9 +1293,8 @@ func (c *Conn) do(d *connDeadline, write func(time.Time, int32) error, read func } if err = read(deadline, size); err != nil { - switch err.(type) { - case Error: - default: + var kafkaError Error + if !errors.As(err, &kafkaError) { c.conn.Close() } } diff --git a/conn_test.go b/conn_test.go index 7c418685..7cc08cf8 100644 --- a/conn_test.go +++ b/conn_test.go @@ -389,9 +389,11 @@ func testConnWrite(t *testing.T, conn *Conn) { func testConnCloseAndWrite(t *testing.T, conn *Conn) { conn.Close() - switch _, err := conn.Write([]byte("Hello World!")); err.(type) { - case *net.OpError: - default: + _, err := conn.Write([]byte("Hello World!")) + + // expect a network error + var netOpError *net.OpError + if !errors.As(err, &netOpError) { t.Error(err) } } @@ -489,7 +491,7 @@ func testConnSeekDontCheck(t *testing.T, conn *Conn) { t.Error("bad offset:", offset) } - if _, err := conn.ReadMessage(1024); err != OffsetOutOfRange { + if _, err := conn.ReadMessage(1024); !errors.Is(err, OffsetOutOfRange) { t.Error("unexpected error:", err) } } @@ -659,13 +661,15 @@ func waitForCoordinator(t *testing.T, conn *Conn, groupID string) { _, err := conn.findCoordinator(findCoordinatorRequestV0{ CoordinatorKey: groupID, }) - switch err { - case nil: + if err != nil { + if errors.Is(err, GroupCoordinatorNotAvailable) { + time.Sleep(250 * time.Millisecond) + continue + } else { + t.Fatalf("unable to find coordinator for group: %v", err) + } + } else { return - case GroupCoordinatorNotAvailable: - time.Sleep(250 * time.Millisecond) - default: - t.Fatalf("unable to find coordinator for group: %v", err) } } @@ -690,15 +694,18 @@ func createGroup(t *testing.T, conn *Conn, groupID string) (generationID int32, }, }, }) - switch err { - case nil: + if err != nil { + if errors.Is(err, NotCoordinatorForGroup) { + time.Sleep(250 * time.Millisecond) + continue + } else { + t.Fatalf("bad joinGroup: %s", err) + } + } else { return - case NotCoordinatorForGroup: - time.Sleep(250 * time.Millisecond) - default: - t.Fatalf("bad joinGroup: %s", err) } } + return } @@ -742,12 +749,11 @@ func testConnFindCoordinator(t *testing.T, conn *Conn) { } response, err := conn.findCoordinator(findCoordinatorRequestV0{CoordinatorKey: groupID}) if err != nil { - switch err { - case GroupCoordinatorNotAvailable: + if errors.Is(err, GroupCoordinatorNotAvailable) { continue - default: - t.Fatalf("bad findCoordinator: %s", err) } + + t.Fatalf("bad findCoordinator: %s", err) } if response.Coordinator.NodeID == 0 { diff --git a/consumergroup.go b/consumergroup.go index d8611689..b9d0a7e2 100644 --- a/consumergroup.go +++ b/consumergroup.go @@ -523,19 +523,21 @@ func (g *Generation) partitionWatcher(interval time.Duration, topic string) { return case <-ticker.C: ops, err := g.conn.readPartitions(topic) - switch err { - case nil, UnknownTopicOrPartition: + switch { + case err == nil, errors.Is(err, UnknownTopicOrPartition): if len(ops) != oParts { g.log(func(l Logger) { l.Printf("Partition changes found, reblancing group: %v.", g.GroupID) }) return } + default: g.logError(func(l Logger) { l.Printf("Problem getting partitions while checking for changes, %v", err) }) - if _, ok := err.(Error); ok { + var kafkaError Error + if errors.As(err, &kafkaError) { continue } // other errors imply that we lost the connection to the coordinator, so we @@ -724,20 +726,24 @@ func (cg *ConsumerGroup) run() { // to the next generation. it will be non-nil in the case of an error // joining or syncing the group. var backoff <-chan time.Time - switch err { - case nil: + + switch { + case err == nil: // no error...the previous generation finished normally. continue - case ErrGroupClosed: + + case errors.Is(err, ErrGroupClosed): // the CG has been closed...leave the group and exit loop. _ = cg.leaveGroup(memberID) return - case RebalanceInProgress: + + case errors.Is(err, RebalanceInProgress): // in case of a RebalanceInProgress, don't leave the group or // change the member ID, but report the error. the next attempt // to join the group will then be subject to the rebalance // timeout, so the broker will be responsible for throttling // this loop. + default: // leave the group and report the error if we had gotten far // enough so as to have a member ID. also clear the member id @@ -984,7 +990,7 @@ func (cg *ConsumerGroup) makeJoinGroupRequestV1(memberID string) (joinGroupReque for _, balancer := range cg.config.GroupBalancers { userData, err := balancer.UserData() if err != nil { - return joinGroupRequestV1{}, fmt.Errorf("unable to construct protocol metadata for member, %v: %v", balancer.ProtocolName(), err) + return joinGroupRequestV1{}, fmt.Errorf("unable to construct protocol metadata for member, %v: %w", balancer.ProtocolName(), err) } request.GroupProtocols = append(request.GroupProtocols, joinGroupRequestGroupProtocolV1{ ProtocolName: balancer.ProtocolName(), @@ -1050,7 +1056,7 @@ func (cg *ConsumerGroup) makeMemberProtocolMetadata(in []joinGroupResponseMember metadata := groupMetadata{} reader := bufio.NewReader(bytes.NewReader(item.MemberMetadata)) if remain, err := (&metadata).readFrom(reader, len(item.MemberMetadata)); err != nil || remain != 0 { - return nil, fmt.Errorf("unable to read metadata for member, %v: %v", item.MemberID, err) + return nil, fmt.Errorf("unable to read metadata for member, %v: %w", item.MemberID, err) } members = append(members, GroupMember{ diff --git a/createtopics.go b/createtopics.go index c903fe22..6767e07c 100644 --- a/createtopics.go +++ b/createtopics.go @@ -3,6 +3,7 @@ package kafka import ( "bufio" "context" + "errors" "fmt" "net" "time" @@ -384,12 +385,14 @@ func (c *Conn) CreateTopics(topics ...TopicConfig) error { _, err := c.createTopics(createTopicsRequestV0{ Topics: requestV0Topics, }) + if err != nil { + if errors.Is(err, TopicAlreadyExists) { + // ok + return nil + } - switch err { - case TopicAlreadyExists: - // ok - return nil - default: return err } + + return nil } diff --git a/dialer_test.go b/dialer_test.go index 5aedb777..7bc9e58c 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" "fmt" "io" "net" @@ -294,7 +295,7 @@ func TestDialerConnectTLSHonorsContext(t *testing.T) { defer cancel() _, err := d.connectTLS(ctx, conn, d.TLS) - if context.DeadlineExceeded != err { + if !errors.Is(err, context.DeadlineExceeded) { t.Errorf("expected err to be %v; got %v", context.DeadlineExceeded, err) t.FailNow() } diff --git a/discard_test.go b/discard_test.go index cb444f15..69b7e72c 100644 --- a/discard_test.go +++ b/discard_test.go @@ -3,6 +3,7 @@ package kafka import ( "bufio" "bytes" + "errors" "io" "testing" ) @@ -52,7 +53,7 @@ func TestDiscardN(t *testing.T) { scenario: "discard more than available", function: func(t *testing.T, r *bufio.Reader, sz int) { remain, err := discardN(r, sz, sz+1) - if err != errShortRead { + if !errors.Is(err, errShortRead) { t.Errorf("Expected errShortRead, got %v", err) } if remain != 0 { @@ -64,7 +65,7 @@ func TestDiscardN(t *testing.T) { scenario: "discard returns error", function: func(t *testing.T, r *bufio.Reader, sz int) { remain, err := discardN(r, sz+2, sz+1) - if err != io.EOF { + if !errors.Is(err, io.EOF) { t.Errorf("Expected EOF, got %v", err) } if remain != 2 { @@ -76,7 +77,7 @@ func TestDiscardN(t *testing.T) { scenario: "errShortRead doesn't mask error", function: func(t *testing.T, r *bufio.Reader, sz int) { remain, err := discardN(r, sz+1, sz+2) - if err != io.EOF { + if !errors.Is(err, io.EOF) { t.Errorf("Expected EOF, got %v", err) } if remain != 1 { diff --git a/example_consumergroup_test.go b/example_consumergroup_test.go index aec5b190..b1dc9097 100644 --- a/example_consumergroup_test.go +++ b/example_consumergroup_test.go @@ -2,6 +2,7 @@ package kafka_test import ( "context" + "errors" "fmt" "os" @@ -42,18 +43,20 @@ func ExampleGeneration_Start_consumerGroupParallelReaders() { reader.SetOffset(offset) for { msg, err := reader.ReadMessage(ctx) - switch err { - case kafka.ErrGenerationEnded: - // generation has ended. commit offsets. in a real app, - // offsets would be committed periodically. - gen.CommitOffsets(map[string]map[int]int64{"my-topic": {partition: offset + 1}}) - return - case nil: - fmt.Printf("received message %s/%d/%d : %s\n", msg.Topic, msg.Partition, msg.Offset, string(msg.Value)) - offset = msg.Offset - default: + if err != nil { + if errors.Is(err, kafka.ErrGenerationEnded) { + // generation has ended. commit offsets. in a real app, + // offsets would be committed periodically. + gen.CommitOffsets(map[string]map[int]int64{"my-topic": {partition: offset + 1}}) + return + } + fmt.Printf("error reading message: %+v\n", err) + return } + + fmt.Printf("received message %s/%d/%d : %s\n", msg.Topic, msg.Partition, msg.Offset, string(msg.Value)) + offset = msg.Offset } }) } diff --git a/message_test.go b/message_test.go index aa9bc630..383cd226 100644 --- a/message_test.go +++ b/message_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "encoding/hex" + "errors" "fmt" "io" "math/rand" @@ -551,7 +552,7 @@ func TestMessageSetReaderEmpty(t *testing.T) { if headers != nil { t.Errorf("expected nil headers, got %v", headers) } - if err != RequestTimedOut { + if !errors.Is(err, RequestTimedOut) { t.Errorf("expected RequestTimedOut, got %v", err) } diff --git a/protocol/protocol.go b/protocol/protocol.go index e8f9ff32..c3445525 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -1,6 +1,7 @@ package protocol import ( + "errors" "fmt" "io" "net" @@ -365,14 +366,15 @@ func parseVersion(s string) (int16, error) { } func dontExpectEOF(err error) error { - switch err { - case nil: - return nil - case io.EOF: - return io.ErrUnexpectedEOF - default: + if err != nil { + if errors.Is(err, io.EOF) { + return io.ErrUnexpectedEOF + } + return err } + + return nil } type Broker struct { diff --git a/protocol/prototest/prototest.go b/protocol/prototest/prototest.go index cf978900..8ba4e571 100644 --- a/protocol/prototest/prototest.go +++ b/protocol/prototest/prototest.go @@ -2,6 +2,7 @@ package prototest import ( "bytes" + "errors" "io" "reflect" "time" @@ -152,7 +153,7 @@ func deepEqualRecords(r1, r2 protocol.RecordReader) bool { rec2, err2 := r2.ReadRecord() if err1 != nil || err2 != nil { - return err1 == err2 + return errors.Is(err1, err2) } if !deepEqualRecord(rec1, rec2) { diff --git a/protocol/record_batch_test.go b/protocol/record_batch_test.go index e48e568e..8a6c351b 100644 --- a/protocol/record_batch_test.go +++ b/protocol/record_batch_test.go @@ -146,7 +146,7 @@ func assertRecords(t *testing.T, r1, r2 RecordReader) { rec2, err2 := r2.ReadRecord() if err1 != nil || err2 != nil { - if err1 != err2 { + if !errors.Is(err1, err2) { t.Error("errors mismatch:") t.Log("expected:", err2) t.Log("found: ", err1) diff --git a/read_test.go b/read_test.go index fd83a479..30d17e48 100644 --- a/read_test.go +++ b/read_test.go @@ -3,6 +3,7 @@ package kafka import ( "bufio" "bytes" + "errors" "io/ioutil" "math" "reflect" @@ -49,7 +50,7 @@ func TestReadVarIntFailing(t *testing.T) { testCase := []byte{135, 135} rd := bufio.NewReader(bytes.NewReader(testCase)) _, err := readVarInt(rd, len(testCase), &v) - if err != errShortRead { + if !errors.Is(err, errShortRead) { t.Errorf("Expected error while parsing var int: %v", err) } } @@ -160,7 +161,7 @@ func TestReadNewBytes(t *testing.T) { if remain != 0 { t.Error("all bytes should have been consumed") } - if err != errShortRead { + if !errors.Is(err, errShortRead) { t.Error("should have returned errShortRead") } b, err = r.Peek(0) diff --git a/reader.go b/reader.go index dcf15fe8..e24f134f 100644 --- a/reader.go +++ b/reader.go @@ -306,7 +306,7 @@ func (r *Reader) run(cg *ConsumerGroup) { if err == nil { break } - if err == r.stctx.Err() { + if errors.Is(err, r.stctx.Err()) { return } r.stats.errors.observe(1) @@ -832,9 +832,7 @@ func (r *Reader) FetchMessage(ctx context.Context) (Message, error) { r.mutex.Unlock() - switch m.error { - case nil: - case io.EOF: + if errors.Is(m.error, io.EOF) { // io.EOF is used as a marker to indicate that the stream // has been closed, in case it was received from the inner // reader we don't want to confuse the program and replace @@ -1249,17 +1247,17 @@ func (r *reader) run(ctx context.Context, offset int64) { }) conn, start, err := r.initialize(ctx, offset) - switch err { - case nil: - case OffsetOutOfRange: - // This would happen if the requested offset is passed the last - // offset on the partition leader. In that case we're just going - // to retry later hoping that enough data has been produced. - r.withErrorLogger(func(log Logger) { - log.Printf("error initializing the kafka reader for partition %d of %s: %s", r.partition, r.topic, OffsetOutOfRange) - }) - continue - default: + if err != nil { + if errors.Is(err, OffsetOutOfRange) { + // This would happen if the requested offset is passed the last + // offset on the partition leader. In that case we're just going + // to retry later hoping that enough data has been produced. + r.withErrorLogger(func(log Logger) { + log.Printf("error initializing the kafka reader for partition %d of %s: %s", r.partition, r.topic, err) + }) + continue + } + // Perform a configured number of attempts before // reporting first errors, this helps mitigate // situations where the kafka server is temporarily @@ -1292,18 +1290,21 @@ func (r *reader) run(ctx context.Context, offset int64) { return } - switch offset, err = r.read(ctx, offset, conn); err { - case nil: + offset, err = r.read(ctx, offset, conn) + switch { + case err == nil: errcount = 0 continue - case io.EOF: + + case errors.Is(err, io.EOF): // done with this batch of messages...carry on. note that this // block relies on the batch repackaging real io.EOF errors as // io.UnexpectedEOF. otherwise, we would end up swallowing real // errors here. errcount = 0 continue - case UnknownTopicOrPartition: + + case errors.Is(err, UnknownTopicOrPartition): r.withErrorLogger(func(log Logger) { log.Printf("failed to read from current broker for partition %d of %s at offset %d, topic or parition not found on this broker, %v", r.partition, r.topic, toHumanOffset(offset), r.brokers) }) @@ -1314,7 +1315,8 @@ func (r *reader) run(ctx context.Context, offset int64) { // topic/partition broker combo. r.stats.rebalances.observe(1) break readLoop - case NotLeaderForPartition: + + case errors.Is(err, NotLeaderForPartition): r.withErrorLogger(func(log Logger) { log.Printf("failed to read from current broker for partition %d of %s at offset %d, not the leader", r.partition, r.topic, toHumanOffset(offset)) }) @@ -1326,7 +1328,7 @@ func (r *reader) run(ctx context.Context, offset int64) { r.stats.rebalances.observe(1) break readLoop - case RequestTimedOut: + case errors.Is(err, RequestTimedOut): // Timeout on the kafka side, this can be safely retried. errcount = 0 r.withLogger(func(log Logger) { @@ -1335,7 +1337,7 @@ func (r *reader) run(ctx context.Context, offset int64) { r.stats.timeouts.observe(1) continue - case OffsetOutOfRange: + case errors.Is(err, OffsetOutOfRange): first, last, err := r.readOffsets(conn) if err != nil { r.withErrorLogger(func(log Logger) { @@ -1364,12 +1366,12 @@ func (r *reader) run(ctx context.Context, offset int64) { }) } - case context.Canceled: + case errors.Is(err, context.Canceled): // Another reader has taken over, we can safely quit. conn.Close() return - case errUnknownCodec: + case errors.Is(err, errUnknownCodec): // The compression codec is either unsupported or has not been // imported. This is a fatal error b/c the reader cannot // proceed. @@ -1377,7 +1379,8 @@ func (r *reader) run(ctx context.Context, offset int64) { break readLoop default: - if _, ok := err.(Error); ok { + var kafkaError Error + if errors.As(err, &kafkaError) { r.sendError(ctx, err) } else { r.withErrorLogger(func(log Logger) { diff --git a/reader_test.go b/reader_test.go index 16050a34..266cf893 100644 --- a/reader_test.go +++ b/reader_test.go @@ -89,7 +89,7 @@ func testReaderReadCanceled(t *testing.T, ctx context.Context, r *Reader) { ctx, cancel := context.WithCancel(ctx) cancel() - if _, err := r.ReadMessage(ctx); err != context.Canceled { + if _, err := r.ReadMessage(ctx); !errors.Is(err, context.Canceled) { t.Error(err) } } @@ -259,7 +259,7 @@ func testReaderOutOfRangeGetsCanceled(t *testing.T, ctx context.Context, r *Read } _, err := r.ReadMessage(ctx) - if err != context.DeadlineExceeded { + if !errors.Is(err, context.DeadlineExceeded) { t.Error("bad error:", err) } @@ -305,15 +305,12 @@ func createTopic(t *testing.T, topic string, partitions int) { }, Timeout: milliseconds(time.Second), }) - switch err { - case nil: - // ok - case TopicAlreadyExists: - // ok - default: - err = fmt.Errorf("creaetTopic, conn.createtTopics: %w", err) - t.Error(err) - t.FailNow() + if err != nil { + if !errors.Is(err, TopicAlreadyExists) { + err = fmt.Errorf("creaetTopic, conn.createtTopics: %w", err) + t.Error(err) + t.FailNow() + } } ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) @@ -664,7 +661,7 @@ func testConsumerGroupSimple(t *testing.T, ctx context.Context, r *Reader) { func TestReaderSetOffsetWhenConsumerGroupsEnabled(t *testing.T) { r := &Reader{config: ReaderConfig{GroupID: "not-zero"}} - if err := r.SetOffset(LastOffset); err != errNotAvailableWithGroup { + if err := r.SetOffset(LastOffset); !errors.Is(err, errNotAvailableWithGroup) { t.Fatalf("expected %v; got %v", errNotAvailableWithGroup, err) } } @@ -687,7 +684,7 @@ func TestReaderReadLagReturnsZeroLagWhenConsumerGroupsEnabled(t *testing.T) { r := &Reader{config: ReaderConfig{GroupID: "not-zero"}} lag, err := r.ReadLag(context.Background()) - if err != errNotAvailableWithGroup { + if !errors.Is(err, errNotAvailableWithGroup) { t.Fatalf("expected %v; got %v", errNotAvailableWithGroup, err) } @@ -1943,13 +1940,10 @@ func createTopicWithCompaction(t *testing.T, topic string, partitions int) { }, }, }) - switch err { - case nil: - // ok - case TopicAlreadyExists: - // ok - default: - require.NoError(t, err) + if err != nil { + if !errors.Is(err, TopicAlreadyExists) { + require.NoError(t, err) + } } ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) diff --git a/transport.go b/transport.go index be6e602d..6ba2d638 100644 --- a/transport.go +++ b/transport.go @@ -605,7 +605,7 @@ func (p *connPool) discover(ctx context.Context, wake <-chan event) { } r, err := res.await(deadline) cancel() - if err != nil && err == ctx.Err() { + if err != nil && errors.Is(err, ctx.Err()) { return } ret, _ := r.(*meta.Response) @@ -1286,14 +1286,14 @@ func authenticateSASL(ctx context.Context, pc *protocol.Conn, mechanism sasl.Mec for completed := false; !completed; { challenge, err := saslAuthenticateRoundTrip(pc, state) - switch err { - case nil: - case io.EOF: - // the broker may communicate a failed exchange by closing the - // connection (esp. in the case where we're passing opaque sasl - // data over the wire since there's no protocol info). - return SASLAuthenticationFailed - default: + if err != nil { + if errors.Is(err, io.EOF) { + // the broker may communicate a failed exchange by closing the + // connection (esp. in the case where we're passing opaque sasl + // data over the wire since there's no protocol info). + return SASLAuthenticationFailed + } + return err } diff --git a/writer_test.go b/writer_test.go index 3d0f07aa..04d01207 100644 --- a/writer_test.go +++ b/writer_test.go @@ -359,8 +359,9 @@ func testWriterMaxBytes(t *testing.T) { t.Error("expected error") return } else if err != nil { - switch e := err.(type) { - case MessageTooLargeError: + var e MessageTooLargeError + switch { + case errors.As(err, &e): if string(e.Message.Value) != string(firstMsg) { t.Errorf("unxpected returned message. Expected: %s, Got %s", firstMsg, e.Message.Value) return @@ -373,6 +374,7 @@ func testWriterMaxBytes(t *testing.T) { t.Errorf("unxpected returned message. Expected: %s, Got %s", secondMsg, e.Message.Value) return } + default: t.Errorf("unexpected error: %s", err) return