Skip to content

Commit

Permalink
set DHT mode to server (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinxie authored Aug 20, 2021
1 parent 1b4944e commit a1a4f02
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 20 deletions.
14 changes: 7 additions & 7 deletions main/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ func main() {
options = append(options, p2p.MasterKey("123"))
options = append(options, p2p.SecureIO())

host, err := p2p.NewHost(context.Background(), options...)
ctx := context.Background()
host, err := p2p.NewHost(ctx, options...)
if err != nil {
p2p.Logger().Panic("Error when instantiating a host.", zap.Error(err))
}
Expand All @@ -127,7 +128,7 @@ func main() {
HandleUnicastMsg := func(ctx context.Context, w io.Writer, data []byte) error {
return HandleMsg(ctx, data)
}
if err := host.AddBroadcastPubSub("measurement", HandleMsg); err != nil {
if err := host.AddBroadcastPubSub(ctx, "measurement", HandleMsg); err != nil {
p2p.Logger().Panic("Error when adding broadcast pubsub.", zap.Error(err))
}
if err := host.AddUnicastPubSub("measurement", HandleUnicastMsg); err != nil {
Expand All @@ -139,13 +140,12 @@ func main() {
if err != nil {
p2p.Logger().Panic("Error when parsing to the bootstrap node address", zap.Error(err))
}
if err := host.ConnectWithMultiaddr(context.Background(), ma); err != nil {
if err := host.ConnectWithMultiaddr(ctx, ma); err != nil {
p2p.Logger().Panic("Error when connecting to the bootstrap node", zap.Error(err))
}
host.JoinOverlay(context.Background())
host.JoinOverlay(ctx)
}

ctx := context.Background()
tick := time.Tick(time.Duration(frequency) * time.Millisecond)
for {
select {
Expand All @@ -154,8 +154,8 @@ func main() {
if broadcast {
err = host.Broadcast(ctx, "measurement", []byte(fmt.Sprintf("%s", host.HostIdentity())))
} else {
for _, neighbor := range host.Neighbors(context.Background()) {
host.Unicast(context.Background(), neighbor, "measurement", []byte(fmt.Sprintf("%s", host.HostIdentity())))
for _, neighbor := range host.Neighbors(ctx) {
host.Unicast(ctx, neighbor, "measurement", []byte(fmt.Sprintf("%s", host.HostIdentity())))
}
}
if err != nil {
Expand Down
24 changes: 13 additions & 11 deletions p2p.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ func NewHost(ctx context.Context, options ...Option) (*Host, error) {
if err != nil {
return nil, err
}
kad, err := dht.New(ctx, host, dht.ProtocolPrefix(ProtocolDHT))
kad, err := dht.New(ctx, host, dht.ProtocolPrefix(ProtocolDHT), dht.Mode(dht.ModeServer))
if err != nil {
}
if err := kad.Bootstrap(ctx); err != nil {
Expand Down Expand Up @@ -416,7 +416,7 @@ func (h *Host) AddUnicastPubSub(topic string, callback HandleUnicast) error {

// AddBroadcastPubSub adds a broadcast topic that the host will pay attention to. This need to be called before using
// Connect/JoinOverlay. Otherwise, pubsub may not be aware of the existing overlay topology
func (h *Host) AddBroadcastPubSub(topic string, callback HandleBroadcast) error {
func (h *Host) AddBroadcastPubSub(ctx context.Context, topic string, callback HandleBroadcast) error {
if _, ok := h.pubs[topic]; ok {
return nil
}
Expand Down Expand Up @@ -448,8 +448,9 @@ func (h *Host) AddBroadcastPubSub(topic string, callback HandleBroadcast) error
select {
case <-h.close:
return
case <-ctx.Done():
return
default:
ctx := context.Background()
msg, err := sub.Next(ctx)
if err != nil {
Logger().Error("Error when subscribing a broadcast message.", zap.Error(err))
Expand Down Expand Up @@ -479,6 +480,8 @@ func (h *Host) AddBroadcastPubSub(topic string, callback HandleBroadcast) error
select {
case <-h.close:
return
case <-ctx.Done():
return
default:
time.Sleep(h.cfg.BlockListCleanupInterval)
h.blacklists[topic].RemoveOldest()
Expand Down Expand Up @@ -582,17 +585,16 @@ func (h *Host) Info() core.PeerAddrInfo {

// Neighbors returns the closest peer addresses
func (h *Host) Neighbors(ctx context.Context) []core.PeerAddrInfo {
peers := h.host.Peerstore().Peers()
dedupedPeers := make(map[string]core.PeerID)
for _, p := range peers {
var (
dedup = make(map[string]bool)
neighbors = make([]core.PeerAddrInfo, 0)
)
for _, p := range h.host.Peerstore().Peers() {
idStr := p.Pretty()
if idStr == h.host.ID().Pretty() || idStr == "" {
if dedup[idStr] || idStr == h.host.ID().Pretty() || idStr == "" {
continue
}
dedupedPeers[idStr] = p
}
neighbors := make([]core.PeerAddrInfo, 0)
for _, p := range dedupedPeers {
dedup[idStr] = true
peer := h.kad.FindLocal(p)
if peer.ID != "" && len(peer.Addrs) > 0 && !h.unicastBlocklist.Blocked(peer.ID, time.Now()) {
neighbors = append(neighbors, peer)
Expand Down
7 changes: 5 additions & 2 deletions p2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"strconv"
"testing"
"time"

"github.com/stretchr/testify/require"
)
Expand All @@ -24,7 +25,7 @@ func TestBroadcast(t *testing.T) {
opts = append(opts, options...)
host, err := NewHost(ctx, opts...)
require.NoError(t, err)
require.NoError(t, host.AddBroadcastPubSub("test", func(ctx context.Context, data []byte) error {
require.NoError(t, host.AddBroadcastPubSub(ctx, "test", func(ctx context.Context, data []byte) error {
fmt.Print(string(data))
fmt.Printf(", received by %s\n", host.HostIdentity())
return nil
Expand All @@ -47,6 +48,7 @@ func TestBroadcast(t *testing.T) {
)
}

time.Sleep(100 * time.Millisecond)
for i := 0; i < n; i++ {
require.NoError(t, hosts[i].Close())
}
Expand Down Expand Up @@ -86,7 +88,7 @@ func TestUnicast(t *testing.T) {

for i, host := range hosts {
neighbors := host.Neighbors(ctx)
require.True(t, len(neighbors) > 0)
require.True(t, len(neighbors) >= n/3)

for _, neighbor := range neighbors {
require.NoError(
Expand All @@ -96,6 +98,7 @@ func TestUnicast(t *testing.T) {
}
}

time.Sleep(100 * time.Millisecond)
for i := 0; i < n; i++ {
require.NoError(t, hosts[i].Close())
}
Expand Down

0 comments on commit a1a4f02

Please sign in to comment.