From 4178fdabbdd599bd653a16471296e18aea17b264 Mon Sep 17 00:00:00 2001 From: Sukun Date: Fri, 27 Jan 2023 03:56:46 +0530 Subject: [PATCH] Add not supported protocols to returned errors (#97) * Add not supported protocols to returned errors * rename struct to ErrNotSupported --- client.go | 41 ++++++++++++++++++++++++++++------------- lazyClient.go | 2 +- multistream_test.go | 23 ++++++++++++++++++----- 3 files changed, 47 insertions(+), 19 deletions(-) diff --git a/client.go b/client.go index 506e453..9a66501 100644 --- a/client.go +++ b/client.go @@ -13,9 +13,17 @@ import ( "strings" ) -// ErrNotSupported is the error returned when the muxer does not support -// the protocol specified for the handshake. -var ErrNotSupported = errors.New("protocol not supported") +// ErrNotSupported is the error returned when the muxer doesn't support +// the protocols tried for the handshake. +type ErrNotSupported[T StringLike] struct { + + // Slice of protocols that were not supported by the muxer + Protos []T +} + +func (e ErrNotSupported[T]) Error() string { + return fmt.Sprintf("protocols not supported: %v", e.Protos) +} // ErrNoProtocols is the error returned when the no protocols have been // specified. @@ -83,14 +91,18 @@ func SelectOneOf[T StringLike](protos []T, rwc io.ReadWriteCloser) (proto T, err // can continue negotiating the rest of the protocols normally. // // This saves us a round trip. - switch err := SelectProtoOrFail(protos[0], rwc); err { + switch err := SelectProtoOrFail(protos[0], rwc); err.(type) { case nil: return protos[0], nil - case ErrNotSupported: // try others + case ErrNotSupported[T]: // try others default: return "", err } - return selectProtosOrFail(protos[1:], rwc) + proto, err = selectProtosOrFail(protos[1:], rwc) + if _, ok := err.(ErrNotSupported[T]); ok { + return "", ErrNotSupported[T]{protos} + } + return proto, err } const simOpenProtocol = "/libp2p/simultaneous-connect" @@ -161,7 +173,11 @@ func clientOpen[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, error) { case protos[0]: return tok, nil case "na": - return selectProtosOrFail(protos[1:], rwc) + proto, err := selectProtosOrFail(protos[1:], rwc) + if _, ok := err.(ErrNotSupported[T]); ok { + return "", ErrNotSupported[T]{protos} + } + return proto, err default: return "", fmt.Errorf("unexpected response: %s", tok) } @@ -170,15 +186,15 @@ func clientOpen[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, error) { func selectProtosOrFail[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, error) { for _, p := range protos { err := trySelect(p, rwc) - switch err { + switch err := err.(type) { case nil: return p, nil - case ErrNotSupported: + case ErrNotSupported[T]: default: return "", err } } - return "", ErrNotSupported + return "", ErrNotSupported[T]{protos} } func simOpen[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, bool, error) { @@ -255,12 +271,11 @@ func simOpenSelectServer[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, e if err = <-werrCh; err != nil { return "", err } - for { tok, err = ReadNextToken[T](rwc) if err == io.EOF { - return "", ErrNotSupported + return "", ErrNotSupported[T]{protos} } if err != nil { @@ -337,7 +352,7 @@ func readProto[T StringLike](proto T, r io.Reader) error { case proto: return nil case "na": - return ErrNotSupported + return ErrNotSupported[T]{[]T{proto}} default: return fmt.Errorf("unrecognized response: %s", tok) } diff --git a/lazyClient.go b/lazyClient.go index 13108cc..6145eaf 100644 --- a/lazyClient.go +++ b/lazyClient.go @@ -78,7 +78,7 @@ func (l *lazyClientConn[T]) doReadHandshake() { } if tok == "na" { - l.rerr = ErrNotSupported + l.rerr = ErrNotSupported[T]{[]T{proto}} return } if tok != proto { diff --git a/multistream_test.go b/multistream_test.go index 454e59d..0398a13 100644 --- a/multistream_test.go +++ b/multistream_test.go @@ -21,6 +21,19 @@ func newRwcStrict(t *testing.T, rwc io.ReadWriteCloser) io.ReadWriteCloser { return &rwcStrict{t: t, rwc: rwc} } +func cmpErrNotSupport(e1 error, e2 ErrNotSupported[string]) bool { + e, ok := e1.(ErrNotSupported[string]) + if !ok || len(e.Protos) != len(e2.Protos) { + return false + } + for i := 0; i < len(e.Protos); i++ { + if e.Protos[i] != e2.Protos[i] { + return false + } + } + return true +} + func (s *rwcStrict) Read(b []byte) (int, error) { if s.reading { s.t.Error("concurrent read") @@ -160,7 +173,7 @@ func TestProtocolNegotiationUnsupported(t *testing.T) { c := NewMSSelect(b, "/foo") c.Write([]byte("foo protocol data")) _, err := c.Read([]byte{0}) - if err != ErrNotSupported { + if !cmpErrNotSupport(err, ErrNotSupported[string]{[]string{"/foo"}}) { t.Fatalf("expected protocol /foo to be unsupported, got: %v", err) } c.Close() @@ -349,7 +362,7 @@ func TestSelectFails(t *testing.T) { go mux.Negotiate(a) _, err := SelectOneOf([]string{"/d", "/e"}, b) - if err != ErrNotSupported { + if !cmpErrNotSupport(err, ErrNotSupported[string]{[]string{"/d", "/e"}}) { t.Fatal("expected to not be supported") } } @@ -842,7 +855,7 @@ func TestSimopenClientServerFail(t *testing.T) { }() _, _, err := SelectWithSimopenOrFail([]string{"/b"}, b) - if err != ErrNotSupported { + if !cmpErrNotSupport(err, ErrNotSupported[string]{[]string{"/b"}}) { t.Fatal(err) } b.Close() @@ -936,7 +949,7 @@ func TestSimopenClientClientFail(t *testing.T) { done := make(chan struct{}) go func() { _, _, err := SelectWithSimopenOrFail([]string{"/a"}, b) - if err != ErrNotSupported { + if !cmpErrNotSupport(err, ErrNotSupported[string]{[]string{"/a"}}) { t.Error(err) } b.Close() @@ -944,7 +957,7 @@ func TestSimopenClientClientFail(t *testing.T) { }() _, _, err := SelectWithSimopenOrFail([]string{"/b"}, a) - if err != ErrNotSupported { + if !cmpErrNotSupport(err, ErrNotSupported[string]{[]string{"/b"}}) { t.Fatal(err) } a.Close()