From 856754b765a5a94d865a4ac3a07a80999271604a Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Tue, 7 Jul 2020 23:44:32 -0700 Subject: [PATCH] fix(traverser): fix race condition for shutdown make sure that the traverser is finished in the request executor --- ipldutil/traverser.go | 20 ++++++++++++++++++- requestmanager/executor/executor.go | 2 +- .../runtraversal/runtraversal_test.go | 3 +++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/ipldutil/traverser.go b/ipldutil/traverser.go index 01f6698a..4ca32910 100644 --- a/ipldutil/traverser.go +++ b/ipldutil/traverser.go @@ -39,6 +39,8 @@ type Traverser interface { Advance(reader io.Reader) error // Error errors the traversal by returning the given error as the result of the next IPLD load Error(err error) + // Shutdown cancels the traversal + Shutdown() } type state struct { @@ -55,9 +57,12 @@ type nextResponse struct { // Start initiates the traversal (run in a go routine because the regular // selector traversal expects a call back) -func (tb TraversalBuilder) Start(ctx context.Context) Traverser { +func (tb TraversalBuilder) Start(parentCtx context.Context) Traverser { + ctx, cancel := context.WithCancel(parentCtx) t := &traverser{ + parentCtx: parentCtx, ctx: ctx, + cancel: cancel, root: tb.Root, selector: tb.Selector, visitor: defaultVisitor, @@ -65,6 +70,7 @@ func (tb TraversalBuilder) Start(ctx context.Context) Traverser { awaitRequest: make(chan struct{}, 1), stateChan: make(chan state, 1), responses: make(chan nextResponse), + stopped: make(chan struct{}), } if tb.Visitor != nil { t.visitor = tb.Visitor @@ -79,7 +85,9 @@ func (tb TraversalBuilder) Start(ctx context.Context) Traverser { // traverser is a class to perform a selector traversal that stops every time a new block is loaded // and waits for manual input (in the form of advance or error) type traverser struct { + parentCtx context.Context ctx context.Context + cancel func() root ipld.Link selector ipld.Node visitor traversal.AdvVisitFn @@ -91,6 +99,7 @@ type traverser struct { awaitRequest chan struct{} stateChan chan state responses chan nextResponse + stopped chan struct{} } func (t *traverser) checkState() { @@ -124,6 +133,7 @@ func (t *traverser) start() { case t.awaitRequest <- struct{}{}: } go func() { + defer close(t.stopped) loader := func(lnk ipld.Link, lnkCtx ipld.LinkContext) (io.Reader, error) { select { case <-t.ctx.Done(): @@ -166,6 +176,14 @@ func (t *traverser) start() { }() } +func (t *traverser) Shutdown() { + t.cancel() + select { + case <-t.parentCtx.Done(): + case <-t.stopped: + } +} + // IsComplete returns true if a traversal is complete func (t *traverser) IsComplete() (bool, error) { t.checkState() diff --git a/requestmanager/executor/executor.go b/requestmanager/executor/executor.go index e9ae1922..441c7f38 100644 --- a/requestmanager/executor/executor.go +++ b/requestmanager/executor/executor.go @@ -99,7 +99,7 @@ func (re *requestExecutor) traverse() error { Visitor: re.visitor, Chooser: re.nodeStyleChooser, }.Start(re.ctx) - + defer traverser.Shutdown() for { isComplete, err := traverser.IsComplete() if isComplete { diff --git a/responsemanager/runtraversal/runtraversal_test.go b/responsemanager/runtraversal/runtraversal_test.go index d4ca9f11..49b60530 100644 --- a/responsemanager/runtraversal/runtraversal_test.go +++ b/responsemanager/runtraversal/runtraversal_test.go @@ -96,6 +96,9 @@ func (ft *fakeTraverser) Error(err error) { ft.receivedOutcomes = append(ft.receivedOutcomes, traverseOutcome{true, err, nil}) } +// Shutdown cancels the traversal if still in progress +func (ft *fakeTraverser) Shutdown() {} + func (ft *fakeTraverser) verifyExpectations(t *testing.T) { require.Equal(t, ft.expectedOutcomes, ft.receivedOutcomes) }