Skip to content

Commit

Permalink
Add checksum offload for Linux
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 15, 2023
1 parent 3195f6f commit 50d4d76
Show file tree
Hide file tree
Showing 14 changed files with 203 additions and 86 deletions.
2 changes: 1 addition & 1 deletion internal/clashtcpip/icmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ func (p ICMPPacket) SetChecksum(sum [2]byte) {
}

func (p ICMPPacket) ResetChecksum() {
p.SetChecksum(zeroChecksum)
p.SetChecksum(ZeroChecksum)
p.SetChecksum(Checksum(0, p))
}
2 changes: 1 addition & 1 deletion internal/clashtcpip/icmpv6.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,6 @@ func (b ICMPv6Packet) Payload() []byte {
}

func (b ICMPv6Packet) ResetChecksum(psum uint32) {
b.SetChecksum(zeroChecksum)
b.SetChecksum(ZeroChecksum)
b.SetChecksum(Checksum(psum, b))
}
2 changes: 1 addition & 1 deletion internal/clashtcpip/ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (p IPv4Packet) DecTimeToLive() {
}

func (p IPv4Packet) ResetChecksum() {
p.SetChecksum(zeroChecksum)
p.SetChecksum(ZeroChecksum)
p.SetChecksum(Checksum(0, p[:p.HeaderLen()]))
}

Expand Down
2 changes: 1 addition & 1 deletion internal/clashtcpip/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (p TCPPacket) SetChecksum(sum [2]byte) {
}

func (p TCPPacket) ResetChecksum(psum uint32) {
p.SetChecksum(zeroChecksum)
p.SetChecksum(ZeroChecksum)
p.SetChecksum(Checksum(psum, p))
}

Expand Down
2 changes: 1 addition & 1 deletion internal/clashtcpip/tcpip.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package clashtcpip

var zeroChecksum = [2]byte{0x00, 0x00}
var ZeroChecksum = [2]byte{0x00, 0x00}

var SumFnc = SumCompat

Expand Down
2 changes: 1 addition & 1 deletion internal/clashtcpip/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (p UDPPacket) SetChecksum(sum [2]byte) {
}

func (p UDPPacket) ResetChecksum(psum uint32) {
p.SetChecksum(zeroChecksum)
p.SetChecksum(ZeroChecksum)
p.SetChecksum(Checksum(psum, p))
}

Expand Down
30 changes: 15 additions & 15 deletions stack_mixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,18 @@ func (m *Mixed) tunLoop() {
m.wintunLoop(winTun)
return
}
if batchTUN, isBatchTUN := m.tun.(BatchTUN); isBatchTUN {
batchSize := batchTUN.BatchSize()
if linuxTUN, isLinuxTUN := m.tun.(LinuxTUN); isLinuxTUN {
m.frontHeadroom = linuxTUN.FrontHeadroom()
m.txChecksumOffload = linuxTUN.TXChecksumOffload()
batchSize := linuxTUN.BatchSize()
if batchSize > 1 {
m.batchLoop(batchTUN, batchSize)
m.batchLoop(linuxTUN, batchSize)
return
}
}
frontHeadroom := m.tun.FrontHeadroom()
packetBuffer := make([]byte, m.mtu+frontHeadroom+PacketOffset)
packetBuffer := make([]byte, m.mtu+PacketOffset)
for {
n, err := m.tun.Read(packetBuffer[frontHeadroom:])
n, err := m.tun.Read(packetBuffer)
if err != nil {
if E.IsClosed(err) {
return
Expand All @@ -111,8 +112,8 @@ func (m *Mixed) tunLoop() {
if n < clashtcpip.IPv4PacketMinLength {
continue
}
rawPacket := packetBuffer[:frontHeadroom+n]
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n]
rawPacket := packetBuffer[:n]
packet := packetBuffer[PacketOffset:n]
if m.processPacket(packet) {
_, err = m.tun.Write(rawPacket)
if err != nil {
Expand Down Expand Up @@ -142,16 +143,15 @@ func (m *Mixed) wintunLoop(winTun WinTun) {
}
}

func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) {
frontHeadroom := m.tun.FrontHeadroom()
func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) {
packetBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, m.mtu+frontHeadroom)
packetBuffers[i] = make([]byte, m.mtu+m.frontHeadroom)
}
for {
n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes)
n, err := linuxTUN.BatchRead(packetBuffers, m.frontHeadroom, packetSizes)
if err != nil {
if E.IsClosed(err) {
return
Expand All @@ -167,13 +167,13 @@ func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) {
continue
}
packetBuffer := packetBuffers[i]
packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize]
packet := packetBuffer[m.frontHeadroom : m.frontHeadroom+packetSize]
if m.processPacket(packet) {
writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize])
writeBuffers = append(writeBuffers, packetBuffer[:m.frontHeadroom+packetSize])
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom)
err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom)
if err != nil {
m.logger.Trace(E.Cause(err, "batch write packet"))
}
Expand Down
92 changes: 60 additions & 32 deletions stack_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type System struct {
udpNat *udpnat.Service[netip.AddrPort]
bindInterface bool
interfaceFinder control.InterfaceFinder
frontHeadroom int
txChecksumOffload bool
}

type Session struct {
Expand Down Expand Up @@ -144,17 +146,18 @@ func (s *System) tunLoop() {
s.wintunLoop(winTun)
return
}
if batchTUN, isBatchTUN := s.tun.(BatchTUN); isBatchTUN {
batchSize := batchTUN.BatchSize()
if linuxTUN, isLinuxTUN := s.tun.(LinuxTUN); isLinuxTUN {
s.frontHeadroom = linuxTUN.FrontHeadroom()
s.txChecksumOffload = linuxTUN.TXChecksumOffload()
batchSize := linuxTUN.BatchSize()
if batchSize > 1 {
s.batchLoop(batchTUN, batchSize)
s.batchLoop(linuxTUN, batchSize)
return
}
}
frontHeadroom := s.tun.FrontHeadroom()
packetBuffer := make([]byte, s.mtu+frontHeadroom+PacketOffset)
packetBuffer := make([]byte, s.mtu+PacketOffset)
for {
n, err := s.tun.Read(packetBuffer[frontHeadroom:])
n, err := s.tun.Read(packetBuffer)
if err != nil {
if E.IsClosed(err) {
return
Expand All @@ -164,8 +167,8 @@ func (s *System) tunLoop() {
if n < clashtcpip.IPv4PacketMinLength {
continue
}
rawPacket := packetBuffer[:frontHeadroom+n]
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n]
rawPacket := packetBuffer[:n]
packet := packetBuffer[PacketOffset:n]
if s.processPacket(packet) {
_, err = s.tun.Write(rawPacket)
if err != nil {
Expand Down Expand Up @@ -195,16 +198,15 @@ func (s *System) wintunLoop(winTun WinTun) {
}
}

func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) {
frontHeadroom := s.tun.FrontHeadroom()
func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) {
packetBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, s.mtu+frontHeadroom)
packetBuffers[i] = make([]byte, s.mtu+s.frontHeadroom)
}
for {
n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes)
n, err := linuxTUN.BatchRead(packetBuffers, s.frontHeadroom, packetSizes)
if err != nil {
if E.IsClosed(err) {
return
Expand All @@ -220,13 +222,13 @@ func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) {
continue
}
packetBuffer := packetBuffers[i]
packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize]
packet := packetBuffer[s.frontHeadroom : s.frontHeadroom+packetSize]
if s.processPacket(packet) {
writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize])
writeBuffers = append(writeBuffers, packetBuffer[:s.frontHeadroom+packetSize])
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom)
err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom)
if err != nil {
s.logger.Trace(E.Cause(err, "batch write packet"))
}
Expand Down Expand Up @@ -352,8 +354,10 @@ func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.
packet.SetDestinationIP(s.inet4ServerAddress)
header.SetDestinationPort(s.tcpPort)
}
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
if !s.txChecksumOffload {
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
}
return nil
}

Expand All @@ -378,8 +382,9 @@ func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.
packet.SetDestinationIP(s.inet6ServerAddress)
header.SetDestinationPort(s.tcpPort6)
}
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
if !s.txChecksumOffload {
header.ResetChecksum(packet.PseudoSum())
}
return nil
}

Expand Down Expand Up @@ -410,7 +415,13 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.
headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize
headerCopy := make([]byte, headerLen)
copy(headerCopy, packet[:headerLen])
return &systemUDPPacketWriter4{s.tun, s.tun.FrontHeadroom() + PacketOffset, headerCopy, source}
return &systemUDPPacketWriter4{
s.tun,
s.frontHeadroom + PacketOffset,
headerCopy,
source,
s.txChecksumOffload,
}
})
return nil
}
Expand All @@ -436,7 +447,13 @@ func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.
headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize
headerCopy := make([]byte, headerLen)
copy(headerCopy, packet[:headerLen])
return &systemUDPPacketWriter6{s.tun, s.tun.FrontHeadroom() + PacketOffset, headerCopy, source}
return &systemUDPPacketWriter6{
s.tun,
s.frontHeadroom + PacketOffset,
headerCopy,
source,
s.txChecksumOffload,
}
})
return nil
}
Expand Down Expand Up @@ -468,10 +485,11 @@ func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip
}

type systemUDPPacketWriter4 struct {
tun Tun
frontHeadroom int
header []byte
source netip.AddrPort
tun Tun
frontHeadroom int
header []byte
source netip.AddrPort
txChecksumOffload bool
}

func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
Expand All @@ -488,8 +506,13 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetDestinationPort(udpHdr.SourcePort())
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(uint16(buffer.Len() + clashtcpip.UDPHeaderSize))
udpHdr.ResetChecksum(ipHdr.PseudoSum())
ipHdr.ResetChecksum()
if !w.txChecksumOffload {
udpHdr.ResetChecksum(ipHdr.PseudoSum())
ipHdr.ResetChecksum()
} else {
//udpHdr.SetChecksum(clashtcpip.ZeroChecksum)
ipHdr.ResetChecksum()
}
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
} else {
Expand All @@ -499,10 +522,11 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
}

type systemUDPPacketWriter6 struct {
tun Tun
frontHeadroom int
header []byte
source netip.AddrPort
tun Tun
frontHeadroom int
header []byte
source netip.AddrPort
txChecksumOffload bool
}

func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
Expand All @@ -520,7 +544,11 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetDestinationPort(udpHdr.SourcePort())
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(udpLen)
udpHdr.ResetChecksum(ipHdr.PseudoSum())
if !w.txChecksumOffload {
udpHdr.ResetChecksum(ipHdr.PseudoSum())
} else {
//udpHdr.SetChecksum(clashtcpip.ZeroChecksum)
}
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
} else {
Expand Down
6 changes: 4 additions & 2 deletions tun.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ type Handler interface {
type Tun interface {
io.ReadWriter
N.VectorisedWriter
N.FrontHeadroom
Close() error
}

Expand All @@ -33,11 +32,13 @@ type WinTun interface {
ReadPacket() ([]byte, func(), error)
}

type BatchTUN interface {
type LinuxTUN interface {
Tun
N.FrontHeadroom
BatchSize() int
BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error)
BatchWrite(buffers [][]byte, offset int) error
TXChecksumOffload() bool
}

type Options struct {
Expand All @@ -46,6 +47,7 @@ type Options struct {
Inet6Address []netip.Prefix
MTU uint32
GSO bool
TXChecksumOffload bool
AutoRoute bool
StrictRoute bool
Inet4RouteAddress []netip.Prefix
Expand Down
6 changes: 0 additions & 6 deletions tun_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"net"
"net/netip"
"os"
"runtime"
"syscall"
"unsafe"

Expand Down Expand Up @@ -68,14 +67,9 @@ func New(options Options) (Tun, error) {
if !ok {
panic("create vectorised writer")
}
runtime.SetFinalizer(nativeTun.tunFile, nil)
return nativeTun, nil
}

func (t *NativeTun) FrontHeadroom() int {
return 0
}

func (t *NativeTun) Read(p []byte) (n int, err error) {
return t.tunFile.Read(p)
}
Expand Down
Loading

0 comments on commit 50d4d76

Please sign in to comment.