Skip to content

Commit

Permalink
Remove crosschain leftovers (#3309)
Browse files Browse the repository at this point in the history
  • Loading branch information
ceyonur authored Aug 21, 2024
1 parent bfe9fbb commit a5acbaa
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 102 deletions.
4 changes: 1 addition & 3 deletions ids/request_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ package ids
type RequestID struct {
// The node this request came from
NodeID NodeID
// The chain this request came from
SourceChainID ID
// The chain the expected response should come from
DestinationChainID ID
ChainID ID
// The unique identifier for this request
RequestID uint32
// The message opcode
Expand Down
12 changes: 0 additions & 12 deletions message/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,6 @@ func GetChainID(m any) (ids.ID, error) {
return ids.ToID(chainIDBytes)
}

type sourceChainIDGetter interface {
GetSourceChainID() ids.ID
}

func GetSourceChainID(m any) (ids.ID, error) {
msg, ok := m.(sourceChainIDGetter)
if !ok {
return GetChainID(m)
}
return msg.GetSourceChainID(), nil
}

type requestIDGetter interface {
GetRequestId() uint32
}
Expand Down
64 changes: 21 additions & 43 deletions snow/networking/router/chain_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (cr *ChainRouter) Initialize(
}

// RegisterRequest marks that we should expect to receive a reply for a request
// issued by [requestingChainID] from the given node's [respondingChainID] and
// from the given node's [chainID] and
// the reply should have the given requestID.
//
// The type of message we expect is [op].
Expand All @@ -148,8 +148,7 @@ func (cr *ChainRouter) Initialize(
func (cr *ChainRouter) RegisterRequest(
ctx context.Context,
nodeID ids.NodeID,
requestingChainID ids.ID,
respondingChainID ids.ID,
chainID ids.ID,
requestID uint32,
op message.Op,
timeoutMsg message.InboundMessage,
Expand All @@ -159,8 +158,7 @@ func (cr *ChainRouter) RegisterRequest(
if cr.closing {
cr.log.Debug("dropping request",
zap.Stringer("nodeID", nodeID),
zap.Stringer("requestingChainID", requestingChainID),
zap.Stringer("respondingChainID", respondingChainID),
zap.Stringer("chainID", chainID),
zap.Uint32("requestID", requestID),
zap.Stringer("messageOp", op),
zap.Error(errClosing),
Expand All @@ -171,16 +169,11 @@ func (cr *ChainRouter) RegisterRequest(
// When we receive a response message type (Chits, Put, Accepted, etc.)
// we validate that we actually sent the corresponding request.
// Give this request a unique ID so we can do that validation.
//
// For cross-chain messages, the responding chain is the source of the
// response which is sent to the requester which is the destination,
// which is why we flip the two in request id generation.
uniqueRequestID := ids.RequestID{
NodeID: nodeID,
SourceChainID: respondingChainID,
DestinationChainID: requestingChainID,
RequestID: requestID,
Op: byte(op),
NodeID: nodeID,
ChainID: chainID,
RequestID: requestID,
Op: byte(op),
}
// Add to the set of unfulfilled requests
cr.timedRequests.Put(uniqueRequestID, requestEntry{
Expand All @@ -203,7 +196,7 @@ func (cr *ChainRouter) RegisterRequest(
// Register a timeout to fire if we don't get a reply in time.
cr.timeoutManager.RegisterRequest(
nodeID,
respondingChainID,
chainID,
shouldMeasureLatency,
uniqueRequestID,
func() {
Expand All @@ -217,7 +210,7 @@ func (cr *ChainRouter) HandleInbound(ctx context.Context, msg message.InboundMes
op := msg.Op()

m := msg.Message()
destinationChainID, err := message.GetChainID(m)
chainID, err := message.GetChainID(m)
if err != nil {
cr.log.Debug("dropping message with invalid field",
zap.Stringer("nodeID", nodeID),
Expand All @@ -230,19 +223,6 @@ func (cr *ChainRouter) HandleInbound(ctx context.Context, msg message.InboundMes
return
}

sourceChainID, err := message.GetSourceChainID(m)
if err != nil {
cr.log.Debug("dropping message with invalid field",
zap.Stringer("nodeID", nodeID),
zap.Stringer("messageOp", op),
zap.String("field", "SourceChainID"),
zap.Error(err),
)

msg.OnFinishedHandling()
return
}

requestID, ok := message.GetRequestID(m)
if !ok {
cr.log.Debug("dropping message with invalid field",
Expand All @@ -262,20 +242,20 @@ func (cr *ChainRouter) HandleInbound(ctx context.Context, msg message.InboundMes
cr.log.Debug("dropping message",
zap.Stringer("messageOp", op),
zap.Stringer("nodeID", nodeID),
zap.Stringer("chainID", destinationChainID),
zap.Stringer("chainID", chainID),
zap.Error(errClosing),
)
msg.OnFinishedHandling()
return
}

// Get the chain, if it exists
chain, exists := cr.chainHandlers[destinationChainID]
chain, exists := cr.chainHandlers[chainID]
if !exists {
cr.log.Debug("dropping message",
zap.Stringer("messageOp", op),
zap.Stringer("nodeID", nodeID),
zap.Stringer("chainID", destinationChainID),
zap.Stringer("chainID", chainID),
zap.Error(errUnknownChain),
)
msg.OnFinishedHandling()
Expand All @@ -286,7 +266,7 @@ func (cr *ChainRouter) HandleInbound(ctx context.Context, msg message.InboundMes
cr.log.Debug("dropping message",
zap.Stringer("messageOp", op),
zap.Stringer("nodeID", nodeID),
zap.Stringer("chainID", destinationChainID),
zap.Stringer("chainID", chainID),
zap.Error(errUnallowedNode),
)
msg.OnFinishedHandling()
Expand Down Expand Up @@ -321,7 +301,7 @@ func (cr *ChainRouter) HandleInbound(ctx context.Context, msg message.InboundMes
if expectedResponse, isFailed := message.FailedToResponseOps[op]; isFailed {
// Create the request ID of the request we sent that this message is in
// response to.
uniqueRequestID, req := cr.clearRequest(expectedResponse, nodeID, sourceChainID, destinationChainID, requestID)
uniqueRequestID, req := cr.clearRequest(expectedResponse, nodeID, chainID, requestID)
if req == nil {
// This was a duplicated response.
msg.OnFinishedHandling()
Expand Down Expand Up @@ -352,7 +332,7 @@ func (cr *ChainRouter) HandleInbound(ctx context.Context, msg message.InboundMes
return
}

uniqueRequestID, req := cr.clearRequest(op, nodeID, sourceChainID, destinationChainID, requestID)
uniqueRequestID, req := cr.clearRequest(op, nodeID, chainID, requestID)
if req == nil {
// We didn't request this message.
msg.OnFinishedHandling()
Expand All @@ -363,7 +343,7 @@ func (cr *ChainRouter) HandleInbound(ctx context.Context, msg message.InboundMes
latency := cr.clock.Time().Sub(req.time)

// Tell the timeout manager we got a response
cr.timeoutManager.RegisterResponse(nodeID, destinationChainID, uniqueRequestID, req.op, latency)
cr.timeoutManager.RegisterResponse(nodeID, chainID, uniqueRequestID, req.op, latency)

// Pass the response to the chain
chain.Push(
Expand Down Expand Up @@ -736,17 +716,15 @@ func (cr *ChainRouter) removeChain(ctx context.Context, chainID ids.ID) {
func (cr *ChainRouter) clearRequest(
op message.Op,
nodeID ids.NodeID,
sourceChainID ids.ID,
destinationChainID ids.ID,
chainID ids.ID,
requestID uint32,
) (ids.RequestID, *requestEntry) {
// Create the request ID of the request we sent that this message is (allegedly) in response to.
uniqueRequestID := ids.RequestID{
NodeID: nodeID,
SourceChainID: sourceChainID,
DestinationChainID: destinationChainID,
RequestID: requestID,
Op: byte(op),
NodeID: nodeID,
ChainID: chainID,
RequestID: requestID,
Op: byte(op),
}
// Mark that an outstanding request has been fulfilled
request, exists := cr.timedRequests.Get(uniqueRequestID)
Expand Down
13 changes: 1 addition & 12 deletions snow/networking/router/chain_router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,6 @@ func TestRouterTimeout(t *testing.T) {
context.Background(),
nodeID,
ctx.ChainID,
ctx.ChainID,
requestID,
message.StateSummaryFrontierOp,
message.InternalGetStateSummaryFrontierFailed(
Expand All @@ -653,7 +652,6 @@ func TestRouterTimeout(t *testing.T) {
context.Background(),
nodeID,
ctx.ChainID,
ctx.ChainID,
requestID,
message.AcceptedStateSummaryOp,
message.InternalGetAcceptedStateSummaryFailed(
Expand All @@ -672,7 +670,6 @@ func TestRouterTimeout(t *testing.T) {
context.Background(),
nodeID,
ctx.ChainID,
ctx.ChainID,
requestID,
message.AcceptedFrontierOp,
message.InternalGetAcceptedFrontierFailed(
Expand All @@ -691,7 +688,6 @@ func TestRouterTimeout(t *testing.T) {
context.Background(),
nodeID,
ctx.ChainID,
ctx.ChainID,
requestID,
message.AcceptedOp,
message.InternalGetAcceptedFailed(
Expand All @@ -710,7 +706,6 @@ func TestRouterTimeout(t *testing.T) {
context.Background(),
nodeID,
ctx.ChainID,
ctx.ChainID,
requestID,
message.AncestorsOp,
message.InternalGetAncestorsFailed(
Expand All @@ -730,7 +725,6 @@ func TestRouterTimeout(t *testing.T) {
context.Background(),
nodeID,
ctx.ChainID,
ctx.ChainID,
requestID,
message.PutOp,
message.InternalGetFailed(
Expand All @@ -749,7 +743,6 @@ func TestRouterTimeout(t *testing.T) {
context.Background(),
nodeID,
ctx.ChainID,
ctx.ChainID,
requestID,
message.ChitsOp,
message.InternalQueryFailed(
Expand All @@ -768,7 +761,6 @@ func TestRouterTimeout(t *testing.T) {
context.Background(),
nodeID,
ctx.ChainID,
ctx.ChainID,
requestID,
message.AppResponseOp,
message.InboundAppError(
Expand Down Expand Up @@ -856,7 +848,6 @@ func TestRouterHonorsRequestedEngine(t *testing.T) {
context.Background(),
nodeID,
ctx.ChainID,
ctx.ChainID,
requestID,
message.StateSummaryFrontierOp,
message.InternalGetStateSummaryFrontierFailed(
Expand Down Expand Up @@ -885,7 +876,6 @@ func TestRouterHonorsRequestedEngine(t *testing.T) {
context.Background(),
nodeID,
ctx.ChainID,
ctx.ChainID,
requestID,
message.AcceptedStateSummaryOp,
message.InternalGetAcceptedStateSummaryFailed(
Expand Down Expand Up @@ -993,7 +983,6 @@ func TestRouterClearTimeouts(t *testing.T) {
context.Background(),
ids.EmptyNodeID,
ids.Empty,
ids.Empty,
requestID,
tt.responseOp,
tt.timeoutMsg,
Expand Down Expand Up @@ -1537,7 +1526,7 @@ func TestAppRequest(t *testing.T) {
}

ctx := context.Background()
chainRouter.RegisterRequest(ctx, ids.EmptyNodeID, ids.Empty, ids.Empty, wantRequestID, tt.responseOp, tt.timeoutMsg, engineType)
chainRouter.RegisterRequest(ctx, ids.EmptyNodeID, ids.Empty, wantRequestID, tt.responseOp, tt.timeoutMsg, engineType)
chainRouter.lock.Lock()
require.Equal(1, chainRouter.timedRequests.Len())
chainRouter.lock.Unlock()
Expand Down
8 changes: 4 additions & 4 deletions snow/networking/router/mock_router.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions snow/networking/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ type InternalHandler interface {
RegisterRequest(
ctx context.Context,
nodeID ids.NodeID,
sourceChainID ids.ID,
destinationChainID ids.ID,
chainID ids.ID,
requestID uint32,
op message.Op,
failedMsg message.InboundMessage,
Expand Down
17 changes: 4 additions & 13 deletions snow/networking/router/traced_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ func (r *tracedRouter) Initialize(
func (r *tracedRouter) RegisterRequest(
ctx context.Context,
nodeID ids.NodeID,
requestingChainID ids.ID,
respondingChainID ids.ID,
chainID ids.ID,
requestID uint32,
op message.Op,
failedMsg message.InboundMessage,
Expand All @@ -76,8 +75,7 @@ func (r *tracedRouter) RegisterRequest(
r.router.RegisterRequest(
ctx,
nodeID,
requestingChainID,
respondingChainID,
chainID,
requestID,
op,
failedMsg,
Expand All @@ -87,13 +85,7 @@ func (r *tracedRouter) RegisterRequest(

func (r *tracedRouter) HandleInbound(ctx context.Context, msg message.InboundMessage) {
m := msg.Message()
destinationChainID, err := message.GetChainID(m)
if err != nil {
r.router.HandleInbound(ctx, msg)
return
}

sourceChainID, err := message.GetSourceChainID(m)
chainID, err := message.GetChainID(m)
if err != nil {
r.router.HandleInbound(ctx, msg)
return
Expand All @@ -102,8 +94,7 @@ func (r *tracedRouter) HandleInbound(ctx context.Context, msg message.InboundMes
ctx, span := r.tracer.Start(ctx, "tracedRouter.HandleInbound", oteltrace.WithAttributes(
attribute.Stringer("nodeID", msg.NodeID()),
attribute.Stringer("messageOp", msg.Op()),
attribute.Stringer("chainID", destinationChainID),
attribute.Stringer("sourceChainID", sourceChainID),
attribute.Stringer("chainID", chainID),
))
defer span.End()

Expand Down
Loading

0 comments on commit a5acbaa

Please sign in to comment.