diff --git a/dht.go b/dht.go index c8fb3c77..175b20f3 100644 --- a/dht.go +++ b/dht.go @@ -378,8 +378,10 @@ func makeDHT(h host.Host, cfg dhtcfg.Config) (*IpfsDHT, error) { func (dht *IpfsDHT) lookupCheck(ctx context.Context, p peer.ID) error { // lookup request to p requesting for its own peer.ID peerids, err := dht.protoMessenger.GetClosestPeers(ctx, p, p) - // p should return at least its own peerid - if err == nil && len(peerids) == 0 { + // p is expected to return at least 1 peer id, unless our routing table has + // less than bucketSize peers, in which case we aren't picky about who we + // add to the routing table. + if err == nil && len(peerids) == 0 && dht.routingTable.Size() >= dht.bucketSize { return fmt.Errorf("peer %s failed to return its closest peers, got %d", p, len(peerids)) } return err diff --git a/dht_test.go b/dht_test.go index 0e4fe926..9527d380 100644 --- a/dht_test.go +++ b/dht_test.go @@ -23,6 +23,7 @@ import ( "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/routing" + "github.com/libp2p/go-msgio" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multihash" @@ -1488,42 +1489,82 @@ func TestInvalidServer(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - a := setupDHT(ctx, t, false) - b := setupDHT(ctx, t, true) + s0 := setupDHT(ctx, t, false, BucketSize(2)) // server + s1 := setupDHT(ctx, t, false, BucketSize(2)) // server + m0 := setupDHT(ctx, t, false, BucketSize(2)) // misbehabing server + m1 := setupDHT(ctx, t, false, BucketSize(2)) // misbehabing server + + // make m0 and m1 advertise all dht server protocols, but hang on all requests + for _, proto := range s0.serverProtocols { + for _, m := range []*IpfsDHT{m0, m1} { + // Hang on every request. + m.host.SetStreamHandler(proto, func(s network.Stream) { + r := msgio.NewVarintReaderSize(s, network.MessageSizeMax) + msgbytes, err := r.ReadMsg() + if err != nil { + t.Fatal(err) + } + var req pb.Message + err = req.Unmarshal(msgbytes) + if err != nil { + t.Fatal(err) + } - // make b advertise all dht server protocols - for _, proto := range a.serverProtocols { - // Hang on every request. - b.host.SetStreamHandler(proto, func(s network.Stream) { - defer s.Reset() // nolint - <-ctx.Done() - }) + // answer with an empty response message + resp := pb.NewMessage(req.GetType(), nil, req.GetClusterLevel()) + + // send out response msg + err = net.WriteMsg(s, resp) + if err != nil { + t.Fatal(err) + } + }) + } } - connectNoSync(t, ctx, a, b) + // connect s0 and m0 + connectNoSync(t, ctx, s0, m0) - c := testCaseCids[0] + // add a provider (p) for a key (k) to s0 + k := testCaseCids[0] p := peer.ID("TestPeer") - a.ProviderStore().AddProvider(ctx, c.Hash(), peer.AddrInfo{ID: p}) + s0.ProviderStore().AddProvider(ctx, k.Hash(), peer.AddrInfo{ID: p}) time.Sleep(time.Millisecond * 5) // just in case... - provs, err := b.FindProviders(ctx, c) + // find the provider for k from m0 + provs, err := m0.FindProviders(ctx, k) if err != nil { t.Fatal(err) } - if len(provs) == 0 { t.Fatal("Expected to get a provider back") } - if provs[0].ID != p { t.Fatal("expected it to be our test peer") } - if a.routingTable.Find(b.self) != "" { - t.Fatal("DHT clients should not be added to routing tables") + + // verify that m0 and s0 contain each other in their routing tables + if s0.routingTable.Find(m0.self) == "" { + // m0 is added to s0 routing table even though it is misbehaving, because + // s0's routing table is not well populated, so s0 isn't picky about who it adds. + t.Fatal("Misbehaving DHT servers should be added to routing table if not well populated") } - if b.routingTable.Find(a.self) == "" { - t.Fatal("DHT server should have been added to the dht client's routing table") + if m0.routingTable.Find(s0.self) == "" { + t.Fatal("DHT server should have been added to the misbehaving server routing table") + } + + // connect s0 to both s1 and m1 + connectNoSync(t, ctx, s0, s1) + connectNoSync(t, ctx, s0, m1) + + // s1 should be added to s0's routing table. Then, because s0's routing table + // contains more than bucketSize (2) entries, lookupCheck is enabled and m1 + // shouldn't be added, because it fails the lookupCheck (hang on all requests). + if s0.routingTable.Find(s1.self) == "" { + t.Fatal("Well behaving DHT server should have been added to the server routing table") + } + if s0.routingTable.Find(m1.self) != "" { + t.Fatal("Misbehaving DHT servers should not be added to routing table if well populated") } } diff --git a/ext_test.go b/ext_test.go index 51340f33..36d3ca93 100644 --- a/ext_test.go +++ b/ext_test.go @@ -42,5 +42,7 @@ func TestInvalidRemotePeers(t *testing.T) { time.Sleep(100 * time.Millisecond) + // hosts[1] isn't added to the routing table because it isn't responding to + // the DHT request require.Equal(t, 0, d.routingTable.Size()) }