netlink: enforce similar pid checks as in iproute2

iproute2's own netlink library asserts that the sockaddr sender pid
has to be the one of the kernel [0]. It also doesn't bail out on pid
mismatch but only skips the message instead. We've seen cases where
the latter had a pid 0; in such case we should skip to the next nl
message instead of hard bail out.

  [0] https://git.kernel.org/pub/scm/network/iproute2/iproute2.git/tree/lib/libnetlink.c
      rtnl_dump_filter_l(), __rtnl_talk_iov()

Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
This commit is contained in:
Daniel Borkmann
2019-07-15 19:54:20 +02:00
committed by Flavio Crisciani
parent 43af4161ea
commit b1e9859792
8 changed files with 62 additions and 16 deletions

View File

@@ -328,10 +328,16 @@ func addrSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- AddrUpdate, done <-c
go func() {
defer close(ch)
for {
msgs, err := s.Receive()
msgs, from, err := s.Receive()
if err != nil {
if cberr != nil {
cberr(fmt.Errorf("Receive: %v", err))
cberr(err)
}
return
}
if from.Pid != nl.PidKernel {
if cberr != nil {
cberr(fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel))
}
continue
}

View File

@@ -1777,13 +1777,19 @@ func linkSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- LinkUpdate, done <-c
go func() {
defer close(ch)
for {
msgs, err := s.Receive()
msgs, from, err := s.Receive()
if err != nil {
if cberr != nil {
cberr(err)
}
return
}
if from.Pid != nl.PidKernel {
if cberr != nil {
cberr(fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel))
}
continue
}
for _, m := range msgs {
if m.Header.Type == unix.NLMSG_DONE {
continue

View File

@@ -1,6 +1,7 @@
package netlink
import (
"fmt"
"net"
"syscall"
"unsafe"
@@ -348,13 +349,19 @@ func neighSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- NeighUpdate, done <
go func() {
defer close(ch)
for {
msgs, err := s.Receive()
msgs, from, err := s.Receive()
if err != nil {
if cberr != nil {
cberr(err)
}
return
}
if from.Pid != nl.PidKernel {
if cberr != nil {
cberr(fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel))
}
continue
}
for _, m := range msgs {
if m.Header.Type == unix.NLMSG_DONE {
continue

View File

@@ -26,6 +26,8 @@ const (
// from kernel more verbose messages e.g. for statistics,
// tc rules or filters, or other more memory requiring data.
RECEIVE_BUFFER_SIZE = 65536
// Kernel netlink pid
PidKernel uint32 = 0
)
// SupportedNlFamilies contains the list of netlink families this netlink package supports
@@ -420,10 +422,13 @@ func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, erro
done:
for {
msgs, err := s.Receive()
msgs, from, err := s.Receive()
if err != nil {
return nil, err
}
if from.Pid != PidKernel {
return nil, fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, PidKernel)
}
for _, m := range msgs {
if m.Header.Seq != req.Seq {
if sharedSocket {
@@ -432,7 +437,7 @@ done:
return nil, fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
}
if m.Header.Pid != pid {
return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
continue
}
if m.Header.Type == unix.NLMSG_DONE {
break done
@@ -617,22 +622,31 @@ func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
return nil
}
func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, error) {
func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetlink, error) {
fd := int(atomic.LoadInt32(&s.fd))
if fd < 0 {
return nil, fmt.Errorf("Receive called on a closed socket")
return nil, nil, fmt.Errorf("Receive called on a closed socket")
}
var fromAddr *unix.SockaddrNetlink
var rb [RECEIVE_BUFFER_SIZE]byte
nr, _, err := unix.Recvfrom(fd, rb[:], 0)
nr, from, err := unix.Recvfrom(fd, rb[:], 0)
if err != nil {
return nil, err
return nil, nil, err
}
fromAddr, ok := from.(*unix.SockaddrNetlink)
if !ok {
return nil, nil, fmt.Errorf("Error converting to netlink sockaddr")
}
if nr < unix.NLMSG_HDRLEN {
return nil, fmt.Errorf("Got short response from netlink")
return nil, nil, fmt.Errorf("Got short response from netlink")
}
rb2 := make([]byte, nr)
copy(rb2, rb[:nr])
return syscall.ParseNetlinkMessage(rb2)
nl, err := syscall.ParseNetlinkMessage(rb2)
if err != nil {
return nil, nil, err
}
return nl, fromAddr, nil
}
// SetSendTimeout allows to set a send timeout on the socket

View File

@@ -73,7 +73,7 @@ func TestIfSocketCloses(t *testing.T) {
go func(sk *NetlinkSocket, endCh chan error) {
endCh <- nil
for {
_, err := sk.Receive()
_, _, err := sk.Receive()
// Receive returned because of a timeout and the FD == -1 means that the socket got closed
if err == unix.EAGAIN && nlSock.GetFd() == -1 {
endCh <- err

View File

@@ -1037,13 +1037,19 @@ func routeSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- RouteUpdate, done <
go func() {
defer close(ch)
for {
msgs, err := s.Receive()
msgs, from, err := s.Receive()
if err != nil {
if cberr != nil {
cberr(err)
}
return
}
if from.Pid != nl.PidKernel {
if cberr != nil {
cberr(fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel))
}
continue
}
for _, m := range msgs {
if m.Header.Type == unix.NLMSG_DONE {
continue

View File

@@ -141,10 +141,13 @@ func SocketGet(local, remote net.Addr) (*Socket, error) {
},
})
s.Send(req)
msgs, err := s.Receive()
msgs, from, err := s.Receive()
if err != nil {
return nil, err
}
if from.Pid != nl.PidKernel {
return nil, fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)
}
if len(msgs) == 0 {
return nil, errors.New("no message nor error from netlink")
}

View File

@@ -54,11 +54,15 @@ func XfrmMonitor(ch chan<- XfrmMsg, done <-chan struct{}, errorChan chan<- error
go func() {
defer close(ch)
for {
msgs, err := s.Receive()
msgs, from, err := s.Receive()
if err != nil {
errorChan <- err
return
}
if from.Pid != nl.PidKernel {
errorChan <- fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)
return
}
for _, m := range msgs {
switch m.Header.Type {
case nl.XFRM_MSG_EXPIRE: