multicast: add readOnly flag (#423)

This commit is contained in:
Alessandro Ros
2023-09-16 11:03:32 +02:00
committed by GitHub
parent e6f7c4dea4
commit 99773e19af
3 changed files with 103 additions and 90 deletions

View File

@@ -25,6 +25,7 @@ type MultiConn struct {
// NewMultiConn allocates a MultiConn. // NewMultiConn allocates a MultiConn.
func NewMultiConn( func NewMultiConn(
address string, address string,
readOnly bool,
listenPacket func(network, address string) (net.PacketConn, error), listenPacket func(network, address string) (net.PacketConn, error),
) (Conn, error) { ) (Conn, error) {
addr, err := net.ResolveUDPAddr("udp4", address) addr, err := net.ResolveUDPAddr("udp4", address)
@@ -67,42 +68,47 @@ func NewMultiConn(
return nil, fmt.Errorf("no multicast-capable interfaces found") return nil, fmt.Errorf("no multicast-capable interfaces found")
} }
writeConns := make([]*net.UDPConn, len(enabledInterfaces)) var writeConns []*net.UDPConn
writeConnIPs := make([]*ipv4.PacketConn, len(enabledInterfaces)) var writeConnIPs []*ipv4.PacketConn
for i, intf := range enabledInterfaces { if !readOnly {
tmp, err := listenPacket("udp4", "224.0.0.0:"+strconv.FormatInt(int64(addr.Port), 10)) writeConns = make([]*net.UDPConn, len(enabledInterfaces))
if err != nil { writeConnIPs = make([]*ipv4.PacketConn, len(enabledInterfaces))
for j := 0; j < i; j++ {
writeConns[j].Close() //nolint:errcheck for i, intf := range enabledInterfaces {
tmp, err := listenPacket("udp4", "224.0.0.0:"+strconv.FormatInt(int64(addr.Port), 10))
if err != nil {
for j := 0; j < i; j++ {
writeConns[j].Close() //nolint:errcheck
}
readConn.Close() //nolint:errcheck
return nil, err
} }
readConn.Close() //nolint:errcheck writeConn := tmp.(*net.UDPConn)
return nil, err
}
writeConn := tmp.(*net.UDPConn)
writeConnIP := ipv4.NewPacketConn(writeConn) writeConnIP := ipv4.NewPacketConn(writeConn)
err = writeConnIP.SetMulticastInterface(intf) err = writeConnIP.SetMulticastInterface(intf)
if err != nil { if err != nil {
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
writeConns[j].Close() //nolint:errcheck writeConns[j].Close() //nolint:errcheck
}
readConn.Close() //nolint:errcheck
return nil, err
} }
readConn.Close() //nolint:errcheck
return nil, err
}
err = writeConnIP.SetMulticastTTL(multicastTTL) err = writeConnIP.SetMulticastTTL(multicastTTL)
if err != nil { if err != nil {
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
writeConns[j].Close() //nolint:errcheck writeConns[j].Close() //nolint:errcheck
}
readConn.Close() //nolint:errcheck
return nil, err
} }
readConn.Close() //nolint:errcheck
return nil, err
}
writeConns[i] = writeConn writeConns[i] = writeConn
writeConnIPs[i] = writeConnIP writeConnIPs[i] = writeConnIP
}
} }
return &MultiConn{ return &MultiConn{

View File

@@ -24,6 +24,7 @@ type MultiConn struct {
// NewMultiConn allocates a MultiConn. // NewMultiConn allocates a MultiConn.
func NewMultiConn( func NewMultiConn(
address string, address string,
readOnly bool,
_ func(network, address string) (net.PacketConn, error), _ func(network, address string) (net.PacketConn, error),
) (Conn, error) { ) (Conn, error) {
addr, err := net.ResolveUDPAddr("udp4", address) addr, err := net.ResolveUDPAddr("udp4", address)
@@ -84,76 +85,82 @@ func NewMultiConn(
return nil, fmt.Errorf("no multicast-capable interfaces found") return nil, fmt.Errorf("no multicast-capable interfaces found")
} }
writeSocks := make([]int, len(enabledInterfaces)) var writeFiles []*os.File
var writeConns []net.PacketConn
for i, intf := range enabledInterfaces { if !readOnly {
writeSock, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) writeSocks := make([]int, len(enabledInterfaces))
if err != nil {
for j := 0; j < i; j++ { for i, intf := range enabledInterfaces {
syscall.Close(writeSocks[j]) //nolint:errcheck writeSock, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
if err != nil {
for j := 0; j < i; j++ {
syscall.Close(writeSocks[j]) //nolint:errcheck
}
syscall.Close(readSock) //nolint:errcheck
return nil, err
} }
syscall.Close(readSock) //nolint:errcheck
return nil, err err = syscall.SetsockoptInt(writeSock, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
if err != nil {
syscall.Close(writeSock) //nolint:errcheck
for j := 0; j < i; j++ {
syscall.Close(writeSocks[j]) //nolint:errcheck
}
syscall.Close(readSock) //nolint:errcheck
return nil, err
}
var lsa syscall.SockaddrInet4
lsa.Port = addr.Port
copy(lsa.Addr[:], addr.IP.To4())
err = syscall.Bind(writeSock, &lsa)
if err != nil {
syscall.Close(writeSock) //nolint:errcheck
for j := 0; j < i; j++ {
syscall.Close(writeSocks[j]) //nolint:errcheck
}
syscall.Close(readSock) //nolint:errcheck
return nil, err
}
var mreqn syscall.IPMreqn
mreqn.Ifindex = int32(intf.Index)
err = syscall.SetsockoptIPMreqn(writeSock, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, &mreqn)
if err != nil {
syscall.Close(writeSock) //nolint:errcheck
for j := 0; j < i; j++ {
syscall.Close(writeSocks[j]) //nolint:errcheck
}
syscall.Close(readSock) //nolint:errcheck
return nil, err
}
err = syscall.SetsockoptInt(writeSock, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, multicastTTL)
if err != nil {
syscall.Close(writeSock) //nolint:errcheck
for j := 0; j < i; j++ {
syscall.Close(writeSocks[j]) //nolint:errcheck
}
syscall.Close(readSock) //nolint:errcheck
return nil, err
}
writeSocks[i] = writeSock
} }
err = syscall.SetsockoptInt(writeSock, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) writeFiles = make([]*os.File, len(writeSocks))
if err != nil { writeConns = make([]net.PacketConn, len(writeSocks))
syscall.Close(writeSock) //nolint:errcheck
for j := 0; j < i; j++ { for i, writeSock := range writeSocks {
syscall.Close(writeSocks[j]) //nolint:errcheck writeFiles[i] = os.NewFile(uintptr(writeSock), "")
} writeConns[i], _ = net.FilePacketConn(writeFiles[i])
syscall.Close(readSock) //nolint:errcheck
return nil, err
} }
var lsa syscall.SockaddrInet4
lsa.Port = addr.Port
copy(lsa.Addr[:], addr.IP.To4())
err = syscall.Bind(writeSock, &lsa)
if err != nil {
syscall.Close(writeSock) //nolint:errcheck
for j := 0; j < i; j++ {
syscall.Close(writeSocks[j]) //nolint:errcheck
}
syscall.Close(readSock) //nolint:errcheck
return nil, err
}
var mreqn syscall.IPMreqn
mreqn.Ifindex = int32(intf.Index)
err = syscall.SetsockoptIPMreqn(writeSock, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, &mreqn)
if err != nil {
syscall.Close(writeSock) //nolint:errcheck
for j := 0; j < i; j++ {
syscall.Close(writeSocks[j]) //nolint:errcheck
}
syscall.Close(readSock) //nolint:errcheck
return nil, err
}
err = syscall.SetsockoptInt(writeSock, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, multicastTTL)
if err != nil {
syscall.Close(writeSock) //nolint:errcheck
for j := 0; j < i; j++ {
syscall.Close(writeSocks[j]) //nolint:errcheck
}
syscall.Close(readSock) //nolint:errcheck
return nil, err
}
writeSocks[i] = writeSock
} }
readFile := os.NewFile(uintptr(readSock), "") readFile := os.NewFile(uintptr(readSock), "")
readConn, _ := net.FilePacketConn(readFile) readConn, _ := net.FilePacketConn(readFile)
writeFiles := make([]*os.File, len(writeSocks))
writeConns := make([]net.PacketConn, len(writeSocks))
for i, writeSock := range writeSocks {
writeFiles[i] = os.NewFile(uintptr(writeSock), "")
writeConns[i], _ = net.FilePacketConn(writeFiles[i])
}
return &MultiConn{ return &MultiConn{
addr: addr, addr: addr,

View File

@@ -76,7 +76,7 @@ func newServerUDPListener(
var listenIP net.IP var listenIP net.IP
if multicastEnable { if multicastEnable {
var err error var err error
pc, err = multicast.NewMultiConn(address, listenPacket) pc, err = multicast.NewMultiConn(address, false, listenPacket)
if err != nil { if err != nil {
return nil, err return nil, err
} }