From 01a06cdc5461af1bb223a6620bf9c697016a95bd Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 5 Aug 2019 18:12:51 +0700 Subject: [PATCH] use a single handle for each reuse --- reuse.go | 28 +++++++++++++++++++++------- reuse_test.go | 4 +++- transport.go | 22 +++++++++++++++++----- 3 files changed, 41 insertions(+), 13 deletions(-) diff --git a/reuse.go b/reuse.go index 56aba1b..08fc14f 100644 --- a/reuse.go +++ b/reuse.go @@ -24,23 +24,37 @@ func (c *reuseConn) GetCount() int { return int(atomic.LoadInt32(&c.refCount)) type reuse struct { mutex sync.Mutex + handle *netlink.Handle // Only set on Linux. nil on other systems. + unicast map[string] /* IP.String() */ map[int] /* port */ *reuseConn // global contains connections that are listening on 0.0.0.0 / :: global map[int]*reuseConn } -func newReuse() *reuse { +func newReuse() (*reuse, error) { + // On non-Linux systems, this will return ErrNotImplemented. + handle, err := netlink.NewHandle() + if err == netlink.ErrNotImplemented { + handle = nil + } else if err != nil { + return nil, err + } return &reuse{ unicast: make(map[string]map[int]*reuseConn), global: make(map[int]*reuseConn), - } + handle: handle, + }, nil } +// Get the source IP that the kernel would use for dialing. +// This only works on Linux. +// On other systems, this returns an empty slice of IP addresses. func (r *reuse) getSourceIPs(network string, raddr *net.UDPAddr) ([]net.IP, error) { - // Determine the source address that the kernel would use for this IP address. - // Note: This only works on Linux. - // On other OSes, this will return a netlink.ErrNotImplemetned. - routes, err := (&netlink.Handle{}).RouteGet(raddr.IP) + if r.handle == nil { + return nil, nil + } + + routes, err := r.handle.RouteGet(raddr.IP) if err != nil { return nil, err } @@ -54,7 +68,7 @@ func (r *reuse) getSourceIPs(network string, raddr *net.UDPAddr) ([]net.IP, erro func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) { ips, err := r.getSourceIPs(network, raddr) - if err != nil && err != netlink.ErrNotImplemented { + if err != nil { return nil, err } diff --git a/reuse_test.go b/reuse_test.go index 2259a39..91668b0 100644 --- a/reuse_test.go +++ b/reuse_test.go @@ -12,7 +12,9 @@ var _ = Describe("Reuse", func() { var reuse *reuse BeforeEach(func() { - reuse = newReuse() + var err error + reuse, err = newReuse() + Expect(err).ToNot(HaveOccurred()) }) It("creates a new global connection when listening on 0.0.0.0", func() { diff --git a/transport.go b/transport.go index bd0adfc..a9ddd41 100644 --- a/transport.go +++ b/transport.go @@ -33,11 +33,19 @@ type connManager struct { reuseUDP6 *reuse } -func newConnManager() *connManager { - return &connManager{ - reuseUDP4: newReuse(), - reuseUDP6: newReuse(), +func newConnManager() (*connManager, error) { + reuseUDP4, err := newReuse() + if err != nil { + return nil, err + } + reuseUDP6, err := newReuse() + if err != nil { + return nil, err } + return &connManager{ + reuseUDP4: reuseUDP4, + reuseUDP6: reuseUDP6, + }, nil } func (c *connManager) getReuse(network string) (*reuse, error) { @@ -87,12 +95,16 @@ func NewTransport(key ic.PrivKey) (tpt.Transport, error) { if err != nil { return nil, err } + connManager, err := newConnManager() + if err != nil { + return nil, err + } return &transport{ privKey: key, localPeer: localPeer, identity: identity, - connManager: newConnManager(), + connManager: connManager, }, nil }