mirror of
https://github.com/vishvananda/netlink.git
synced 2025-09-27 04:05:59 +08:00

Refactors test setup and teardown logic to use `t.Cleanup` instead of `defer`. This ensures that cleanup functions are correctly scoped to each subtest's lifecycle, improving test isolation and reliability. The `setUpNetlinkTest` helper function is also improved to correctly save and restore the original network namespace, ensuring that tests do not leak state. To support this, a `Close()` method that returns an error is added to the `Handle` struct, allowing for proper cleanup of underlying netlink sockets. The test helpers are updated to use this new method, preventing resource leaks between tests. Additionally, a bug in the `netns` tests is fixed where a large namespace ID could overflow a 32-bit integer, causing spurious failures on some systems.
314 lines
7.2 KiB
Go
314 lines
7.2 KiB
Go
//go:build linux
|
|
// +build linux
|
|
|
|
package netlink
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"os"
|
|
"os/exec"
|
|
"runtime"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/vishvananda/netlink/nl"
|
|
"github.com/vishvananda/netns"
|
|
"golang.org/x/sys/unix"
|
|
)
|
|
|
|
type tearDownNetlinkTest func()
|
|
|
|
func skipUnlessRoot(t testing.TB) {
|
|
t.Helper()
|
|
|
|
if os.Getuid() != 0 {
|
|
t.Skip("Test requires root privileges.")
|
|
}
|
|
}
|
|
|
|
func skipUnlessKModuleLoaded(t *testing.T, moduleNames ...string) {
|
|
t.Helper()
|
|
file, err := ioutil.ReadFile("/proc/modules")
|
|
if err != nil {
|
|
t.Fatal("Failed to open /proc/modules", err)
|
|
}
|
|
|
|
foundRequiredMods := make(map[string]bool)
|
|
lines := strings.Split(string(file), "\n")
|
|
|
|
for _, name := range moduleNames {
|
|
foundRequiredMods[name] = false
|
|
for _, line := range lines {
|
|
n := strings.Split(line, " ")[0]
|
|
if n == name {
|
|
foundRequiredMods[name] = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
failed := false
|
|
for _, name := range moduleNames {
|
|
if found, _ := foundRequiredMods[name]; !found {
|
|
t.Logf("Test requires missing kmodule %q.", name)
|
|
failed = true
|
|
}
|
|
}
|
|
if failed {
|
|
t.SkipNow()
|
|
}
|
|
}
|
|
|
|
func setUpNetlinkTest(t testing.TB) tearDownNetlinkTest {
|
|
skipUnlessRoot(t)
|
|
// Lock the OS thread, then record original namespace
|
|
runtime.LockOSThread()
|
|
origNS, err := netns.Get()
|
|
if err != nil {
|
|
runtime.UnlockOSThread()
|
|
t.Fatal("Failed to get current namespace:", err)
|
|
}
|
|
// Create and enter a fresh namespace
|
|
ns, err := netns.New()
|
|
if err != nil {
|
|
// attempt to restore before failing
|
|
_ = netns.Set(origNS)
|
|
runtime.UnlockOSThread()
|
|
t.Fatal("Failed to create new namespace:", err)
|
|
}
|
|
// Reinitialize the package-level handle in this namespace
|
|
if pkgHandle != nil {
|
|
// ensure all sockets from the previous Handle are closed
|
|
_ = pkgHandle.Close()
|
|
}
|
|
pkgHandle = &Handle{}
|
|
|
|
return func() {
|
|
// Close the new namespace handle
|
|
ns.Close()
|
|
// Restore the original namespace
|
|
if err := netns.Set(origNS); err != nil {
|
|
t.Fatalf("Failed to restore original namespace: %v", err)
|
|
}
|
|
_ = origNS.Close()
|
|
// Unlock the OS thread
|
|
runtime.UnlockOSThread()
|
|
}
|
|
}
|
|
|
|
// setUpNamedNetlinkTest create a temporary named names space with a random name
|
|
func setUpNamedNetlinkTest(t *testing.T) (string, tearDownNetlinkTest) {
|
|
skipUnlessRoot(t)
|
|
|
|
origNS, err := netns.Get()
|
|
if err != nil {
|
|
t.Fatal("Failed saving orig namespace")
|
|
}
|
|
|
|
// create a random name
|
|
rnd := make([]byte, 4)
|
|
if _, err := rand.Read(rnd); err != nil {
|
|
t.Fatal("failed creating random ns name")
|
|
}
|
|
name := "netlinktest-" + hex.EncodeToString(rnd)
|
|
|
|
ns, err := netns.NewNamed(name)
|
|
if err != nil {
|
|
t.Fatal("Failed to create new ns", err)
|
|
}
|
|
|
|
runtime.LockOSThread()
|
|
cleanup := func() {
|
|
ns.Close()
|
|
netns.DeleteNamed(name)
|
|
netns.Set(origNS)
|
|
runtime.UnlockOSThread()
|
|
}
|
|
|
|
if err := netns.Set(ns); err != nil {
|
|
cleanup()
|
|
t.Fatal("Failed entering new namespace", err)
|
|
}
|
|
|
|
return name, cleanup
|
|
}
|
|
|
|
func setUpNetlinkTestWithLoopback(t *testing.T) tearDownNetlinkTest {
|
|
skipUnlessRoot(t)
|
|
|
|
runtime.LockOSThread()
|
|
|
|
// Save the current namespace
|
|
origNS, err := netns.Get()
|
|
if err != nil {
|
|
runtime.UnlockOSThread()
|
|
t.Fatal("Failed to get current namespace:", err)
|
|
}
|
|
|
|
// Create and enter a fresh namespace
|
|
ns, err := netns.New()
|
|
if err != nil {
|
|
runtime.UnlockOSThread()
|
|
t.Fatal("Failed to create new netns:", err)
|
|
}
|
|
|
|
// Bring up the loopback interface
|
|
link, err := LinkByName("lo")
|
|
if err != nil {
|
|
t.Fatalf("Failed to find \"lo\" in new netns: %v", err)
|
|
}
|
|
if err := LinkSetUp(link); err != nil {
|
|
t.Fatalf("Failed to bring up \"lo\" in new netns: %v", err)
|
|
}
|
|
|
|
// Teardown: restore original namespace and thread state
|
|
return func() {
|
|
ns.Close()
|
|
if err := netns.Set(origNS); err != nil {
|
|
t.Fatalf("Failed to restore original namespace: %v", err)
|
|
}
|
|
_ = origNS.Close()
|
|
runtime.UnlockOSThread()
|
|
}
|
|
}
|
|
|
|
func setUpF(t *testing.T, path, value string) {
|
|
file, err := os.Create(path)
|
|
if err != nil {
|
|
t.Fatalf("Failed to open %s: %s", path, err)
|
|
}
|
|
defer file.Close()
|
|
file.WriteString(value)
|
|
}
|
|
|
|
func setUpMPLSNetlinkTest(t *testing.T) tearDownNetlinkTest {
|
|
if _, err := os.Stat("/proc/sys/net/mpls/platform_labels"); err != nil {
|
|
t.Skip("Test requires MPLS support.")
|
|
}
|
|
f := setUpNetlinkTest(t)
|
|
setUpF(t, "/proc/sys/net/mpls/platform_labels", "1024")
|
|
setUpF(t, "/proc/sys/net/mpls/conf/lo/input", "1")
|
|
return f
|
|
}
|
|
|
|
func setUpSEG6NetlinkTest(t *testing.T) tearDownNetlinkTest {
|
|
// check if SEG6 options are enabled in Kernel Config
|
|
cmd := exec.Command("uname", "-r")
|
|
var out bytes.Buffer
|
|
cmd.Stdout = &out
|
|
if err := cmd.Run(); err != nil {
|
|
t.Fatal("Failed to run: uname -r")
|
|
}
|
|
s := []string{"/boot/config-", strings.TrimRight(out.String(), "\n")}
|
|
filename := strings.Join(s, "")
|
|
|
|
grepKey := func(key, fname string) (string, error) {
|
|
cmd := exec.Command("grep", key, filename)
|
|
var out bytes.Buffer
|
|
cmd.Stdout = &out
|
|
err := cmd.Run() // "err != nil" if no line matched with grep
|
|
return strings.TrimRight(out.String(), "\n"), err
|
|
}
|
|
key := string("CONFIG_IPV6_SEG6_LWTUNNEL=y")
|
|
if _, err := grepKey(key, filename); err != nil {
|
|
msg := "Skipped test because it requires SEG6_LWTUNNEL support."
|
|
log.Println(msg)
|
|
t.Skip(msg)
|
|
}
|
|
// Add CONFIG_IPV6_SEG6_HMAC to support seg6_hamc
|
|
// key := string("CONFIG_IPV6_SEG6_HMAC=y")
|
|
|
|
return setUpNetlinkTest(t)
|
|
}
|
|
|
|
func setUpNetlinkTestWithKModule(t *testing.T, moduleNames ...string) tearDownNetlinkTest {
|
|
skipUnlessKModuleLoaded(t, moduleNames...)
|
|
return setUpNetlinkTest(t)
|
|
}
|
|
func setUpNamedNetlinkTestWithKModule(t *testing.T, moduleNames ...string) (string, tearDownNetlinkTest) {
|
|
file, err := ioutil.ReadFile("/proc/modules")
|
|
if err != nil {
|
|
t.Fatal("Failed to open /proc/modules", err)
|
|
}
|
|
|
|
foundRequiredMods := make(map[string]bool)
|
|
lines := strings.Split(string(file), "\n")
|
|
|
|
for _, name := range moduleNames {
|
|
foundRequiredMods[name] = false
|
|
for _, line := range lines {
|
|
n := strings.Split(line, " ")[0]
|
|
if n == name {
|
|
foundRequiredMods[name] = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
failed := false
|
|
for _, name := range moduleNames {
|
|
if found, _ := foundRequiredMods[name]; !found {
|
|
t.Logf("Test requires missing kmodule %q.", name)
|
|
failed = true
|
|
}
|
|
}
|
|
if failed {
|
|
t.SkipNow()
|
|
}
|
|
|
|
return setUpNamedNetlinkTest(t)
|
|
}
|
|
|
|
func remountSysfs() error {
|
|
if err := unix.Mount("", "/", "none", unix.MS_SLAVE|unix.MS_REC, ""); err != nil {
|
|
return err
|
|
}
|
|
if err := unix.Unmount("/sys", unix.MNT_DETACH); err != nil {
|
|
return err
|
|
}
|
|
return unix.Mount("", "/sys", "sysfs", 0, "")
|
|
}
|
|
|
|
func minKernelRequired(t *testing.T, kernel, major int) {
|
|
t.Helper()
|
|
|
|
k, m, err := KernelVersion()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if k < kernel || k == kernel && m < major {
|
|
t.Skipf("Host Kernel (%d.%d) does not meet test's minimum required version: (%d.%d)",
|
|
k, m, kernel, major)
|
|
}
|
|
}
|
|
|
|
func KernelVersion() (kernel, major int, err error) {
|
|
uts := unix.Utsname{}
|
|
if err = unix.Uname(&uts); err != nil {
|
|
return
|
|
}
|
|
|
|
ba := make([]byte, 0, len(uts.Release))
|
|
for _, b := range uts.Release {
|
|
if b == 0 {
|
|
break
|
|
}
|
|
ba = append(ba, byte(b))
|
|
}
|
|
var rest string
|
|
if n, _ := fmt.Sscanf(string(ba), "%d.%d%s", &kernel, &major, &rest); n < 2 {
|
|
err = fmt.Errorf("can't parse kernel version in %q", string(ba))
|
|
}
|
|
return
|
|
}
|
|
|
|
func TestMain(m *testing.M) {
|
|
nl.EnableErrorMessageReporting = true
|
|
os.Exit(m.Run())
|
|
}
|