diff --git a/beacon-chain/cache/payload_attestation.go b/beacon-chain/cache/payload_attestation.go index 2a921559b2d7..359b6b03871e 100644 --- a/beacon-chain/cache/payload_attestation.go +++ b/beacon-chain/cache/payload_attestation.go @@ -107,6 +107,22 @@ func (p *PayloadAttestationCache) Add(att *eth.PayloadAttestationMessage, idx ui return nil } +// Get returns the aggregated PayloadAttestation for the given root and status +// if the root doesn't exist or status is invalid, the function returns nil. +func (p *PayloadAttestationCache) Get(root [32]byte, status primitives.PTCStatus) *eth.PayloadAttestation { + p.Lock() + defer p.Unlock() + + if p.root != root { + return nil + } + if status >= primitives.PAYLOAD_INVALID_STATUS { + return nil + } + + return eth.CopyPayloadAttestation(p.attestations[status]) +} + // Clear clears the internal map func (p *PayloadAttestationCache) Clear() { p.Lock() diff --git a/beacon-chain/cache/payload_attestation_test.go b/beacon-chain/cache/payload_attestation_test.go index 10c274222edf..1bb7067c9adf 100644 --- a/beacon-chain/cache/payload_attestation_test.go +++ b/beacon-chain/cache/payload_attestation_test.go @@ -93,3 +93,51 @@ func TestPayloadAttestationCache(t *testing.T) { indices = att.AggregationBits.BitIndices() require.DeepEqual(t, []int{int(idx)}, indices) } + +func TestPayloadAttestationCache_Get(t *testing.T) { + root := [32]byte{1, 2, 3} + wrongRoot := [32]byte{4, 5, 6} + status := primitives.PAYLOAD_PRESENT + invalidStatus := primitives.PAYLOAD_INVALID_STATUS + + cache := &PayloadAttestationCache{ + root: root, + attestations: [primitives.PAYLOAD_INVALID_STATUS]*eth.PayloadAttestation{ + { + Signature: []byte{1}, + }, + { + Signature: []byte{2}, + }, + { + Signature: []byte{3}, + }, + }, + } + + t.Run("valid root and status", func(t *testing.T) { + result := cache.Get(root, status) + require.NotNil(t, result, "Expected a non-nil result") + require.DeepEqual(t, cache.attestations[status], result) + }) + + t.Run("invalid root", func(t *testing.T) { + result := cache.Get(wrongRoot, status) + require.IsNil(t, result) + }) + + t.Run("status out of bound", func(t *testing.T) { + result := cache.Get(root, invalidStatus) + require.IsNil(t, result) + }) + + t.Run("no attestation", func(t *testing.T) { + emptyCache := &PayloadAttestationCache{ + root: root, + attestations: [primitives.PAYLOAD_INVALID_STATUS]*eth.PayloadAttestation{}, + } + + result := emptyCache.Get(root, status) + require.IsNil(t, result) + }) +} diff --git a/proto/prysm/v1alpha1/cloners.go b/proto/prysm/v1alpha1/cloners.go index 9cfeb6396616..98599e5b8df1 100644 --- a/proto/prysm/v1alpha1/cloners.go +++ b/proto/prysm/v1alpha1/cloners.go @@ -74,7 +74,7 @@ func CopySignedBeaconBlockEPBS(sigBlock *SignedBeaconBlockEpbs) *SignedBeaconBlo } } -// CopyBeaconBlockEPBS copies the provided CopyBeaconBlockEPBS. +// CopyBeaconBlockEPBS copies the provided BeaconBlockEPBS. func CopyBeaconBlockEPBS(block *BeaconBlockEpbs) *BeaconBlockEpbs { if block == nil { return nil @@ -88,7 +88,7 @@ func CopyBeaconBlockEPBS(block *BeaconBlockEpbs) *BeaconBlockEpbs { } } -// CopyBeaconBlockBodyEPBS copies the provided CopyBeaconBlockBodyEPBS. +// CopyBeaconBlockBodyEPBS copies the provided BeaconBlockBodyEPBS. func CopyBeaconBlockBodyEPBS(body *BeaconBlockBodyEpbs) *BeaconBlockBodyEpbs { if body == nil { return nil @@ -105,7 +105,7 @@ func CopyBeaconBlockBodyEPBS(body *BeaconBlockBodyEpbs) *BeaconBlockBodyEpbs { SyncAggregate: body.SyncAggregate.Copy(), BlsToExecutionChanges: CopySlice(body.BlsToExecutionChanges), SignedExecutionPayloadHeader: CopySignedExecutionPayloadHeader(body.SignedExecutionPayloadHeader), - PayloadAttestations: CopyPayloadAttestation(body.PayloadAttestations), + PayloadAttestations: CopyPayloadAttestations(body.PayloadAttestations), } } @@ -137,18 +137,14 @@ func CopyExecutionPayloadHeaderEPBS(payload *enginev1.ExecutionPayloadHeaderEPBS } } -// CopyPayloadAttestation copies the provided PayloadAttestation array. -func CopyPayloadAttestation(attestations []*PayloadAttestation) []*PayloadAttestation { +// CopyPayloadAttestations copies the provided PayloadAttestation array. +func CopyPayloadAttestations(attestations []*PayloadAttestation) []*PayloadAttestation { if attestations == nil { return nil } newAttestations := make([]*PayloadAttestation, len(attestations)) for i, att := range attestations { - newAttestations[i] = &PayloadAttestation{ - AggregationBits: bytesutil.SafeCopyBytes(att.AggregationBits), - Data: CopyPayloadAttestationData(att.Data), - Signature: bytesutil.SafeCopyBytes(att.Signature), - } + newAttestations[i] = CopyPayloadAttestation(att) } return newAttestations } @@ -164,3 +160,15 @@ func CopyPayloadAttestationData(data *PayloadAttestationData) *PayloadAttestatio PayloadStatus: data.PayloadStatus, } } + +// CopyPayloadAttestation copies the provided PayloadAttestation. +func CopyPayloadAttestation(a *PayloadAttestation) *PayloadAttestation { + if a == nil { + return nil + } + return &PayloadAttestation{ + AggregationBits: bytesutil.SafeCopyBytes(a.AggregationBits), + Data: CopyPayloadAttestationData(a.Data), + Signature: bytesutil.SafeCopyBytes(a.Signature), + } +}