move host resolution from headers to client/server (#883)

This commit is contained in:
Alessandro Ros
2025-09-07 15:39:02 +02:00
committed by GitHub
parent c466c342ba
commit 9db34842c8
7 changed files with 93 additions and 45 deletions

View File

@@ -1867,8 +1867,17 @@ func (c *Client) doSetup(
}
var remoteIP net.IP
if thRes.Source != nil {
remoteIP = *thRes.Source
if thRes.Source2 != nil {
if ip := net.ParseIP(*thRes.Source2); ip != nil {
remoteIP = ip
} else {
var addr *net.UDPAddr
addr, err = net.ResolveUDPAddr("udp", *thRes.Source2)
if err != nil {
return nil, fmt.Errorf("unable to solve source host: %w", err)
}
remoteIP = addr.IP
}
} else {
remoteIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
}
@@ -1902,21 +1911,41 @@ func (c *Client) doSetup(
return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{}
}
if thRes.Ports == nil {
return nil, liberrors.ErrClientTransportHeaderNoPorts{}
}
if thRes.Destination == nil {
return nil, liberrors.ErrClientTransportHeaderNoDestination{}
}
var remoteIP net.IP
if thRes.Source != nil {
remoteIP = *thRes.Source
if thRes.Source2 != nil {
if ip := net.ParseIP(*thRes.Source2); ip != nil {
remoteIP = ip
} else {
var addr *net.UDPAddr
addr, err = net.ResolveUDPAddr("udp", *thRes.Source2)
if err != nil {
return nil, fmt.Errorf("unable to solve source host: %w", err)
}
remoteIP = addr.IP
}
} else {
remoteIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
}
var destIP net.IP
if thRes.Destination2 == nil {
return nil, liberrors.ErrClientTransportHeaderNoDestination{}
}
if ip := net.ParseIP(*thRes.Destination2); ip != nil {
destIP = ip
} else {
var addr *net.UDPAddr
addr, err = net.ResolveUDPAddr("udp", *thRes.Destination2)
if err != nil {
return nil, fmt.Errorf("unable to solve destination host: %w", err)
}
destIP = addr.IP
}
if thRes.Ports == nil {
return nil, liberrors.ErrClientTransportHeaderNoPorts{}
}
var intf *net.Interface
intf, err = interfaceOfConn(c.nconn)
if err != nil {
@@ -1927,8 +1956,8 @@ func (c *Client) doSetup(
c,
true,
intf,
net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[0]), 10)),
net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[1]), 10)),
net.JoinHostPort(destIP.String(), strconv.FormatInt(int64(thRes.Ports[0]), 10)),
net.JoinHostPort(destIP.String(), strconv.FormatInt(int64(thRes.Ports[1]), 10)),
)
if err != nil {
return nil, err

View File

@@ -28,6 +28,10 @@ import (
"github.com/bluenviron/mediacommon/v2/pkg/codecs/mpeg4audio"
)
func stringPtr(v string) *string {
return &v
}
func ipPtr(v net.IP) *net.IP {
return &v
}
@@ -461,8 +465,7 @@ func TestClientPlay(t *testing.T) {
v := headers.TransportDeliveryMulticast
th.Delivery = &v
th.Protocol = headers.TransportProtocolUDP
v2 := net.ParseIP("224.1.0.1")
th.Destination = &v2
th.Destination2 = stringPtr("224.1.0.1")
th.Ports = &[2]int{25000 + i*2, 25001 + i*2}
l1s[i], err2 = net.ListenPacket("udp", net.JoinHostPort("224.0.0.0", strconv.FormatInt(int64(th.Ports[0]), 10)))

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("prot=mikey;uri")

View File

@@ -140,11 +140,21 @@ type Transport struct {
Delivery *TransportDelivery
// (optional) Source IP.
//
// Deprecated: replaced by Source2
Source *net.IP
// (optional) Source IP/host.
Source2 *string
// (optional) destination IP.
//
// Deprecated: replaced by Destination2
Destination *net.IP
// (optional) destination IP/host.
Destination2 *string
// (optional) interleaved frame IDs.
InterleavedIDs *[2]int
@@ -220,27 +230,18 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error {
case "source":
if v != "" {
ip := net.ParseIP(v)
if ip == nil {
addrs, err2 := net.LookupHost(v)
if err2 != nil {
return fmt.Errorf("invalid source (%v)", v)
}
ip = net.ParseIP(addrs[0])
if ip == nil {
return fmt.Errorf("invalid source (%v)", v)
}
if ip := net.ParseIP(v); ip != nil {
h.Source = &ip
}
h.Source = &ip
h.Source2 = &v
}
case "destination":
if v != "" {
ip := net.ParseIP(v)
if ip == nil {
return fmt.Errorf("invalid destination (%v)", v)
if ip := net.ParseIP(v); ip != nil {
h.Destination = &ip
}
h.Destination = &ip
h.Destination2 = &v
}
case "interleaved":
@@ -351,11 +352,21 @@ func (h Transport) Marshal() base.HeaderValue {
}
if h.Source != nil {
rets = append(rets, "source="+h.Source.String())
v := h.Source.String()
h.Source2 = &v
}
if h.Source2 != nil {
rets = append(rets, "source="+*h.Source2)
}
if h.Destination != nil {
rets = append(rets, "destination="+h.Destination.String())
v := h.Destination.String()
h.Destination2 = &v
}
if h.Destination2 != nil {
rets = append(rets, "destination="+*h.Destination2)
}
if h.InterleavedIDs != nil {

View File

@@ -54,11 +54,12 @@ var casesTransport = []struct {
base.HeaderValue{`RTP/AVP;multicast;destination=225.219.201.15;port=7000-7001;ttl=127`},
base.HeaderValue{`RTP/AVP;multicast;destination=225.219.201.15;port=7000-7001;ttl=127`},
Transport{
Protocol: TransportProtocolUDP,
Delivery: deliveryPtr(TransportDeliveryMulticast),
Destination: ipPtr(net.ParseIP("225.219.201.15")),
TTL: uintPtr(127),
Ports: &[2]int{7000, 7001},
Protocol: TransportProtocolUDP,
Delivery: deliveryPtr(TransportDeliveryMulticast),
Destination: ipPtr(net.ParseIP("225.219.201.15")),
Destination2: stringPtr("225.219.201.15"),
TTL: uintPtr(127),
Ports: &[2]int{7000, 7001},
},
},
{
@@ -94,6 +95,7 @@ var casesTransport = []struct {
ClientPorts: &[2]int{14186, 14187},
ServerPorts: &[2]int{5000, 5001},
Source: ipPtr(net.ParseIP("127.0.0.1")),
Source2: stringPtr("127.0.0.1"),
},
},
{
@@ -164,6 +166,7 @@ var casesTransport = []struct {
Protocol: TransportProtocolUDP,
Delivery: deliveryPtr(TransportDeliveryUnicast),
Source: ipPtr(net.ParseIP("172.16.8.2")),
Source2: stringPtr("172.16.8.2"),
ClientPorts: &[2]int{14236, 14237},
ServerPorts: &[2]int{56002, 56003},
},

View File

@@ -914,7 +914,7 @@ func TestServerPlay(t *testing.T) {
require.NoError(t, err)
for _, intf := range intfs {
err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination})
err = p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP(*th.Destination2)})
require.NoError(t, err)
}
@@ -928,7 +928,7 @@ func TestServerPlay(t *testing.T) {
require.NoError(t, err)
for _, intf := range intfs {
err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination})
err = p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP(*th.Destination2)})
require.NoError(t, err)
}
@@ -1042,7 +1042,7 @@ func TestServerPlay(t *testing.T) {
case "multicast":
_, err = l2.WriteTo(buf, &net.UDPAddr{
IP: *th.Destination,
IP: net.ParseIP(*th.Destination2),
Port: th.Ports[1],
})
require.NoError(t, err)
@@ -1234,7 +1234,7 @@ func TestServerPlaySocketError(t *testing.T) {
require.NoError(t, err)
for _, intf := range intfs {
err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination})
err = p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP(*th.Destination2)})
require.NoError(t, err)
}
@@ -1248,7 +1248,7 @@ func TestServerPlaySocketError(t *testing.T) {
require.NoError(t, err)
for _, intf := range intfs {
err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination})
err = p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP(*th.Destination2)})
require.NoError(t, err)
}

View File

@@ -1468,8 +1468,8 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
th.Delivery = &de
v := uint(127)
th.TTL = &v
d := stream.medias[medi].multicastWriter.ip()
th.Destination = &d
dest := stream.medias[medi].multicastWriter.ip().String()
th.Destination2 = &dest
th.Ports = &[2]int{ss.s.MulticastRTPPort, ss.s.MulticastRTCPPort}
}