From 27f8015ac619e2e6a1097e979f0562e8d05cb255 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Mon, 15 Aug 2022 00:10:21 +0200 Subject: [PATCH] simplify usage of Conn --- client.go | 107 +++++++++++------------ client_publish_test.go | 82 +++++++++--------- client_read_test.go | 187 ++++++++++++++++++++--------------------- client_test.go | 28 ++---- pkg/conn/conn.go | 98 +++++++++------------ pkg/conn/conn_test.go | 65 ++++++++------ server_publish_test.go | 6 +- server_read_test.go | 17 ++-- server_test.go | 10 +-- serverconn.go | 19 ++--- 10 files changed, 284 insertions(+), 335 deletions(-) diff --git a/client.go b/client.go index c483c37f..bf100aad 100644 --- a/client.go +++ b/client.go @@ -769,8 +769,7 @@ func (c *Client) runReader() { c.readerErr <- func() error { if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast { for { - var res base.Response - err := c.conn.ReadResponse(&res) + _, err := c.conn.ReadResponse() if err != nil { return err } @@ -852,17 +851,14 @@ func (c *Client) runReader() { } } - var frame base.InterleavedFrame - var res base.Response - for { - what, err := c.conn.ReadInterleavedFrameOrResponse(&frame, &res) + what, err := c.conn.ReadInterleavedFrameOrResponse() if err != nil { return err } - if _, ok := what.(*base.InterleavedFrame); ok { - channel := frame.Channel + if fr, ok := what.(*base.InterleavedFrame); ok { + channel := fr.Channel isRTP := true if (channel % 2) != 0 { channel-- @@ -874,7 +870,7 @@ func (c *Client) runReader() { continue } - err := processFunc(track, isRTP, frame.Payload) + err := processFunc(track, isRTP, fr.Payload) if err != nil { return err } @@ -1051,61 +1047,58 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba return nil, err } - var res base.Response + if skipResponse { + return nil, nil + } - if !skipResponse { - c.nconn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) + c.nconn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) + var res *base.Response + if allowFrames { + // read the response and ignore interleaved frames in between; + // interleaved frames are sent in two cases: + // * when the server is v4lrtspserver, before the PLAY response + // * when the stream is already playing + res, err = c.conn.ReadResponseIgnoreFrames() + } else { + res, err = c.conn.ReadResponse() + } + if err != nil { + return nil, err + } - if allowFrames { - // read the response and ignore interleaved frames in between; - // interleaved frames are sent in two cases: - // * when the server is v4lrtspserver, before the PLAY response - // * when the stream is already playing - err = c.conn.ReadResponseIgnoreFrames(&res) - if err != nil { - return nil, err - } - } else { - err = c.conn.ReadResponse(&res) - if err != nil { - return nil, err - } + if c.OnResponse != nil { + c.OnResponse(res) + } + + // get session from response + if v, ok := res.Header["Session"]; ok { + var sx headers.Session + err := sx.Unmarshal(v) + if err != nil { + return nil, liberrors.ErrClientSessionHeaderInvalid{Err: err} } + c.session = sx.Session - if c.OnResponse != nil { - c.OnResponse(&res) - } - - // get session from response - if v, ok := res.Header["Session"]; ok { - var sx headers.Session - err := sx.Unmarshal(v) - if err != nil { - return nil, liberrors.ErrClientSessionHeaderInvalid{Err: err} - } - c.session = sx.Session - - if sx.Timeout != nil && *sx.Timeout > 0 { - c.keepalivePeriod = time.Duration(float64(*sx.Timeout)*0.8) * time.Second - } - } - - // if required, send request again with authentication - if res.StatusCode == base.StatusUnauthorized && req.URL.User != nil && c.sender == nil { - pass, _ := req.URL.User.Password() - user := req.URL.User.Username() - - sender, err := auth.NewSender(res.Header["WWW-Authenticate"], user, pass) - if err != nil { - return nil, fmt.Errorf("unable to setup authentication: %s", err) - } - c.sender = sender - - return c.do(req, skipResponse, allowFrames) + if sx.Timeout != nil && *sx.Timeout > 0 { + c.keepalivePeriod = time.Duration(float64(*sx.Timeout)*0.8) * time.Second } } - return &res, nil + // if required, send request again with authentication + if res.StatusCode == base.StatusUnauthorized && req.URL.User != nil && c.sender == nil { + pass, _ := req.URL.User.Password() + user := req.URL.User.Username() + + sender, err := auth.NewSender(res.Header["WWW-Authenticate"], user, pass) + if err != nil { + return nil, fmt.Errorf("unable to setup authentication: %s", err) + } + c.sender = sender + + return c.do(req, skipResponse, allowFrames) + } + + return res, nil } func (c *Client) doOptions(u *url.URL) (*base.Response, error) { diff --git a/client_publish_test.go b/client_publish_test.go index 41ad0df0..eb369960 100644 --- a/client_publish_test.go +++ b/client_publish_test.go @@ -83,7 +83,7 @@ func TestClientPublishSerial(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) @@ -100,7 +100,7 @@ func TestClientPublishSerial(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Announce, req.Method) require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) @@ -110,7 +110,7 @@ func TestClientPublishSerial(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream/trackID=0"), req.URL) @@ -155,7 +155,7 @@ func TestClientPublishSerial(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Record, req.Method) require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) @@ -175,8 +175,7 @@ func TestClientPublishSerial(t *testing.T) { require.NoError(t, err) require.Equal(t, testRTPPacket, pkt) } else { - var f base.InterleavedFrame - err = conn.ReadInterleavedFrame(&f) + f, err := conn.ReadInterleavedFrame() require.NoError(t, err) require.Equal(t, 0, f.Channel) var pkt rtp.Packet @@ -199,7 +198,7 @@ func TestClientPublishSerial(t *testing.T) { require.NoError(t, err) } - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) @@ -293,7 +292,7 @@ func TestClientPublishParallel(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -309,7 +308,7 @@ func TestClientPublishParallel(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Announce, req.Method) @@ -318,7 +317,7 @@ func TestClientPublishParallel(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -350,7 +349,7 @@ func TestClientPublishParallel(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Record, req.Method) @@ -359,7 +358,7 @@ func TestClientPublishParallel(t *testing.T) { }) require.NoError(t, err) - req, err = readRequestIgnoreFrames(conn) + req, err = conn.ReadRequestIgnoreFrames() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -436,7 +435,7 @@ func TestClientPublishPauseSerial(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -453,7 +452,7 @@ func TestClientPublishPauseSerial(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Announce, req.Method) @@ -462,7 +461,7 @@ func TestClientPublishPauseSerial(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -494,7 +493,7 @@ func TestClientPublishPauseSerial(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Record, req.Method) @@ -503,7 +502,7 @@ func TestClientPublishPauseSerial(t *testing.T) { }) require.NoError(t, err) - req, err = readRequestIgnoreFrames(conn) + req, err = conn.ReadRequestIgnoreFrames() require.NoError(t, err) require.Equal(t, base.Pause, req.Method) @@ -512,7 +511,7 @@ func TestClientPublishPauseSerial(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Record, req.Method) @@ -521,7 +520,7 @@ func TestClientPublishPauseSerial(t *testing.T) { }) require.NoError(t, err) - req, err = readRequestIgnoreFrames(conn) + req, err = conn.ReadRequestIgnoreFrames() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -588,7 +587,7 @@ func TestClientPublishPauseParallel(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -605,7 +604,7 @@ func TestClientPublishPauseParallel(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Announce, req.Method) @@ -614,7 +613,7 @@ func TestClientPublishPauseParallel(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -646,7 +645,7 @@ func TestClientPublishPauseParallel(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Record, req.Method) @@ -655,7 +654,7 @@ func TestClientPublishPauseParallel(t *testing.T) { }) require.NoError(t, err) - req, err = readRequestIgnoreFrames(conn) + req, err = conn.ReadRequestIgnoreFrames() require.NoError(t, err) require.Equal(t, base.Pause, req.Method) @@ -727,7 +726,7 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) @@ -744,7 +743,7 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Announce, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) @@ -754,7 +753,7 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -763,7 +762,7 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -789,7 +788,7 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Record, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) @@ -799,8 +798,7 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - var f base.InterleavedFrame - err = conn.ReadInterleavedFrame(&f) + f, err := conn.ReadInterleavedFrame() require.NoError(t, err) require.Equal(t, 0, f.Channel) var pkt rtp.Packet @@ -808,7 +806,7 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { require.NoError(t, err) require.Equal(t, testRTPPacket, pkt) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -852,7 +850,7 @@ func TestClientPublishRTCPReport(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -868,7 +866,7 @@ func TestClientPublishRTCPReport(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Announce, req.Method) @@ -877,7 +875,7 @@ func TestClientPublishRTCPReport(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -909,7 +907,7 @@ func TestClientPublishRTCPReport(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Record, req.Method) @@ -942,7 +940,7 @@ func TestClientPublishRTCPReport(t *testing.T) { close(reportReceived) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -998,7 +996,7 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -1014,7 +1012,7 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Announce, req.Method) @@ -1023,7 +1021,7 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -1048,7 +1046,7 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Record, req.Method) @@ -1069,7 +1067,7 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) { }, make([]byte, 1024)) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) diff --git a/client_read_test.go b/client_read_test.go index c0d94051..22c72b2e 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -93,7 +93,7 @@ func TestClientReadTracks(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -109,7 +109,7 @@ func TestClientReadTracks(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) @@ -128,7 +128,7 @@ func TestClientReadTracks(t *testing.T) { require.NoError(t, err) for i := 0; i < 3; i++ { - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL(fmt.Sprintf("rtsp://localhost:8554/teststream/trackID=%d", i)), req.URL) @@ -156,7 +156,7 @@ func TestClientReadTracks(t *testing.T) { require.NoError(t, err) } - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/"), req.URL) @@ -166,7 +166,7 @@ func TestClientReadTracks(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/"), req.URL) @@ -223,7 +223,7 @@ func TestClientRead(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value"), req.URL) @@ -240,7 +240,7 @@ func TestClientRead(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value"), req.URL) @@ -264,7 +264,7 @@ func TestClientRead(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value/trackID=0"), req.URL) @@ -345,7 +345,7 @@ func TestClientRead(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value/"), req.URL) @@ -397,8 +397,7 @@ func TestClientRead(t *testing.T) { close(packetRecv) case "tcp", "tls": - var f base.InterleavedFrame - err := conn.ReadInterleavedFrame(&f) + f, err := conn.ReadInterleavedFrame() require.NoError(t, err) require.Equal(t, 1, f.Channel) packets, err := rtcp.Unmarshal(f.Payload) @@ -407,7 +406,7 @@ func TestClientRead(t *testing.T) { close(packetRecv) } - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value/"), req.URL) @@ -472,7 +471,7 @@ func TestClientReadPartial(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -488,7 +487,7 @@ func TestClientReadPartial(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL("rtsp://"+listenIP+":8554/teststream"), req.URL) @@ -518,7 +517,7 @@ func TestClientReadPartial(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL("rtsp://"+listenIP+":8554/teststream/trackID=1"), req.URL) @@ -545,7 +544,7 @@ func TestClientReadPartial(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) require.Equal(t, mustParseURL("rtsp://"+listenIP+":8554/teststream/"), req.URL) @@ -561,7 +560,7 @@ func TestClientReadPartial(t *testing.T) { }, make([]byte, 1024)) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) require.Equal(t, mustParseURL("rtsp://"+listenIP+":8554/teststream/"), req.URL) @@ -625,7 +624,7 @@ func TestClientReadContentBase(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -641,7 +640,7 @@ func TestClientReadContentBase(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) @@ -681,7 +680,7 @@ func TestClientReadContentBase(t *testing.T) { require.NoError(t, err) } - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL) @@ -708,7 +707,7 @@ func TestClientReadContentBase(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) @@ -718,7 +717,7 @@ func TestClientReadContentBase(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) @@ -762,7 +761,7 @@ func TestClientReadAnyPort(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -778,7 +777,7 @@ func TestClientReadAnyPort(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -801,7 +800,7 @@ func TestClientReadAnyPort(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -847,7 +846,7 @@ func TestClientReadAnyPort(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -920,7 +919,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -936,7 +935,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -959,7 +958,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -968,7 +967,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -992,7 +991,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -1046,7 +1045,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -1062,7 +1061,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -1076,7 +1075,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -1093,7 +1092,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL) @@ -1120,7 +1119,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -1129,7 +1128,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -1145,7 +1144,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -1161,7 +1160,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -1175,7 +1174,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -1189,7 +1188,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL) @@ -1218,7 +1217,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -1233,7 +1232,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }, make([]byte, 1024)) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -1276,7 +1275,7 @@ func TestClientReadDifferentInterleavedIDs(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -1292,7 +1291,7 @@ func TestClientReadDifferentInterleavedIDs(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) @@ -1316,7 +1315,7 @@ func TestClientReadDifferentInterleavedIDs(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL) @@ -1342,7 +1341,7 @@ func TestClientReadDifferentInterleavedIDs(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/"), req.URL) @@ -1358,7 +1357,7 @@ func TestClientReadDifferentInterleavedIDs(t *testing.T) { }, make([]byte, 1024)) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/"), req.URL) @@ -1419,7 +1418,7 @@ func TestClientReadRedirect(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -1435,7 +1434,7 @@ func TestClientReadRedirect(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -1454,7 +1453,7 @@ func TestClientReadRedirect(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -1471,7 +1470,7 @@ func TestClientReadRedirect(t *testing.T) { require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -1497,7 +1496,7 @@ func TestClientReadRedirect(t *testing.T) { }) require.NoError(t, err) } - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) authHeaderVal, exists := req.Header["Authorization"] require.True(t, exists) @@ -1526,7 +1525,7 @@ func TestClientReadRedirect(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -1550,7 +1549,7 @@ func TestClientReadRedirect(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -1647,7 +1646,7 @@ func TestClientReadPause(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -1663,7 +1662,7 @@ func TestClientReadPause(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -1686,7 +1685,7 @@ func TestClientReadPause(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -1718,7 +1717,7 @@ func TestClientReadPause(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -1729,7 +1728,7 @@ func TestClientReadPause(t *testing.T) { writerTerminate, writerDone := writeFrames(&inTH, conn) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Pause, req.Method) @@ -1741,7 +1740,7 @@ func TestClientReadPause(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -1752,7 +1751,7 @@ func TestClientReadPause(t *testing.T) { writerTerminate, writerDone = writeFrames(&inTH, conn) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -1821,7 +1820,7 @@ func TestClientReadRTCPReport(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -1837,7 +1836,7 @@ func TestClientReadRTCPReport(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -1860,7 +1859,7 @@ func TestClientReadRTCPReport(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -1892,7 +1891,7 @@ func TestClientReadRTCPReport(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -1960,7 +1959,7 @@ func TestClientReadRTCPReport(t *testing.T) { close(reportReceived) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -2002,7 +2001,7 @@ func TestClientReadErrorTimeout(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -2018,7 +2017,7 @@ func TestClientReadErrorTimeout(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -2041,7 +2040,7 @@ func TestClientReadErrorTimeout(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -2079,7 +2078,7 @@ func TestClientReadErrorTimeout(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -2096,7 +2095,7 @@ func TestClientReadErrorTimeout(t *testing.T) { }) } - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -2154,7 +2153,7 @@ func TestClientReadIgnoreTCPInvalidTrack(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -2170,7 +2169,7 @@ func TestClientReadIgnoreTCPInvalidTrack(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -2193,7 +2192,7 @@ func TestClientReadIgnoreTCPInvalidTrack(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -2218,7 +2217,7 @@ func TestClientReadIgnoreTCPInvalidTrack(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -2239,7 +2238,7 @@ func TestClientReadIgnoreTCPInvalidTrack(t *testing.T) { }, make([]byte, 1024)) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -2283,7 +2282,7 @@ func TestClientReadSeek(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -2299,7 +2298,7 @@ func TestClientReadSeek(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -2322,7 +2321,7 @@ func TestClientReadSeek(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -2347,7 +2346,7 @@ func TestClientReadSeek(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -2365,7 +2364,7 @@ func TestClientReadSeek(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Pause, req.Method) @@ -2374,7 +2373,7 @@ func TestClientReadSeek(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -2391,7 +2390,7 @@ func TestClientReadSeek(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -2455,7 +2454,7 @@ func TestClientReadKeepaliveFromSession(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -2471,7 +2470,7 @@ func TestClientReadKeepaliveFromSession(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -2494,7 +2493,7 @@ func TestClientReadKeepaliveFromSession(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -2525,7 +2524,7 @@ func TestClientReadKeepaliveFromSession(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -2537,7 +2536,7 @@ func TestClientReadKeepaliveFromSession(t *testing.T) { recv := make(chan struct{}) go func() { defer close(recv) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -2582,7 +2581,7 @@ func TestClientReadDifferentSource(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/test/stream?param=value"), req.URL) @@ -2599,7 +2598,7 @@ func TestClientReadDifferentSource(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/test/stream?param=value"), req.URL) @@ -2623,7 +2622,7 @@ func TestClientReadDifferentSource(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/test/stream?param=value/trackID=0"), req.URL) @@ -2662,7 +2661,7 @@ func TestClientReadDifferentSource(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/test/stream?param=value/"), req.URL) @@ -2680,7 +2679,7 @@ func TestClientReadDifferentSource(t *testing.T) { Port: th.ClientPorts[0], }) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/test/stream?param=value/"), req.URL) diff --git a/client_test.go b/client_test.go index 7bcae613..2a1df6dd 100644 --- a/client_test.go +++ b/client_test.go @@ -22,18 +22,6 @@ func mustParseURL(s string) *url.URL { return u } -func readRequest(conn *conn.Conn) (*base.Request, error) { - var req base.Request - err := conn.ReadRequest(&req) - return &req, err -} - -func readRequestIgnoreFrames(conn *conn.Conn) (*base.Request, error) { - var req base.Request - err := conn.ReadRequestIgnoreFrames(&req) - return &req, err -} - func TestClientTLSSetServerName(t *testing.T) { l, err := net.Listen("tcp", "localhost:8554") require.NoError(t, err) @@ -96,7 +84,7 @@ func TestClientSession(t *testing.T) { conn := conn.NewConn(nconn) defer nconn.Close() - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -111,7 +99,7 @@ func TestClientSession(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -165,7 +153,7 @@ func TestClientAuth(t *testing.T) { conn := conn.NewConn(nconn) defer nconn.Close() - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -179,7 +167,7 @@ func TestClientAuth(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -193,7 +181,7 @@ func TestClientAuth(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -247,7 +235,7 @@ func TestClientDescribeCharset(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) @@ -261,7 +249,7 @@ func TestClientDescribeCharset(t *testing.T) { }) require.NoError(t, err) - req, err = readRequest(conn) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) @@ -347,7 +335,7 @@ func TestClientCloseDuringRequest(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - req, err := readRequest(conn) + req, err := conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Options, req.Method) diff --git a/pkg/conn/conn.go b/pkg/conn/conn.go index 9c4b4a0a..d6b60ea5 100644 --- a/pkg/conn/conn.go +++ b/pkg/conn/conn.go @@ -13,8 +13,11 @@ const ( // Conn is a RTSP TCP connection. type Conn struct { - w io.Writer - br *bufio.Reader + w io.Writer + br *bufio.Reader + req base.Request + res base.Response + fr base.InterleavedFrame } // NewConn allocates a Conn. @@ -25,26 +28,26 @@ func NewConn(rw io.ReadWriter) *Conn { } } -// ReadResponse reads a Response. -func (c *Conn) ReadResponse(res *base.Response) error { - return res.Read(c.br) +// ReadRequest reads a Request. +func (c *Conn) ReadRequest() (*base.Request, error) { + err := c.req.Read(c.br) + return &c.req, err } -// ReadRequest reads a Request. -func (c *Conn) ReadRequest(req *base.Request) error { - return req.Read(c.br) +// ReadResponse reads a Response. +func (c *Conn) ReadResponse() (*base.Response, error) { + err := c.res.Read(c.br) + return &c.res, err } // ReadInterleavedFrame reads a InterleavedFrame. -func (c *Conn) ReadInterleavedFrame(fr *base.InterleavedFrame) error { - return fr.Read(c.br) +func (c *Conn) ReadInterleavedFrame() (*base.InterleavedFrame, error) { + err := c.fr.Read(c.br) + return &c.fr, err } // ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Request. -func (c *Conn) ReadInterleavedFrameOrRequest( - frame *base.InterleavedFrame, - req *base.Request, -) (interface{}, error) { +func (c *Conn) ReadInterleavedFrameOrRequest() (interface{}, error) { b, err := c.br.ReadByte() if err != nil { return nil, err @@ -52,26 +55,14 @@ func (c *Conn) ReadInterleavedFrameOrRequest( c.br.UnreadByte() if b == base.InterleavedFrameMagicByte { - err := frame.Read(c.br) - if err != nil { - return nil, err - } - return frame, err + return c.ReadInterleavedFrame() } - err = req.Read(c.br) - if err != nil { - return nil, err - } - - return req, nil + return c.ReadRequest() } // ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response. -func (c *Conn) ReadInterleavedFrameOrResponse( - frame *base.InterleavedFrame, - res *base.Response, -) (interface{}, error) { +func (c *Conn) ReadInterleavedFrameOrResponse() (interface{}, error) { b, err := c.br.ReadByte() if err != nil { return nil, err @@ -79,49 +70,36 @@ func (c *Conn) ReadInterleavedFrameOrResponse( c.br.UnreadByte() if b == base.InterleavedFrameMagicByte { - err := frame.Read(c.br) + return c.ReadInterleavedFrame() + } + + return c.ReadResponse() +} + +// ReadRequestIgnoreFrames reads a Request and ignores frames in between. +func (c *Conn) ReadRequestIgnoreFrames() (*base.Request, error) { + for { + recv, err := c.ReadInterleavedFrameOrRequest() if err != nil { return nil, err } - return frame, err - } - err = res.Read(c.br) - if err != nil { - return nil, err - } - - return res, nil -} - -// ReadRequestIgnoreFrames reads a Request and ignore frames in between. -func (c *Conn) ReadRequestIgnoreFrames(req *base.Request) error { - var f base.InterleavedFrame - - for { - recv, err := c.ReadInterleavedFrameOrRequest(&f, req) - if err != nil { - return err - } - - if _, ok := recv.(*base.Request); ok { - return nil + if req, ok := recv.(*base.Request); ok { + return req, nil } } } -// ReadResponseIgnoreFrames reads a Response and ignore frames in between. -func (c *Conn) ReadResponseIgnoreFrames(res *base.Response) error { - var f base.InterleavedFrame - +// ReadResponseIgnoreFrames reads a Response and ignores frames in between. +func (c *Conn) ReadResponseIgnoreFrames() (*base.Response, error) { for { - recv, err := c.ReadInterleavedFrameOrResponse(&f, res) + recv, err := c.ReadInterleavedFrameOrResponse() if err != nil { - return err + return nil, err } - if _, ok := recv.(*base.Response); ok { - return nil + if res, ok := recv.(*base.Response); ok { + return res, nil } } } diff --git a/pkg/conn/conn_test.go b/pkg/conn/conn_test.go index c4b27eea..350aa812 100644 --- a/pkg/conn/conn_test.go +++ b/pkg/conn/conn_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/url" ) func TestReadInterleavedFrameOrRequest(t *testing.T) { @@ -16,17 +17,29 @@ func TestReadInterleavedFrameOrRequest(t *testing.T) { "\r\n") byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...) - var f base.InterleavedFrame - var req base.Request conn := NewConn(bytes.NewBuffer(byts)) - out, err := conn.ReadInterleavedFrameOrRequest(&f, &req) + out, err := conn.ReadInterleavedFrameOrRequest() require.NoError(t, err) - require.Equal(t, &req, out) + require.Equal(t, &base.Request{ + Method: base.Describe, + URL: &url.URL{ + Scheme: "rtsp", + Host: "example.com", + Path: "/media.mp4", + }, + Header: base.Header{ + "Accept": base.HeaderValue{"application/sdp"}, + "CSeq": base.HeaderValue{"2"}, + }, + }, out) - out, err = conn.ReadInterleavedFrameOrRequest(&f, &req) + out, err = conn.ReadInterleavedFrameOrRequest() require.NoError(t, err) - require.Equal(t, &f, out) + require.Equal(t, &base.InterleavedFrame{ + Channel: 6, + Payload: []byte{0x01, 0x02, 0x03, 0x04}, + }, out) } func TestReadInterleavedFrameOrRequestErrors(t *testing.T) { @@ -52,10 +65,8 @@ func TestReadInterleavedFrameOrRequestErrors(t *testing.T) { }, } { t.Run(ca.name, func(t *testing.T) { - var f base.InterleavedFrame - var req base.Request conn := NewConn(bytes.NewBuffer(ca.byts)) - _, err := conn.ReadInterleavedFrameOrRequest(&f, &req) + _, err := conn.ReadInterleavedFrameOrRequest() require.EqualError(t, err, ca.err) }) } @@ -68,17 +79,25 @@ func TestReadInterleavedFrameOrResponse(t *testing.T) { "\r\n") byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...) - var f base.InterleavedFrame - var res base.Response conn := NewConn(bytes.NewBuffer(byts)) - out, err := conn.ReadInterleavedFrameOrResponse(&f, &res) + out, err := conn.ReadInterleavedFrameOrResponse() require.NoError(t, err) - require.Equal(t, &res, out) + require.Equal(t, &base.Response{ + StatusCode: 200, + StatusMessage: "OK", + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Public": base.HeaderValue{"DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE"}, + }, + }, out) - out, err = conn.ReadInterleavedFrameOrResponse(&f, &res) + out, err = conn.ReadInterleavedFrameOrResponse() require.NoError(t, err) - require.Equal(t, &f, out) + require.Equal(t, &base.InterleavedFrame{ + Channel: 6, + Payload: []byte{0x01, 0x02, 0x03, 0x04}, + }, out) } func TestReadInterleavedFrameOrResponseErrors(t *testing.T) { @@ -104,10 +123,8 @@ func TestReadInterleavedFrameOrResponseErrors(t *testing.T) { }, } { t.Run(ca.name, func(t *testing.T) { - var f base.InterleavedFrame - var res base.Response conn := NewConn(bytes.NewBuffer(ca.byts)) - _, err := conn.ReadInterleavedFrameOrResponse(&f, &res) + _, err := conn.ReadInterleavedFrameOrResponse() require.EqualError(t, err, ca.err) }) } @@ -122,8 +139,7 @@ func TestReadRequestIgnoreFrames(t *testing.T) { "\r\n")...) conn := NewConn(bytes.NewBuffer(byts)) - var req base.Request - err := conn.ReadRequestIgnoreFrames(&req) + _, err := conn.ReadRequestIgnoreFrames() require.NoError(t, err) } @@ -131,8 +147,7 @@ func TestReadRequestIgnoreFramesErrors(t *testing.T) { byts := []byte{0x25} conn := NewConn(bytes.NewBuffer(byts)) - var req base.Request - err := conn.ReadRequestIgnoreFrames(&req) + _, err := conn.ReadRequestIgnoreFrames() require.EqualError(t, err, "EOF") } @@ -144,8 +159,7 @@ func TestReadResponseIgnoreFrames(t *testing.T) { "\r\n")...) conn := NewConn(bytes.NewBuffer(byts)) - var res base.Response - err := conn.ReadResponseIgnoreFrames(&res) + _, err := conn.ReadResponseIgnoreFrames() require.NoError(t, err) } @@ -153,7 +167,6 @@ func TestReadResponseIgnoreFramesErrors(t *testing.T) { byts := []byte{0x25} conn := NewConn(bytes.NewBuffer(byts)) - var res base.Response - err := conn.ReadResponseIgnoreFrames(&res) + _, err := conn.ReadResponseIgnoreFrames() require.EqualError(t, err, "EOF") } diff --git a/server_publish_test.go b/server_publish_test.go index fdb53211..1fe600d9 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -753,8 +753,7 @@ func TestServerPublish(t *testing.T) { require.NoError(t, err) require.Equal(t, testRTCPPacketMarshaled, buf[:n]) } else { - var f base.InterleavedFrame - err := conn.ReadInterleavedFrame(&f) + f, err := conn.ReadInterleavedFrame() require.NoError(t, err) require.Equal(t, 1, f.Channel) require.Equal(t, testRTCPPacketMarshaled, f.Payload) @@ -803,8 +802,7 @@ func TestServerPublish(t *testing.T) { require.NoError(t, err) require.Equal(t, testRTCPPacketMarshaled, buf[:n]) } else { - var f base.InterleavedFrame - err := conn.ReadInterleavedFrame(&f) + f, err := conn.ReadInterleavedFrame() require.NoError(t, err) require.Equal(t, 1, f.Channel) require.Equal(t, testRTCPPacketMarshaled, f.Payload) diff --git a/server_read_test.go b/server_read_test.go index 4a9815be..48fd7ef3 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -518,8 +518,7 @@ func TestServerRead(t *testing.T) { require.Equal(t, testRTCPPacketMarshaled, buf[:n]) case "tcp", "tls": - var f base.InterleavedFrame - err := conn.ReadInterleavedFrame(&f) + f, err := conn.ReadInterleavedFrame() require.NoError(t, err) switch f.Channel { @@ -546,10 +545,8 @@ func TestServerRead(t *testing.T) { require.NoError(t, err) require.Equal(t, testRTCPPacketMarshaled, buf[:n]) } else { - var f base.InterleavedFrame - for i := 0; i < 2; i++ { - err := conn.ReadInterleavedFrame(&f) + f, err := conn.ReadInterleavedFrame() require.NoError(t, err) switch f.Channel { @@ -899,8 +896,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - var f base.InterleavedFrame - err = conn.ReadInterleavedFrame(&f) + _, err = conn.ReadInterleavedFrame() require.NoError(t, err) } @@ -1227,7 +1223,7 @@ func TestServerReadPlayPausePause(t *testing.T) { }) require.NoError(t, err) - res, err = readResIgnoreFrames(conn) + res, err = conn.ReadResponseIgnoreFrames() require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) @@ -1241,7 +1237,7 @@ func TestServerReadPlayPausePause(t *testing.T) { }) require.NoError(t, err) - res, err = readResIgnoreFrames(conn) + res, err = conn.ReadResponseIgnoreFrames() require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) } @@ -1667,8 +1663,7 @@ func TestServerReadPartialTracks(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - var f base.InterleavedFrame - err = conn.ReadInterleavedFrame(&f) + f, err := conn.ReadInterleavedFrame() require.NoError(t, err) require.Equal(t, 4, f.Channel) require.Equal(t, testRTPPacketMarshaled, f.Payload) diff --git a/server_test.go b/server_test.go index f51ebfaf..05481111 100644 --- a/server_test.go +++ b/server_test.go @@ -76,15 +76,7 @@ func writeReqReadRes( return nil, err } - var res base.Response - err = conn.ReadResponse(&res) - return &res, err -} - -func readResIgnoreFrames(conn *conn.Conn) (*base.Response, error) { - var res base.Response - err := conn.ReadResponseIgnoreFrames(&res) - return &res, err + return conn.ReadResponse() } type testServerHandler struct { diff --git a/serverconn.go b/serverconn.go index 51e75228..cefe55fd 100644 --- a/serverconn.go +++ b/serverconn.go @@ -187,17 +187,15 @@ func (sc *ServerConn) readFuncStandard(readRequest chan readReq) error { // reset deadline sc.nconn.SetReadDeadline(time.Time{}) - var req base.Request - for { - err := sc.conn.ReadRequest(&req) + req, err := sc.conn.ReadRequest() if err != nil { return err } cres := make(chan error) select { - case readRequest <- readReq{req: &req, res: cres}: + case readRequest <- readReq{req: req, res: cres}: err = <-cres if err != nil { return err @@ -294,22 +292,19 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { } } - var req base.Request - var frame base.InterleavedFrame - for { if sc.session.state == ServerSessionStateRecord { sc.nconn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout)) } - what, err := sc.conn.ReadInterleavedFrameOrRequest(&frame, &req) + what, err := sc.conn.ReadInterleavedFrameOrRequest() if err != nil { return err } - switch what.(type) { + switch twhat := what.(type) { case *base.InterleavedFrame: - channel := frame.Channel + channel := twhat.Channel isRTP := true if (channel % 2) != 0 { channel-- @@ -318,7 +313,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { // forward frame only if it has been set up if trackID, ok := sc.session.tcpTracksByChannel[channel]; ok { - err := processFunc(trackID, isRTP, frame.Payload) + err := processFunc(trackID, isRTP, twhat.Payload) if err != nil { return err } @@ -327,7 +322,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { case *base.Request: cres := make(chan error) select { - case readRequest <- readReq{req: &req, res: cres}: + case readRequest <- readReq{req: twhat, res: cres}: err := <-cres if err != nil { return err