From 72b9c537884e367675278dcb081468296e3f7bf7 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 15 Nov 2019 14:39:26 +0800 Subject: [PATCH] make reuse work on Windows --- libp2pquic_suite_test.go | 2 +- netlink_other.go | 3 +- reuse.go => reuse_base.go | 68 ++++++--------------------------------- reuse_linux_test.go | 42 ++++++++++++++++++++++++ reuse_not_win.go | 66 +++++++++++++++++++++++++++++++++++++ reuse_test.go | 57 +++++++++++--------------------- reuse_win.go | 26 +++++++++++++++ 7 files changed, 165 insertions(+), 99 deletions(-) rename reuse.go => reuse_base.go (70%) create mode 100644 reuse_linux_test.go create mode 100644 reuse_not_win.go create mode 100644 reuse_win.go diff --git a/libp2pquic_suite_test.go b/libp2pquic_suite_test.go index ce48c3f..a2e1df4 100644 --- a/libp2pquic_suite_test.go +++ b/libp2pquic_suite_test.go @@ -28,7 +28,7 @@ var maxUnusedDurationOrig time.Duration func isGarbageCollectorRunning() bool { var b bytes.Buffer pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "go-libp2p-quic-transport.(*reuse).runGarbageCollector") + return strings.Contains(b.String(), "go-libp2p-quic-transport.(*reuseBase).runGarbageCollector") } var _ = BeforeEach(func() { diff --git a/netlink_other.go b/netlink_other.go index 58ad3ca..ccc073d 100644 --- a/netlink_other.go +++ b/netlink_other.go @@ -1,8 +1,9 @@ // +build !linux +// +build !windows package libp2pquic import "github.com/vishvananda/netlink/nl" -// nl.SupportedNlFamilies is the default netlink families used by the netlink package +// SupportedNlFamilies is the default netlink families used by the netlink package var SupportedNlFamilies = nl.SupportedNlFamilies diff --git a/reuse.go b/reuse_base.go similarity index 70% rename from reuse.go rename to reuse_base.go index c6415b4..347a9a6 100644 --- a/reuse.go +++ b/reuse_base.go @@ -4,11 +4,9 @@ import ( "net" "sync" "time" - - "github.com/vishvananda/netlink" ) -// Constants. Defined as variables to simplify testing. +// Constant. Defined as variables to simplify testing. var ( garbageCollectInterval = 30 * time.Second maxUnusedDuration = 10 * time.Second @@ -48,34 +46,24 @@ func (c *reuseConn) ShouldGarbageCollect(now time.Time) bool { return !c.unusedSince.IsZero() && c.unusedSince.Add(maxUnusedDuration).Before(now) } -type reuse struct { +type reuseBase struct { mutex sync.Mutex garbageCollectorRunning bool - handle *netlink.Handle // Only set on Linux. nil on other systems. - unicast map[string] /* IP.String() */ map[int] /* port */ *reuseConn // global contains connections that are listening on 0.0.0.0 / :: global map[int]*reuseConn } -func newReuse() (*reuse, error) { - // On non-Linux systems, this will return ErrNotImplemented. - handle, err := netlink.NewHandle(SupportedNlFamilies...) - if err == netlink.ErrNotImplemented { - handle = nil - } else if err != nil { - return nil, err - } - return &reuse{ +func newReuseBase() reuseBase { + return reuseBase{ unicast: make(map[string]map[int]*reuseConn), global: make(map[int]*reuseConn), - handle: handle, - }, nil + } } -func (r *reuse) runGarbageCollector() { +func (r *reuseBase) runGarbageCollector() { ticker := time.NewTicker(garbageCollectInterval) defer ticker.Stop() @@ -114,52 +102,14 @@ func (r *reuse) runGarbageCollector() { } // must be called while holding the mutex -func (r *reuse) maybeStartGarbageCollector() { +func (r *reuseBase) maybeStartGarbageCollector() { if !r.garbageCollectorRunning { r.garbageCollectorRunning = true go r.runGarbageCollector() } } -// Get the source IP that the kernel would use for dialing. -// This only works on Linux. -// On other systems, this returns an empty slice of IP addresses. -func (r *reuse) getSourceIPs(network string, raddr *net.UDPAddr) ([]net.IP, error) { - if r.handle == nil { - return nil, nil - } - - routes, err := r.handle.RouteGet(raddr.IP) - if err != nil { - return nil, err - } - - ips := make([]net.IP, 0, len(routes)) - for _, route := range routes { - ips = append(ips, route.Src) - } - return ips, nil -} - -func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) { - ips, err := r.getSourceIPs(network, raddr) - if err != nil { - return nil, err - } - - r.mutex.Lock() - defer r.mutex.Unlock() - - conn, err := r.dialLocked(network, raddr, ips) - if err != nil { - return nil, err - } - conn.IncreaseCount() - r.maybeStartGarbageCollector() - return conn, nil -} - -func (r *reuse) dialLocked(network string, raddr *net.UDPAddr, ips []net.IP) (*reuseConn, error) { +func (r *reuseBase) dialLocked(network string, raddr *net.UDPAddr, ips []net.IP) (*reuseConn, error) { for _, ip := range ips { // We already have at least one suitable connection... if conns, ok := r.unicast[ip.String()]; ok { @@ -194,7 +144,7 @@ func (r *reuse) dialLocked(network string, raddr *net.UDPAddr, ips []net.IP) (*r return rconn, nil } -func (r *reuse) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) { +func (r *reuseBase) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) { conn, err := net.ListenUDP(network, laddr) if err != nil { return nil, err diff --git a/reuse_linux_test.go b/reuse_linux_test.go new file mode 100644 index 0000000..8bc401a --- /dev/null +++ b/reuse_linux_test.go @@ -0,0 +1,42 @@ +// +build linux + +package libp2pquic + +import ( + "net" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Reuse (on Linux)", func() { + var reuse *reuse + + BeforeEach(func() { + var err error + reuse, err = newReuse() + Expect(err).ToNot(HaveOccurred()) + }) + + Context("creating and reusing connections", func() { + AfterEach(func() { closeAllConns(reuse) }) + + It("reuses a connection it created for listening on a specific interface", func() { + raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234") + Expect(err).ToNot(HaveOccurred()) + ips, err := reuse.getSourceIPs("udp4", raddr) + Expect(err).ToNot(HaveOccurred()) + Expect(ips).ToNot(BeEmpty()) + // listen + addr, err := net.ResolveUDPAddr("udp4", ips[0].String()+":0") + Expect(err).ToNot(HaveOccurred()) + lconn, err := reuse.Listen("udp4", addr) + Expect(err).ToNot(HaveOccurred()) + Expect(lconn.GetCount()).To(Equal(1)) + // dial + conn, err := reuse.Dial("udp4", raddr) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.GetCount()).To(Equal(2)) + }) + }) +}) diff --git a/reuse_not_win.go b/reuse_not_win.go new file mode 100644 index 0000000..fb36b83 --- /dev/null +++ b/reuse_not_win.go @@ -0,0 +1,66 @@ +// +build !windows + +package libp2pquic + +import ( + "net" + + "github.com/vishvananda/netlink" +) + +type reuse struct { + reuseBase + + handle *netlink.Handle // Only set on Linux. nil on other systems. +} + +func newReuse() (*reuse, error) { + handle, err := netlink.NewHandle(SupportedNlFamilies...) + if err == netlink.ErrNotImplemented { + handle = nil + } else if err != nil { + return nil, err + } + return &reuse{ + reuseBase: newReuseBase(), + handle: handle, + }, nil +} + +// Get the source IP that the kernel would use for dialing. +// This only works on Linux. +// On other systems, this returns an empty slice of IP addresses. +func (r *reuse) getSourceIPs(network string, raddr *net.UDPAddr) ([]net.IP, error) { + if r.handle == nil { + return nil, nil + } + + routes, err := r.handle.RouteGet(raddr.IP) + if err != nil { + return nil, err + } + + ips := make([]net.IP, 0, len(routes)) + for _, route := range routes { + ips = append(ips, route.Src) + } + return ips, nil +} + +func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) { + ips, err := r.getSourceIPs(network, raddr) + if err != nil { + return nil, err + } + + r.mutex.Lock() + defer r.mutex.Unlock() + + conn, err := r.dialLocked(network, raddr, ips) + if err != nil { + return nil, err + } + conn.IncreaseCount() + r.maybeStartGarbageCollector() + return conn, nil +} diff --git a/reuse_test.go b/reuse_test.go index 276fd5b..e773915 100644 --- a/reuse_test.go +++ b/reuse_test.go @@ -2,7 +2,6 @@ package libp2pquic import ( "net" - "runtime" "time" . "github.com/onsi/ginkgo" @@ -15,6 +14,24 @@ func (c *reuseConn) GetCount() int { return c.refCount } +func closeAllConns(reuse *reuse) { + reuse.mutex.Lock() + for _, conn := range reuse.global { + for conn.GetCount() > 0 { + conn.DecreaseCount() + } + } + for _, conns := range reuse.unicast { + for _, conn := range conns { + for conn.GetCount() > 0 { + conn.DecreaseCount() + } + } + } + reuse.mutex.Unlock() + Eventually(isGarbageCollectorRunning).Should(BeFalse()) +} + var _ = Describe("Reuse", func() { var reuse *reuse @@ -25,23 +42,7 @@ var _ = Describe("Reuse", func() { }) Context("creating and reusing connections", func() { - AfterEach(func() { - reuse.mutex.Lock() - for _, conn := range reuse.global { - for conn.GetCount() > 0 { - conn.DecreaseCount() - } - } - for _, conns := range reuse.unicast { - for _, conn := range conns { - for conn.GetCount() > 0 { - conn.DecreaseCount() - } - } - } - reuse.mutex.Unlock() - Eventually(isGarbageCollectorRunning).Should(BeFalse()) - }) + AfterEach(func() { closeAllConns(reuse) }) It("creates a new global connection when listening on 0.0.0.0", func() { addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0") @@ -84,26 +85,6 @@ var _ = Describe("Reuse", func() { Expect(err).ToNot(HaveOccurred()) Expect(conn.GetCount()).To(Equal(2)) }) - - if runtime.GOOS == "linux" { - It("reuses a connection it created for listening on a specific interface", func() { - raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234") - Expect(err).ToNot(HaveOccurred()) - ips, err := reuse.getSourceIPs("udp4", raddr) - Expect(err).ToNot(HaveOccurred()) - Expect(ips).ToNot(BeEmpty()) - // listen - addr, err := net.ResolveUDPAddr("udp4", ips[0].String()+":0") - Expect(err).ToNot(HaveOccurred()) - lconn, err := reuse.Listen("udp4", addr) - Expect(err).ToNot(HaveOccurred()) - Expect(lconn.GetCount()).To(Equal(1)) - // dial - conn, err := reuse.Dial("udp4", raddr) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.GetCount()).To(Equal(2)) - }) - } }) Context("garbage-collecting connections", func() { diff --git a/reuse_win.go b/reuse_win.go new file mode 100644 index 0000000..0f57c8e --- /dev/null +++ b/reuse_win.go @@ -0,0 +1,26 @@ +// +build windows + +package libp2pquic + +import "net" + +type reuse struct { + reuseBase +} + +func newReuse() (*reuse, error) { + return &reuse{reuseBase: newReuseBase()}, nil +} + +func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) { + r.mutex.Lock() + defer r.mutex.Unlock() + + conn, err := r.dialLocked(network, raddr, nil) + if err != nil { + return nil, err + } + conn.IncreaseCount() + r.maybeStartGarbageCollector() + return conn, nil +}