diff --git a/p2p/security/noise/benchmark_test.go b/p2p/security/noise/benchmark_test.go index 5697856355..12009a5bd3 100644 --- a/p2p/security/noise/benchmark_test.go +++ b/p2p/security/noise/benchmark_test.go @@ -85,7 +85,7 @@ func (b benchenv) connect(stopTimer bool) (*secureSession, *secureSession) { initSession, initErr = b.initTpt.SecureOutbound(context.TODO(), initConn, b.respTpt.localID) }() - respSession, respErr := b.respTpt.SecureInbound(context.TODO(), respConn) + respSession, respErr := b.respTpt.SecureInbound(context.TODO(), respConn, "") <-done if initErr != nil { diff --git a/p2p/security/noise/handshake.go b/p2p/security/noise/handshake.go index f872af9019..d70a52958a 100644 --- a/p2p/security/noise/handshake.go +++ b/p2p/security/noise/handshake.go @@ -60,9 +60,9 @@ func (s *secureSession) runHandshake(ctx context.Context) error { } } - // We can re-use this buffer for all handshake messages as it's size + // We can re-use this buffer for all handshake messages as its size // will be the size of the maximum handshake message for the Noise XX pattern. - // Also, since we prefix every noise handshake message with it's length, we need to account for + // Also, since we prefix every noise handshake message with its length, we need to account for // it when we fetch the buffer from the pool maxMsgSize := 2*noise.DH25519.DHLen() + len(payload) + 2*poly1305.TagSize hbuf := pool.Get(maxMsgSize + LengthPrefixLength) @@ -242,8 +242,10 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStati return err } - // if we know who we're trying to reach, make sure we have the right peer - if s.initiator && s.remoteID != id { + // check the peer ID for: + // * all outbound connection + // * inbound connections, if we know which peer we want to connect to (SecureInbound called with a peer ID) + if (s.initiator && s.remoteID != id) || (!s.initiator && s.remoteID != "" && s.remoteID != id) { // use Pretty() as it produces the full b58-encoded string, rather than abbreviated forms. return fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty()) } diff --git a/p2p/security/noise/transport.go b/p2p/security/noise/transport.go index a1daa7d3d5..c8d7a44f50 100644 --- a/p2p/security/noise/transport.go +++ b/p2p/security/noise/transport.go @@ -36,8 +36,9 @@ func New(privkey crypto.PrivKey) (*Transport, error) { } // SecureInbound runs the Noise handshake as the responder. -func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.SecureConn, error) { - return newSecureSession(t, ctx, insecure, "", false) +// If p is empty, connections from any peer are accepted. +func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { + return newSecureSession(t, ctx, insecure, p, false) } // SecureOutbound runs the Noise handshake as the initiator. diff --git a/p2p/security/noise/transport_test.go b/p2p/security/noise/transport_test.go index 818a22980d..b65b9cb5c2 100644 --- a/p2p/security/noise/transport_test.go +++ b/p2p/security/noise/transport_test.go @@ -5,14 +5,17 @@ import ( "context" "encoding/binary" "errors" - "golang.org/x/crypto/poly1305" "io" "math/rand" "net" "testing" "time" - crypto "github.com/libp2p/go-libp2p-core/crypto" + "github.com/stretchr/testify/assert" + + "golang.org/x/crypto/poly1305" + + "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/sec" @@ -79,7 +82,7 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess initConn, initErr = initTransport.SecureOutbound(context.TODO(), init, respTransport.localID) }() - respConn, respErr := respTransport.SecureInbound(context.TODO(), resp) + respConn, respErr := respTransport.SecureInbound(context.TODO(), resp, "") <-done if initErr != nil { @@ -161,24 +164,66 @@ func TestKeys(t *testing.T) { } } -func TestPeerIDMismatchFailsHandshake(t *testing.T) { +func TestPeerIDMatch(t *testing.T) { initTransport := newTestTransport(t, crypto.Ed25519, 2048) respTransport := newTestTransport(t, crypto.Ed25519, 2048) init, resp := newConnPair(t) - var initErr error done := make(chan struct{}) go func() { defer close(done) - _, initErr = initTransport.SecureOutbound(context.TODO(), init, "a-random-peer-id") + conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID) + assert.NoError(t, err) + assert.Equal(t, conn.RemotePeer(), respTransport.localID) + b := make([]byte, 6) + _, err = conn.Read(b) + assert.NoError(t, err) + assert.Equal(t, b, []byte("foobar")) }() - _, _ = respTransport.SecureInbound(context.TODO(), resp) - <-done + conn, err := respTransport.SecureInbound(context.TODO(), resp, initTransport.localID) + require.NoError(t, err) + require.Equal(t, conn.RemotePeer(), initTransport.localID) + _, err = conn.Write([]byte("foobar")) + require.NoError(t, err) +} - if initErr == nil { - t.Fatal("expected initiator to fail with peer ID mismatch error") - } +func TestPeerIDMismatchOutboundFailsHandshake(t *testing.T) { + initTransport := newTestTransport(t, crypto.Ed25519, 2048) + respTransport := newTestTransport(t, crypto.Ed25519, 2048) + init, resp := newConnPair(t) + + errChan := make(chan error) + go func() { + _, err := initTransport.SecureOutbound(context.TODO(), init, "a-random-peer-id") + errChan <- err + }() + + _, err := respTransport.SecureInbound(context.TODO(), resp, "") + require.Error(t, err) + + initErr := <-errChan + require.Error(t, initErr, "expected initiator to fail with peer ID mismatch error") + require.Contains(t, initErr.Error(), "but remote key matches") +} + +func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) { + initTransport := newTestTransport(t, crypto.Ed25519, 2048) + respTransport := newTestTransport(t, crypto.Ed25519, 2048) + init, resp := newConnPair(t) + + done := make(chan struct{}) + go func() { + defer close(done) + conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID) + assert.NoError(t, err) + _, err = conn.Read([]byte{0}) + assert.Error(t, err) + }() + + _, err := respTransport.SecureInbound(context.TODO(), resp, "a-random-peer-id") + require.Error(t, err, "expected responder to fail with peer ID mismatch error") + <-done } func makeLargePlaintext(size int) []byte {