Skip to content

Commit

Permalink
Merge pull request #104 from libp2p/check-peer-id-on-inbound
Browse files Browse the repository at this point in the history
add the peer ID to SecureInbound
  • Loading branch information
marten-seemann authored Sep 8, 2021
2 parents 123a116 + a05bdd7 commit 72fe0a0
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 18 deletions.
2 changes: 1 addition & 1 deletion p2p/security/noise/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (b benchenv) connect(stopTimer bool) (*secureSession, *secureSession) {
initSession, initErr = b.initTpt.SecureOutbound(context.TODO(), initConn, b.respTpt.localID)
}()

respSession, respErr := b.respTpt.SecureInbound(context.TODO(), respConn)
respSession, respErr := b.respTpt.SecureInbound(context.TODO(), respConn, "")
<-done

if initErr != nil {
Expand Down
10 changes: 6 additions & 4 deletions p2p/security/noise/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
}
}

// We can re-use this buffer for all handshake messages as it's size
// We can re-use this buffer for all handshake messages as its size
// will be the size of the maximum handshake message for the Noise XX pattern.
// Also, since we prefix every noise handshake message with it's length, we need to account for
// Also, since we prefix every noise handshake message with its length, we need to account for
// it when we fetch the buffer from the pool
maxMsgSize := 2*noise.DH25519.DHLen() + len(payload) + 2*poly1305.TagSize
hbuf := pool.Get(maxMsgSize + LengthPrefixLength)
Expand Down Expand Up @@ -242,8 +242,10 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStati
return err
}

// if we know who we're trying to reach, make sure we have the right peer
if s.initiator && s.remoteID != id {
// check the peer ID for:
// * all outbound connection
// * inbound connections, if we know which peer we want to connect to (SecureInbound called with a peer ID)
if (s.initiator && s.remoteID != id) || (!s.initiator && s.remoteID != "" && s.remoteID != id) {
// use Pretty() as it produces the full b58-encoded string, rather than abbreviated forms.
return fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty())
}
Expand Down
5 changes: 3 additions & 2 deletions p2p/security/noise/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ func New(privkey crypto.PrivKey) (*Transport, error) {
}

// SecureInbound runs the Noise handshake as the responder.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.SecureConn, error) {
return newSecureSession(t, ctx, insecure, "", false)
// If p is empty, connections from any peer are accepted.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
return newSecureSession(t, ctx, insecure, p, false)
}

// SecureOutbound runs the Noise handshake as the initiator.
Expand Down
67 changes: 56 additions & 11 deletions p2p/security/noise/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@ import (
"context"
"encoding/binary"
"errors"
"golang.org/x/crypto/poly1305"
"io"
"math/rand"
"net"
"testing"
"time"

crypto "github.com/libp2p/go-libp2p-core/crypto"
"github.com/stretchr/testify/assert"

"golang.org/x/crypto/poly1305"

"github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/sec"

Expand Down Expand Up @@ -79,7 +82,7 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess
initConn, initErr = initTransport.SecureOutbound(context.TODO(), init, respTransport.localID)
}()

respConn, respErr := respTransport.SecureInbound(context.TODO(), resp)
respConn, respErr := respTransport.SecureInbound(context.TODO(), resp, "")
<-done

if initErr != nil {
Expand Down Expand Up @@ -161,24 +164,66 @@ func TestKeys(t *testing.T) {
}
}

func TestPeerIDMismatchFailsHandshake(t *testing.T) {
func TestPeerIDMatch(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
init, resp := newConnPair(t)

var initErr error
done := make(chan struct{})
go func() {
defer close(done)
_, initErr = initTransport.SecureOutbound(context.TODO(), init, "a-random-peer-id")
conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID)
assert.NoError(t, err)
assert.Equal(t, conn.RemotePeer(), respTransport.localID)
b := make([]byte, 6)
_, err = conn.Read(b)
assert.NoError(t, err)
assert.Equal(t, b, []byte("foobar"))
}()

_, _ = respTransport.SecureInbound(context.TODO(), resp)
<-done
conn, err := respTransport.SecureInbound(context.TODO(), resp, initTransport.localID)
require.NoError(t, err)
require.Equal(t, conn.RemotePeer(), initTransport.localID)
_, err = conn.Write([]byte("foobar"))
require.NoError(t, err)
}

if initErr == nil {
t.Fatal("expected initiator to fail with peer ID mismatch error")
}
func TestPeerIDMismatchOutboundFailsHandshake(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
init, resp := newConnPair(t)

errChan := make(chan error)
go func() {
_, err := initTransport.SecureOutbound(context.TODO(), init, "a-random-peer-id")
errChan <- err
}()

_, err := respTransport.SecureInbound(context.TODO(), resp, "")
require.Error(t, err)

initErr := <-errChan
require.Error(t, initErr, "expected initiator to fail with peer ID mismatch error")
require.Contains(t, initErr.Error(), "but remote key matches")
}

func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
init, resp := newConnPair(t)

done := make(chan struct{})
go func() {
defer close(done)
conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID)
assert.NoError(t, err)
_, err = conn.Read([]byte{0})
assert.Error(t, err)
}()

_, err := respTransport.SecureInbound(context.TODO(), resp, "a-random-peer-id")
require.Error(t, err, "expected responder to fail with peer ID mismatch error")
<-done
}

func makeLargePlaintext(size int) []byte {
Expand Down

0 comments on commit 72fe0a0

Please sign in to comment.