Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cancel request and wait function #185

Merged
merged 8 commits into from
Aug 4, 2021
12 changes: 8 additions & 4 deletions graphsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ const (
RequestCancelled = ResponseStatusCode(35)
)

// RequestContextCancelledErr is an error message received on the error channel when the request context given by the user is cancelled/times out
type RequestContextCancelledErr struct{}
// RequestClientCancelledErr is an error message received on the error channel when the request is cancelled on by the client code,
// either by closing the passed request context or calling CancelRequest
type RequestClientCancelledErr struct{}

func (e RequestContextCancelledErr) Error() string {
return "Request Context Cancelled"
func (e RequestClientCancelledErr) Error() string {
return "Request Cancelled By Client"
masih marked this conversation as resolved.
Show resolved Hide resolved
}

// RequestFailedBusyErr is an error message received on the error channel when the peer is busy
Expand Down Expand Up @@ -369,4 +370,7 @@ type GraphExchange interface {

// CancelResponse cancels an in progress response
CancelResponse(peer.ID, RequestID) error

// CancelRequest cancels an in progress request
CancelRequest(context.Context, RequestID) error
}
7 changes: 6 additions & 1 deletion impl/graphsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (gs *GraphSync) RegisterIncomingRequestHook(hook graphsync.OnIncomingReques
return gs.incomingRequestHooks.Register(hook)
}

// RegisterIncomingRequestHook adds a hook that runs when a new incoming request is added
// RegisterIncomingRequestQueuedHook adds a hook that runs when a new incoming request is added
// to the responder's task queue.
func (gs *GraphSync) RegisterIncomingRequestQueuedHook(hook graphsync.OnIncomingRequestQueuedHook) graphsync.UnregisterHookFunc {
return gs.incomingRequestQueuedHooks.Register(hook)
Expand Down Expand Up @@ -296,6 +296,11 @@ func (gs *GraphSync) CancelResponse(p peer.ID, requestID graphsync.RequestID) er
return gs.responseManager.CancelResponse(p, requestID)
}

// CancelRequest cancels an in progress request
func (gs *GraphSync) CancelRequest(ctx context.Context, requestID graphsync.RequestID) error {
return gs.requestManager.CancelRequest(ctx, requestID)
}

type graphSyncReceiver GraphSync

func (gsr *graphSyncReceiver) graphSync() *GraphSync {
Expand Down
4 changes: 2 additions & 2 deletions impl/graphsync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ func TestNetworkDisconnect(t *testing.T) {

testutil.AssertReceive(ctx, t, networkError, &err, "should receive network error")
testutil.AssertReceive(ctx, t, errChan, &err, "should receive an error")
require.EqualError(t, err, graphsync.RequestContextCancelledErr{}.Error())
require.EqualError(t, err, graphsync.RequestClientCancelledErr{}.Error())
testutil.AssertReceive(ctx, t, receiverError, &err, "should receive an error on receiver side")
}

Expand Down Expand Up @@ -653,7 +653,7 @@ func TestConnectFail(t *testing.T) {
var err error
testutil.AssertReceive(ctx, t, reqNetworkError, &err, "should receive network error")
testutil.AssertReceive(ctx, t, errChan, &err, "should receive an error")
require.EqualError(t, err, graphsync.RequestContextCancelledErr{}.Error())
require.EqualError(t, err, graphsync.RequestClientCancelledErr{}.Error())
}

func TestGraphsyncRoundTripAlternatePersistenceAndNodes(t *testing.T) {
Expand Down
7 changes: 7 additions & 0 deletions ipldutil/traverser.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ import (
"github.com/ipld/go-ipld-prime/traversal/selector"
)

/* TODO: This traverser creates an extra go-routine and is quite complicated, in order to give calling code control of
a selector traversal. If it were implemented inside of go-ipld-primes traversal library, with access to private functions,
it could be done without an extra go-routine, avoiding the possibility of races and simplifying implementation. This has
been documented here: https://github.com/ipld/go-ipld-prime/issues/213 -- and when this issue is implemented, this traverser
can go away */

var defaultVisitor traversal.AdvVisitFn = func(traversal.Progress, ipld.Node, traversal.VisitReason) error { return nil }

// ContextCancelError is a sentinel that indicates the passed in context
Expand Down Expand Up @@ -137,6 +143,7 @@ func (t *traverser) writeDone(err error) {
func (t *traverser) start() {
select {
case <-t.ctx.Done():
close(t.stopped)
return
case t.awaitRequest <- struct{}{}:
}
Expand Down
17 changes: 17 additions & 0 deletions ipldutil/traverser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"testing"
"time"

blocks "github.com/ipfs/go-block-format"
ipld "github.com/ipld/go-ipld-prime"
Expand All @@ -21,6 +22,22 @@ import (
func TestTraverser(t *testing.T) {
ctx := context.Background()

t.Run("started with shutdown context, then shutdown", func(t *testing.T) {
cancelledCtx, cancel := context.WithCancel(ctx)
cancel()
testdata := testutil.NewTestIPLDTree()
ssb := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any)
sel := ssb.ExploreRecursive(selector.RecursionLimitNone(), ssb.ExploreAll(ssb.ExploreRecursiveEdge())).Node()
traverser := TraversalBuilder{
Root: testdata.RootNodeLnk,
Selector: sel,
}.Start(cancelledCtx)
timeoutCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
traverser.Shutdown(timeoutCtx)
require.NoError(t, timeoutCtx.Err())
})

t.Run("traverses correctly, simple struct", func(t *testing.T) {
testdata := testutil.NewTestIPLDTree()
ssb := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any)
Expand Down
10 changes: 5 additions & 5 deletions requestmanager/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type ExecutionEnv struct {
type RequestExecution struct {
Ctx context.Context
P peer.ID
NetworkError chan error
TerminalError chan error
Request gsmsg.GraphSyncRequest
LastResponse *atomic.Value
DoNotSendCids *cid.Set
Expand All @@ -54,7 +54,7 @@ func (ee ExecutionEnv) Start(re RequestExecution) (chan graphsync.ResponseProgre
inProgressErr: make(chan error),
ctx: re.Ctx,
p: re.P,
networkError: re.NetworkError,
terminalError: re.TerminalError,
request: re.Request,
lastResponse: re.LastResponse,
doNotSendCids: re.DoNotSendCids,
Expand All @@ -73,7 +73,7 @@ type requestExecutor struct {
inProgressErr chan error
ctx context.Context
p peer.ID
networkError chan error
terminalError chan error
request gsmsg.GraphSyncRequest
lastResponse *atomic.Value
nodeStyleChooser traversal.LinkTargetNodePrototypeChooser
Expand Down Expand Up @@ -153,9 +153,9 @@ func (re *requestExecutor) run() {
}
}
select {
case networkError := <-re.networkError:
case terminalError := <-re.terminalError:
select {
case re.inProgressErr <- networkError:
case re.inProgressErr <- terminalError:
case <-re.env.Ctx.Done():
}
default:
Expand Down
67 changes: 52 additions & 15 deletions requestmanager/requestmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ type inProgressRequestStatus struct {
startTime time.Time
cancelFn func()
p peer.ID
networkError chan error
terminalError chan error
resumeMessages chan []graphsync.ExtensionData
pauseMessages chan struct{}
paused bool
lastResponse atomic.Value
onTerminated []chan error
}

// PeerHandler is an interface that can send requests to peers
Expand Down Expand Up @@ -234,8 +235,10 @@ func (rm *RequestManager) singleErrorResponse(err error) (chan graphsync.Respons
}

type cancelRequestMessage struct {
requestID graphsync.RequestID
isPause bool
requestID graphsync.RequestID
isPause bool
onTerminated chan error
terminalError error
}

func (rm *RequestManager) cancelRequest(requestID graphsync.RequestID,
Expand All @@ -244,7 +247,7 @@ func (rm *RequestManager) cancelRequest(requestID graphsync.RequestID,
cancelMessageChannel := rm.messages
for cancelMessageChannel != nil || incomingResponses != nil || incomingErrors != nil {
select {
case cancelMessageChannel <- &cancelRequestMessage{requestID, false}:
case cancelMessageChannel <- &cancelRequestMessage{requestID, false, nil, nil}:
cancelMessageChannel = nil
// clear out any remaining responses, in case and "incoming reponse"
// messages get processed before our cancel message
Expand All @@ -262,6 +265,12 @@ func (rm *RequestManager) cancelRequest(requestID graphsync.RequestID,
}
}

// CancelRequest cancels the given request ID and waits for the request to terminate
func (rm *RequestManager) CancelRequest(ctx context.Context, requestID graphsync.RequestID) error {
terminated := make(chan error, 1)
return rm.sendSyncMessage(&cancelRequestMessage{requestID, false, terminated, graphsync.RequestClientCancelledErr{}}, terminated, ctx.Done())
}

type processResponseMessage struct {
p peer.ID
responses []gsmsg.GraphSyncResponse
Expand All @@ -288,7 +297,7 @@ type unpauseRequestMessage struct {
// Can also send extensions with unpause
func (rm *RequestManager) UnpauseRequest(requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error {
response := make(chan error, 1)
return rm.sendSyncMessage(&unpauseRequestMessage{requestID, extensions, response}, response)
return rm.sendSyncMessage(&unpauseRequestMessage{requestID, extensions, response}, response, nil)
}

type pauseRequestMessage struct {
Expand All @@ -299,18 +308,22 @@ type pauseRequestMessage struct {
// PauseRequest pauses an in progress request (may take 1 or more blocks to process)
func (rm *RequestManager) PauseRequest(requestID graphsync.RequestID) error {
response := make(chan error, 1)
return rm.sendSyncMessage(&pauseRequestMessage{requestID, response}, response)
return rm.sendSyncMessage(&pauseRequestMessage{requestID, response}, response, nil)
}

func (rm *RequestManager) sendSyncMessage(message requestManagerMessage, response chan error) error {
func (rm *RequestManager) sendSyncMessage(message requestManagerMessage, response chan error, done <-chan struct{}) error {
select {
case <-rm.ctx.Done():
return errors.New("Context Cancelled")
case <-done:
return errors.New("Context Cancelled")
case rm.messages <- message:
}
select {
case <-rm.ctx.Done():
return errors.New("Context Cancelled")
case <-done:
return errors.New("Context Cancelled")
dirkmc marked this conversation as resolved.
Show resolved Hide resolved
case err := <-response:
return err
}
Expand Down Expand Up @@ -374,9 +387,9 @@ func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *Re
p := nrm.p
resumeMessages := make(chan []graphsync.ExtensionData, 1)
pauseMessages := make(chan struct{}, 1)
networkError := make(chan error, 1)
terminalError := make(chan error, 1)
requestStatus := &inProgressRequestStatus{
ctx: ctx, startTime: time.Now(), cancelFn: cancel, p: p, resumeMessages: resumeMessages, pauseMessages: pauseMessages, networkError: networkError,
ctx: ctx, startTime: time.Now(), cancelFn: cancel, p: p, resumeMessages: resumeMessages, pauseMessages: pauseMessages, terminalError: terminalError,
}
lastResponse := &requestStatus.lastResponse
lastResponse.Store(gsmsg.NewResponse(request.ID(), graphsync.RequestAcknowledged))
Expand All @@ -392,7 +405,7 @@ func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *Re
Ctx: ctx,
P: p,
Request: request,
NetworkError: networkError,
TerminalError: terminalError,
LastResponse: lastResponse,
DoNotSendCids: doNotSendCids,
NodePrototypeChooser: hooksResult.CustomChooser,
Expand Down Expand Up @@ -421,14 +434,38 @@ func (trm *terminateRequestMessage) handle(rm *RequestManager) {
}
delete(rm.inProgressRequestStatuses, trm.requestID)
rm.asyncLoader.CleanupRequest(trm.requestID)
if ok {
for _, onTerminated := range ipr.onTerminated {
select {
case <-rm.ctx.Done():
case onTerminated <- nil:
}
}
}
}

func (crm *cancelRequestMessage) handle(rm *RequestManager) {
inProgressRequestStatus, ok := rm.inProgressRequestStatuses[crm.requestID]
if !ok {
if crm.onTerminated != nil {
select {
case crm.onTerminated <- errors.New("request not found"):
case <-rm.ctx.Done():
}
}
return
}

if crm.onTerminated != nil {
inProgressRequestStatus.onTerminated = append(inProgressRequestStatus.onTerminated, crm.onTerminated)
}
if crm.terminalError != nil {
select {
case inProgressRequestStatus.terminalError <- crm.terminalError:
default:
}
}

rm.sendRequest(inProgressRequestStatus.p, gsmsg.CancelRequest(crm.requestID))
if crm.isPause {
inProgressRequestStatus.paused = true
Expand Down Expand Up @@ -488,8 +525,8 @@ func (rm *RequestManager) processExtensionsForResponse(p peer.ID, response gsmsg
}
responseError := rm.generateResponseErrorFromStatus(graphsync.RequestFailedUnknown)
select {
case requestStatus.networkError <- responseError:
case <-requestStatus.ctx.Done():
case requestStatus.terminalError <- responseError:
default:
}
rm.sendRequest(p, gsmsg.CancelRequest(response.RequestID()))
requestStatus.cancelFn()
Expand All @@ -505,8 +542,8 @@ func (rm *RequestManager) processTerminations(responses []gsmsg.GraphSyncRespons
requestStatus := rm.inProgressRequestStatuses[response.RequestID()]
responseError := rm.generateResponseErrorFromStatus(response.Status())
select {
case requestStatus.networkError <- responseError:
case <-requestStatus.ctx.Done():
case requestStatus.terminalError <- responseError:
default:
}
requestStatus.cancelFn()
}
Expand Down Expand Up @@ -542,7 +579,7 @@ func (rm *RequestManager) processBlockHooks(p peer.ID, response graphsync.Respon
_, isPause := result.Err.(hooks.ErrPaused)
select {
case <-rm.ctx.Done():
case rm.messages <- &cancelRequestMessage{response.RequestID(), isPause}:
case rm.messages <- &cancelRequestMessage{response.RequestID(), isPause, nil, nil}:
}
}
return result.Err
Expand Down
Loading