Skip to content

Commit

Permalink
Adding support for RTA_VIA
Browse files Browse the repository at this point in the history
Signed-off-by: Steve Shaw <shaw38@gmail.com>
  • Loading branch information
fach committed Dec 3, 2020
1 parent fb953eb commit 5da81c2
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 0 deletions.
9 changes: 9 additions & 0 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type Route struct {
MPLSDst *int
NewDst Destination
Encap Encap
Via Destination
MTU int
Window int
Rtt int
Expand Down Expand Up @@ -79,6 +80,9 @@ func (r Route) String() string {
if r.Encap != nil {
elems = append(elems, fmt.Sprintf("Encap: %s", r.Encap))
}
if r.Via != nil {
elems = append(elems, fmt.Sprintf("Via: %s", r.Via))
}
elems = append(elems, fmt.Sprintf("Src: %s", r.Src))
if len(r.MultiPath) > 0 {
elems = append(elems, fmt.Sprintf("Gw: %s", r.MultiPath))
Expand Down Expand Up @@ -107,6 +111,7 @@ func (r Route) Equal(x Route) bool {
r.Flags == x.Flags &&
(r.MPLSDst == x.MPLSDst || (r.MPLSDst != nil && x.MPLSDst != nil && *r.MPLSDst == *x.MPLSDst)) &&
(r.NewDst == x.NewDst || (r.NewDst != nil && r.NewDst.Equal(x.NewDst))) &&
(r.Via == x.Via || (r.Via != nil && r.Via.Equal(x.Via))) &&
(r.Encap == x.Encap || (r.Encap != nil && r.Encap.Equal(x.Encap)))
}

Expand Down Expand Up @@ -136,6 +141,7 @@ type NexthopInfo struct {
Flags int
NewDst Destination
Encap Encap
Via Destination
}

func (n *NexthopInfo) String() string {
Expand All @@ -147,6 +153,9 @@ func (n *NexthopInfo) String() string {
if n.Encap != nil {
elems = append(elems, fmt.Sprintf("Encap: %s", n.Encap))
}
if n.Via != nil {
elems = append(elems, fmt.Sprintf("Via: %s", n.Via))
}
elems = append(elems, fmt.Sprintf("Weight: %d", n.Hops+1))
elems = append(elems, fmt.Sprintf("Gw: %s", n.Gw))
elems = append(elems, fmt.Sprintf("Flags: %s", n.ListFlags()))
Expand Down
85 changes: 85 additions & 0 deletions route_linux.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package netlink

import (
"bytes"
"encoding/binary"
"fmt"
"net"
"strings"
Expand Down Expand Up @@ -446,6 +448,62 @@ func (e *SEG6LocalEncap) Equal(x Encap) bool {
return true
}

type Via struct {
AddrFamily int
Addr net.IP
}

func (v *Via) Equal(x Destination) bool {
o, ok := x.(*Via)
if !ok {
return false
}
if v.AddrFamily == x.Family() && v.Addr.Equal(o.Addr) {
return true
}
return false
}

func (v *Via) String() string {
return fmt.Sprintf("Family: %d, Address: %s", v.AddrFamily, v.Addr.String())
}

func (v *Via) Family() int {
return v.AddrFamily
}

func (v *Via) Encode() ([]byte, error) {
buf := &bytes.Buffer{}
err := binary.Write(buf, native, uint16(v.AddrFamily))
if err != nil {
return nil, err
}
err = binary.Write(buf, native, v.Addr)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}

func (v *Via) Decode(b []byte) error {
native := nl.NativeEndian()
if len(b) < 6 {
return fmt.Errorf("decoding failed: buffer too small (%d bytes)", len(b))
}
v.AddrFamily = int(native.Uint16(b[0:2]))
if v.AddrFamily == nl.FAMILY_V4 {
v.Addr = net.IP(b[2:6])
return nil
} else if v.AddrFamily == nl.FAMILY_V6 {
if len(b) < 18 {
return fmt.Errorf("decoding failed: buffer too small (%d bytes)", len(b))
}
v.Addr = net.IP(b[2:])
return nil
}
return fmt.Errorf("decoding failed: address family %d unknown", v.AddrFamily)
}

// RouteAdd will add a route to the system.
// Equivalent to: `ip route add $route`
func RouteAdd(route *Route) error {
Expand Down Expand Up @@ -567,6 +625,14 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg
rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_GATEWAY, gwData))
}

if route.Via != nil {
buf, err := route.Via.Encode()
if err != nil {
return fmt.Errorf("failed to encode RTA_VIA: %v", err)
}
rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_VIA, buf))
}

if len(route.MultiPath) > 0 {
buf := []byte{}
for _, nh := range route.MultiPath {
Expand Down Expand Up @@ -609,6 +675,13 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg
}
children = append(children, nl.NewRtAttr(unix.RTA_ENCAP, buf))
}
if nh.Via != nil {
buf, err := nh.Via.Encode()
if err != nil {
return err
}
children = append(children, nl.NewRtAttr(unix.RTA_VIA, buf))
}
rtnh.Children = children
buf = append(buf, rtnh.Serialize()...)
}
Expand Down Expand Up @@ -907,6 +980,12 @@ func deserializeRoute(m []byte) (Route, error) {
encapType = attr
case unix.RTA_ENCAP:
encap = attr
case unix.RTA_VIA:
d := &Via{}
if err := d.Decode(attr.Value); err != nil {
return nil, nil, err
}
info.Via = d
}
}

Expand Down Expand Up @@ -944,6 +1023,12 @@ func deserializeRoute(m []byte) (Route, error) {
return route, err
}
route.NewDst = d
case unix.RTA_VIA:
v := &Via{}
if err := v.Decode(attr.Value); err != nil {
return route, err
}
route.Via = v
case unix.RTA_ENCAP_TYPE:
encapType = attr
case unix.RTA_ENCAP:
Expand Down
66 changes: 66 additions & 0 deletions route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1339,3 +1339,69 @@ func TestMTURouteAddDel(t *testing.T) {
t.Fatal("Route not removed properly")
}
}

func TestRouteViaAddDel(t *testing.T) {
minKernelRequired(t, 5, 4)
tearDown := setUpNetlinkTest(t)
defer tearDown()

_, err := RouteList(nil, FAMILY_V4)
if err != nil {
t.Fatal(err)
}

link, err := LinkByName("lo")
if err != nil {
t.Fatal(err)
}

if err := LinkSetUp(link); err != nil {
t.Fatal(err)
}

route := &Route{
LinkIndex: link.Attrs().Index,
Dst: &net.IPNet{
IP: net.IPv4(192, 168, 0, 0),
Mask: net.CIDRMask(24, 32),
},
MultiPath: []*NexthopInfo{
{
LinkIndex: link.Attrs().Index,
Via: &Via{
AddrFamily: FAMILY_V6,
Addr: net.ParseIP("2001::1"),
},
},
},
}

if err := RouteAdd(route); err != nil {
t.Fatalf("route: %v, err: %v", route, err)
}

routes, err := RouteList(link, FAMILY_V4)
if err != nil {
t.Fatal(err)
}
if len(routes) != 1 {
t.Fatal("Route not added properly")
}

got := routes[0].Via
want := route.MultiPath[0].Via
if !want.Equal(got) {
t.Fatalf("Route Via attribute does not match; got: %s, want: %s", got, want)
}

if err := RouteDel(route); err != nil {
t.Fatal(err)
}
routes, err = RouteList(link, FAMILY_V4)
if err != nil {
t.Fatal(err)
}
if len(routes) != 0 {
t.Fatal("Route not removed properly")
}
}

0 comments on commit 5da81c2

Please sign in to comment.