diff --git a/notify.go b/notify.go deleted file mode 100644 index f560d398..00000000 --- a/notify.go +++ /dev/null @@ -1,75 +0,0 @@ -package pubsub - -import ( - "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peer" - ma "github.com/multiformats/go-multiaddr" -) - -var _ network.Notifiee = (*PubSubNotif)(nil) - -type PubSubNotif PubSub - -func (p *PubSubNotif) OpenedStream(n network.Network, s network.Stream) { -} - -func (p *PubSubNotif) ClosedStream(n network.Network, s network.Stream) { -} - -func (p *PubSubNotif) Connected(n network.Network, c network.Conn) { - // ignore transient connections - if c.Stat().Limited { - return - } - - go func() { - p.newPeersPrioLk.RLock() - p.newPeersMx.Lock() - p.newPeersPend[c.RemotePeer()] = struct{}{} - p.newPeersMx.Unlock() - p.newPeersPrioLk.RUnlock() - - select { - case p.newPeers <- struct{}{}: - default: - } - }() -} - -func (p *PubSubNotif) Disconnected(n network.Network, c network.Conn) { -} - -func (p *PubSubNotif) Listen(n network.Network, _ ma.Multiaddr) { -} - -func (p *PubSubNotif) ListenClose(n network.Network, _ ma.Multiaddr) { -} - -func (p *PubSubNotif) Initialize() { - isTransient := func(pid peer.ID) bool { - for _, c := range p.host.Network().ConnsToPeer(pid) { - if !c.Stat().Limited { - return false - } - } - - return true - } - - p.newPeersPrioLk.RLock() - p.newPeersMx.Lock() - for _, pid := range p.host.Network().Peers() { - if isTransient(pid) { - continue - } - - p.newPeersPend[pid] = struct{}{} - } - p.newPeersMx.Unlock() - p.newPeersPrioLk.RUnlock() - - select { - case p.newPeers <- struct{}{}: - default: - } -} diff --git a/peer_notify.go b/peer_notify.go new file mode 100644 index 00000000..44aceeef --- /dev/null +++ b/peer_notify.go @@ -0,0 +1,112 @@ +package pubsub + +import ( + "context" + + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" +) + +func (ps *PubSub) watchForNewPeers(ctx context.Context) { + // We don't bother subscribing to "connectivity" events because we always run identify after + // every new connection. + sub, err := ps.host.EventBus().Subscribe([]interface{}{ + &event.EvtPeerIdentificationCompleted{}, + &event.EvtPeerProtocolsUpdated{}, + }) + if err != nil { + log.Errorf("failed to subscribe to peer identification events: %v", err) + return + } + defer sub.Close() + + ps.newPeersPrioLk.RLock() + ps.newPeersMx.Lock() + for _, pid := range ps.host.Network().Peers() { + if ps.host.Network().Connectedness(pid) != network.Connected { + continue + } + ps.newPeersPend[pid] = struct{}{} + } + ps.newPeersMx.Unlock() + ps.newPeersPrioLk.RUnlock() + + select { + case ps.newPeers <- struct{}{}: + default: + } + + var supportsProtocol func(protocol.ID) bool + if ps.protoMatchFunc != nil { + var supportedProtocols []func(protocol.ID) bool + for _, proto := range ps.rt.Protocols() { + + supportedProtocols = append(supportedProtocols, ps.protoMatchFunc(proto)) + } + supportsProtocol = func(proto protocol.ID) bool { + for _, fn := range supportedProtocols { + if (fn)(proto) { + return true + } + } + return false + } + } else { + supportedProtocols := make(map[protocol.ID]struct{}) + for _, proto := range ps.rt.Protocols() { + supportedProtocols[proto] = struct{}{} + } + supportsProtocol = func(proto protocol.ID) bool { + _, ok := supportedProtocols[proto] + return ok + } + } + + for ctx.Err() == nil { + var ev any + select { + case <-ctx.Done(): + return + case ev = <-sub.Out(): + } + + var protos []protocol.ID + var peer peer.ID + switch ev := ev.(type) { + case event.EvtPeerIdentificationCompleted: + peer = ev.Peer + protos = ev.Protocols + case event.EvtPeerProtocolsUpdated: + peer = ev.Peer + protos = ev.Added + default: + continue + } + + // We don't bother checking connectivity (connected and non-"limited") here because + // we'll check when actually handling the new peer. + + for _, p := range protos { + if supportsProtocol(p) { + ps.notifyNewPeer(peer) + break + } + } + } + +} + +func (ps *PubSub) notifyNewPeer(peer peer.ID) { + ps.newPeersPrioLk.RLock() + ps.newPeersMx.Lock() + ps.newPeersPend[peer] = struct{}{} + ps.newPeersMx.Unlock() + ps.newPeersPrioLk.RUnlock() + + select { + case ps.newPeers <- struct{}{}: + default: + } +} diff --git a/pubsub.go b/pubsub.go index c4ecae65..24c297dd 100644 --- a/pubsub.go +++ b/pubsub.go @@ -327,14 +327,12 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option h.SetStreamHandler(id, ps.handleNewStream) } } - h.Network().Notify((*PubSubNotif)(ps)) + go ps.watchForNewPeers(ctx) ps.val.Start(ps) go ps.processLoop(ctx) - (*PubSubNotif)(ps).Initialize() - return ps, nil } @@ -687,6 +685,8 @@ func (p *PubSub) handlePendingPeers() { p.newPeersPrioLk.Unlock() for pid := range newPeers { + // Make sure we have a non-limited connection. We do this late because we may have + // disconnected in the meantime. if p.host.Network().Connectedness(pid) != network.Connected { continue }