Make IPSet actually support IPs, and fix protocol errors for newer kernels

This commit is contained in:
Anonymous
2021-03-08 11:39:35 +00:00
committed by Alessandro Boch
parent 66fce01bfa
commit 3b8f3fd48d
3 changed files with 162 additions and 21 deletions

View File

@@ -26,8 +26,8 @@ var (
"destroy": {cmdDestroy, "creates a new ipset", 1}, "destroy": {cmdDestroy, "creates a new ipset", 1},
"list": {cmdList, "list specific ipset", 1}, "list": {cmdList, "list specific ipset", 1},
"listall": {cmdListAll, "list all ipsets", 0}, "listall": {cmdListAll, "list all ipsets", 0},
"add": {cmdAddDel(netlink.IpsetAdd), "add entry", 1}, "add": {cmdAddDel(netlink.IpsetAdd), "add entry", 2},
"del": {cmdAddDel(netlink.IpsetDel), "delete entry", 1}, "del": {cmdAddDel(netlink.IpsetDel), "delete entry", 2},
} }
timeoutVal *uint32 timeoutVal *uint32
@@ -89,9 +89,9 @@ func printUsage() {
} }
func cmdProtocol(_ []string) { func cmdProtocol(_ []string) {
protocol, err := netlink.IpsetProtocol() protocol, minProto, err := netlink.IpsetProtocol()
check(err) check(err)
log.Println("Protocol:", protocol) log.Println("Protocol:", protocol, "min:", minProto)
} }
func cmdCreate(args []string) { func cmdCreate(args []string) {

View File

@@ -23,13 +23,15 @@ type IPSetEntry struct {
// IPSetResult is the result of a dump request for a set // IPSetResult is the result of a dump request for a set
type IPSetResult struct { type IPSetResult struct {
Nfgenmsg *nl.Nfgenmsg Nfgenmsg *nl.Nfgenmsg
Protocol uint8 Protocol uint8
Revision uint8 ProtocolMinVersion uint8
Family uint8 Revision uint8
Flags uint8 Family uint8
SetName string Flags uint8
TypeName string SetName string
TypeName string
Comment string
HashSize uint32 HashSize uint32
NumEntries uint32 NumEntries uint32
@@ -38,6 +40,7 @@ type IPSetResult struct {
SizeInMemory uint32 SizeInMemory uint32
CadtFlags uint32 CadtFlags uint32
Timeout *uint32 Timeout *uint32
LineNo uint32
Entries []IPSetEntry Entries []IPSetEntry
} }
@@ -52,7 +55,7 @@ type IpsetCreateOptions struct {
} }
// IpsetProtocol returns the ipset protocol version from the kernel // IpsetProtocol returns the ipset protocol version from the kernel
func IpsetProtocol() (uint8, error) { func IpsetProtocol() (uint8, uint8, error) {
return pkgHandle.IpsetProtocol() return pkgHandle.IpsetProtocol()
} }
@@ -86,20 +89,20 @@ func IpsetAdd(setname string, entry *IPSetEntry) error {
return pkgHandle.ipsetAddDel(nl.IPSET_CMD_ADD, setname, entry) return pkgHandle.ipsetAddDel(nl.IPSET_CMD_ADD, setname, entry)
} }
// IpsetDele deletes an entry from an existing ipset. // IpsetDel deletes an entry from an existing ipset.
func IpsetDel(setname string, entry *IPSetEntry) error { func IpsetDel(setname string, entry *IPSetEntry) error {
return pkgHandle.ipsetAddDel(nl.IPSET_CMD_DEL, setname, entry) return pkgHandle.ipsetAddDel(nl.IPSET_CMD_DEL, setname, entry)
} }
func (h *Handle) IpsetProtocol() (uint8, error) { func (h *Handle) IpsetProtocol() (protocol uint8, minVersion uint8, err error) {
req := h.newIpsetRequest(nl.IPSET_CMD_PROTOCOL) req := h.newIpsetRequest(nl.IPSET_CMD_PROTOCOL)
msgs, err := req.Execute(unix.NETLINK_NETFILTER, 0) msgs, err := req.Execute(unix.NETLINK_NETFILTER, 0)
if err != nil { if err != nil {
return 0, err return 0, 0, err
} }
response := ipsetUnserialize(msgs)
return ipsetUnserialize(msgs).Protocol, nil return response.Protocol, response.ProtocolMinVersion, nil
} }
func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOptions) error { func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOptions) error {
@@ -112,7 +115,7 @@ func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOption
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname))) req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname)))
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_TYPENAME, nl.ZeroTerminated(typename))) req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_TYPENAME, nl.ZeroTerminated(typename)))
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_REVISION, nl.Uint8Attr(0))) req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_REVISION, nl.Uint8Attr(0)))
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_FAMILY, nl.Uint8Attr(0))) req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_FAMILY, nl.Uint8Attr(2))) // 2 == inet
data := nl.NewRtAttr(nl.IPSET_ATTR_DATA|int(nl.NLA_F_NESTED), nil) data := nl.NewRtAttr(nl.IPSET_ATTR_DATA|int(nl.NLA_F_NESTED), nil)
@@ -187,6 +190,11 @@ func (h *Handle) IpsetListAll() ([]IPSetResult, error) {
func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error { func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error {
req := h.newIpsetRequest(nlCmd) req := h.newIpsetRequest(nlCmd)
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname))) req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname)))
if entry.Comment != "" {
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_COMMENT, nl.ZeroTerminated(entry.Comment)))
}
data := nl.NewRtAttr(nl.IPSET_ATTR_DATA|int(nl.NLA_F_NESTED), nil) data := nl.NewRtAttr(nl.IPSET_ATTR_DATA|int(nl.NLA_F_NESTED), nil)
if !entry.Replace { if !entry.Replace {
@@ -197,7 +205,12 @@ func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error
data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_TIMEOUT | nl.NLA_F_NET_BYTEORDER, Value: *entry.Timeout}) data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_TIMEOUT | nl.NLA_F_NET_BYTEORDER, Value: *entry.Timeout})
} }
if entry.MAC != nil { if entry.MAC != nil {
data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_ETHER, entry.MAC)) nestedData := nl.NewRtAttr(nl.IPSET_ATTR_ETHER|int(nl.NLA_F_NET_BYTEORDER), entry.MAC)
data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_ETHER|int(nl.NLA_F_NESTED), nestedData.Serialize()))
}
if entry.IP != nil {
nestedData := nl.NewRtAttr(nl.IPSET_ATTR_IP|int(nl.NLA_F_NET_BYTEORDER), entry.IP)
data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_IP|int(nl.NLA_F_NESTED), nestedData.Serialize()))
} }
data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_LINENO | nl.NLA_F_NET_BYTEORDER, Value: 0}) data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_LINENO | nl.NLA_F_NET_BYTEORDER, Value: 0})
@@ -249,6 +262,8 @@ func (result *IPSetResult) unserialize(msg []byte) {
result.Protocol = attr.Value[0] result.Protocol = attr.Value[0]
case nl.IPSET_ATTR_SETNAME: case nl.IPSET_ATTR_SETNAME:
result.SetName = nl.BytesToString(attr.Value) result.SetName = nl.BytesToString(attr.Value)
case nl.IPSET_ATTR_COMMENT:
result.Comment = nl.BytesToString(attr.Value)
case nl.IPSET_ATTR_TYPENAME: case nl.IPSET_ATTR_TYPENAME:
result.TypeName = nl.BytesToString(attr.Value) result.TypeName = nl.BytesToString(attr.Value)
case nl.IPSET_ATTR_REVISION: case nl.IPSET_ATTR_REVISION:
@@ -261,6 +276,8 @@ func (result *IPSetResult) unserialize(msg []byte) {
result.parseAttrData(attr.Value) result.parseAttrData(attr.Value)
case nl.IPSET_ATTR_ADT | nl.NLA_F_NESTED: case nl.IPSET_ATTR_ADT | nl.NLA_F_NESTED:
result.parseAttrADT(attr.Value) result.parseAttrADT(attr.Value)
case nl.IPSET_ATTR_PROTOCOL_MIN:
result.ProtocolMinVersion = attr.Value[0]
default: default:
log.Printf("unknown ipset attribute from kernel: %+v %v", attr, attr.Type&nl.NLA_TYPE_MASK) log.Printf("unknown ipset attribute from kernel: %+v %v", attr, attr.Type&nl.NLA_TYPE_MASK)
} }
@@ -285,6 +302,17 @@ func (result *IPSetResult) parseAttrData(data []byte) {
result.SizeInMemory = attr.Uint32() result.SizeInMemory = attr.Uint32()
case nl.IPSET_ATTR_CADT_FLAGS | nl.NLA_F_NET_BYTEORDER: case nl.IPSET_ATTR_CADT_FLAGS | nl.NLA_F_NET_BYTEORDER:
result.CadtFlags = attr.Uint32() result.CadtFlags = attr.Uint32()
case nl.IPSET_ATTR_IP | nl.NLA_F_NESTED:
for nested := range nl.ParseAttributes(attr.Value) {
switch nested.Type {
case nl.IPSET_ATTR_IP | nl.NLA_F_NET_BYTEORDER:
result.Entries = append(result.Entries, IPSetEntry{IP: nested.Value})
}
}
case nl.IPSET_ATTR_CADT_LINENO | nl.NLA_F_NET_BYTEORDER:
result.LineNo = attr.Uint32()
case nl.IPSET_ATTR_COMMENT:
result.Comment = nl.BytesToString(attr.Value)
default: default:
log.Printf("unknown ipset data attribute from kernel: %+v %v", attr, attr.Type&nl.NLA_TYPE_MASK) log.Printf("unknown ipset data attribute from kernel: %+v %v", attr, attr.Type&nl.NLA_TYPE_MASK)
} }
@@ -316,6 +344,8 @@ func parseIPSetEntry(data []byte) (entry IPSetEntry) {
entry.Packets = &val entry.Packets = &val
case nl.IPSET_ATTR_ETHER: case nl.IPSET_ATTR_ETHER:
entry.MAC = net.HardwareAddr(attr.Value) entry.MAC = net.HardwareAddr(attr.Value)
case nl.IPSET_ATTR_IP:
entry.IP = net.IP(attr.Value)
case nl.IPSET_ATTR_COMMENT: case nl.IPSET_ATTR_COMMENT:
entry.Comment = nl.BytesToString(attr.Value) entry.Comment = nl.BytesToString(attr.Value)
case nl.IPSET_ATTR_IP | nl.NLA_F_NESTED: case nl.IPSET_ATTR_IP | nl.NLA_F_NESTED:

View File

@@ -2,11 +2,10 @@ package netlink
import ( import (
"bytes" "bytes"
"github.com/vishvananda/netlink/nl"
"io/ioutil" "io/ioutil"
"net" "net"
"testing" "testing"
"github.com/vishvananda/netlink/nl"
) )
func TestParseIpsetProtocolResult(t *testing.T) { func TestParseIpsetProtocolResult(t *testing.T) {
@@ -85,3 +84,115 @@ func TestParseIpsetListResult(t *testing.T) {
t.Errorf("expected MAC for second entry to be %s, got %s", expectedMAC.String(), ent.MAC.String()) t.Errorf("expected MAC for second entry to be %s, got %s", expectedMAC.String(), ent.MAC.String())
} }
} }
func TestIpsetCreateListAddDelDestroy(t *testing.T) {
tearDown := setUpNetlinkTest(t)
defer tearDown()
timeout := uint32(3)
err := IpsetCreate("my-test-ipset-1", "hash:ip", IpsetCreateOptions{
Replace: true,
Timeout: &timeout,
Counters: true,
Comments: false,
Skbinfo: false,
})
if err != nil {
t.Fatal(err)
}
err = IpsetCreate("my-test-ipset-2", "hash:net", IpsetCreateOptions{
Replace: true,
Timeout: &timeout,
Counters: false,
Comments: true,
Skbinfo: true,
})
if err != nil {
t.Fatal(err)
}
results, err := IpsetListAll()
if err != nil {
t.Fatal(err)
}
if len(results) != 2 {
t.Fatalf("expected 2 IPSets to be created, got %d", len(results))
}
if results[0].SetName != "my-test-ipset-1" {
t.Errorf("expected name to be 'my-test-ipset-1', but got '%s'", results[0].SetName)
}
if results[1].SetName != "my-test-ipset-2" {
t.Errorf("expected name to be 'my-test-ipset-2', but got '%s'", results[1].SetName)
}
if results[0].TypeName != "hash:ip" {
t.Errorf("expected type to be 'hash:ip', but got '%s'", results[0].TypeName)
}
if results[1].TypeName != "hash:net" {
t.Errorf("expected type to be 'hash:net', but got '%s'", results[1].TypeName)
}
if *results[0].Timeout != 3 {
t.Errorf("expected timeout to be 3, but got '%d'", *results[0].Timeout)
}
err = IpsetAdd("my-test-ipset-1", &IPSetEntry{
Comment: "test comment",
IP: net.ParseIP("10.99.99.99").To4(),
Replace: false,
})
if err != nil {
t.Fatal(err)
}
result, err := IpsetList("my-test-ipset-1")
if err != nil {
t.Fatal(err)
}
if len(result.Entries) != 1 {
t.Fatalf("expected 1 entry be created, got '%d'", len(result.Entries))
}
if result.Entries[0].IP.String() != "10.99.99.99" {
t.Fatalf("expected entry to be '10.99.99.99', got '%s'", result.Entries[0].IP.String())
}
if result.Entries[0].Comment != "test comment" {
// This is only supported in the kernel module from revision 2 or 4, so comments may be ignored.
t.Logf("expected comment to be 'test comment', got '%s'", result.Entries[0].Comment)
}
err = IpsetDel("my-test-ipset-1", &IPSetEntry{
Comment: "test comment",
IP: net.ParseIP("10.99.99.99").To4(),
})
if err != nil {
t.Fatal(err)
}
result, err = IpsetList("my-test-ipset-1")
if err != nil {
t.Fatal(err)
}
if len(result.Entries) != 0 {
t.Fatalf("expected 0 entries to exist, got %d", len(result.Entries))
}
err = IpsetDestroy("my-test-ipset-1")
if err != nil {
t.Fatal(err)
}
err = IpsetDestroy("my-test-ipset-2")
if err != nil {
t.Fatal(err)
}
}