diff --git a/dht.go b/dht.go index 8f0f344a1..0465249de 100644 --- a/dht.go +++ b/dht.go @@ -46,6 +46,11 @@ const ( modeClient = 2 ) +const ( + kad1 protocol.ID = "/kad/1.0.0" + kad2 protocol.ID = "/kad/2.0.0" +) + // IpfsDHT is an implementation of Kademlia with S/Kademlia modifications. // It is used to implement the base Routing module. type IpfsDHT struct { @@ -73,7 +78,11 @@ type IpfsDHT struct { stripedPutLocks [256]sync.Mutex - protocols []protocol.ID // DHT protocols + // Primary DHT protocols - we query and respond to these protocols + protocols []protocol.ID + + // DHT protocols we can respond to (may contain protocols in addition to the primary protocols) + serverProtocols []protocol.ID auto bool mode mode @@ -109,12 +118,16 @@ var ( // New creates a new DHT with the specified host and options. func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error) { var cfg config - if err := cfg.Apply(append([]Option{defaults}, options...)...); err != nil { + if err := cfg.apply(append([]Option{defaults}, options...)...); err != nil { return nil, err } - if cfg.disjointPaths == 0 { - cfg.disjointPaths = cfg.bucketSize / 2 + if err := cfg.applyFallbacks(); err != nil { + return nil, err + } + if err := cfg.validate(); err != nil { + return nil, err } + dht := makeDHT(ctx, h, cfg) dht.autoRefresh = cfg.routingTable.autoRefresh dht.rtRefreshPeriod = cfg.routingTable.refreshPeriod @@ -175,7 +188,7 @@ func NewDHT(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT { // requests. If you need a peer to respond to DHT requests, use NewDHT instead. // NewDHTClient creates a new DHT object with the given peer as the 'local' host func NewDHTClient(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT { - dht, err := New(ctx, h, Datastore(dstore), Client(true)) + dht, err := New(ctx, h, Datastore(dstore), Mode(ModeClient)) if err != nil { panic(err) } @@ -196,6 +209,19 @@ func makeDHT(ctx context.Context, h host.Host, cfg config) *IpfsDHT { cmgr.UntagPeer(p, "kbucket") } + protocols := []protocol.ID{cfg.protocolPrefix + kad2} + serverProtocols := []protocol.ID{cfg.protocolPrefix + kad2, cfg.protocolPrefix + kad1} + + // check if custom test protocols were set + if len(cfg.testProtocols) > 0 { + protocols = make([]protocol.ID, len(cfg.testProtocols)) + serverProtocols = make([]protocol.ID, len(cfg.testProtocols)) + for i, p := range cfg.testProtocols { + protocols[i] = cfg.protocolPrefix + p + serverProtocols[i] = cfg.protocolPrefix + p + } + } + dht := &IpfsDHT{ datastore: cfg.datastore, self: h.ID(), @@ -205,7 +231,8 @@ func makeDHT(ctx context.Context, h host.Host, cfg config) *IpfsDHT { birth: time.Now(), rng: rand.New(rand.NewSource(rand.Int63())), routingTable: rt, - protocols: cfg.protocols, + protocols: protocols, + serverProtocols: serverProtocols, bucketSize: cfg.bucketSize, alpha: cfg.concurrency, d: cfg.disjointPaths, @@ -483,22 +510,30 @@ func (dht *IpfsDHT) setMode(m mode) error { } } +// moveToServerMode advertises (via libp2p identify updates) that we are able to respond to DHT queries and sets the appropriate stream handlers. +// Note: We may support responding to queries with protocols aside from our primary ones in order to support +// interoperability with older versions of the DHT protocol. func (dht *IpfsDHT) moveToServerMode() error { dht.mode = modeServer - for _, p := range dht.protocols { + for _, p := range dht.serverProtocols { dht.host.SetStreamHandler(p, dht.handleNewStream) } return nil } +// moveToClientMode stops advertising (and rescinds advertisements via libp2p identify updates) that we are able to +// respond to DHT queries and removes the appropriate stream handlers. We also kill all inbound streams that were +// utilizing the handled protocols. +// Note: We may support responding to queries with protocols aside from our primary ones in order to support +// interoperability with older versions of the DHT protocol. func (dht *IpfsDHT) moveToClientMode() error { dht.mode = modeClient - for _, p := range dht.protocols { + for _, p := range dht.serverProtocols { dht.host.RemoveStreamHandler(p) } pset := make(map[protocol.ID]bool) - for _, p := range dht.protocols { + for _, p := range dht.serverProtocols { pset[p] = true } @@ -540,15 +575,6 @@ func (dht *IpfsDHT) Close() error { return dht.proc.Close() } -func (dht *IpfsDHT) protocolStrs() []string { - pstrs := make([]string, len(dht.protocols)) - for idx, proto := range dht.protocols { - pstrs[idx] = string(proto) - } - - return pstrs -} - func mkDsKey(s string) ds.Key { return ds.NewKey(base32.RawStdEncoding.EncodeToString([]byte(s))) } diff --git a/dht_net.go b/dht_net.go index 1eada2550..2b38169f9 100644 --- a/dht_net.go +++ b/dht_net.go @@ -318,6 +318,9 @@ func (ms *messageSender) prep(ctx context.Context) error { return nil } + // We only want to speak to peers using our primary protocols. We do not want to query any peer that only speaks + // one of the secondary "server" protocols that we happen to support (e.g. older nodes that we can respond to for + // backwards compatibility reasons). nstr, err := ms.dht.host.NewStream(ctx, ms.p, ms.dht.protocols...) if err != nil { return err diff --git a/dht_options.go b/dht_options.go index 9028f8755..ecbce7496 100644 --- a/dht_options.go +++ b/dht_options.go @@ -23,12 +23,14 @@ const ( ModeServer ) +const DefaultPrefix protocol.ID = "/ipfs" + // Options is a structure containing all the options that can be used when constructing a DHT. type config struct { datastore ds.Batching validator record.Validator mode ModeOpt - protocols []protocol.ID + protocolPrefix protocol.ID bucketSize int disjointPaths int concurrency int @@ -42,12 +44,18 @@ type config struct { autoRefresh bool latencyTolerance time.Duration } + + // internal parameters, not publicly exposed + protocols, serverProtocols []protocol.ID + + // test parameters + testProtocols []protocol.ID } -// Apply applies the given options to this Option -func (o *config) Apply(opts ...Option) error { +// apply applies the given options to this Option +func (c *config) apply(opts ...Option) error { for i, opt := range opts { - if err := opt(o); err != nil { + if err := opt(c); err != nil { return fmt.Errorf("dht option %d failed: %s", i, err) } } @@ -57,6 +65,8 @@ func (o *config) Apply(opts ...Option) error { // Option DHT option type. type Option func(*config) error +const defaultBucketSize = 20 + // defaults are the default DHT options. This option will be automatically // prepended to any options you pass to the DHT constructor. var defaults = func(o *config) error { @@ -64,7 +74,7 @@ var defaults = func(o *config) error { "pk": record.PublicKeyValidator{}, } o.datastore = dssync.MutexWrap(ds.NewMapDatastore()) - o.protocols = DefaultProtocols + o.protocolPrefix = DefaultPrefix o.enableProviders = true o.enableValues = true @@ -74,17 +84,47 @@ var defaults = func(o *config) error { o.routingTable.autoRefresh = true o.maxRecordAge = time.Hour * 36 - o.bucketSize = 20 + o.bucketSize = defaultBucketSize o.concurrency = 3 return nil } +// applyFallbacks sets default DHT options. It is applied after Defaults and any options passed to the constructor in +// order to allow for defaults that are based on other set options. +func (c *config) applyFallbacks() error { + if c.disjointPaths == 0 { + c.disjointPaths = c.bucketSize / 2 + } + return nil +} + +func (c *config) validate() error { + if c.protocolPrefix == DefaultPrefix { + if c.bucketSize != defaultBucketSize { + return fmt.Errorf("protocol prefix %s must use bucket size %d", DefaultPrefix, defaultBucketSize) + } + if !c.enableProviders { + return fmt.Errorf("protocol prefix %s must have providers enabled", DefaultPrefix) + } + if !c.enableValues { + return fmt.Errorf("protocol prefix %s must have values enabled", DefaultPrefix) + } + if nsval, ok := c.validator.(record.NamespacedValidator); !ok { + return fmt.Errorf("protocol prefix %s must use a namespaced validator", DefaultPrefix) + } else if len(nsval) > 2 || nsval["pk"] == nil || nsval["ipns"] == nil { + return fmt.Errorf("protocol prefix %s must support only the /pk and /ipns namespaces", DefaultPrefix) + } + return nil + } + return nil +} + // RoutingTableLatencyTolerance sets the maximum acceptable latency for peers // in the routing table's cluster. func RoutingTableLatencyTolerance(latency time.Duration) Option { - return func(o *config) error { - o.routingTable.latencyTolerance = latency + return func(c *config) error { + c.routingTable.latencyTolerance = latency return nil } } @@ -92,8 +132,8 @@ func RoutingTableLatencyTolerance(latency time.Duration) Option { // RoutingTableRefreshQueryTimeout sets the timeout for routing table refresh // queries. func RoutingTableRefreshQueryTimeout(timeout time.Duration) Option { - return func(o *config) error { - o.routingTable.refreshQueryTimeout = timeout + return func(c *config) error { + c.routingTable.refreshQueryTimeout = timeout return nil } } @@ -105,8 +145,8 @@ func RoutingTableRefreshQueryTimeout(timeout time.Duration) Option { // 1. Then searching for a random key in each bucket that hasn't been queried in // the last refresh period. func RoutingTableRefreshPeriod(period time.Duration) Option { - return func(o *config) error { - o.routingTable.refreshPeriod = period + return func(c *config) error { + c.routingTable.refreshPeriod = period return nil } } @@ -115,20 +155,8 @@ func RoutingTableRefreshPeriod(period time.Duration) Option { // // Defaults to an in-memory (temporary) map. func Datastore(ds ds.Batching) Option { - return func(o *config) error { - o.datastore = ds - return nil - } -} - -// Client configures whether or not the DHT operates in client-only mode. -// -// Defaults to false. -func Client(only bool) Option { - return func(o *config) error { - if only { - o.mode = ModeClient - } + return func(c *config) error { + c.datastore = ds return nil } } @@ -137,8 +165,8 @@ func Client(only bool) Option { // // Defaults to ModeAuto. func Mode(m ModeOpt) Option { - return func(o *config) error { - o.mode = m + return func(c *config) error { + c.mode = m return nil } } @@ -147,8 +175,8 @@ func Mode(m ModeOpt) Option { // // Defaults to a namespaced validator that can only validate public keys. func Validator(v record.Validator) Option { - return func(o *config) error { - o.validator = v + return func(c *config) error { + c.validator = v return nil } } @@ -161,8 +189,8 @@ func Validator(v record.Validator) Option { // myValidator)`, all records with keys starting with `/ipns/` will be validated // with `myValidator`. func NamespacedValidator(ns string, v record.Validator) Option { - return func(o *config) error { - nsval, ok := o.validator.(record.NamespacedValidator) + return func(c *config) error { + nsval, ok := c.validator.(record.NamespacedValidator) if !ok { return fmt.Errorf("can only add namespaced validators to a NamespacedValidator") } @@ -171,12 +199,13 @@ func NamespacedValidator(ns string, v record.Validator) Option { } } -// Protocols sets the protocols for the DHT +// ProtocolPrefix sets an application specific prefix to be attached to all DHT protocols. For example, +// /myapp/kad/1.0.0 instead of /ipfs/kad/1.0.0. Prefix should be of the form /myapp. // -// Defaults to dht.DefaultProtocols -func Protocols(protocols ...protocol.ID) Option { - return func(o *config) error { - o.protocols = protocols +// Defaults to dht.DefaultPrefix +func ProtocolPrefix(prefix protocol.ID) Option { + return func(c *config) error { + c.protocolPrefix = prefix return nil } } @@ -185,8 +214,8 @@ func Protocols(protocols ...protocol.ID) Option { // // The default value is 20. func BucketSize(bucketSize int) Option { - return func(o *config) error { - o.bucketSize = bucketSize + return func(c *config) error { + c.bucketSize = bucketSize return nil } } @@ -195,8 +224,8 @@ func BucketSize(bucketSize int) Option { // // The default value is 3. func Concurrency(alpha int) Option { - return func(o *config) error { - o.concurrency = alpha + return func(c *config) error { + c.concurrency = alpha return nil } } @@ -205,8 +234,8 @@ func Concurrency(alpha int) Option { // // The default value is BucketSize/2. func DisjointPaths(d int) Option { - return func(o *config) error { - o.disjointPaths = d + return func(c *config) error { + c.disjointPaths = d return nil } } @@ -218,8 +247,8 @@ func DisjointPaths(d int) Option { // until the year 2020 (a great time in the future). For that record to stick around // it must be rebroadcasted more frequently than once every 'MaxRecordAge' func MaxRecordAge(maxAge time.Duration) Option { - return func(o *config) error { - o.maxRecordAge = maxAge + return func(c *config) error { + c.maxRecordAge = maxAge return nil } } @@ -228,8 +257,8 @@ func MaxRecordAge(maxAge time.Duration) Option { // table. This means that we will neither refresh the routing table periodically // nor when the routing table size goes below the minimum threshold. func DisableAutoRefresh() Option { - return func(o *config) error { - o.routingTable.autoRefresh = false + return func(c *config) error { + c.routingTable.autoRefresh = false return nil } } @@ -241,8 +270,8 @@ func DisableAutoRefresh() Option { // WARNING: do not change this unless you're using a forked DHT (i.e., a private // network and/or distinct DHT protocols with the `Protocols` option). func DisableProviders() Option { - return func(o *config) error { - o.enableProviders = false + return func(c *config) error { + c.enableProviders = false return nil } } @@ -255,8 +284,17 @@ func DisableProviders() Option { // WARNING: do not change this unless you're using a forked DHT (i.e., a private // network and/or distinct DHT protocols with the `Protocols` option). func DisableValues() Option { - return func(o *config) error { - o.enableValues = false + return func(c *config) error { + c.enableValues = false + return nil + } +} + +// customProtocols is only to be used for testing. It sets the protocols that the DHT listens on and queries with to be +// the ones passed in. The custom protocols are still augmented by the Prefix. +func customProtocols(protos ...protocol.ID) Option { + return func(c *config) error { + c.testProtocols = protos return nil } } diff --git a/dht_test.go b/dht_test.go index 5fb42b654..ec41db4a5 100644 --- a/dht_test.go +++ b/dht_test.go @@ -17,7 +17,6 @@ import ( "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" - "github.com/libp2p/go-libp2p-core/protocol" "github.com/libp2p/go-libp2p-core/routing" "github.com/multiformats/go-multihash" "github.com/multiformats/go-multistream" @@ -129,8 +128,11 @@ func (testAtomicPutValidator) Select(_ string, bs [][]byte) (int, error) { return index, nil } +var testPrefix = ProtocolPrefix("/test") + func setupDHT(ctx context.Context, t *testing.T, client bool, options ...Option) *IpfsDHT { baseOpts := []Option{ + testPrefix, NamespacedValidator("v", blankValidator{}), DisableAutoRefresh(), } @@ -800,6 +802,7 @@ func TestRefreshBelowMinRTThreshold(t *testing.T) { dhtA, err := New( ctx, bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)), + testPrefix, Mode(ModeServer), NamespacedValidator("v", blankValidator{}), ) @@ -1559,6 +1562,9 @@ func TestProvideDisabled(t *testing.T) { var ( optsA, optsB []Option ) + optsA = append(optsA, ProtocolPrefix("/provMaybeDisabled")) + optsB = append(optsB, ProtocolPrefix("/provMaybeDisabled")) + if !enabledA { optsA = append(optsA, DisableProviders()) } @@ -1614,10 +1620,9 @@ func TestProvideDisabled(t *testing.T) { } func TestHandleRemotePeerProtocolChanges(t *testing.T) { - proto := protocol.ID("/v1/dht") ctx := context.Background() os := []Option{ - Protocols(proto), + testPrefix, Mode(ModeServer), NamespacedValidator("v", blankValidator{}), DisableAutoRefresh(), @@ -1657,7 +1662,7 @@ func TestGetSetPluggedProtocol(t *testing.T) { defer cancel() os := []Option{ - Protocols("/esh/dht"), + ProtocolPrefix("/esh"), Mode(ModeServer), NamespacedValidator("v", blankValidator{}), DisableAutoRefresh(), @@ -1696,7 +1701,7 @@ func TestGetSetPluggedProtocol(t *testing.T) { defer cancel() dhtA, err := New(ctx, bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)), []Option{ - Protocols("/esh/dht"), + ProtocolPrefix("/esh"), Mode(ModeServer), NamespacedValidator("v", blankValidator{}), DisableAutoRefresh(), @@ -1706,7 +1711,7 @@ func TestGetSetPluggedProtocol(t *testing.T) { } dhtB, err := New(ctx, bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)), []Option{ - Protocols("/lsr/dht"), + ProtocolPrefix("/lsr"), Mode(ModeServer), NamespacedValidator("v", blankValidator{}), DisableAutoRefresh(), @@ -1818,3 +1823,107 @@ func TestDynamicModeSwitching(t *testing.T) { assertDHTClient() } + +func TestProtocolUpgrade(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + os := []Option{ + Mode(ModeServer), + NamespacedValidator("v", blankValidator{}), + DisableAutoRefresh(), + DisjointPaths(1), + } + + // This test verifies that we can have a node serving both old and new DHTs that will respond as a server to the old + // DHT, but only act as a client of the new DHT. In it's capacity as a server it should also only tell queriers + // about other DHT servers in the new DHT. + + dhtA, err := New(ctx, bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)), + append([]Option{testPrefix}, os...)...) + if err != nil { + t.Fatal(err) + } + + dhtB, err := New(ctx, bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)), + append([]Option{testPrefix}, os...)...) + if err != nil { + t.Fatal(err) + } + + dhtC, err := New(ctx, bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)), + append([]Option{testPrefix, customProtocols(kad1)}, os...)...) + if err != nil { + t.Fatal(err) + } + + connect(t, ctx, dhtA, dhtB) + connectNoSync(t, ctx, dhtA, dhtC) + wait(t, ctx, dhtC, dhtA) + + if sz := dhtA.RoutingTable().Size(); sz != 1 { + t.Fatalf("Expected routing table to be of size %d got %d", 1, sz) + } + + ctxT, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + if err := dhtB.PutValue(ctxT, "/v/bat", []byte("screech")); err != nil { + t.Fatal(err) + } + + value, err := dhtC.GetValue(ctxT, "/v/bat") + if err != nil { + t.Fatal(err) + } + + if string(value) != "screech" { + t.Fatalf("Expected 'screech' got '%s'", string(value)) + } + + if err := dhtC.PutValue(ctxT, "/v/cat", []byte("meow")); err != nil { + t.Fatal(err) + } + + value, err = dhtB.GetValue(ctxT, "/v/cat") + if err != nil { + t.Fatal(err) + } + + if string(value) != "meow" { + t.Fatalf("Expected 'meow' got '%s'", string(value)) + } + + // Add record into local DHT only + rec := record.MakePutRecord("/v/crow", []byte("caw")) + rec.TimeReceived = u.FormatRFC3339(time.Now()) + err = dhtC.putLocal(string(rec.Key), rec) + if err != nil { + t.Fatal(err) + } + + value, err = dhtB.GetValue(ctxT, "/v/crow") + switch err { + case nil: + t.Fatalf("should not have been able to find value for %s", "/v/crow") + case routing.ErrNotFound: + default: + t.Fatal(err) + } + + // Add record into local DHT only + rec = record.MakePutRecord("/v/bee", []byte("buzz")) + rec.TimeReceived = u.FormatRFC3339(time.Now()) + err = dhtB.putLocal(string(rec.Key), rec) + if err != nil { + t.Fatal(err) + } + + value, err = dhtC.GetValue(ctxT, "/v/bee") + if err != nil { + t.Fatal(err) + } + + if string(value) != "buzz" { + t.Fatalf("Expected 'buzz' got '%s'", string(value)) + } +} diff --git a/ext_test.go b/ext_test.go index 24fff4436..832609e92 100644 --- a/ext_test.go +++ b/ext_test.go @@ -30,7 +30,7 @@ func TestHungRequest(t *testing.T) { } hosts := mn.Hosts() - os := []Option{DisableAutoRefresh()} + os := []Option{testPrefix, DisableAutoRefresh()} d, err := New(ctx, hosts[0], os...) if err != nil { t.Fatal(err) @@ -80,7 +80,7 @@ func TestGetFailures(t *testing.T) { host1 := bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)) host2 := bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)) - d, err := New(ctx, host1, DisableAutoRefresh(), Mode(ModeServer)) + d, err := New(ctx, host1, testPrefix, DisableAutoRefresh(), Mode(ModeServer)) if err != nil { t.Fatal(err) } @@ -207,7 +207,7 @@ func TestNotFound(t *testing.T) { } hosts := mn.Hosts() - os := []Option{DisableAutoRefresh()} + os := []Option{testPrefix, DisableAutoRefresh()} d, err := New(ctx, hosts[0], os...) if err != nil { t.Fatal(err) @@ -287,7 +287,7 @@ func TestLessThanKResponses(t *testing.T) { } hosts := mn.Hosts() - os := []Option{DisableAutoRefresh()} + os := []Option{testPrefix, DisableAutoRefresh()} d, err := New(ctx, hosts[0], os...) if err != nil { t.Fatal(err) @@ -357,7 +357,7 @@ func TestMultipleQueries(t *testing.T) { t.Fatal(err) } hosts := mn.Hosts() - os := []Option{DisableAutoRefresh()} + os := []Option{testPrefix, DisableAutoRefresh()} d, err := New(ctx, hosts[0], os...) if err != nil { t.Fatal(err) diff --git a/opts/options.go b/opts/options.go index 92804c02c..a32d8cbee 100644 --- a/opts/options.go +++ b/opts/options.go @@ -6,20 +6,10 @@ import ( "time" ds "github.com/ipfs/go-datastore" - "github.com/libp2p/go-libp2p-core/protocol" "github.com/libp2p/go-libp2p-kad-dht" "github.com/libp2p/go-libp2p-record" ) -// Deprecated: The old format did not support more than one message per stream, and is not supported -// or relevant with stream pooling. ProtocolDHT should be used instead. -const ProtocolDHTOld = "/ipfs/dht" - -var ( - ProtocolDHT = dht.ProtocolDHT - DefaultProtocols = dht.DefaultProtocols -) - // Deprecated: use dht.RoutingTableLatencyTolerance func RoutingTableLatencyTolerance(latency time.Duration) dht.Option { return dht.RoutingTableLatencyTolerance(latency) @@ -38,8 +28,16 @@ func RoutingTableRefreshPeriod(period time.Duration) dht.Option { // Deprecated: use dht.Datastore func Datastore(ds ds.Batching) dht.Option { return dht.Datastore(ds) } -// Deprecated: use dht.Client -func Client(only bool) dht.Option { return dht.Client(only) } +// Client configures whether or not the DHT operates in client-only mode. +// +// Defaults to false (which is ModeAuto). +// Deprecated: use dht.Mode(ModeClient) +func Client(only bool) dht.Option { + if only { + return dht.Mode(dht.ModeClient) + } + return dht.Mode(dht.ModeAuto) +} // Deprecated: use dht.Mode func Mode(m dht.ModeOpt) dht.Option { return dht.Mode(m) } @@ -52,9 +50,6 @@ func NamespacedValidator(ns string, v record.Validator) dht.Option { return dht.NamespacedValidator(ns, v) } -// Deprecated: use dht.Protocols -func Protocols(protocols ...protocol.ID) dht.Option { return dht.Protocols(protocols...) } - // Deprecated: use dht.BucketSize func BucketSize(bucketSize int) dht.Option { return dht.BucketSize(bucketSize) } diff --git a/records_test.go b/records_test.go index ad14c6d60..092f5b3c0 100644 --- a/records_test.go +++ b/records_test.go @@ -318,6 +318,9 @@ func TestValuesDisabled(t *testing.T) { var ( optsA, optsB []Option ) + optsA = append(optsA, ProtocolPrefix("/valuesMaybeDisabled")) + optsB = append(optsB, ProtocolPrefix("/valuesMaybeDisabled")) + if !enabledA { optsA = append(optsA, DisableValues()) } diff --git a/subscriber_notifee.go b/subscriber_notifee.go index 3dd9a4223..ce62e906e 100644 --- a/subscriber_notifee.go +++ b/subscriber_notifee.go @@ -5,6 +5,7 @@ import ( "github.com/libp2p/go-libp2p-core/event" "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-eventbus" @@ -56,11 +57,11 @@ func newSubscriberNotifiee(dht *IpfsDHT) (*subscriberNotifee, error) { dht.plk.Lock() defer dht.plk.Unlock() for _, p := range dht.host.Network().Peers() { - protos, err := dht.peerstore.SupportsProtocols(p, dht.protocolStrs()...) + valid, err := dht.validRTPeer(p) if err != nil { return nil, fmt.Errorf("could not check peerstore for protocol support: err: %s", err) } - if len(protos) != 0 { + if valid { dht.Update(dht.ctx, p) } } @@ -110,25 +111,25 @@ func handlePeerIdentificationCompletedEvent(dht *IpfsDHT, e event.EvtPeerIdentif } // if the peer supports the DHT protocol, add it to our RT and kick a refresh if needed - protos, err := dht.peerstore.SupportsProtocols(e.Peer, dht.protocolStrs()...) + valid, err := dht.validRTPeer(e.Peer) if err != nil { logger.Errorf("could not check peerstore for protocol support: err: %s", err) return } - if len(protos) != 0 { + if valid { dht.Update(dht.ctx, e.Peer) fixLowPeers(dht) } } func handlePeerProtocolsUpdatedEvent(dht *IpfsDHT, e event.EvtPeerProtocolsUpdated) { - protos, err := dht.peerstore.SupportsProtocols(e.Peer, dht.protocolStrs()...) + valid, err := dht.validRTPeer(e.Peer) if err != nil { logger.Errorf("could not check peerstore for protocol support: err: %s", err) return } - if len(protos) > 0 { + if valid { dht.routingTable.Update(e.Peer) } else { dht.routingTable.Remove(e.Peer) @@ -158,6 +159,23 @@ func handleLocalReachabilityChangedEvent(dht *IpfsDHT, e event.EvtLocalReachabil } } +// validRTPeer returns true if the peer supports the DHT protocol and false otherwise. Supporting the DHT protocol means +// supporting the primary protocols, we do not want to add peers that are speaking obsolete secondary protocols to our +// routing table +func (dht *IpfsDHT) validRTPeer(p peer.ID) (bool, error) { + pstrs := make([]string, len(dht.protocols)) + for idx, proto := range dht.protocols { + pstrs[idx] = string(proto) + } + + protos, err := dht.peerstore.SupportsProtocols(p, pstrs...) + if err != nil { + return false, err + } + + return len(protos) > 0, nil +} + // fixLowPeers tries to get more peers into the routing table if we're below the threshold func fixLowPeers(dht *IpfsDHT) { if dht.routingTable.Size() > minRTRefreshThreshold { @@ -167,8 +185,8 @@ func fixLowPeers(dht *IpfsDHT) { // Passively add peers we already know about for _, p := range dht.host.Network().Peers() { // Don't bother probing, we do that on connect. - protos, err := dht.peerstore.SupportsProtocols(p, dht.protocolStrs()...) - if err == nil && len(protos) != 0 { + valid, _ := dht.validRTPeer(p) + if valid { dht.Update(dht.Context(), p) } }