Skip to content

Commit

Permalink
Make IPSet actually support IPs, and fix protocol errors for newer ke…
Browse files Browse the repository at this point in the history
…rnels
  • Loading branch information
Anonymous committed Mar 11, 2021
1 parent 3bf47fa commit d618bed
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 21 deletions.
8 changes: 4 additions & 4 deletions cmd/ipset-test/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ var (
"destroy": {cmdDestroy, "creates a new ipset", 1},
"list": {cmdList, "list specific ipset", 1},
"listall": {cmdListAll, "list all ipsets", 0},
"add": {cmdAddDel(netlink.IpsetAdd), "add entry", 1},
"del": {cmdAddDel(netlink.IpsetDel), "delete entry", 1},
"add": {cmdAddDel(netlink.IpsetAdd), "add entry", 2},
"del": {cmdAddDel(netlink.IpsetDel), "delete entry", 2},
}

timeoutVal *uint32
Expand Down Expand Up @@ -89,9 +89,9 @@ func printUsage() {
}

func cmdProtocol(_ []string) {
protocol, err := netlink.IpsetProtocol()
protocol, minProto, err := netlink.IpsetProtocol()
check(err)
log.Println("Protocol:", protocol)
log.Println("Protocol:", protocol, "min:", minProto)
}

func cmdCreate(args []string) {
Expand Down
60 changes: 45 additions & 15 deletions ipset_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ type IPSetEntry struct {

// IPSetResult is the result of a dump request for a set
type IPSetResult struct {
Nfgenmsg *nl.Nfgenmsg
Protocol uint8
Revision uint8
Family uint8
Flags uint8
SetName string
TypeName string
Nfgenmsg *nl.Nfgenmsg
Protocol uint8
ProtocolMinVersion uint8
Revision uint8
Family uint8
Flags uint8
SetName string
TypeName string
Comment string

HashSize uint32
NumEntries uint32
Expand All @@ -38,6 +40,7 @@ type IPSetResult struct {
SizeInMemory uint32
CadtFlags uint32
Timeout *uint32
LineNo uint32

Entries []IPSetEntry
}
Expand All @@ -52,7 +55,7 @@ type IpsetCreateOptions struct {
}

// IpsetProtocol returns the ipset protocol version from the kernel
func IpsetProtocol() (uint8, error) {
func IpsetProtocol() (uint8, uint8, error) {
return pkgHandle.IpsetProtocol()
}

Expand Down Expand Up @@ -86,20 +89,20 @@ func IpsetAdd(setname string, entry *IPSetEntry) error {
return pkgHandle.ipsetAddDel(nl.IPSET_CMD_ADD, setname, entry)
}

// IpsetDele deletes an entry from an existing ipset.
// IpsetDel deletes an entry from an existing ipset.
func IpsetDel(setname string, entry *IPSetEntry) error {
return pkgHandle.ipsetAddDel(nl.IPSET_CMD_DEL, setname, entry)
}

func (h *Handle) IpsetProtocol() (uint8, error) {
func (h *Handle) IpsetProtocol() (protocol uint8, minVersion uint8, err error) {
req := h.newIpsetRequest(nl.IPSET_CMD_PROTOCOL)
msgs, err := req.Execute(unix.NETLINK_NETFILTER, 0)

if err != nil {
return 0, err
return 0, 0, err
}

return ipsetUnserialize(msgs).Protocol, nil
response := ipsetUnserialize(msgs)
return response.Protocol, response.ProtocolMinVersion, nil
}

func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOptions) error {
Expand All @@ -112,7 +115,7 @@ func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOption
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname)))
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_TYPENAME, nl.ZeroTerminated(typename)))
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_REVISION, nl.Uint8Attr(0)))
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_FAMILY, nl.Uint8Attr(0)))
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_FAMILY, nl.Uint8Attr(2))) // 2 == inet

data := nl.NewRtAttr(nl.IPSET_ATTR_DATA|int(nl.NLA_F_NESTED), nil)

Expand Down Expand Up @@ -187,6 +190,11 @@ func (h *Handle) IpsetListAll() ([]IPSetResult, error) {
func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error {
req := h.newIpsetRequest(nlCmd)
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname)))

if entry.Comment != "" {
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_COMMENT, nl.ZeroTerminated(entry.Comment)))
}

data := nl.NewRtAttr(nl.IPSET_ATTR_DATA|int(nl.NLA_F_NESTED), nil)

if !entry.Replace {
Expand All @@ -197,7 +205,12 @@ func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error
data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_TIMEOUT | nl.NLA_F_NET_BYTEORDER, Value: *entry.Timeout})
}
if entry.MAC != nil {
data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_ETHER, entry.MAC))
nestedData := nl.NewRtAttr(nl.IPSET_ATTR_ETHER|int(nl.NLA_F_NET_BYTEORDER), entry.MAC)
data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_ETHER|int(nl.NLA_F_NESTED), nestedData.Serialize()))
}
if entry.IP != nil {
nestedData := nl.NewRtAttr(nl.IPSET_ATTR_IP|int(nl.NLA_F_NET_BYTEORDER), entry.IP)
data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_IP|int(nl.NLA_F_NESTED), nestedData.Serialize()))
}

data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_LINENO | nl.NLA_F_NET_BYTEORDER, Value: 0})
Expand Down Expand Up @@ -249,6 +262,8 @@ func (result *IPSetResult) unserialize(msg []byte) {
result.Protocol = attr.Value[0]
case nl.IPSET_ATTR_SETNAME:
result.SetName = nl.BytesToString(attr.Value)
case nl.IPSET_ATTR_COMMENT:
result.Comment = nl.BytesToString(attr.Value)
case nl.IPSET_ATTR_TYPENAME:
result.TypeName = nl.BytesToString(attr.Value)
case nl.IPSET_ATTR_REVISION:
Expand All @@ -261,6 +276,8 @@ func (result *IPSetResult) unserialize(msg []byte) {
result.parseAttrData(attr.Value)
case nl.IPSET_ATTR_ADT | nl.NLA_F_NESTED:
result.parseAttrADT(attr.Value)
case nl.IPSET_ATTR_PROTOCOL_MIN:
result.ProtocolMinVersion = attr.Value[0]
default:
log.Printf("unknown ipset attribute from kernel: %+v %v", attr, attr.Type&nl.NLA_TYPE_MASK)
}
Expand All @@ -285,6 +302,17 @@ func (result *IPSetResult) parseAttrData(data []byte) {
result.SizeInMemory = attr.Uint32()
case nl.IPSET_ATTR_CADT_FLAGS | nl.NLA_F_NET_BYTEORDER:
result.CadtFlags = attr.Uint32()
case nl.IPSET_ATTR_IP | nl.NLA_F_NESTED:
for nested := range nl.ParseAttributes(attr.Value) {
switch nested.Type {
case nl.IPSET_ATTR_IP | nl.NLA_F_NET_BYTEORDER:
result.Entries = append(result.Entries, IPSetEntry{IP: nested.Value})
}
}
case nl.IPSET_ATTR_CADT_LINENO | nl.NLA_F_NET_BYTEORDER:
result.LineNo = attr.Uint32()
case nl.IPSET_ATTR_COMMENT:
result.Comment = nl.BytesToString(attr.Value)
default:
log.Printf("unknown ipset data attribute from kernel: %+v %v", attr, attr.Type&nl.NLA_TYPE_MASK)
}
Expand Down Expand Up @@ -316,6 +344,8 @@ func parseIPSetEntry(data []byte) (entry IPSetEntry) {
entry.Packets = &val
case nl.IPSET_ATTR_ETHER:
entry.MAC = net.HardwareAddr(attr.Value)
case nl.IPSET_ATTR_IP:
entry.IP = net.IP(attr.Value)
case nl.IPSET_ATTR_COMMENT:
entry.Comment = nl.BytesToString(attr.Value)
case nl.IPSET_ATTR_IP | nl.NLA_F_NESTED:
Expand Down
115 changes: 113 additions & 2 deletions ipset_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ package netlink

import (
"bytes"
"github.com/vishvananda/netlink/nl"
"io/ioutil"
"net"
"testing"

"github.com/vishvananda/netlink/nl"
)

func TestParseIpsetProtocolResult(t *testing.T) {
Expand Down Expand Up @@ -85,3 +84,115 @@ func TestParseIpsetListResult(t *testing.T) {
t.Errorf("expected MAC for second entry to be %s, got %s", expectedMAC.String(), ent.MAC.String())
}
}

func TestIpsetCreateListAddDelDestroy(t *testing.T) {
tearDown := setUpNetlinkTest(t)
defer tearDown()
timeout := uint32(3)
err := IpsetCreate("my-test-ipset-1", "hash:ip", IpsetCreateOptions{
Replace: true,
Timeout: &timeout,
Counters: true,
Comments: false,
Skbinfo: false,
})
if err != nil {
t.Fatal(err)
}

err = IpsetCreate("my-test-ipset-2", "hash:net", IpsetCreateOptions{
Replace: true,
Timeout: &timeout,
Counters: false,
Comments: true,
Skbinfo: true,
})
if err != nil {
t.Fatal(err)
}

results, err := IpsetListAll()

if err != nil {
t.Fatal(err)
}

if len(results) != 2 {
t.Fatalf("expected 2 IPSets to be created, got %d", len(results))
}

if results[0].SetName != "my-test-ipset-1" {
t.Errorf("expected name to be 'my-test-ipset-1', but got '%s'", results[0].SetName)
}

if results[1].SetName != "my-test-ipset-2" {
t.Errorf("expected name to be 'my-test-ipset-2', but got '%s'", results[1].SetName)
}

if results[0].TypeName != "hash:ip" {
t.Errorf("expected type to be 'hash:ip', but got '%s'", results[0].TypeName)
}

if results[1].TypeName != "hash:net" {
t.Errorf("expected type to be 'hash:net', but got '%s'", results[1].TypeName)
}

if *results[0].Timeout != 3 {
t.Errorf("expected timeout to be 3, but got '%d'", *results[0].Timeout)
}

err = IpsetAdd("my-test-ipset-1", &IPSetEntry{
Comment: "test comment",
IP: net.ParseIP("10.99.99.99").To4(),
Replace: false,
})

if err != nil {
t.Fatal(err)
}

result, err := IpsetList("my-test-ipset-1")

if err != nil {
t.Fatal(err)
}

if len(result.Entries) != 1 {
t.Fatalf("expected 1 entry be created, got '%d'", len(result.Entries))
}
if result.Entries[0].IP.String() != "10.99.99.99" {
t.Fatalf("expected entry to be '10.99.99.99', got '%s'", result.Entries[0].IP.String())
}

if result.Entries[0].Comment != "test comment" {
// This is only supported in the kernel module from revision 2 or 4, so comments may be ignored.
t.Logf("expected comment to be 'test comment', got '%s'", result.Entries[0].Comment)
}

err = IpsetDel("my-test-ipset-1", &IPSetEntry{
Comment: "test comment",
IP: net.ParseIP("10.99.99.99").To4(),
})
if err != nil {
t.Fatal(err)
}

result, err = IpsetList("my-test-ipset-1")
if err != nil {
t.Fatal(err)
}

if len(result.Entries) != 0 {
t.Fatalf("expected 0 entries to exist, got %d", len(result.Entries))
}

err = IpsetDestroy("my-test-ipset-1")
if err != nil {
t.Fatal(err)
}

err = IpsetDestroy("my-test-ipset-2")
if err != nil {
t.Fatal(err)
}
}

0 comments on commit d618bed

Please sign in to comment.