diff --git a/nl/xfrm_linux.go b/nl/xfrm_linux.go index cdb318b..6cfd8f9 100644 --- a/nl/xfrm_linux.go +++ b/nl/xfrm_linux.go @@ -78,10 +78,14 @@ const ( XFRMA_PROTO /* __u8 */ XFRMA_ADDRESS_FILTER /* struct xfrm_address_filter */ XFRMA_PAD - XFRMA_OFFLOAD_DEV /* struct xfrm_state_offload */ - XFRMA_SET_MARK /* __u32 */ - XFRMA_SET_MARK_MASK /* __u32 */ - XFRMA_IF_ID /* __u32 */ + XFRMA_OFFLOAD_DEV /* struct xfrm_state_offload */ + XFRMA_SET_MARK /* __u32 */ + XFRMA_SET_MARK_MASK /* __u32 */ + XFRMA_IF_ID /* __u32 */ + XFRMA_MTIMER_THRESH /* __u32 in seconds for input SA */ + XFRMA_SA_DIR /* __u8 */ + XFRMA_NAT_KEEPALIVE_INTERVAL /* __u32 in seconds for NAT keepalive */ + XFRMA_SA_PCPU /* __u32 */ XFRMA_MAX = iota - 1 ) diff --git a/xfrm_linux.go b/xfrm_linux.go index dd38ed8..b603e4c 100644 --- a/xfrm_linux.go +++ b/xfrm_linux.go @@ -48,6 +48,14 @@ const ( XFRM_MODE_MAX ) +// SADir is an enum representing an ipsec template direction. +type SADir uint8 + +const ( + XFRM_SA_DIR_IN SADir = iota + 1 + XFRM_SA_DIR_OUT +) + func (m Mode) String() string { switch m { case XFRM_MODE_TRANSPORT: diff --git a/xfrm_state_linux.go b/xfrm_state_linux.go index 2f46146..092ffe9 100644 --- a/xfrm_state_linux.go +++ b/xfrm_state_linux.go @@ -113,7 +113,9 @@ type XfrmState struct { Statistics XfrmStateStats Mark *XfrmMark OutputMark *XfrmMark + SADir SADir Ifid int + Pcpunum *uint32 Auth *XfrmStateAlgo Crypt *XfrmStateAlgo Aead *XfrmStateAlgo @@ -126,8 +128,8 @@ type XfrmState struct { } func (sa XfrmState) String() string { - return fmt.Sprintf("Dst: %v, Src: %v, Proto: %s, Mode: %s, SPI: 0x%x, ReqID: 0x%x, ReplayWindow: %d, Mark: %v, OutputMark: %v, Ifid: %d, Auth: %v, Crypt: %v, Aead: %v, Encap: %v, ESN: %t, DontEncapDSCP: %t, OSeqMayWrap: %t, Replay: %v", - sa.Dst, sa.Src, sa.Proto, sa.Mode, sa.Spi, sa.Reqid, sa.ReplayWindow, sa.Mark, sa.OutputMark, sa.Ifid, sa.Auth, sa.Crypt, sa.Aead, sa.Encap, sa.ESN, sa.DontEncapDSCP, sa.OSeqMayWrap, sa.Replay) + return fmt.Sprintf("Dst: %v, Src: %v, Proto: %s, Mode: %s, SPI: 0x%x, ReqID: 0x%x, ReplayWindow: %d, Mark: %v, OutputMark: %v, SADir: %d, Ifid: %d, Pcpunum: %d, Auth: %v, Crypt: %v, Aead: %v, Encap: %v, ESN: %t, DontEncapDSCP: %t, OSeqMayWrap: %t, Replay: %v", + sa.Dst, sa.Src, sa.Proto, sa.Mode, sa.Spi, sa.Reqid, sa.ReplayWindow, sa.Mark, sa.OutputMark, sa.SADir, sa.Ifid, *sa.Pcpunum, sa.Auth, sa.Crypt, sa.Aead, sa.Encap, sa.ESN, sa.DontEncapDSCP, sa.OSeqMayWrap, sa.Replay) } func (sa XfrmState) Print(stats bool) string { if !stats { @@ -333,11 +335,21 @@ func (h *Handle) xfrmStateAddOrUpdate(state *XfrmState, nlProto int) error { req.AddData(out) } + if state.SADir != 0 { + saDir := nl.NewRtAttr(nl.XFRMA_SA_DIR, nl.Uint8Attr(uint8(state.SADir))) + req.AddData(saDir) + } + if state.Ifid != 0 { ifId := nl.NewRtAttr(nl.XFRMA_IF_ID, nl.Uint32Attr(uint32(state.Ifid))) req.AddData(ifId) } + if state.Pcpunum != nil { + pcpuNum := nl.NewRtAttr(nl.XFRMA_SA_PCPU, nl.Uint32Attr(uint32(*state.Pcpunum))) + req.AddData(pcpuNum) + } + _, err := req.Execute(unix.NETLINK_XFRM, 0) return err } @@ -459,6 +471,11 @@ func (h *Handle) xfrmStateGetOrDelete(state *XfrmState, nlProto int) (*XfrmState req.AddData(ifId) } + if state.Pcpunum != nil { + pcpuNum := nl.NewRtAttr(nl.XFRMA_SA_PCPU, nl.Uint32Attr(uint32(*state.Pcpunum))) + req.AddData(pcpuNum) + } + resType := nl.XFRM_MSG_NEWSA if nlProto == nl.XFRM_MSG_DELSA { resType = 0 @@ -581,8 +598,13 @@ func parseXfrmState(m []byte, family int) (*XfrmState, error) { if state.OutputMark.Mask == 0xffffffff { state.OutputMark.Mask = 0 } + case nl.XFRMA_SA_DIR: + state.SADir = SADir(attr.Value[0]) case nl.XFRMA_IF_ID: state.Ifid = int(native.Uint32(attr.Value)) + case nl.XFRMA_SA_PCPU: + pcpuNum := native.Uint32(attr.Value) + state.Pcpunum = &pcpuNum case nl.XFRMA_REPLAY_VAL: if state.Replay == nil { state.Replay = new(XfrmReplayState) diff --git a/xfrm_state_linux_test.go b/xfrm_state_linux_test.go index 22031e3..eafa83b 100644 --- a/xfrm_state_linux_test.go +++ b/xfrm_state_linux_test.go @@ -225,6 +225,72 @@ func TestXfrmStateWithIfid(t *testing.T) { } } +func TestXfrmStateWithSADir(t *testing.T) { + minKernelRequired(t, 4, 19) + defer setUpNetlinkTest(t)() + + state := getBaseState() + state.SADir = XFRM_SA_DIR_IN + if err := XfrmStateAdd(state); err != nil { + t.Fatal(err) + } + s, err := XfrmStateGet(state) + if err != nil { + t.Fatal(err) + } + if !compareStates(state, s) { + t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s) + } + if err = XfrmStateDel(s); err != nil { + t.Fatal(err) + } +} + +func TestXfrmStateWithPcpunumWithoutSADir(t *testing.T) { + minKernelRequired(t, 4, 19) + defer setUpNetlinkTest(t)() + + state := getBaseState() + pcpuNum := uint32(1) + state.Pcpunum = &pcpuNum + if err := XfrmStateAdd(state); err != nil { + t.Fatal(err) + } + s, err := XfrmStateGet(state) + if err != nil { + t.Fatal(err) + } + if !compareStates(state, s) { + t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s) + } + if err = XfrmStateDel(s); err != nil { + t.Fatal(err) + } +} + +func TestXfrmStateWithPcpunumWithSADir(t *testing.T) { + minKernelRequired(t, 4, 19) + defer setUpNetlinkTest(t)() + + state := getBaseState() + state.SADir = XFRM_SA_DIR_IN + pcpuNum := uint32(1) + state.Pcpunum = &pcpuNum + if err := XfrmStateAdd(state); err != nil { + t.Fatal(err) + } + s, err := XfrmStateGet(state) + if err != nil { + t.Fatal(err) + } + if !compareStates(state, s) { + t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s) + } + if err = XfrmStateDel(s); err != nil { + t.Fatal(err) + } +} + func TestXfrmStateWithOutputMark(t *testing.T) { minKernelRequired(t, 4, 14) defer setUpNetlinkTest(t)()