diff --git a/cmd/ipset-test/main.go b/cmd/ipset-test/main.go index 55b39ccc..84c2a278 100644 --- a/cmd/ipset-test/main.go +++ b/cmd/ipset-test/main.go @@ -26,8 +26,8 @@ var ( "destroy": {cmdDestroy, "creates a new ipset", 1}, "list": {cmdList, "list specific ipset", 1}, "listall": {cmdListAll, "list all ipsets", 0}, - "add": {cmdAddDel(netlink.IpsetAdd), "add entry", 1}, - "del": {cmdAddDel(netlink.IpsetDel), "delete entry", 1}, + "add": {cmdAddDel(netlink.IpsetAdd), "add entry", 2}, + "del": {cmdAddDel(netlink.IpsetDel), "delete entry", 2}, } timeoutVal *uint32 @@ -89,9 +89,9 @@ func printUsage() { } func cmdProtocol(_ []string) { - protocol, err := netlink.IpsetProtocol() + protocol, minProto, err := netlink.IpsetProtocol() check(err) - log.Println("Protocol:", protocol) + log.Println("Protocol:", protocol, "min:", minProto) } func cmdCreate(args []string) { diff --git a/ipset_linux.go b/ipset_linux.go index 5487fc1c..2adc2440 100644 --- a/ipset_linux.go +++ b/ipset_linux.go @@ -23,13 +23,15 @@ type IPSetEntry struct { // IPSetResult is the result of a dump request for a set type IPSetResult struct { - Nfgenmsg *nl.Nfgenmsg - Protocol uint8 - Revision uint8 - Family uint8 - Flags uint8 - SetName string - TypeName string + Nfgenmsg *nl.Nfgenmsg + Protocol uint8 + ProtocolMinVersion uint8 + Revision uint8 + Family uint8 + Flags uint8 + SetName string + TypeName string + Comment string HashSize uint32 NumEntries uint32 @@ -38,6 +40,7 @@ type IPSetResult struct { SizeInMemory uint32 CadtFlags uint32 Timeout *uint32 + LineNo uint32 Entries []IPSetEntry } @@ -52,7 +55,7 @@ type IpsetCreateOptions struct { } // IpsetProtocol returns the ipset protocol version from the kernel -func IpsetProtocol() (uint8, error) { +func IpsetProtocol() (uint8, uint8, error) { return pkgHandle.IpsetProtocol() } @@ -86,20 +89,20 @@ func IpsetAdd(setname string, entry *IPSetEntry) error { return pkgHandle.ipsetAddDel(nl.IPSET_CMD_ADD, setname, entry) } -// IpsetDele deletes an entry from an existing ipset. +// IpsetDel deletes an entry from an existing ipset. func IpsetDel(setname string, entry *IPSetEntry) error { return pkgHandle.ipsetAddDel(nl.IPSET_CMD_DEL, setname, entry) } -func (h *Handle) IpsetProtocol() (uint8, error) { +func (h *Handle) IpsetProtocol() (protocol uint8, minVersion uint8, err error) { req := h.newIpsetRequest(nl.IPSET_CMD_PROTOCOL) msgs, err := req.Execute(unix.NETLINK_NETFILTER, 0) if err != nil { - return 0, err + return 0, 0, err } - - return ipsetUnserialize(msgs).Protocol, nil + response := ipsetUnserialize(msgs) + return response.Protocol, response.ProtocolMinVersion, nil } func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOptions) error { @@ -112,7 +115,7 @@ func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOption req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname))) req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_TYPENAME, nl.ZeroTerminated(typename))) req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_REVISION, nl.Uint8Attr(0))) - req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_FAMILY, nl.Uint8Attr(0))) + req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_FAMILY, nl.Uint8Attr(2))) // 2 == inet data := nl.NewRtAttr(nl.IPSET_ATTR_DATA|int(nl.NLA_F_NESTED), nil) @@ -187,6 +190,11 @@ func (h *Handle) IpsetListAll() ([]IPSetResult, error) { func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error { req := h.newIpsetRequest(nlCmd) req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname))) + + if entry.Comment != "" { + req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_COMMENT, nl.ZeroTerminated(entry.Comment))) + } + data := nl.NewRtAttr(nl.IPSET_ATTR_DATA|int(nl.NLA_F_NESTED), nil) if !entry.Replace { @@ -197,7 +205,12 @@ func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_TIMEOUT | nl.NLA_F_NET_BYTEORDER, Value: *entry.Timeout}) } if entry.MAC != nil { - data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_ETHER, entry.MAC)) + nestedData := nl.NewRtAttr(nl.IPSET_ATTR_ETHER|int(nl.NLA_F_NET_BYTEORDER), entry.MAC) + data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_ETHER|int(nl.NLA_F_NESTED), nestedData.Serialize())) + } + if entry.IP != nil { + nestedData := nl.NewRtAttr(nl.IPSET_ATTR_IP|int(nl.NLA_F_NET_BYTEORDER), entry.IP) + data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_IP|int(nl.NLA_F_NESTED), nestedData.Serialize())) } data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_LINENO | nl.NLA_F_NET_BYTEORDER, Value: 0}) @@ -249,6 +262,8 @@ func (result *IPSetResult) unserialize(msg []byte) { result.Protocol = attr.Value[0] case nl.IPSET_ATTR_SETNAME: result.SetName = nl.BytesToString(attr.Value) + case nl.IPSET_ATTR_COMMENT: + result.Comment = nl.BytesToString(attr.Value) case nl.IPSET_ATTR_TYPENAME: result.TypeName = nl.BytesToString(attr.Value) case nl.IPSET_ATTR_REVISION: @@ -261,6 +276,8 @@ func (result *IPSetResult) unserialize(msg []byte) { result.parseAttrData(attr.Value) case nl.IPSET_ATTR_ADT | nl.NLA_F_NESTED: result.parseAttrADT(attr.Value) + case nl.IPSET_ATTR_PROTOCOL_MIN: + result.ProtocolMinVersion = attr.Value[0] default: log.Printf("unknown ipset attribute from kernel: %+v %v", attr, attr.Type&nl.NLA_TYPE_MASK) } @@ -285,6 +302,17 @@ func (result *IPSetResult) parseAttrData(data []byte) { result.SizeInMemory = attr.Uint32() case nl.IPSET_ATTR_CADT_FLAGS | nl.NLA_F_NET_BYTEORDER: result.CadtFlags = attr.Uint32() + case nl.IPSET_ATTR_IP | nl.NLA_F_NESTED: + for nested := range nl.ParseAttributes(attr.Value) { + switch nested.Type { + case nl.IPSET_ATTR_IP | nl.NLA_F_NET_BYTEORDER: + result.Entries = append(result.Entries, IPSetEntry{IP: nested.Value}) + } + } + case nl.IPSET_ATTR_CADT_LINENO | nl.NLA_F_NET_BYTEORDER: + result.LineNo = attr.Uint32() + case nl.IPSET_ATTR_COMMENT: + result.Comment = nl.BytesToString(attr.Value) default: log.Printf("unknown ipset data attribute from kernel: %+v %v", attr, attr.Type&nl.NLA_TYPE_MASK) } @@ -316,6 +344,8 @@ func parseIPSetEntry(data []byte) (entry IPSetEntry) { entry.Packets = &val case nl.IPSET_ATTR_ETHER: entry.MAC = net.HardwareAddr(attr.Value) + case nl.IPSET_ATTR_IP: + entry.IP = net.IP(attr.Value) case nl.IPSET_ATTR_COMMENT: entry.Comment = nl.BytesToString(attr.Value) case nl.IPSET_ATTR_IP | nl.NLA_F_NESTED: diff --git a/ipset_linux_test.go b/ipset_linux_test.go index 865d0a75..df298c41 100644 --- a/ipset_linux_test.go +++ b/ipset_linux_test.go @@ -2,11 +2,10 @@ package netlink import ( "bytes" + "github.com/vishvananda/netlink/nl" "io/ioutil" "net" "testing" - - "github.com/vishvananda/netlink/nl" ) func TestParseIpsetProtocolResult(t *testing.T) { @@ -85,3 +84,115 @@ func TestParseIpsetListResult(t *testing.T) { t.Errorf("expected MAC for second entry to be %s, got %s", expectedMAC.String(), ent.MAC.String()) } } + +func TestIpsetCreateListAddDelDestroy(t *testing.T) { + tearDown := setUpNetlinkTest(t) + defer tearDown() + timeout := uint32(3) + err := IpsetCreate("my-test-ipset-1", "hash:ip", IpsetCreateOptions{ + Replace: true, + Timeout: &timeout, + Counters: true, + Comments: false, + Skbinfo: false, + }) + if err != nil { + t.Fatal(err) + } + + err = IpsetCreate("my-test-ipset-2", "hash:net", IpsetCreateOptions{ + Replace: true, + Timeout: &timeout, + Counters: false, + Comments: true, + Skbinfo: true, + }) + if err != nil { + t.Fatal(err) + } + + results, err := IpsetListAll() + + if err != nil { + t.Fatal(err) + } + + if len(results) != 2 { + t.Fatalf("expected 2 IPSets to be created, got %d", len(results)) + } + + if results[0].SetName != "my-test-ipset-1" { + t.Errorf("expected name to be 'my-test-ipset-1', but got '%s'", results[0].SetName) + } + + if results[1].SetName != "my-test-ipset-2" { + t.Errorf("expected name to be 'my-test-ipset-2', but got '%s'", results[1].SetName) + } + + if results[0].TypeName != "hash:ip" { + t.Errorf("expected type to be 'hash:ip', but got '%s'", results[0].TypeName) + } + + if results[1].TypeName != "hash:net" { + t.Errorf("expected type to be 'hash:net', but got '%s'", results[1].TypeName) + } + + if *results[0].Timeout != 3 { + t.Errorf("expected timeout to be 3, but got '%d'", *results[0].Timeout) + } + + err = IpsetAdd("my-test-ipset-1", &IPSetEntry{ + Comment: "test comment", + IP: net.ParseIP("10.99.99.99").To4(), + Replace: false, + }) + + if err != nil { + t.Fatal(err) + } + + result, err := IpsetList("my-test-ipset-1") + + if err != nil { + t.Fatal(err) + } + + if len(result.Entries) != 1 { + t.Fatalf("expected 1 entry be created, got '%d'", len(result.Entries)) + } + if result.Entries[0].IP.String() != "10.99.99.99" { + t.Fatalf("expected entry to be '10.99.99.99', got '%s'", result.Entries[0].IP.String()) + } + + if result.Entries[0].Comment != "test comment" { + // This is only supported in the kernel module from revision 2 or 4, so comments may be ignored. + t.Logf("expected comment to be 'test comment', got '%s'", result.Entries[0].Comment) + } + + err = IpsetDel("my-test-ipset-1", &IPSetEntry{ + Comment: "test comment", + IP: net.ParseIP("10.99.99.99").To4(), + }) + if err != nil { + t.Fatal(err) + } + + result, err = IpsetList("my-test-ipset-1") + if err != nil { + t.Fatal(err) + } + + if len(result.Entries) != 0 { + t.Fatalf("expected 0 entries to exist, got %d", len(result.Entries)) + } + + err = IpsetDestroy("my-test-ipset-1") + if err != nil { + t.Fatal(err) + } + + err = IpsetDestroy("my-test-ipset-2") + if err != nil { + t.Fatal(err) + } +}