Skip to content

Commit

Permalink
fix(traverser): fix race condition for shutdown
Browse files Browse the repository at this point in the history
make sure that the traverser is finished in the request executor
  • Loading branch information
hannahhoward committed Jul 8, 2020
1 parent 642930f commit 856754b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
20 changes: 19 additions & 1 deletion ipldutil/traverser.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ type Traverser interface {
Advance(reader io.Reader) error
// Error errors the traversal by returning the given error as the result of the next IPLD load
Error(err error)
// Shutdown cancels the traversal
Shutdown()
}

type state struct {
Expand All @@ -55,16 +57,20 @@ type nextResponse struct {

// Start initiates the traversal (run in a go routine because the regular
// selector traversal expects a call back)
func (tb TraversalBuilder) Start(ctx context.Context) Traverser {
func (tb TraversalBuilder) Start(parentCtx context.Context) Traverser {
ctx, cancel := context.WithCancel(parentCtx)
t := &traverser{
parentCtx: parentCtx,
ctx: ctx,
cancel: cancel,
root: tb.Root,
selector: tb.Selector,
visitor: defaultVisitor,
chooser: defaultChooser,
awaitRequest: make(chan struct{}, 1),
stateChan: make(chan state, 1),
responses: make(chan nextResponse),
stopped: make(chan struct{}),
}
if tb.Visitor != nil {
t.visitor = tb.Visitor
Expand All @@ -79,7 +85,9 @@ func (tb TraversalBuilder) Start(ctx context.Context) Traverser {
// traverser is a class to perform a selector traversal that stops every time a new block is loaded
// and waits for manual input (in the form of advance or error)
type traverser struct {
parentCtx context.Context
ctx context.Context
cancel func()
root ipld.Link
selector ipld.Node
visitor traversal.AdvVisitFn
Expand All @@ -91,6 +99,7 @@ type traverser struct {
awaitRequest chan struct{}
stateChan chan state
responses chan nextResponse
stopped chan struct{}
}

func (t *traverser) checkState() {
Expand Down Expand Up @@ -124,6 +133,7 @@ func (t *traverser) start() {
case t.awaitRequest <- struct{}{}:
}
go func() {
defer close(t.stopped)
loader := func(lnk ipld.Link, lnkCtx ipld.LinkContext) (io.Reader, error) {
select {
case <-t.ctx.Done():
Expand Down Expand Up @@ -166,6 +176,14 @@ func (t *traverser) start() {
}()
}

func (t *traverser) Shutdown() {
t.cancel()
select {
case <-t.parentCtx.Done():
case <-t.stopped:
}
}

// IsComplete returns true if a traversal is complete
func (t *traverser) IsComplete() (bool, error) {
t.checkState()
Expand Down
2 changes: 1 addition & 1 deletion requestmanager/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (re *requestExecutor) traverse() error {
Visitor: re.visitor,
Chooser: re.nodeStyleChooser,
}.Start(re.ctx)

defer traverser.Shutdown()
for {
isComplete, err := traverser.IsComplete()
if isComplete {
Expand Down
3 changes: 3 additions & 0 deletions responsemanager/runtraversal/runtraversal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ func (ft *fakeTraverser) Error(err error) {
ft.receivedOutcomes = append(ft.receivedOutcomes, traverseOutcome{true, err, nil})
}

// Shutdown cancels the traversal if still in progress
func (ft *fakeTraverser) Shutdown() {}

func (ft *fakeTraverser) verifyExpectations(t *testing.T) {
require.Equal(t, ft.expectedOutcomes, ft.receivedOutcomes)
}
Expand Down

0 comments on commit 856754b

Please sign in to comment.