From e1869a8557ed4efad1a54b7faab03769179ad81e Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Wed, 13 Sep 2023 22:55:20 +0200 Subject: [PATCH] optimize multicast on Linux by listening on a single IP (#417) (https://github.com/bluenviron/mediamtx/issues/2133) --- client_udp_listener.go | 1 + pkg/multicast/multi_conn.go | 11 +- pkg/multicast/multi_conn_lin.go | 245 +++++++++++++++++++++++++++++++ pkg/multicast/read_from_lin.go | 32 ---- pkg/multicast/read_from_win.go | 19 --- pkg/multicast/single_conn.go | 11 +- pkg/multicast/single_conn_lin.go | 180 +++++++++++++++++++++++ server_udp_listener.go | 1 + 8 files changed, 435 insertions(+), 65 deletions(-) create mode 100644 pkg/multicast/multi_conn_lin.go delete mode 100644 pkg/multicast/read_from_lin.go delete mode 100644 pkg/multicast/read_from_win.go create mode 100644 pkg/multicast/single_conn_lin.go diff --git a/client_udp_listener.go b/client_udp_listener.go index 0d90030f..7897c287 100644 --- a/client_udp_listener.go +++ b/client_udp_listener.go @@ -107,6 +107,7 @@ func newClientUDPListener( err := pc.SetReadBuffer(udpKernelReadBufferSize) if err != nil { + pc.Close() return nil, err } diff --git a/pkg/multicast/multi_conn.go b/pkg/multicast/multi_conn.go index 52c756eb..1b7cb397 100644 --- a/pkg/multicast/multi_conn.go +++ b/pkg/multicast/multi_conn.go @@ -1,3 +1,6 @@ +//go:build !linux +// +build !linux + package multicast import ( @@ -64,12 +67,6 @@ func NewMultiConn( return nil, fmt.Errorf("no multicast-capable interfaces found") } - err = setupReadFrom(readConnIP) - if err != nil { - readConn.Close() //nolint:errcheck - return nil, err - } - writeConns := make([]*net.UDPConn, len(enabledInterfaces)) writeConnIPs := make([]*ipv4.PacketConn, len(enabledInterfaces)) @@ -174,5 +171,5 @@ func (c *MultiConn) WriteTo(b []byte, addr net.Addr) (int, error) { // ReadFrom implements Conn. func (c *MultiConn) ReadFrom(b []byte) (int, net.Addr, error) { - return readFrom(c.readConnIP, c.addr.IP, b) + return c.readConn.ReadFrom(b) } diff --git a/pkg/multicast/multi_conn_lin.go b/pkg/multicast/multi_conn_lin.go new file mode 100644 index 00000000..67e89008 --- /dev/null +++ b/pkg/multicast/multi_conn_lin.go @@ -0,0 +1,245 @@ +//go:build linux +// +build linux + +package multicast + +import ( + "fmt" + "net" + "os" + "syscall" + "time" +) + +// MultiConn is a multicast connection +// that works in parallel on all interfaces. +type MultiConn struct { + addr *net.UDPAddr + readFile *os.File + readConn net.PacketConn + writeFiles []*os.File + writeConns []net.PacketConn +} + +// NewMultiConn allocates a MultiConn. +func NewMultiConn( + address string, + _ func(network, address string) (net.PacketConn, error), +) (Conn, error) { + addr, err := net.ResolveUDPAddr("udp4", address) + if err != nil { + return nil, err + } + + readSock, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) + if err != nil { + return nil, err + } + + err = syscall.SetsockoptInt(readSock, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + if err != nil { + syscall.Close(readSock) //nolint:errcheck + return nil, err + } + + const SO_REUSEPORT = 0x0f //nolint:revive + err = syscall.SetsockoptInt(readSock, syscall.SOL_SOCKET, SO_REUSEPORT, 1) + if err != nil { + syscall.Close(readSock) //nolint:errcheck + return nil, err + } + + var lsa syscall.SockaddrInet4 + lsa.Port = addr.Port + copy(lsa.Addr[:], addr.IP.To4()) + err = syscall.Bind(readSock, &lsa) + if err != nil { + syscall.Close(readSock) //nolint:errcheck + return nil, err + } + + intfs, err := net.Interfaces() + if err != nil { + syscall.Close(readSock) //nolint:errcheck + return nil, err + } + + var enabledInterfaces []*net.Interface //nolint:prealloc + for _, intf := range intfs { + if (intf.Flags & net.FlagMulticast) == 0 { + continue + } + cintf := intf + + var mreq syscall.IPMreq + copy(mreq.Multiaddr[:], addr.IP.To4()) + err = setIPMreqInterface(&mreq, &cintf) + if err != nil { + continue + } + + err = syscall.SetsockoptIPMreq(readSock, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, &mreq) + if err != nil { + continue + } + + enabledInterfaces = append(enabledInterfaces, &cintf) + } + + if enabledInterfaces == nil { + syscall.Close(readSock) //nolint:errcheck + return nil, fmt.Errorf("no multicast-capable interfaces found") + } + + writeSocks := make([]int, len(enabledInterfaces)) + + for i, intf := range enabledInterfaces { + writeSock, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) + if err != nil { + for j := 0; j < i; j++ { + syscall.Close(writeSocks[j]) //nolint:errcheck + } + syscall.Close(readSock) //nolint:errcheck + return nil, err + } + + err = syscall.SetsockoptInt(writeSock, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + if err != nil { + syscall.Close(writeSock) //nolint:errcheck + for j := 0; j < i; j++ { + syscall.Close(writeSocks[j]) //nolint:errcheck + } + syscall.Close(readSock) //nolint:errcheck + return nil, err + } + + const SO_REUSEPORT = 0x0f //nolint:revive + err = syscall.SetsockoptInt(writeSock, syscall.SOL_SOCKET, SO_REUSEPORT, 1) + if err != nil { + syscall.Close(writeSock) //nolint:errcheck + for j := 0; j < i; j++ { + syscall.Close(writeSocks[j]) //nolint:errcheck + } + syscall.Close(readSock) //nolint:errcheck + return nil, err + } + + var lsa syscall.SockaddrInet4 + lsa.Port = addr.Port + copy(lsa.Addr[:], addr.IP.To4()) + err = syscall.Bind(writeSock, &lsa) + if err != nil { + syscall.Close(writeSock) //nolint:errcheck + for j := 0; j < i; j++ { + syscall.Close(writeSocks[j]) //nolint:errcheck + } + syscall.Close(readSock) //nolint:errcheck + return nil, err + } + + var mreqn syscall.IPMreqn + mreqn.Ifindex = int32(intf.Index) + + err = syscall.SetsockoptIPMreqn(writeSock, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, &mreqn) + if err != nil { + syscall.Close(writeSock) //nolint:errcheck + for j := 0; j < i; j++ { + syscall.Close(writeSocks[j]) //nolint:errcheck + } + syscall.Close(readSock) //nolint:errcheck + return nil, err + } + + err = syscall.SetsockoptInt(writeSock, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, multicastTTL) + if err != nil { + syscall.Close(writeSock) //nolint:errcheck + for j := 0; j < i; j++ { + syscall.Close(writeSocks[j]) //nolint:errcheck + } + syscall.Close(readSock) //nolint:errcheck + return nil, err + } + + writeSocks[i] = writeSock + } + + readFile := os.NewFile(uintptr(readSock), "") + readConn, _ := net.FilePacketConn(readFile) + writeFiles := make([]*os.File, len(writeSocks)) + writeConns := make([]net.PacketConn, len(writeSocks)) + + for i, writeSock := range writeSocks { + writeFiles[i] = os.NewFile(uintptr(writeSock), "") + writeConns[i], _ = net.FilePacketConn(writeFiles[i]) + } + + return &MultiConn{ + addr: addr, + readFile: readFile, + readConn: readConn, + writeFiles: writeFiles, + writeConns: writeConns, + }, nil +} + +// Close implements Conn. +func (c *MultiConn) Close() error { + for i, writeConn := range c.writeConns { + writeConn.Close() + c.writeFiles[i].Close() + } + c.readConn.Close() + c.readFile.Close() + return nil +} + +// SetReadBuffer implements Conn. +func (c *MultiConn) SetReadBuffer(bytes int) error { + return syscall.SetsockoptInt(int(c.readFile.Fd()), syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes) +} + +// LocalAddr implements Conn. +func (c *MultiConn) LocalAddr() net.Addr { + return c.readConn.LocalAddr() +} + +// SetDeadline implements Conn. +func (c *MultiConn) SetDeadline(_ time.Time) error { + panic("unimplemented") +} + +// SetReadDeadline implements Conn. +func (c *MultiConn) SetReadDeadline(t time.Time) error { + return c.readConn.SetReadDeadline(t) +} + +// SetWriteDeadline implements Conn. +func (c *MultiConn) SetWriteDeadline(t time.Time) error { + var err error + for _, c := range c.writeConns { + err2 := c.SetWriteDeadline(t) + if err == nil { + err = err2 + } + } + return err +} + +// WriteTo implements Conn. +func (c *MultiConn) WriteTo(b []byte, addr net.Addr) (int, error) { + var n int + var err error + for _, c := range c.writeConns { + var err2 error + n, err2 = c.WriteTo(b, addr) + if err == nil { + err = err2 + } + } + return n, err +} + +// ReadFrom implements Conn. +func (c *MultiConn) ReadFrom(b []byte) (int, net.Addr, error) { + return c.readConn.ReadFrom(b) +} diff --git a/pkg/multicast/read_from_lin.go b/pkg/multicast/read_from_lin.go deleted file mode 100644 index 5faa4a50..00000000 --- a/pkg/multicast/read_from_lin.go +++ /dev/null @@ -1,32 +0,0 @@ -//go:build !windows -// +build !windows - -package multicast - -import ( - "net" - - "golang.org/x/net/ipv4" -) - -func setupReadFrom(c *ipv4.PacketConn) error { - return c.SetControlMessage(ipv4.FlagDst, true) -} - -func readFrom(c *ipv4.PacketConn, destIP net.IP, b []byte) (int, net.Addr, error) { - for { - n, cm, src, err := c.ReadFrom(b) - if err != nil { - return 0, nil, err - } - - // a multicast connection can receive packets - // addressed to groups joined by other connections. - // discard them. - if !cm.Dst.Equal(destIP) { - continue - } - - return n, src, nil - } -} diff --git a/pkg/multicast/read_from_win.go b/pkg/multicast/read_from_win.go deleted file mode 100644 index 7fc0bb02..00000000 --- a/pkg/multicast/read_from_win.go +++ /dev/null @@ -1,19 +0,0 @@ -//go:build windows -// +build windows - -package multicast - -import ( - "net" - - "golang.org/x/net/ipv4" -) - -func setupReadFrom(c *ipv4.PacketConn) error { - return nil -} - -func readFrom(c *ipv4.PacketConn, destIP net.IP, b []byte) (int, net.Addr, error) { - n, _, src, err := c.ReadFrom(b) - return n, src, err -} diff --git a/pkg/multicast/single_conn.go b/pkg/multicast/single_conn.go index 8873bfe3..4484ac43 100644 --- a/pkg/multicast/single_conn.go +++ b/pkg/multicast/single_conn.go @@ -1,3 +1,6 @@ +//go:build !linux +// +build !linux + package multicast import ( @@ -46,12 +49,6 @@ func NewSingleConn( return nil, err } - err = setupReadFrom(connIP) - if err != nil { - conn.Close() //nolint:errcheck - return nil, err - } - err = connIP.SetMulticastInterface(intf) if err != nil { conn.Close() //nolint:errcheck @@ -108,5 +105,5 @@ func (c *SingleConn) WriteTo(b []byte, addr net.Addr) (int, error) { // ReadFrom implements Conn. func (c *SingleConn) ReadFrom(b []byte) (int, net.Addr, error) { - return readFrom(c.connIP, c.addr.IP, b) + return c.conn.ReadFrom(b) } diff --git a/pkg/multicast/single_conn_lin.go b/pkg/multicast/single_conn_lin.go new file mode 100644 index 00000000..e3277e9c --- /dev/null +++ b/pkg/multicast/single_conn_lin.go @@ -0,0 +1,180 @@ +//go:build linux +// +build linux + +package multicast + +import ( + "fmt" + "net" + "os" + "syscall" + "time" +) + +const ( + // same size as GStreamer's rtspsrc + multicastTTL = 16 +) + +// https://cs.opensource.google/go/x/net/+/refs/tags/v0.15.0:ipv4/sys_asmreq.go;l=51 +func setIPMreqInterface(mreq *syscall.IPMreq, ifi *net.Interface) error { + if ifi == nil { + return nil + } + ifat, err := ifi.Addrs() + if err != nil { + return err + } + for _, ifa := range ifat { + switch ifa := ifa.(type) { + case *net.IPAddr: + if ip := ifa.IP.To4(); ip != nil { + copy(mreq.Interface[:], ip) + return nil + } + case *net.IPNet: + if ip := ifa.IP.To4(); ip != nil { + copy(mreq.Interface[:], ip) + return nil + } + } + } + return fmt.Errorf("no such interface") +} + +// SingleConn is a multicast connection +// that works on a single interface. +type SingleConn struct { + addr *net.UDPAddr + file *os.File + conn net.PacketConn +} + +// NewSingleConn allocates a SingleConn. +func NewSingleConn( + intf *net.Interface, + address string, + _ func(network, address string) (net.PacketConn, error), +) (Conn, error) { + addr, err := net.ResolveUDPAddr("udp4", address) + if err != nil { + return nil, err + } + + sock, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) + if err != nil { + return nil, err + } + + err = syscall.SetsockoptInt(sock, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + if err != nil { + syscall.Close(sock) //nolint:errcheck + return nil, err + } + + const SO_REUSEPORT = 0x0f //nolint:revive + err = syscall.SetsockoptInt(sock, syscall.SOL_SOCKET, SO_REUSEPORT, 1) + if err != nil { + syscall.Close(sock) //nolint:errcheck + return nil, err + } + + err = syscall.SetsockoptString(sock, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, intf.Name) + if err != nil { + syscall.Close(sock) //nolint:errcheck + return nil, err + } + + var lsa syscall.SockaddrInet4 + lsa.Port = addr.Port + copy(lsa.Addr[:], addr.IP.To4()) + err = syscall.Bind(sock, &lsa) + if err != nil { + syscall.Close(sock) //nolint:errcheck + return nil, err + } + + var mreq syscall.IPMreq + copy(mreq.Multiaddr[:], addr.IP.To4()) + err = setIPMreqInterface(&mreq, intf) + if err != nil { + syscall.Close(sock) //nolint:errcheck + return nil, err + } + + err = syscall.SetsockoptIPMreq(sock, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, &mreq) + if err != nil { + syscall.Close(sock) //nolint:errcheck + return nil, err + } + + var mreqn syscall.IPMreqn + mreqn.Ifindex = int32(intf.Index) + + err = syscall.SetsockoptIPMreqn(sock, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, &mreqn) + if err != nil { + syscall.Close(sock) //nolint:errcheck + return nil, err + } + + err = syscall.SetsockoptInt(sock, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, multicastTTL) + if err != nil { + syscall.Close(sock) //nolint:errcheck + return nil, err + } + + file := os.NewFile(uintptr(sock), "") + conn, err := net.FilePacketConn(file) + if err != nil { + file.Close() + return nil, err + } + + return &SingleConn{ + addr: addr, + file: file, + conn: conn, + }, nil +} + +// Close implements Conn. +func (c *SingleConn) Close() error { + c.conn.Close() + c.file.Close() + return nil +} + +// SetReadBuffer implements Conn. +func (c *SingleConn) SetReadBuffer(bytes int) error { + return syscall.SetsockoptInt(int(c.file.Fd()), syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes) +} + +// LocalAddr implements Conn. +func (c *SingleConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// SetDeadline implements Conn. +func (c *SingleConn) SetDeadline(_ time.Time) error { + panic("unimplemented") +} + +// SetReadDeadline implements Conn. +func (c *SingleConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetWriteDeadline implements Conn. +func (c *SingleConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +// WriteTo implements Conn. +func (c *SingleConn) WriteTo(b []byte, addr net.Addr) (int, error) { + return c.conn.WriteTo(b, addr) +} + +// ReadFrom implements Conn. +func (c *SingleConn) ReadFrom(b []byte) (int, net.Addr, error) { + return c.conn.ReadFrom(b) +} diff --git a/server_udp_listener.go b/server_udp_listener.go index b85c5791..4d5871d1 100644 --- a/server_udp_listener.go +++ b/server_udp_listener.go @@ -97,6 +97,7 @@ func newServerUDPListener( err := pc.SetReadBuffer(udpKernelReadBufferSize) if err != nil { + pc.Close() return nil, err }