diff --git a/network/p2p/peer_tracker.go b/network/p2p/peer_tracker.go index 4d5c78a48f53..31a4fb61cbb8 100644 --- a/network/p2p/peer_tracker.go +++ b/network/p2p/peer_tracker.go @@ -25,8 +25,8 @@ import ( const ( bandwidthHalflife = 5 * time.Minute - // controls how eagerly we connect to new peers vs. using - // peers with known good response bandwidth. + // controls how eagerly we connect to new peers vs. using peers with known + // good response bandwidth. desiredMinResponsivePeers = 20 newPeerConnectFactor = 0.1 @@ -35,91 +35,107 @@ const ( randomPeerProbability = 0.2 ) -// information we track on a given peer -type peerInfo struct { - version *version.Application - bandwidth safemath.Averager -} - // Tracks the bandwidth of responses coming from peers, // preferring to contact peers with known good bandwidth, connecting // to new peers with an exponentially decaying probability. type PeerTracker struct { // Lock to protect concurrent access to the peer tracker - lock sync.Mutex - // All peers we are connected to - peers map[ids.NodeID]*peerInfo - // Peers that we're connected to that we've sent a request to - // since we most recently connected to them. + lock sync.RWMutex + // Peers that we're connected to that we haven't sent a request to since we + // most recently connected to them. + untrackedPeers set.Set[ids.NodeID] + // Peers that we're connected to that we've sent a request to since we most + // recently connected to them. trackedPeers set.Set[ids.NodeID] - // Peers that we're connected to that responded to the last request they were sent. + // Peers that we're connected to that responded to the last request they + // were sent. responsivePeers set.Set[ids.NodeID] - // Max heap that contains the average bandwidth of peers. - bandwidthHeap heap.Map[ids.NodeID, safemath.Averager] - averageBandwidth safemath.Averager - log logging.Logger - numTrackedPeers prometheus.Gauge - numResponsivePeers prometheus.Gauge - averageBandwidthMetric prometheus.Gauge + // Bandwidth of peers that we have measured. + peerBandwidth map[ids.NodeID]safemath.Averager + // Max heap that contains the average bandwidth of peers that do not have an + // outstanding request. + bandwidthHeap heap.Map[ids.NodeID, safemath.Averager] + // Average bandwidth is only used for metrics. + averageBandwidth safemath.Averager + + // The below fields are assumed to be constant and are not protected by the + // lock. + log logging.Logger + ignoredNodes set.Set[ids.NodeID] + minVersion *version.Application + metrics peerTrackerMetrics +} + +type peerTrackerMetrics struct { + numTrackedPeers prometheus.Gauge + numResponsivePeers prometheus.Gauge + averageBandwidth prometheus.Gauge } func NewPeerTracker( log logging.Logger, metricsNamespace string, registerer prometheus.Registerer, + ignoredNodes set.Set[ids.NodeID], + minVersion *version.Application, ) (*PeerTracker, error) { t := &PeerTracker{ - peers: make(map[ids.NodeID]*peerInfo), - trackedPeers: make(set.Set[ids.NodeID]), - responsivePeers: make(set.Set[ids.NodeID]), + peerBandwidth: make(map[ids.NodeID]safemath.Averager), bandwidthHeap: heap.NewMap[ids.NodeID, safemath.Averager](func(a, b safemath.Averager) bool { return a.Read() > b.Read() }), averageBandwidth: safemath.NewAverager(0, bandwidthHalflife, time.Now()), log: log, - numTrackedPeers: prometheus.NewGauge( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Name: "num_tracked_peers", - Help: "number of tracked peers", - }, - ), - numResponsivePeers: prometheus.NewGauge( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Name: "num_responsive_peers", - Help: "number of responsive peers", - }, - ), - averageBandwidthMetric: prometheus.NewGauge( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Name: "average_bandwidth", - Help: "average sync bandwidth used by peers", - }, - ), + ignoredNodes: ignoredNodes, + minVersion: minVersion, + metrics: peerTrackerMetrics{ + numTrackedPeers: prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: metricsNamespace, + Name: "num_tracked_peers", + Help: "number of tracked peers", + }, + ), + numResponsivePeers: prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: metricsNamespace, + Name: "num_responsive_peers", + Help: "number of responsive peers", + }, + ), + averageBandwidth: prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: metricsNamespace, + Name: "average_bandwidth", + Help: "average sync bandwidth used by peers", + }, + ), + }, } err := utils.Err( - registerer.Register(t.numTrackedPeers), - registerer.Register(t.numResponsivePeers), - registerer.Register(t.averageBandwidthMetric), + registerer.Register(t.metrics.numTrackedPeers), + registerer.Register(t.metrics.numResponsivePeers), + registerer.Register(t.metrics.averageBandwidth), ) return t, err } -// Returns true if we're not connected to enough peers. -// Otherwise returns true probabilistically based on the number of tracked peers. -// Assumes p.lock is held. -func (p *PeerTracker) shouldTrackNewPeer() bool { +// Returns true if: +// - We have not observed the desired minimum number of responsive peers. +// - Randomly with the frequency decreasing as the number of responsive peers +// increases. +// +// Assumes the read lock is held. +func (p *PeerTracker) shouldSelectUntrackedPeer() bool { numResponsivePeers := p.responsivePeers.Len() if numResponsivePeers < desiredMinResponsivePeers { return true } - if len(p.trackedPeers) >= len(p.peers) { - // already tracking all the peers - return false + if p.untrackedPeers.Len() == 0 { + return false // already tracking all peers } + // TODO danlaine: we should consider tuning this probability function. // With [newPeerConnectFactor] as 0.1 the probabilities are: // @@ -136,150 +152,166 @@ func (p *PeerTracker) shouldTrackNewPeer() bool { return rand.Float64() < newPeerProbability // #nosec G404 } -// TODO get rid of minVersion -// Returns a peer that we're connected to. -// If we should track more peers, returns a random peer with version >= [minVersion], if any exist. -// Otherwise, with probability [randomPeerProbability] returns a random peer from [p.responsivePeers]. -// With probability [1-randomPeerProbability] returns the peer in [p.bandwidthHeap] with the highest bandwidth. -func (p *PeerTracker) GetAnyPeer(minVersion *version.Application) (ids.NodeID, bool) { - p.lock.Lock() - defer p.lock.Unlock() +// SelectPeer that we could send a request to. +// +// If we should track more peers, returns a random untracked peer, if any exist. +// Otherwise, with probability [randomPeerProbability] returns a random peer +// from [p.responsivePeers]. +// With probability [1-randomPeerProbability] returns the peer in +// [p.bandwidthHeap] with the highest bandwidth. +// +// Returns false if there are no connected peers. +func (p *PeerTracker) SelectPeer() (ids.NodeID, bool) { + p.lock.RLock() + defer p.lock.RUnlock() - if p.shouldTrackNewPeer() { - for nodeID := range p.peers { - // if minVersion is specified and peer's version is less, skip - if minVersion != nil && p.peers[nodeID].version.Compare(minVersion) < 0 { - continue - } - // skip peers already tracked - if p.trackedPeers.Contains(nodeID) { - continue - } - p.log.Debug( - "tracking peer", - zap.Int("trackedPeers", len(p.trackedPeers)), + if p.shouldSelectUntrackedPeer() { + if nodeID, ok := p.untrackedPeers.Peek(); ok { + p.log.Debug("selecting peer", + zap.String("reason", "untracked"), zap.Stringer("nodeID", nodeID), + zap.Int("trackedPeers", p.trackedPeers.Len()), + zap.Int("responsivePeers", p.responsivePeers.Len()), ) return nodeID, true } } - var ( - nodeID ids.NodeID - ok bool - ) - useRand := rand.Float64() < randomPeerProbability // #nosec G404 - if useRand { - nodeID, ok = p.responsivePeers.Peek() + useBandwidthHeap := rand.Float64() > randomPeerProbability // #nosec G404 + if useBandwidthHeap { + if nodeID, bandwidth, ok := p.bandwidthHeap.Peek(); ok { + p.log.Debug("selecting peer", + zap.String("reason", "bandwidth"), + zap.Stringer("nodeID", nodeID), + zap.Float64("bandwidth", bandwidth.Read()), + ) + return nodeID, true + } } else { - nodeID, _, ok = p.bandwidthHeap.Pop() + if nodeID, ok := p.responsivePeers.Peek(); ok { + p.log.Debug("selecting peer", + zap.String("reason", "responsive"), + zap.Stringer("nodeID", nodeID), + ) + return nodeID, true + } } - if !ok { - // if no nodes found in the bandwidth heap, return a tracked node at random - return p.trackedPeers.Peek() + + if nodeID, ok := p.trackedPeers.Peek(); ok { + p.log.Debug("selecting peer", + zap.String("reason", "tracked"), + zap.Stringer("nodeID", nodeID), + zap.Bool("checkedBandwidthHeap", useBandwidthHeap), + ) + return nodeID, true } - p.log.Debug( - "peer tracking: popping peer", - zap.Stringer("nodeID", nodeID), - zap.Bool("random", useRand), - ) - return nodeID, true + + // We're not connected to any peers. + return ids.EmptyNodeID, false } // Record that we sent a request to [nodeID]. -func (p *PeerTracker) TrackPeer(nodeID ids.NodeID) { +// +// Removes the peer's bandwidth averager from the bandwidth heap. +func (p *PeerTracker) RegisterRequest(nodeID ids.NodeID) { p.lock.Lock() defer p.lock.Unlock() + p.untrackedPeers.Remove(nodeID) p.trackedPeers.Add(nodeID) - p.numTrackedPeers.Set(float64(p.trackedPeers.Len())) + p.bandwidthHeap.Remove(nodeID) + + p.metrics.numTrackedPeers.Set(float64(p.trackedPeers.Len())) } // Record that we observed that [nodeID]'s bandwidth is [bandwidth]. +// // Adds the peer's bandwidth averager to the bandwidth heap. -func (p *PeerTracker) TrackBandwidth(nodeID ids.NodeID, bandwidth float64) { +func (p *PeerTracker) RegisterResponse(nodeID ids.NodeID, bandwidth float64) { + p.updateBandwidth(nodeID, bandwidth, true) +} + +// Record that a request failed to [nodeID]. +// +// Adds the peer's bandwidth averager to the bandwidth heap. +func (p *PeerTracker) RegisterFailure(nodeID ids.NodeID) { + p.updateBandwidth(nodeID, 0, false) +} + +func (p *PeerTracker) updateBandwidth(nodeID ids.NodeID, bandwidth float64, responsive bool) { p.lock.Lock() defer p.lock.Unlock() - peer := p.peers[nodeID] - if peer == nil { - // we're not connected to this peer, nothing to do here - p.log.Debug("tracking bandwidth for untracked peer", zap.Stringer("nodeID", nodeID)) + if !p.trackedPeers.Contains(nodeID) { + // we're not tracking this peer, nothing to do here + p.log.Debug("tracking bandwidth for untracked peer", + zap.Stringer("nodeID", nodeID), + ) return } now := time.Now() - if peer.bandwidth == nil { - peer.bandwidth = safemath.NewAverager(bandwidth, bandwidthHalflife, now) + peerBandwidth, ok := p.peerBandwidth[nodeID] + if ok { + peerBandwidth.Observe(bandwidth, now) } else { - peer.bandwidth.Observe(bandwidth, now) + peerBandwidth = safemath.NewAverager(bandwidth, bandwidthHalflife, now) + p.peerBandwidth[nodeID] = peerBandwidth } - p.bandwidthHeap.Push(nodeID, peer.bandwidth) + p.bandwidthHeap.Push(nodeID, peerBandwidth) + p.averageBandwidth.Observe(bandwidth, now) - if bandwidth == 0 { - p.responsivePeers.Remove(nodeID) - } else { + if responsive { p.responsivePeers.Add(nodeID) - // TODO danlaine: shouldn't we add the observation of 0 - // to the average bandwidth in the if statement? - p.averageBandwidth.Observe(bandwidth, now) - p.averageBandwidthMetric.Set(p.averageBandwidth.Read()) + } else { + p.responsivePeers.Remove(nodeID) } - p.numResponsivePeers.Set(float64(p.responsivePeers.Len())) + + p.metrics.numResponsivePeers.Set(float64(p.responsivePeers.Len())) + p.metrics.averageBandwidth.Set(p.averageBandwidth.Read()) } -// Connected should be called when [nodeID] connects to this node +// Connected should be called when [nodeID] connects to this node. func (p *PeerTracker) Connected(nodeID ids.NodeID, nodeVersion *version.Application) { - p.lock.Lock() - defer p.lock.Unlock() - - peer := p.peers[nodeID] - if peer == nil { - p.peers[nodeID] = &peerInfo{ - version: nodeVersion, - } + // If this peer should be ignored, don't mark it as connected. + if p.ignoredNodes.Contains(nodeID) { return } - - // Peer is already connected, update the version if it has changed. - // Log a warning message since the consensus engine should never call Connected on a peer - // that we have already marked as Connected. - if nodeVersion.Compare(peer.version) != 0 { - p.peers[nodeID] = &peerInfo{ - version: nodeVersion, - bandwidth: peer.bandwidth, - } - p.log.Warn( - "updating node version of already connected peer", - zap.Stringer("nodeID", nodeID), - zap.Stringer("storedVersion", peer.version), - zap.Stringer("nodeVersion", nodeVersion), - ) - } else { - p.log.Warn( - "ignoring peer connected event for already connected peer with identical version", - zap.Stringer("nodeID", nodeID), - ) + // If minVersion is specified and peer's version is less, don't mark it as + // connected. + if p.minVersion != nil && nodeVersion.Compare(p.minVersion) < 0 { + return } + + p.lock.Lock() + defer p.lock.Unlock() + + p.untrackedPeers.Add(nodeID) } -// Disconnected should be called when [nodeID] disconnects from this node +// Disconnected should be called when [nodeID] disconnects from this node. func (p *PeerTracker) Disconnected(nodeID ids.NodeID) { p.lock.Lock() defer p.lock.Unlock() - p.bandwidthHeap.Remove(nodeID) + // Because of the checks performed in Connected, it's possible that this + // node was never marked as connected here. However, all of the below + // functions are noops if called with a peer that was never marked as + // connected. + p.untrackedPeers.Remove(nodeID) p.trackedPeers.Remove(nodeID) - p.numTrackedPeers.Set(float64(p.trackedPeers.Len())) p.responsivePeers.Remove(nodeID) - p.numResponsivePeers.Set(float64(p.responsivePeers.Len())) - delete(p.peers, nodeID) + delete(p.peerBandwidth, nodeID) + p.bandwidthHeap.Remove(nodeID) + + p.metrics.numTrackedPeers.Set(float64(p.trackedPeers.Len())) + p.metrics.numResponsivePeers.Set(float64(p.responsivePeers.Len())) } // Returns the number of peers the node is connected to. func (p *PeerTracker) Size() int { - p.lock.Lock() - defer p.lock.Unlock() + p.lock.RLock() + defer p.lock.RUnlock() - return len(p.peers) + return p.untrackedPeers.Len() + p.trackedPeers.Len() } diff --git a/network/p2p/peer_tracker_test.go b/network/p2p/peer_tracker_test.go index 42b44edf71a8..01ebcfb840c4 100644 --- a/network/p2p/peer_tracker_test.go +++ b/network/p2p/peer_tracker_test.go @@ -16,7 +16,13 @@ import ( func TestPeerTracker(t *testing.T) { require := require.New(t) - p, err := NewPeerTracker(logging.NoLog{}, "", prometheus.NewRegistry()) + p, err := NewPeerTracker( + logging.NoLog{}, + "", + prometheus.NewRegistry(), + nil, + nil, + ) require.NoError(err) // Connect some peers @@ -38,25 +44,25 @@ func TestPeerTracker(t *testing.T) { // Expect requests to go to new peers until we have desiredMinResponsivePeers responsive peers. for i := 0; i < desiredMinResponsivePeers+numExtraPeers/2; i++ { - peer, ok := p.GetAnyPeer(nil) + peer, ok := p.SelectPeer() require.True(ok) - require.NotNil(peer) + require.NotZero(peer) _, exists := responsivePeers[peer] require.Falsef(exists, "expected connecting to a new peer, but got the same peer twice: peer %s iteration %d", peer, i) responsivePeers[peer] = true - p.TrackPeer(peer) // mark the peer as having a message sent to it + p.RegisterRequest(peer) // mark the peer as having a message sent to it } // Mark some peers as responsive and others as not responsive i := 0 for peer := range responsivePeers { if i < desiredMinResponsivePeers { - p.TrackBandwidth(peer, 10) + p.RegisterResponse(peer, 10) } else { responsivePeers[peer] = false // remember which peers were not responsive - p.TrackBandwidth(peer, 0) + p.RegisterFailure(peer) } i++ } @@ -64,18 +70,18 @@ func TestPeerTracker(t *testing.T) { // Expect requests to go to responsive or new peers, so long as they are available numRequests := 50 for i := 0; i < numRequests; i++ { - peer, ok := p.GetAnyPeer(nil) + peer, ok := p.SelectPeer() require.True(ok) - require.NotNil(peer) + require.NotZero(peer) responsive, ok := responsivePeers[peer] if ok { require.Truef(responsive, "expected connecting to a responsive peer, but got a peer that was not responsive: peer %s iteration %d", peer, i) - p.TrackBandwidth(peer, 10) + p.RegisterResponse(peer, 10) } else { responsivePeers[peer] = false // remember that we connected to this peer - p.TrackPeer(peer) // mark the peer as having a message sent to it - p.TrackBandwidth(peer, 0) // mark the peer as non-responsive + p.RegisterRequest(peer) // mark the peer as having a message sent to it + p.RegisterFailure(peer) // mark the peer as non-responsive } } @@ -88,9 +94,9 @@ func TestPeerTracker(t *testing.T) { } // Requests should fall back on non-responsive peers when no other choice is left - peer, ok := p.GetAnyPeer(nil) + peer, ok := p.SelectPeer() require.True(ok) - require.NotNil(peer) + require.NotZero(peer) responsive, ok := responsivePeers[peer] require.True(ok) diff --git a/x/sync/client.go b/x/sync/client.go index b503a0be9982..7a71f1d4435e 100644 --- a/x/sync/client.go +++ b/x/sync/client.go @@ -17,7 +17,6 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/maybe" - "github.com/ava-labs/avalanchego/version" "github.com/ava-labs/avalanchego/x/merkledb" pb "github.com/ava-labs/avalanchego/proto/pb/sync" @@ -67,22 +66,20 @@ type Client interface { } type client struct { - networkClient NetworkClient - stateSyncNodes []ids.NodeID - stateSyncNodeIdx uint32 - stateSyncMinVersion *version.Application - log logging.Logger - metrics SyncMetrics - tokenSize int + networkClient NetworkClient + stateSyncNodes []ids.NodeID + stateSyncNodeIdx uint32 + log logging.Logger + metrics SyncMetrics + tokenSize int } type ClientConfig struct { - NetworkClient NetworkClient - StateSyncNodeIDs []ids.NodeID - StateSyncMinVersion *version.Application - Log logging.Logger - Metrics SyncMetrics - BranchFactor merkledb.BranchFactor + NetworkClient NetworkClient + StateSyncNodeIDs []ids.NodeID + Log logging.Logger + Metrics SyncMetrics + BranchFactor merkledb.BranchFactor } func NewClient(config *ClientConfig) (Client, error) { @@ -90,12 +87,11 @@ func NewClient(config *ClientConfig) (Client, error) { return nil, err } return &client{ - networkClient: config.NetworkClient, - stateSyncNodes: config.StateSyncNodeIDs, - stateSyncMinVersion: config.StateSyncMinVersion, - log: config.Log, - metrics: config.Metrics, - tokenSize: merkledb.BranchFactorToTokenSize[config.BranchFactor], + networkClient: config.NetworkClient, + stateSyncNodes: config.StateSyncNodeIDs, + log: config.Log, + metrics: config.Metrics, + tokenSize: merkledb.BranchFactorToTokenSize[config.BranchFactor], }, nil } @@ -365,7 +361,7 @@ func (c *client) get(ctx context.Context, request []byte) (ids.NodeID, []byte, e c.metrics.RequestMade() if len(c.stateSyncNodes) == 0 { - nodeID, response, err = c.networkClient.RequestAny(ctx, c.stateSyncMinVersion, request) + nodeID, response, err = c.networkClient.RequestAny(ctx, request) } else { // Get the next nodeID to query using the [nodeIdx] offset. // If we're out of nodes, loop back to 0. diff --git a/x/sync/client_test.go b/x/sync/client_test.go index 5c84fb4bd1b9..fed60e93c1fa 100644 --- a/x/sync/client_test.go +++ b/x/sync/client_test.go @@ -21,7 +21,6 @@ import ( "github.com/ava-labs/avalanchego/trace" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/maybe" - "github.com/ava-labs/avalanchego/version" "github.com/ava-labs/avalanchego/x/merkledb" pb "github.com/ava-labs/avalanchego/proto/pb/sync" @@ -98,10 +97,9 @@ func sendRangeProofRequest( networkClient.EXPECT().RequestAny( gomock.Any(), // ctx - gomock.Any(), // min version gomock.Any(), // request ).DoAndReturn( - func(_ context.Context, _ *version.Application, request []byte) (ids.NodeID, []byte, error) { + func(_ context.Context, request []byte) (ids.NodeID, []byte, error) { go func() { // Get response from server require.NoError(server.AppRequest(context.Background(), clientNodeID, 0, time.Now().Add(time.Hour), request)) @@ -398,10 +396,9 @@ func sendChangeProofRequest( networkClient.EXPECT().RequestAny( gomock.Any(), // ctx - gomock.Any(), // min version gomock.Any(), // request ).DoAndReturn( - func(_ context.Context, _ *version.Application, request []byte) (ids.NodeID, []byte, error) { + func(_ context.Context, request []byte) (ids.NodeID, []byte, error) { go func() { // Get response from server require.NoError(server.AppRequest(context.Background(), clientNodeID, 0, time.Now().Add(time.Hour), request)) @@ -763,7 +760,6 @@ func TestAppRequestSendFailed(t *testing.T) { networkClient.EXPECT().RequestAny( gomock.Any(), gomock.Any(), - gomock.Any(), ).Return(ids.EmptyNodeID, nil, errAppSendFailed).Times(2) _, err = client.GetChangeProof( diff --git a/x/sync/mock_network_client.go b/x/sync/mock_network_client.go index 428191492c4d..3156d145d82b 100644 --- a/x/sync/mock_network_client.go +++ b/x/sync/mock_network_client.go @@ -113,9 +113,9 @@ func (mr *MockNetworkClientMockRecorder) Request(arg0, arg1, arg2 any) *gomock.C } // RequestAny mocks base method. -func (m *MockNetworkClient) RequestAny(arg0 context.Context, arg1 *version.Application, arg2 []byte) (ids.NodeID, []byte, error) { +func (m *MockNetworkClient) RequestAny(arg0 context.Context, arg1 []byte) (ids.NodeID, []byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RequestAny", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "RequestAny", arg0, arg1) ret0, _ := ret[0].(ids.NodeID) ret1, _ := ret[1].([]byte) ret2, _ := ret[2].(error) @@ -123,7 +123,7 @@ func (m *MockNetworkClient) RequestAny(arg0 context.Context, arg1 *version.Appli } // RequestAny indicates an expected call of RequestAny. -func (mr *MockNetworkClientMockRecorder) RequestAny(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockNetworkClientMockRecorder) RequestAny(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestAny", reflect.TypeOf((*MockNetworkClient)(nil).RequestAny), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestAny", reflect.TypeOf((*MockNetworkClient)(nil).RequestAny), arg0, arg1) } diff --git a/x/sync/network_client.go b/x/sync/network_client.go index 9aea4483db92..15f59cc5885a 100644 --- a/x/sync/network_client.go +++ b/x/sync/network_client.go @@ -41,7 +41,6 @@ type NetworkClient interface { // the request should be retried. RequestAny( ctx context.Context, - minVersion *version.Application, request []byte, ) (ids.NodeID, []byte, error) @@ -77,8 +76,6 @@ type NetworkClient interface { type networkClient struct { lock sync.Mutex log logging.Logger - // This node's ID - myNodeID ids.NodeID // requestID counter used to track outbound requests requestID uint32 // requestID => handler for the response/failure @@ -98,15 +95,21 @@ func NewNetworkClient( log logging.Logger, metricsNamespace string, registerer prometheus.Registerer, + minVersion *version.Application, ) (NetworkClient, error) { - peerTracker, err := p2p.NewPeerTracker(log, metricsNamespace, registerer) + peerTracker, err := p2p.NewPeerTracker( + log, + metricsNamespace, + registerer, + set.Of(myNodeID), + minVersion, + ) if err != nil { return nil, fmt.Errorf("failed to create peer tracker: %w", err) } return &networkClient{ appSender: appSender, - myNodeID: myNodeID, outstandingRequestHandlers: make(map[uint32]ResponseHandler), activeRequests: semaphore.NewWeighted(maxActiveRequests), peers: peerTracker, @@ -191,7 +194,6 @@ func (c *networkClient) getRequestHandler(requestID uint32) (ResponseHandler, bo // If [errAppSendFailed] is returned this should be considered fatal. func (c *networkClient) RequestAny( ctx context.Context, - minVersion *version.Application, request []byte, ) (ids.NodeID, []byte, error) { // Take a slot from total [activeRequests] and block until a slot becomes available. @@ -200,18 +202,32 @@ func (c *networkClient) RequestAny( } defer c.activeRequests.Release(1) - nodeID, ok := c.peers.GetAnyPeer(minVersion) - if !ok { - return ids.EmptyNodeID, nil, fmt.Errorf( - "no peers found matching version %s out of %d peers", - minVersion, c.peers.Size(), - ) + nodeID, responseChan, err := c.sendRequestAny(ctx, request) + if err != nil { + return ids.EmptyNodeID, nil, err } - response, err := c.request(ctx, nodeID, request) + response, err := c.awaitResponse(ctx, nodeID, responseChan) return nodeID, response, err } +func (c *networkClient) sendRequestAny( + ctx context.Context, + request []byte, +) (ids.NodeID, chan []byte, error) { + c.lock.Lock() + defer c.lock.Unlock() + + nodeID, ok := c.peers.SelectPeer() + if !ok { + numPeers := c.peers.Size() + return ids.EmptyNodeID, nil, fmt.Errorf("no peers found from %d peers", numPeers) + } + + responseChan, err := c.sendRequestLocked(ctx, nodeID, request) + return nodeID, responseChan, err +} + // If [errAppSendFailed] is returned this should be considered fatal. func (c *networkClient) Request( ctx context.Context, @@ -225,40 +241,56 @@ func (c *networkClient) Request( } defer c.activeRequests.Release(1) - return c.request(ctx, nodeID, request) + responseChan, err := c.sendRequest(ctx, nodeID, request) + if err != nil { + return nil, err + } + + return c.awaitResponse(ctx, nodeID, responseChan) } -// Sends [request] to [nodeID] and returns the response. -// Returns an error if the request failed or [ctx] is canceled. -// If [errAppSendFailed] is returned this should be considered fatal. -// Blocks until a response is received or the [ctx] is canceled fails. -// Releases active requests semaphore if there was an error in sending the request. -// Assumes [nodeID] is never [c.myNodeID] since we guarantee -// [c.myNodeID] will not be added to [c.peers]. -// Assumes [c.lock] is not held and unlocks [c.lock] before returning. -func (c *networkClient) request( +func (c *networkClient) sendRequest( ctx context.Context, nodeID ids.NodeID, request []byte, -) ([]byte, error) { +) (chan []byte, error) { c.lock.Lock() - c.log.Debug("sending request to peer", - zap.Stringer("nodeID", nodeID), - zap.Int("requestLen", len(request)), - ) - c.peers.TrackPeer(nodeID) + defer c.lock.Unlock() + return c.sendRequestLocked(ctx, nodeID, request) +} + +// Sends [request] to [nodeID] and returns a channel that will populate the +// response. +// +// If [errAppSendFailed] is returned this should be considered fatal. +// +// Assumes [nodeID] is never [c.myNodeID] since we guarantee [c.myNodeID] will +// not be added to [c.peers]. +// +// Assumes [c.lock] is held. +func (c *networkClient) sendRequestLocked( + ctx context.Context, + nodeID ids.NodeID, + request []byte, +) (chan []byte, error) { requestID := c.requestID c.requestID++ - nodeIDs := set.Of(nodeID) + c.log.Debug("sending request to peer", + zap.Stringer("nodeID", nodeID), + zap.Uint32("requestID", requestID), + zap.Int("requestLen", len(request)), + ) + c.peers.RegisterRequest(nodeID) // Send an app request to the peer. + nodeIDs := set.Of(nodeID) if err := c.appSender.SendAppRequest(ctx, nodeIDs, requestID, request); err != nil { c.lock.Unlock() - c.log.Fatal( - "failed to send app request", + c.log.Fatal("failed to send app request", zap.Stringer("nodeID", nodeID), + zap.Uint32("requestID", requestID), zap.Int("requestLen", len(request)), zap.Error(err), ) @@ -267,9 +299,24 @@ func (c *networkClient) request( handler := newResponseHandler() c.outstandingRequestHandlers[requestID] = handler + return handler.responseChan, nil +} - c.lock.Unlock() // unlock so response can be received - +// awaitResponse from [nodeID] and returns the response. +// +// Returns an error if the request failed or [ctx] is canceled. +// +// Blocks until a response is received or the [ctx] is canceled fails. +// +// Assumes [nodeID] is never [c.myNodeID] since we guarantee [c.myNodeID] will +// not be added to [c.peers]. +// +// Assumes [c.lock] is not held. +func (c *networkClient) awaitResponse( + ctx context.Context, + nodeID ids.NodeID, + responseChan chan []byte, +) ([]byte, error) { var ( response []byte responded bool @@ -277,22 +324,21 @@ func (c *networkClient) request( ) select { case <-ctx.Done(): - c.peers.TrackBandwidth(nodeID, 0) + c.peers.RegisterFailure(nodeID) return nil, ctx.Err() - case response, responded = <-handler.responseChan: + case response, responded = <-responseChan: } if !responded { - c.peers.TrackBandwidth(nodeID, 0) + c.peers.RegisterFailure(nodeID) return nil, errRequestFailed } elapsedSeconds := time.Since(startTime).Seconds() - bandwidth := float64(len(response))/elapsedSeconds + epsilon - c.peers.TrackBandwidth(nodeID, bandwidth) + bandwidth := float64(len(response)) / (elapsedSeconds + epsilon) + c.peers.RegisterResponse(nodeID, bandwidth) c.log.Debug("received response from peer", zap.Stringer("nodeID", nodeID), - zap.Uint32("requestID", requestID), zap.Int("responseLen", len(response)), ) return response, nil @@ -303,22 +349,12 @@ func (c *networkClient) Connected( nodeID ids.NodeID, nodeVersion *version.Application, ) error { - if nodeID == c.myNodeID { - c.log.Debug("skipping registering self as peer") - return nil - } - c.log.Debug("adding new peer", zap.Stringer("nodeID", nodeID)) c.peers.Connected(nodeID, nodeVersion) return nil } func (c *networkClient) Disconnected(_ context.Context, nodeID ids.NodeID) error { - if nodeID == c.myNodeID { - c.log.Debug("skipping deregistering self as peer") - return nil - } - c.log.Debug("disconnecting peer", zap.Stringer("nodeID", nodeID)) c.peers.Disconnected(nodeID) return nil