diff --git a/go.mod b/go.mod index 2b318b7db1..6070c2debd 100644 --- a/go.mod +++ b/go.mod @@ -34,6 +34,7 @@ require ( github.com/multiformats/go-multiaddr-dns v0.2.0 github.com/multiformats/go-multiaddr-net v0.1.3 github.com/multiformats/go-multistream v0.1.1 + github.com/stretchr/testify v1.4.0 github.com/whyrusleeping/mdns v0.0.0-20190826153040-b9b60ed33aa9 ) diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index f64e1840f8..ca721eb54e 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -4,6 +4,7 @@ import ( "context" "io" "net" + "sync" "time" "github.com/libp2p/go-libp2p/p2p/protocol/identify" @@ -21,8 +22,6 @@ import ( inat "github.com/libp2p/go-libp2p-nat" logging "github.com/ipfs/go-log" - "github.com/jbenet/goprocess" - goprocessctx "github.com/jbenet/goprocess/context" ma "github.com/multiformats/go-multiaddr" madns "github.com/multiformats/go-multiaddr-dns" @@ -68,6 +67,13 @@ const NATPortMap Option = iota // * uses an identity service to send + receive node information // * uses a nat service to establish NAT port mappings type BasicHost struct { + ctx context.Context + ctxCancel context.CancelFunc + // ensures we shutdown ONLY once + closeSync sync.Once + // keep track of resources we need to wait on before shutting down + refCount sync.WaitGroup + network network.Network mux *msmux.MultistreamMuxer ids *identify.IDService @@ -81,8 +87,6 @@ type BasicHost struct { negtimeout time.Duration - proc goprocess.Process - emitters struct { evtLocalProtocolsUpdated event.Emitter evtLocalAddrsUpdated event.Emitter @@ -128,6 +132,8 @@ type HostOpts struct { // NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network. func NewHost(ctx context.Context, net network.Network, opts *HostOpts) (*BasicHost, error) { + hostCtx, cancel := context.WithCancel(ctx) + h := &BasicHost{ network: net, mux: msmux.NewMultistreamMuxer(), @@ -136,6 +142,8 @@ func NewHost(ctx context.Context, net network.Network, opts *HostOpts) (*BasicHo maResolver: madns.DefaultResolver, eventbus: eventbus.NewBus(), addrChangeChan: make(chan struct{}, 1), + ctx: hostCtx, + ctxCancel: cancel, } var err error @@ -146,28 +154,12 @@ func NewHost(ctx context.Context, net network.Network, opts *HostOpts) (*BasicHo return nil, err } - h.proc = goprocessctx.WithContextAndTeardown(ctx, func() error { - if h.natmgr != nil { - h.natmgr.Close() - } - if h.cmgr != nil { - h.cmgr.Close() - } - _ = h.emitters.evtLocalProtocolsUpdated.Close() - _ = h.emitters.evtLocalAddrsUpdated.Close() - return h.Network().Close() - }) - if opts.MultistreamMuxer != nil { h.mux = opts.MultistreamMuxer } // we can't set this as a default above because it depends on the *BasicHost. - h.ids = identify.NewIDService( - goprocessctx.WithProcessClosing(ctx, h.proc), - h, - identify.UserAgent(opts.UserAgent), - ) + h.ids = identify.NewIDService(h, identify.UserAgent(opts.UserAgent)) if uint64(opts.NegotiationTimeout) != 0 { h.negtimeout = opts.NegotiationTimeout @@ -242,7 +234,8 @@ func New(net network.Network, opts ...interface{}) *BasicHost { // Start starts background tasks in the host func (h *BasicHost) Start() { - h.proc.Go(h.background) + h.refCount.Add(1) + go h.background() } // newConnHandler is the remote-opened conn handler for inet.Network @@ -343,7 +336,9 @@ func makeUpdatedAddrEvent(prev, current []ma.Multiaddr) *event.EvtLocalAddresses return &evt } -func (h *BasicHost) background(p goprocess.Process) { +func (h *BasicHost) background() { + defer h.refCount.Done() + // periodically schedules an IdentifyPush to update our peers for changes // in our address set (if needed) ticker := time.NewTicker(10 * time.Second) @@ -356,7 +351,7 @@ func (h *BasicHost) background(p goprocess.Process) { select { case <-ticker.C: case <-h.addrChangeChan: - case <-p.Closing(): + case <-h.ctx.Done(): return } @@ -805,14 +800,26 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr { // Close shuts down the Host's services (network, etc). func (h *BasicHost) Close() error { - // You're thinking of adding some teardown logic here, right? Well - // don't! Add any process teardown logic to the teardown function in the - // constructor. - // - // This: - // 1. May be called multiple times. - // 2. May _never_ be called if the host is stopped by the context. - return h.proc.Close() + h.closeSync.Do(func() { + h.ctxCancel() + if h.natmgr != nil { + h.natmgr.Close() + } + if h.cmgr != nil { + h.cmgr.Close() + } + if h.ids != nil { + h.ids.Close() + } + + _ = h.emitters.evtLocalProtocolsUpdated.Close() + _ = h.emitters.evtLocalAddrsUpdated.Close() + h.Network().Close() + + h.refCount.Wait() + }) + + return nil } type streamWrapper struct { diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index c0dee8e571..bc6df3edee 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -22,6 +22,7 @@ import ( swarmt "github.com/libp2p/go-libp2p-swarm/testing" ma "github.com/multiformats/go-multiaddr" madns "github.com/multiformats/go-multiaddr-dns" + "github.com/stretchr/testify/require" ) func TestHostDoubleClose(t *testing.T) { @@ -80,6 +81,16 @@ func TestHostSimple(t *testing.T) { } } +func TestMultipleClose(t *testing.T) { + ctx := context.Background() + h := New(swarmt.GenSwarm(t, ctx)) + + require.NoError(t, h.Close()) + require.NoError(t, h.Close()) + require.NoError(t, h.Close()) + +} + func TestProtocolHandlerEvents(t *testing.T) { ctx := context.Background() h := New(swarmt.GenSwarm(t, ctx)) diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index a0e1ae25c3..5df656f60c 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -7,7 +7,6 @@ import ( "sync" "time" - "github.com/libp2p/go-eventbus" ic "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/event" "github.com/libp2p/go-libp2p-core/helpers" @@ -17,6 +16,7 @@ import ( "github.com/libp2p/go-libp2p-core/peerstore" "github.com/libp2p/go-libp2p-core/protocol" + "github.com/libp2p/go-eventbus" pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb" ggio "github.com/gogo/protobuf/io" @@ -71,7 +71,12 @@ type IDService struct { Host host.Host UserAgent string - ctx context.Context + ctx context.Context + ctxCancel context.CancelFunc + // ensure we shutdown ONLY once + closeSync sync.Once + // track resources that need to be shut down before we shut down + refCount sync.WaitGroup // connections undergoing identification // for wait purposes @@ -94,7 +99,7 @@ type IDService struct { // NewIDService constructs a new *IDService and activates it by // attaching its stream handler to the given host.Host. -func NewIDService(ctx context.Context, h host.Host, opts ...Option) *IDService { +func NewIDService(h host.Host, opts ...Option) *IDService { var cfg config for _, opt := range opts { opt(&cfg) @@ -105,13 +110,15 @@ func NewIDService(ctx context.Context, h host.Host, opts ...Option) *IDService { userAgent = cfg.userAgent } + hostCtx, cancel := context.WithCancel(context.Background()) s := &IDService{ Host: h, UserAgent: userAgent, - ctx: ctx, + ctx: hostCtx, + ctxCancel: cancel, currid: make(map[network.Conn]chan struct{}), - observedAddrs: NewObservedAddrSet(ctx), + observedAddrs: NewObservedAddrSet(hostCtx), } // handle local protocol handler updates, and push deltas to peers. @@ -120,6 +127,7 @@ func NewIDService(ctx context.Context, h host.Host, opts ...Option) *IDService { if err != nil { log.Warningf("identify service not subscribed to local protocol handlers updates; err: %s", err) } else { + s.refCount.Add(1) go s.handleEvents() } @@ -143,14 +151,19 @@ func NewIDService(ctx context.Context, h host.Host, opts ...Option) *IDService { return s } +// Close shuts down the IDService +func (ids *IDService) Close() error { + ids.closeSync.Do(func() { + ids.ctxCancel() + ids.refCount.Wait() + }) + return nil +} + func (ids *IDService) handleEvents() { sub := ids.subscription - defer func() { - _ = sub.Close() - // drain the channel. - for range sub.Out() { - } - }() + defer ids.refCount.Done() + defer sub.Close() for { select { diff --git a/p2p/protocol/identify/id_test.go b/p2p/protocol/identify/id_test.go index 234dbe4891..e6d50c95c4 100644 --- a/p2p/protocol/identify/id_test.go +++ b/p2p/protocol/identify/id_test.go @@ -38,8 +38,10 @@ func subtestIDService(t *testing.T) { h1p := h1.ID() h2p := h2.ID() - ids1 := identify.NewIDService(ctx, h1) - ids2 := identify.NewIDService(ctx, h2) + ids1 := identify.NewIDService(h1) + ids2 := identify.NewIDService(h2) + defer ids1.Close() + defer ids2.Close() testKnowsAddrs(t, h1, h2p, []ma.Multiaddr{}) // nothing testKnowsAddrs(t, h2, h1p, []ma.Multiaddr{}) // nothing @@ -253,9 +255,14 @@ func TestLocalhostAddrFiltering(t *testing.T) { Addrs: p2addrs[1:], }) - _ = identify.NewIDService(ctx, p1) - ids2 := identify.NewIDService(ctx, p2) - ids3 := identify.NewIDService(ctx, p3) + ids1 := identify.NewIDService(p1) + ids2 := identify.NewIDService(p2) + ids3 := identify.NewIDService(p3) + defer func() { + ids1.Close() + ids2.Close() + ids3.Close() + }() conns := p2.Network().ConnsToPeer(id1) if len(conns) == 0 { @@ -291,8 +298,12 @@ func TestIdentifyDeltaOnProtocolChange(t *testing.T) { h2.SetStreamHandler(protocol.TestingID, func(_ network.Stream) {}) - ids1 := identify.NewIDService(ctx, h1) - _ = identify.NewIDService(ctx, h2) + ids1 := identify.NewIDService(h1) + ids2 := identify.NewIDService(h2) + defer func() { + ids1.Close() + ids2.Close() + }() if err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}); err != nil { t.Fatal(err) @@ -404,8 +415,10 @@ func TestIdentifyDeltaWhileIdentifyingConn(t *testing.T) { defer h2.Close() defer h1.Close() - _ = identify.NewIDService(ctx, h1) - ids2 := identify.NewIDService(ctx, h2) + ids1 := identify.NewIDService(h1) + ids2 := identify.NewIDService(h2) + defer ids1.Close() + defer ids2.Close() // replace the original identify handler by one that blocks until we close the block channel. // this allows us to control how long identify runs.