diff --git a/multistream.go b/multistream.go index ff57554..b3de4c0 100644 --- a/multistream.go +++ b/multistream.go @@ -235,9 +235,10 @@ loop: continue loop } - if err := delimWriteBuffered(rwc, []byte(tok)); err != nil { - return "", nil, err - } + // Ignore the error here. We want the handshake to finish, even if the + // other side has closed this rwc for writing. They may have sent us a + // message and closed. Future writers will get an error anyways. + _ = delimWriteBuffered(rwc, []byte(tok)) // hand off processing to the sub-protocol handler return tok, h.Handle, nil diff --git a/multistream_test.go b/multistream_test.go index 0de859e..5664862 100644 --- a/multistream_test.go +++ b/multistream_test.go @@ -688,6 +688,81 @@ func TestNegotiateFail(t *testing.T) { } } +type mockStream struct { + expectWrite [][]byte + toRead [][]byte +} + +func (s *mockStream) Close() error { + return nil +} + +func (s *mockStream) Write(p []byte) (n int, err error) { + if len(s.expectWrite) == 0 { + return 0, fmt.Errorf("no more writes expected") + } + + if !bytes.Equal(s.expectWrite[0], p) { + return 0, fmt.Errorf("unexpected write") + } + + s.expectWrite = s.expectWrite[1:] + return len(p), nil +} + +func (s *mockStream) Read(p []byte) (n int, err error) { + if len(s.toRead) == 0 { + return 0, fmt.Errorf("no more reads expected") + } + + if len(p) < len(s.toRead[0]) { + copy(p, s.toRead[0]) + s.toRead[0] = s.toRead[0][len(p):] + n = len(p) + } else { + copy(p, s.toRead[0]) + n = len(s.toRead[0]) + s.toRead = s.toRead[1:] + } + + return n, nil +} + +func TestNegotiatePeerSendsAndCloses(t *testing.T) { + // Tests the case where a peer will negotiate a protocol, send data, then close the stream immediately + var buf bytes.Buffer + err := delimWrite(&buf, []byte(ProtocolID)) + if err != nil { + t.Fatal(err) + } + delimtedProtocolID := make([]byte, buf.Len()) + copy(delimtedProtocolID, buf.Bytes()) + + err = delimWrite(&buf, []byte("foo")) + if err != nil { + t.Fatal(err) + } + err = delimWrite(&buf, []byte("somedata")) + if err != nil { + t.Fatal(err) + } + + s := &mockStream{ + // We mock the closed stream by only expecting a single write. The + // mockstream will error on any more writes (same as writing to a closed + // stream) + expectWrite: [][]byte{delimtedProtocolID}, + toRead: [][]byte{buf.Bytes()}, + } + + mux := NewMultistreamMuxer() + mux.AddHandler("foo", nil) + _, _, err = mux.Negotiate(s) + if err != nil { + t.Fatal("Negotiate should not fail here", err) + } +} + func TestSimopenClientServer(t *testing.T) { a, b := newPipe(t)