fix reading and writing multicast packets in case of multiple interfaces (#413)

(https://github.com/bluenviron/mediamtx/issues/2029)
This commit is contained in:
Alessandro Ros
2023-09-11 23:36:53 +02:00
committed by GitHub
parent 8fdc7193f2
commit 78198a588b
11 changed files with 432 additions and 83 deletions

View File

@@ -32,6 +32,7 @@ linters-settings:
- (*net.TCPConn).SetKeepAlive
- (*net.TCPConn).SetKeepAlivePeriod
- (*net.TCPConn).SetNoDelay
- (*net.UDPConn).Close
- (net.Listener).Close
- (net.PacketConn).Close
- (net.PacketConn).SetReadDeadline

View File

@@ -1303,6 +1303,7 @@ func (c *Client) doSetup(
err := cm.allocateUDPListeners(
false,
nil,
net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)),
net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)),
)
@@ -1398,10 +1399,11 @@ func (c *Client) doSetup(
return nil, liberrors.ErrClientServerPortsNotProvided{}
}
var readIP net.IP
if thRes.Source != nil {
cm.udpRTPListener.readIP = *thRes.Source
readIP = *thRes.Source
} else {
cm.udpRTPListener.readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
}
if serverPortsValid {
@@ -1414,12 +1416,7 @@ func (c *Client) doSetup(
Port: thRes.ServerPorts[0],
}
}
if thRes.Source != nil {
cm.udpRTCPListener.readIP = *thRes.Source
} else {
cm.udpRTCPListener.readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
}
cm.udpRTPListener.readIP = readIP
if serverPortsValid {
if !c.AnyPortEnable {
@@ -1431,6 +1428,7 @@ func (c *Client) doSetup(
Port: thRes.ServerPorts[1],
}
}
cm.udpRTCPListener.readIP = readIP
case TransportUDPMulticast:
if thRes.Delivery == nil || *thRes.Delivery != headers.TransportDeliveryMulticast {
@@ -1445,8 +1443,16 @@ func (c *Client) doSetup(
return nil, liberrors.ErrClientTransportHeaderNoDestination{}
}
var readIP net.IP
if thRes.Source != nil {
readIP = *thRes.Source
} else {
readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
}
err := cm.allocateUDPListeners(
true,
readIP,
net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[0]), 10)),
net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[1]), 10)),
)
@@ -1454,14 +1460,14 @@ func (c *Client) doSetup(
return nil, err
}
cm.udpRTPListener.readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
cm.udpRTPListener.readIP = readIP
cm.udpRTPListener.readPort = thRes.Ports[0]
cm.udpRTPListener.writeAddr = &net.UDPAddr{
IP: *thRes.Destination,
Port: thRes.Ports[0],
}
cm.udpRTCPListener.readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
cm.udpRTCPListener.readIP = readIP
cm.udpRTCPListener.readPort = thRes.Ports[1]
cm.udpRTCPListener.writeAddr = &net.UDPAddr{
IP: *thRes.Destination,

View File

@@ -1,6 +1,7 @@
package gortsplib
import (
"net"
"sync/atomic"
"time"
@@ -41,11 +42,17 @@ func (cm *clientMedia) close() {
}
}
func (cm *clientMedia) allocateUDPListeners(multicast bool, rtpAddress string, rtcpAddress string) error {
func (cm *clientMedia) allocateUDPListeners(
multicastEnable bool,
multicastSourceIP net.IP,
rtpAddress string,
rtcpAddress string,
) error {
if rtpAddress != ":0" {
l1, err := newClientUDPListener(
cm.c,
multicast,
multicastEnable,
multicastSourceIP,
rtpAddress,
)
if err != nil {
@@ -54,7 +61,8 @@ func (cm *clientMedia) allocateUDPListeners(multicast bool, rtpAddress string, r
l2, err := newClientUDPListener(
cm.c,
multicast,
multicastEnable,
multicastSourceIP,
rtcpAddress,
)
if err != nil {

View File

@@ -2,13 +2,14 @@ package gortsplib
import (
"crypto/rand"
"fmt"
"math/big"
"net"
"strconv"
"sync/atomic"
"time"
"golang.org/x/net/ipv4"
"github.com/bluenviron/gortsplib/v4/pkg/multicast"
)
func int64Ptr(v int64) *int64 {
@@ -24,6 +25,35 @@ func randInRange(max int) (int, error) {
return int(n.Int64()), nil
}
func findMulticastInterfaceForSource(ip net.IP) (*net.Interface, error) {
if ip.Equal(net.ParseIP("127.0.0.1")) {
return nil, fmt.Errorf("IP 127.0.0.1 can't be used as source of a multicast stream. Use the LAN IP of your PC")
}
intfs, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, intf := range intfs {
if (intf.Flags & net.FlagMulticast) == 0 {
continue
}
addrs, err := intf.Addrs()
if err == nil {
for _, addr := range addrs {
_, ipnet, err := net.ParseCIDR(addr.String())
if err == nil && ipnet.Contains(ip) {
return &intf, nil
}
}
}
}
return nil, fmt.Errorf("found no interface that is multicast-capable and can communicate with IP %v", ip)
}
type clientUDPListener struct {
c *Client
pc net.PacketConn
@@ -52,6 +82,7 @@ func newClientUDPListenerPair(c *Client) (*clientUDPListener, *clientUDPListener
rtpListener, err := newClientUDPListener(
c,
false,
nil,
net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)),
)
if err != nil {
@@ -62,6 +93,7 @@ func newClientUDPListenerPair(c *Client) (*clientUDPListener, *clientUDPListener
rtcpListener, err := newClientUDPListener(
c,
false,
nil,
net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)),
)
if err != nil {
@@ -73,36 +105,28 @@ func newClientUDPListenerPair(c *Client) (*clientUDPListener, *clientUDPListener
}
}
type packetConn interface {
net.PacketConn
SetReadBuffer(int) error
}
func newClientUDPListener(
c *Client,
multicast bool,
multicastEnable bool,
multicastSourceIP net.IP,
address string,
) (*clientUDPListener, error) {
var pc *net.UDPConn
if multicast {
host, port, err := net.SplitHostPort(address)
var pc packetConn
if multicastEnable {
intf, err := findMulticastInterfaceForSource(multicastSourceIP)
if err != nil {
return nil, err
}
tmp, err := c.ListenPacket(restrictNetwork("udp", "224.0.0.0:"+port))
pc, err = multicast.NewSingleConn(intf, address, c.ListenPacket)
if err != nil {
return nil, err
}
p := ipv4.NewPacketConn(tmp)
err = p.SetMulticastTTL(multicastTTL)
if err != nil {
return nil, err
}
err = joinMulticastGroupOnAtLeastOneInterface(p, net.ParseIP(host))
if err != nil {
return nil, err
}
pc = tmp.(*net.UDPConn)
} else {
tmp, err := c.ListenPacket(restrictNetwork("udp", address))
if err != nil {

View File

@@ -6,7 +6,4 @@ const (
// 1500 (UDP MTU) - 20 (IP header) - 8 (UDP header)
udpMaxPayloadSize = 1472
// same size as GStreamer's rtspsrc
multicastTTL = 16
)

178
pkg/multicast/multi_conn.go Normal file
View File

@@ -0,0 +1,178 @@
package multicast
import (
"fmt"
"net"
"strconv"
"time"
"golang.org/x/net/ipv4"
)
// MultiConn is a multicast connection
// that works in parallel on all interfaces.
type MultiConn struct {
addr *net.UDPAddr
readConn *net.UDPConn
readConnIP *ipv4.PacketConn
writeConns []*net.UDPConn
writeConnIPs []*ipv4.PacketConn
}
// NewMultiConn allocates a MultiConn.
func NewMultiConn(
address string,
listenPacket func(network, address string) (net.PacketConn, error),
) (Conn, error) {
addr, err := net.ResolveUDPAddr("udp4", address)
if err != nil {
return nil, err
}
tmp, err := listenPacket("udp4", "224.0.0.0:"+strconv.FormatInt(int64(addr.Port), 10))
if err != nil {
return nil, err
}
readConn := tmp.(*net.UDPConn)
intfs, err := net.Interfaces()
if err != nil {
readConn.Close() //nolint:errcheck
return nil, err
}
readConnIP := ipv4.NewPacketConn(readConn)
var enabledInterfaces []*net.Interface //nolint:prealloc
for _, intf := range intfs {
if (intf.Flags & net.FlagMulticast) == 0 {
continue
}
cintf := intf
err = readConnIP.JoinGroup(&cintf, &net.UDPAddr{IP: addr.IP})
if err != nil {
continue
}
enabledInterfaces = append(enabledInterfaces, &cintf)
}
if enabledInterfaces == nil {
readConn.Close() //nolint:errcheck
return nil, fmt.Errorf("no multicast-capable interfaces found")
}
err = setupReadFrom(readConnIP)
if err != nil {
readConn.Close() //nolint:errcheck
return nil, err
}
writeConns := make([]*net.UDPConn, len(enabledInterfaces))
writeConnIPs := make([]*ipv4.PacketConn, len(enabledInterfaces))
for i, intf := range enabledInterfaces {
tmp, err := listenPacket("udp4", "224.0.0.0:"+strconv.FormatInt(int64(addr.Port), 10))
if err != nil {
for j := 0; j < i; j++ {
writeConns[j].Close() //nolint:errcheck
}
readConn.Close() //nolint:errcheck
return nil, err
}
writeConn := tmp.(*net.UDPConn)
writeConnIP := ipv4.NewPacketConn(writeConn)
err = writeConnIP.SetMulticastInterface(intf)
if err != nil {
for j := 0; j < i; j++ {
writeConns[j].Close() //nolint:errcheck
}
readConn.Close() //nolint:errcheck
return nil, err
}
err = writeConnIP.SetMulticastTTL(multicastTTL)
if err != nil {
for j := 0; j < i; j++ {
writeConns[j].Close() //nolint:errcheck
}
readConn.Close() //nolint:errcheck
return nil, err
}
writeConns[i] = writeConn
writeConnIPs[i] = writeConnIP
}
return &MultiConn{
addr: addr,
readConn: readConn,
readConnIP: readConnIP,
writeConns: writeConns,
writeConnIPs: writeConnIPs,
}, nil
}
// Close implements Conn.
func (c *MultiConn) Close() error {
for _, c := range c.writeConns {
c.Close() //nolint:errcheck
}
c.readConn.Close() //nolint:errcheck
return nil
}
// SetReadBuffer implements Conn.
func (c *MultiConn) SetReadBuffer(bytes int) error {
return c.readConn.SetReadBuffer(bytes)
}
// LocalAddr implements Conn.
func (c *MultiConn) LocalAddr() net.Addr {
return c.readConn.LocalAddr()
}
// SetDeadline implements Conn.
func (c *MultiConn) SetDeadline(_ time.Time) error {
panic("unimplemented")
}
// SetReadDeadline implements Conn.
func (c *MultiConn) SetReadDeadline(t time.Time) error {
return c.readConn.SetReadDeadline(t)
}
// SetWriteDeadline implements Conn.
func (c *MultiConn) SetWriteDeadline(t time.Time) error {
var err error
for _, c := range c.writeConns {
err2 := c.SetWriteDeadline(t)
if err == nil {
err = err2
}
}
return err
}
// WriteTo implements Conn.
func (c *MultiConn) WriteTo(b []byte, addr net.Addr) (int, error) {
var n int
var err error
for _, c := range c.writeConns {
var err2 error
n, err2 = c.WriteTo(b, addr)
if err == nil {
err = err2
}
}
return n, err
}
// ReadFrom implements Conn.
func (c *MultiConn) ReadFrom(b []byte) (int, net.Addr, error) {
return readFrom(c.readConnIP, c.addr.IP, b)
}

View File

@@ -0,0 +1,12 @@
// Package multicast contains multicast connections.
package multicast
import (
"net"
)
// Conn is a Multicast connection.
type Conn interface {
net.PacketConn
SetReadBuffer(int) error
}

View File

@@ -0,0 +1,32 @@
//go:build !windows
// +build !windows
package multicast
import (
"net"
"golang.org/x/net/ipv4"
)
func setupReadFrom(c *ipv4.PacketConn) error {
return c.SetControlMessage(ipv4.FlagDst, true)
}
func readFrom(c *ipv4.PacketConn, destIP net.IP, b []byte) (int, net.Addr, error) {
for {
n, cm, src, err := c.ReadFrom(b)
if err != nil {
return 0, nil, err
}
// a multicast connection can receive packets
// addressed to groups joined by other connections.
// discard them.
if !cm.Dst.Equal(destIP) {
continue
}
return n, src, nil
}
}

View File

@@ -0,0 +1,19 @@
//go:build windows
// +build windows
package multicast
import (
"net"
"golang.org/x/net/ipv4"
)
func setupReadFrom(c *ipv4.PacketConn) error {
return nil
}
func readFrom(c *ipv4.PacketConn, destIP net.IP, b []byte) (int, net.Addr, error) {
n, _, src, err := c.ReadFrom(b)
return n, src, err
}

View File

@@ -0,0 +1,112 @@
package multicast
import (
"net"
"strconv"
"time"
"golang.org/x/net/ipv4"
)
const (
// same size as GStreamer's rtspsrc
multicastTTL = 16
)
// SingleConn is a multicast connection
// that works on a single interface.
type SingleConn struct {
addr *net.UDPAddr
conn *net.UDPConn
connIP *ipv4.PacketConn
}
// NewSingleConn allocates a SingleConn.
func NewSingleConn(
intf *net.Interface,
address string,
listenPacket func(network, address string) (net.PacketConn, error),
) (Conn, error) {
addr, err := net.ResolveUDPAddr("udp4", address)
if err != nil {
return nil, err
}
tmp, err := listenPacket("udp4", "224.0.0.0:"+strconv.FormatInt(int64(addr.Port), 10))
if err != nil {
return nil, err
}
conn := tmp.(*net.UDPConn)
connIP := ipv4.NewPacketConn(conn)
err = connIP.JoinGroup(intf, &net.UDPAddr{IP: addr.IP})
if err != nil {
conn.Close() //nolint:errcheck
return nil, err
}
err = setupReadFrom(connIP)
if err != nil {
conn.Close() //nolint:errcheck
return nil, err
}
err = connIP.SetMulticastInterface(intf)
if err != nil {
conn.Close() //nolint:errcheck
return nil, err
}
err = connIP.SetMulticastTTL(multicastTTL)
if err != nil {
conn.Close() //nolint:errcheck
return nil, err
}
return &SingleConn{
addr: addr,
conn: conn,
connIP: connIP,
}, nil
}
// Close implements Conn.
func (c *SingleConn) Close() error {
return c.conn.Close()
}
// SetReadBuffer implements Conn.
func (c *SingleConn) SetReadBuffer(bytes int) error {
return c.conn.SetReadBuffer(bytes)
}
// LocalAddr implements Conn.
func (c *SingleConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// SetDeadline implements Conn.
func (c *SingleConn) SetDeadline(_ time.Time) error {
panic("unimplemented")
}
// SetReadDeadline implements Conn.
func (c *SingleConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetWriteDeadline implements Conn.
func (c *SingleConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
// WriteTo implements Conn.
func (c *SingleConn) WriteTo(b []byte, addr net.Addr) (int, error) {
return c.conn.WriteTo(b, addr)
}
// ReadFrom implements Conn.
func (c *SingleConn) ReadFrom(b []byte) (int, net.Addr, error) {
return readFrom(c.connIP, c.addr.IP, b)
}

View File

@@ -1,39 +1,14 @@
package gortsplib
import (
"fmt"
"net"
"strconv"
"sync"
"time"
"golang.org/x/net/ipv4"
"github.com/bluenviron/gortsplib/v4/pkg/multicast"
)
func joinMulticastGroupOnAtLeastOneInterface(p *ipv4.PacketConn, listenIP net.IP) error {
intfs, err := net.Interfaces()
if err != nil {
return err
}
success := false
for _, intf := range intfs {
if (intf.Flags & net.FlagMulticast) != 0 {
err := p.JoinGroup(&intf, &net.UDPAddr{IP: listenIP})
if err == nil {
success = true
}
}
}
if !success {
return fmt.Errorf("unable to activate multicast on any network interface")
}
return nil
}
type clientAddr struct {
ip [net.IPv6len]byte // use a fixed-size array to enable the equality operator
port int
@@ -94,43 +69,28 @@ func newServerUDPListenerMulticastPair(
func newServerUDPListener(
listenPacket func(network, address string) (net.PacketConn, error),
writeTimeout time.Duration,
multicast bool,
multicastEnable bool,
address string,
) (*serverUDPListener, error) {
var pc *net.UDPConn
var pc packetConn
var listenIP net.IP
if multicast {
host, port, err := net.SplitHostPort(address)
if multicastEnable {
var err error
pc, err = multicast.NewMultiConn(address, listenPacket)
if err != nil {
return nil, err
}
tmp, err := listenPacket(restrictNetwork("udp", "224.0.0.0:"+port))
host, _, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
p := ipv4.NewPacketConn(tmp)
err = p.SetMulticastTTL(multicastTTL)
if err != nil {
return nil, err
}
listenIP = net.ParseIP(host)
err = joinMulticastGroupOnAtLeastOneInterface(p, listenIP)
if err != nil {
return nil, err
}
pc = tmp.(*net.UDPConn)
} else {
tmp, err := listenPacket(restrictNetwork("udp", address))
if err != nil {
return nil, err
}
pc = tmp.(*net.UDPConn)
listenIP = tmp.LocalAddr().(*net.UDPAddr).IP
}