Skip to content

Commit

Permalink
Migrate to udpnat2 / Add PrepareConnection
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Oct 21, 2024
1 parent dc5d3e8 commit c45d71a
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 89 deletions.
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ go 1.20
require (
github.com/go-ole/go-ole v1.3.0
github.com/sagernet/fswatch v0.1.1
github.com/sagernet/gvisor v0.0.0-20241019061641-46bad1ee6ecc
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
github.com/sagernet/nftables v0.3.0-beta.4
github.com/sagernet/sing v0.5.0-rc.4.0.20241020060022-1270938dd44a
github.com/sagernet/sing v0.5.0-rc.4.0.20241021134838-8f165de804ce
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
golang.org/x/net v0.26.0
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8Ku
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/sagernet/fswatch v0.1.1 h1:YqID+93B7VRfqIH3PArW/XpJv5H4OLEVWDfProGoRQs=
github.com/sagernet/fswatch v0.1.1/go.mod h1:nz85laH0mkQqJfaOrqPpkwtU1znMFNVTpT/5oRsVz/o=
github.com/sagernet/gvisor v0.0.0-20241019061641-46bad1ee6ecc h1:IvmeRstYX63O0QpLGJgVOaaM21ZIG0frJi6MT29Irtk=
github.com/sagernet/gvisor v0.0.0-20241019061641-46bad1ee6ecc/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw=
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3 h1:RxEz7LhPNiF/gX/Hg+OXr5lqsM9iVAgmaK1L1vzlDRM=
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw=
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis=
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
github.com/sagernet/sing v0.5.0-rc.4.0.20241020060022-1270938dd44a h1:6qlFfBvLZT/MhDpUr4cKY6RxYTnaCcFgOrJEnf/0+io=
github.com/sagernet/sing v0.5.0-rc.4.0.20241020060022-1270938dd44a/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.5.0-rc.4.0.20241021134838-8f165de804ce h1:5qVxlM/CSW1pTBiiD2ZOIi2ziE6EXdRlnT4H+enjbEk=
github.com/sagernet/sing v0.5.0-rc.4.0.20241021134838-8f165de804ce/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
Expand Down
3 changes: 2 additions & 1 deletion stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/binary"
"net"
"net/netip"
"time"

"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
Expand All @@ -23,7 +24,7 @@ type StackOptions struct {
Tun Tun
TunOptions Options
EndpointIndependentNat bool
UDPTimeout int64
UDPTimeout time.Duration
Handler Handler
Logger logger.Logger
ForwarderBindInterface bool
Expand Down
22 changes: 17 additions & 5 deletions stack_gvisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package tun
import (
"context"
"net/netip"
"os"
"time"

"github.com/sagernet/gvisor/pkg/tcpip"
Expand Down Expand Up @@ -32,7 +33,7 @@ type GVisor struct {
ctx context.Context
tun GVisorTun
endpointIndependentNat bool
udpTimeout int64
udpTimeout time.Duration
broadcastAddr netip.Addr
handler Handler
logger logger.Logger
Expand Down Expand Up @@ -85,13 +86,18 @@ func (t *GVisor) Start() error {
localAddr: source.TCPAddr(),
remoteAddr: destination.TCPAddr(),
}
pErr := t.handler.PrepareConnection(source, destination)
if pErr != nil {
r.Complete(gWriteUnreachable(t.stack, r.Packet(), err) == os.ErrInvalid)
return
}
go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil)
})
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
if !t.endpointIndependentNat {
udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) {
var wq waiter.Queue
endpoint, err := request.CreateEndpoint(&wq)
endpoint, err := r.CreateEndpoint(&wq)
if err != nil {
return
}
Expand All @@ -102,9 +108,15 @@ func (t *GVisor) Start() error {
endpoint.Abort()
return
}
source := M.SocksaddrFromNet(lAddr)
destination := M.SocksaddrFromNet(rAddr)
pErr := t.handler.PrepareConnection(source, destination)
if pErr != nil {
gWriteUnreachable(t.stack, r.Packet(), pErr)
r.Packet().DecRef()
return
}
go func() {
source := M.SocksaddrFromNet(lAddr)
destination := M.SocksaddrFromNet(rAddr)
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, destination), time.Duration(t.udpTimeout)*time.Second)
t.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil)
}()
Expand Down
75 changes: 39 additions & 36 deletions stack_gvisor_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"net/netip"
"os"
"sync"
"time"
_ "unsafe"

"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
Expand All @@ -19,59 +21,60 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/common/udpnat2"
)

type UDPForwarder struct {
ctx context.Context
stack *stack.Stack
udpNat *udpnat.Service[netip.AddrPort]

// cache
cacheProto tcpip.NetworkProtocolNumber
cacheID stack.TransportEndpointID
ctx context.Context
stack *stack.Stack
handler Handler
udpNat *udpnat.Service
}

func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder {
return &UDPForwarder{
ctx: ctx,
stack: stack,
udpNat: udpnat.NewEx[netip.AddrPort](udpTimeout, handler),
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, timeout time.Duration) *UDPForwarder {
forwarder := &UDPForwarder{
ctx: ctx,
stack: stack,
handler: handler,
}
forwarder.udpNat = udpnat.New(handler, forwarder.PreparePacketConnection, timeout)
return forwarder
}

func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
source := M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort)
destination := M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort)
if source.IsIPv4() {
f.cacheProto = header.IPv4ProtocolNumber
} else {
f.cacheProto = header.IPv6ProtocolNumber
}
gBuffer := pkt.Data().ToBuffer()
sBuffer := buf.NewSize(int(gBuffer.Size()))
gBuffer.Apply(func(view *buffer.View) {
sBuffer.Write(view.AsSlice())
bufferRange := pkt.Data().AsRange()
bufferSlices := make([][]byte, bufferRange.Size())
rangeIterate(bufferRange, func(view *buffer.View) {
bufferSlices = append(bufferSlices, view.AsSlice())
})
f.cacheID = id
f.udpNat.NewPacketEx(
f.ctx,
source.AddrPort(),
sBuffer,
source,
destination,
f.newUDPConn,
)
f.udpNat.NewPacket(bufferSlices, source, destination, pkt)
return true
}

func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter {
return &UDPBackWriter{
//go:linkname rangeIterate github.com/sagernet/gvisor/pkg/tcpip/buffer.(Range).iterate
func rangeIterate(r stack.Range, fn func(*buffer.View))

func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
pErr := f.handler.PrepareConnection(source, destination)
if pErr != nil {
gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer), pErr)
return false, nil, nil, nil
}
var sourceNetwork tcpip.NetworkProtocolNumber
if source.Addr.Is4() {
sourceNetwork = header.IPv4ProtocolNumber
} else {
sourceNetwork = header.IPv6ProtocolNumber
}
writer := &UDPBackWriter{
stack: f.stack,
source: f.cacheID.RemoteAddress,
sourcePort: f.cacheID.RemotePort,
sourceNetwork: f.cacheProto,
source: AddressFromAddr(source.Addr),
sourcePort: source.Port,
sourceNetwork: sourceNetwork,
}
return true, f.ctx, writer, nil
}

type UDPBackWriter struct {
Expand Down
89 changes: 51 additions & 38 deletions stack_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat"
"github.com/sagernet/sing/common/udpnat2"
)

var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25")
Expand All @@ -34,13 +34,13 @@ type System struct {
inet6ServerAddress netip.Addr
inet6Address netip.Addr
broadcastAddr netip.Addr
udpTimeout int64
udpTimeout time.Duration
tcpListener net.Listener
tcpListener6 net.Listener
tcpPort uint16
tcpPort6 uint16
tcpNat *TCPNat
udpNat *udpnat.Service[netip.AddrPort]
udpNat *udpnat.Service
bindInterface bool
interfaceFinder control.InterfaceFinder
frontHeadroom int
Expand Down Expand Up @@ -151,8 +151,8 @@ func (s *System) start() error {
s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port
go s.acceptLoop(tcpListener)
}
s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout))
s.udpNat = udpnat.NewEx[netip.AddrPort](s.udpTimeout, s.handler)
s.tcpNat = NewNat(s.ctx, s.udpTimeout)
s.udpNat = udpnat.New(s.handler, s.preparePacketConnection, s.udpTimeout)
return nil
}

Expand Down Expand Up @@ -354,7 +354,11 @@ func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.
packet.SetDestinationIP(session.Source.Addr())
header.SetDestinationPort(session.Source.Port())
} else {
natPort := s.tcpNat.Lookup(source, destination)
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
if err != nil {
// TODO: implement rejects
return nil
}
packet.SetSourceIP(s.inet4Address)
header.SetSourcePort(natPort)
packet.SetDestinationIP(s.inet4ServerAddress)
Expand Down Expand Up @@ -385,7 +389,11 @@ func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.
packet.SetDestinationIP(session.Source.Addr())
header.SetDestinationPort(session.Source.Port())
} else {
natPort := s.tcpNat.Lookup(source, destination)
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
if err != nil {
// TODO: implement rejects
return nil
}
packet.SetSourceIP(s.inet6Address)
header.SetSourcePort(natPort)
packet.SetDestinationIP(s.inet6ServerAddress)
Expand All @@ -409,56 +417,61 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.
if !header.Valid() {
return E.New("ipv4: udp: invalid packet")
}
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
source := M.SocksaddrFrom(packet.SourceIP(), header.SourcePort())
destination := M.SocksaddrFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr.IsGlobalUnicast() {
return nil
}
data := buf.As(header.Payload())
if data.Len() == 0 {
s.udpNat.NewPacket([][]byte{header.Payload()}, source, destination, packet)
return nil
}

func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error {
if !header.Valid() {
return E.New("ipv6: udp: invalid packet")
}
source := M.SocksaddrFrom(packet.SourceIP(), header.SourcePort())
destination := M.SocksaddrFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr.IsGlobalUnicast() {
return nil
}
s.udpNat.NewPacketEx(s.ctx, source, data.ToOwned(), M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), func(natConn N.PacketConn) N.PacketWriter {
s.udpNat.NewPacket([][]byte{header.Payload()}, source, destination, packet)
return nil
}

func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
pErr := s.handler.PrepareConnection(source, destination)
if pErr != nil {
// TODO: implement ICMP port unreachable
return false, nil, nil, nil
}
var writer N.PacketWriter
if source.IsIPv4() {
packet := userData.(clashtcpip.IPv4Packet)
headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize
headerCopy := make([]byte, headerLen)
copy(headerCopy, packet[:headerLen])
return &systemUDPPacketWriter4{
writer = &systemUDPPacketWriter4{
s.tun,
s.frontHeadroom + PacketOffset,
headerCopy,
source,
source.AddrPort(),
s.txChecksumOffload,
}
})
return nil
}

func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error {
if !header.Valid() {
return E.New("ipv6: udp: invalid packet")
}
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return nil
}
data := buf.As(header.Payload())
if data.Len() == 0 {
return nil
}
s.udpNat.NewPacketEx(s.ctx, source, data.ToOwned(), M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), func(natConn N.PacketConn) N.PacketWriter {
headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize
} else {
packet := userData.(clashtcpip.IPv6Packet)
headerLen := len(packet) - int(packet.PayloadLength()) + clashtcpip.UDPHeaderSize
headerCopy := make([]byte, headerLen)
copy(headerCopy, packet[:headerLen])
return &systemUDPPacketWriter6{
writer = &systemUDPPacketWriter6{
s.tun,
s.frontHeadroom + PacketOffset,
headerCopy,
source,
source.AddrPort(),
s.txChecksumOffload,
}
})
return nil
}
return true, s.ctx, writer, nil
}

func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error {
Expand Down
11 changes: 8 additions & 3 deletions stack_system_nat.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tun

import (
"context"
M "github.com/sagernet/sing/common/metadata"
"net/netip"
"sync"
"time"
Expand Down Expand Up @@ -68,12 +69,16 @@ func (n *TCPNat) LookupBack(port uint16) *TCPSession {
return session
}

func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint16 {
func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort, handler Handler) (uint16, error) {
n.addrAccess.RLock()
port, loaded := n.addrMap[source]
n.addrAccess.RUnlock()
if loaded {
return port
return port, nil
}
pErr := handler.PrepareConnection(M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination))
if pErr != nil {
return 0, pErr
}
n.addrAccess.Lock()
nextPort := n.portIndex
Expand All @@ -92,5 +97,5 @@ func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint1
LastActive: time.Now(),
}
n.portAccess.Unlock()
return nextPort
return nextPort, nil
}
Loading

0 comments on commit c45d71a

Please sign in to comment.