Mimic ipset C code for determining correct default ipset revision

Signed-off-by: Benjamin Leggett <benjamin.leggett@solo.io>
This commit is contained in:
Benjamin Leggett
2024-11-18 17:34:34 -05:00
committed by Alessandro Boch
parent 2426b0576c
commit 1f4f72c917
2 changed files with 173 additions and 19 deletions

View File

@@ -147,9 +147,11 @@ func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOption
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname)))
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_TYPENAME, nl.ZeroTerminated(typename)))
cadtFlags := optionsToBitflag(options)
revision := options.Revision
if revision == 0 {
revision = getIpsetDefaultWithTypeName(typename)
revision = getIpsetDefaultRevision(typename, cadtFlags)
}
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_REVISION, nl.Uint8Attr(revision)))
@@ -181,18 +183,6 @@ func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOption
data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_TIMEOUT | nl.NLA_F_NET_BYTEORDER, Value: *timeout})
}
var cadtFlags uint32
if options.Comments {
cadtFlags |= nl.IPSET_FLAG_WITH_COMMENT
}
if options.Counters {
cadtFlags |= nl.IPSET_FLAG_WITH_COUNTERS
}
if options.Skbinfo {
cadtFlags |= nl.IPSET_FLAG_WITH_SKBINFO
}
if cadtFlags != 0 {
data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_CADT_FLAGS | nl.NLA_F_NET_BYTEORDER, Value: cadtFlags})
}
@@ -395,14 +385,89 @@ func (h *Handle) newIpsetRequest(cmd int) *nl.NetlinkRequest {
return req
}
func getIpsetDefaultWithTypeName(typename string) uint8 {
// NOTE: This can't just take typename into account, it also has to take desired
// feature support into account, on a per-set-type basis, to return the correct revision, see e.g.
// https://github.com/Olipro/ipset/blob/9f145b49100104d6570fe5c31a5236816ebb4f8f/kernel/net/netfilter/ipset/ip_set_hash_ipport.c#L30
//
// This means that whenever a new "type" of ipset is added, returning the "correct" default revision
// requires adding a new case here for that type, and consulting the ipset C code to figure out the correct
// combination of type name, feature bit flags, and revision ranges.
//
// Care should be taken as some types share the same revision ranges for the same features, and others do not.
// When in doubt, mimic the C code.
func getIpsetDefaultRevision(typename string, featureFlags uint32) uint8 {
switch typename {
case "hash:ip,port",
"hash:ip,port,ip",
"hash:ip,port,net",
"hash:ip,port,ip":
// Taken from
// - ipset/kernel/net/netfilter/ipset/ip_set_hash_ipport.c
// - ipset/kernel/net/netfilter/ipset/ip_set_hash_ipportip.c
if (featureFlags & nl.IPSET_FLAG_WITH_SKBINFO) != 0 {
return 5
}
if (featureFlags & nl.IPSET_FLAG_WITH_FORCEADD) != 0 {
return 4
}
if (featureFlags & nl.IPSET_FLAG_WITH_COMMENT) != 0 {
return 3
}
if (featureFlags & nl.IPSET_FLAG_WITH_COUNTERS) != 0 {
return 2
}
// the min revision this library supports for this type
return 1
case "hash:ip,port,net",
"hash:net,port":
// Taken from
// - ipset/kernel/net/netfilter/ipset/ip_set_hash_ipportnet.c
// - ipset/kernel/net/netfilter/ipset/ip_set_hash_netport.c
if (featureFlags & nl.IPSET_FLAG_WITH_SKBINFO) != 0 {
return 7
}
if (featureFlags & nl.IPSET_FLAG_WITH_FORCEADD) != 0 {
return 6
}
if (featureFlags & nl.IPSET_FLAG_WITH_COMMENT) != 0 {
return 5
}
if (featureFlags & nl.IPSET_FLAG_WITH_COUNTERS) != 0 {
return 4
}
if (featureFlags & nl.IPSET_FLAG_NOMATCH) != 0 {
return 3
}
// the min revision this library supports for this type
return 2
case "hash:ip":
// Taken from
// - ipset/kernel/net/netfilter/ipset/ip_set_hash_ip.c
if (featureFlags & nl.IPSET_FLAG_WITH_SKBINFO) != 0 {
return 4
}
if (featureFlags & nl.IPSET_FLAG_WITH_FORCEADD) != 0 {
return 3
}
if (featureFlags & nl.IPSET_FLAG_WITH_COMMENT) != 0 {
return 2
}
// the min revision this library supports for this type
return 1
}
// can't map the correct revision for this type.
return 0
}
@@ -579,3 +644,19 @@ func parseIPSetEntry(data []byte) (entry IPSetEntry) {
}
return
}
func optionsToBitflag(options IpsetCreateOptions) uint32 {
var cadtFlags uint32
if options.Comments {
cadtFlags |= nl.IPSET_FLAG_WITH_COMMENT
}
if options.Counters {
cadtFlags |= nl.IPSET_FLAG_WITH_COUNTERS
}
if options.Skbinfo {
cadtFlags |= nl.IPSET_FLAG_WITH_SKBINFO
}
return cadtFlags
}

View File

@@ -2,8 +2,8 @@ package netlink
import (
"bytes"
"io/ioutil"
"net"
"os"
"testing"
"github.com/vishvananda/netlink/nl"
@@ -11,7 +11,7 @@ import (
)
func TestParseIpsetProtocolResult(t *testing.T) {
msgBytes, err := ioutil.ReadFile("testdata/ipset_protocol_result")
msgBytes, err := os.ReadFile("testdata/ipset_protocol_result")
if err != nil {
t.Fatalf("reading test fixture failed: %v", err)
}
@@ -23,7 +23,7 @@ func TestParseIpsetProtocolResult(t *testing.T) {
}
func TestParseIpsetListResult(t *testing.T) {
msgBytes, err := ioutil.ReadFile("testdata/ipset_list_result")
msgBytes, err := os.ReadFile("testdata/ipset_list_result")
if err != nil {
t.Fatalf("reading test fixture failed: %v", err)
}
@@ -759,3 +759,76 @@ func TestIpsetMaxElements(t *testing.T) {
t.Fatalf("expected '%d' entry be created, got '%d'", maxElements, len(result.Entries))
}
}
func TestIpsetDefaultRevision(t *testing.T) {
testCases := []struct {
desc string
typename string
options IpsetCreateOptions
expectedRevision uint8
}{
{
desc: "Type-hash:ip,port",
typename: "hash:ip,port",
options: IpsetCreateOptions{
Counters: true,
Comments: true,
Skbinfo: false,
},
expectedRevision: 3,
},
{
desc: "Type-hash:ip,port_nocomment",
typename: "hash:ip,port",
options: IpsetCreateOptions{
Counters: true,
Comments: false,
Skbinfo: false,
},
expectedRevision: 2,
},
{
desc: "Type-hash:ip,port_skbinfo",
typename: "hash:ip,port",
options: IpsetCreateOptions{
Counters: true,
Comments: false,
Skbinfo: true,
},
expectedRevision: 5,
},
{
desc: "Type-hash:ip,port,net",
typename: "hash:ip,port,net",
options: IpsetCreateOptions{
Counters: true,
Comments: false,
Skbinfo: true,
},
expectedRevision: 7,
},
{
desc: "Type-hash:net,port_baseline_revision_no_opts",
typename: "hash:net,port",
options: IpsetCreateOptions{
Counters: false,
Comments: false,
Skbinfo: false,
},
expectedRevision: 2,
},
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
cadtFlags := optionsToBitflag(tC.options)
defRev := getIpsetDefaultRevision(tC.typename, cadtFlags)
if defRev != tC.expectedRevision {
t.Fatalf("expected default revision of '%d', got '%d'", tC.expectedRevision, defRev)
}
})
}
}