Skip to content

Commit

Permalink
Add not supported protocols to returned errors (#97)
Browse files Browse the repository at this point in the history
* Add not supported protocols to returned errors

* rename struct to ErrNotSupported
  • Loading branch information
sukunrt authored Jan 26, 2023
1 parent 07062ee commit 4178fda
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 19 deletions.
41 changes: 28 additions & 13 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,17 @@ import (
"strings"
)

// ErrNotSupported is the error returned when the muxer does not support
// the protocol specified for the handshake.
var ErrNotSupported = errors.New("protocol not supported")
// ErrNotSupported is the error returned when the muxer doesn't support
// the protocols tried for the handshake.
type ErrNotSupported[T StringLike] struct {

// Slice of protocols that were not supported by the muxer
Protos []T
}

func (e ErrNotSupported[T]) Error() string {
return fmt.Sprintf("protocols not supported: %v", e.Protos)
}

// ErrNoProtocols is the error returned when the no protocols have been
// specified.
Expand Down Expand Up @@ -83,14 +91,18 @@ func SelectOneOf[T StringLike](protos []T, rwc io.ReadWriteCloser) (proto T, err
// can continue negotiating the rest of the protocols normally.
//
// This saves us a round trip.
switch err := SelectProtoOrFail(protos[0], rwc); err {
switch err := SelectProtoOrFail(protos[0], rwc); err.(type) {
case nil:
return protos[0], nil
case ErrNotSupported: // try others
case ErrNotSupported[T]: // try others
default:
return "", err
}
return selectProtosOrFail(protos[1:], rwc)
proto, err = selectProtosOrFail(protos[1:], rwc)
if _, ok := err.(ErrNotSupported[T]); ok {
return "", ErrNotSupported[T]{protos}
}
return proto, err
}

const simOpenProtocol = "/libp2p/simultaneous-connect"
Expand Down Expand Up @@ -161,7 +173,11 @@ func clientOpen[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, error) {
case protos[0]:
return tok, nil
case "na":
return selectProtosOrFail(protos[1:], rwc)
proto, err := selectProtosOrFail(protos[1:], rwc)
if _, ok := err.(ErrNotSupported[T]); ok {
return "", ErrNotSupported[T]{protos}
}
return proto, err
default:
return "", fmt.Errorf("unexpected response: %s", tok)
}
Expand All @@ -170,15 +186,15 @@ func clientOpen[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, error) {
func selectProtosOrFail[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, error) {
for _, p := range protos {
err := trySelect(p, rwc)
switch err {
switch err := err.(type) {
case nil:
return p, nil
case ErrNotSupported:
case ErrNotSupported[T]:
default:
return "", err
}
}
return "", ErrNotSupported
return "", ErrNotSupported[T]{protos}
}

func simOpen[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, bool, error) {
Expand Down Expand Up @@ -255,12 +271,11 @@ func simOpenSelectServer[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, e
if err = <-werrCh; err != nil {
return "", err
}

for {
tok, err = ReadNextToken[T](rwc)

if err == io.EOF {
return "", ErrNotSupported
return "", ErrNotSupported[T]{protos}
}

if err != nil {
Expand Down Expand Up @@ -337,7 +352,7 @@ func readProto[T StringLike](proto T, r io.Reader) error {
case proto:
return nil
case "na":
return ErrNotSupported
return ErrNotSupported[T]{[]T{proto}}
default:
return fmt.Errorf("unrecognized response: %s", tok)
}
Expand Down
2 changes: 1 addition & 1 deletion lazyClient.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (l *lazyClientConn[T]) doReadHandshake() {
}

if tok == "na" {
l.rerr = ErrNotSupported
l.rerr = ErrNotSupported[T]{[]T{proto}}
return
}
if tok != proto {
Expand Down
23 changes: 18 additions & 5 deletions multistream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@ func newRwcStrict(t *testing.T, rwc io.ReadWriteCloser) io.ReadWriteCloser {
return &rwcStrict{t: t, rwc: rwc}
}

func cmpErrNotSupport(e1 error, e2 ErrNotSupported[string]) bool {
e, ok := e1.(ErrNotSupported[string])
if !ok || len(e.Protos) != len(e2.Protos) {
return false
}
for i := 0; i < len(e.Protos); i++ {
if e.Protos[i] != e2.Protos[i] {
return false
}
}
return true
}

func (s *rwcStrict) Read(b []byte) (int, error) {
if s.reading {
s.t.Error("concurrent read")
Expand Down Expand Up @@ -160,7 +173,7 @@ func TestProtocolNegotiationUnsupported(t *testing.T) {
c := NewMSSelect(b, "/foo")
c.Write([]byte("foo protocol data"))
_, err := c.Read([]byte{0})
if err != ErrNotSupported {
if !cmpErrNotSupport(err, ErrNotSupported[string]{[]string{"/foo"}}) {
t.Fatalf("expected protocol /foo to be unsupported, got: %v", err)
}
c.Close()
Expand Down Expand Up @@ -349,7 +362,7 @@ func TestSelectFails(t *testing.T) {
go mux.Negotiate(a)

_, err := SelectOneOf([]string{"/d", "/e"}, b)
if err != ErrNotSupported {
if !cmpErrNotSupport(err, ErrNotSupported[string]{[]string{"/d", "/e"}}) {
t.Fatal("expected to not be supported")
}
}
Expand Down Expand Up @@ -842,7 +855,7 @@ func TestSimopenClientServerFail(t *testing.T) {
}()

_, _, err := SelectWithSimopenOrFail([]string{"/b"}, b)
if err != ErrNotSupported {
if !cmpErrNotSupport(err, ErrNotSupported[string]{[]string{"/b"}}) {
t.Fatal(err)
}
b.Close()
Expand Down Expand Up @@ -936,15 +949,15 @@ func TestSimopenClientClientFail(t *testing.T) {
done := make(chan struct{})
go func() {
_, _, err := SelectWithSimopenOrFail([]string{"/a"}, b)
if err != ErrNotSupported {
if !cmpErrNotSupport(err, ErrNotSupported[string]{[]string{"/a"}}) {
t.Error(err)
}
b.Close()
close(done)
}()

_, _, err := SelectWithSimopenOrFail([]string{"/b"}, a)
if err != ErrNotSupported {
if !cmpErrNotSupport(err, ErrNotSupported[string]{[]string{"/b"}}) {
t.Fatal(err)
}
a.Close()
Expand Down

0 comments on commit 4178fda

Please sign in to comment.