From 9db34842c8de292e1768c58f60fa55d7008d0e55 Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Sun, 7 Sep 2025 15:39:02 +0200 Subject: [PATCH] move host resolution from headers to client/server (#883) --- client.go | 57 ++++++++++++++----- client_play_test.go | 7 ++- .../FuzzKeyMgmtUnmarshal/e466018bd7205d8b | 2 + pkg/headers/transport.go | 45 +++++++++------ pkg/headers/transport_test.go | 13 +++-- server_play_test.go | 10 ++-- server_session.go | 4 +- 7 files changed, 93 insertions(+), 45 deletions(-) create mode 100644 pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/e466018bd7205d8b diff --git a/client.go b/client.go index 66fb8c3e..223eea83 100644 --- a/client.go +++ b/client.go @@ -1867,8 +1867,17 @@ func (c *Client) doSetup( } var remoteIP net.IP - if thRes.Source != nil { - remoteIP = *thRes.Source + if thRes.Source2 != nil { + if ip := net.ParseIP(*thRes.Source2); ip != nil { + remoteIP = ip + } else { + var addr *net.UDPAddr + addr, err = net.ResolveUDPAddr("udp", *thRes.Source2) + if err != nil { + return nil, fmt.Errorf("unable to solve source host: %w", err) + } + remoteIP = addr.IP + } } else { remoteIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP } @@ -1902,21 +1911,41 @@ func (c *Client) doSetup( return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{} } - if thRes.Ports == nil { - return nil, liberrors.ErrClientTransportHeaderNoPorts{} - } - - if thRes.Destination == nil { - return nil, liberrors.ErrClientTransportHeaderNoDestination{} - } - var remoteIP net.IP - if thRes.Source != nil { - remoteIP = *thRes.Source + if thRes.Source2 != nil { + if ip := net.ParseIP(*thRes.Source2); ip != nil { + remoteIP = ip + } else { + var addr *net.UDPAddr + addr, err = net.ResolveUDPAddr("udp", *thRes.Source2) + if err != nil { + return nil, fmt.Errorf("unable to solve source host: %w", err) + } + remoteIP = addr.IP + } } else { remoteIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP } + var destIP net.IP + if thRes.Destination2 == nil { + return nil, liberrors.ErrClientTransportHeaderNoDestination{} + } + if ip := net.ParseIP(*thRes.Destination2); ip != nil { + destIP = ip + } else { + var addr *net.UDPAddr + addr, err = net.ResolveUDPAddr("udp", *thRes.Destination2) + if err != nil { + return nil, fmt.Errorf("unable to solve destination host: %w", err) + } + destIP = addr.IP + } + + if thRes.Ports == nil { + return nil, liberrors.ErrClientTransportHeaderNoPorts{} + } + var intf *net.Interface intf, err = interfaceOfConn(c.nconn) if err != nil { @@ -1927,8 +1956,8 @@ func (c *Client) doSetup( c, true, intf, - net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[0]), 10)), - net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[1]), 10)), + net.JoinHostPort(destIP.String(), strconv.FormatInt(int64(thRes.Ports[0]), 10)), + net.JoinHostPort(destIP.String(), strconv.FormatInt(int64(thRes.Ports[1]), 10)), ) if err != nil { return nil, err diff --git a/client_play_test.go b/client_play_test.go index 63d5b596..34e131b1 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -28,6 +28,10 @@ import ( "github.com/bluenviron/mediacommon/v2/pkg/codecs/mpeg4audio" ) +func stringPtr(v string) *string { + return &v +} + func ipPtr(v net.IP) *net.IP { return &v } @@ -461,8 +465,7 @@ func TestClientPlay(t *testing.T) { v := headers.TransportDeliveryMulticast th.Delivery = &v th.Protocol = headers.TransportProtocolUDP - v2 := net.ParseIP("224.1.0.1") - th.Destination = &v2 + th.Destination2 = stringPtr("224.1.0.1") th.Ports = &[2]int{25000 + i*2, 25001 + i*2} l1s[i], err2 = net.ListenPacket("udp", net.JoinHostPort("224.0.0.0", strconv.FormatInt(int64(th.Ports[0]), 10))) diff --git a/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/e466018bd7205d8b b/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/e466018bd7205d8b new file mode 100644 index 00000000..6d28ca41 --- /dev/null +++ b/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/e466018bd7205d8b @@ -0,0 +1,2 @@ +go test fuzz v1 +string("prot=mikey;uri") diff --git a/pkg/headers/transport.go b/pkg/headers/transport.go index 9cfafd44..39d5384f 100644 --- a/pkg/headers/transport.go +++ b/pkg/headers/transport.go @@ -140,11 +140,21 @@ type Transport struct { Delivery *TransportDelivery // (optional) Source IP. + // + // Deprecated: replaced by Source2 Source *net.IP + // (optional) Source IP/host. + Source2 *string + // (optional) destination IP. + // + // Deprecated: replaced by Destination2 Destination *net.IP + // (optional) destination IP/host. + Destination2 *string + // (optional) interleaved frame IDs. InterleavedIDs *[2]int @@ -220,27 +230,18 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error { case "source": if v != "" { - ip := net.ParseIP(v) - if ip == nil { - addrs, err2 := net.LookupHost(v) - if err2 != nil { - return fmt.Errorf("invalid source (%v)", v) - } - ip = net.ParseIP(addrs[0]) - if ip == nil { - return fmt.Errorf("invalid source (%v)", v) - } + if ip := net.ParseIP(v); ip != nil { + h.Source = &ip } - h.Source = &ip + h.Source2 = &v } case "destination": if v != "" { - ip := net.ParseIP(v) - if ip == nil { - return fmt.Errorf("invalid destination (%v)", v) + if ip := net.ParseIP(v); ip != nil { + h.Destination = &ip } - h.Destination = &ip + h.Destination2 = &v } case "interleaved": @@ -351,11 +352,21 @@ func (h Transport) Marshal() base.HeaderValue { } if h.Source != nil { - rets = append(rets, "source="+h.Source.String()) + v := h.Source.String() + h.Source2 = &v + } + + if h.Source2 != nil { + rets = append(rets, "source="+*h.Source2) } if h.Destination != nil { - rets = append(rets, "destination="+h.Destination.String()) + v := h.Destination.String() + h.Destination2 = &v + } + + if h.Destination2 != nil { + rets = append(rets, "destination="+*h.Destination2) } if h.InterleavedIDs != nil { diff --git a/pkg/headers/transport_test.go b/pkg/headers/transport_test.go index fa1fdcbe..a6247ffb 100644 --- a/pkg/headers/transport_test.go +++ b/pkg/headers/transport_test.go @@ -54,11 +54,12 @@ var casesTransport = []struct { base.HeaderValue{`RTP/AVP;multicast;destination=225.219.201.15;port=7000-7001;ttl=127`}, base.HeaderValue{`RTP/AVP;multicast;destination=225.219.201.15;port=7000-7001;ttl=127`}, Transport{ - Protocol: TransportProtocolUDP, - Delivery: deliveryPtr(TransportDeliveryMulticast), - Destination: ipPtr(net.ParseIP("225.219.201.15")), - TTL: uintPtr(127), - Ports: &[2]int{7000, 7001}, + Protocol: TransportProtocolUDP, + Delivery: deliveryPtr(TransportDeliveryMulticast), + Destination: ipPtr(net.ParseIP("225.219.201.15")), + Destination2: stringPtr("225.219.201.15"), + TTL: uintPtr(127), + Ports: &[2]int{7000, 7001}, }, }, { @@ -94,6 +95,7 @@ var casesTransport = []struct { ClientPorts: &[2]int{14186, 14187}, ServerPorts: &[2]int{5000, 5001}, Source: ipPtr(net.ParseIP("127.0.0.1")), + Source2: stringPtr("127.0.0.1"), }, }, { @@ -164,6 +166,7 @@ var casesTransport = []struct { Protocol: TransportProtocolUDP, Delivery: deliveryPtr(TransportDeliveryUnicast), Source: ipPtr(net.ParseIP("172.16.8.2")), + Source2: stringPtr("172.16.8.2"), ClientPorts: &[2]int{14236, 14237}, ServerPorts: &[2]int{56002, 56003}, }, diff --git a/server_play_test.go b/server_play_test.go index 78651aca..e045feaa 100644 --- a/server_play_test.go +++ b/server_play_test.go @@ -914,7 +914,7 @@ func TestServerPlay(t *testing.T) { require.NoError(t, err) for _, intf := range intfs { - err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination}) + err = p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP(*th.Destination2)}) require.NoError(t, err) } @@ -928,7 +928,7 @@ func TestServerPlay(t *testing.T) { require.NoError(t, err) for _, intf := range intfs { - err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination}) + err = p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP(*th.Destination2)}) require.NoError(t, err) } @@ -1042,7 +1042,7 @@ func TestServerPlay(t *testing.T) { case "multicast": _, err = l2.WriteTo(buf, &net.UDPAddr{ - IP: *th.Destination, + IP: net.ParseIP(*th.Destination2), Port: th.Ports[1], }) require.NoError(t, err) @@ -1234,7 +1234,7 @@ func TestServerPlaySocketError(t *testing.T) { require.NoError(t, err) for _, intf := range intfs { - err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination}) + err = p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP(*th.Destination2)}) require.NoError(t, err) } @@ -1248,7 +1248,7 @@ func TestServerPlaySocketError(t *testing.T) { require.NoError(t, err) for _, intf := range intfs { - err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination}) + err = p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP(*th.Destination2)}) require.NoError(t, err) } diff --git a/server_session.go b/server_session.go index 1d74f590..3d8f6a2d 100644 --- a/server_session.go +++ b/server_session.go @@ -1468,8 +1468,8 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( th.Delivery = &de v := uint(127) th.TTL = &v - d := stream.medias[medi].multicastWriter.ip() - th.Destination = &d + dest := stream.medias[medi].multicastWriter.ip().String() + th.Destination2 = &dest th.Ports = &[2]int{ss.s.MulticastRTPPort, ss.s.MulticastRTCPPort} }