diff --git a/go.mod b/go.mod index f5dfa499d1d5..9a629725cdef 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/DataDog/zstd v1.5.2 github.com/Microsoft/go-winio v0.5.2 github.com/NYTimes/gziphandler v1.1.1 - github.com/ava-labs/coreth v0.12.9-rc.5 + github.com/ava-labs/coreth v0.12.9-rc.7 github.com/ava-labs/ledger-avalanche/go v0.0.0-20231102202641-ae2ebdaeac34 github.com/btcsuite/btcd/btcutil v1.1.3 github.com/cockroachdb/pebble v0.0.0-20230209160836-829675f94811 diff --git a/go.sum b/go.sum index 35141a6c1638..3478c17b2d6b 100644 --- a/go.sum +++ b/go.sum @@ -66,8 +66,8 @@ github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= -github.com/ava-labs/coreth v0.12.9-rc.5 h1:xYBgNm1uOPfUdUNm8+fS8ellHnEd4qfFNb6uZHo9tqI= -github.com/ava-labs/coreth v0.12.9-rc.5/go.mod h1:rECKQfGFDeodrwGPlJSvFUJDbVr30jSMIVjQLi6pNX4= +github.com/ava-labs/coreth v0.12.9-rc.7 h1:AlCmXnrJwo0NxlEXQHysQgRQSCA14PZW6iyJmeVYB34= +github.com/ava-labs/coreth v0.12.9-rc.7/go.mod h1:yrf2vEah4Fgj6sJ4UpHewo4DLolwdpf2bJuLRT80PGw= github.com/ava-labs/ledger-avalanche/go v0.0.0-20231102202641-ae2ebdaeac34 h1:mg9Uw6oZFJKytJxgxnl3uxZOs/SB8CVHg6Io4Tf99Zc= github.com/ava-labs/ledger-avalanche/go v0.0.0-20231102202641-ae2ebdaeac34/go.mod h1:pJxaT9bUgeRNVmNRgtCHb7sFDIRKy7CzTQVi8gGNT6g= github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= diff --git a/network/p2p/client.go b/network/p2p/client.go index d950a4b0a227..1c3c9bee01da 100644 --- a/network/p2p/client.go +++ b/network/p2p/client.go @@ -42,10 +42,9 @@ type CrossChainAppResponseCallback func( type Client struct { handlerID uint64 handlerPrefix []byte - router *Router + router *router sender common.AppSender - // nodeSampler is used to select nodes to route AppRequestAny to - nodeSampler NodeSampler + options *clientOptions } // AppRequestAny issues an AppRequest to an arbitrary node decided by Client. @@ -56,7 +55,7 @@ func (c *Client) AppRequestAny( appRequestBytes []byte, onResponse AppResponseCallback, ) error { - sampled := c.nodeSampler.Sample(ctx, 1) + sampled := c.options.nodeSampler.Sample(ctx, 1) if len(sampled) != 1 { return ErrNoPeers } diff --git a/network/p2p/gossip/gossip_test.go b/network/p2p/gossip/gossip_test.go index eb4b23ecd9c8..d30fac0008e7 100644 --- a/network/p2p/gossip/gossip_test.go +++ b/network/p2p/gossip/gossip_test.go @@ -13,8 +13,6 @@ import ( "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/network/p2p" "github.com/ava-labs/avalanchego/snow/engine/common" @@ -117,10 +115,9 @@ func TestGossiperGossip(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { require := require.New(t) - ctrl := gomock.NewController(t) - responseSender := common.NewMockSender(ctrl) - responseRouter := p2p.NewRouter(logging.NoLog{}, responseSender, prometheus.NewRegistry(), "") + responseSender := &common.SenderTest{} + responseNetwork := p2p.NewNetwork(logging.NoLog{}, responseSender, prometheus.NewRegistry(), "") responseBloom, err := NewBloomFilter(1000, 0.01) require.NoError(err) responseSet := testSet{ @@ -130,31 +127,30 @@ func TestGossiperGossip(t *testing.T) { for _, item := range tt.responder { require.NoError(responseSet.Add(item)) } - peers := &p2p.Peers{} - require.NoError(peers.Connected(context.Background(), ids.EmptyNodeID, nil)) handler, err := NewHandler[*testTx](responseSet, tt.config, prometheus.NewRegistry()) require.NoError(err) - _, err = responseRouter.RegisterAppProtocol(0x0, handler, peers) + _, err = responseNetwork.NewAppProtocol(0x0, handler) require.NoError(err) - requestSender := common.NewMockSender(ctrl) - requestRouter := p2p.NewRouter(logging.NoLog{}, requestSender, prometheus.NewRegistry(), "") - - gossiped := make(chan struct{}) - requestSender.EXPECT().SendAppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Do(func(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, request []byte) { + requestSender := &common.SenderTest{ + SendAppRequestF: func(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, request []byte) error { go func() { - require.NoError(responseRouter.AppRequest(ctx, ids.EmptyNodeID, requestID, time.Time{}, request)) + require.NoError(responseNetwork.AppRequest(ctx, ids.EmptyNodeID, requestID, time.Time{}, request)) }() - }).AnyTimes() + return nil + }, + } - responseSender.EXPECT(). - SendAppResponse(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Do(func(ctx context.Context, nodeID ids.NodeID, requestID uint32, appResponseBytes []byte) { - require.NoError(requestRouter.AppResponse(ctx, nodeID, requestID, appResponseBytes)) - close(gossiped) - }).AnyTimes() + requestNetwork := p2p.NewNetwork(logging.NoLog{}, requestSender, prometheus.NewRegistry(), "") + require.NoError(requestNetwork.Connected(context.Background(), ids.EmptyNodeID, nil)) + + gossiped := make(chan struct{}) + responseSender.SendAppResponseF = func(ctx context.Context, nodeID ids.NodeID, requestID uint32, appResponseBytes []byte) error { + require.NoError(requestNetwork.AppResponse(ctx, nodeID, requestID, appResponseBytes)) + close(gossiped) + return nil + } bloom, err := NewBloomFilter(1000, 0.01) require.NoError(err) @@ -166,7 +162,7 @@ func TestGossiperGossip(t *testing.T) { require.NoError(requestSet.Add(item)) } - requestClient, err := requestRouter.RegisterAppProtocol(0x0, nil, peers) + requestClient, err := requestNetwork.NewAppProtocol(0x0, nil) require.NoError(err) config := Config{ diff --git a/network/p2p/network.go b/network/p2p/network.go new file mode 100644 index 000000000000..444c2e4b9408 --- /dev/null +++ b/network/p2p/network.go @@ -0,0 +1,188 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package p2p + +import ( + "context" + "encoding/binary" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" + "github.com/ava-labs/avalanchego/version" +) + +var ( + _ validators.Connector = (*Network)(nil) + _ common.AppHandler = (*Network)(nil) + _ NodeSampler = (*peerSampler)(nil) +) + +// ClientOption configures Client +type ClientOption interface { + apply(options *clientOptions) +} + +type clientOptionFunc func(options *clientOptions) + +func (o clientOptionFunc) apply(options *clientOptions) { + o(options) +} + +// WithValidatorSampling configures Client.AppRequestAny to sample validators +func WithValidatorSampling(validators *Validators) ClientOption { + return clientOptionFunc(func(options *clientOptions) { + options.nodeSampler = validators + }) +} + +// clientOptions holds client-configurable values +type clientOptions struct { + // nodeSampler is used to select nodes to route Client.AppRequestAny to + nodeSampler NodeSampler +} + +// NewNetwork returns an instance of Network +func NewNetwork( + log logging.Logger, + sender common.AppSender, + metrics prometheus.Registerer, + namespace string, +) *Network { + return &Network{ + Peers: &Peers{}, + log: log, + sender: sender, + metrics: metrics, + namespace: namespace, + router: newRouter(log, sender, metrics, namespace), + } +} + +// Network exposes networking state and supports building p2p application +// protocols +type Network struct { + Peers *Peers + + log logging.Logger + sender common.AppSender + metrics prometheus.Registerer + namespace string + + router *router +} + +func (n *Network) AppRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, deadline time.Time, request []byte) error { + return n.router.AppRequest(ctx, nodeID, requestID, deadline, request) +} + +func (n *Network) AppResponse(ctx context.Context, nodeID ids.NodeID, requestID uint32, response []byte) error { + return n.router.AppResponse(ctx, nodeID, requestID, response) +} + +func (n *Network) AppRequestFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { + return n.router.AppRequestFailed(ctx, nodeID, requestID) +} + +func (n *Network) AppGossip(ctx context.Context, nodeID ids.NodeID, msg []byte) error { + return n.router.AppGossip(ctx, nodeID, msg) +} + +func (n *Network) CrossChainAppRequest(ctx context.Context, chainID ids.ID, requestID uint32, deadline time.Time, request []byte) error { + return n.router.CrossChainAppRequest(ctx, chainID, requestID, deadline, request) +} + +func (n *Network) CrossChainAppResponse(ctx context.Context, chainID ids.ID, requestID uint32, response []byte) error { + return n.router.CrossChainAppResponse(ctx, chainID, requestID, response) +} + +func (n *Network) CrossChainAppRequestFailed(ctx context.Context, chainID ids.ID, requestID uint32) error { + return n.router.CrossChainAppRequestFailed(ctx, chainID, requestID) +} + +func (n *Network) Connected(_ context.Context, nodeID ids.NodeID, _ *version.Application) error { + n.Peers.add(nodeID) + return nil +} + +func (n *Network) Disconnected(_ context.Context, nodeID ids.NodeID) error { + n.Peers.remove(nodeID) + return nil +} + +// NewAppProtocol reserves an identifier for an application protocol handler and +// returns a Client that can be used to send messages for the corresponding +// protocol. +func (n *Network) NewAppProtocol(handlerID uint64, handler Handler, options ...ClientOption) (*Client, error) { + if err := n.router.addHandler(handlerID, handler); err != nil { + return nil, err + } + + client := &Client{ + handlerID: handlerID, + handlerPrefix: binary.AppendUvarint(nil, handlerID), + sender: n.sender, + router: n.router, + options: &clientOptions{ + nodeSampler: &peerSampler{ + peers: n.Peers, + }, + }, + } + + for _, option := range options { + option.apply(client.options) + } + + return client, nil +} + +// Peers contains metadata about the current set of connected peers +type Peers struct { + lock sync.RWMutex + set set.SampleableSet[ids.NodeID] +} + +func (p *Peers) add(nodeID ids.NodeID) { + p.lock.Lock() + defer p.lock.Unlock() + + p.set.Add(nodeID) +} + +func (p *Peers) remove(nodeID ids.NodeID) { + p.lock.Lock() + defer p.lock.Unlock() + + p.set.Remove(nodeID) +} + +func (p *Peers) has(nodeID ids.NodeID) bool { + p.lock.RLock() + defer p.lock.RUnlock() + + return p.set.Contains(nodeID) +} + +// Sample returns a pseudo-random sample of up to limit Peers +func (p *Peers) Sample(limit int) []ids.NodeID { + p.lock.RLock() + defer p.lock.RUnlock() + + return p.set.Sample(limit) +} + +type peerSampler struct { + peers *Peers +} + +func (p peerSampler) Sample(_ context.Context, limit int) []ids.NodeID { + return p.peers.Sample(limit) +} diff --git a/network/p2p/network_test.go b/network/p2p/network_test.go new file mode 100644 index 000000000000..590858a0c467 --- /dev/null +++ b/network/p2p/network_test.go @@ -0,0 +1,596 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package p2p + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/stretchr/testify/require" + + "go.uber.org/mock/gomock" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/network/p2p/mocks" + "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/math" + "github.com/ava-labs/avalanchego/utils/set" + "github.com/ava-labs/avalanchego/version" +) + +func TestAppRequestResponse(t *testing.T) { + handlerID := uint64(0x0) + request := []byte("request") + response := []byte("response") + nodeID := ids.GenerateTestNodeID() + chainID := ids.GenerateTestID() + + ctxKey := new(string) + ctxVal := new(string) + *ctxKey = "foo" + *ctxVal = "bar" + + tests := []struct { + name string + requestFunc func(t *testing.T, network *Network, client *Client, sender *common.SenderTest, handler *mocks.MockHandler, wg *sync.WaitGroup) + }{ + { + name: "app request", + requestFunc: func(t *testing.T, network *Network, client *Client, sender *common.SenderTest, handler *mocks.MockHandler, wg *sync.WaitGroup) { + sender.SendAppRequestF = func(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, request []byte) error { + for range nodeIDs { + go func() { + require.NoError(t, network.AppRequest(ctx, nodeID, requestID, time.Time{}, request)) + }() + } + + return nil + } + sender.SendAppResponseF = func(ctx context.Context, _ ids.NodeID, requestID uint32, response []byte) error { + go func() { + ctx = context.WithValue(ctx, ctxKey, ctxVal) + require.NoError(t, network.AppResponse(ctx, nodeID, requestID, response)) + }() + + return nil + } + handler.EXPECT(). + AppRequest(context.Background(), nodeID, gomock.Any(), request). + DoAndReturn(func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) { + return response, nil + }) + + callback := func(ctx context.Context, actualNodeID ids.NodeID, actualResponse []byte, err error) { + defer wg.Done() + + require.NoError(t, err) + require.Equal(t, ctxVal, ctx.Value(ctxKey)) + require.Equal(t, nodeID, actualNodeID) + require.Equal(t, response, actualResponse) + } + + require.NoError(t, client.AppRequestAny(context.Background(), request, callback)) + }, + }, + { + name: "app request failed", + requestFunc: func(t *testing.T, network *Network, client *Client, sender *common.SenderTest, handler *mocks.MockHandler, wg *sync.WaitGroup) { + sender.SendAppRequestF = func(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, request []byte) error { + for range nodeIDs { + go func() { + require.NoError(t, network.AppRequestFailed(ctx, nodeID, requestID)) + }() + } + + return nil + } + + callback := func(_ context.Context, actualNodeID ids.NodeID, actualResponse []byte, err error) { + defer wg.Done() + + require.ErrorIs(t, err, ErrAppRequestFailed) + require.Equal(t, nodeID, actualNodeID) + require.Nil(t, actualResponse) + } + + require.NoError(t, client.AppRequest(context.Background(), set.Of(nodeID), request, callback)) + }, + }, + { + name: "cross-chain app request", + requestFunc: func(t *testing.T, network *Network, client *Client, sender *common.SenderTest, handler *mocks.MockHandler, wg *sync.WaitGroup) { + chainID := ids.GenerateTestID() + sender.SendCrossChainAppRequestF = func(ctx context.Context, chainID ids.ID, requestID uint32, request []byte) { + go func() { + require.NoError(t, network.CrossChainAppRequest(ctx, chainID, requestID, time.Time{}, request)) + }() + } + sender.SendCrossChainAppResponseF = func(ctx context.Context, chainID ids.ID, requestID uint32, response []byte) { + go func() { + ctx = context.WithValue(ctx, ctxKey, ctxVal) + require.NoError(t, network.CrossChainAppResponse(ctx, chainID, requestID, response)) + }() + } + handler.EXPECT(). + CrossChainAppRequest(context.Background(), chainID, gomock.Any(), request). + DoAndReturn(func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) { + return response, nil + }) + + callback := func(ctx context.Context, actualChainID ids.ID, actualResponse []byte, err error) { + defer wg.Done() + require.NoError(t, err) + require.Equal(t, ctxVal, ctx.Value(ctxKey)) + require.Equal(t, chainID, actualChainID) + require.Equal(t, response, actualResponse) + } + + require.NoError(t, client.CrossChainAppRequest(context.Background(), chainID, request, callback)) + }, + }, + { + name: "cross-chain app request failed", + requestFunc: func(t *testing.T, network *Network, client *Client, sender *common.SenderTest, handler *mocks.MockHandler, wg *sync.WaitGroup) { + sender.SendCrossChainAppRequestF = func(ctx context.Context, chainID ids.ID, requestID uint32, request []byte) { + go func() { + require.NoError(t, network.CrossChainAppRequestFailed(ctx, chainID, requestID)) + }() + } + + callback := func(_ context.Context, actualChainID ids.ID, actualResponse []byte, err error) { + defer wg.Done() + + require.ErrorIs(t, err, ErrAppRequestFailed) + require.Equal(t, chainID, actualChainID) + require.Nil(t, actualResponse) + } + + require.NoError(t, client.CrossChainAppRequest(context.Background(), chainID, request, callback)) + }, + }, + { + name: "app gossip", + requestFunc: func(t *testing.T, network *Network, client *Client, sender *common.SenderTest, handler *mocks.MockHandler, wg *sync.WaitGroup) { + sender.SendAppGossipF = func(ctx context.Context, gossip []byte) error { + go func() { + require.NoError(t, network.AppGossip(ctx, nodeID, gossip)) + }() + + return nil + } + handler.EXPECT(). + AppGossip(context.Background(), nodeID, request). + DoAndReturn(func(context.Context, ids.NodeID, []byte) error { + defer wg.Done() + return nil + }) + + require.NoError(t, client.AppGossip(context.Background(), request)) + }, + }, + { + name: "app gossip specific", + requestFunc: func(t *testing.T, network *Network, client *Client, sender *common.SenderTest, handler *mocks.MockHandler, wg *sync.WaitGroup) { + sender.SendAppGossipSpecificF = func(ctx context.Context, nodeIDs set.Set[ids.NodeID], bytes []byte) error { + for n := range nodeIDs { + nodeID := n + go func() { + require.NoError(t, network.AppGossip(ctx, nodeID, bytes)) + }() + } + + return nil + } + handler.EXPECT(). + AppGossip(context.Background(), nodeID, request). + DoAndReturn(func(context.Context, ids.NodeID, []byte) error { + defer wg.Done() + return nil + }) + + require.NoError(t, client.AppGossipSpecific(context.Background(), set.Of(nodeID), request)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + ctrl := gomock.NewController(t) + + sender := &common.SenderTest{} + handler := mocks.NewMockHandler(ctrl) + n := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "") + require.NoError(n.Connected(context.Background(), nodeID, nil)) + client, err := n.NewAppProtocol(handlerID, handler) + require.NoError(err) + + wg := &sync.WaitGroup{} + wg.Add(1) + tt.requestFunc(t, n, client, sender, handler, wg) + wg.Wait() + }) + } +} + +func TestNetworkDropMessage(t *testing.T) { + unregistered := byte(0x0) + + tests := []struct { + name string + requestFunc func(network *Network) error + err error + }{ + { + name: "drop unregistered app request message", + requestFunc: func(network *Network) error { + return network.AppRequest(context.Background(), ids.GenerateTestNodeID(), 0, time.Time{}, []byte{unregistered}) + }, + err: nil, + }, + { + name: "drop empty app request message", + requestFunc: func(network *Network) error { + return network.AppRequest(context.Background(), ids.GenerateTestNodeID(), 0, time.Time{}, []byte{}) + }, + err: nil, + }, + { + name: "drop unregistered cross-chain app request message", + requestFunc: func(network *Network) error { + return network.CrossChainAppRequest(context.Background(), ids.GenerateTestID(), 0, time.Time{}, []byte{unregistered}) + }, + err: nil, + }, + { + name: "drop empty cross-chain app request message", + requestFunc: func(network *Network) error { + return network.CrossChainAppRequest(context.Background(), ids.GenerateTestID(), 0, time.Time{}, []byte{}) + }, + err: nil, + }, + { + name: "drop unregistered gossip message", + requestFunc: func(network *Network) error { + return network.AppGossip(context.Background(), ids.GenerateTestNodeID(), []byte{unregistered}) + }, + err: nil, + }, + { + name: "drop empty gossip message", + requestFunc: func(network *Network) error { + return network.AppGossip(context.Background(), ids.GenerateTestNodeID(), []byte{}) + }, + err: nil, + }, + { + name: "drop unrequested app request failed", + requestFunc: func(network *Network) error { + return network.AppRequestFailed(context.Background(), ids.GenerateTestNodeID(), 0) + }, + err: ErrUnrequestedResponse, + }, + { + name: "drop unrequested app response", + requestFunc: func(network *Network) error { + return network.AppResponse(context.Background(), ids.GenerateTestNodeID(), 0, nil) + }, + err: ErrUnrequestedResponse, + }, + { + name: "drop unrequested cross-chain request failed", + requestFunc: func(network *Network) error { + return network.CrossChainAppRequestFailed(context.Background(), ids.GenerateTestID(), 0) + }, + err: ErrUnrequestedResponse, + }, + { + name: "drop unrequested cross-chain response", + requestFunc: func(network *Network) error { + return network.CrossChainAppResponse(context.Background(), ids.GenerateTestID(), 0, nil) + }, + err: ErrUnrequestedResponse, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + network := NewNetwork(logging.NoLog{}, &common.SenderTest{}, prometheus.NewRegistry(), "") + + err := tt.requestFunc(network) + require.ErrorIs(err, tt.err) + }) + } +} + +// It's possible for the request id to overflow and wrap around. +// If there are still pending requests with the same request id, we should +// not attempt to issue another request until the previous one has cleared. +func TestAppRequestDuplicateRequestIDs(t *testing.T) { + require := require.New(t) + ctrl := gomock.NewController(t) + + handler := mocks.NewMockHandler(ctrl) + sender := &common.SenderTest{ + SendAppResponseF: func(context.Context, ids.NodeID, uint32, []byte) error { + return nil + }, + } + network := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "") + nodeID := ids.GenerateTestNodeID() + + requestSent := &sync.WaitGroup{} + sender.SendAppRequestF = func(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, request []byte) error { + for range nodeIDs { + requestSent.Add(1) + go func() { + require.NoError(network.AppRequest(ctx, nodeID, requestID, time.Time{}, request)) + requestSent.Done() + }() + } + + return nil + } + + timeout := &sync.WaitGroup{} + response := []byte("response") + handler.EXPECT().AppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, nodeID ids.NodeID, deadline time.Time, request []byte) ([]byte, error) { + timeout.Wait() + return response, nil + }).AnyTimes() + + require.NoError(network.Connected(context.Background(), nodeID, nil)) + client, err := network.NewAppProtocol(0x1, handler) + require.NoError(err) + + onResponse := func(ctx context.Context, nodeID ids.NodeID, got []byte, err error) { + require.NoError(err) + require.Equal(response, got) + } + + require.NoError(client.AppRequest(context.Background(), set.Of(nodeID), []byte{}, onResponse)) + requestSent.Wait() + + // force the network to use the same requestID + network.router.requestID = 1 + timeout.Add(1) + err = client.AppRequest(context.Background(), set.Of(nodeID), []byte{}, nil) + requestSent.Wait() + require.ErrorIs(err, ErrRequestPending) + + timeout.Done() +} + +// Sample should always return up to [limit] peers, and less if fewer than +// [limit] peers are available. +func TestPeersSample(t *testing.T) { + nodeID1 := ids.GenerateTestNodeID() + nodeID2 := ids.GenerateTestNodeID() + nodeID3 := ids.GenerateTestNodeID() + + tests := []struct { + name string + connected set.Set[ids.NodeID] + disconnected set.Set[ids.NodeID] + limit int + }{ + { + name: "no peers", + limit: 1, + }, + { + name: "one peer connected", + connected: set.Of(nodeID1), + limit: 1, + }, + { + name: "multiple peers connected", + connected: set.Of(nodeID1, nodeID2, nodeID3), + limit: 1, + }, + { + name: "peer connects and disconnects - 1", + connected: set.Of(nodeID1), + disconnected: set.Of(nodeID1), + limit: 1, + }, + { + name: "peer connects and disconnects - 2", + connected: set.Of(nodeID1, nodeID2), + disconnected: set.Of(nodeID2), + limit: 1, + }, + { + name: "peer connects and disconnects - 2", + connected: set.Of(nodeID1, nodeID2, nodeID3), + disconnected: set.Of(nodeID1, nodeID2), + limit: 1, + }, + { + name: "less than limit peers", + connected: set.Of(nodeID1, nodeID2, nodeID3), + limit: 4, + }, + { + name: "limit peers", + connected: set.Of(nodeID1, nodeID2, nodeID3), + limit: 3, + }, + { + name: "more than limit peers", + connected: set.Of(nodeID1, nodeID2, nodeID3), + limit: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + network := NewNetwork(logging.NoLog{}, &common.SenderTest{}, prometheus.NewRegistry(), "") + + for connected := range tt.connected { + require.NoError(network.Connected(context.Background(), connected, nil)) + } + + for disconnected := range tt.disconnected { + require.NoError(network.Disconnected(context.Background(), disconnected)) + } + + sampleable := set.Set[ids.NodeID]{} + sampleable.Union(tt.connected) + sampleable.Difference(tt.disconnected) + + sampled := network.Peers.Sample(tt.limit) + require.Len(sampled, math.Min(tt.limit, len(sampleable))) + require.Subset(sampleable, sampled) + }) + } +} + +func TestAppRequestAnyNodeSelection(t *testing.T) { + tests := []struct { + name string + peers []ids.NodeID + expected error + }{ + { + name: "no peers", + expected: ErrNoPeers, + }, + { + name: "has peers", + peers: []ids.NodeID{ids.GenerateTestNodeID()}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + sent := set.Set[ids.NodeID]{} + sender := &common.SenderTest{ + SendAppRequestF: func(_ context.Context, nodeIDs set.Set[ids.NodeID], _ uint32, _ []byte) error { + for nodeID := range nodeIDs { + sent.Add(nodeID) + } + return nil + }, + } + + n := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "") + for _, peer := range tt.peers { + require.NoError(n.Connected(context.Background(), peer, &version.Application{})) + } + + client, err := n.NewAppProtocol(1, nil) + require.NoError(err) + + err = client.AppRequestAny(context.Background(), []byte("foobar"), nil) + require.ErrorIs(err, tt.expected) + }) + } +} + +func TestNodeSamplerClientOption(t *testing.T) { + nodeID0 := ids.GenerateTestNodeID() + nodeID1 := ids.GenerateTestNodeID() + nodeID2 := ids.GenerateTestNodeID() + + tests := []struct { + name string + peers []ids.NodeID + option func(t *testing.T, n *Network) ClientOption + expected []ids.NodeID + expectedErr error + }{ + { + name: "default", + peers: []ids.NodeID{nodeID0, nodeID1, nodeID2}, + option: func(_ *testing.T, n *Network) ClientOption { + return clientOptionFunc(func(*clientOptions) {}) + }, + expected: []ids.NodeID{nodeID0, nodeID1, nodeID2}, + }, + { + name: "validator connected", + peers: []ids.NodeID{nodeID0, nodeID1}, + option: func(t *testing.T, n *Network) ClientOption { + state := &validators.TestState{ + GetCurrentHeightF: func(context.Context) (uint64, error) { + return 0, nil + }, + GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { + return map[ids.NodeID]*validators.GetValidatorOutput{ + nodeID1: nil, + }, nil + }, + } + + validators := NewValidators(n.Peers, n.log, ids.Empty, state, 0) + return WithValidatorSampling(validators) + }, + expected: []ids.NodeID{nodeID1}, + }, + { + name: "validator disconnected", + peers: []ids.NodeID{nodeID0}, + option: func(t *testing.T, n *Network) ClientOption { + state := &validators.TestState{ + GetCurrentHeightF: func(context.Context) (uint64, error) { + return 0, nil + }, + GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { + return map[ids.NodeID]*validators.GetValidatorOutput{ + nodeID1: nil, + }, nil + }, + } + + validators := NewValidators(n.Peers, n.log, ids.Empty, state, 0) + return WithValidatorSampling(validators) + }, + expectedErr: ErrNoPeers, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + done := make(chan struct{}) + sender := &common.SenderTest{ + SendAppRequestF: func(_ context.Context, nodeIDs set.Set[ids.NodeID], _ uint32, _ []byte) error { + require.Subset(tt.expected, nodeIDs.List()) + close(done) + return nil + }, + } + network := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "") + ctx := context.Background() + for _, peer := range tt.peers { + require.NoError(network.Connected(ctx, peer, nil)) + } + + client, err := network.NewAppProtocol(0x0, nil, tt.option(t, network)) + require.NoError(err) + + if err = client.AppRequestAny(ctx, []byte("request"), nil); err != nil { + close(done) + } + + require.ErrorIs(tt.expectedErr, err) + <-done + }) + } +} diff --git a/network/p2p/peers.go b/network/p2p/peers.go deleted file mode 100644 index 47982aeb2dc4..000000000000 --- a/network/p2p/peers.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package p2p - -import ( - "context" - "sync" - - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/validators" - "github.com/ava-labs/avalanchego/utils/set" - "github.com/ava-labs/avalanchego/version" -) - -var ( - _ validators.Connector = (*Peers)(nil) - _ NodeSampler = (*Peers)(nil) -) - -// Peers contains a set of nodes that we are connected to. -type Peers struct { - lock sync.RWMutex - peers set.SampleableSet[ids.NodeID] -} - -func (p *Peers) Connected(_ context.Context, nodeID ids.NodeID, _ *version.Application) error { - p.lock.Lock() - defer p.lock.Unlock() - - p.peers.Add(nodeID) - - return nil -} - -func (p *Peers) Disconnected(_ context.Context, nodeID ids.NodeID) error { - p.lock.Lock() - defer p.lock.Unlock() - - p.peers.Remove(nodeID) - - return nil -} - -func (p *Peers) Sample(_ context.Context, limit int) []ids.NodeID { - p.lock.RLock() - defer p.lock.RUnlock() - - return p.peers.Sample(limit) -} diff --git a/network/p2p/peers_test.go b/network/p2p/peers_test.go deleted file mode 100644 index 9835cf065b0b..000000000000 --- a/network/p2p/peers_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package p2p - -import ( - "context" - "testing" - - "github.com/prometheus/client_golang/prometheus" - - "github.com/stretchr/testify/require" - - "go.uber.org/mock/gomock" - - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/engine/common" - "github.com/ava-labs/avalanchego/utils/logging" - "github.com/ava-labs/avalanchego/utils/math" - "github.com/ava-labs/avalanchego/utils/set" -) - -// Sample should always return up to [limit] peers, and less if fewer than -// [limit] peers are available. -func TestPeersSample(t *testing.T) { - nodeID1 := ids.GenerateTestNodeID() - nodeID2 := ids.GenerateTestNodeID() - nodeID3 := ids.GenerateTestNodeID() - - tests := []struct { - name string - connected set.Set[ids.NodeID] - disconnected set.Set[ids.NodeID] - limit int - }{ - { - name: "no peers", - limit: 1, - }, - { - name: "one peer connected", - connected: set.Of(nodeID1), - limit: 1, - }, - { - name: "multiple peers connected", - connected: set.Of(nodeID1, nodeID2, nodeID3), - limit: 1, - }, - { - name: "peer connects and disconnects - 1", - connected: set.Of(nodeID1), - disconnected: set.Of(nodeID1), - limit: 1, - }, - { - name: "peer connects and disconnects - 2", - connected: set.Of(nodeID1, nodeID2), - disconnected: set.Of(nodeID2), - limit: 1, - }, - { - name: "peer connects and disconnects - 2", - connected: set.Of(nodeID1, nodeID2, nodeID3), - disconnected: set.Of(nodeID1, nodeID2), - limit: 1, - }, - { - name: "less than limit peers", - connected: set.Of(nodeID1, nodeID2, nodeID3), - limit: 4, - }, - { - name: "limit peers", - connected: set.Of(nodeID1, nodeID2, nodeID3), - limit: 3, - }, - { - name: "more than limit peers", - connected: set.Of(nodeID1, nodeID2, nodeID3), - limit: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - require := require.New(t) - peers := &Peers{} - - for connected := range tt.connected { - require.NoError(peers.Connected(context.Background(), connected, nil)) - } - - for disconnected := range tt.disconnected { - require.NoError(peers.Disconnected(context.Background(), disconnected)) - } - - sampleable := set.Set[ids.NodeID]{} - sampleable.Union(tt.connected) - sampleable.Difference(tt.disconnected) - - sampled := peers.Sample(context.Background(), tt.limit) - require.Len(sampled, math.Min(tt.limit, len(sampleable))) - require.Subset(sampleable, sampled) - }) - } -} - -func TestAppRequestAnyNodeSelection(t *testing.T) { - tests := []struct { - name string - peers []ids.NodeID - expected error - }{ - { - name: "no peers", - expected: ErrNoPeers, - }, - { - name: "has peers", - peers: []ids.NodeID{ids.GenerateTestNodeID()}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - require := require.New(t) - ctrl := gomock.NewController(t) - mockAppSender := common.NewMockSender(ctrl) - - expectedCalls := 0 - if tt.expected == nil { - expectedCalls = 1 - } - mockAppSender.EXPECT().SendAppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(expectedCalls) - - r := NewRouter(logging.NoLog{}, mockAppSender, prometheus.NewRegistry(), "") - peers := &Peers{} - for _, peer := range tt.peers { - require.NoError(peers.Connected(context.Background(), peer, nil)) - } - - client, err := r.RegisterAppProtocol(1, nil, peers) - require.NoError(err) - - err = client.AppRequestAny(context.Background(), []byte("foobar"), nil) - require.ErrorIs(err, tt.expected) - }) - } -} diff --git a/network/p2p/router.go b/network/p2p/router.go index b689b7ae1a17..110e9b6de627 100644 --- a/network/p2p/router.go +++ b/network/p2p/router.go @@ -26,7 +26,7 @@ var ( ErrExistingAppProtocol = errors.New("existing app protocol") ErrUnrequestedResponse = errors.New("unrequested response") - _ common.AppHandler = (*Router)(nil) + _ common.AppHandler = (*router)(nil) ) type metrics struct { @@ -55,10 +55,10 @@ type meteredHandler struct { *metrics } -// Router routes incoming application messages to the corresponding registered +// router routes incoming application messages to the corresponding registered // app handler. App messages must be made using the registered handler's // corresponding Client. -type Router struct { +type router struct { log logging.Logger sender common.AppSender metrics prometheus.Registerer @@ -71,14 +71,14 @@ type Router struct { requestID uint32 } -// NewRouter returns a new instance of Router -func NewRouter( +// newRouter returns a new instance of Router +func newRouter( log logging.Logger, sender common.AppSender, metrics prometheus.Registerer, namespace string, -) *Router { - return &Router{ +) *router { + return &router{ log: log, sender: sender, metrics: metrics, @@ -91,15 +91,12 @@ func NewRouter( } } -// RegisterAppProtocol reserves an identifier for an application protocol and -// returns a Client that can be used to send messages for the corresponding -// protocol. -func (r *Router) RegisterAppProtocol(handlerID uint64, handler Handler, nodeSampler NodeSampler) (*Client, error) { +func (r *router) addHandler(handlerID uint64, handler Handler) error { r.lock.Lock() defer r.lock.Unlock() if _, ok := r.handlers[handlerID]; ok { - return nil, fmt.Errorf("failed to register handler id %d: %w", handlerID, ErrExistingAppProtocol) + return fmt.Errorf("failed to register handler id %d: %w", handlerID, ErrExistingAppProtocol) } appRequestTime, err := metric.NewAverager( @@ -109,7 +106,7 @@ func (r *Router) RegisterAppProtocol(handlerID uint64, handler Handler, nodeSamp r.metrics, ) if err != nil { - return nil, fmt.Errorf("failed to register app request metric for handler_%d: %w", handlerID, err) + return fmt.Errorf("failed to register app request metric for handler_%d: %w", handlerID, err) } appRequestFailedTime, err := metric.NewAverager( @@ -119,7 +116,7 @@ func (r *Router) RegisterAppProtocol(handlerID uint64, handler Handler, nodeSamp r.metrics, ) if err != nil { - return nil, fmt.Errorf("failed to register app request failed metric for handler_%d: %w", handlerID, err) + return fmt.Errorf("failed to register app request failed metric for handler_%d: %w", handlerID, err) } appResponseTime, err := metric.NewAverager( @@ -129,7 +126,7 @@ func (r *Router) RegisterAppProtocol(handlerID uint64, handler Handler, nodeSamp r.metrics, ) if err != nil { - return nil, fmt.Errorf("failed to register app response metric for handler_%d: %w", handlerID, err) + return fmt.Errorf("failed to register app response metric for handler_%d: %w", handlerID, err) } appGossipTime, err := metric.NewAverager( @@ -139,7 +136,7 @@ func (r *Router) RegisterAppProtocol(handlerID uint64, handler Handler, nodeSamp r.metrics, ) if err != nil { - return nil, fmt.Errorf("failed to register app gossip metric for handler_%d: %w", handlerID, err) + return fmt.Errorf("failed to register app gossip metric for handler_%d: %w", handlerID, err) } crossChainAppRequestTime, err := metric.NewAverager( @@ -149,7 +146,7 @@ func (r *Router) RegisterAppProtocol(handlerID uint64, handler Handler, nodeSamp r.metrics, ) if err != nil { - return nil, fmt.Errorf("failed to register cross-chain app request metric for handler_%d: %w", handlerID, err) + return fmt.Errorf("failed to register cross-chain app request metric for handler_%d: %w", handlerID, err) } crossChainAppRequestFailedTime, err := metric.NewAverager( @@ -159,7 +156,7 @@ func (r *Router) RegisterAppProtocol(handlerID uint64, handler Handler, nodeSamp r.metrics, ) if err != nil { - return nil, fmt.Errorf("failed to register cross-chain app request failed metric for handler_%d: %w", handlerID, err) + return fmt.Errorf("failed to register cross-chain app request failed metric for handler_%d: %w", handlerID, err) } crossChainAppResponseTime, err := metric.NewAverager( @@ -169,7 +166,7 @@ func (r *Router) RegisterAppProtocol(handlerID uint64, handler Handler, nodeSamp r.metrics, ) if err != nil { - return nil, fmt.Errorf("failed to register cross-chain app response metric for handler_%d: %w", handlerID, err) + return fmt.Errorf("failed to register cross-chain app response metric for handler_%d: %w", handlerID, err) } r.handlers[handlerID] = &meteredHandler{ @@ -190,13 +187,7 @@ func (r *Router) RegisterAppProtocol(handlerID uint64, handler Handler, nodeSamp }, } - return &Client{ - handlerID: handlerID, - handlerPrefix: binary.AppendUvarint(nil, handlerID), - sender: r.sender, - router: r, - nodeSampler: nodeSampler, - }, nil + return nil } // AppRequest routes an AppRequest to a Handler based on the handler prefix. The @@ -204,7 +195,7 @@ func (r *Router) RegisterAppProtocol(handlerID uint64, handler Handler, nodeSamp // // Any error condition propagated outside Handler application logic is // considered fatal -func (r *Router) AppRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, deadline time.Time, request []byte) error { +func (r *router) AppRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, deadline time.Time, request []byte) error { start := time.Now() parsedMsg, handler, ok := r.parse(request) if !ok { @@ -232,7 +223,7 @@ func (r *Router) AppRequest(ctx context.Context, nodeID ids.NodeID, requestID ui // // Any error condition propagated outside Handler application logic is // considered fatal -func (r *Router) AppRequestFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { +func (r *router) AppRequestFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { start := time.Now() pending, ok := r.clearAppRequest(requestID) if !ok { @@ -250,7 +241,7 @@ func (r *Router) AppRequestFailed(ctx context.Context, nodeID ids.NodeID, reques // // Any error condition propagated outside Handler application logic is // considered fatal -func (r *Router) AppResponse(ctx context.Context, nodeID ids.NodeID, requestID uint32, response []byte) error { +func (r *router) AppResponse(ctx context.Context, nodeID ids.NodeID, requestID uint32, response []byte) error { start := time.Now() pending, ok := r.clearAppRequest(requestID) if !ok { @@ -268,7 +259,7 @@ func (r *Router) AppResponse(ctx context.Context, nodeID ids.NodeID, requestID u // // Any error condition propagated outside Handler application logic is // considered fatal -func (r *Router) AppGossip(ctx context.Context, nodeID ids.NodeID, gossip []byte) error { +func (r *router) AppGossip(ctx context.Context, nodeID ids.NodeID, gossip []byte) error { start := time.Now() parsedMsg, handler, ok := r.parse(gossip) if !ok { @@ -292,7 +283,7 @@ func (r *Router) AppGossip(ctx context.Context, nodeID ids.NodeID, gossip []byte // // Any error condition propagated outside Handler application logic is // considered fatal -func (r *Router) CrossChainAppRequest( +func (r *router) CrossChainAppRequest( ctx context.Context, chainID ids.ID, requestID uint32, @@ -325,7 +316,7 @@ func (r *Router) CrossChainAppRequest( // // Any error condition propagated outside Handler application logic is // considered fatal -func (r *Router) CrossChainAppRequestFailed(ctx context.Context, chainID ids.ID, requestID uint32) error { +func (r *router) CrossChainAppRequestFailed(ctx context.Context, chainID ids.ID, requestID uint32) error { start := time.Now() pending, ok := r.clearCrossChainAppRequest(requestID) if !ok { @@ -343,7 +334,7 @@ func (r *Router) CrossChainAppRequestFailed(ctx context.Context, chainID ids.ID, // // Any error condition propagated outside Handler application logic is // considered fatal -func (r *Router) CrossChainAppResponse(ctx context.Context, chainID ids.ID, requestID uint32, response []byte) error { +func (r *router) CrossChainAppResponse(ctx context.Context, chainID ids.ID, requestID uint32, response []byte) error { start := time.Now() pending, ok := r.clearCrossChainAppRequest(requestID) if !ok { @@ -365,7 +356,7 @@ func (r *Router) CrossChainAppResponse(ctx context.Context, chainID ids.ID, requ // - A boolean indicating that parsing succeeded. // // Invariant: Assumes [r.lock] isn't held. -func (r *Router) parse(msg []byte) ([]byte, *meteredHandler, bool) { +func (r *router) parse(msg []byte) ([]byte, *meteredHandler, bool) { handlerID, bytesRead := binary.Uvarint(msg) if bytesRead <= 0 { return nil, nil, false @@ -379,7 +370,7 @@ func (r *Router) parse(msg []byte) ([]byte, *meteredHandler, bool) { } // Invariant: Assumes [r.lock] isn't held. -func (r *Router) clearAppRequest(requestID uint32) (pendingAppRequest, bool) { +func (r *router) clearAppRequest(requestID uint32) (pendingAppRequest, bool) { r.lock.Lock() defer r.lock.Unlock() @@ -389,7 +380,7 @@ func (r *Router) clearAppRequest(requestID uint32) (pendingAppRequest, bool) { } // Invariant: Assumes [r.lock] isn't held. -func (r *Router) clearCrossChainAppRequest(requestID uint32) (pendingCrossChainAppRequest, bool) { +func (r *router) clearCrossChainAppRequest(requestID uint32) (pendingCrossChainAppRequest, bool) { r.lock.Lock() defer r.lock.Unlock() diff --git a/network/p2p/router_test.go b/network/p2p/router_test.go deleted file mode 100644 index 924a72b0b70a..000000000000 --- a/network/p2p/router_test.go +++ /dev/null @@ -1,360 +0,0 @@ -// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package p2p - -import ( - "context" - "sync" - "testing" - "time" - - "github.com/prometheus/client_golang/prometheus" - - "github.com/stretchr/testify/require" - - "go.uber.org/mock/gomock" - - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/network/p2p/mocks" - "github.com/ava-labs/avalanchego/snow/engine/common" - "github.com/ava-labs/avalanchego/utils/logging" - "github.com/ava-labs/avalanchego/utils/set" -) - -func TestAppRequestResponse(t *testing.T) { - handlerID := uint64(0x0) - request := []byte("request") - response := []byte("response") - nodeID := ids.GenerateTestNodeID() - chainID := ids.GenerateTestID() - - ctxKey := new(string) - ctxVal := new(string) - *ctxKey = "foo" - *ctxVal = "bar" - - tests := []struct { - name string - requestFunc func(t *testing.T, router *Router, client *Client, sender *common.MockSender, handler *mocks.MockHandler, wg *sync.WaitGroup) - }{ - { - name: "app request", - requestFunc: func(t *testing.T, router *Router, client *Client, sender *common.MockSender, handler *mocks.MockHandler, wg *sync.WaitGroup) { - sender.EXPECT().SendAppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Do(func(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, request []byte) { - for range nodeIDs { - go func() { - require.NoError(t, router.AppRequest(ctx, nodeID, requestID, time.Time{}, request)) - }() - } - }).AnyTimes() - sender.EXPECT().SendAppResponse(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Do(func(ctx context.Context, _ ids.NodeID, requestID uint32, response []byte) { - go func() { - ctx = context.WithValue(ctx, ctxKey, ctxVal) - require.NoError(t, router.AppResponse(ctx, nodeID, requestID, response)) - }() - }).AnyTimes() - handler.EXPECT(). - AppRequest(context.Background(), nodeID, gomock.Any(), request). - DoAndReturn(func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) { - return response, nil - }) - - callback := func(ctx context.Context, actualNodeID ids.NodeID, actualResponse []byte, err error) { - defer wg.Done() - - require.NoError(t, err) - require.Equal(t, ctxVal, ctx.Value(ctxKey)) - require.Equal(t, nodeID, actualNodeID) - require.Equal(t, response, actualResponse) - } - - require.NoError(t, client.AppRequestAny(context.Background(), request, callback)) - }, - }, - { - name: "app request failed", - requestFunc: func(t *testing.T, router *Router, client *Client, sender *common.MockSender, handler *mocks.MockHandler, wg *sync.WaitGroup) { - sender.EXPECT().SendAppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Do(func(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, request []byte) { - for range nodeIDs { - go func() { - require.NoError(t, router.AppRequestFailed(ctx, nodeID, requestID)) - }() - } - }) - - callback := func(_ context.Context, actualNodeID ids.NodeID, actualResponse []byte, err error) { - defer wg.Done() - - require.ErrorIs(t, err, ErrAppRequestFailed) - require.Equal(t, nodeID, actualNodeID) - require.Nil(t, actualResponse) - } - - require.NoError(t, client.AppRequest(context.Background(), set.Of(nodeID), request, callback)) - }, - }, - { - name: "cross-chain app request", - requestFunc: func(t *testing.T, router *Router, client *Client, sender *common.MockSender, handler *mocks.MockHandler, wg *sync.WaitGroup) { - chainID := ids.GenerateTestID() - sender.EXPECT().SendCrossChainAppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Do(func(ctx context.Context, chainID ids.ID, requestID uint32, request []byte) { - go func() { - require.NoError(t, router.CrossChainAppRequest(ctx, chainID, requestID, time.Time{}, request)) - }() - }).AnyTimes() - sender.EXPECT().SendCrossChainAppResponse(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Do(func(ctx context.Context, chainID ids.ID, requestID uint32, response []byte) { - go func() { - ctx = context.WithValue(ctx, ctxKey, ctxVal) - require.NoError(t, router.CrossChainAppResponse(ctx, chainID, requestID, response)) - }() - }).AnyTimes() - handler.EXPECT(). - CrossChainAppRequest(context.Background(), chainID, gomock.Any(), request). - DoAndReturn(func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) { - return response, nil - }) - - callback := func(ctx context.Context, actualChainID ids.ID, actualResponse []byte, err error) { - defer wg.Done() - require.NoError(t, err) - require.Equal(t, ctxVal, ctx.Value(ctxKey)) - require.Equal(t, chainID, actualChainID) - require.Equal(t, response, actualResponse) - } - - require.NoError(t, client.CrossChainAppRequest(context.Background(), chainID, request, callback)) - }, - }, - { - name: "cross-chain app request failed", - requestFunc: func(t *testing.T, router *Router, client *Client, sender *common.MockSender, handler *mocks.MockHandler, wg *sync.WaitGroup) { - sender.EXPECT().SendCrossChainAppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Do(func(ctx context.Context, chainID ids.ID, requestID uint32, request []byte) { - go func() { - require.NoError(t, router.CrossChainAppRequestFailed(ctx, chainID, requestID)) - }() - }) - - callback := func(_ context.Context, actualChainID ids.ID, actualResponse []byte, err error) { - defer wg.Done() - - require.ErrorIs(t, err, ErrAppRequestFailed) - require.Equal(t, chainID, actualChainID) - require.Nil(t, actualResponse) - } - - require.NoError(t, client.CrossChainAppRequest(context.Background(), chainID, request, callback)) - }, - }, - { - name: "app gossip", - requestFunc: func(t *testing.T, router *Router, client *Client, sender *common.MockSender, handler *mocks.MockHandler, wg *sync.WaitGroup) { - sender.EXPECT().SendAppGossip(gomock.Any(), gomock.Any()). - Do(func(ctx context.Context, gossip []byte) { - go func() { - require.NoError(t, router.AppGossip(ctx, nodeID, gossip)) - }() - }).AnyTimes() - handler.EXPECT(). - AppGossip(context.Background(), nodeID, request). - DoAndReturn(func(context.Context, ids.NodeID, []byte) error { - defer wg.Done() - return nil - }) - - require.NoError(t, client.AppGossip(context.Background(), request)) - }, - }, - { - name: "app gossip specific", - requestFunc: func(t *testing.T, router *Router, client *Client, sender *common.MockSender, handler *mocks.MockHandler, wg *sync.WaitGroup) { - sender.EXPECT().SendAppGossipSpecific(gomock.Any(), gomock.Any(), gomock.Any()). - Do(func(ctx context.Context, nodeIDs set.Set[ids.NodeID], gossip []byte) { - for n := range nodeIDs { - nodeID := n - go func() { - require.NoError(t, router.AppGossip(ctx, nodeID, gossip)) - }() - } - }).AnyTimes() - handler.EXPECT(). - AppGossip(context.Background(), nodeID, request). - DoAndReturn(func(context.Context, ids.NodeID, []byte) error { - defer wg.Done() - return nil - }) - - require.NoError(t, client.AppGossipSpecific(context.Background(), set.Of(nodeID), request)) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - require := require.New(t) - ctrl := gomock.NewController(t) - - sender := common.NewMockSender(ctrl) - handler := mocks.NewMockHandler(ctrl) - router := NewRouter(logging.NoLog{}, sender, prometheus.NewRegistry(), "") - peers := &Peers{} - require.NoError(peers.Connected(context.Background(), nodeID, nil)) - client, err := router.RegisterAppProtocol(handlerID, handler, peers) - require.NoError(err) - - wg := &sync.WaitGroup{} - wg.Add(1) - tt.requestFunc(t, router, client, sender, handler, wg) - wg.Wait() - }) - } -} - -func TestRouterDropMessage(t *testing.T) { - unregistered := byte(0x0) - - tests := []struct { - name string - requestFunc func(router *Router) error - err error - }{ - { - name: "drop unregistered app request message", - requestFunc: func(router *Router) error { - return router.AppRequest(context.Background(), ids.GenerateTestNodeID(), 0, time.Time{}, []byte{unregistered}) - }, - err: nil, - }, - { - name: "drop empty app request message", - requestFunc: func(router *Router) error { - return router.AppRequest(context.Background(), ids.GenerateTestNodeID(), 0, time.Time{}, []byte{}) - }, - err: nil, - }, - { - name: "drop unregistered cross-chain app request message", - requestFunc: func(router *Router) error { - return router.CrossChainAppRequest(context.Background(), ids.GenerateTestID(), 0, time.Time{}, []byte{unregistered}) - }, - err: nil, - }, - { - name: "drop empty cross-chain app request message", - requestFunc: func(router *Router) error { - return router.CrossChainAppRequest(context.Background(), ids.GenerateTestID(), 0, time.Time{}, []byte{}) - }, - err: nil, - }, - { - name: "drop unregistered gossip message", - requestFunc: func(router *Router) error { - return router.AppGossip(context.Background(), ids.GenerateTestNodeID(), []byte{unregistered}) - }, - err: nil, - }, - { - name: "drop empty gossip message", - requestFunc: func(router *Router) error { - return router.AppGossip(context.Background(), ids.GenerateTestNodeID(), []byte{}) - }, - err: nil, - }, - { - name: "drop unrequested app request failed", - requestFunc: func(router *Router) error { - return router.AppRequestFailed(context.Background(), ids.GenerateTestNodeID(), 0) - }, - err: ErrUnrequestedResponse, - }, - { - name: "drop unrequested app response", - requestFunc: func(router *Router) error { - return router.AppResponse(context.Background(), ids.GenerateTestNodeID(), 0, nil) - }, - err: ErrUnrequestedResponse, - }, - { - name: "drop unrequested cross-chain request failed", - requestFunc: func(router *Router) error { - return router.CrossChainAppRequestFailed(context.Background(), ids.GenerateTestID(), 0) - }, - err: ErrUnrequestedResponse, - }, - { - name: "drop unrequested cross-chain response", - requestFunc: func(router *Router) error { - return router.CrossChainAppResponse(context.Background(), ids.GenerateTestID(), 0, nil) - }, - err: ErrUnrequestedResponse, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - require := require.New(t) - - router := NewRouter(logging.NoLog{}, nil, prometheus.NewRegistry(), "") - - err := tt.requestFunc(router) - require.ErrorIs(err, tt.err) - }) - } -} - -// It's possible for the request id to overflow and wrap around. -// If there are still pending requests with the same request id, we should -// not attempt to issue another request until the previous one has cleared. -func TestAppRequestDuplicateRequestIDs(t *testing.T) { - require := require.New(t) - ctrl := gomock.NewController(t) - - handler := mocks.NewMockHandler(ctrl) - sender := common.NewMockSender(ctrl) - router := NewRouter(logging.NoLog{}, sender, prometheus.NewRegistry(), "") - nodeID := ids.GenerateTestNodeID() - - requestSent := &sync.WaitGroup{} - sender.EXPECT().SendAppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Do(func(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, request []byte) { - for range nodeIDs { - requestSent.Add(1) - go func() { - require.NoError(router.AppRequest(ctx, nodeID, requestID, time.Time{}, request)) - requestSent.Done() - }() - } - }).AnyTimes() - - timeout := &sync.WaitGroup{} - response := []byte("response") - handler.EXPECT().AppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, nodeID ids.NodeID, deadline time.Time, request []byte) ([]byte, error) { - timeout.Wait() - return response, nil - }).AnyTimes() - sender.EXPECT().SendAppResponse(gomock.Any(), gomock.Any(), gomock.Any(), response) - - peers := &Peers{} - require.NoError(peers.Connected(context.Background(), nodeID, nil)) - client, err := router.RegisterAppProtocol(0x1, handler, peers) - require.NoError(err) - - require.NoError(client.AppRequest(context.Background(), set.Of(nodeID), []byte{}, nil)) - requestSent.Wait() - - // force the router to use the same requestID - router.requestID = 1 - timeout.Add(1) - err = client.AppRequest(context.Background(), set.Of(nodeID), []byte{}, nil) - requestSent.Wait() - require.ErrorIs(err, ErrRequestPending) - - timeout.Done() -} diff --git a/network/p2p/validators.go b/network/p2p/validators.go index edad9b890430..a780c87f0d8c 100644 --- a/network/p2p/validators.go +++ b/network/p2p/validators.go @@ -22,11 +22,18 @@ var ( ) type ValidatorSet interface { - Has(ctx context.Context, nodeID ids.NodeID) bool + Has(ctx context.Context, nodeID ids.NodeID) bool // TODO return error } -func NewValidators(log logging.Logger, subnetID ids.ID, validators validators.State, maxValidatorSetStaleness time.Duration) *Validators { +func NewValidators( + peers *Peers, + log logging.Logger, + subnetID ids.ID, + validators validators.State, + maxValidatorSetStaleness time.Duration, +) *Validators { return &Validators{ + peers: peers, log: log, subnetID: subnetID, validators: validators, @@ -36,6 +43,7 @@ func NewValidators(log logging.Logger, subnetID ids.ID, validators validators.St // Validators contains a set of nodes that are staking. type Validators struct { + peers *Peers log logging.Logger subnetID ids.ID validators validators.State @@ -71,20 +79,33 @@ func (v *Validators) refresh(ctx context.Context) { v.lastUpdated = time.Now() } +// Sample returns a random sample of connected validators func (v *Validators) Sample(ctx context.Context, limit int) []ids.NodeID { v.lock.Lock() defer v.lock.Unlock() v.refresh(ctx) - return v.validatorIDs.Sample(limit) + validatorIDs := v.validatorIDs.Sample(limit) + sampled := validatorIDs[:0] + + for _, validatorID := range validatorIDs { + if !v.peers.has(validatorID) { + continue + } + + sampled = append(sampled, validatorID) + } + + return sampled } +// Has returns if nodeID is a connected validator func (v *Validators) Has(ctx context.Context, nodeID ids.NodeID) bool { v.lock.Lock() defer v.lock.Unlock() v.refresh(ctx) - return v.validatorIDs.Contains(nodeID) + return v.peers.has(nodeID) && v.validatorIDs.Contains(nodeID) } diff --git a/network/p2p/validators_test.go b/network/p2p/validators_test.go index 5db06f7a2efa..e721b4a978af 100644 --- a/network/p2p/validators_test.go +++ b/network/p2p/validators_test.go @@ -9,11 +9,14 @@ import ( "testing" "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/logging" ) @@ -151,9 +154,8 @@ func TestValidatorsSample(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { require := require.New(t) - ctrl := gomock.NewController(t) - subnetID := ids.GenerateTestID() + ctrl := gomock.NewController(t) mockValidators := validators.NewMockState(ctrl) calls := make([]*gomock.Call, 0) @@ -177,7 +179,12 @@ func TestValidatorsSample(t *testing.T) { } gomock.InOrder(calls...) - v := NewValidators(logging.NoLog{}, subnetID, mockValidators, tt.maxStaleness) + network := NewNetwork(logging.NoLog{}, &common.SenderTest{}, prometheus.NewRegistry(), "") + ctx := context.Background() + require.NoError(network.Connected(ctx, nodeID1, nil)) + require.NoError(network.Connected(ctx, nodeID2, nil)) + + v := NewValidators(network.Peers, network.log, subnetID, mockValidators, tt.maxStaleness) for _, call := range tt.calls { v.lastUpdated = call.time sampled := v.Sample(context.Background(), call.limit)