diff --git a/dht.go b/dht.go index 3a2550704..f1b2adbdb 100644 --- a/dht.go +++ b/dht.go @@ -115,8 +115,8 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er dht.enableValues = cfg.EnableValues // register for network notifs. - dht.host.Network().Notify((*netNotifiee)(dht)) - + dht.proc.Go((*subscriberNotifee)(dht).subscribe) + // handle providers dht.proc.AddChild(dht.providers.Process()) dht.Validator = cfg.Validator @@ -182,12 +182,8 @@ func makeDHT(ctx context.Context, h host.Host, cfg opts.Options) *IpfsDHT { triggerRtRefresh: make(chan chan<- error), } - // create a DHT proc with the given teardown - dht.proc = goprocess.WithTeardown(func() error { - // remove ourselves from network notifs. - dht.host.Network().StopNotify((*netNotifiee)(dht)) - return nil - }) + // create a DHT proc with the given context + dht.proc = goprocess.WithContext(ctx) // create a tagged context derived from the original context ctxTags := dht.newContextWithLocalTags(ctx) diff --git a/dht_net.go b/dht_net.go index 18249099f..ffecbacfc 100644 --- a/dht_net.go +++ b/dht_net.go @@ -142,8 +142,6 @@ func (dht *IpfsDHT) handleNewMessage(s network.Stream) bool { return false } - dht.updateFromMessage(ctx, mPeer, &req) - if resp == nil { continue } @@ -187,9 +185,6 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message return nil, err } - // update the peer (on valid msgs only) - dht.updateFromMessage(ctx, p, rpmes) - stats.Record(ctx, metrics.SentRequests.M(1), metrics.SentBytes.M(int64(pmes.Size())), @@ -230,15 +225,6 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message return nil } -func (dht *IpfsDHT) updateFromMessage(ctx context.Context, p peer.ID, mes *pb.Message) error { - // Make sure that this node is actually a DHT server, not just a client. - protos, err := dht.peerstore.SupportsProtocols(p, dht.protocolStrs()...) - if err == nil && len(protos) > 0 { - dht.Update(ctx, p) - } - return nil -} - func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) { dht.smlk.Lock() ms, ok := dht.strmap[p] diff --git a/ext_test.go b/ext_test.go index b01d0b5dc..2967690f0 100644 --- a/ext_test.go +++ b/ext_test.go @@ -8,8 +8,11 @@ import ( "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/peerstore" "github.com/libp2p/go-libp2p-core/routing" opts "github.com/libp2p/go-libp2p-kad-dht/opts" + swarmt "github.com/libp2p/go-libp2p-swarm/testing" + bhost "github.com/libp2p/go-libp2p/p2p/host/basic" ggio "github.com/gogo/protobuf/io" u "github.com/ipfs/go-ipfs-util" @@ -67,25 +70,28 @@ func TestGetFailures(t *testing.T) { } ctx := context.Background() - mn, err := mocknet.FullMeshConnected(ctx, 2) - if err != nil { - t.Fatal(err) - } - hosts := mn.Hosts() - os := []opts.Option{opts.DisableAutoRefresh()} - d, err := New(ctx, hosts[0], os...) + host1 := bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)) + host2 := bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)) + + d, err := New(ctx, host1, opts.DisableAutoRefresh()) if err != nil { t.Fatal(err) } - d.Update(ctx, hosts[1].ID()) // Reply with failures to every message - hosts[1].SetStreamHandler(d.protocols[0], func(s network.Stream) { + host2.SetStreamHandler(d.protocols[0], func(s network.Stream) { time.Sleep(400 * time.Millisecond) s.Close() }) + host1.Peerstore().AddAddrs(host2.ID(), host2.Addrs(), peerstore.ConnectedAddrTTL) + _, err = host1.Network().DialPeer(ctx, host2.ID()) + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + // This one should time out ctx1, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() @@ -104,7 +110,7 @@ func TestGetFailures(t *testing.T) { t.Log("Timeout test passed.") // Reply with failures to every message - hosts[1].SetStreamHandler(d.protocols[0], func(s network.Stream) { + host2.SetStreamHandler(d.protocols[0], func(s network.Stream) { defer s.Close() pbr := ggio.NewDelimitedReader(s, network.MessageSizeMax) @@ -156,7 +162,7 @@ func TestGetFailures(t *testing.T) { Record: rec, } - s, err := hosts[1].NewStream(context.Background(), hosts[0].ID(), d.protocols[0]) + s, err := host2.NewStream(context.Background(), host1.ID(), d.protocols[0]) if err != nil { t.Fatal(err) } diff --git a/go.mod b/go.mod index 52a3daa84..7303a726e 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,8 @@ require ( github.com/ipfs/go-ipfs-util v0.0.1 github.com/ipfs/go-log v0.0.1 github.com/jbenet/goprocess v0.1.3 - github.com/libp2p/go-libp2p v0.5.2 + github.com/libp2p/go-eventbus v0.1.0 + github.com/libp2p/go-libp2p v0.5.3-0.20200221174525-7ba322244e0a github.com/libp2p/go-libp2p-core v0.3.1 github.com/libp2p/go-libp2p-kbucket v0.2.3 github.com/libp2p/go-libp2p-peerstore v0.1.4 diff --git a/go.sum b/go.sum index e372d2c14..985359ee9 100644 --- a/go.sum +++ b/go.sum @@ -163,8 +163,8 @@ github.com/libp2p/go-flow-metrics v0.0.2 h1:U5TvqfoyR6GVRM+bC15Ux1ltar1kbj6Zw6xO github.com/libp2p/go-flow-metrics v0.0.2/go.mod h1:HeoSNUrOJVK1jEpDqVEiUOIXqhbnS27omG0uWU5slZs= github.com/libp2p/go-flow-metrics v0.0.3 h1:8tAs/hSdNvUiLgtlSy3mxwxWP4I9y/jlkPFT7epKdeM= github.com/libp2p/go-flow-metrics v0.0.3/go.mod h1:HeoSNUrOJVK1jEpDqVEiUOIXqhbnS27omG0uWU5slZs= -github.com/libp2p/go-libp2p v0.5.2 h1:fjQUTyB7x/4XgO31OEWkJ5uFeHRgpoExlf0rXz5BO8k= -github.com/libp2p/go-libp2p v0.5.2/go.mod h1:o2r6AcpNl1eNGoiWhRtPji03NYOvZumeQ6u+X6gSxnM= +github.com/libp2p/go-libp2p v0.5.3-0.20200221174525-7ba322244e0a h1:cxYryrTPI23R5InZb9Kc86dj819f7yVMapQPuj1Ti1s= +github.com/libp2p/go-libp2p v0.5.3-0.20200221174525-7ba322244e0a/go.mod h1:8UlWMmxcKNxyY0ocYX8Ft4IZ0mMfr7b89v1qZdXxwrk= github.com/libp2p/go-libp2p-autonat v0.1.1 h1:WLBZcIRsjZlWdAZj9CiBSvU2wQXoUOiS1Zk1tM7DTJI= github.com/libp2p/go-libp2p-autonat v0.1.1/go.mod h1:OXqkeGOY2xJVWKAGV2inNF5aKN/djNA3fdpCWloIudE= github.com/libp2p/go-libp2p-blankhost v0.1.1/go.mod h1:pf2fvdLJPsC1FsVrNP3DUUvMzUts2dsLLBEpo1vW1ro= diff --git a/notif.go b/notif.go deleted file mode 100644 index 04000e31e..000000000 --- a/notif.go +++ /dev/null @@ -1,144 +0,0 @@ -package dht - -import ( - "context" - - "github.com/libp2p/go-libp2p-core/helpers" - "github.com/libp2p/go-libp2p-core/network" - - ma "github.com/multiformats/go-multiaddr" - mstream "github.com/multiformats/go-multistream" -) - -// netNotifiee defines methods to be used with the IpfsDHT -type netNotifiee IpfsDHT - -func (nn *netNotifiee) DHT() *IpfsDHT { - return (*IpfsDHT)(nn) -} - -func (nn *netNotifiee) Connected(n network.Network, v network.Conn) { - dht := nn.DHT() - select { - case <-dht.Process().Closing(): - return - default: - } - - p := v.RemotePeer() - protos, err := dht.peerstore.SupportsProtocols(p, dht.protocolStrs()...) - if err == nil && len(protos) != 0 { - // We lock here for consistency with the lock in testConnection. - // This probably isn't necessary because (dis)connect - // notifications are serialized but it's nice to be consistent. - dht.plk.Lock() - defer dht.plk.Unlock() - if dht.host.Network().Connectedness(p) == network.Connected { - refresh := dht.routingTable.Size() <= minRTRefreshThreshold - dht.Update(dht.Context(), p) - if refresh && dht.autoRefresh { - select { - case dht.triggerRtRefresh <- nil: - default: - } - } - } - return - } - - // Note: Unfortunately, the peerstore may not yet know that this peer is - // a DHT server. So, if it didn't return a positive response above, test - // manually. - go nn.testConnection(v) -} - -func (nn *netNotifiee) testConnection(v network.Conn) { - dht := nn.DHT() - p := v.RemotePeer() - - // Forcibly use *this* connection. Otherwise, if we have two connections, we could: - // 1. Test it twice. - // 2. Have it closed from under us leaving the second (open) connection untested. - s, err := v.NewStream() - if err != nil { - // Connection error - return - } - defer helpers.FullClose(s) - - selected, err := mstream.SelectOneOf(dht.protocolStrs(), s) - if err != nil { - // Doesn't support the protocol - return - } - // Remember this choice (makes subsequent negotiations faster) - dht.peerstore.AddProtocols(p, selected) - - // We lock here as we race with disconnect. If we didn't lock, we could - // finish processing a connect after handling the associated disconnect - // event and add the peer to the routing table after removing it. - dht.plk.Lock() - defer dht.plk.Unlock() - if dht.host.Network().Connectedness(p) == network.Connected { - refresh := dht.routingTable.Size() <= minRTRefreshThreshold - dht.Update(dht.Context(), p) - if refresh && dht.autoRefresh { - select { - case dht.triggerRtRefresh <- nil: - default: - } - } - } -} - -func (nn *netNotifiee) Disconnected(n network.Network, v network.Conn) { - dht := nn.DHT() - select { - case <-dht.Process().Closing(): - return - default: - } - - p := v.RemotePeer() - - // Lock and check to see if we're still connected. We lock to make sure - // we don't concurrently process a connect event. - dht.plk.Lock() - defer dht.plk.Unlock() - if dht.host.Network().Connectedness(p) == network.Connected { - // We're still connected. - return - } - - dht.routingTable.Remove(p) - if dht.routingTable.Size() < minRTRefreshThreshold { - // TODO: Actively bootstrap. For now, just try to add the currently connected peers. - for _, p := range dht.host.Network().Peers() { - // Don't bother probing, we do that on connect. - protos, err := dht.peerstore.SupportsProtocols(p, dht.protocolStrs()...) - if err == nil && len(protos) != 0 { - dht.Update(dht.Context(), p) - } - } - } - - dht.smlk.Lock() - defer dht.smlk.Unlock() - ms, ok := dht.strmap[p] - if !ok { - return - } - delete(dht.strmap, p) - - // Do this asynchronously as ms.lk can block for a while. - go func() { - ms.lk.Lock(context.Background()) - defer ms.lk.Unlock() - ms.invalidate() - }() -} - -func (nn *netNotifiee) OpenedStream(n network.Network, v network.Stream) {} -func (nn *netNotifiee) ClosedStream(n network.Network, v network.Stream) {} -func (nn *netNotifiee) Listen(n network.Network, a ma.Multiaddr) {} -func (nn *netNotifiee) ListenClose(n network.Network, a ma.Multiaddr) {} diff --git a/notify_test.go b/notify_test.go index 3a15a8e82..4c1046b66 100644 --- a/notify_test.go +++ b/notify_test.go @@ -16,8 +16,8 @@ func TestNotifieeMultipleConn(t *testing.T) { d1 := setupDHT(ctx, t, false) d2 := setupDHT(ctx, t, false) - nn1 := (*netNotifiee)(d1) - nn2 := (*netNotifiee)(d2) + nn1 := (*subscriberNotifee)(d1) + nn2 := (*subscriberNotifee)(d2) connect(t, ctx, d1, d2) c12 := d1.host.Network().ConnsToPeer(d2.self)[0] diff --git a/subscriber_notifee.go b/subscriber_notifee.go new file mode 100644 index 000000000..886e0c5a2 --- /dev/null +++ b/subscriber_notifee.go @@ -0,0 +1,143 @@ +package dht + +import ( + "github.com/libp2p/go-libp2p-core/event" + "github.com/libp2p/go-libp2p-core/network" + + "github.com/libp2p/go-eventbus" + + ma "github.com/multiformats/go-multiaddr" + + "github.com/jbenet/goprocess" +) + +// subscriberNotifee implements network.Notifee and also manages the subscriber to the event bus. We consume peer +// identification events to trigger inclusion in the routing table, and we consume Disconnected events to eject peers +// from it. +type subscriberNotifee IpfsDHT + +func (nn *subscriberNotifee) DHT() *IpfsDHT { + return (*IpfsDHT)(nn) +} + +func (nn *subscriberNotifee) subscribe(proc goprocess.Process) { + dht := nn.DHT() + + dht.host.Network().Notify(nn) + defer dht.host.Network().StopNotify(nn) + + var err error + evts := []interface{}{ + &event.EvtPeerIdentificationCompleted{}, + } + + // subscribe to the EvtPeerIdentificationCompleted event which notifies us every time a peer successfully completes identification + sub, err := dht.host.EventBus().Subscribe(evts, eventbus.BufSize(256)) + if err != nil { + logger.Errorf("dht not subscribed to peer identification events; things will fail; err: %s", err) + } + defer sub.Close() + + dht.plk.Lock() + for _, p := range dht.host.Network().Peers() { + protos, err := dht.peerstore.SupportsProtocols(p, dht.protocolStrs()...) + if err == nil && len(protos) != 0 { + dht.Update(dht.ctx, p) + } + } + dht.plk.Unlock() + + for { + select { + case evt, more := <-sub.Out(): + // we will not be getting any more events + if !more { + return + } + + // something has gone really wrong if we get an event for another type + ev, ok := evt.(event.EvtPeerIdentificationCompleted) + if !ok { + logger.Errorf("got wrong type from subscription: %T", ev) + return + } + + dht.plk.Lock() + if dht.host.Network().Connectedness(ev.Peer) != network.Connected { + dht.plk.Unlock() + continue + } + + // if the peer supports the DHT protocol, add it to our RT and kick a refresh if needed + protos, err := dht.peerstore.SupportsProtocols(ev.Peer, dht.protocolStrs()...) + if err == nil && len(protos) != 0 { + refresh := dht.routingTable.Size() <= minRTRefreshThreshold + dht.Update(dht.ctx, ev.Peer) + if refresh && dht.autoRefresh { + select { + case dht.triggerRtRefresh <- nil: + default: + } + } + } + dht.plk.Unlock() + + case <-proc.Closing(): + return + } + } +} + +func (nn *subscriberNotifee) Disconnected(n network.Network, v network.Conn) { + dht := nn.DHT() + select { + case <-dht.Process().Closing(): + return + default: + } + + p := v.RemotePeer() + + // Lock and check to see if we're still connected. We lock to make sure + // we don't concurrently process a connect event. + dht.plk.Lock() + defer dht.plk.Unlock() + if dht.host.Network().Connectedness(p) == network.Connected { + // We're still connected. + return + } + + dht.routingTable.Remove(p) + + if dht.routingTable.Size() < minRTRefreshThreshold { + // TODO: Actively bootstrap. For now, just try to add the currently connected peers. + for _, p := range dht.host.Network().Peers() { + // Don't bother probing, we do that on connect. + protos, err := dht.peerstore.SupportsProtocols(p, dht.protocolStrs()...) + if err == nil && len(protos) != 0 { + dht.Update(dht.Context(), p) + } + } + } + + dht.smlk.Lock() + defer dht.smlk.Unlock() + ms, ok := dht.strmap[p] + if !ok { + return + } + delete(dht.strmap, p) + + // Do this asynchronously as ms.lk can block for a while. + go func() { + ms.lk.Lock() + defer ms.lk.Unlock() + ms.invalidate() + }() +} + +func (nn *subscriberNotifee) Connected(n network.Network, v network.Conn) {} +func (nn *subscriberNotifee) OpenedStream(n network.Network, v network.Stream) {} +func (nn *subscriberNotifee) ClosedStream(n network.Network, v network.Stream) {} +func (nn *subscriberNotifee) Listen(n network.Network, a ma.Multiaddr) {} +func (nn *subscriberNotifee) ListenClose(n network.Network, a ma.Multiaddr) {}