Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rule: fix 32-bit platforms don't support adding rules with a mark 0xF0000000 #983

Merged
merged 1 commit into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions route_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,7 @@ type RouteGetOptions struct {
VrfName string
SrcAddr net.IP
UID *uint32
Mark int
Mark uint32
FIBMatch bool
}

Expand Down Expand Up @@ -1557,7 +1557,7 @@ func (h *Handle) RouteGetWithOptions(destination net.IP, options *RouteGetOption

if options.Mark > 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(options.Mark))
native.PutUint32(b, options.Mark)

req.AddData(nl.NewRtAttr(unix.RTA_MARK, b))
}
Expand Down
80 changes: 65 additions & 15 deletions route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2292,27 +2292,41 @@ func TestRouteFWMarkOption(t *testing.T) {
}

// a table different than unix.RT_TABLE_MAIN
testtable := 1000
testTable0 := 254
testTable1 := 1000
testTable2 := 1001

gw1 := net.IPv4(192, 168, 1, 254)
gw2 := net.IPv4(192, 168, 2, 254)
gw0 := net.IPv4(192, 168, 1, 254)
gw1 := net.IPv4(192, 168, 2, 254)
gw2 := net.IPv4(192, 168, 3, 254)

// add default route via gw1 (in main route table by default)
// add default route via gw0 (in main route table by default)
defaultRouteMain := Route{
Dst: nil,
Gw: gw1,
Dst: nil,
Gw: gw0,
Table: testTable0,
}
if err := RouteAdd(&defaultRouteMain); err != nil {
t.Fatal(err)
}

// add default route via gw1 in test route table
defaultRouteTest1 := Route{
Dst: nil,
Gw: gw1,
Table: testTable1,
}
if err := RouteAdd(&defaultRouteTest1); err != nil {
t.Fatal(err)
}

// add default route via gw2 in test route table
defaultRouteTest := Route{
defaultRouteTest2 := Route{
Dst: nil,
Gw: gw2,
Table: testtable,
Table: testTable2,
}
if err := RouteAdd(&defaultRouteTest); err != nil {
if err := RouteAdd(&defaultRouteTest2); err != nil {
t.Fatal(err)
}

Expand All @@ -2324,34 +2338,70 @@ func TestRouteFWMarkOption(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if len(routes) != 2 || routes[0].Table == routes[1].Table {
if len(routes) != 3 || routes[0].Table == routes[1].Table || routes[1].Table == routes[2].Table ||
routes[0].Table == routes[2].Table {
t.Fatal("Routes not added properly")
}

// add a rule that fwmark match should result in route lookup of test table
fwmark := 1000
fwmark1 := uint32(0xAFFFFFFF)
fwmark2 := uint32(0xBFFFFFFF)

rule := NewRule()
rule.Mark = fwmark
rule.Mask = 0xFFFFFFFF
rule.Table = testtable
rule.Mark = fwmark1
rule.Mask = &[]uint32{0xFFFFFFFF}[0]

rule.Table = testTable1
if err := RuleAdd(rule); err != nil {
t.Fatal(err)
}

rule = NewRule()
rule.Mark = fwmark2
rule.Mask = &[]uint32{0xFFFFFFFF}[0]
rule.Table = testTable2
if err := RuleAdd(rule); err != nil {
t.Fatal(err)
}

rules, err := RuleListFiltered(FAMILY_V4, &Rule{Mark: fwmark1}, RT_FILTER_MARK)
if err != nil {
t.Fatal(err)
}
if len(rules) != 1 || rules[0].Table != testTable1 || rules[0].Mark != fwmark1 {
t.Fatal("Rules not added properly")
}

rules, err = RuleListFiltered(FAMILY_V4, &Rule{Mark: fwmark2}, RT_FILTER_MARK)
if err != nil {
t.Fatal(err)
}
if len(rules) != 1 || rules[0].Table != testTable2 || rules[0].Mark != fwmark2 {
t.Fatal("Rules not added properly")
}

dstIP := net.IPv4(10, 1, 1, 1)

// check getting route without FWMark option
routes, err = RouteGetWithOptions(dstIP, &RouteGetOptions{})
if err != nil {
t.Fatal(err)
}
if len(routes) != 1 || !routes[0].Gw.Equal(gw0) {
t.Fatal(routes)
}

// check getting route with FWMark option
routes, err = RouteGetWithOptions(dstIP, &RouteGetOptions{Mark: fwmark1})
if err != nil {
t.Fatal(err)
}
if len(routes) != 1 || !routes[0].Gw.Equal(gw1) {
t.Fatal(routes)
}

// check getting route with FWMark option
routes, err = RouteGetWithOptions(dstIP, &RouteGetOptions{Mark: fwmark})
routes, err = RouteGetWithOptions(dstIP, &RouteGetOptions{Mark: fwmark2})
if err != nil {
t.Fatal(err)
}
Expand Down
8 changes: 4 additions & 4 deletions rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ type Rule struct {
Priority int
Family int
Table int
Mark int
Mask int
Mark uint32
Mask *uint32
Tos uint
TunID uint
Goto int
Expand Down Expand Up @@ -51,8 +51,8 @@ func NewRule() *Rule {
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Priority: -1,
Mark: -1,
Mask: -1,
Mark: 0,
Mask: nil,
Goto: -1,
Flow: -1,
}
Expand Down
25 changes: 18 additions & 7 deletions rule_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error {
native.PutUint32(b, uint32(rule.Priority))
req.AddData(nl.NewRtAttr(nl.FRA_PRIORITY, b))
}
if rule.Mark >= 0 {
if rule.Mark != 0 || rule.Mask != nil {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Mark))
native.PutUint32(b, rule.Mark)
req.AddData(nl.NewRtAttr(nl.FRA_FWMARK, b))
}
if rule.Mask >= 0 {
if rule.Mask != nil {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Mask))
native.PutUint32(b, *rule.Mask)
req.AddData(nl.NewRtAttr(nl.FRA_FWMASK, b))
}
if rule.Flow >= 0 {
Expand Down Expand Up @@ -242,9 +242,10 @@ func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) (
Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attrs[j].Value)),
}
case nl.FRA_FWMARK:
rule.Mark = int(native.Uint32(attrs[j].Value[0:4]))
rule.Mark = native.Uint32(attrs[j].Value[0:4])
case nl.FRA_FWMASK:
rule.Mask = int(native.Uint32(attrs[j].Value[0:4]))
mask := native.Uint32(attrs[j].Value[0:4])
rule.Mask = &mask
case nl.FRA_TUN_ID:
rule.TunID = uint(native.Uint64(attrs[j].Value[0:8]))
case nl.FRA_IIFNAME:
Expand Down Expand Up @@ -297,7 +298,7 @@ func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) (
continue
case filterMask&RT_FILTER_MARK != 0 && rule.Mark != filter.Mark:
continue
case filterMask&RT_FILTER_MASK != 0 && rule.Mask != filter.Mask:
case filterMask&RT_FILTER_MASK != 0 && !ptrEqual(rule.Mask, filter.Mask):
continue
}
}
Expand All @@ -321,3 +322,13 @@ func (pr *RuleUIDRange) toRtAttrData() []byte {
native.PutUint32(b[1], pr.End)
return bytes.Join(b, []byte{})
}

func ptrEqual(a, b *uint32) bool {
if a == b {
return true
}
if (a == nil) || (b == nil) {
return false
}
return *a == *b
}
Loading
Loading