From 6166a1207759728014aafa8c4b2ffe1110839afc Mon Sep 17 00:00:00 2001 From: lnykww Date: Thu, 21 Mar 2019 20:13:28 +0800 Subject: [PATCH] reuse port --- listener.go | 23 ++++-- reuse.go | 190 +++++++++++++++++++++++++++++++++++++++++++++++ reuse_test.go | 199 ++++++++++++++++++++++++++++++++++++++++++++++++++ transport.go | 74 +++++++++---------- 4 files changed, 442 insertions(+), 44 deletions(-) create mode 100644 reuse.go create mode 100644 reuse_test.go diff --git a/listener.go b/listener.go index ef8b019..2cf36f0 100644 --- a/listener.go +++ b/listener.go @@ -16,9 +16,8 @@ var quicListenAddr = quic.ListenAddr // A listener listens for QUIC connections. type listener struct { - quicListener quic.Listener - transport tpt.Transport - + quicListener quic.Listener + transport *transport privKey ic.PrivKey localPeer peer.ID localMultiaddr ma.Multiaddr @@ -26,7 +25,7 @@ type listener struct { var _ tpt.Listener = &listener{} -func newListener(addr ma.Multiaddr, transport tpt.Transport, localPeer peer.ID, key ic.PrivKey, tlsConf *tls.Config) (tpt.Listener, error) { +func newListener(addr ma.Multiaddr, t *transport, localPeer peer.ID, key ic.PrivKey, tlsConf *tls.Config) (tpt.Listener, error) { lnet, host, err := manet.DialArgs(addr) if err != nil { return nil, err @@ -35,7 +34,7 @@ func newListener(addr ma.Multiaddr, transport tpt.Transport, localPeer peer.ID, if err != nil { return nil, err } - conn, err := net.ListenUDP(lnet, laddr) + conn, err := t.connManagers.Listen(lnet, laddr) if err != nil { return nil, err } @@ -49,7 +48,7 @@ func newListener(addr ma.Multiaddr, transport tpt.Transport, localPeer peer.ID, } return &listener{ quicListener: ln, - transport: transport, + transport: t, privKey: key, localPeer: localPeer, localMultiaddr: localMultiaddr, @@ -99,6 +98,18 @@ func (l *listener) setupConn(sess quic.Session) (tpt.Conn, error) { // Close closes the listener. func (l *listener) Close() error { + lnet, host, err := manet.DialArgs(l.localMultiaddr) + if err != nil { + return err + } + laddr, err := net.ResolveUDPAddr(lnet, host) + if err != nil { + return err + } + err = l.transport.connManagers.Close(lnet, laddr) + if err != nil { + return err + } return l.quicListener.Close() } diff --git a/reuse.go b/reuse.go new file mode 100644 index 0000000..9204a71 --- /dev/null +++ b/reuse.go @@ -0,0 +1,190 @@ +package libp2pquic + +import ( + "errors" + "net" + "sync" + + srcs "github.com/lnykww/go-src-select" +) + +type ReuseConn struct { + net.PacketConn + lock sync.Mutex + ref int +} + +func NewReuseConn(conn net.PacketConn) *ReuseConn { + return &ReuseConn{ + PacketConn: conn, + ref: 1, + lock: sync.Mutex{}, + } +} + +func (rc *ReuseConn) Ref() error { + rc.lock.Lock() + defer rc.lock.Unlock() + if rc.ref == 0 { + return errors.New("conn closed") + } + rc.ref++ + return nil +} + +func (rc *ReuseConn) Close() error { + rc.lock.Lock() + defer rc.lock.Unlock() + var err error + switch rc.ref { + case 0: // cloesd, just return + return nil + case 1: // no reference, close the conn + err = rc.PacketConn.Close() + } + rc.ref-- + return err +} + +type Reuse struct { + lock sync.Mutex + unicast map[string]map[int]net.PacketConn + unspecific []net.PacketConn + connGlobal net.PacketConn + connGlobalOnce sync.Once +} + +func NewReuse() *Reuse { + return &Reuse{ + unicast: make(map[string]map[int]net.PacketConn), + unspecific: make([]net.PacketConn, 0), + } +} + +// getConnGlobal get the global random port socket, if not exist, create +// it first. +func (r *Reuse) getConnGlobal(network string) (net.PacketConn, error) { + var err error + r.connGlobalOnce.Do(func() { + var addr *net.UDPAddr + var conn net.PacketConn + var host string + switch network { + case "udp4": + host = "0.0.0.0:0" + case "udp6": + host = "[::]:0" + } + addr, err = net.ResolveUDPAddr(network, host) + if err != nil { + return + } + conn, err = net.ListenUDP(network, addr) + if err != nil { + return + } + + r.connGlobal = NewReuseConn(conn) + }) + if r.connGlobal == nil && err == nil { + err = errors.New("global socket init not done") + } + return r.connGlobal, err +} + +// rueseConn Assertion the type of the conn is ReuseConn and inc the ref +func (r *Reuse) reuseConn(conn net.PacketConn) error { + reuseConn, ok := conn.(*ReuseConn) + if !ok { + panic("type ReuseConn Assert failed: something wrong!") + } + return reuseConn.Ref() +} + +func (r *Reuse) dial(network string, raddr *net.UDPAddr) (net.PacketConn, error) { + // Find the source address which kernel use + sip, err := srcs.Select(raddr.IP) + if err != nil { + return r.getConnGlobal(network) + } + r.lock.Lock() + defer r.lock.Unlock() + + // If we has a listener on this address, use it to dial + if c, ok := r.unicast[sip.String()]; ok { + for _, v := range c { + return v, nil + } + } + + if len(r.unspecific) != 0 { + return r.unspecific[0], nil + } + + return r.getConnGlobal(network) +} + +func (r *Reuse) Dial(network string, raddr *net.UDPAddr) (net.PacketConn, error) { + conn, err := r.dial(network, raddr) + if err != nil { + return nil, err + } + // we are reuse a conn, reference it + if err = r.reuseConn(conn); err != nil { + return nil, err + } + return conn, nil +} + +func (r *Reuse) Listen(network string, laddr *net.UDPAddr) (net.PacketConn, error) { + conn, err := net.ListenUDP(network, laddr) + if err != nil { + return nil, err + } + + reuseConn := NewReuseConn(conn) + + r.lock.Lock() + defer r.lock.Unlock() + + switch { + case laddr.IP.IsUnspecified(): + r.unspecific = append(r.unspecific, reuseConn) + default: + if _, ok := r.unicast[laddr.IP.String()]; !ok { + r.unicast[laddr.IP.String()] = make(map[int]net.PacketConn) + } + if _, ok := r.unicast[laddr.IP.String()][laddr.Port]; ok { + conn.Close() + return nil, errors.New("addr already listen") + } + r.unicast[laddr.IP.String()][laddr.Port] = reuseConn + } + return reuseConn, nil +} + +func (r *Reuse) Close(addr *net.UDPAddr) error { + r.lock.Lock() + defer r.lock.Unlock() + switch { + case addr.IP.IsUnspecified(): + for index, conn := range r.unspecific { + recAddr := conn.LocalAddr().(*net.UDPAddr) + if recAddr.IP.Equal(addr.IP) && recAddr.Port == addr.Port { + r.unspecific = append(r.unspecific[:index], r.unspecific[index+1:]...) + return nil + } + } + default: + if us, ok := r.unicast[addr.IP.String()]; ok { + if _, ok := us[addr.Port]; ok { + delete(us, addr.Port) + } + + if len(us) == 0 { + delete(r.unicast, addr.IP.String()) + } + } + } + return nil +} diff --git a/reuse_test.go b/reuse_test.go new file mode 100644 index 0000000..c0ca1ec --- /dev/null +++ b/reuse_test.go @@ -0,0 +1,199 @@ +package libp2pquic + +import ( + "net" + "strings" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Reuse", func() { + var reuse *Reuse + + BeforeEach(func() { + reuse = NewReuse() + }) + + It("IPv4 Reuse correct socket", func() { + host := "0.0.0.0" + network := "udp4" + daddr := "8.8.8.8" + laddr := "127.0.0.1" + + unspecific, err := net.ResolveUDPAddr(network, strings.Join([]string{host, ":1004"}, "")) + Expect(err).ToNot(HaveOccurred()) + unspecificConn, err := reuse.Listen(network, unspecific) + Expect(err).ToNot(HaveOccurred()) + + // Only has unspecific listener, dial to dest("8.8.8.8") + // expect use the unspecific socket + dest, err := net.ResolveUDPAddr(network, strings.Join([]string{daddr, ":1004"}, "")) + Expect(err).ToNot(HaveOccurred()) + dialConn, err := reuse.Dial(network, dest) + Expect(err).ToNot(HaveOccurred()) + Expect(dialConn.(*ReuseConn).PacketConn).To(Equal(unspecificConn.(*ReuseConn).PacketConn)) + + // Add 127.0.0.1 addr, dial to dest("8.8.8.8") + // expect use the unspecific socket + localAddr, err := net.ResolveUDPAddr(network, strings.Join([]string{laddr, ":1005"}, "")) + Expect(err).ToNot(HaveOccurred()) + localHostConn, err := reuse.Listen(network, localAddr) + Expect(err).ToNot(HaveOccurred()) + + dialConn, err = reuse.Dial(network, dest) + Expect(err).ToNot(HaveOccurred()) + Expect(dialConn.(*ReuseConn).PacketConn).To(Equal(unspecificConn.(*ReuseConn).PacketConn)) + + // dial to localhost expcet use the localHostConn + localhost, err := net.ResolveUDPAddr(network, strings.Join([]string{laddr, ":1006"}, "")) + Expect(err).ToNot(HaveOccurred()) + localHostDialConn, err := reuse.Dial(network, localhost) + Expect(err).ToNot(HaveOccurred()) + Expect(localHostDialConn.(*ReuseConn).PacketConn).To(Equal(localHostConn.(*ReuseConn).PacketConn)) + + // close the unspecific listener + reuse.Close(unspecific) + // dial to dest("8.8.8.8"), expect use the global conn + dialConnGlobal, err := reuse.Dial(network, dest) + Expect(err).ToNot(HaveOccurred()) + connLocalAddr, ok := dialConnGlobal.LocalAddr().(*net.UDPAddr) + Expect(ok).To(BeTrue()) + Expect(connLocalAddr.Port).NotTo(Equal(1004)) + Expect(connLocalAddr.IP.IsUnspecified()).To(BeTrue()) + + // dial to localhost also use the localHostConn + localHostDialConn2, err := reuse.Dial(network, localhost) + Expect(err).ToNot(HaveOccurred()) + Expect(localHostDialConn2.(*ReuseConn).PacketConn).To(Equal(localHostConn.(*ReuseConn).PacketConn)) + // close the localAddr listener + reuse.Close(localAddr) + + // dial to localhost expect use the global conn + dialConnGlobal2, err := reuse.Dial(network, localhost) + Expect(err).ToNot(HaveOccurred()) + connLocalAddr2, ok := dialConnGlobal2.LocalAddr().(*net.UDPAddr) + Expect(ok).To(BeTrue()) + Expect(connLocalAddr2.Port).NotTo(Equal(1004)) + Expect(connLocalAddr2.IP.IsUnspecified()).To(BeTrue()) + }) + + It("IPv6 Reuse correct socket", func() { + host := "[::]" + network := "udp6" + daddr := "[2001:4860:4860::8888]" + laddr := "[::1]" + + unspecific, err := net.ResolveUDPAddr(network, strings.Join([]string{host, ":1004"}, "")) + Expect(err).ToNot(HaveOccurred()) + unspecificConn, err := reuse.Listen(network, unspecific) + Expect(err).ToNot(HaveOccurred()) + + // Only has unspecific listener, dial to dest("2001:4860:4860::8888") + // expect use the unspecific socket + dest, err := net.ResolveUDPAddr(network, strings.Join([]string{daddr, ":1004"}, "")) + Expect(err).ToNot(HaveOccurred()) + dialConn, err := reuse.Dial(network, dest) + Expect(err).ToNot(HaveOccurred()) + Expect(dialConn.(*ReuseConn).PacketConn).To(Equal(unspecificConn.(*ReuseConn).PacketConn)) + + // Add [::1] addr, dial to dest("2001:4860:4860::8888") + // expect use the unspecific socket + localAddr, err := net.ResolveUDPAddr(network, strings.Join([]string{laddr, ":1005"}, "")) + Expect(err).ToNot(HaveOccurred()) + localHostConn, err := reuse.Listen(network, localAddr) + Expect(err).ToNot(HaveOccurred()) + + dialConn, err = reuse.Dial(network, dest) + Expect(err).ToNot(HaveOccurred()) + // for ipv6 will use localhost or unspecific connection + // if there is no default ipv6 route, will use localhost + // what ever never use globalConnection + Expect(reuse.connGlobal).To(BeNil()) + + // dial to localhost expcet use the localHostConn + localhost, err := net.ResolveUDPAddr(network, strings.Join([]string{laddr, ":1006"}, "")) + Expect(err).ToNot(HaveOccurred()) + localHostDialConn, err := reuse.Dial(network, localhost) + Expect(err).ToNot(HaveOccurred()) + Expect(localHostDialConn.(*ReuseConn).PacketConn).To(Equal(localHostConn.(*ReuseConn).PacketConn)) + + // close the unspecific listener + reuse.Close(unspecific) + + // dial to localhost also use the localHostConn + localHostDialConn2, err := reuse.Dial(network, localhost) + Expect(err).ToNot(HaveOccurred()) + Expect(localHostDialConn2.(*ReuseConn).PacketConn).To(Equal(localHostConn.(*ReuseConn).PacketConn)) + + // close the localAddr listener + reuse.Close(localAddr) + // dial to dest("2001:4860:4860::8888"), expect use the global conn + Expect(err).ToNot(HaveOccurred()) + dialConnGlobal, err := reuse.Dial(network, dest) + Expect(err).ToNot(HaveOccurred()) + connLocalAddr, ok := dialConnGlobal.LocalAddr().(*net.UDPAddr) + Expect(ok).To(BeTrue()) + Expect(connLocalAddr.Port).NotTo(Equal(1004)) + Expect(connLocalAddr.IP.IsUnspecified()).To(BeTrue()) + + }) + + It("ReuseConn test", func() { + network := "udp4" + addr1, err := net.ResolveUDPAddr(network, "127.0.0.1:4444") + Expect(err).ToNot(HaveOccurred()) + addr2, err := net.ResolveUDPAddr(network, "127.0.0.1:4445") + Expect(err).ToNot(HaveOccurred()) + conn1, err := net.ListenUDP(network, addr1) + Expect(err).ToNot(HaveOccurred()) + conn2, err := net.ListenUDP(network, addr2) + Expect(err).ToNot(HaveOccurred()) + + reuseConn1 := NewReuseConn(conn1) + reuseConn2 := NewReuseConn(conn2) + + TestData := "ReuseConnTest" + + sendData := func() { + n, err := conn1.WriteTo([]byte(TestData), addr2) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(len(TestData))) + } + + go sendData() + + reuseConn2.SetReadDeadline(time.Now().Add(5 * time.Second)) + data := make([]byte, len(TestData)) + _, _, err = reuseConn2.ReadFrom(data) + Expect(err).ToNot(HaveOccurred()) + Expect(string(data[:])).To(Equal(TestData)) + + err = reuseConn2.Ref() + Expect(err).ToNot(HaveOccurred()) + + err = reuseConn2.Close() + Expect(err).ToNot(HaveOccurred()) + + go sendData() + + reuseConn2.SetReadDeadline(time.Now().Add(5 * time.Second)) + data = make([]byte, len(TestData)) + _, _, err = reuseConn2.ReadFrom(data) + Expect(err).ToNot(HaveOccurred()) + Expect(string(data[:])).To(Equal(TestData)) + + err = reuseConn2.Close() + Expect(err).ToNot(HaveOccurred()) + reuseConn2.SetReadDeadline(time.Now().Add(5 * time.Second)) + _, _, err = reuseConn2.ReadFrom(data) + Expect(strings.Contains(err.Error(), "use of closed network connection")).To(BeTrue()) + + err = reuseConn1.Close() + Expect(err).ToNot(HaveOccurred()) + reuseConn1.SetReadDeadline(time.Now().Add(5 * time.Second)) + _, _, err = reuseConn1.ReadFrom(data) + Expect(strings.Contains(err.Error(), "use of closed network connection")).To(BeTrue()) + }) +}) diff --git a/transport.go b/transport.go index 212ddf4..e3fb902 100644 --- a/transport.go +++ b/transport.go @@ -5,9 +5,7 @@ import ( "crypto/tls" "crypto/x509" "errors" - "fmt" "net" - "sync" ic "github.com/libp2p/go-libp2p-crypto" peer "github.com/libp2p/go-libp2p-peer" @@ -31,47 +29,36 @@ var quicConfig = &quic.Config{ KeepAlive: true, } -type connManager struct { - connIPv4Once sync.Once - connIPv4 net.PacketConn - - connIPv6Once sync.Once - connIPv6 net.PacketConn +type connManagers struct { + reuses map[string]*Reuse } -func (c *connManager) GetConnForAddr(network string) (net.PacketConn, error) { - switch network { - case "udp4": - var err error - c.connIPv4Once.Do(func() { - c.connIPv4, err = c.createConn(network, "0.0.0.0:0") - }) - return c.connIPv4, err - case "udp6": - var err error - c.connIPv6Once.Do(func() { - c.connIPv6, err = c.createConn(network, ":0") - }) - return c.connIPv6, err - default: - return nil, fmt.Errorf("unsupported network: %s", network) +func (c *connManagers) Listen(network string, laddr *net.UDPAddr) (net.PacketConn, error) { + if reuse, ok := c.reuses[network]; ok { + return reuse.Listen(network, laddr) } + return nil, errors.New("invalid network: must be either udp4 or udp6") } -func (c *connManager) createConn(network, host string) (net.PacketConn, error) { - addr, err := net.ResolveUDPAddr(network, host) - if err != nil { - return nil, err +func (c *connManagers) Dial(network string, raddr *net.UDPAddr) (net.PacketConn, error) { + if reuse, ok := c.reuses[network]; ok { + return reuse.Dial(network, raddr) + } + return nil, errors.New("invalid network: must be either udp4 or udp6") +} +func (c *connManagers) Close(network string, laddr *net.UDPAddr) error { + if reuse, ok := c.reuses[network]; ok { + return reuse.Close(laddr) } - return net.ListenUDP(network, addr) + return errors.New("invalid network: must be either udp4 or udp6") } // The Transport implements the tpt.Transport interface for QUIC connections. type transport struct { - privKey ic.PrivKey - localPeer peer.ID - tlsConf *tls.Config - connManager *connManager + privKey ic.PrivKey + localPeer peer.ID + tlsConf *tls.Config + connManagers *connManagers } var _ tpt.Transport = &transport{} @@ -87,11 +74,18 @@ func NewTransport(key ic.PrivKey) (tpt.Transport, error) { return nil, err } + connManagers := &connManagers{ + reuses: make(map[string]*Reuse), + } + + connManagers.reuses["udp4"] = NewReuse() + connManagers.reuses["udp6"] = NewReuse() + return &transport{ - privKey: key, - localPeer: localPeer, - tlsConf: tlsConf, - connManager: &connManager{}, + privKey: key, + localPeer: localPeer, + tlsConf: tlsConf, + connManagers: connManagers, }, nil } @@ -101,7 +95,11 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp if err != nil { return nil, err } - pconn, err := t.connManager.GetConnForAddr(network) + udpAddr, err := net.ResolveUDPAddr(network, host) + if err != nil { + return nil, err + } + pconn, err := t.connManagers.Dial(network, udpAddr) if err != nil { return nil, err }