diff --git a/client_read_test.go b/client_read_test.go index 3bf37cea..931a938e 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -272,7 +272,7 @@ func TestClientRead(t *testing.T) { v := base.StreamDeliveryMulticast th.Delivery = &v th.Protocol = base.StreamProtocolUDP - v2 := "224.1.0.1" + v2 := net.ParseIP("224.1.0.1") th.Destination = &v2 th.Ports = &[2]int{25000, 25001} diff --git a/clientconn.go b/clientconn.go index 3d86e638..c708ba81 100644 --- a/clientconn.go +++ b/clientconn.go @@ -1317,12 +1317,12 @@ func (cc *ClientConn) doSetup( return nil, liberrors.ErrClientTransportHeaderNoDestination{} } - rtpListener, err = newClientConnUDPListener(cc, true, *thRes.Destination+":"+strconv.FormatInt(int64(thRes.Ports[0]), 10)) + rtpListener, err = newClientConnUDPListener(cc, true, thRes.Destination.String()+":"+strconv.FormatInt(int64(thRes.Ports[0]), 10)) if err != nil { return nil, err } - rtcpListener, err = newClientConnUDPListener(cc, true, *thRes.Destination+":"+strconv.FormatInt(int64(thRes.Ports[1]), 10)) + rtcpListener, err = newClientConnUDPListener(cc, true, thRes.Destination.String()+":"+strconv.FormatInt(int64(thRes.Ports[1]), 10)) if err != nil { rtpListener.close() return nil, err diff --git a/pkg/headers/transport.go b/pkg/headers/transport.go index be327619..308deaf0 100644 --- a/pkg/headers/transport.go +++ b/pkg/headers/transport.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "encoding/hex" "fmt" + "net" "strconv" "strings" @@ -29,8 +30,8 @@ type Transport struct { // (optional) delivery method of the stream Delivery *base.StreamDelivery - // (optional) destination - Destination *string + // (optional) destination IP + Destination *net.IP // (optional) interleaved frame ids InterleavedIDs *[2]int @@ -122,7 +123,11 @@ func (h *Transport) Read(v base.HeaderValue) error { h.Delivery = &v case "destination": - h.Destination = &v + ip := net.ParseIP(v) + if ip == nil { + return fmt.Errorf("invalid destination (%v)", v) + } + h.Destination = &ip case "interleaved": ports, err := parsePorts(v) @@ -231,7 +236,7 @@ func (h Transport) Write() base.HeaderValue { } if h.Destination != nil { - rets = append(rets, "destination="+*h.Destination) + rets = append(rets, "destination="+h.Destination.String()) } if h.InterleavedIDs != nil { diff --git a/pkg/headers/transport_test.go b/pkg/headers/transport_test.go index 3ab5820e..8357ab7f 100644 --- a/pkg/headers/transport_test.go +++ b/pkg/headers/transport_test.go @@ -1,6 +1,7 @@ package headers import ( + "net" "testing" "github.com/stretchr/testify/require" @@ -55,8 +56,8 @@ var casesTransport = []struct { v := base.StreamDeliveryMulticast return &v }(), - Destination: func() *string { - v := "225.219.201.15" + Destination: func() *net.IP { + v := net.ParseIP("225.219.201.15") return &v }(), TTL: func() *uint { @@ -212,6 +213,11 @@ func TestTransportReadErrors(t *testing.T) { base.HeaderValue{`RTP/AVP;unicast;ttl=aa`}, "strconv.ParseUint: parsing \"aa\": invalid syntax", }, + { + "invalid destination", + base.HeaderValue{`RTP/AVP;unicast;destination=aa`}, + "invalid destination (aa)", + }, { "invalid ports 1", base.HeaderValue{`RTP/AVP;unicast;port=aa`}, diff --git a/server_read_test.go b/server_read_test.go index 4e71c5fa..ca84d326 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -456,7 +456,7 @@ func TestServerRead(t *testing.T) { require.NoError(t, err) for _, intf := range intfs { - err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP(*th.Destination)}) + err := p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination}) require.NoError(t, err) } @@ -470,7 +470,7 @@ func TestServerRead(t *testing.T) { require.NoError(t, err) for _, intf := range intfs { - err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP(*th.Destination)}) + err := p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination}) require.NoError(t, err) } } @@ -532,7 +532,7 @@ func TestServerRead(t *testing.T) { case "multicast": l2.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ - IP: net.ParseIP(*th.Destination), + IP: *th.Destination, Port: th.Ports[1], }) <-framesReceived diff --git a/serversession.go b/serversession.go index 7de85539..6de3f316 100644 --- a/serversession.go +++ b/serversession.go @@ -675,7 +675,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base th.Delivery = &de v := uint(127) th.TTL = &v - d := stream.multicastListeners[trackID].rtpListener.ip().String() + d := stream.multicastListeners[trackID].rtpListener.ip() th.Destination = &d th.Ports = &[2]int{ stream.multicastListeners[trackID].rtpListener.port(),