From 30b43a9eadedba4c4afa69f8eea79cd515fb320f Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Thu, 21 Jan 2021 09:02:15 +0100 Subject: [PATCH] support servers which don't provide the server-port field (#21) --- clientconf_test.go | 194 ++++++++++++++++++++++++--------------------- clientconn.go | 34 +++++--- 2 files changed, 124 insertions(+), 104 deletions(-) diff --git a/clientconf_test.go b/clientconf_test.go index 1c36197b..fd8e3d8c 100644 --- a/clientconf_test.go +++ b/clientconf_test.go @@ -149,114 +149,126 @@ func TestClientDialRead(t *testing.T) { } } -func TestClientDialReadZeroServerPorts(t *testing.T) { - l, err := net.Listen("tcp", "localhost:8554") - require.NoError(t, err) - defer l.Close() +func TestClientDialReadNoServerPorts(t *testing.T) { + for _, ca := range []string{ + "zero", + "no", + } { + t.Run(ca, func(t *testing.T) { + l, err := net.Listen("tcp", "localhost:8554") + require.NoError(t, err) + defer l.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) - conn, err := l.Accept() - require.NoError(t, err) - defer conn.Close() - bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + conn, err := l.Accept() + require.NoError(t, err) + defer conn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) - var req base.Request - err = req.Read(bconn.Reader) - require.NoError(t, err) - require.Equal(t, base.Options, req.Method) + var req base.Request + err = req.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.Options, req.Method) - err = base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Public": base.HeaderValue{strings.Join([]string{ - string(base.Describe), - string(base.Setup), - string(base.Play), - }, ", ")}, - }, - }.Write(bconn.Writer) - require.NoError(t, err) + err = base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + string(base.Setup), + string(base.Play), + }, ", ")}, + }, + }.Write(bconn.Writer) + require.NoError(t, err) - err = req.Read(bconn.Reader) - require.NoError(t, err) - require.Equal(t, base.Describe, req.Method) + err = req.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.Describe, req.Method) - track, err := NewTrackH264(96, []byte("123456"), []byte("123456")) - require.NoError(t, err) - sdp := Tracks{track}.Write() + track, err := NewTrackH264(96, []byte("123456"), []byte("123456")) + require.NoError(t, err) + sdp := Tracks{track}.Write() - err = base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Content-Type": base.HeaderValue{"application/sdp"}, - }, - Body: sdp, - }.Write(bconn.Writer) - require.NoError(t, err) + err = base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Body: sdp, + }.Write(bconn.Writer) + require.NoError(t, err) - err = req.Read(bconn.Reader) - require.NoError(t, err) - require.Equal(t, base.Setup, req.Method) + err = req.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.Setup, req.Method) - th, err := headers.ReadTransport(req.Header["Transport"]) - require.NoError(t, err) + th, err := headers.ReadTransport(req.Header["Transport"]) + require.NoError(t, err) - err = base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Transport": headers.Transport{ - Protocol: StreamProtocolUDP, - Delivery: func() *base.StreamDelivery { - v := base.StreamDeliveryUnicast - return &v - }(), - ClientPorts: th.ClientPorts, - ServerPorts: &[2]int{0, 0}, - }.Write(), - }, - }.Write(bconn.Writer) - require.NoError(t, err) + err = base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": headers.Transport{ + Protocol: StreamProtocolUDP, + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + ClientPorts: th.ClientPorts, + ServerPorts: func() *[2]int { + if ca == "zero" { + return &[2]int{0, 0} + } + return nil + }(), + }.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) - err = req.Read(bconn.Reader) - require.NoError(t, err) - require.Equal(t, base.Play, req.Method) + err = req.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.Play, req.Method) - err = base.Response{ - StatusCode: base.StatusOK, - }.Write(bconn.Writer) - require.NoError(t, err) + err = base.Response{ + StatusCode: base.StatusOK, + }.Write(bconn.Writer) + require.NoError(t, err) - time.Sleep(1 * time.Second) + time.Sleep(1 * time.Second) - l1, err := net.ListenPacket("udp", "localhost:0") - require.NoError(t, err) - defer l1.Close() + l1, err := net.ListenPacket("udp", "localhost:0") + require.NoError(t, err) + defer l1.Close() - l1.WriteTo([]byte("\x00\x00\x00\x00"), &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: th.ClientPorts[0], + l1.WriteTo([]byte("\x00\x00\x00\x00"), &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: th.ClientPorts[0], + }) + }() + + conf := ClientConf{ + AnyPortEnable: true, + } + + conn, err := conf.DialRead("rtsp://localhost:8554/teststream") + require.NoError(t, err) + + frameRecv := make(chan struct{}) + done := conn.ReadFrames(func(id int, typ StreamType, payload []byte) { + close(frameRecv) + }) + + <-frameRecv + conn.Close() + <-done }) - }() - - conf := ClientConf{ - AnyPortEnable: true, } - - conn, err := conf.DialRead("rtsp://localhost:8554/teststream") - require.NoError(t, err) - - frameRecv := make(chan struct{}) - done := conn.ReadFrames(func(id int, typ StreamType, payload []byte) { - close(frameRecv) - }) - - <-frameRecv - conn.Close() - <-done } func TestClientDialReadAutomaticProtocol(t *testing.T) { diff --git a/clientconn.go b/clientconn.go index 745bac7f..421563b8 100644 --- a/clientconn.go +++ b/clientconn.go @@ -575,21 +575,25 @@ func (c *ClientConn) Setup(mode headers.TransportMode, track *Track, } if proto == StreamProtocolUDP { - if thRes.ServerPorts == nil { - rtpListener.close() - rtcpListener.close() - return nil, fmt.Errorf("server ports have not been provided. Use AnyPortEnable to communicate with this server") - } - - if (thRes.ServerPorts[0] == 0 && thRes.ServerPorts[1] != 0) || - (thRes.ServerPorts[0] != 0 && thRes.ServerPorts[1] == 0) { - rtpListener.close() - rtcpListener.close() - return nil, fmt.Errorf("server ports must be both zero or both not zero") + if thRes.ServerPorts != nil { + if (thRes.ServerPorts[0] == 0 && thRes.ServerPorts[1] != 0) || + (thRes.ServerPorts[0] != 0 && thRes.ServerPorts[1] == 0) { + rtpListener.close() + rtcpListener.close() + return nil, fmt.Errorf("server ports must be both zero or both not zero") + } } if !c.conf.AnyPortEnable { + if thRes.ServerPorts == nil { + rtpListener.close() + rtcpListener.close() + return nil, fmt.Errorf("server ports have not been provided. Use AnyPortEnable to communicate with this server") + } + if thRes.ServerPorts[0] == 0 && thRes.ServerPorts[1] == 0 { + rtpListener.close() + rtcpListener.close() return nil, fmt.Errorf("server ports have not been provided. Use AnyPortEnable to communicate with this server") } } @@ -623,14 +627,18 @@ func (c *ClientConn) Setup(mode headers.TransportMode, track *Track, if proto == StreamProtocolUDP { rtpListener.remoteIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP rtpListener.remoteZone = c.nconn.RemoteAddr().(*net.TCPAddr).Zone - rtpListener.remotePort = (*thRes.ServerPorts)[0] + if thRes.ServerPorts != nil { + rtpListener.remotePort = (*thRes.ServerPorts)[0] + } rtpListener.trackID = track.ID rtpListener.streamType = StreamTypeRTP c.udpRTPListeners[track.ID] = rtpListener rtcpListener.remoteIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP rtcpListener.remoteZone = c.nconn.RemoteAddr().(*net.TCPAddr).Zone - rtcpListener.remotePort = (*thRes.ServerPorts)[1] + if thRes.ServerPorts != nil { + rtcpListener.remotePort = (*thRes.ServerPorts)[1] + } rtcpListener.trackID = track.ID rtcpListener.streamType = StreamTypeRTCP c.udpRTCPListeners[track.ID] = rtcpListener