From 3c27c1c1e3b7e6253aea6df98aa65be4d18c080c Mon Sep 17 00:00:00 2001 From: Martynas Pumputis Date: Mon, 16 Jan 2017 17:58:20 +0000 Subject: [PATCH] Add XfrmAllocSpi --- nl/xfrm_state_linux.go | 25 ++++++++++++++++ nl/xfrm_state_linux_test.go | 27 +++++++++++++++++ xfrm_state_linux.go | 59 +++++++++++++++++++++++++++++++------ xfrm_state_test.go | 20 +++++++++++++ 4 files changed, 122 insertions(+), 9 deletions(-) diff --git a/nl/xfrm_state_linux.go b/nl/xfrm_state_linux.go index 856db3d..e5b9a09 100644 --- a/nl/xfrm_state_linux.go +++ b/nl/xfrm_state_linux.go @@ -8,6 +8,7 @@ const ( SizeofXfrmUsersaId = 0x18 SizeofXfrmStats = 0x0c SizeofXfrmUsersaInfo = 0xe0 + SizeofXfrmUserSpiInfo = 0xe8 SizeofXfrmAlgo = 0x44 SizeofXfrmAlgoAuth = 0x48 SizeofXfrmAlgoAEAD = 0x48 @@ -120,6 +121,30 @@ func (msg *XfrmUsersaInfo) Serialize() []byte { return (*(*[SizeofXfrmUsersaInfo]byte)(unsafe.Pointer(msg)))[:] } +// struct xfrm_userspi_info { +// struct xfrm_usersa_info info; +// __u32 min; +// __u32 max; +// }; + +type XfrmUserSpiInfo struct { + XfrmUsersaInfo XfrmUsersaInfo + Min uint32 + Max uint32 +} + +func (msg *XfrmUserSpiInfo) Len() int { + return SizeofXfrmUserSpiInfo +} + +func DeserializeXfrmUserSpiInfo(b []byte) *XfrmUserSpiInfo { + return (*XfrmUserSpiInfo)(unsafe.Pointer(&b[0:SizeofXfrmUserSpiInfo][0])) +} + +func (msg *XfrmUserSpiInfo) Serialize() []byte { + return (*(*[SizeofXfrmUserSpiInfo]byte)(unsafe.Pointer(msg)))[:] +} + // struct xfrm_algo { // char alg_name[64]; // unsigned int alg_key_len; /* in bits */ diff --git a/nl/xfrm_state_linux_test.go b/nl/xfrm_state_linux_test.go index eb31208..5ede308 100644 --- a/nl/xfrm_state_linux_test.go +++ b/nl/xfrm_state_linux_test.go @@ -118,6 +118,33 @@ func (msg *XfrmAlgo) serializeSafe() []byte { return b } +func (msg *XfrmUserSpiInfo) write(b []byte) { + native := NativeEndian() + msg.XfrmUsersaInfo.write(b[0:SizeofXfrmUsersaInfo]) + native.PutUint32(b[SizeofXfrmUsersaInfo:SizeofXfrmUsersaInfo+4], msg.Min) + native.PutUint32(b[SizeofXfrmUsersaInfo+4:SizeofXfrmUsersaInfo+8], msg.Max) +} + +func (msg *XfrmUserSpiInfo) serializeSafe() []byte { + b := make([]byte, SizeofXfrmUserSpiInfo) + msg.write(b) + return b +} + +func deserializeXfrmUserSpiInfoSafe(b []byte) *XfrmUserSpiInfo { + var msg = XfrmUserSpiInfo{} + binary.Read(bytes.NewReader(b[0:SizeofXfrmUserSpiInfo]), NativeEndian(), &msg) + return &msg +} + +func TestXfrmUserSpiInfoDeserializeSerialize(t *testing.T) { + var orig = make([]byte, SizeofXfrmUserSpiInfo) + rand.Read(orig) + safemsg := deserializeXfrmUserSpiInfoSafe(orig) + msg := DeserializeXfrmUserSpiInfo(orig) + testDeserializeSerialize(t, orig, safemsg, msg) +} + func deserializeXfrmAlgoSafe(b []byte) *XfrmAlgo { var msg = XfrmAlgo{} copy(msg.AlgName[:], b[0:64]) diff --git a/xfrm_state_linux.go b/xfrm_state_linux.go index 7b3ef9e..dda4970 100644 --- a/xfrm_state_linux.go +++ b/xfrm_state_linux.go @@ -72,6 +72,12 @@ func (h *Handle) XfrmStateAdd(state *XfrmState) error { return h.xfrmStateAddOrUpdate(state, nl.XFRM_MSG_NEWSA) } +// XfrmStateAllocSpi will allocate an xfrm state in the system. +// Equivalent to: `ip xfrm state allocspi` +func XfrmStateAllocSpi(state *XfrmState) (*XfrmState, error) { + return pkgHandle.xfrmStateAllocSpi(state) +} + // XfrmStateUpdate will update an xfrm state to the system. // Equivalent to: `ip xfrm state update $state` func XfrmStateUpdate(state *XfrmState) error { @@ -91,15 +97,7 @@ func (h *Handle) xfrmStateAddOrUpdate(state *XfrmState, nlProto int) error { } req := h.newNetlinkRequest(nlProto, syscall.NLM_F_CREATE|syscall.NLM_F_EXCL|syscall.NLM_F_ACK) - msg := &nl.XfrmUsersaInfo{} - msg.Family = uint16(nl.GetIPFamily(state.Dst)) - msg.Id.Daddr.FromIP(state.Dst) - msg.Saddr.FromIP(state.Src) - msg.Id.Proto = uint8(state.Proto) - msg.Mode = uint8(state.Mode) - msg.Id.Spi = nl.Swap32(uint32(state.Spi)) - msg.Reqid = uint32(state.Reqid) - msg.ReplayWindow = uint8(state.ReplayWindow) + msg := xfrmUsersaInfoFromXfrmState(state) limitsToLft(state.Limits, &msg.Lft) req.AddData(msg) @@ -134,6 +132,35 @@ func (h *Handle) xfrmStateAddOrUpdate(state *XfrmState, nlProto int) error { return err } +func (h *Handle) xfrmStateAllocSpi(state *XfrmState) (*XfrmState, error) { + req := h.newNetlinkRequest(nl.XFRM_MSG_ALLOCSPI, + syscall.NLM_F_CREATE|syscall.NLM_F_EXCL|syscall.NLM_F_ACK) + + msg := &nl.XfrmUserSpiInfo{} + msg.XfrmUsersaInfo = *(xfrmUsersaInfoFromXfrmState(state)) + // 1-255 is reserved by IANA for future use + msg.Min = 0x100 + msg.Max = 0xffffffff + req.AddData(msg) + + if state.Mark != nil { + out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark)) + req.AddData(out) + } + + msgs, err := req.Execute(syscall.NETLINK_XFRM, 0) + if err != nil { + return nil, err + } + + s, err := parseXfrmState(msgs[0], FAMILY_ALL) + if err != nil { + return nil, err + } + + return s, err +} + // XfrmStateDel will delete an xfrm state from the system. Note that // the Algos are ignored when matching the state to delete. // Equivalent to: `ip xfrm state del $state` @@ -372,3 +399,17 @@ func limitsToLft(lmts XfrmStateLimits, lft *nl.XfrmLifetimeCfg) { func lftToLimits(lft *nl.XfrmLifetimeCfg, lmts *XfrmStateLimits) { *lmts = *(*XfrmStateLimits)(unsafe.Pointer(lft)) } + +func xfrmUsersaInfoFromXfrmState(state *XfrmState) *nl.XfrmUsersaInfo { + msg := &nl.XfrmUsersaInfo{} + msg.Family = uint16(nl.GetIPFamily(state.Dst)) + msg.Id.Daddr.FromIP(state.Dst) + msg.Saddr.FromIP(state.Src) + msg.Id.Proto = uint8(state.Proto) + msg.Mode = uint8(state.Mode) + msg.Id.Spi = nl.Swap32(uint32(state.Spi)) + msg.Reqid = uint32(state.Reqid) + msg.ReplayWindow = uint8(state.ReplayWindow) + + return msg +} diff --git a/xfrm_state_test.go b/xfrm_state_test.go index 9a199f0..7f58e53 100644 --- a/xfrm_state_test.go +++ b/xfrm_state_test.go @@ -59,6 +59,26 @@ func testXfrmStateAddGetDel(t *testing.T, state *XfrmState) { } } +func TestXfrmStateAllocSpi(t *testing.T) { + setUpNetlinkTest(t)() + + state := getBaseState() + state.Spi = 0 + state.Auth = nil + state.Crypt = nil + rstate, err := XfrmStateAllocSpi(state) + if err != nil { + t.Fatal(err) + } + if rstate.Spi == 0 { + t.Fatalf("SPI is not allocated") + } + rstate.Spi = 0 + if !compareStates(state, rstate) { + t.Fatalf("State not properly allocated") + } +} + func TestXfrmStateFlush(t *testing.T) { setUpNetlinkTest(t)()