Skip to content

Commit

Permalink
Simplify sampler interface (#3026)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenButtolph authored May 17, 2024
1 parent 0928176 commit ddd6d25
Show file tree
Hide file tree
Showing 24 changed files with 110 additions and 110 deletions.
4 changes: 2 additions & 2 deletions network/ip_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ func (i *ipTracker) GetGossipableIPs(

uniform.Initialize(uint64(len(i.gossipableIPs)))
for len(ips) < maxNumIPs {
index, err := uniform.Next()
if err != nil {
index, hasNext := uniform.Next()
if !hasNext {
return ips
}

Expand Down
4 changes: 2 additions & 2 deletions network/p2p/validators.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ func (v *Validators) Sample(ctx context.Context, limit int) []ids.NodeID {

uniform.Initialize(uint64(len(v.validatorList)))
for len(sampled) < limit {
i, err := uniform.Next()
if err != nil {
i, hasNext := uniform.Next()
if !hasNext {
break
}

Expand Down
4 changes: 2 additions & 2 deletions network/peer/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ func (s *peerSet) Sample(n int, precondition func(Peer) bool) []Peer {

peers := make([]Peer, 0, n)
for len(peers) < n {
index, err := sampler.Next()
if err != nil {
index, hasNext := sampler.Next()
if !hasNext {
// We have run out of peers to attempt to sample.
break
}
Expand Down
10 changes: 7 additions & 3 deletions snow/consensus/snowman/bootstrapper/sampler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
package bootstrapper

import (
"errors"

"github.com/ava-labs/avalanchego/utils/math"
"github.com/ava-labs/avalanchego/utils/sampler"
"github.com/ava-labs/avalanchego/utils/set"
)

var errUnexpectedSamplerFailure = errors.New("unexpected sampler failure")

// Sample keys from [elements] uniformly by weight without replacement. The
// returned set will have size less than or equal to [maxSize]. This function
// will error if the sum of all weights overflows.
Expand Down Expand Up @@ -36,9 +40,9 @@ func Sample[T comparable](elements map[T]uint64, maxSize int) (set.Set[T], error
}

maxSize = int(min(uint64(maxSize), totalWeight))
indices, err := sampler.Sample(maxSize)
if err != nil {
return nil, err
indices, ok := sampler.Sample(maxSize)
if !ok {
return nil, errUnexpectedSamplerFailure
}

sampledElements := set.NewSet[T](maxSize)
Expand Down
3 changes: 1 addition & 2 deletions snow/validators/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/utils/crypto/bls"
"github.com/ava-labs/avalanchego/utils/sampler"
"github.com/ava-labs/avalanchego/utils/set"

safemath "github.com/ava-labs/avalanchego/utils/math"
Expand Down Expand Up @@ -396,7 +395,7 @@ func TestSample(t *testing.T) {
require.Equal([]ids.NodeID{nodeID0}, sampled)

_, err = m.Sample(subnetID, 2)
require.ErrorIs(err, sampler.ErrOutOfRange)
require.ErrorIs(err, errInsufficientWeight)

nodeID1 := ids.GenerateTestNodeID()
require.NoError(m.AddStaker(subnetID, nodeID1, nil, ids.Empty, math.MaxInt64-1))
Expand Down
7 changes: 4 additions & 3 deletions snow/validators/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var (
errDuplicateValidator = errors.New("duplicate validator")
errMissingValidator = errors.New("missing validator")
errTotalWeightNotUint64 = errors.New("total weight is not a uint64")
errInsufficientWeight = errors.New("insufficient weight")
)

// newSet returns a new, empty set of validators.
Expand Down Expand Up @@ -257,9 +258,9 @@ func (s *vdrSet) sample(size int) ([]ids.NodeID, error) {
s.samplerInitialized = true
}

indices, err := s.sampler.Sample(size)
if err != nil {
return nil, err
indices, ok := s.sampler.Sample(size)
if !ok {
return nil, errInsufficientWeight
}

list := make([]ids.NodeID, size)
Expand Down
3 changes: 1 addition & 2 deletions snow/validators/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/utils/crypto/bls"
"github.com/ava-labs/avalanchego/utils/sampler"
"github.com/ava-labs/avalanchego/utils/set"

safemath "github.com/ava-labs/avalanchego/utils/math"
Expand Down Expand Up @@ -343,7 +342,7 @@ func TestSetSample(t *testing.T) {
require.Equal([]ids.NodeID{nodeID0}, sampled)

_, err = s.Sample(2)
require.ErrorIs(err, sampler.ErrOutOfRange)
require.ErrorIs(err, errInsufficientWeight)

nodeID1 := ids.GenerateTestNodeID()
require.NoError(s.Add(nodeID1, nil, ids.Empty, math.MaxInt64-1))
Expand Down
6 changes: 3 additions & 3 deletions utils/sampler/uniform.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ package sampler
type Uniform interface {
Initialize(sampleRange uint64)
// Sample returns length numbers in the range [0,sampleRange). If there
// aren't enough numbers in the range, an error is returned. If length is
// aren't enough numbers in the range, false is returned. If length is
// negative the implementation may panic.
Sample(length int) ([]uint64, error)
Sample(length int) ([]uint64, bool)

Next() (uint64, bool)
Reset()
Next() (uint64, error)
}

// NewUniform returns a new sampler
Expand Down
2 changes: 1 addition & 1 deletion utils/sampler/uniform_best.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ samplerLoop:

start := s.clock.Time()
for i := 0; i < s.benchmarkIterations; i++ {
if _, err := sampler.Sample(sampleSize); err != nil {
if _, ok := sampler.Sample(sampleSize); !ok {
continue samplerLoop
}
}
Expand Down
16 changes: 8 additions & 8 deletions utils/sampler/uniform_replacer.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,34 @@ func (s *uniformReplacer) Initialize(length uint64) {
s.drawsCount = 0
}

func (s *uniformReplacer) Sample(count int) ([]uint64, error) {
func (s *uniformReplacer) Sample(count int) ([]uint64, bool) {
s.Reset()

results := make([]uint64, count)
for i := 0; i < count; i++ {
ret, err := s.Next()
if err != nil {
return nil, err
ret, hasNext := s.Next()
if !hasNext {
return nil, false
}
results[i] = ret
}
return results, nil
return results, true
}

func (s *uniformReplacer) Reset() {
clear(s.drawn)
s.drawsCount = 0
}

func (s *uniformReplacer) Next() (uint64, error) {
func (s *uniformReplacer) Next() (uint64, bool) {
if s.drawsCount >= s.length {
return 0, ErrOutOfRange
return 0, false
}

draw := s.rng.Uint64Inclusive(s.length-1-s.drawsCount) + s.drawsCount
ret := s.drawn.get(draw, draw)
s.drawn[draw] = s.drawn.get(s.drawsCount, s.drawsCount)
s.drawsCount++

return ret, nil
return ret, true
}
16 changes: 8 additions & 8 deletions utils/sampler/uniform_resample.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,28 @@ func (s *uniformResample) Initialize(length uint64) {
s.drawn = make(map[uint64]struct{})
}

func (s *uniformResample) Sample(count int) ([]uint64, error) {
func (s *uniformResample) Sample(count int) ([]uint64, bool) {
s.Reset()

results := make([]uint64, count)
for i := 0; i < count; i++ {
ret, err := s.Next()
if err != nil {
return nil, err
ret, hasNext := s.Next()
if !hasNext {
return nil, false
}
results[i] = ret
}
return results, nil
return results, true
}

func (s *uniformResample) Reset() {
clear(s.drawn)
}

func (s *uniformResample) Next() (uint64, error) {
func (s *uniformResample) Next() (uint64, bool) {
i := uint64(len(s.drawn))
if i >= s.length {
return 0, ErrOutOfRange
return 0, false
}

for {
Expand All @@ -53,6 +53,6 @@ func (s *uniformResample) Next() (uint64, error) {
continue
}
s.drawn[draw] = struct{}{}
return draw, nil
return draw, true
}
}
32 changes: 16 additions & 16 deletions utils/sampler/uniform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ func UniformInitializeMaxUint64Test(t *testing.T, s Uniform) {
s.Initialize(math.MaxUint64)

for {
val, err := s.Next()
require.NoError(t, err)
val, hasNext := s.Next()
require.True(t, hasNext)

if val > math.MaxInt64 {
break
Expand All @@ -95,17 +95,17 @@ func UniformInitializeMaxUint64Test(t *testing.T, s Uniform) {
func UniformOutOfRangeTest(t *testing.T, s Uniform) {
s.Initialize(0)

_, err := s.Sample(1)
require.ErrorIs(t, err, ErrOutOfRange)
_, ok := s.Sample(1)
require.False(t, ok)
}

func UniformEmptyTest(t *testing.T, s Uniform) {
require := require.New(t)

s.Initialize(1)

val, err := s.Sample(0)
require.NoError(err)
val, ok := s.Sample(0)
require.True(ok)
require.Empty(val)
}

Expand All @@ -114,8 +114,8 @@ func UniformSingletonTest(t *testing.T, s Uniform) {

s.Initialize(1)

val, err := s.Sample(1)
require.NoError(err)
val, ok := s.Sample(1)
require.True(ok)
require.Equal([]uint64{0}, val)
}

Expand All @@ -124,8 +124,8 @@ func UniformDistributionTest(t *testing.T, s Uniform) {

s.Initialize(3)

val, err := s.Sample(3)
require.NoError(err)
val, ok := s.Sample(3)
require.True(ok)

slices.Sort(val)
require.Equal([]uint64{0, 1, 2}, val)
Expand All @@ -134,8 +134,8 @@ func UniformDistributionTest(t *testing.T, s Uniform) {
func UniformOverSampleTest(t *testing.T, s Uniform) {
s.Initialize(3)

_, err := s.Sample(4)
require.ErrorIs(t, err, ErrOutOfRange)
_, ok := s.Sample(4)
require.False(t, ok)
}

func UniformLazilySample(t *testing.T, s Uniform) {
Expand All @@ -146,15 +146,15 @@ func UniformLazilySample(t *testing.T, s Uniform) {
for j := 0; j < 2; j++ {
sampled := map[uint64]bool{}
for i := 0; i < 3; i++ {
val, err := s.Next()
require.NoError(err)
val, hasNext := s.Next()
require.True(hasNext)
require.False(sampled[val])

sampled[val] = true
}

_, err := s.Next()
require.ErrorIs(err, ErrOutOfRange)
_, hasNext := s.Next()
require.False(hasNext)

s.Reset()
}
Expand Down
6 changes: 1 addition & 5 deletions utils/sampler/weighted.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,11 @@

package sampler

import "errors"

var ErrOutOfRange = errors.New("out of range")

// Weighted defines how to sample a specified valued based on a provided
// weighted distribution
type Weighted interface {
Initialize(weights []uint64) error
Sample(sampleValue uint64) (int, error)
Sample(sampleValue uint64) (int, bool)
}

// NewWeighted returns a new sampler
Expand Down
6 changes: 3 additions & 3 deletions utils/sampler/weighted_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ func (s *weightedArray) Initialize(weights []uint64) error {
return nil
}

func (s *weightedArray) Sample(value uint64) (int, error) {
func (s *weightedArray) Sample(value uint64) (int, bool) {
if len(s.arr) == 0 || s.arr[len(s.arr)-1].cumulativeWeight <= value {
return 0, ErrOutOfRange
return 0, false
}
minIndex := 0
maxIndex := len(s.arr) - 1
Expand All @@ -98,7 +98,7 @@ func (s *weightedArray) Sample(value uint64) (int, error) {
currentElem := s.arr[index]
currentWeight := currentElem.cumulativeWeight
if previousWeight <= value && value < currentWeight {
return currentElem.index, nil
return currentElem.index, true
}

if value < previousWeight {
Expand Down
2 changes: 1 addition & 1 deletion utils/sampler/weighted_best.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ samplerLoop:

start := s.clock.Time()
for _, sample := range samples {
if _, err := sampler.Sample(sample); err != nil {
if _, ok := sampler.Sample(sample); !ok {
continue samplerLoop
}
}
Expand Down
6 changes: 3 additions & 3 deletions utils/sampler/weighted_heap.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,17 @@ func (s *weightedHeap) Initialize(weights []uint64) error {
return nil
}

func (s *weightedHeap) Sample(value uint64) (int, error) {
func (s *weightedHeap) Sample(value uint64) (int, bool) {
if len(s.heap) == 0 || s.heap[0].cumulativeWeight <= value {
return 0, ErrOutOfRange
return 0, false
}

index := 0
for {
currentElement := s.heap[index]
currentWeight := currentElement.weight
if value < currentWeight {
return currentElement.index, nil
return currentElement.index, true
}
value -= currentWeight

Expand Down
Loading

0 comments on commit ddd6d25

Please sign in to comment.