diff --git a/stack.go b/stack.go index 2191187..2e96e9d 100644 --- a/stack.go +++ b/stack.go @@ -35,9 +35,15 @@ func NewStack( ) (Stack, error) { switch stack { case "": - return NewSystem(options) + if WithGVisor { + return NewMixed(options) + } else { + return NewSystem(options) + } case "gvisor": return NewGVisor(options) + case "mixed": + return NewMixed(options) case "system": return NewSystem(options) case "lwip": diff --git a/stack_gvisor_stub.go b/stack_gvisor_stub.go index bd380f4..64c8a65 100644 --- a/stack_gvisor_stub.go +++ b/stack_gvisor_stub.go @@ -13,3 +13,9 @@ func NewGVisor( ) (Stack, error) { return nil, ErrGVisorNotIncluded } + +func NewMixed( + options StackOptions, +) (Stack, error) { + return nil, ErrGVisorNotIncluded +} diff --git a/stack_mixed.go b/stack_mixed.go new file mode 100644 index 0000000..7294901 --- /dev/null +++ b/stack_mixed.go @@ -0,0 +1,162 @@ +//go:build with_gvisor + +package tun + +import ( + "time" + "unsafe" + + "github.com/sagernet/gvisor/pkg/buffer" + "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" + "github.com/sagernet/gvisor/pkg/tcpip/header" + "github.com/sagernet/gvisor/pkg/tcpip/link/channel" + "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" + "github.com/sagernet/gvisor/pkg/waiter" + "github.com/sagernet/sing-tun/internal/clashtcpip" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/canceler" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" +) + +type Mixed struct { + *System + endpointIndependentNat bool + stack *stack.Stack + endpoint *channel.Endpoint +} + +func NewMixed( + options StackOptions, +) (Stack, error) { + system, err := NewSystem(options) + if err != nil { + return nil, err + } + return &Mixed{ + System: system.(*System), + endpointIndependentNat: options.EndpointIndependentNat, + }, nil +} + +func (m *Mixed) Start() error { + err := m.System.start() + if err != nil { + return err + } + endpoint := channel.New(1024, m.mtu, "") + ipStack, err := newGVisorStack(endpoint) + if err != nil { + return err + } + if !m.endpointIndependentNat { + udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) { + var wq waiter.Queue + endpoint, err := request.CreateEndpoint(&wq) + if err != nil { + return + } + udpConn := gonet.NewUDPConn(ipStack, &wq, endpoint) + lAddr := udpConn.RemoteAddr() + rAddr := udpConn.LocalAddr() + if lAddr == nil || rAddr == nil { + endpoint.Abort() + return + } + gConn := &gUDPConn{udpConn, ipStack, (*gRequest)(unsafe.Pointer(request)).pkt.IncRef()} + go func() { + var metadata M.Metadata + metadata.Source = M.SocksaddrFromNet(lAddr) + metadata.Destination = M.SocksaddrFromNet(rAddr) + ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(gConn), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(m.udpTimeout)*time.Second) + hErr := m.handler.NewPacketConnection(ctx, conn, metadata) + if hErr != nil { + endpoint.Abort() + } + }() + }) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) + } else { + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket) + } + m.stack = ipStack + m.endpoint = endpoint + go m.tunLoop() + return nil +} + +func (m *Mixed) tunLoop() { + if winTun, isWinTun := m.tun.(WinTun); isWinTun { + m.wintunLoop(winTun) + return + } + packetBuffer := make([]byte, m.mtu+PacketOffset) + for { + n, err := m.tun.Read(packetBuffer) + if err != nil { + return + } + if n < clashtcpip.IPv4PacketMinLength { + continue + } + packet := packetBuffer[PacketOffset:n] + switch ipVersion := packet[0] >> 4; ipVersion { + case 4: + err = m.processIPv4(packet) + case 6: + err = m.processIPv6(packet) + default: + err = E.New("ip: unknown version: ", ipVersion) + } + if err != nil { + m.logger.Trace(err) + } + } +} + +func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) error { + switch packet.Protocol() { + case clashtcpip.TCP: + return m.processIPv4TCP(packet, packet.Payload()) + case clashtcpip.UDP: + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + }) + m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt) + pkt.DecRef() + return nil + case clashtcpip.ICMP: + return m.processIPv4ICMP(packet, packet.Payload()) + default: + return common.Error(m.tun.Write(packet)) + } +} + +func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) error { + switch packet.Protocol() { + case clashtcpip.TCP: + return m.processIPv6TCP(packet, packet.Payload()) + case clashtcpip.UDP: + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + }) + m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt) + pkt.DecRef() + return nil + case clashtcpip.ICMPv6: + return m.processIPv6ICMP(packet, packet.Payload()) + default: + return common.Error(m.tun.Write(packet)) + } +} + +func (m *Mixed) Close() error { + m.endpoint.Attach(nil) + m.stack.Close() + for _, endpoint := range m.stack.CleanupEndpoints() { + endpoint.Abort() + } + return m.System.Close() +}