Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more unit tests for aggregator #878

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions warp/aggregator/aggregation_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ func newSignatureAggregationJob(

// Execute aggregates signatures for the requested message
func (a *signatureAggregationJob) Execute(ctx context.Context) (*AggregateSignatureResult, error) {
log.Info("Fetching signature", "subnetID", a.subnetID, "height", a.height)
msgID := a.msg.ID()
log.Info("Fetching signature", "msgID", msgID, "subnetID", a.subnetID, "height", a.height)
validators, totalWeight, err := avalancheWarp.GetCanonicalValidatorSet(ctx, a.state, a.height, a.subnetID)
if err != nil {
return nil, fmt.Errorf("failed to get validator set: %w", err)
Expand Down Expand Up @@ -95,24 +96,31 @@ func (a *signatureAggregationJob) Execute(ctx context.Context) (*AggregateSignat
wg.Add(1)
go func() {
defer wg.Done()
log.Info("Fetching warp signature", "nodeID", signatureJob.nodeID, "index", i)
log.Info("Fetching warp signature", "msgID", msgID, "nodeID", signatureJob.nodeID, "index", i)
blsSignature, err := signatureJob.Execute(signatureFetchCtx)
if err != nil {
log.Info("Failed to fetch signature at index %d: %s", i, signatureJob)
return
}
log.Info("Retrieved warp signature", "nodeID", signatureJob.nodeID, "index", i, "signature", hexutil.Bytes(bls.SignatureToBytes(blsSignature)))
// Add the signature and check if we've reached the requested threshold
log.Info("Retrieved warp signature", "msgID", msgID, "nodeID", signatureJob.nodeID, "index", i, "signature", hexutil.Bytes(bls.SignatureToBytes(blsSignature)))

// Obtain signatureLock for aggregating the signature weight
signatureLock.Lock()
defer signatureLock.Unlock()

// Exit early if context was cancelled
if err := signatureFetchCtx.Err(); err != nil {
return
}

// Add the signature and check if we've reached the requested threshold
blsSignatures = append(blsSignatures, blsSignature)
bitSet.Add(i)
log.Info("Updated weight", "totalWeight", signatureWeight+signatureJob.weight, "addedWeight", signatureJob.weight)
signatureWeight += signatureJob.weight
log.Info("Updated weight", "msgID", msgID, "signatureWeight", signatureWeight, "addedWeight", signatureJob.weight)
// If the signature weight meets the requested threshold, cancel signature fetching
if err := avalancheWarp.VerifyWeight(signatureWeight, totalWeight, a.maxNeededQuorumNum, a.quorumDen); err == nil {
log.Info("Verify weight passed, exiting aggregation early", "maxNeededQuorumNum", a.maxNeededQuorumNum, "totalWeight", totalWeight, "signatureWeight", signatureWeight)
log.Info("Verify weight passed, exiting aggregation early", "msgID", msgID, "maxNeededQuorumNum", a.maxNeededQuorumNum, "totalWeight", totalWeight, "signatureWeight", signatureWeight)
signatureFetchCancel()
}
}()
Expand Down
64 changes: 63 additions & 1 deletion warp/aggregator/aggregation_job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ func executeSignatureAggregationTest(t testing.TB, test signatureAggregationTest
t.Helper()

res, err := test.job.Execute(test.ctx)
require.ErrorIs(t, err, test.expectedErr)
if test.expectedErr != nil {
require.ErrorIs(t, err, test.expectedErr)
return
}

Expand Down Expand Up @@ -211,6 +211,68 @@ func TestAggregateThresholdSignatures(t *testing.T) {
})
}

func TestAggregateThresholdSignaturesOverMaxNeeded(t *testing.T) {
ctx := context.Background()
aggregationJob := newSignatureAggregationJob(
&mockFetcher{
fetch: func(ctx context.Context, nodeID ids.NodeID, _ *avalancheWarp.UnsignedMessage) (*bls.Signature, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
// Allow bls signatures from all nodes even though we only need 3/5
for i, matchingNodeID := range nodeIDs {
if matchingNodeID == nodeID {
return blsSignatures[i], nil
}
}
}
return nil, errors.New("what do we say to the god of death")
},
},
pChainHeight,
subnetID,
60,
60,
100,
&validators.TestState{
GetSubnetIDF: getSubnetIDF,
GetCurrentHeightF: getCurrentHeightF,
GetValidatorSetF: func(ctx context.Context, height uint64, subnetID ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) {
res := make(map[ids.NodeID]*validators.GetValidatorOutput)
for i := 0; i < len(nodeIDs); i++ {
res[nodeIDs[i]] = &validators.GetValidatorOutput{
NodeID: nodeIDs[i],
PublicKey: blsPublicKeys[i],
Weight: 100,
}
}
return res, nil
},
},
unsignedMsg,
)

signature := &avalancheWarp.BitSetSignature{
Signers: set.NewBits(0, 1, 2).Bytes(),
}
signedMessage, err := avalancheWarp.NewMessage(unsignedMsg, signature)
require.NoError(t, err)
aggregateSignature, err := bls.AggregateSignatures(blsSignatures)
require.NoError(t, err)
copy(signature.Signature[:], bls.SignatureToBytes(aggregateSignature))
expectedRes := &AggregateSignatureResult{
SignatureWeight: 300,
TotalWeight: 500,
Message: signedMessage,
}
executeSignatureAggregationTest(t, signatureAggregationTest{
ctx: ctx,
job: aggregationJob,
expectedRes: expectedRes,
})
}

func TestAggregateThresholdSignaturesInsufficientWeight(t *testing.T) {
ctx := context.Background()
aggregationJob := newSignatureAggregationJob(
Expand Down