diff --git a/pkg/multicast/multi_conn.go b/pkg/multicast/multi_conn.go index 1b7cb397..2e9243d1 100644 --- a/pkg/multicast/multi_conn.go +++ b/pkg/multicast/multi_conn.go @@ -25,6 +25,7 @@ type MultiConn struct { // NewMultiConn allocates a MultiConn. func NewMultiConn( address string, + readOnly bool, listenPacket func(network, address string) (net.PacketConn, error), ) (Conn, error) { addr, err := net.ResolveUDPAddr("udp4", address) @@ -67,42 +68,47 @@ func NewMultiConn( return nil, fmt.Errorf("no multicast-capable interfaces found") } - writeConns := make([]*net.UDPConn, len(enabledInterfaces)) - writeConnIPs := make([]*ipv4.PacketConn, len(enabledInterfaces)) + var writeConns []*net.UDPConn + var writeConnIPs []*ipv4.PacketConn - for i, intf := range enabledInterfaces { - tmp, err := listenPacket("udp4", "224.0.0.0:"+strconv.FormatInt(int64(addr.Port), 10)) - if err != nil { - for j := 0; j < i; j++ { - writeConns[j].Close() //nolint:errcheck + if !readOnly { + writeConns = make([]*net.UDPConn, len(enabledInterfaces)) + writeConnIPs = make([]*ipv4.PacketConn, len(enabledInterfaces)) + + for i, intf := range enabledInterfaces { + tmp, err := listenPacket("udp4", "224.0.0.0:"+strconv.FormatInt(int64(addr.Port), 10)) + if err != nil { + for j := 0; j < i; j++ { + writeConns[j].Close() //nolint:errcheck + } + readConn.Close() //nolint:errcheck + return nil, err } - readConn.Close() //nolint:errcheck - return nil, err - } - writeConn := tmp.(*net.UDPConn) + writeConn := tmp.(*net.UDPConn) - writeConnIP := ipv4.NewPacketConn(writeConn) + writeConnIP := ipv4.NewPacketConn(writeConn) - err = writeConnIP.SetMulticastInterface(intf) - if err != nil { - for j := 0; j < i; j++ { - writeConns[j].Close() //nolint:errcheck + err = writeConnIP.SetMulticastInterface(intf) + if err != nil { + for j := 0; j < i; j++ { + writeConns[j].Close() //nolint:errcheck + } + readConn.Close() //nolint:errcheck + return nil, err } - readConn.Close() //nolint:errcheck - return nil, err - } - err = writeConnIP.SetMulticastTTL(multicastTTL) - if err != nil { - for j := 0; j < i; j++ { - writeConns[j].Close() //nolint:errcheck + err = writeConnIP.SetMulticastTTL(multicastTTL) + if err != nil { + for j := 0; j < i; j++ { + writeConns[j].Close() //nolint:errcheck + } + readConn.Close() //nolint:errcheck + return nil, err } - readConn.Close() //nolint:errcheck - return nil, err - } - writeConns[i] = writeConn - writeConnIPs[i] = writeConnIP + writeConns[i] = writeConn + writeConnIPs[i] = writeConnIP + } } return &MultiConn{ diff --git a/pkg/multicast/multi_conn_lin.go b/pkg/multicast/multi_conn_lin.go index 89755a4e..c450e7dd 100644 --- a/pkg/multicast/multi_conn_lin.go +++ b/pkg/multicast/multi_conn_lin.go @@ -24,6 +24,7 @@ type MultiConn struct { // NewMultiConn allocates a MultiConn. func NewMultiConn( address string, + readOnly bool, _ func(network, address string) (net.PacketConn, error), ) (Conn, error) { addr, err := net.ResolveUDPAddr("udp4", address) @@ -84,76 +85,82 @@ func NewMultiConn( return nil, fmt.Errorf("no multicast-capable interfaces found") } - writeSocks := make([]int, len(enabledInterfaces)) + var writeFiles []*os.File + var writeConns []net.PacketConn - for i, intf := range enabledInterfaces { - writeSock, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) - if err != nil { - for j := 0; j < i; j++ { - syscall.Close(writeSocks[j]) //nolint:errcheck + if !readOnly { + writeSocks := make([]int, len(enabledInterfaces)) + + for i, intf := range enabledInterfaces { + writeSock, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) + if err != nil { + for j := 0; j < i; j++ { + syscall.Close(writeSocks[j]) //nolint:errcheck + } + syscall.Close(readSock) //nolint:errcheck + return nil, err } - syscall.Close(readSock) //nolint:errcheck - return nil, err + + err = syscall.SetsockoptInt(writeSock, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + if err != nil { + syscall.Close(writeSock) //nolint:errcheck + for j := 0; j < i; j++ { + syscall.Close(writeSocks[j]) //nolint:errcheck + } + syscall.Close(readSock) //nolint:errcheck + return nil, err + } + + var lsa syscall.SockaddrInet4 + lsa.Port = addr.Port + copy(lsa.Addr[:], addr.IP.To4()) + err = syscall.Bind(writeSock, &lsa) + if err != nil { + syscall.Close(writeSock) //nolint:errcheck + for j := 0; j < i; j++ { + syscall.Close(writeSocks[j]) //nolint:errcheck + } + syscall.Close(readSock) //nolint:errcheck + return nil, err + } + + var mreqn syscall.IPMreqn + mreqn.Ifindex = int32(intf.Index) + + err = syscall.SetsockoptIPMreqn(writeSock, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, &mreqn) + if err != nil { + syscall.Close(writeSock) //nolint:errcheck + for j := 0; j < i; j++ { + syscall.Close(writeSocks[j]) //nolint:errcheck + } + syscall.Close(readSock) //nolint:errcheck + return nil, err + } + + err = syscall.SetsockoptInt(writeSock, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, multicastTTL) + if err != nil { + syscall.Close(writeSock) //nolint:errcheck + for j := 0; j < i; j++ { + syscall.Close(writeSocks[j]) //nolint:errcheck + } + syscall.Close(readSock) //nolint:errcheck + return nil, err + } + + writeSocks[i] = writeSock } - err = syscall.SetsockoptInt(writeSock, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) - if err != nil { - syscall.Close(writeSock) //nolint:errcheck - for j := 0; j < i; j++ { - syscall.Close(writeSocks[j]) //nolint:errcheck - } - syscall.Close(readSock) //nolint:errcheck - return nil, err + writeFiles = make([]*os.File, len(writeSocks)) + writeConns = make([]net.PacketConn, len(writeSocks)) + + for i, writeSock := range writeSocks { + writeFiles[i] = os.NewFile(uintptr(writeSock), "") + writeConns[i], _ = net.FilePacketConn(writeFiles[i]) } - - var lsa syscall.SockaddrInet4 - lsa.Port = addr.Port - copy(lsa.Addr[:], addr.IP.To4()) - err = syscall.Bind(writeSock, &lsa) - if err != nil { - syscall.Close(writeSock) //nolint:errcheck - for j := 0; j < i; j++ { - syscall.Close(writeSocks[j]) //nolint:errcheck - } - syscall.Close(readSock) //nolint:errcheck - return nil, err - } - - var mreqn syscall.IPMreqn - mreqn.Ifindex = int32(intf.Index) - - err = syscall.SetsockoptIPMreqn(writeSock, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, &mreqn) - if err != nil { - syscall.Close(writeSock) //nolint:errcheck - for j := 0; j < i; j++ { - syscall.Close(writeSocks[j]) //nolint:errcheck - } - syscall.Close(readSock) //nolint:errcheck - return nil, err - } - - err = syscall.SetsockoptInt(writeSock, syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, multicastTTL) - if err != nil { - syscall.Close(writeSock) //nolint:errcheck - for j := 0; j < i; j++ { - syscall.Close(writeSocks[j]) //nolint:errcheck - } - syscall.Close(readSock) //nolint:errcheck - return nil, err - } - - writeSocks[i] = writeSock } readFile := os.NewFile(uintptr(readSock), "") readConn, _ := net.FilePacketConn(readFile) - writeFiles := make([]*os.File, len(writeSocks)) - writeConns := make([]net.PacketConn, len(writeSocks)) - - for i, writeSock := range writeSocks { - writeFiles[i] = os.NewFile(uintptr(writeSock), "") - writeConns[i], _ = net.FilePacketConn(writeFiles[i]) - } return &MultiConn{ addr: addr, diff --git a/server_udp_listener.go b/server_udp_listener.go index 4d5871d1..6c6082d4 100644 --- a/server_udp_listener.go +++ b/server_udp_listener.go @@ -76,7 +76,7 @@ func newServerUDPListener( var listenIP net.IP if multicastEnable { var err error - pc, err = multicast.NewMultiConn(address, listenPacket) + pc, err = multicast.NewMultiConn(address, false, listenPacket) if err != nil { return nil, err }