Files
gortsplib/pkg/multicast/multi_conn.go
2025-08-11 10:49:54 +02:00

187 lines
4.0 KiB
Go

//go:build !linux
package multicast
import (
"fmt"
"net"
"strconv"
"syscall"
"time"
"golang.org/x/net/ipv4"
)
// multiConn is a multicast connection
// that works in parallel on all interfaces.
type multiConn struct {
addr *net.UDPAddr
readConn *net.UDPConn
readConnIP *ipv4.PacketConn
writeConns []*net.UDPConn
writeConnIPs []*ipv4.PacketConn
}
// NewMultiConn allocates a multi-interface multicast connection.
func NewMultiConn(
address string,
readOnly bool,
listenPacket func(network, address string) (net.PacketConn, error),
) (Conn, error) {
addr, err := net.ResolveUDPAddr("udp4", address)
if err != nil {
return nil, err
}
tmp, err := listenPacket("udp4", "224.0.0.0:"+strconv.FormatInt(int64(addr.Port), 10))
if err != nil {
return nil, err
}
readConn := tmp.(*net.UDPConn)
intfs, err := net.Interfaces()
if err != nil {
readConn.Close() //nolint:errcheck
return nil, err
}
readConnIP := ipv4.NewPacketConn(readConn)
var enabledInterfaces []*net.Interface //nolint:prealloc
for _, intf := range intfs {
if (intf.Flags & net.FlagMulticast) == 0 {
continue
}
cintf := intf
err = readConnIP.JoinGroup(&cintf, &net.UDPAddr{IP: addr.IP})
if err != nil {
continue
}
enabledInterfaces = append(enabledInterfaces, &cintf)
}
if enabledInterfaces == nil {
readConn.Close() //nolint:errcheck
return nil, fmt.Errorf("no multicast-capable interfaces found")
}
var writeConns []*net.UDPConn
var writeConnIPs []*ipv4.PacketConn
if !readOnly {
writeConns = make([]*net.UDPConn, len(enabledInterfaces))
writeConnIPs = make([]*ipv4.PacketConn, len(enabledInterfaces))
for i, intf := range enabledInterfaces {
tmp, err := listenPacket("udp4", "224.0.0.0:"+strconv.FormatInt(int64(addr.Port), 10))
if err != nil {
for j := 0; j < i; j++ {
writeConns[j].Close() //nolint:errcheck
}
readConn.Close() //nolint:errcheck
return nil, err
}
writeConn := tmp.(*net.UDPConn)
writeConnIP := ipv4.NewPacketConn(writeConn)
err = writeConnIP.SetMulticastInterface(intf)
if err != nil {
for j := 0; j < i; j++ {
writeConns[j].Close() //nolint:errcheck
}
readConn.Close() //nolint:errcheck
return nil, err
}
err = writeConnIP.SetMulticastTTL(multicastTTL)
if err != nil {
for j := 0; j < i; j++ {
writeConns[j].Close() //nolint:errcheck
}
readConn.Close() //nolint:errcheck
return nil, err
}
writeConns[i] = writeConn
writeConnIPs[i] = writeConnIP
}
}
return &multiConn{
addr: addr,
readConn: readConn,
readConnIP: readConnIP,
writeConns: writeConns,
writeConnIPs: writeConnIPs,
}, nil
}
// Close implements Conn.
func (c *multiConn) Close() error {
for _, c := range c.writeConns {
c.Close() //nolint:errcheck
}
c.readConn.Close() //nolint:errcheck
return nil
}
// SetReadBuffer implements Conn.
func (c *multiConn) SetReadBuffer(bytes int) error {
return c.readConn.SetReadBuffer(bytes)
}
// SyscallConn implements Conn.
func (c *multiConn) SyscallConn() (syscall.RawConn, error) {
return c.readConn.SyscallConn()
}
// LocalAddr implements Conn.
func (c *multiConn) LocalAddr() net.Addr {
return c.readConn.LocalAddr()
}
// SetDeadline implements Conn.
func (c *multiConn) SetDeadline(_ time.Time) error {
panic("unimplemented")
}
// SetReadDeadline implements Conn.
func (c *multiConn) SetReadDeadline(t time.Time) error {
return c.readConn.SetReadDeadline(t)
}
// SetWriteDeadline implements Conn.
func (c *multiConn) SetWriteDeadline(t time.Time) error {
var err error
for _, c := range c.writeConns {
err2 := c.SetWriteDeadline(t)
if err == nil {
err = err2
}
}
return err
}
// WriteTo implements Conn.
func (c *multiConn) WriteTo(b []byte, addr net.Addr) (int, error) {
var n int
var err error
for _, c := range c.writeConns {
var err2 error
n, err2 = c.WriteTo(b, addr)
if err == nil {
err = err2
}
}
return n, err
}
// ReadFrom implements Conn.
func (c *multiConn) ReadFrom(b []byte) (int, net.Addr, error) {
return c.readConn.ReadFrom(b)
}