feat(mcast-backend): add config parsing and update test

Signed-off-by: Adam Rizkalla <ajarizzo@gmail.com>
This commit is contained in:
Adam Rizkalla
2025-03-10 20:13:08 +00:00
committed by Steffen Vogel
parent b0ea76d023
commit 672f352e5f
4 changed files with 80 additions and 37 deletions

View File

@@ -14,7 +14,6 @@ import (
signalingproto "cunicu.li/cunicu/pkg/proto/signaling"
"cunicu.li/cunicu/pkg/signaling"
"go.uber.org/zap"
"golang.org/x/net/ipv4"
"google.golang.org/protobuf/proto"
)
@@ -28,8 +27,7 @@ func init() { //nolint:gochecknoinits
type Backend struct {
signaling.SubscriptionsRegistry
send_conn net.PacketConn
recv_conn *net.UDPConn
conn *net.UDPConn
mcast_addr *net.UDPAddr
config BackendConfig
@@ -42,49 +40,38 @@ func NewBackend(cfg *signaling.BackendConfig, logger *log.Logger) (signaling.Bac
logger: logger,
}
//if err := b.config.Parse(cfg); err != nil {
// return nil, fmt.Errorf("failed to parse backend configuration: %w", err)
//}
var err error
// Parse multicast group
if b.mcast_addr, err = net.ResolveUDPAddr("udp", "224.0.0.1:9999"); err != nil {
if err = b.config.Parse(cfg); err != nil {
return nil, fmt.Errorf("failed to parse backend configuration: %w", err)
}
// Parse multicast group address
if b.mcast_addr, err = net.ResolveUDPAddr("udp", b.config.Target); err != nil {
return nil, fmt.Errorf("Error parsing multicast address: %w", err)
}
// Bind to any available local UDP port for sending to multicast group
if b.send_conn, err = net.ListenPacket("udp", ":0"); err != nil {
return nil, fmt.Errorf("Error binding to local address: %w", err)
}
p := ipv4.NewPacketConn(b.send_conn)
if err := p.JoinGroup(nil, b.mcast_addr); err != nil {
return nil, fmt.Errorf("Error joining multicast group: %w", err)
}
// Add listener for multicast group
if b.recv_conn, err = net.ListenMulticastUDP("udp", nil, b.mcast_addr); err != nil {
if b.conn, err = net.ListenMulticastUDP("udp", b.config.Options.Interface, b.mcast_addr); err != nil {
return nil, fmt.Errorf("Error adding multicast listener: %w", err)
}
// Enable multicast loopback
fd, _ := b.recv_conn.File()
syscall.SetsockoptInt(int(fd.Fd()), syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, 1)
//syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, 1)
if b.config.Options.Loopback {
// Enable multicast loopback
fd, _ := b.conn.File()
syscall.SetsockoptInt(int(fd.Fd()), syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, 1)
}
go func() {
buf := make([]byte, 4096)
for {
n, _, err := b.recv_conn.ReadFrom(buf)
n, err := b.conn.Read(buf)
if err != nil {
if err == net.ErrClosed {
break
}
b.logger.Error("Error reading from UDPConn", zap.Error(err))
break
//continue
continue
}
var env signalingproto.Envelope
@@ -135,7 +122,7 @@ func (b *Backend) Publish(ctx context.Context, kp *crypto.KeyPair, msg *signalin
return fmt.Errorf("Error marshaling protobuf: %w", err)
}
if _, err = b.send_conn.WriteTo(data, b.mcast_addr); err != nil {
if _, err = b.conn.WriteTo(data, b.mcast_addr); err != nil {
return fmt.Errorf("failed to publish message: %w", err)
}
@@ -143,7 +130,9 @@ func (b *Backend) Publish(ctx context.Context, kp *crypto.KeyPair, msg *signalin
}
func (b *Backend) Close() error {
//return fmt.Errorf("Close() called")
// NOTE: Do not close the connection; on certain OS (like Linux),
// The UDPConn.Read() will continue to block even if the connection
// is closed
//if err := b.conn.Close(); err != nil {
// return fmt.Errorf("failed to close multicast connection: %w", err)
//}

View File

@@ -22,7 +22,9 @@ func TestSuite(t *testing.T) {
var _ = Describe("Multicast backend", func() {
u := url.URL{
Scheme: "multicast",
Scheme: "multicast",
Host: "239.0.0.1:9999",
RawQuery: "interface=lo&loopback=true",
}
test.BackendTest(&u, 10)

View File

@@ -4,23 +4,31 @@
package mcast
import (
"fmt"
"net"
"cunicu.li/cunicu/pkg/signaling"
)
type BackendOptions struct {
Interface *net.Interface
Loopback bool
}
type BackendConfig struct {
signaling.BackendConfig
Target string
Loopback bool
Target string
Options BackendOptions
}
func (c *BackendConfig) Parse(cfg *signaling.BackendConfig) (err error) {
c.BackendConfig = *cfg
//c.Target, c.Loopback, err = ParseURL(c.BackendConfig.URI.String())
//if err != nil {
// return fmt.Errorf("failed to parse multicast URL: %w", err)
//}
c.Target, c.Options, err = ParseURL(c.BackendConfig.URI.String())
if err != nil {
return fmt.Errorf("failed to parse multicast URL: %w", err)
}
return nil
}

View File

@@ -3,3 +3,47 @@
// Package mcast implements a signaling backend using multicast
package mcast
import (
"errors"
"fmt"
"net"
"net/url"
"strconv"
)
var errInvalidAddress = errors.New("missing multicast address")
func ParseURL(urlStr string) (string, BackendOptions, error) {
o := BackendOptions{
Interface: nil,
Loopback: false,
}
u, err := url.Parse(urlStr)
if err != nil {
return "", o, err
}
q := u.Query()
if q.Has("interface") {
if o.Interface, err = net.InterfaceByName(q.Get("interface")); err != nil {
return "", o, fmt.Errorf("failed to parse 'interface' option: %w", err)
}
}
if q.Has("loopback") {
var err error
if o.Loopback, err = strconv.ParseBool(q.Get("loopback")); err != nil {
return "", o, fmt.Errorf("failed to parse 'loopback' option: %w", err)
}
}
if u.Host == "" {
return "", o, errInvalidAddress
}
return u.Host, o, nil
}