Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

make reuse work on Windows #83

Merged
merged 1 commit into from
Nov 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libp2pquic_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var maxUnusedDurationOrig time.Duration
func isGarbageCollectorRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "go-libp2p-quic-transport.(*reuse).runGarbageCollector")
return strings.Contains(b.String(), "go-libp2p-quic-transport.(*reuseBase).runGarbageCollector")
}

var _ = BeforeEach(func() {
Expand Down
3 changes: 2 additions & 1 deletion netlink_other.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// +build !linux
// +build !windows

package libp2pquic

import "github.com/vishvananda/netlink/nl"

// nl.SupportedNlFamilies is the default netlink families used by the netlink package
// SupportedNlFamilies is the default netlink families used by the netlink package
var SupportedNlFamilies = nl.SupportedNlFamilies
68 changes: 9 additions & 59 deletions reuse.go → reuse_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@ import (
"net"
"sync"
"time"

"github.com/vishvananda/netlink"
)

// Constants. Defined as variables to simplify testing.
// Constant. Defined as variables to simplify testing.
var (
garbageCollectInterval = 30 * time.Second
maxUnusedDuration = 10 * time.Second
Expand Down Expand Up @@ -48,34 +46,24 @@ func (c *reuseConn) ShouldGarbageCollect(now time.Time) bool {
return !c.unusedSince.IsZero() && c.unusedSince.Add(maxUnusedDuration).Before(now)
}

type reuse struct {
type reuseBase struct {
mutex sync.Mutex

garbageCollectorRunning bool

handle *netlink.Handle // Only set on Linux. nil on other systems.

unicast map[string] /* IP.String() */ map[int] /* port */ *reuseConn
// global contains connections that are listening on 0.0.0.0 / ::
global map[int]*reuseConn
}

func newReuse() (*reuse, error) {
// On non-Linux systems, this will return ErrNotImplemented.
handle, err := netlink.NewHandle(SupportedNlFamilies...)
if err == netlink.ErrNotImplemented {
handle = nil
} else if err != nil {
return nil, err
}
return &reuse{
func newReuseBase() reuseBase {
return reuseBase{
unicast: make(map[string]map[int]*reuseConn),
global: make(map[int]*reuseConn),
handle: handle,
}, nil
}
}

func (r *reuse) runGarbageCollector() {
func (r *reuseBase) runGarbageCollector() {
ticker := time.NewTicker(garbageCollectInterval)
defer ticker.Stop()

Expand Down Expand Up @@ -114,52 +102,14 @@ func (r *reuse) runGarbageCollector() {
}

// must be called while holding the mutex
func (r *reuse) maybeStartGarbageCollector() {
func (r *reuseBase) maybeStartGarbageCollector() {
if !r.garbageCollectorRunning {
r.garbageCollectorRunning = true
go r.runGarbageCollector()
}
}

// Get the source IP that the kernel would use for dialing.
// This only works on Linux.
// On other systems, this returns an empty slice of IP addresses.
func (r *reuse) getSourceIPs(network string, raddr *net.UDPAddr) ([]net.IP, error) {
if r.handle == nil {
return nil, nil
}

routes, err := r.handle.RouteGet(raddr.IP)
if err != nil {
return nil, err
}

ips := make([]net.IP, 0, len(routes))
for _, route := range routes {
ips = append(ips, route.Src)
}
return ips, nil
}

func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) {
ips, err := r.getSourceIPs(network, raddr)
if err != nil {
return nil, err
}

r.mutex.Lock()
defer r.mutex.Unlock()

conn, err := r.dialLocked(network, raddr, ips)
if err != nil {
return nil, err
}
conn.IncreaseCount()
r.maybeStartGarbageCollector()
return conn, nil
}

func (r *reuse) dialLocked(network string, raddr *net.UDPAddr, ips []net.IP) (*reuseConn, error) {
func (r *reuseBase) dialLocked(network string, raddr *net.UDPAddr, ips []net.IP) (*reuseConn, error) {
for _, ip := range ips {
// We already have at least one suitable connection...
if conns, ok := r.unicast[ip.String()]; ok {
Expand Down Expand Up @@ -194,7 +144,7 @@ func (r *reuse) dialLocked(network string, raddr *net.UDPAddr, ips []net.IP) (*r
return rconn, nil
}

func (r *reuse) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) {
func (r *reuseBase) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) {
conn, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, err
Expand Down
42 changes: 42 additions & 0 deletions reuse_linux_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// +build linux

package libp2pquic

import (
"net"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)

var _ = Describe("Reuse (on Linux)", func() {
var reuse *reuse

BeforeEach(func() {
var err error
reuse, err = newReuse()
Expect(err).ToNot(HaveOccurred())
})

Context("creating and reusing connections", func() {
AfterEach(func() { closeAllConns(reuse) })

It("reuses a connection it created for listening on a specific interface", func() {
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
Expect(err).ToNot(HaveOccurred())
ips, err := reuse.getSourceIPs("udp4", raddr)
Expect(err).ToNot(HaveOccurred())
Expect(ips).ToNot(BeEmpty())
// listen
addr, err := net.ResolveUDPAddr("udp4", ips[0].String()+":0")
Expect(err).ToNot(HaveOccurred())
lconn, err := reuse.Listen("udp4", addr)
Expect(err).ToNot(HaveOccurred())
Expect(lconn.GetCount()).To(Equal(1))
// dial
conn, err := reuse.Dial("udp4", raddr)
Expect(err).ToNot(HaveOccurred())
Expect(conn.GetCount()).To(Equal(2))
})
})
})
66 changes: 66 additions & 0 deletions reuse_not_win.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// +build !windows

package libp2pquic

import (
"net"

"github.com/vishvananda/netlink"
)

type reuse struct {
reuseBase

handle *netlink.Handle // Only set on Linux. nil on other systems.
}

func newReuse() (*reuse, error) {
handle, err := netlink.NewHandle(SupportedNlFamilies...)
if err == netlink.ErrNotImplemented {
handle = nil
} else if err != nil {
return nil, err
}
return &reuse{
reuseBase: newReuseBase(),
handle: handle,
}, nil
}

// Get the source IP that the kernel would use for dialing.
// This only works on Linux.
// On other systems, this returns an empty slice of IP addresses.
func (r *reuse) getSourceIPs(network string, raddr *net.UDPAddr) ([]net.IP, error) {
if r.handle == nil {
return nil, nil
}

routes, err := r.handle.RouteGet(raddr.IP)
if err != nil {
return nil, err
}

ips := make([]net.IP, 0, len(routes))
for _, route := range routes {
ips = append(ips, route.Src)
}
return ips, nil
}

func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) {
ips, err := r.getSourceIPs(network, raddr)
if err != nil {
return nil, err
}

r.mutex.Lock()
defer r.mutex.Unlock()

conn, err := r.dialLocked(network, raddr, ips)
if err != nil {
return nil, err
}
conn.IncreaseCount()
r.maybeStartGarbageCollector()
return conn, nil
}
57 changes: 19 additions & 38 deletions reuse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package libp2pquic

import (
"net"
"runtime"
"time"

. "github.com/onsi/ginkgo"
Expand All @@ -15,6 +14,24 @@ func (c *reuseConn) GetCount() int {
return c.refCount
}

func closeAllConns(reuse *reuse) {
reuse.mutex.Lock()
for _, conn := range reuse.global {
for conn.GetCount() > 0 {
conn.DecreaseCount()
}
}
for _, conns := range reuse.unicast {
for _, conn := range conns {
for conn.GetCount() > 0 {
conn.DecreaseCount()
}
}
}
reuse.mutex.Unlock()
Eventually(isGarbageCollectorRunning).Should(BeFalse())
}

var _ = Describe("Reuse", func() {
var reuse *reuse

Expand All @@ -25,23 +42,7 @@ var _ = Describe("Reuse", func() {
})

Context("creating and reusing connections", func() {
AfterEach(func() {
reuse.mutex.Lock()
for _, conn := range reuse.global {
for conn.GetCount() > 0 {
conn.DecreaseCount()
}
}
for _, conns := range reuse.unicast {
for _, conn := range conns {
for conn.GetCount() > 0 {
conn.DecreaseCount()
}
}
}
reuse.mutex.Unlock()
Eventually(isGarbageCollectorRunning).Should(BeFalse())
})
AfterEach(func() { closeAllConns(reuse) })

It("creates a new global connection when listening on 0.0.0.0", func() {
addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0")
Expand Down Expand Up @@ -84,26 +85,6 @@ var _ = Describe("Reuse", func() {
Expect(err).ToNot(HaveOccurred())
Expect(conn.GetCount()).To(Equal(2))
})

if runtime.GOOS == "linux" {
It("reuses a connection it created for listening on a specific interface", func() {
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
Expect(err).ToNot(HaveOccurred())
ips, err := reuse.getSourceIPs("udp4", raddr)
Expect(err).ToNot(HaveOccurred())
Expect(ips).ToNot(BeEmpty())
// listen
addr, err := net.ResolveUDPAddr("udp4", ips[0].String()+":0")
Expect(err).ToNot(HaveOccurred())
lconn, err := reuse.Listen("udp4", addr)
Expect(err).ToNot(HaveOccurred())
Expect(lconn.GetCount()).To(Equal(1))
// dial
conn, err := reuse.Dial("udp4", raddr)
Expect(err).ToNot(HaveOccurred())
Expect(conn.GetCount()).To(Equal(2))
})
}
})

Context("garbage-collecting connections", func() {
Expand Down
26 changes: 26 additions & 0 deletions reuse_win.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// +build windows

package libp2pquic

import "net"

type reuse struct {
reuseBase
}

func newReuse() (*reuse, error) {
return &reuse{reuseBase: newReuseBase()}, nil
}

func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) {
r.mutex.Lock()
defer r.mutex.Unlock()

conn, err := r.dialLocked(network, raddr, nil)
if err != nil {
return nil, err
}
conn.IncreaseCount()
r.maybeStartGarbageCollector()
return conn, nil
}