From 4714f4b51d5b8b178902e56526b635e178aefc2d Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Mon, 6 Nov 2023 10:25:58 -0800 Subject: [PATCH] chore: switch to typed atomics This just means we can be less concerned about struct field ordering and other details. --- flow_test.go | 5 ++--- meter.go | 6 +++--- meter_test.go | 1 + sweeper.go | 11 +++++------ 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/flow_test.go b/flow_test.go index edfdec4..a66a6ed 100644 --- a/flow_test.go +++ b/flow_test.go @@ -3,7 +3,6 @@ package flow import ( "math" "sync" - "sync/atomic" "testing" "time" ) @@ -106,7 +105,7 @@ func TestUnregister(t *testing.T) { mockClock.Add(62 * time.Second) - if atomic.LoadUint64(&m.accumulator) != 0 { + if m.accumulator.Load() != 0 { t.Error("expected meter to be paused") } @@ -131,7 +130,7 @@ func TestUnregister(t *testing.T) { if actual.Total != 120 { t.Errorf("expected total 120, got %d", actual.Total) } - if atomic.LoadUint64(&m.accumulator) == 0 { + if m.accumulator.Load() == 0 { t.Error("expected meter to be active") } diff --git a/meter.go b/meter.go index 597efc4..eb0748b 100644 --- a/meter.go +++ b/meter.go @@ -31,7 +31,7 @@ func (s Snapshot) String() string { // Meter is a meter for monitoring a flow. type Meter struct { - accumulator uint64 + accumulator atomic.Uint64 // managed by the sweeper loop. registered bool @@ -42,7 +42,7 @@ type Meter struct { // Mark updates the total. func (m *Meter) Mark(count uint64) { - if count > 0 && atomic.AddUint64(&m.accumulator, count) == count { + if count > 0 && m.accumulator.Add(count) == count { // The accumulator is 0 so we probably need to register. We may // already _be_ registered however, if we are, the registration // loop will notice that `m.registered` is set and ignore us. @@ -60,7 +60,7 @@ func (m *Meter) Snapshot() Snapshot { // Reset sets accumulator, total and rate to zero. func (m *Meter) Reset() { globalSweeper.snapshotMu.Lock() - atomic.StoreUint64(&m.accumulator, 0) + m.accumulator.Store(0) m.snapshot.Rate = 0 m.snapshot.Total = 0 globalSweeper.snapshotMu.Unlock() diff --git a/meter_test.go b/meter_test.go index 8442ebd..9bcff22 100644 --- a/meter_test.go +++ b/meter_test.go @@ -3,6 +3,7 @@ package flow import ( "testing" "time" + "unsafe" ) func TestResetMeter(t *testing.T) { diff --git a/sweeper.go b/sweeper.go index 8b9a262..ec291ad 100644 --- a/sweeper.go +++ b/sweeper.go @@ -3,7 +3,6 @@ package flow import ( "math" "sync" - "sync/atomic" "time" "github.com/benbjohnson/clock" @@ -100,7 +99,7 @@ func (sw *sweeper) update() { // Calculate the bandwidth for all active meters. for i, m := range sw.meters[:sw.activeMeters] { - total := atomic.LoadUint64(&m.accumulator) + total := m.accumulator.Load() diff := total - m.snapshot.Total instant := timeMultiplier * float64(diff) @@ -124,7 +123,7 @@ func (sw *sweeper) update() { // Ok, so we are idle... // Mark this as idle by zeroing the accumulator. - swappedTotal := atomic.SwapUint64(&m.accumulator, 0) + swappedTotal := m.accumulator.Swap(0) // So..., are we really idle? if swappedTotal > total { @@ -134,7 +133,7 @@ func (sw *sweeper) update() { // First, add back what we removed. If we can do this // fast enough, we can put it back before anyone // notices. - currentTotal := atomic.AddUint64(&m.accumulator, swappedTotal) + currentTotal := m.accumulator.Add(swappedTotal) // Did we make it? if currentTotal == swappedTotal { @@ -150,7 +149,7 @@ func (sw *sweeper) update() { // `^uint64(total - 1)` is the two's complement of // `total`. It's the "correct" way to subtract // atomically in go. - atomic.AddUint64(&m.accumulator, ^uint64(m.snapshot.Total-1)) + m.accumulator.Add(^uint64(m.snapshot.Total - 1)) } // Reset the rate, keep the total. @@ -163,7 +162,7 @@ func (sw *sweeper) update() { // 1. We don't do this on register to avoid having to take the snapshot lock. // 2. We skip calculating the bandwidth for this round so we get an _accurate_ bandwidth calculation. for _, m := range sw.meters[sw.activeMeters:] { - total := atomic.AddUint64(&m.accumulator, m.snapshot.Total) + total := m.accumulator.Add(m.snapshot.Total) if total > m.snapshot.Total { m.snapshot.LastUpdate = now }