mirror of
https://github.com/vishvananda/netlink.git
synced 2025-09-26 20:01:13 +08:00
netlink: enforce similar pid checks as in iproute2
iproute2's own netlink library asserts that the sockaddr sender pid has to be the one of the kernel [0]. It also doesn't bail out on pid mismatch but only skips the message instead. We've seen cases where the latter had a pid 0; in such case we should skip to the next nl message instead of hard bail out. [0] https://git.kernel.org/pub/scm/network/iproute2/iproute2.git/tree/lib/libnetlink.c rtnl_dump_filter_l(), __rtnl_talk_iov() Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
This commit is contained in:

committed by
Flavio Crisciani

parent
43af4161ea
commit
b1e9859792
@@ -328,10 +328,16 @@ func addrSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- AddrUpdate, done <-c
|
||||
go func() {
|
||||
defer close(ch)
|
||||
for {
|
||||
msgs, err := s.Receive()
|
||||
msgs, from, err := s.Receive()
|
||||
if err != nil {
|
||||
if cberr != nil {
|
||||
cberr(fmt.Errorf("Receive: %v", err))
|
||||
cberr(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if from.Pid != nl.PidKernel {
|
||||
if cberr != nil {
|
||||
cberr(fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
@@ -1777,13 +1777,19 @@ func linkSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- LinkUpdate, done <-c
|
||||
go func() {
|
||||
defer close(ch)
|
||||
for {
|
||||
msgs, err := s.Receive()
|
||||
msgs, from, err := s.Receive()
|
||||
if err != nil {
|
||||
if cberr != nil {
|
||||
cberr(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if from.Pid != nl.PidKernel {
|
||||
if cberr != nil {
|
||||
cberr(fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel))
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, m := range msgs {
|
||||
if m.Header.Type == unix.NLMSG_DONE {
|
||||
continue
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
@@ -348,13 +349,19 @@ func neighSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- NeighUpdate, done <
|
||||
go func() {
|
||||
defer close(ch)
|
||||
for {
|
||||
msgs, err := s.Receive()
|
||||
msgs, from, err := s.Receive()
|
||||
if err != nil {
|
||||
if cberr != nil {
|
||||
cberr(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if from.Pid != nl.PidKernel {
|
||||
if cberr != nil {
|
||||
cberr(fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel))
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, m := range msgs {
|
||||
if m.Header.Type == unix.NLMSG_DONE {
|
||||
continue
|
||||
|
@@ -26,6 +26,8 @@ const (
|
||||
// from kernel more verbose messages e.g. for statistics,
|
||||
// tc rules or filters, or other more memory requiring data.
|
||||
RECEIVE_BUFFER_SIZE = 65536
|
||||
// Kernel netlink pid
|
||||
PidKernel uint32 = 0
|
||||
)
|
||||
|
||||
// SupportedNlFamilies contains the list of netlink families this netlink package supports
|
||||
@@ -420,10 +422,13 @@ func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, erro
|
||||
|
||||
done:
|
||||
for {
|
||||
msgs, err := s.Receive()
|
||||
msgs, from, err := s.Receive()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if from.Pid != PidKernel {
|
||||
return nil, fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, PidKernel)
|
||||
}
|
||||
for _, m := range msgs {
|
||||
if m.Header.Seq != req.Seq {
|
||||
if sharedSocket {
|
||||
@@ -432,7 +437,7 @@ done:
|
||||
return nil, fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
|
||||
}
|
||||
if m.Header.Pid != pid {
|
||||
return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
|
||||
continue
|
||||
}
|
||||
if m.Header.Type == unix.NLMSG_DONE {
|
||||
break done
|
||||
@@ -617,22 +622,31 @@ func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, error) {
|
||||
func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetlink, error) {
|
||||
fd := int(atomic.LoadInt32(&s.fd))
|
||||
if fd < 0 {
|
||||
return nil, fmt.Errorf("Receive called on a closed socket")
|
||||
return nil, nil, fmt.Errorf("Receive called on a closed socket")
|
||||
}
|
||||
var fromAddr *unix.SockaddrNetlink
|
||||
var rb [RECEIVE_BUFFER_SIZE]byte
|
||||
nr, _, err := unix.Recvfrom(fd, rb[:], 0)
|
||||
nr, from, err := unix.Recvfrom(fd, rb[:], 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
fromAddr, ok := from.(*unix.SockaddrNetlink)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("Error converting to netlink sockaddr")
|
||||
}
|
||||
if nr < unix.NLMSG_HDRLEN {
|
||||
return nil, fmt.Errorf("Got short response from netlink")
|
||||
return nil, nil, fmt.Errorf("Got short response from netlink")
|
||||
}
|
||||
rb2 := make([]byte, nr)
|
||||
copy(rb2, rb[:nr])
|
||||
return syscall.ParseNetlinkMessage(rb2)
|
||||
nl, err := syscall.ParseNetlinkMessage(rb2)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return nl, fromAddr, nil
|
||||
}
|
||||
|
||||
// SetSendTimeout allows to set a send timeout on the socket
|
||||
|
@@ -73,7 +73,7 @@ func TestIfSocketCloses(t *testing.T) {
|
||||
go func(sk *NetlinkSocket, endCh chan error) {
|
||||
endCh <- nil
|
||||
for {
|
||||
_, err := sk.Receive()
|
||||
_, _, err := sk.Receive()
|
||||
// Receive returned because of a timeout and the FD == -1 means that the socket got closed
|
||||
if err == unix.EAGAIN && nlSock.GetFd() == -1 {
|
||||
endCh <- err
|
||||
|
@@ -1037,13 +1037,19 @@ func routeSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- RouteUpdate, done <
|
||||
go func() {
|
||||
defer close(ch)
|
||||
for {
|
||||
msgs, err := s.Receive()
|
||||
msgs, from, err := s.Receive()
|
||||
if err != nil {
|
||||
if cberr != nil {
|
||||
cberr(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if from.Pid != nl.PidKernel {
|
||||
if cberr != nil {
|
||||
cberr(fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel))
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, m := range msgs {
|
||||
if m.Header.Type == unix.NLMSG_DONE {
|
||||
continue
|
||||
|
@@ -141,10 +141,13 @@ func SocketGet(local, remote net.Addr) (*Socket, error) {
|
||||
},
|
||||
})
|
||||
s.Send(req)
|
||||
msgs, err := s.Receive()
|
||||
msgs, from, err := s.Receive()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if from.Pid != nl.PidKernel {
|
||||
return nil, fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)
|
||||
}
|
||||
if len(msgs) == 0 {
|
||||
return nil, errors.New("no message nor error from netlink")
|
||||
}
|
||||
|
@@ -54,11 +54,15 @@ func XfrmMonitor(ch chan<- XfrmMsg, done <-chan struct{}, errorChan chan<- error
|
||||
go func() {
|
||||
defer close(ch)
|
||||
for {
|
||||
msgs, err := s.Receive()
|
||||
msgs, from, err := s.Receive()
|
||||
if err != nil {
|
||||
errorChan <- err
|
||||
return
|
||||
}
|
||||
if from.Pid != nl.PidKernel {
|
||||
errorChan <- fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)
|
||||
return
|
||||
}
|
||||
for _, m := range msgs {
|
||||
switch m.Header.Type {
|
||||
case nl.XFRM_MSG_EXPIRE:
|
||||
|
Reference in New Issue
Block a user