diff --git a/mtu_discoverer.go b/mtu_discoverer.go index 317b09292f6..fdf20157bef 100644 --- a/mtu_discoverer.go +++ b/mtu_discoverer.go @@ -2,6 +2,7 @@ package quic import ( "net" + "sync/atomic" "time" "github.com/quic-go/quic-go/internal/ackhandler" @@ -47,23 +48,25 @@ type mtuFinder struct { rttStats *utils.RTTStats inFlight protocol.ByteCount // the size of the probe packet currently in flight. InvalidByteCount if none is in flight - current protocol.ByteCount + current *atomic.Int64 max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer) } var _ mtuDiscoverer = &mtuFinder{} func newMTUDiscoverer(rttStats *utils.RTTStats, start protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) *mtuFinder { + var current atomic.Int64 + current.Store(int64(start)) return &mtuFinder{ inFlight: protocol.InvalidByteCount, - current: start, + current: ¤t, rttStats: rttStats, mtuIncreased: mtuIncreased, } } func (f *mtuFinder) done() bool { - return f.max-f.current <= maxMTUDiff+1 + return f.max-f.CurrentSize() <= maxMTUDiff+1 } func (f *mtuFinder) Start(maxPacketSize protocol.ByteCount) { @@ -82,7 +85,7 @@ func (f *mtuFinder) ShouldSendProbe(now time.Time) bool { } func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) { - size := (f.max + f.current) / 2 + size := (f.max + f.CurrentSize()) / 2 f.lastProbeTime = time.Now() f.inFlight = size return ackhandler.Frame{ @@ -92,7 +95,7 @@ func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) { } func (f *mtuFinder) CurrentSize() protocol.ByteCount { - return f.current + return protocol.ByteCount(f.current.Load()) } type mtuFinderAckHandler mtuFinder @@ -105,7 +108,7 @@ func (h *mtuFinderAckHandler) OnAcked(wire.Frame) { panic("OnAcked callback called although there's no MTU probe packet in flight") } h.inFlight = protocol.InvalidByteCount - h.current = size + h.current.Store(int64(size)) h.mtuIncreased(size) }