diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index ce82ad331d..dbefefa481 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -25,7 +25,6 @@ import ( "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/host/pstoremanager" "github.com/libp2p/go-libp2p/p2p/host/relaysvc" - inat "github.com/libp2p/go-libp2p/p2p/net/nat" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/protocol/identify" @@ -858,99 +857,19 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr { finalAddrs = dedupAddrs(finalAddrs) - var natMappings []inat.Mapping - // natmgr is nil if we do not use nat option; - // h.natmgr.NAT() is nil if not ready, or no nat is available. - if h.natmgr != nil && h.natmgr.NAT() != nil { - natMappings = h.natmgr.NAT().Mappings() - } - - if len(natMappings) > 0 { + if h.natmgr != nil { // We have successfully mapped ports on our NAT. Use those // instead of observed addresses (mostly). - // First, generate a mapping table. - // protocol -> internal port -> external addr - ports := make(map[string]map[int]net.Addr) - for _, m := range natMappings { - addr, err := m.ExternalAddr() - if err != nil { - // mapping not ready yet. - continue - } - protoPorts, ok := ports[m.Protocol()] - if !ok { - protoPorts = make(map[int]net.Addr) - ports[m.Protocol()] = protoPorts - } - protoPorts[m.InternalPort()] = addr - } - // Next, apply this mapping to our addresses. for _, listen := range listenAddrs { - found := false - transport, rest := ma.SplitFunc(listen, func(c ma.Component) bool { - if found { - return true - } - switch c.Protocol().Code { - case ma.P_TCP, ma.P_UDP: - found = true - } - return false - }) - if !manet.IsThinWaist(transport) { - continue - } - - naddr, err := manet.ToNetAddr(transport) - if err != nil { - log.Error("error parsing net multiaddr %q: %s", transport, err) + extMaddr := h.natmgr.GetMapping(listen) + if extMaddr == nil { + // not mapped continue } - var ( - ip net.IP - iport int - protocol string - ) - switch naddr := naddr.(type) { - case *net.TCPAddr: - ip = naddr.IP - iport = naddr.Port - protocol = "tcp" - case *net.UDPAddr: - ip = naddr.IP - iport = naddr.Port - protocol = "udp" - default: - continue - } - - if !ip.IsGlobalUnicast() && !ip.IsUnspecified() { - // We only map global unicast & unspecified addresses ports. - // Not broadcast, multicast, etc. - continue - } - - mappedAddr, ok := ports[protocol][iport] - if !ok { - // Not mapped. - continue - } - - mappedMaddr, err := manet.FromNetAddr(mappedAddr) - if err != nil { - log.Errorf("mapped addr can't be turned into a multiaddr %q: %s", mappedAddr, err) - continue - } - - extMaddr := mappedMaddr - if rest != nil { - extMaddr = ma.Join(extMaddr, rest) - } - // if the router reported a sane address if !manet.IsIPUnspecified(extMaddr) { // Add in the mapped addr. diff --git a/p2p/host/basic/mock_nat_test.go b/p2p/host/basic/mock_nat_test.go new file mode 100644 index 0000000000..7714b25853 --- /dev/null +++ b/p2p/host/basic/mock_nat_test.go @@ -0,0 +1,92 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/libp2p/go-libp2p/p2p/host/basic (interfaces: NAT) + +// Package basichost is a generated GoMock package. +package basichost + +import ( + netip "net/netip" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockNAT is a mock of NAT interface. +type MockNAT struct { + ctrl *gomock.Controller + recorder *MockNATMockRecorder +} + +// MockNATMockRecorder is the mock recorder for MockNAT. +type MockNATMockRecorder struct { + mock *MockNAT +} + +// NewMockNAT creates a new mock instance. +func NewMockNAT(ctrl *gomock.Controller) *MockNAT { + mock := &MockNAT{ctrl: ctrl} + mock.recorder = &MockNATMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNAT) EXPECT() *MockNATMockRecorder { + return m.recorder +} + +// AddMapping mocks base method. +func (m *MockNAT) AddMapping(arg0 string, arg1 int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddMapping", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddMapping indicates an expected call of AddMapping. +func (mr *MockNATMockRecorder) AddMapping(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddMapping", reflect.TypeOf((*MockNAT)(nil).AddMapping), arg0, arg1) +} + +// Close mocks base method. +func (m *MockNAT) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockNATMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockNAT)(nil).Close)) +} + +// GetMapping mocks base method. +func (m *MockNAT) GetMapping(arg0 string, arg1 int) (netip.AddrPort, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMapping", arg0, arg1) + ret0, _ := ret[0].(netip.AddrPort) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetMapping indicates an expected call of GetMapping. +func (mr *MockNATMockRecorder) GetMapping(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMapping", reflect.TypeOf((*MockNAT)(nil).GetMapping), arg0, arg1) +} + +// RemoveMapping mocks base method. +func (m *MockNAT) RemoveMapping(arg0 string, arg1 int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveMapping", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveMapping indicates an expected call of RemoveMapping. +func (mr *MockNATMockRecorder) RemoveMapping(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveMapping", reflect.TypeOf((*MockNAT)(nil).RemoveMapping), arg0, arg1) +} diff --git a/p2p/host/basic/mocks.go b/p2p/host/basic/mocks.go new file mode 100644 index 0000000000..3ad4d4e90b --- /dev/null +++ b/p2p/host/basic/mocks.go @@ -0,0 +1,6 @@ +//go:build gomock || generate + +package basichost + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package basichost -destination mock_nat_test.go github.com/libp2p/go-libp2p/p2p/host/basic NAT" +type NAT nat diff --git a/p2p/host/basic/natmgr.go b/p2p/host/basic/natmgr.go index 782c116d44..6ebe37b9b7 100644 --- a/p2p/host/basic/natmgr.go +++ b/p2p/host/basic/natmgr.go @@ -4,6 +4,7 @@ import ( "context" "io" "net" + "net/netip" "strconv" "sync" "time" @@ -12,24 +13,37 @@ import ( inat "github.com/libp2p/go-libp2p/p2p/net/nat" ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" ) // NATManager is a simple interface to manage NAT devices. +// It listens Listen and ListenClose notifications from the network.Network, +// and tries to obtain port mappings for those. type NATManager interface { - // NAT gets the NAT device managed by the NAT manager. - NAT() *inat.NAT - - // Ready receives a notification when the NAT device is ready for use. - Ready() <-chan struct{} - + GetMapping(ma.Multiaddr) ma.Multiaddr io.Closer } // NewNATManager creates a NAT manager. func NewNATManager(net network.Network) NATManager { - return newNatManager(net) + return newNATManager(net) +} + +type entry struct { + protocol string + port int +} + +type nat interface { + AddMapping(protocol string, port int) error + RemoveMapping(protocol string, port int) error + GetMapping(protocol string, port int) (netip.AddrPort, bool) + io.Closer } +// so we can mock it in tests +var discoverNAT = func(ctx context.Context) (nat, error) { return inat.DiscoverNAT(ctx) } + // natManager takes care of adding + removing port mappings to the nat. // Initialized with the host if it has a NATPortMap option enabled. // natManager receives signals from the network, and check on nat mappings: @@ -39,22 +53,23 @@ func NewNATManager(net network.Network) NATManager { type natManager struct { net network.Network natMx sync.RWMutex - nat *inat.NAT + nat nat - ready chan struct{} // closed once the nat is ready to process port mappings - syncFlag chan struct{} + syncFlag chan struct{} // cap: 1 + + tracked map[entry]bool // the bool is only used in doSync and has no meaning outside of that function refCount sync.WaitGroup ctxCancel context.CancelFunc } -func newNatManager(net network.Network) *natManager { +func newNATManager(net network.Network) *natManager { ctx, cancel := context.WithCancel(context.Background()) nmgr := &natManager{ net: net, - ready: make(chan struct{}), syncFlag: make(chan struct{}, 1), ctxCancel: cancel, + tracked: make(map[entry]bool), } nmgr.refCount.Add(1) go nmgr.background(ctx) @@ -69,36 +84,29 @@ func (nmgr *natManager) Close() error { return nil } -// Ready returns a channel which will be closed when the NAT has been found -// and is ready to be used, or the search process is done. -func (nmgr *natManager) Ready() <-chan struct{} { - return nmgr.ready -} - func (nmgr *natManager) background(ctx context.Context) { defer nmgr.refCount.Done() defer func() { nmgr.natMx.Lock() + defer nmgr.natMx.Unlock() + if nmgr.nat != nil { nmgr.nat.Close() } - nmgr.natMx.Unlock() }() discoverCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - natInstance, err := inat.DiscoverNAT(discoverCtx) + natInstance, err := discoverNAT(discoverCtx) if err != nil { log.Info("DiscoverNAT error:", err) - close(nmgr.ready) return } nmgr.natMx.Lock() nmgr.nat = natInstance nmgr.natMx.Unlock() - close(nmgr.ready) // sign natManager up for network notifications // we need to sign up here to avoid missing some notifs @@ -127,10 +135,10 @@ func (nmgr *natManager) sync() { // doSync syncs the current NAT mappings, removing any outdated mappings and adding any // new mappings. func (nmgr *natManager) doSync() { - ports := map[string]map[int]bool{ - "tcp": {}, - "udp": {}, + for e := range nmgr.tracked { + nmgr.tracked[e] = false } + var newAddresses []entry for _, maddr := range nmgr.net.ListenAddresses() { // Strip the IP maIP, rest := ma.SplitFirst(maddr) @@ -144,10 +152,9 @@ func (nmgr *natManager) doSync() { continue } - // Only bother if we're listening on a - // unicast/unspecified IP. + // Only bother if we're listening on an unicast / unspecified IP. ip := net.IP(maIP.RawValue()) - if !(ip.IsGlobalUnicast() || ip.IsUnspecified()) { + if !ip.IsGlobalUnicast() && !ip.IsUnspecified() { continue } @@ -166,74 +173,118 @@ func (nmgr *natManager) doSync() { default: continue } - port, err := strconv.ParseUint(proto.Value(), 10, 16) if err != nil { // bug in multiaddr panic(err) } - ports[protocol][int(port)] = false + e := entry{protocol: protocol, port: int(port)} + if _, ok := nmgr.tracked[e]; ok { + nmgr.tracked[e] = true + } else { + newAddresses = append(newAddresses, e) + } } var wg sync.WaitGroup defer wg.Wait() // Close old mappings - for _, m := range nmgr.nat.Mappings() { - mappedPort := m.InternalPort() - if _, ok := ports[m.Protocol()][mappedPort]; !ok { - // No longer need this mapping. - wg.Add(1) - go func(m inat.Mapping) { - defer wg.Done() - m.Close() - }(m) - } else { - // already mapped - ports[m.Protocol()][mappedPort] = true + for e, v := range nmgr.tracked { + if !v { + nmgr.nat.RemoveMapping(e.protocol, e.port) + delete(nmgr.tracked, e) } } // Create new mappings. - for proto, pports := range ports { - for port, mapped := range pports { - if mapped { - continue - } - wg.Add(1) - go func(proto string, port int) { - defer wg.Done() - _, err := nmgr.nat.NewMapping(proto, port) - if err != nil { - log.Errorf("failed to port-map %s port %d: %s", proto, port, err) - } - }(proto, port) + for _, e := range newAddresses { + if err := nmgr.nat.AddMapping(e.protocol, e.port); err != nil { + log.Errorf("failed to port-map %s port %d: %s", e.protocol, e.port, err) } + nmgr.tracked[e] = false } } -// NAT returns the natManager's nat object. this may be nil, if -// (a) the search process is still ongoing, or (b) the search process -// found no nat. Clients must check whether the return value is nil. -func (nmgr *natManager) NAT() *inat.NAT { +func (nmgr *natManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr { nmgr.natMx.Lock() defer nmgr.natMx.Unlock() - return nmgr.nat -} -type nmgrNetNotifiee natManager + if nmgr.nat == nil { // NAT not yet initialized + return nil + } -func (nn *nmgrNetNotifiee) natManager() *natManager { - return (*natManager)(nn) -} + var found bool + var proto int // ma.P_TCP or ma.P_UDP + transport, rest := ma.SplitFunc(addr, func(c ma.Component) bool { + if found { + return true + } + proto = c.Protocol().Code + found = proto == ma.P_TCP || proto == ma.P_UDP + return false + }) + if !manet.IsThinWaist(transport) { + return nil + } -func (nn *nmgrNetNotifiee) Listen(n network.Network, addr ma.Multiaddr) { - nn.natManager().sync() -} + naddr, err := manet.ToNetAddr(transport) + if err != nil { + log.Error("error parsing net multiaddr %q: %s", transport, err) + return nil + } -func (nn *nmgrNetNotifiee) ListenClose(n network.Network, addr ma.Multiaddr) { - nn.natManager().sync() + var ( + ip net.IP + port int + protocol string + ) + switch naddr := naddr.(type) { + case *net.TCPAddr: + ip = naddr.IP + port = naddr.Port + protocol = "tcp" + case *net.UDPAddr: + ip = naddr.IP + port = naddr.Port + protocol = "udp" + default: + return nil + } + + if !ip.IsGlobalUnicast() && !ip.IsUnspecified() { + // We only map global unicast & unspecified addresses ports, not broadcast, multicast, etc. + return nil + } + + extAddr, ok := nmgr.nat.GetMapping(protocol, port) + if !ok { + return nil + } + + var mappedAddr net.Addr + switch naddr.(type) { + case *net.TCPAddr: + mappedAddr = net.TCPAddrFromAddrPort(extAddr) + case *net.UDPAddr: + mappedAddr = net.UDPAddrFromAddrPort(extAddr) + } + mappedMaddr, err := manet.FromNetAddr(mappedAddr) + if err != nil { + log.Errorf("mapped addr can't be turned into a multiaddr %q: %s", mappedAddr, err) + return nil + } + extMaddr := mappedMaddr + if rest != nil { + extMaddr = ma.Join(extMaddr, rest) + } + return extMaddr } -func (nn *nmgrNetNotifiee) Connected(network.Network, network.Conn) {} -func (nn *nmgrNetNotifiee) Disconnected(network.Network, network.Conn) {} +type nmgrNetNotifiee natManager + +func (nn *nmgrNetNotifiee) natManager() *natManager { return (*natManager)(nn) } +func (nn *nmgrNetNotifiee) Listen(network.Network, ma.Multiaddr) { nn.natManager().sync() } +func (nn *nmgrNetNotifiee) ListenClose(n network.Network, addr ma.Multiaddr) { nn.natManager().sync() } +func (nn *nmgrNetNotifiee) Connected(network.Network, network.Conn) {} +func (nn *nmgrNetNotifiee) Disconnected(network.Network, network.Conn) {} diff --git a/p2p/host/basic/natmgr_test.go b/p2p/host/basic/natmgr_test.go new file mode 100644 index 0000000000..e507b45c82 --- /dev/null +++ b/p2p/host/basic/natmgr_test.go @@ -0,0 +1,108 @@ +package basichost + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/require" + + ma "github.com/multiformats/go-multiaddr" + + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + + "github.com/golang/mock/gomock" +) + +func setupMockNAT(t *testing.T) (mockNAT *MockNAT, reset func()) { + t.Helper() + ctrl := gomock.NewController(t) + mockNAT = NewMockNAT(ctrl) + origDiscoverNAT := discoverNAT + discoverNAT = func(ctx context.Context) (nat, error) { return mockNAT, nil } + return mockNAT, func() { + discoverNAT = origDiscoverNAT + ctrl.Finish() + } +} + +func TestMapping(t *testing.T) { + mockNAT, reset := setupMockNAT(t) + defer reset() + + sw := swarmt.GenSwarm(t) + defer sw.Close() + m := newNATManager(sw) + require.Eventually(t, func() bool { + m.natMx.Lock() + defer m.natMx.Unlock() + return m.nat != nil + }, time.Second, time.Millisecond) + externalAddr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 4321) + // pretend that we have a TCP mapping + mockNAT.EXPECT().GetMapping("tcp", 1234).Return(externalAddr, true) + require.Equal(t, ma.StringCast("/ip4/1.2.3.4/tcp/4321"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/tcp/1234"))) + + // pretend that we have a QUIC mapping + mockNAT.EXPECT().GetMapping("udp", 1234).Return(externalAddr, true) + require.Equal(t, ma.StringCast("/ip4/1.2.3.4/udp/4321/quic-v1"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1"))) + + // pretend that there's no mapping + mockNAT.EXPECT().GetMapping("tcp", 1234).Return(netip.AddrPort{}, false) + require.Nil(t, m.GetMapping(ma.StringCast("/ip4/0.0.0.0/tcp/1234"))) + + // make sure this works for WebSocket addresses as well + mockNAT.EXPECT().GetMapping("tcp", 1234).Return(externalAddr, true) + require.Equal(t, ma.StringCast("/ip4/1.2.3.4/tcp/4321/ws"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/tcp/1234/ws"))) + + // make sure this works for WebTransport addresses as well + mockNAT.EXPECT().GetMapping("udp", 1234).Return(externalAddr, true) + require.Equal(t, ma.StringCast("/ip4/1.2.3.4/udp/4321/quic-v1/webtransport"), m.GetMapping(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1/webtransport"))) +} + +func TestAddAndRemoveListeners(t *testing.T) { + mockNAT, reset := setupMockNAT(t) + defer reset() + + sw := swarmt.GenSwarm(t) + defer sw.Close() + m := newNATManager(sw) + require.Eventually(t, func() bool { + m.natMx.Lock() + defer m.natMx.Unlock() + return m.nat != nil + }, time.Second, time.Millisecond) + + added := make(chan struct{}, 1) + // add a TCP listener + mockNAT.EXPECT().AddMapping("tcp", 1234).Do(func(string, int) { added <- struct{}{} }) + require.NoError(t, sw.Listen(ma.StringCast("/ip4/0.0.0.0/tcp/1234"))) + select { + case <-added: + case <-time.After(time.Second): + t.Fatal("didn't receive call to AddMapping") + } + + // add a QUIC listener + mockNAT.EXPECT().AddMapping("udp", 1234).Do(func(string, int) { added <- struct{}{} }) + require.NoError(t, sw.Listen(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1"))) + select { + case <-added: + case <-time.After(time.Second): + t.Fatal("didn't receive call to AddMapping") + } + + // remove the QUIC listener + mockNAT.EXPECT().RemoveMapping("udp", 1234).Do(func(string, int) { added <- struct{}{} }) + sw.ListenClose(ma.StringCast("/ip4/0.0.0.0/udp/1234/quic-v1")) + select { + case <-added: + case <-time.After(time.Second): + t.Fatal("didn't receive call to RemoveMapping") + } + + // test shutdown + mockNAT.EXPECT().RemoveMapping("tcp", 1234).MaxTimes(1) + mockNAT.EXPECT().Close().MaxTimes(1) +} diff --git a/p2p/net/nat/mapping.go b/p2p/net/nat/mapping.go deleted file mode 100644 index f9b508e4e2..0000000000 --- a/p2p/net/nat/mapping.go +++ /dev/null @@ -1,119 +0,0 @@ -package nat - -import ( - "fmt" - "net" - "sync" - "time" -) - -// Mapping represents a port mapping in a NAT. -type Mapping interface { - // NAT returns the NAT object this Mapping belongs to. - NAT() *NAT - - // Protocol returns the protocol of this port mapping. This is either - // "tcp" or "udp" as no other protocols are likely to be NAT-supported. - Protocol() string - - // InternalPort returns the internal device port. Mapping will continue to - // try to map InternalPort() to an external facing port. - InternalPort() int - - // ExternalPort returns the external facing port. If the mapping is not - // established, port will be 0 - ExternalPort() int - - // ExternalAddr returns the external facing address. If the mapping is not - // established, addr will be nil, and and ErrNoMapping will be returned. - ExternalAddr() (addr net.Addr, err error) - - // Close closes the port mapping - Close() error -} - -// keeps republishing -type mapping struct { - sync.Mutex // guards all fields - - nat *NAT - proto string - intport int - extport int - - cached net.IP - cacheTime time.Time - cacheLk sync.Mutex -} - -func (m *mapping) NAT() *NAT { - m.Lock() - defer m.Unlock() - return m.nat -} - -func (m *mapping) Protocol() string { - m.Lock() - defer m.Unlock() - return m.proto -} - -func (m *mapping) InternalPort() int { - m.Lock() - defer m.Unlock() - return m.intport -} - -func (m *mapping) ExternalPort() int { - m.Lock() - defer m.Unlock() - return m.extport -} - -func (m *mapping) setExternalPort(p int) { - m.Lock() - defer m.Unlock() - m.extport = p -} - -func (m *mapping) ExternalAddr() (net.Addr, error) { - m.cacheLk.Lock() - defer m.cacheLk.Unlock() - oport := m.ExternalPort() - if oport == 0 { - // dont even try right now. - return nil, ErrNoMapping - } - - if time.Since(m.cacheTime) >= CacheTime { - m.nat.natmu.Lock() - cval, err := m.nat.nat.GetExternalAddress() - m.nat.natmu.Unlock() - - if err != nil { - return nil, err - } - - m.cached = cval - m.cacheTime = time.Now() - } - switch m.Protocol() { - case "tcp": - return &net.TCPAddr{ - IP: m.cached, - Port: oport, - }, nil - case "udp": - return &net.UDPAddr{ - IP: m.cached, - Port: oport, - }, nil - default: - panic(fmt.Sprintf("invalid protocol %q", m.Protocol())) - } -} - -func (m *mapping) Close() error { - m.nat.removeMapping(m) - return nil -} diff --git a/p2p/net/nat/mock_nat_test.go b/p2p/net/nat/mock_nat_test.go new file mode 100644 index 0000000000..bb91bac247 --- /dev/null +++ b/p2p/net/nat/mock_nat_test.go @@ -0,0 +1,124 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/libp2p/go-nat (interfaces: NAT) + +// Package nat is a generated GoMock package. +package nat + +import ( + net "net" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" +) + +// MockNAT is a mock of NAT interface. +type MockNAT struct { + ctrl *gomock.Controller + recorder *MockNATMockRecorder +} + +// MockNATMockRecorder is the mock recorder for MockNAT. +type MockNATMockRecorder struct { + mock *MockNAT +} + +// NewMockNAT creates a new mock instance. +func NewMockNAT(ctrl *gomock.Controller) *MockNAT { + mock := &MockNAT{ctrl: ctrl} + mock.recorder = &MockNATMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNAT) EXPECT() *MockNATMockRecorder { + return m.recorder +} + +// AddPortMapping mocks base method. +func (m *MockNAT) AddPortMapping(arg0 string, arg1 int, arg2 string, arg3 time.Duration) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddPortMapping", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddPortMapping indicates an expected call of AddPortMapping. +func (mr *MockNATMockRecorder) AddPortMapping(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddPortMapping", reflect.TypeOf((*MockNAT)(nil).AddPortMapping), arg0, arg1, arg2, arg3) +} + +// DeletePortMapping mocks base method. +func (m *MockNAT) DeletePortMapping(arg0 string, arg1 int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePortMapping", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePortMapping indicates an expected call of DeletePortMapping. +func (mr *MockNATMockRecorder) DeletePortMapping(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePortMapping", reflect.TypeOf((*MockNAT)(nil).DeletePortMapping), arg0, arg1) +} + +// GetDeviceAddress mocks base method. +func (m *MockNAT) GetDeviceAddress() (net.IP, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDeviceAddress") + ret0, _ := ret[0].(net.IP) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetDeviceAddress indicates an expected call of GetDeviceAddress. +func (mr *MockNATMockRecorder) GetDeviceAddress() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDeviceAddress", reflect.TypeOf((*MockNAT)(nil).GetDeviceAddress)) +} + +// GetExternalAddress mocks base method. +func (m *MockNAT) GetExternalAddress() (net.IP, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetExternalAddress") + ret0, _ := ret[0].(net.IP) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetExternalAddress indicates an expected call of GetExternalAddress. +func (mr *MockNATMockRecorder) GetExternalAddress() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExternalAddress", reflect.TypeOf((*MockNAT)(nil).GetExternalAddress)) +} + +// GetInternalAddress mocks base method. +func (m *MockNAT) GetInternalAddress() (net.IP, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetInternalAddress") + ret0, _ := ret[0].(net.IP) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetInternalAddress indicates an expected call of GetInternalAddress. +func (mr *MockNATMockRecorder) GetInternalAddress() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInternalAddress", reflect.TypeOf((*MockNAT)(nil).GetInternalAddress)) +} + +// Type mocks base method. +func (m *MockNAT) Type() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Type") + ret0, _ := ret[0].(string) + return ret0 +} + +// Type indicates an expected call of Type. +func (mr *MockNATMockRecorder) Type() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Type", reflect.TypeOf((*MockNAT)(nil).Type)) +} diff --git a/p2p/net/nat/nat.go b/p2p/net/nat/nat.go index e2656f8bcc..68834ac877 100644 --- a/p2p/net/nat/nat.go +++ b/p2p/net/nat/nat.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net/netip" "sync" "time" @@ -19,18 +20,30 @@ var log = logging.Logger("nat") // MappingDuration is a default port mapping duration. // Port mappings are renewed every (MappingDuration / 3) -const MappingDuration = time.Second * 60 +const MappingDuration = time.Minute // CacheTime is the time a mapping will cache an external address for -const CacheTime = time.Second * 15 +const CacheTime = 15 * time.Second -// DiscoverNAT looks for a NAT device in the network and -// returns an object that can manage port mappings. +type entry struct { + protocol string + port int +} + +// so we can mock it in tests +var discoverGateway = nat.DiscoverGateway + +// DiscoverNAT looks for a NAT device in the network and returns an object that can manage port mappings. func DiscoverNAT(ctx context.Context) (*NAT, error) { - natInstance, err := nat.DiscoverGateway(ctx) + natInstance, err := discoverGateway(ctx) if err != nil { return nil, err } + var extAddr netip.Addr + extIP, err := natInstance.GetExternalAddress() + if err == nil { + extAddr, _ = netip.AddrFromSlice(extIP) + } // Log the device addr. addr, err := natInstance.GetDeviceAddress() @@ -40,7 +53,20 @@ func DiscoverNAT(ctx context.Context) (*NAT, error) { log.Debug("DiscoverGateway address:", addr) } - return newNAT(natInstance), nil + ctx, cancel := context.WithCancel(context.Background()) + nat := &NAT{ + nat: natInstance, + extAddr: extAddr, + mappings: make(map[entry]int), + ctx: ctx, + ctxCancel: cancel, + } + nat.refCount.Add(1) + go func() { + defer nat.refCount.Done() + nat.background() + }() + return nat, nil } // NAT is an object that manages address port mappings in @@ -50,6 +76,8 @@ func DiscoverNAT(ctx context.Context) (*NAT, error) { type NAT struct { natmu sync.Mutex nat nat.NAT + // External IP of the NAT. Will be renewed periodically (every CacheTime). + extAddr netip.Addr refCount sync.WaitGroup ctx context.Context @@ -57,17 +85,7 @@ type NAT struct { mappingmu sync.RWMutex // guards mappings closed bool - mappings map[*mapping]struct{} -} - -func newNAT(realNAT nat.NAT) *NAT { - ctx, cancel := context.WithCancel(context.Background()) - return &NAT{ - nat: realNAT, - mappings: make(map[*mapping]struct{}), - ctx: ctx, - ctxCancel: cancel, - } + mappings map[entry]int } // Close shuts down all port mappings. NAT can no longer be used. @@ -81,99 +99,139 @@ func (nat *NAT) Close() error { return nil } -// Mappings returns a slice of all NAT mappings -func (nat *NAT) Mappings() []Mapping { +func (nat *NAT) GetMapping(protocol string, port int) (addr netip.AddrPort, found bool) { nat.mappingmu.Lock() - maps2 := make([]Mapping, 0, len(nat.mappings)) - for m := range nat.mappings { - maps2 = append(maps2, m) + defer nat.mappingmu.Unlock() + + if !nat.extAddr.IsValid() { + return netip.AddrPort{}, false } - nat.mappingmu.Unlock() - return maps2 + extPort, found := nat.mappings[entry{protocol: protocol, port: port}] + if !found { + return netip.AddrPort{}, false + } + return netip.AddrPortFrom(nat.extAddr, uint16(extPort)), true } -// NewMapping attempts to construct a mapping on protocol and internal port -// It will also periodically renew the mapping until the returned Mapping -// -- or its parent NAT -- is Closed. +// AddMapping attempts to construct a mapping on protocol and internal port. +// It blocks until a mapping was established. Once added, it periodically renews the mapping. // // May not succeed, and mappings may change over time; // NAT devices may not respect our port requests, and even lie. -// Clients should not store the mapped results, but rather always -// poll our object for the latest mappings. -func (nat *NAT) NewMapping(protocol string, port int) (Mapping, error) { - if nat == nil { - return nil, fmt.Errorf("no nat available") - } - +func (nat *NAT) AddMapping(protocol string, port int) error { switch protocol { case "tcp", "udp": default: - return nil, fmt.Errorf("invalid protocol: %s", protocol) - } - - m := &mapping{ - intport: port, - nat: nat, - proto: protocol, + return fmt.Errorf("invalid protocol: %s", protocol) } nat.mappingmu.Lock() + defer nat.mappingmu.Unlock() + if nat.closed { - nat.mappingmu.Unlock() - return nil, errors.New("closed") + return errors.New("closed") } - nat.mappings[m] = struct{}{} - nat.refCount.Add(1) - nat.mappingmu.Unlock() - go nat.refreshMappings(m) // do it once synchronously, so first mapping is done right away, and before exiting, // allowing users -- in the optimistic case -- to use results right after. - nat.establishMapping(m) - return m, nil + extPort := nat.establishMapping(protocol, port) + nat.mappings[entry{protocol: protocol, port: port}] = extPort + return nil } -func (nat *NAT) removeMapping(m *mapping) { +// RemoveMapping removes a port mapping. +// It blocks until the NAT has removed the mapping. +func (nat *NAT) RemoveMapping(protocol string, port int) error { nat.mappingmu.Lock() - delete(nat.mappings, m) - nat.mappingmu.Unlock() - nat.natmu.Lock() - nat.nat.DeletePortMapping(m.Protocol(), m.InternalPort()) - nat.natmu.Unlock() + defer nat.mappingmu.Unlock() + + switch protocol { + case "tcp", "udp": + e := entry{protocol: protocol, port: port} + if _, ok := nat.mappings[e]; ok { + delete(nat.mappings, e) + return nat.nat.DeletePortMapping(protocol, port) + } + return errors.New("unknown mapping") + default: + return fmt.Errorf("invalid protocol: %s", protocol) + } } -func (nat *NAT) refreshMappings(m *mapping) { - defer nat.refCount.Done() - t := time.NewTicker(MappingDuration / 3) +func (nat *NAT) background() { + const mappingUpdate = MappingDuration / 3 + + now := time.Now() + nextMappingUpdate := now.Add(mappingUpdate) + nextAddrUpdate := now.Add(CacheTime) + + t := time.NewTimer(minTime(nextMappingUpdate, nextAddrUpdate).Sub(now)) // don't use a ticker here. We don't know how long establishing the mappings takes. defer t.Stop() + var in []entry + var out []int // port numbers for { select { - case <-t.C: - nat.establishMapping(m) + case now := <-t.C: + if now.After(nextMappingUpdate) { + in = in[:0] + out = out[:0] + nat.mappingmu.Lock() + for e := range nat.mappings { + in = append(in, e) + } + nat.mappingmu.Unlock() + // Establishing the mapping involves network requests. + // Don't hold the mutex, just save the ports. + for _, e := range in { + out = append(out, nat.establishMapping(e.protocol, e.port)) + } + nat.mappingmu.Lock() + for i, p := range in { + if _, ok := nat.mappings[p]; !ok { + continue // entry might have been deleted + } + nat.mappings[p] = out[i] + } + nat.mappingmu.Unlock() + nextMappingUpdate = time.Now().Add(mappingUpdate) + } + if now.After(nextAddrUpdate) { + var extAddr netip.Addr + extIP, err := nat.nat.GetExternalAddress() + if err == nil { + extAddr, _ = netip.AddrFromSlice(extIP) + } + nat.extAddr = extAddr + nextAddrUpdate = time.Now().Add(CacheTime) + } + t.Reset(time.Until(minTime(nextAddrUpdate, nextMappingUpdate))) case <-nat.ctx.Done(): - m.Close() + nat.mappingmu.Lock() + for e := range nat.mappings { + delete(nat.mappings, e) + nat.nat.DeletePortMapping(e.protocol, e.port) + } + nat.mappingmu.Unlock() return } } } -func (nat *NAT) establishMapping(m *mapping) { - oldport := m.ExternalPort() - - log.Debugf("Attempting port map: %s/%d", m.Protocol(), m.InternalPort()) +func (nat *NAT) establishMapping(protocol string, internalPort int) (externalPort int) { + log.Debugf("Attempting port map: %s/%d", protocol, internalPort) const comment = "libp2p" nat.natmu.Lock() - newport, err := nat.nat.AddPortMapping(m.Protocol(), m.InternalPort(), comment, MappingDuration) + var err error + externalPort, err = nat.nat.AddPortMapping(protocol, internalPort, comment, MappingDuration) if err != nil { // Some hardware does not support mappings with timeout, so try that - newport, err = nat.nat.AddPortMapping(m.Protocol(), m.InternalPort(), comment, 0) + externalPort, err = nat.nat.AddPortMapping(protocol, internalPort, comment, 0) } nat.natmu.Unlock() - if err != nil || newport == 0 { - m.setExternalPort(0) // clear mapping + if err != nil || externalPort == 0 { // TODO: log.Event if err != nil { log.Warnf("failed to establish port mapping: %s", err) @@ -182,12 +240,16 @@ func (nat *NAT) establishMapping(m *mapping) { } // we do not close if the mapping failed, // because it may work again next time. - return + return 0 } - m.setExternalPort(newport) - log.Debugf("NAT Mapping: %d --> %d (%s)", m.ExternalPort(), m.InternalPort(), m.Protocol()) - if oldport != 0 && newport != oldport { - log.Debugf("failed to renew same port mapping: ch %d -> %d", oldport, newport) + log.Debugf("NAT Mapping: %d --> %d (%s)", externalPort, internalPort, protocol) + return externalPort +} + +func minTime(a, b time.Time) time.Time { + if a.Before(b) { + return a } + return b } diff --git a/p2p/net/nat/nat_test.go b/p2p/net/nat/nat_test.go new file mode 100644 index 0000000000..8fffb512c7 --- /dev/null +++ b/p2p/net/nat/nat_test.go @@ -0,0 +1,69 @@ +package nat + +import ( + "context" + "errors" + "net" + "net/netip" + "testing" + + "github.com/libp2p/go-nat" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" +) + +//go:generate sh -c "go run github.com/golang/mock/mockgen -package nat -destination mock_nat_test.go github.com/libp2p/go-nat NAT" + +func setupMockNAT(t *testing.T) (mockNAT *MockNAT, reset func()) { + t.Helper() + ctrl := gomock.NewController(t) + mockNAT = NewMockNAT(ctrl) + mockNAT.EXPECT().GetDeviceAddress().Return(nil, errors.New("nope")) // is only used for logging + origDiscoverGateway := discoverGateway + discoverGateway = func(ctx context.Context) (nat.NAT, error) { return mockNAT, nil } + return mockNAT, func() { + discoverGateway = origDiscoverGateway + ctrl.Finish() + } +} + +func TestAddMapping(t *testing.T) { + mockNAT, reset := setupMockNAT(t) + defer reset() + + mockNAT.EXPECT().GetExternalAddress().Return(net.IPv4(1, 2, 3, 4), nil) + nat, err := DiscoverNAT(context.Background()) + require.NoError(t, err) + + mockNAT.EXPECT().AddPortMapping("tcp", 10000, gomock.Any(), MappingDuration).Return(1234, nil) + require.NoError(t, nat.AddMapping("tcp", 10000)) + + _, found := nat.GetMapping("tcp", 9999) + require.False(t, found, "didn't expect a port mapping for unmapped port") + _, found = nat.GetMapping("udp", 10000) + require.False(t, found, "didn't expect a port mapping for unmapped protocol") + mapped, found := nat.GetMapping("tcp", 10000) + require.True(t, found, "expected port mapping") + require.Equal(t, netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 1234), mapped) +} + +func TestRemoveMapping(t *testing.T) { + mockNAT, reset := setupMockNAT(t) + defer reset() + + mockNAT.EXPECT().GetExternalAddress().Return(net.IPv4(1, 2, 3, 4), nil) + nat, err := DiscoverNAT(context.Background()) + require.NoError(t, err) + mockNAT.EXPECT().AddPortMapping("tcp", 10000, gomock.Any(), MappingDuration).Return(1234, nil) + require.NoError(t, nat.AddMapping("tcp", 10000)) + _, found := nat.GetMapping("tcp", 10000) + require.True(t, found, "expected port mapping") + + require.Error(t, nat.RemoveMapping("tcp", 9999), "expected error for unknown mapping") + mockNAT.EXPECT().DeletePortMapping("tcp", 10000) + require.NoError(t, nat.RemoveMapping("tcp", 10000)) + + _, found = nat.GetMapping("tcp", 10000) + require.False(t, found, "didn't expect port mapping for deleted mapping") +}