From 06bed24dd9e07023ab68ffe0c50e19f763297780 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sun, 14 Aug 2022 23:43:01 +0200 Subject: [PATCH] add intermediate layer between net.Conn and client / server --- client.go | 77 +- client_publish_test.go | 329 ++++---- client_read_test.go | 1304 +++++++++++++---------------- client_test.go | 93 +- constants.go | 6 - pkg/base/interleavedframe.go | 68 +- pkg/base/interleavedframe_test.go | 112 +-- pkg/base/request.go | 17 - pkg/base/request_test.go | 23 - pkg/base/response.go | 17 - pkg/base/response_test.go | 22 - pkg/conn/conn.go | 148 ++++ pkg/conn/conn_test.go | 159 ++++ server_publish_test.go | 193 +++-- server_read_test.go | 243 +++--- server_test.go | 155 ++-- serverconn.go | 36 +- serversession.go | 18 +- 18 files changed, 1459 insertions(+), 1561 deletions(-) create mode 100644 pkg/conn/conn.go create mode 100644 pkg/conn/conn_test.go diff --git a/client.go b/client.go index aa42c632..c483c37f 100644 --- a/client.go +++ b/client.go @@ -8,7 +8,6 @@ Examples are available at https://github.com/aler9/gortsplib/tree/master/example package gortsplib import ( - "bufio" "context" "crypto/tls" "fmt" @@ -24,6 +23,7 @@ import ( "github.com/aler9/gortsplib/pkg/auth" "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/conn" "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/ringbuffer" @@ -256,8 +256,8 @@ type Client struct { ctx context.Context ctxCancel func() state clientState - conn net.Conn - br *bufio.Reader + nconn net.Conn + conn *conn.Conn session string sender *auth.Sender cseq int @@ -581,11 +581,13 @@ func (c *Client) doClose() { URL: c.baseURL, }, true, false) - c.conn.Close() + c.nconn.Close() + c.nconn = nil c.conn = nil - } else if c.conn != nil { + } else if c.nconn != nil { c.connCloserStop() - c.conn.Close() + c.nconn.Close() + c.nconn = nil c.conn = nil } @@ -756,7 +758,7 @@ func (c *Client) playRecordStart() { // for some reason, SetReadDeadline() must always be called in the same // goroutine, otherwise Read() freezes. // therefore, we disable the deadline and perform a check with a ticker. - c.conn.SetReadDeadline(time.Time{}) + c.nconn.SetReadDeadline(time.Time{}) // start reader c.readerErr = make(chan error) @@ -768,7 +770,7 @@ func (c *Client) runReader() { if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast { for { var res base.Response - err := res.Read(c.br) + err := c.conn.ReadResponse(&res) if err != nil { return err } @@ -854,7 +856,7 @@ func (c *Client) runReader() { var res base.Response for { - what, err := base.ReadInterleavedFrameOrResponse(&frame, tcpMaxFramePayloadSize, &res, c.br) + what, err := c.conn.ReadInterleavedFrameOrResponse(&frame, &res) if err != nil { return err } @@ -885,7 +887,7 @@ func (c *Client) runReader() { func (c *Client) playRecordStop(isClosing bool) { // stop reader if c.readerErr != nil { - c.conn.SetReadDeadline(time.Now()) + c.nconn.SetReadDeadline(time.Now()) <-c.readerErr } @@ -963,7 +965,7 @@ func (c *Client) connOpen() error { return err } - c.conn = func() net.Conn { + c.nconn = func() net.Conn { if c.scheme == "rtsps" { tlsConfig := c.TLSConfig @@ -979,7 +981,8 @@ func (c *Client) connOpen() error { return nconn }() - c.br = bufio.NewReaderSize(c.conn, tcpReadBufferSize) + c.conn = conn.NewConn(c.nconn) + c.connCloserStart() return nil } @@ -993,7 +996,7 @@ func (c *Client) connCloserStart() { select { case <-c.ctx.Done(): - c.conn.Close() + c.nconn.Close() case <-c.connCloserTerminate: } @@ -1007,7 +1010,7 @@ func (c *Client) connCloserStop() { } func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*base.Response, error) { - if c.conn == nil { + if c.nconn == nil { err := c.connOpen() if err != nil { return nil, err @@ -1042,10 +1045,8 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba c.OnRequest(req) } - byts, _ := req.Marshal() - - c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) - _, err := c.conn.Write(byts) + c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) + err := c.conn.WriteRequest(req) if err != nil { return nil, err } @@ -1053,19 +1054,19 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba var res base.Response if !skipResponse { - c.conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) + c.nconn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) if allowFrames { // read the response and ignore interleaved frames in between; // interleaved frames are sent in two cases: // * when the server is v4lrtspserver, before the PLAY response // * when the stream is already playing - err = res.ReadIgnoreFrames(tcpMaxFramePayloadSize, c.br) + err = c.conn.ReadResponseIgnoreFrames(&res) if err != nil { return nil, err } } else { - err = res.Read(c.br) + err = c.conn.ReadResponse(&res) if err != nil { return nil, err } @@ -1491,13 +1492,13 @@ func (c *Client) doSetup( if thRes.Source != nil { return *thRes.Source } - return c.conn.RemoteAddr().(*net.TCPAddr).IP + return c.nconn.RemoteAddr().(*net.TCPAddr).IP }() if thRes.ServerPorts != nil { ct.udpRTPListener.readPort = thRes.ServerPorts[0] ct.udpRTPListener.writeAddr = &net.UDPAddr{ - IP: c.conn.RemoteAddr().(*net.TCPAddr).IP, - Zone: c.conn.RemoteAddr().(*net.TCPAddr).Zone, + IP: c.nconn.RemoteAddr().(*net.TCPAddr).IP, + Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone, Port: thRes.ServerPorts[0], } } @@ -1506,13 +1507,13 @@ func (c *Client) doSetup( if thRes.Source != nil { return *thRes.Source } - return c.conn.RemoteAddr().(*net.TCPAddr).IP + return c.nconn.RemoteAddr().(*net.TCPAddr).IP }() if thRes.ServerPorts != nil { ct.udpRTCPListener.readPort = thRes.ServerPorts[1] ct.udpRTCPListener.writeAddr = &net.UDPAddr{ - IP: c.conn.RemoteAddr().(*net.TCPAddr).IP, - Zone: c.conn.RemoteAddr().(*net.TCPAddr).Zone, + IP: c.nconn.RemoteAddr().(*net.TCPAddr).IP, + Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone, Port: thRes.ServerPorts[1], } } @@ -1551,14 +1552,14 @@ func (c *Client) doSetup( return nil, err } - ct.udpRTPListener.readIP = c.conn.RemoteAddr().(*net.TCPAddr).IP + ct.udpRTPListener.readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP ct.udpRTPListener.readPort = thRes.Ports[0] ct.udpRTPListener.writeAddr = &net.UDPAddr{ IP: *thRes.Destination, Port: thRes.Ports[0], } - ct.udpRTCPListener.readIP = c.conn.RemoteAddr().(*net.TCPAddr).IP + ct.udpRTCPListener.readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP ct.udpRTCPListener.readPort = thRes.Ports[1] ct.udpRTCPListener.writeAddr = &net.UDPAddr{ IP: *thRes.Destination, @@ -1848,19 +1849,17 @@ func (c *Client) runWriter() { writeFunc = func(trackID int, isRTP bool, payload []byte) { if isRTP { - f := rtpFrames[trackID] - f.Payload = payload - n, _ := f.MarshalTo(buf) + fr := rtpFrames[trackID] + fr.Payload = payload - c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) - c.conn.Write(buf[:n]) + c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) + c.conn.WriteInterleavedFrame(fr, buf) } else { - f := rtcpFrames[trackID] - f.Payload = payload - n, _ := f.MarshalTo(buf) + fr := rtcpFrames[trackID] + fr.Payload = payload - c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) - c.conn.Write(buf[:n]) + c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) + c.conn.WriteInterleavedFrame(fr, buf) } } } diff --git a/client_publish_test.go b/client_publish_test.go index d842c672..41ad0df0 100644 --- a/client_publish_test.go +++ b/client_publish_test.go @@ -1,7 +1,6 @@ package gortsplib import ( - "bufio" "crypto/tls" "net" "strings" @@ -13,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/conn" "github.com/aler9/gortsplib/pkg/headers" ) @@ -78,17 +78,17 @@ func TestClientPublishSerial(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -97,22 +97,20 @@ func TestClientPublishSerial(t *testing.T) { string(base.Record), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Announce, req.Method) require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream/trackID=0"), req.URL) @@ -149,24 +147,22 @@ func TestClientPublishSerial(t *testing.T) { th.InterleavedIDs = inTH.InterleavedIDs } - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Record, req.Method) require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) // client -> server (RTP) @@ -180,7 +176,7 @@ func TestClientPublishSerial(t *testing.T) { require.Equal(t, testRTPPacket, pkt) } else { var f base.InterleavedFrame - err = f.Read(1024, br) + err = conn.ReadInterleavedFrame(&f) require.NoError(t, err) require.Equal(t, 0, f.Channel) var pkt rtp.Packet @@ -196,23 +192,21 @@ func TestClientPublishSerial(t *testing.T) { Port: th.ClientPorts[1], }) } else { - byts, _ := base.InterleavedFrame{ + err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 1, Payload: testRTCPPacketMarshaled, - }.Marshal() - _, err = conn.Write(byts) + }, make([]byte, 1024)) require.NoError(t, err) } - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -294,16 +288,16 @@ func TestClientPublishParallel(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -312,21 +306,19 @@ func TestClientPublishParallel(t *testing.T) { string(base.Record), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Announce, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -350,33 +342,30 @@ func TestClientPublishParallel(t *testing.T) { th.InterleavedIDs = inTH.InterleavedIDs } - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Record, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequestIgnoreFrames(br) + req, err = readRequestIgnoreFrames(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -442,16 +431,16 @@ func TestClientPublishPauseSerial(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -461,21 +450,19 @@ func TestClientPublishPauseSerial(t *testing.T) { string(base.Pause), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Announce, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -499,53 +486,48 @@ func TestClientPublishPauseSerial(t *testing.T) { th.InterleavedIDs = inTH.InterleavedIDs } - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Record, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequestIgnoreFrames(br) + req, err = readRequestIgnoreFrames(conn) require.NoError(t, err) require.Equal(t, base.Pause, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Record, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequestIgnoreFrames(br) + req, err = readRequestIgnoreFrames(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -601,16 +583,16 @@ func TestClientPublishPauseParallel(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -620,21 +602,19 @@ func TestClientPublishPauseParallel(t *testing.T) { string(base.Pause), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Announce, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -658,33 +638,30 @@ func TestClientPublishPauseParallel(t *testing.T) { th.InterleavedIDs = inTH.InterleavedIDs } - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Record, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequestIgnoreFrames(br) + req, err = readRequestIgnoreFrames(conn) require.NoError(t, err) require.Equal(t, base.Pause, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -745,17 +722,17 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -764,32 +741,29 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { string(base.Record), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Announce, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusUnsupportedTransport, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -807,28 +781,26 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { InterleavedIDs: &[2]int{0, 1}, } - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Record, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) var f base.InterleavedFrame - err = f.Read(2048, br) + err = conn.ReadInterleavedFrame(&f) require.NoError(t, err) require.Equal(t, 0, f.Channel) var pkt rtp.Packet @@ -836,14 +808,13 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { require.NoError(t, err) require.Equal(t, testRTPPacket, pkt) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -876,16 +847,16 @@ func TestClientPublishRTCPReport(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -894,21 +865,19 @@ func TestClientPublishRTCPReport(t *testing.T) { string(base.Record), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Announce, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -924,7 +893,7 @@ func TestClientPublishRTCPReport(t *testing.T) { require.NoError(t, err) defer l2.Close() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": headers.Transport{ @@ -937,18 +906,16 @@ func TestClientPublishRTCPReport(t *testing.T) { ServerPorts: &[2]int{34556, 34557}, }.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Record, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) buf := make([]byte, 2048) @@ -975,14 +942,13 @@ func TestClientPublishRTCPReport(t *testing.T) { close(reportReceived) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -1027,16 +993,16 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -1045,21 +1011,19 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) { string(base.Record), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Announce, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -1076,47 +1040,42 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) { InterleavedIDs: inTH.InterleavedIDs, } - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Record, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - byts, _ = base.InterleavedFrame{ + conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 0, Payload: testRTPPacketMarshaled, - }.Marshal() - _, err = conn.Write(byts) + }, make([]byte, 1024)) require.NoError(t, err) - byts, _ = base.InterleavedFrame{ + conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 1, Payload: testRTCPPacketMarshaled, - }.Marshal() - _, err = conn.Write(byts) + }, make([]byte, 1024)) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() diff --git a/client_read_test.go b/client_read_test.go index f5d31f6d..c0d94051 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -1,7 +1,6 @@ package gortsplib import ( - "bufio" "crypto/tls" "fmt" "net" @@ -17,6 +16,7 @@ import ( "github.com/aler9/gortsplib/pkg/auth" "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/conn" "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/mpeg4audio" "github.com/aler9/gortsplib/pkg/url" @@ -88,16 +88,16 @@ func TestClientReadTracks(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -106,11 +106,10 @@ func TestClientReadTracks(t *testing.T) { string(base.Play), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) @@ -118,19 +117,18 @@ func TestClientReadTracks(t *testing.T) { tracks := Tracks{track1, track2, track3} tracks.setControls() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) for i := 0; i < 3; i++ { - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL(fmt.Sprintf("rtsp://localhost:8554/teststream/trackID=%d", i)), req.URL) @@ -149,36 +147,33 @@ func TestClientReadTracks(t *testing.T) { ServerPorts: &[2]int{34556 + i*2, 34557 + i*2}, } - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) } - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/"), req.URL) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/"), req.URL) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -223,17 +218,17 @@ func TestClientRead(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value"), req.URL) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -242,11 +237,10 @@ func TestClientRead(t *testing.T) { string(base.Play), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value"), req.URL) @@ -260,18 +254,17 @@ func TestClientRead(t *testing.T) { tracks := Tracks{track} tracks.setControls() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, "Content-Base": base.HeaderValue{scheme + "://" + listenIP + ":8554/test/stream?param=value/"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value/trackID=0"), req.URL) @@ -344,25 +337,23 @@ func TestClientRead(t *testing.T) { th.InterleavedIDs = &[2]int{0, 1} } - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value/"), req.URL) require.Equal(t, base.HeaderValue{"npt=0-"}, req.Header["Range"]) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) // server -> client (RTP) @@ -380,11 +371,10 @@ func TestClientRead(t *testing.T) { }) case "tcp", "tls": - byts, _ := base.InterleavedFrame{ + err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 0, Payload: testRTPPacketMarshaled, - }.Marshal() - _, err = conn.Write(byts) + }, make([]byte, 1024)) require.NoError(t, err) } @@ -408,7 +398,7 @@ func TestClientRead(t *testing.T) { case "tcp", "tls": var f base.InterleavedFrame - err := f.Read(2048, br) + err := conn.ReadInterleavedFrame(&f) require.NoError(t, err) require.Equal(t, 1, f.Channel) packets, err := rtcp.Unmarshal(f.Payload) @@ -417,15 +407,14 @@ func TestClientRead(t *testing.T) { close(packetRecv) } - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value/"), req.URL) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -478,16 +467,16 @@ func TestClientReadPartial(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -496,11 +485,10 @@ func TestClientReadPartial(t *testing.T) { string(base.Play), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL("rtsp://"+listenIP+":8554/teststream"), req.URL) @@ -520,18 +508,17 @@ func TestClientReadPartial(t *testing.T) { tracks := Tracks{track1, track2} tracks.setControls() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, "Content-Base": base.HeaderValue{"rtsp://" + listenIP + ":8554/teststream/"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL("rtsp://"+listenIP+":8554/teststream/trackID=1"), req.URL) @@ -550,42 +537,38 @@ func TestClientReadPartial(t *testing.T) { InterleavedIDs: inTH.InterleavedIDs, } - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) require.Equal(t, mustParseURL("rtsp://"+listenIP+":8554/teststream/"), req.URL) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - byts, _ = base.InterleavedFrame{ + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 0, Payload: testRTPPacketMarshaled, - }.Marshal() - _, err = conn.Write(byts) + }, make([]byte, 1024)) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) require.Equal(t, mustParseURL("rtsp://"+listenIP+":8554/teststream/"), req.URL) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -637,16 +620,16 @@ func TestClientReadContentBase(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -655,11 +638,10 @@ func TestClientReadContentBase(t *testing.T) { string(base.Play), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) @@ -675,33 +657,31 @@ func TestClientReadContentBase(t *testing.T) { switch ca { case "absent": - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) case "inside control attribute": body := string(tracks.Marshal(false)) body = strings.Replace(body, "t=0 0", "t=0 0\r\na=control:rtsp://localhost:8554/teststream", 1) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream2/"}, }, Body: []byte(body), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) } - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL) @@ -720,35 +700,32 @@ func TestClientReadContentBase(t *testing.T) { ServerPorts: &[2]int{34556, 34557}, } - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -780,16 +757,16 @@ func TestClientReadAnyPort(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -798,11 +775,10 @@ func TestClientReadAnyPort(t *testing.T) { string(base.Play), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -815,18 +791,17 @@ func TestClientReadAnyPort(t *testing.T) { tracks := Tracks{track} tracks.setControls() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -842,7 +817,7 @@ func TestClientReadAnyPort(t *testing.T) { require.NoError(t, err) defer l1b.Close() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": headers.Transport{ @@ -869,18 +844,16 @@ func TestClientReadAnyPort(t *testing.T) { }(), }.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) time.Sleep(500 * time.Millisecond) @@ -942,16 +915,16 @@ func TestClientReadAutomaticProtocol(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -960,11 +933,10 @@ func TestClientReadAutomaticProtocol(t *testing.T) { string(base.Play), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -977,28 +949,26 @@ func TestClientReadAutomaticProtocol(t *testing.T) { tracks := Tracks{track} tracks.setControls() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusUnsupportedTransport, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -1007,7 +977,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { require.NoError(t, err) require.Equal(t, headers.TransportProtocolTCP, inTH.Protocol) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": headers.Transport{ @@ -1019,25 +989,22 @@ func TestClientReadAutomaticProtocol(t *testing.T) { InterleavedIDs: &[2]int{0, 1}, }.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - byts, _ = base.InterleavedFrame{ + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 0, Payload: testRTPPacketMarshaled, - }.Marshal() - _, err = conn.Write(byts) + }, make([]byte, 1024)) require.NoError(t, err) }() @@ -1066,228 +1033,215 @@ func TestClientReadAutomaticProtocol(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() - require.NoError(t, err) - br := bufio.NewReader(conn) - - req, err := readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Options, req.Method) - - byts, _ := base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Public": base.HeaderValue{strings.Join([]string{ - string(base.Describe), - string(base.Setup), - string(base.Play), - }, ", ")}, - }, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) - - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Describe, req.Method) - - v := auth.NewValidator("myuser", "mypass", nil) - - byts, _ = base.Response{ - StatusCode: base.StatusUnauthorized, - Header: base.Header{ - "WWW-Authenticate": v.Header(), - }, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) - - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Describe, req.Method) - - err = v.ValidateRequest(req) - require.NoError(t, err) - - track := &TrackH264{ + tracks := Tracks{&TrackH264{ PayloadType: 96, SPS: []byte{0x01, 0x02, 0x03, 0x04}, PPS: []byte{0x01, 0x02, 0x03, 0x04}, - } - - tracks := Tracks{track} + }} tracks.setControls() - byts, _ = base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Content-Type": base.HeaderValue{"application/sdp"}, - "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, - }, - Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) + func() { + nconn, err := l.Accept() + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Setup, req.Method) - require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL) + req, err := readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Options, req.Method) - var inTH headers.Transport - err = inTH.Unmarshal(req.Header["Transport"]) - require.NoError(t, err) + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + string(base.Setup), + string(base.Play), + }, ", ")}, + }, + }) + require.NoError(t, err) - th := headers.Transport{ - Delivery: func() *headers.TransportDelivery { - v := headers.TransportDeliveryUnicast - return &v - }(), - Protocol: headers.TransportProtocolUDP, - ServerPorts: &[2]int{34556, 34557}, - ClientPorts: inTH.ClientPorts, - } + req, err = readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Describe, req.Method) - byts, _ = base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Transport": th.Marshal(), - }, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) + v := auth.NewValidator("myuser", "mypass", nil) - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Play, req.Method) + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusUnauthorized, + Header: base.Header{ + "WWW-Authenticate": v.Header(), + }, + }) + require.NoError(t, err) - byts, _ = base.Response{ - StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) + req, err = readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Describe, req.Method) - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Teardown, req.Method) + err = v.ValidateRequest(req) + require.NoError(t, err) - byts, _ = base.Response{ - StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"application/sdp"}, + "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, + }, + Body: tracks.Marshal(false), + }) + require.NoError(t, err) - conn.Close() + req, err = readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Setup, req.Method) + require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL) - conn, err = l.Accept() - require.NoError(t, err) - br = bufio.NewReader(conn) + var inTH headers.Transport + err = inTH.Unmarshal(req.Header["Transport"]) + require.NoError(t, err) - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Options, req.Method) + th := headers.Transport{ + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + Protocol: headers.TransportProtocolUDP, + ServerPorts: &[2]int{34556, 34557}, + ClientPorts: inTH.ClientPorts, + } - byts, _ = base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Public": base.HeaderValue{strings.Join([]string{ - string(base.Describe), - string(base.Setup), - string(base.Play), - }, ", ")}, - }, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": th.Marshal(), + }, + }) + require.NoError(t, err) - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Describe, req.Method) + req, err = readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Play, req.Method) - byts, _ = base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Content-Type": base.HeaderValue{"application/sdp"}, - "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, - }, - Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err) - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Setup, req.Method) + req, err = readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Teardown, req.Method) - v = auth.NewValidator("myuser", "mypass", nil) + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err) + }() - byts, _ = base.Response{ - StatusCode: base.StatusUnauthorized, - Header: base.Header{ - "WWW-Authenticate": v.Header(), - }, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) + func() { + nconn, err := l.Accept() + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Setup, req.Method) - require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL) + req, err := readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Options, req.Method) - err = v.ValidateRequest(req) - require.NoError(t, err) + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + string(base.Setup), + string(base.Play), + }, ", ")}, + }, + }) + require.NoError(t, err) - inTH = headers.Transport{} - err = inTH.Unmarshal(req.Header["Transport"]) - require.NoError(t, err) + req, err = readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Describe, req.Method) - th = headers.Transport{ - Delivery: func() *headers.TransportDelivery { - v := headers.TransportDeliveryUnicast - return &v - }(), - Protocol: headers.TransportProtocolTCP, - InterleavedIDs: inTH.InterleavedIDs, - } + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"application/sdp"}, + "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, + }, + Body: tracks.Marshal(false), + }) + require.NoError(t, err) - byts, _ = base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Transport": th.Marshal(), - }, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) + req, err = readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Setup, req.Method) - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Play, req.Method) + v := auth.NewValidator("myuser", "mypass", nil) - byts, _ = base.Response{ - StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusUnauthorized, + Header: base.Header{ + "WWW-Authenticate": v.Header(), + }, + }) + require.NoError(t, err) - byts, _ = base.InterleavedFrame{ - Channel: 0, - Payload: testRTPPacketMarshaled, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) + req, err = readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Setup, req.Method) + require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL) - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Teardown, req.Method) + err = v.ValidateRequest(req) + require.NoError(t, err) - byts, _ = base.Response{ - StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) + var inTH headers.Transport + err = inTH.Unmarshal(req.Header["Transport"]) + require.NoError(t, err) - conn.Close() + th := headers.Transport{ + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + Protocol: headers.TransportProtocolTCP, + InterleavedIDs: inTH.InterleavedIDs, + } + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": th.Marshal(), + }, + }) + require.NoError(t, err) + + req, err = readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Play, req.Method) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err) + + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 0, + Payload: testRTPPacketMarshaled, + }, make([]byte, 1024)) + require.NoError(t, err) + + req, err = readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Teardown, req.Method) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err) + }() }() packetRecv := make(chan struct{}) @@ -1317,16 +1271,16 @@ func TestClientReadDifferentInterleavedIDs(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -1335,11 +1289,10 @@ func TestClientReadDifferentInterleavedIDs(t *testing.T) { string(base.Play), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) @@ -1353,18 +1306,17 @@ func TestClientReadDifferentInterleavedIDs(t *testing.T) { tracks := Tracks{track1} tracks.setControls() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL) @@ -1382,42 +1334,38 @@ func TestClientReadDifferentInterleavedIDs(t *testing.T) { InterleavedIDs: &[2]int{2, 3}, } - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/"), req.URL) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - byts, _ = base.InterleavedFrame{ + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 2, Payload: testRTPPacketMarshaled, - }.Marshal() - _, err = conn.Write(byts) + }, make([]byte, 1024)) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/"), req.URL) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -1465,167 +1413,163 @@ func TestClientReadRedirect(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() - require.NoError(t, err) - br := bufio.NewReader(conn) - - req, err := readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Options, req.Method) - - byts, _ := base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Public": base.HeaderValue{strings.Join([]string{ - string(base.Describe), - string(base.Setup), - string(base.Play), - }, ", ")}, - }, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) - - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Describe, req.Method) - - byts, _ = base.Response{ - StatusCode: base.StatusMovedPermanently, - Header: base.Header{ - "Location": base.HeaderValue{"rtsp://localhost:8554/test"}, - }, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) - - conn.Close() - - conn, err = l.Accept() - require.NoError(t, err) - defer conn.Close() - br = bufio.NewReader(conn) - - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Options, req.Method) - - byts, _ = base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Public": base.HeaderValue{strings.Join([]string{ - string(base.Describe), - string(base.Setup), - string(base.Play), - }, ", ")}, - }, - }.Marshal() - - _, err = conn.Write(byts) - require.NoError(t, err) - - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Describe, req.Method) - - if withCredentials { - if _, exists := req.Header["Authorization"]; !exists { - authRealm := "example@localhost" - authNonce := "exampleNonce" - authOpaque := "exampleOpaque" - authStale := "FALSE" - authAlg := "MD5" - byts, _ = base.Response{ - Header: base.Header{ - "WWW-Authenticate": headers.Authenticate{ - Method: headers.AuthDigest, - Realm: &authRealm, - Nonce: &authNonce, - Opaque: &authOpaque, - Stale: &authStale, - Algorithm: &authAlg, - }.Marshal(), - }, - StatusCode: base.StatusUnauthorized, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) - } - req, err = readRequest(br) + func() { + nconn, err := l.Accept() + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) + + req, err := readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Options, req.Method) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + string(base.Setup), + string(base.Play), + }, ", ")}, + }, + }) + require.NoError(t, err) + + req, err = readRequest(conn) require.NoError(t, err) - authHeaderVal, exists := req.Header["Authorization"] - require.True(t, exists) - var authHeader headers.Authenticate - require.NoError(t, authHeader.Unmarshal(authHeaderVal)) - require.Equal(t, *authHeader.Username, "testusr") require.Equal(t, base.Describe, req.Method) - } - track := &TrackH264{ - PayloadType: 96, - SPS: []byte{0x01, 0x02, 0x03, 0x04}, - PPS: []byte{0x01, 0x02, 0x03, 0x04}, - } + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusMovedPermanently, + Header: base.Header{ + "Location": base.HeaderValue{"rtsp://localhost:8554/test"}, + }, + }) + require.NoError(t, err) + }() - tracks := Tracks{track} - tracks.setControls() + func() { + nconn, err := l.Accept() + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) - byts, _ = base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Content-Type": base.HeaderValue{"application/sdp"}, - "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, - }, - Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) + req, err := readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Options, req.Method) - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Setup, req.Method) + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + string(base.Setup), + string(base.Play), + }, ", ")}, + }, + }) - var th headers.Transport - err = th.Unmarshal(req.Header["Transport"]) - require.NoError(t, err) + require.NoError(t, err) - byts, _ = base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Transport": headers.Transport{ - Protocol: headers.TransportProtocolUDP, - Delivery: func() *headers.TransportDelivery { - v := headers.TransportDeliveryUnicast - return &v - }(), - ClientPorts: th.ClientPorts, - ServerPorts: &[2]int{34556, 34557}, - }.Marshal(), - }, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) + req, err = readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Describe, req.Method) - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Play, req.Method) + if withCredentials { + if _, exists := req.Header["Authorization"]; !exists { + authRealm := "example@localhost" + authNonce := "exampleNonce" + authOpaque := "exampleOpaque" + authStale := "FALSE" + authAlg := "MD5" + err = conn.WriteResponse(&base.Response{ + Header: base.Header{ + "WWW-Authenticate": headers.Authenticate{ + Method: headers.AuthDigest, + Realm: &authRealm, + Nonce: &authNonce, + Opaque: &authOpaque, + Stale: &authStale, + Algorithm: &authAlg, + }.Marshal(), + }, + StatusCode: base.StatusUnauthorized, + }) + require.NoError(t, err) + } + req, err = readRequest(conn) + require.NoError(t, err) + authHeaderVal, exists := req.Header["Authorization"] + require.True(t, exists) + var authHeader headers.Authenticate + require.NoError(t, authHeader.Unmarshal(authHeaderVal)) + require.Equal(t, *authHeader.Username, "testusr") + require.Equal(t, base.Describe, req.Method) + } - byts, _ = base.Response{ - StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) - require.NoError(t, err) + track := &TrackH264{ + PayloadType: 96, + SPS: []byte{0x01, 0x02, 0x03, 0x04}, + PPS: []byte{0x01, 0x02, 0x03, 0x04}, + } - time.Sleep(500 * time.Millisecond) + tracks := Tracks{track} + tracks.setControls() - l1, err := net.ListenPacket("udp", "localhost:34556") - require.NoError(t, err) - defer l1.Close() + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"application/sdp"}, + "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, + }, + Body: tracks.Marshal(false), + }) + require.NoError(t, err) - l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: th.ClientPorts[0], - }) + req, err = readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Setup, req.Method) + + var th headers.Transport + err = th.Unmarshal(req.Header["Transport"]) + require.NoError(t, err) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": headers.Transport{ + Protocol: headers.TransportProtocolUDP, + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + ClientPorts: th.ClientPorts, + ServerPorts: &[2]int{34556, 34557}, + }.Marshal(), + }, + }) + require.NoError(t, err) + + req, err = readRequest(conn) + require.NoError(t, err) + require.Equal(t, base.Play, req.Method) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err) + + time.Sleep(500 * time.Millisecond) + + l1, err := net.ListenPacket("udp", "localhost:34556") + require.NoError(t, err) + defer l1.Close() + + l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: th.ClientPorts[0], + }) + }() }() ru := "rtsp://localhost:8554/path1" @@ -1642,7 +1586,7 @@ func TestClientReadRedirect(t *testing.T) { } func TestClientReadPause(t *testing.T) { - writeFrames := func(inTH *headers.Transport, conn net.Conn, br *bufio.Reader) (chan struct{}, chan struct{}) { + writeFrames := func(inTH *headers.Transport, conn *conn.Conn) (chan struct{}, chan struct{}) { writerTerminate := make(chan struct{}) writerDone := make(chan struct{}) @@ -1669,11 +1613,10 @@ func TestClientReadPause(t *testing.T) { Port: inTH.ClientPorts[0], }) } else { - byts, _ := base.InterleavedFrame{ + conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 0, Payload: testRTPPacketMarshaled, - }.Marshal() - conn.Write(byts) + }, make([]byte, 1024)) } case <-writerTerminate: @@ -1699,16 +1642,16 @@ func TestClientReadPause(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -1717,11 +1660,10 @@ func TestClientReadPause(t *testing.T) { string(base.Play), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -1734,18 +1676,17 @@ func TestClientReadPause(t *testing.T) { tracks := Tracks{track} tracks.setControls() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -1769,63 +1710,58 @@ func TestClientReadPause(t *testing.T) { th.InterleavedIDs = inTH.InterleavedIDs } - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - writerTerminate, writerDone := writeFrames(&inTH, conn, br) + writerTerminate, writerDone := writeFrames(&inTH, conn) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Pause, req.Method) close(writerTerminate) <-writerDone - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - writerTerminate, writerDone = writeFrames(&inTH, conn, br) + writerTerminate, writerDone = writeFrames(&inTH, conn) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) close(writerTerminate) <-writerDone - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -1880,16 +1816,16 @@ func TestClientReadRTCPReport(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -1898,11 +1834,10 @@ func TestClientReadRTCPReport(t *testing.T) { string(base.Play), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -1915,18 +1850,17 @@ func TestClientReadRTCPReport(t *testing.T) { tracks := Tracks{track} tracks.setControls() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -1942,7 +1876,7 @@ func TestClientReadRTCPReport(t *testing.T) { require.NoError(t, err) defer l2.Close() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": headers.Transport{ @@ -1955,18 +1889,16 @@ func TestClientReadRTCPReport(t *testing.T) { ClientPorts: inTH.ClientPorts, }.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) // skip firewall opening @@ -1985,7 +1917,7 @@ func TestClientReadRTCPReport(t *testing.T) { }, Payload: []byte{0x05, 0x02, 0x03, 0x04}, } - byts, _ = pkt.Marshal() + byts, _ := pkt.Marshal() _, err = l1.WriteTo(byts, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: inTH.ClientPorts[0], @@ -2028,14 +1960,13 @@ func TestClientReadRTCPReport(t *testing.T) { close(reportReceived) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -2066,16 +1997,16 @@ func TestClientReadErrorTimeout(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -2084,11 +2015,10 @@ func TestClientReadErrorTimeout(t *testing.T) { string(base.Play), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -2101,18 +2031,17 @@ func TestClientReadErrorTimeout(t *testing.T) { tracks := Tracks{track} tracks.setControls() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -2142,23 +2071,21 @@ func TestClientReadErrorTimeout(t *testing.T) { th.InterleavedIDs = inTH.InterleavedIDs } - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) if transport == "udp" || transport == "auto" { @@ -2169,14 +2096,13 @@ func TestClientReadErrorTimeout(t *testing.T) { }) } - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -2223,16 +2149,16 @@ func TestClientReadIgnoreTCPInvalidTrack(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -2241,11 +2167,10 @@ func TestClientReadIgnoreTCPInvalidTrack(t *testing.T) { string(base.Play), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -2258,18 +2183,17 @@ func TestClientReadIgnoreTCPInvalidTrack(t *testing.T) { tracks := Tracks{track} tracks.setControls() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -2286,47 +2210,42 @@ func TestClientReadIgnoreTCPInvalidTrack(t *testing.T) { th.Protocol = headers.TransportProtocolTCP th.InterleavedIDs = inTH.InterleavedIDs - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - byts, _ = base.InterleavedFrame{ + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 6, Payload: testRTPPacketMarshaled, - }.Marshal() - _, err = conn.Write(byts) + }, make([]byte, 1024)) require.NoError(t, err) - byts, _ = base.InterleavedFrame{ + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 0, Payload: testRTPPacketMarshaled, - }.Marshal() - _, err = conn.Write(byts) + }, make([]byte, 1024)) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -2359,16 +2278,16 @@ func TestClientReadSeek(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -2377,11 +2296,10 @@ func TestClientReadSeek(t *testing.T) { string(base.Play), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -2394,18 +2312,17 @@ func TestClientReadSeek(t *testing.T) { tracks := Tracks{track} tracks.setControls() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -2422,16 +2339,15 @@ func TestClientReadSeek(t *testing.T) { InterleavedIDs: inTH.InterleavedIDs, } - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -2444,23 +2360,21 @@ func TestClientReadSeek(t *testing.T) { }, }, ra) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Pause, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -2472,20 +2386,18 @@ func TestClientReadSeek(t *testing.T) { }, }, ra) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -2538,16 +2450,16 @@ func TestClientReadKeepaliveFromSession(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -2556,11 +2468,10 @@ func TestClientReadKeepaliveFromSession(t *testing.T) { string(base.Play), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -2573,18 +2484,17 @@ func TestClientReadKeepaliveFromSession(t *testing.T) { tracks := Tracks{track} tracks.setControls() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) @@ -2592,7 +2502,7 @@ func TestClientReadKeepaliveFromSession(t *testing.T) { err = inTH.Unmarshal(req.Header["Transport"]) require.NoError(t, err) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": headers.Transport{ @@ -2612,31 +2522,28 @@ func TestClientReadKeepaliveFromSession(t *testing.T) { }(), }.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) recv := make(chan struct{}) go func() { defer close(recv) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -2670,17 +2577,17 @@ func TestClientReadDifferentSource(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/test/stream?param=value"), req.URL) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -2689,11 +2596,10 @@ func TestClientReadDifferentSource(t *testing.T) { string(base.Play), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/test/stream?param=value"), req.URL) @@ -2707,18 +2613,17 @@ func TestClientReadDifferentSource(t *testing.T) { tracks := Tracks{track} tracks.setControls() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, "Content-Base": base.HeaderValue{"rtsp://localhost:8554/test/stream?param=value/"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/test/stream?param=value/trackID=0"), req.URL) @@ -2749,25 +2654,23 @@ func TestClientReadDifferentSource(t *testing.T) { require.NoError(t, err) defer l2.Close() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Play, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/test/stream?param=value/"), req.URL) require.Equal(t, base.HeaderValue{"npt=0-"}, req.Header["Range"]) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) // server -> client (RTP) @@ -2777,15 +2680,14 @@ func TestClientReadDifferentSource(t *testing.T) { Port: th.ClientPorts[0], }) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/test/stream?param=value/"), req.URL) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() diff --git a/client_test.go b/client_test.go index 45ade5d2..7bcae613 100644 --- a/client_test.go +++ b/client_test.go @@ -1,7 +1,6 @@ package gortsplib import ( - "bufio" "crypto/tls" "net" "strings" @@ -11,6 +10,7 @@ import ( "github.com/aler9/gortsplib/pkg/auth" "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/conn" "github.com/aler9/gortsplib/pkg/url" ) @@ -22,15 +22,15 @@ func mustParseURL(s string) *url.URL { return u } -func readRequest(br *bufio.Reader) (*base.Request, error) { +func readRequest(conn *conn.Conn) (*base.Request, error) { var req base.Request - err := req.Read(br) + err := conn.ReadRequest(&req) return &req, err } -func readRequestIgnoreFrames(br *bufio.Reader) (*base.Request, error) { +func readRequestIgnoreFrames(conn *conn.Conn) (*base.Request, error) { var req base.Request - err := req.ReadIgnoreFrames(2048, br) + err := conn.ReadRequestIgnoreFrames(&req) return &req, err } @@ -44,14 +44,14 @@ func TestClientTLSSetServerName(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() + defer nconn.Close() cert, err := tls.X509KeyPair(serverCert, serverKey) require.NoError(t, err) - tconn := tls.Server(conn, &tls.Config{ + tnconn := tls.Server(nconn, &tls.Config{ Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true, VerifyConnection: func(cs tls.ConnectionState) error { @@ -60,7 +60,7 @@ func TestClientTLSSetServerName(t *testing.T) { }, }) - err = tconn.Handshake() + err = tnconn.Handshake() require.EqualError(t, err, "remote error: tls: bad certificate") }() @@ -91,16 +91,16 @@ func TestClientSession(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - br := bufio.NewReader(conn) - defer conn.Close() + conn := conn.NewConn(nconn) + defer nconn.Close() - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ @@ -108,11 +108,10 @@ func TestClientSession(t *testing.T) { }, ", ")}, "Session": base.HeaderValue{"123456"}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -127,15 +126,14 @@ func TestClientSession(t *testing.T) { tracks := Tracks{track} tracks.setControls() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, "Session": base.HeaderValue{"123456"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -162,42 +160,40 @@ func TestClientAuth(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - br := bufio.NewReader(conn) - defer conn.Close() + conn := conn.NewConn(nconn) + defer nconn.Close() - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ string(base.Describe), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) v := auth.NewValidator("myuser", "mypass", nil) - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusUnauthorized, Header: base.Header{ "WWW-Authenticate": v.Header(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) @@ -213,14 +209,13 @@ func TestClientAuth(t *testing.T) { tracks := Tracks{track} tracks.setControls() - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, }, Body: tracks.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -247,27 +242,26 @@ func TestClientDescribeCharset(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Public": base.HeaderValue{strings.Join([]string{ string(base.Describe), }, ", ")}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - req, err = readRequest(br) + req, err = readRequest(conn) require.NoError(t, err) require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) @@ -278,15 +272,14 @@ func TestClientDescribeCharset(t *testing.T) { PPS: []byte{0x01, 0x02, 0x03, 0x04}, } - byts, _ = base.Response{ + err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp; charset=utf-8"}, "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, }, Body: Tracks{track1}.Marshal(false), - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) }() @@ -349,12 +342,12 @@ func TestClientCloseDuringRequest(t *testing.T) { go func() { defer close(serverDone) - conn, err := l.Accept() + nconn, err := l.Accept() require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - req, err := readRequest(br) + req, err := readRequest(conn) require.NoError(t, err) require.Equal(t, base.Options, req.Method) diff --git a/constants.go b/constants.go index 862e4b61..8db66f0c 100644 --- a/constants.go +++ b/constants.go @@ -1,12 +1,6 @@ package gortsplib const ( - tcpReadBufferSize = 4096 - - // this must fit an entire H264 NALU and a RTP header. - // with a 250 Mbps H264 video, the maximum NALU size is 2.2MB - tcpMaxFramePayloadSize = 3 * 1024 * 1024 - // same size as GStreamer's rtspsrc udpKernelReadBufferSize = 0x80000 diff --git a/pkg/base/interleavedframe.go b/pkg/base/interleavedframe.go index 09c4e09a..7d0b1ae3 100644 --- a/pkg/base/interleavedframe.go +++ b/pkg/base/interleavedframe.go @@ -7,65 +7,10 @@ import ( ) const ( - interleavedFrameMagicByte = 0x24 + // InterleavedFrameMagicByte is the first byte of an interleaved frame. + InterleavedFrameMagicByte = 0x24 ) -// ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Response. -func ReadInterleavedFrameOrRequest( - frame *InterleavedFrame, - maxPayloadSize int, - req *Request, - br *bufio.Reader, -) (interface{}, error) { - b, err := br.ReadByte() - if err != nil { - return nil, err - } - br.UnreadByte() - - if b == interleavedFrameMagicByte { - err := frame.Read(maxPayloadSize, br) - if err != nil { - return nil, err - } - return frame, err - } - - err = req.Read(br) - if err != nil { - return nil, err - } - return req, nil -} - -// ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response. -func ReadInterleavedFrameOrResponse( - frame *InterleavedFrame, - maxPayloadSize int, - res *Response, - br *bufio.Reader, -) (interface{}, error) { - b, err := br.ReadByte() - if err != nil { - return nil, err - } - br.UnreadByte() - - if b == interleavedFrameMagicByte { - err := frame.Read(maxPayloadSize, br) - if err != nil { - return nil, err - } - return frame, err - } - - err = res.Read(br) - if err != nil { - return nil, err - } - return res, nil -} - // InterleavedFrame is an interleaved frame, and allows to transfer binary data // within RTSP/TCP connections. It is used to send and receive RTP and RTCP packets with TCP. type InterleavedFrame struct { @@ -77,22 +22,19 @@ type InterleavedFrame struct { } // Read decodes an interleaved frame. -func (f *InterleavedFrame) Read(maxPayloadSize int, br *bufio.Reader) error { +func (f *InterleavedFrame) Read(br *bufio.Reader) error { var header [4]byte _, err := io.ReadFull(br, header[:]) if err != nil { return err } - if header[0] != interleavedFrameMagicByte { + if header[0] != InterleavedFrameMagicByte { return fmt.Errorf("invalid magic byte (0x%.2x)", header[0]) } + // it's useless to check payloadLen since it's limited to 65535 payloadLen := int(uint16(header[2])<<8 | uint16(header[3])) - if payloadLen > maxPayloadSize { - return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", - payloadLen, maxPayloadSize) - } f.Channel = int(header[1]) f.Payload = make([]byte, payloadLen) diff --git a/pkg/base/interleavedframe_test.go b/pkg/base/interleavedframe_test.go index 00538fbb..baec9c54 100644 --- a/pkg/base/interleavedframe_test.go +++ b/pkg/base/interleavedframe_test.go @@ -37,7 +37,7 @@ func TestInterleavedFrameRead(t *testing.T) { for _, ca := range casesInterleavedFrame { t.Run(ca.name, func(t *testing.T) { - err := f.Read(1024, bufio.NewReader(bytes.NewBuffer(ca.enc))) + err := f.Read(bufio.NewReader(bytes.NewBuffer(ca.enc))) require.NoError(t, err) require.Equal(t, ca.dec, f) }) @@ -60,11 +60,6 @@ func TestInterleavedFrameReadErrors(t *testing.T) { []byte{0x55, 0x00, 0x00, 0x00}, "invalid magic byte (0x55)", }, - { - "payload size too big", - []byte{0x24, 0x00, 0x00, 0x08}, - "payload size (8) greater than maximum allowed (5)", - }, { "payload invalid", []byte{0x24, 0x00, 0x00, 0x05, 0x01, 0x02}, @@ -73,7 +68,7 @@ func TestInterleavedFrameReadErrors(t *testing.T) { } { t.Run(ca.name, func(t *testing.T) { var f InterleavedFrame - err := f.Read(5, bufio.NewReader(bytes.NewBuffer(ca.byts))) + err := f.Read(bufio.NewReader(bytes.NewBuffer(ca.byts))) require.EqualError(t, err, ca.err) }) } @@ -88,106 +83,3 @@ func TestInterleavedFrameMarshal(t *testing.T) { }) } } - -func TestReadInterleavedFrameOrRequest(t *testing.T) { - byts := []byte("DESCRIBE rtsp://example.com/media.mp4 RTSP/1.0\r\n" + - "Accept: application/sdp\r\n" + - "CSeq: 2\r\n" + - "\r\n") - byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...) - - var f InterleavedFrame - var req Request - br := bufio.NewReader(bytes.NewBuffer(byts)) - - out, err := ReadInterleavedFrameOrRequest(&f, 10, &req, br) - require.NoError(t, err) - require.Equal(t, &req, out) - - out, err = ReadInterleavedFrameOrRequest(&f, 10, &req, br) - require.NoError(t, err) - require.Equal(t, &f, out) -} - -func TestReadInterleavedFrameOrRequestErrors(t *testing.T) { - for _, ca := range []struct { - name string - byts []byte - err string - }{ - { - "empty", - []byte{}, - "EOF", - }, - { - "invalid frame", - []byte{0x24, 0x00}, - "unexpected EOF", - }, - { - "invalid request", - []byte("DESCRIBE"), - "EOF", - }, - } { - t.Run(ca.name, func(t *testing.T) { - var f InterleavedFrame - var req Request - br := bufio.NewReader(bytes.NewBuffer(ca.byts)) - _, err := ReadInterleavedFrameOrRequest(&f, 10, &req, br) - require.EqualError(t, err, ca.err) - }) - } -} - -func TestReadInterleavedFrameOrResponse(t *testing.T) { - byts := []byte("RTSP/1.0 200 OK\r\n" + - "CSeq: 1\r\n" + - "Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" + - "\r\n") - byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...) - - var f InterleavedFrame - var res Response - br := bufio.NewReader(bytes.NewBuffer(byts)) - out, err := ReadInterleavedFrameOrResponse(&f, 10, &res, br) - require.NoError(t, err) - require.Equal(t, &res, out) - - out, err = ReadInterleavedFrameOrResponse(&f, 10, &res, br) - require.NoError(t, err) - require.Equal(t, &f, out) -} - -func TestReadInterleavedFrameOrResponseErrors(t *testing.T) { - for _, ca := range []struct { - name string - byts []byte - err string - }{ - { - "empty", - []byte{}, - "EOF", - }, - { - "invalid frame", - []byte{0x24, 0x00}, - "unexpected EOF", - }, - { - "invalid response", - []byte("RTSP/1.0"), - "EOF", - }, - } { - t.Run(ca.name, func(t *testing.T) { - var f InterleavedFrame - var res Response - br := bufio.NewReader(bytes.NewBuffer(ca.byts)) - _, err := ReadInterleavedFrameOrResponse(&f, 10, &res, br) - require.EqualError(t, err, ca.err) - }) - } -} diff --git a/pkg/base/request.go b/pkg/base/request.go index 4b38cf0a..f7ee2df6 100644 --- a/pkg/base/request.go +++ b/pkg/base/request.go @@ -100,23 +100,6 @@ func (req *Request) Read(rb *bufio.Reader) error { return nil } -// ReadIgnoreFrames reads a request and ignores any interleaved frame sent -// before the request. -func (req *Request) ReadIgnoreFrames(maxPayloadSize int, rb *bufio.Reader) error { - var f InterleavedFrame - - for { - recv, err := ReadInterleavedFrameOrRequest(&f, maxPayloadSize, req, rb) - if err != nil { - return err - } - - if _, ok := recv.(*Request); ok { - return nil - } - } -} - // MarshalSize returns the size of a Request. func (req Request) MarshalSize() int { n := 0 diff --git a/pkg/base/request_test.go b/pkg/base/request_test.go index 3e345ad6..19801455 100644 --- a/pkg/base/request_test.go +++ b/pkg/base/request_test.go @@ -238,29 +238,6 @@ func TestRequestMarshal(t *testing.T) { } } -func TestRequestReadIgnoreFrames(t *testing.T) { - byts := []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4} - byts = append(byts, []byte("OPTIONS rtsp://example.com/media.mp4 RTSP/1.0\r\n"+ - "CSeq: 1\r\n"+ - "Proxy-Require: gzipped-messages\r\n"+ - "Require: implicit-play\r\n"+ - "\r\n")...) - - rb := bufio.NewReader(bytes.NewBuffer(byts)) - var req Request - err := req.ReadIgnoreFrames(10, rb) - require.NoError(t, err) -} - -func TestRequestReadIgnoreFramesErrors(t *testing.T) { - byts := []byte{0x25} - - rb := bufio.NewReader(bytes.NewBuffer(byts)) - var req Request - err := req.ReadIgnoreFrames(10, rb) - require.EqualError(t, err, "EOF") -} - func TestRequestString(t *testing.T) { byts := []byte("OPTIONS rtsp://example.com/media.mp4 RTSP/1.0\r\n" + "CSeq: 1\r\n" + diff --git a/pkg/base/response.go b/pkg/base/response.go index c7a0bba2..fddb5af5 100644 --- a/pkg/base/response.go +++ b/pkg/base/response.go @@ -184,23 +184,6 @@ func (res *Response) Read(rb *bufio.Reader) error { return nil } -// ReadIgnoreFrames reads a response and ignores any interleaved frame sent -// before the response. -func (res *Response) ReadIgnoreFrames(maxPayloadSize int, rb *bufio.Reader) error { - var f InterleavedFrame - - for { - recv, err := ReadInterleavedFrameOrResponse(&f, maxPayloadSize, res, rb) - if err != nil { - return err - } - - if _, ok := recv.(*Response); ok { - return nil - } - } -} - // MarshalSize returns the size of a Response. func (res Response) MarshalSize() int { n := 0 diff --git a/pkg/base/response_test.go b/pkg/base/response_test.go index df890a66..2b71dc6b 100644 --- a/pkg/base/response_test.go +++ b/pkg/base/response_test.go @@ -212,28 +212,6 @@ func TestResponseMarshalAutoFillStatus(t *testing.T) { require.Equal(t, byts, buf) } -func TestResponseReadIgnoreFrames(t *testing.T) { - byts := []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4} - byts = append(byts, []byte("RTSP/1.0 200 OK\r\n"+ - "CSeq: 1\r\n"+ - "Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n"+ - "\r\n")...) - - rb := bufio.NewReader(bytes.NewBuffer(byts)) - var res Response - err := res.ReadIgnoreFrames(10, rb) - require.NoError(t, err) -} - -func TestResponseReadIgnoreFramesErrors(t *testing.T) { - byts := []byte{0x25} - - rb := bufio.NewReader(bytes.NewBuffer(byts)) - var res Response - err := res.ReadIgnoreFrames(10, rb) - require.EqualError(t, err, "EOF") -} - func TestResponseString(t *testing.T) { byts := []byte("RTSP/1.0 200 OK\r\n" + "CSeq: 3\r\n" + diff --git a/pkg/conn/conn.go b/pkg/conn/conn.go new file mode 100644 index 00000000..9c4b4a0a --- /dev/null +++ b/pkg/conn/conn.go @@ -0,0 +1,148 @@ +package conn + +import ( + "bufio" + "io" + + "github.com/aler9/gortsplib/pkg/base" +) + +const ( + readBufferSize = 4096 +) + +// Conn is a RTSP TCP connection. +type Conn struct { + w io.Writer + br *bufio.Reader +} + +// NewConn allocates a Conn. +func NewConn(rw io.ReadWriter) *Conn { + return &Conn{ + w: rw, + br: bufio.NewReaderSize(rw, readBufferSize), + } +} + +// ReadResponse reads a Response. +func (c *Conn) ReadResponse(res *base.Response) error { + return res.Read(c.br) +} + +// ReadRequest reads a Request. +func (c *Conn) ReadRequest(req *base.Request) error { + return req.Read(c.br) +} + +// ReadInterleavedFrame reads a InterleavedFrame. +func (c *Conn) ReadInterleavedFrame(fr *base.InterleavedFrame) error { + return fr.Read(c.br) +} + +// ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Request. +func (c *Conn) ReadInterleavedFrameOrRequest( + frame *base.InterleavedFrame, + req *base.Request, +) (interface{}, error) { + b, err := c.br.ReadByte() + if err != nil { + return nil, err + } + c.br.UnreadByte() + + if b == base.InterleavedFrameMagicByte { + err := frame.Read(c.br) + if err != nil { + return nil, err + } + return frame, err + } + + err = req.Read(c.br) + if err != nil { + return nil, err + } + + return req, nil +} + +// ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response. +func (c *Conn) ReadInterleavedFrameOrResponse( + frame *base.InterleavedFrame, + res *base.Response, +) (interface{}, error) { + b, err := c.br.ReadByte() + if err != nil { + return nil, err + } + c.br.UnreadByte() + + if b == base.InterleavedFrameMagicByte { + err := frame.Read(c.br) + if err != nil { + return nil, err + } + return frame, err + } + + err = res.Read(c.br) + if err != nil { + return nil, err + } + + return res, nil +} + +// ReadRequestIgnoreFrames reads a Request and ignore frames in between. +func (c *Conn) ReadRequestIgnoreFrames(req *base.Request) error { + var f base.InterleavedFrame + + for { + recv, err := c.ReadInterleavedFrameOrRequest(&f, req) + if err != nil { + return err + } + + if _, ok := recv.(*base.Request); ok { + return nil + } + } +} + +// ReadResponseIgnoreFrames reads a Response and ignore frames in between. +func (c *Conn) ReadResponseIgnoreFrames(res *base.Response) error { + var f base.InterleavedFrame + + for { + recv, err := c.ReadInterleavedFrameOrResponse(&f, res) + if err != nil { + return err + } + + if _, ok := recv.(*base.Response); ok { + return nil + } + } +} + +// WriteRequest writes a request. +func (c *Conn) WriteRequest(req *base.Request) error { + buf, _ := req.Marshal() + _, err := c.w.Write(buf) + return err +} + +// WriteResponse writes a response. +func (c *Conn) WriteResponse(res *base.Response) error { + buf, _ := res.Marshal() + _, err := c.w.Write(buf) + return err +} + +// WriteInterleavedFrame writes an interleaved frame. +func (c *Conn) WriteInterleavedFrame(fr *base.InterleavedFrame, buf []byte) error { + n, _ := fr.MarshalTo(buf) + _, err := c.w.Write(buf[:n]) + return err +} diff --git a/pkg/conn/conn_test.go b/pkg/conn/conn_test.go new file mode 100644 index 00000000..c4b27eea --- /dev/null +++ b/pkg/conn/conn_test.go @@ -0,0 +1,159 @@ +package conn + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/aler9/gortsplib/pkg/base" +) + +func TestReadInterleavedFrameOrRequest(t *testing.T) { + byts := []byte("DESCRIBE rtsp://example.com/media.mp4 RTSP/1.0\r\n" + + "Accept: application/sdp\r\n" + + "CSeq: 2\r\n" + + "\r\n") + byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...) + + var f base.InterleavedFrame + var req base.Request + conn := NewConn(bytes.NewBuffer(byts)) + + out, err := conn.ReadInterleavedFrameOrRequest(&f, &req) + require.NoError(t, err) + require.Equal(t, &req, out) + + out, err = conn.ReadInterleavedFrameOrRequest(&f, &req) + require.NoError(t, err) + require.Equal(t, &f, out) +} + +func TestReadInterleavedFrameOrRequestErrors(t *testing.T) { + for _, ca := range []struct { + name string + byts []byte + err string + }{ + { + "empty", + []byte{}, + "EOF", + }, + { + "invalid frame", + []byte{0x24, 0x00}, + "unexpected EOF", + }, + { + "invalid request", + []byte("DESCRIBE"), + "EOF", + }, + } { + t.Run(ca.name, func(t *testing.T) { + var f base.InterleavedFrame + var req base.Request + conn := NewConn(bytes.NewBuffer(ca.byts)) + _, err := conn.ReadInterleavedFrameOrRequest(&f, &req) + require.EqualError(t, err, ca.err) + }) + } +} + +func TestReadInterleavedFrameOrResponse(t *testing.T) { + byts := []byte("RTSP/1.0 200 OK\r\n" + + "CSeq: 1\r\n" + + "Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" + + "\r\n") + byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...) + + var f base.InterleavedFrame + var res base.Response + conn := NewConn(bytes.NewBuffer(byts)) + + out, err := conn.ReadInterleavedFrameOrResponse(&f, &res) + require.NoError(t, err) + require.Equal(t, &res, out) + + out, err = conn.ReadInterleavedFrameOrResponse(&f, &res) + require.NoError(t, err) + require.Equal(t, &f, out) +} + +func TestReadInterleavedFrameOrResponseErrors(t *testing.T) { + for _, ca := range []struct { + name string + byts []byte + err string + }{ + { + "empty", + []byte{}, + "EOF", + }, + { + "invalid frame", + []byte{0x24, 0x00}, + "unexpected EOF", + }, + { + "invalid response", + []byte("RTSP/1.0"), + "EOF", + }, + } { + t.Run(ca.name, func(t *testing.T) { + var f base.InterleavedFrame + var res base.Response + conn := NewConn(bytes.NewBuffer(ca.byts)) + _, err := conn.ReadInterleavedFrameOrResponse(&f, &res) + require.EqualError(t, err, ca.err) + }) + } +} + +func TestReadRequestIgnoreFrames(t *testing.T) { + byts := []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4} + byts = append(byts, []byte("OPTIONS rtsp://example.com/media.mp4 RTSP/1.0\r\n"+ + "CSeq: 1\r\n"+ + "Proxy-Require: gzipped-messages\r\n"+ + "Require: implicit-play\r\n"+ + "\r\n")...) + + conn := NewConn(bytes.NewBuffer(byts)) + var req base.Request + err := conn.ReadRequestIgnoreFrames(&req) + require.NoError(t, err) +} + +func TestReadRequestIgnoreFramesErrors(t *testing.T) { + byts := []byte{0x25} + + conn := NewConn(bytes.NewBuffer(byts)) + var req base.Request + err := conn.ReadRequestIgnoreFrames(&req) + require.EqualError(t, err, "EOF") +} + +func TestReadResponseIgnoreFrames(t *testing.T) { + byts := []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4} + byts = append(byts, []byte("RTSP/1.0 200 OK\r\n"+ + "CSeq: 1\r\n"+ + "Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n"+ + "\r\n")...) + + conn := NewConn(bytes.NewBuffer(byts)) + var res base.Response + err := conn.ReadResponseIgnoreFrames(&res) + require.NoError(t, err) +} + +func TestReadResponseIgnoreFramesErrors(t *testing.T) { + byts := []byte{0x25} + + conn := NewConn(bytes.NewBuffer(byts)) + var res base.Response + err := conn.ReadResponseIgnoreFrames(&res) + require.EqualError(t, err, "EOF") +} diff --git a/server_publish_test.go b/server_publish_test.go index 36463d14..fdb53211 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -1,7 +1,6 @@ package gortsplib import ( - "bufio" "crypto/tls" "net" "testing" @@ -13,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/conn" "github.com/aler9/gortsplib/pkg/headers" ) @@ -113,13 +113,13 @@ func TestServerPublishErrorAnnounce(t *testing.T) { }, } { t.Run(ca.name, func(t *testing.T) { - connClosed := make(chan struct{}) + nconnClosed := make(chan struct{}) s := &Server{ Handler: &testServerHandler{ onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { require.EqualError(t, ctx.Error, ca.err) - close(connClosed) + close(nconnClosed) }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { return &base.Response{ @@ -134,15 +134,15 @@ func TestServerPublishErrorAnnounce(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - _, err = writeReqReadRes(conn, br, ca.req) + _, err = writeReqReadRes(conn, ca.req) require.NoError(t, err) - <-connClosed + <-nconnClosed }) } } @@ -225,10 +225,10 @@ func TestServerPublishSetupPath(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) track := &TrackH264{ PayloadType: 96, @@ -255,7 +255,7 @@ func TestServerPublishSetupPath(t *testing.T) { byts, _ := sout.Marshal() - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Announce, URL: mustParseURL("rtsp://localhost:8554/" + ca.path), Header: base.Header{ @@ -280,7 +280,7 @@ func TestServerPublishSetupPath(t *testing.T) { InterleavedIDs: &[2]int{0, 1}, } - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL(ca.url), Header: base.Header{ @@ -320,10 +320,10 @@ func TestServerPublishErrorSetupDifferentPaths(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) track := &TrackH264{ PayloadType: 96, @@ -334,7 +334,7 @@ func TestServerPublishErrorSetupDifferentPaths(t *testing.T) { tracks := Tracks{track} tracks.setControls() - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Announce, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -359,7 +359,7 @@ func TestServerPublishErrorSetupDifferentPaths(t *testing.T) { InterleavedIDs: &[2]int{0, 1}, } - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/test2stream/trackID=0"), Header: base.Header{ @@ -400,10 +400,10 @@ func TestServerPublishErrorSetupTrackTwice(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) track := &TrackH264{ PayloadType: 96, @@ -414,7 +414,7 @@ func TestServerPublishErrorSetupTrackTwice(t *testing.T) { tracks := Tracks{track} tracks.setControls() - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Announce, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -439,7 +439,7 @@ func TestServerPublishErrorSetupTrackTwice(t *testing.T) { InterleavedIDs: &[2]int{0, 1}, } - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -454,7 +454,7 @@ func TestServerPublishErrorSetupTrackTwice(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -501,10 +501,10 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) track1 := &TrackH264{ PayloadType: 96, @@ -521,7 +521,7 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) { tracks := Tracks{track1, track2} tracks.setControls() - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Announce, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -546,7 +546,7 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) { InterleavedIDs: &[2]int{0, 1}, } - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -561,7 +561,7 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Record, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -583,18 +583,18 @@ func TestServerPublish(t *testing.T) { "tls", } { t.Run(transport, func(t *testing.T) { - connOpened := make(chan struct{}) - connClosed := make(chan struct{}) + nconnOpened := make(chan struct{}) + nconnClosed := make(chan struct{}) sessionOpened := make(chan struct{}) sessionClosed := make(chan struct{}) s := &Server{ Handler: &testServerHandler{ onConnOpen: func(ctx *ServerHandlerOnConnOpenCtx) { - close(connOpened) + close(nconnOpened) }, onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { - close(connClosed) + close(nconnClosed) }, onSessionOpen: func(ctx *ServerHandlerOnSessionOpenCtx) { close(sessionOpened) @@ -649,19 +649,19 @@ func TestServerPublish(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() + defer nconn.Close() - conn = func() net.Conn { + nconn = func() net.Conn { if transport == "tls" { - return tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) + return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true}) } - return conn + return nconn }() - br := bufio.NewReader(conn) + conn := conn.NewConn(nconn) - <-connOpened + <-nconnOpened track := &TrackH264{ PayloadType: 96, @@ -672,7 +672,7 @@ func TestServerPublish(t *testing.T) { tracks := Tracks{track} tracks.setControls() - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Announce, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -716,7 +716,7 @@ func TestServerPublish(t *testing.T) { inTH.InterleavedIDs = &[2]int{0, 1} } - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -735,7 +735,7 @@ func TestServerPublish(t *testing.T) { err = th.Unmarshal(res.Header["Transport"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Record, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -754,7 +754,7 @@ func TestServerPublish(t *testing.T) { require.Equal(t, testRTCPPacketMarshaled, buf[:n]) } else { var f base.InterleavedFrame - err := f.Read(2048, br) + err := conn.ReadInterleavedFrame(&f) require.NoError(t, err) require.Equal(t, 1, f.Channel) require.Equal(t, testRTCPPacketMarshaled, f.Payload) @@ -783,18 +783,16 @@ func TestServerPublish(t *testing.T) { Port: th.ServerPorts[1], }) } else { - byts, _ := base.InterleavedFrame{ + err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 0, Payload: testRTPPacketMarshaled, - }.Marshal() - _, err = conn.Write(byts) + }, make([]byte, 1024)) require.NoError(t, err) - byts, _ = base.InterleavedFrame{ + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 1, Payload: testRTCPPacketMarshaled, - }.Marshal() - _, err = conn.Write(byts) + }, make([]byte, 1024)) require.NoError(t, err) } @@ -806,13 +804,13 @@ func TestServerPublish(t *testing.T) { require.Equal(t, testRTCPPacketMarshaled, buf[:n]) } else { var f base.InterleavedFrame - err := f.Read(2048, br) + err := conn.ReadInterleavedFrame(&f) require.NoError(t, err) require.Equal(t, 1, f.Channel) require.Equal(t, testRTCPPacketMarshaled, f.Payload) } - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Teardown, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -825,8 +823,8 @@ func TestServerPublish(t *testing.T) { <-sessionClosed - conn.Close() - <-connClosed + nconn.Close() + <-nconnClosed }) } } @@ -862,10 +860,10 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) track := &TrackH264{ PayloadType: 96, @@ -876,7 +874,7 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) { tracks := Tracks{track} tracks.setControls() - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Announce, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -901,7 +899,7 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) { ClientPorts: &[2]int{35466, 35467}, } - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -920,7 +918,7 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) { err = th.Unmarshal(res.Header["Transport"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Record, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -931,11 +929,10 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - byts, _ := base.InterleavedFrame{ + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 0, Payload: []byte{0x01, 0x02, 0x03, 0x04}, - }.Marshal() - _, err = conn.Write(byts) + }, make([]byte, 1024)) require.NoError(t, err) } @@ -968,10 +965,10 @@ func TestServerPublishRTCPReport(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) track := &TrackH264{ PayloadType: 96, @@ -982,7 +979,7 @@ func TestServerPublishRTCPReport(t *testing.T) { tracks := Tracks{track} tracks.setControls() - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Announce, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1002,7 +999,7 @@ func TestServerPublishRTCPReport(t *testing.T) { require.NoError(t, err) defer l2.Close() - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -1032,7 +1029,7 @@ func TestServerPublishRTCPReport(t *testing.T) { err = th.Unmarshal(res.Header["Transport"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Record, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1105,13 +1102,13 @@ func TestServerPublishTimeout(t *testing.T) { "tcp", } { t.Run(transport, func(t *testing.T) { - connClosed := make(chan struct{}) + nconnClosed := make(chan struct{}) sessionClosed := make(chan struct{}) s := &Server{ Handler: &testServerHandler{ onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { - close(connClosed) + close(nconnClosed) }, onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { close(sessionClosed) @@ -1145,10 +1142,10 @@ func TestServerPublishTimeout(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) track := &TrackH264{ PayloadType: 96, @@ -1159,7 +1156,7 @@ func TestServerPublishTimeout(t *testing.T) { tracks := Tracks{track} tracks.setControls() - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Announce, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1190,7 +1187,7 @@ func TestServerPublishTimeout(t *testing.T) { inTH.InterleavedIDs = &[2]int{0, 1} } - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -1209,7 +1206,7 @@ func TestServerPublishTimeout(t *testing.T) { err = th.Unmarshal(res.Header["Transport"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Record, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1223,7 +1220,7 @@ func TestServerPublishTimeout(t *testing.T) { <-sessionClosed if transport == "tcp" { - <-connClosed + <-nconnClosed } }) } @@ -1235,13 +1232,13 @@ func TestServerPublishWithoutTeardown(t *testing.T) { "tcp", } { t.Run(transport, func(t *testing.T) { - connClosed := make(chan struct{}) + nconnClosed := make(chan struct{}) sessionClosed := make(chan struct{}) s := &Server{ Handler: &testServerHandler{ onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { - close(connClosed) + close(nconnClosed) }, onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { close(sessionClosed) @@ -1275,9 +1272,9 @@ func TestServerPublishWithoutTeardown(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - br := bufio.NewReader(conn) + conn := conn.NewConn(nconn) track := &TrackH264{ PayloadType: 96, @@ -1288,7 +1285,7 @@ func TestServerPublishWithoutTeardown(t *testing.T) { tracks := Tracks{track} tracks.setControls() - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Announce, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1319,7 +1316,7 @@ func TestServerPublishWithoutTeardown(t *testing.T) { inTH.InterleavedIDs = &[2]int{0, 1} } - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -1338,7 +1335,7 @@ func TestServerPublishWithoutTeardown(t *testing.T) { err = th.Unmarshal(res.Header["Transport"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Record, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1349,10 +1346,10 @@ func TestServerPublishWithoutTeardown(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - conn.Close() + nconn.Close() <-sessionClosed - <-connClosed + <-nconnClosed }) } } @@ -1395,10 +1392,10 @@ func TestServerPublishUDPChangeConn(t *testing.T) { sxID := "" func() { - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) track := &TrackH264{ PayloadType: 96, @@ -1409,7 +1406,7 @@ func TestServerPublishUDPChangeConn(t *testing.T) { tracks := Tracks{track} tracks.setControls() - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Announce, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1434,7 +1431,7 @@ func TestServerPublishUDPChangeConn(t *testing.T) { ClientPorts: &[2]int{35466, 35467}, } - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -1449,7 +1446,7 @@ func TestServerPublishUDPChangeConn(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Record, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1464,12 +1461,12 @@ func TestServerPublishUDPChangeConn(t *testing.T) { }() func() { - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.GetParameter, URL: mustParseURL("rtsp://localhost:8554/teststream/"), Header: base.Header{ diff --git a/server_read_test.go b/server_read_test.go index 7bebf5ca..4a9815be 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -1,7 +1,6 @@ package gortsplib import ( - "bufio" "crypto/tls" "net" "strconv" @@ -16,6 +15,7 @@ import ( "golang.org/x/net/ipv4" "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/conn" "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/url" ) @@ -118,10 +118,10 @@ func TestServerReadSetupPath(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) th := &headers.Transport{ Protocol: headers.TransportProtocolTCP, @@ -136,7 +136,7 @@ func TestServerReadSetupPath(t *testing.T) { InterleavedIDs: &[2]int{ca.trackID * 2, (ca.trackID * 2) + 1}, } - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL(ca.url), Header: base.Header{ @@ -157,7 +157,7 @@ func TestServerReadSetupErrors(t *testing.T) { "closed stream", } { t.Run(ca, func(t *testing.T) { - connClosed := make(chan struct{}) + nconnClosed := make(chan struct{}) track := &TrackH264{ PayloadType: 96, @@ -185,7 +185,7 @@ func TestServerReadSetupErrors(t *testing.T) { case "closed stream": require.EqualError(t, ctx.Error, "stream is closed") } - close(connClosed) + close(nconnClosed) }, onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { return &base.Response{ @@ -200,10 +200,10 @@ func TestServerReadSetupErrors(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) th := &headers.Transport{ Protocol: headers.TransportProtocolTCP, @@ -218,7 +218,7 @@ func TestServerReadSetupErrors(t *testing.T) { InterleavedIDs: &[2]int{0, 1}, } - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -237,7 +237,7 @@ func TestServerReadSetupErrors(t *testing.T) { require.NoError(t, err) th.InterleavedIDs = &[2]int{2, 3} - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/test12stream/trackID=1"), Header: base.Header{ @@ -258,7 +258,7 @@ func TestServerReadSetupErrors(t *testing.T) { require.NoError(t, err) th.InterleavedIDs = &[2]int{2, 3} - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -275,7 +275,7 @@ func TestServerReadSetupErrors(t *testing.T) { require.Equal(t, base.StatusBadRequest, res.StatusCode) } - <-connClosed + <-nconnClosed }) } } @@ -288,8 +288,8 @@ func TestServerRead(t *testing.T) { "multicast", } { t.Run(transport, func(t *testing.T) { - connOpened := make(chan struct{}) - connClosed := make(chan struct{}) + nconnOpened := make(chan struct{}) + nconnClosed := make(chan struct{}) sessionOpened := make(chan struct{}) sessionClosed := make(chan struct{}) framesReceived := make(chan struct{}) @@ -310,10 +310,10 @@ func TestServerRead(t *testing.T) { s := &Server{ Handler: &testServerHandler{ onConnOpen: func(ctx *ServerHandlerOnConnOpenCtx) { - close(connOpened) + close(nconnOpened) }, onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { - close(connClosed) + close(nconnClosed) }, onSessionOpen: func(ctx *ServerHandlerOnSessionOpenCtx) { close(sessionOpened) @@ -385,18 +385,18 @@ func TestServerRead(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", listenIP+":8554") + nconn, err := net.Dial("tcp", listenIP+":8554") require.NoError(t, err) - conn = func() net.Conn { + nconn = func() net.Conn { if transport == "tls" { - return tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) + return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true}) } - return conn + return nconn }() - br := bufio.NewReader(conn) + conn := conn.NewConn(nconn) - <-connOpened + <-nconnOpened inTH := &headers.Transport{ Mode: func() *headers.TransportMode { @@ -424,7 +424,7 @@ func TestServerRead(t *testing.T) { inTH.InterleavedIDs = &[2]int{4, 5} } - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream/trackID=0"), Header: base.Header{ @@ -498,7 +498,7 @@ func TestServerRead(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"), Header: base.Header{ @@ -519,7 +519,7 @@ func TestServerRead(t *testing.T) { case "tcp", "tls": var f base.InterleavedFrame - err := f.Read(2048, br) + err := conn.ReadInterleavedFrame(&f) require.NoError(t, err) switch f.Channel { @@ -549,7 +549,7 @@ func TestServerRead(t *testing.T) { var f base.InterleavedFrame for i := 0; i < 2; i++ { - err := f.Read(2048, br) + err := conn.ReadInterleavedFrame(&f) require.NoError(t, err) switch f.Channel { @@ -582,18 +582,17 @@ func TestServerRead(t *testing.T) { <-framesReceived default: - byts, _ := base.InterleavedFrame{ + err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 5, Payload: testRTCPPacketMarshaled, - }.Marshal() - _, err = conn.Write(byts) + }, make([]byte, 1024)) require.NoError(t, err) <-framesReceived } if transport == "udp" || transport == "multicast" { // ping with OPTIONS - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Options, URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"), Header: base.Header{ @@ -605,7 +604,7 @@ func TestServerRead(t *testing.T) { require.Equal(t, base.StatusOK, res.StatusCode) // ping with GET_PARAMETER - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.GetParameter, URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"), Header: base.Header{ @@ -617,7 +616,7 @@ func TestServerRead(t *testing.T) { require.Equal(t, base.StatusOK, res.StatusCode) } - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Teardown, URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"), Header: base.Header{ @@ -630,8 +629,8 @@ func TestServerRead(t *testing.T) { <-sessionClosed - conn.Close() - <-connClosed + nconn.Close() + <-nconnClosed }) } } @@ -669,10 +668,10 @@ func TestServerReadRTCPReport(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) inTH := &headers.Transport{ Mode: func() *headers.TransportMode { @@ -687,7 +686,7 @@ func TestServerReadRTCPReport(t *testing.T) { ClientPorts: &[2]int{35466, 35467}, } - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -710,7 +709,7 @@ func TestServerReadRTCPReport(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -737,7 +736,7 @@ func TestServerReadRTCPReport(t *testing.T) { OctetCount: 8, }, packets[0]) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Teardown, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -779,12 +778,12 @@ func TestServerReadVLCMulticast(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", listenIP+":8554") + nconn, err := net.Dial("tcp", listenIP+":8554") require.NoError(t, err) - br := bufio.NewReader(conn) - defer conn.Close() + conn := conn.NewConn(nconn) + defer nconn.Close() - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Describe, URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream?vlcmulticast"), Header: base.Header{ @@ -858,12 +857,12 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -889,7 +888,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -900,8 +899,8 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - var fr base.InterleavedFrame - err = fr.Read(2048, br) + var f base.InterleavedFrame + err = conn.ReadInterleavedFrame(&f) require.NoError(t, err) } @@ -937,12 +936,12 @@ func TestServerReadPlayPlay(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -968,7 +967,7 @@ func TestServerReadPlayPlay(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -979,7 +978,7 @@ func TestServerReadPlayPlay(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1053,12 +1052,12 @@ func TestServerReadPlayPausePlay(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -1084,7 +1083,7 @@ func TestServerReadPlayPausePlay(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1095,7 +1094,7 @@ func TestServerReadPlayPausePlay(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Pause, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1106,7 +1105,7 @@ func TestServerReadPlayPausePlay(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1176,12 +1175,12 @@ func TestServerReadPlayPausePause(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -1207,7 +1206,7 @@ func TestServerReadPlayPausePause(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1218,33 +1217,31 @@ func TestServerReadPlayPausePause(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - byts, _ := base.Request{ + err = conn.WriteRequest(&base.Request{ Method: base.Pause, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ "CSeq": base.HeaderValue{"2"}, "Session": base.HeaderValue{sx.Session}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - res, err = readResIgnoreFrames(br) + res, err = readResIgnoreFrames(conn) require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - byts, _ = base.Request{ + err = conn.WriteRequest(&base.Request{ Method: base.Pause, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ "CSeq": base.HeaderValue{"2"}, "Session": base.HeaderValue{sx.Session}, }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) - res, err = readResIgnoreFrames(br) + res, err = readResIgnoreFrames(conn) require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) } @@ -1308,10 +1305,10 @@ func TestServerReadTimeout(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) inTH := &headers.Transport{ Mode: func() *headers.TransportMode { @@ -1333,7 +1330,7 @@ func TestServerReadTimeout(t *testing.T) { inTH.Protocol = headers.TransportProtocolUDP } - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -1348,7 +1345,7 @@ func TestServerReadTimeout(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1370,7 +1367,7 @@ func TestServerReadWithoutTeardown(t *testing.T) { "tcp", } { t.Run(transport, func(t *testing.T) { - connClosed := make(chan struct{}) + nconnClosed := make(chan struct{}) sessionClosed := make(chan struct{}) track := &TrackH264{ @@ -1385,7 +1382,7 @@ func TestServerReadWithoutTeardown(t *testing.T) { s := &Server{ Handler: &testServerHandler{ onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { - close(connClosed) + close(nconnClosed) }, onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { close(sessionClosed) @@ -1420,10 +1417,10 @@ func TestServerReadWithoutTeardown(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) inTH := &headers.Transport{ Delivery: func() *headers.TransportDelivery { @@ -1444,7 +1441,7 @@ func TestServerReadWithoutTeardown(t *testing.T) { inTH.InterleavedIDs = &[2]int{0, 1} } - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -1459,7 +1456,7 @@ func TestServerReadWithoutTeardown(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1470,10 +1467,10 @@ func TestServerReadWithoutTeardown(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - conn.Close() + nconn.Close() <-sessionClosed - <-connClosed + <-nconnClosed }) } } @@ -1518,10 +1515,10 @@ func TestServerReadUDPChangeConn(t *testing.T) { sxID := "" func() { - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) inTH := &headers.Transport{ Delivery: func() *headers.TransportDelivery { @@ -1536,7 +1533,7 @@ func TestServerReadUDPChangeConn(t *testing.T) { ClientPorts: &[2]int{35466, 35467}, } - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -1551,7 +1548,7 @@ func TestServerReadUDPChangeConn(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1566,12 +1563,12 @@ func TestServerReadUDPChangeConn(t *testing.T) { }() func() { - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.GetParameter, URL: mustParseURL("rtsp://localhost:8554/teststream/"), Header: base.Header{ @@ -1626,10 +1623,10 @@ func TestServerReadPartialTracks(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) inTH := &headers.Transport{ Delivery: func() *headers.TransportDelivery { @@ -1644,7 +1641,7 @@ func TestServerReadPartialTracks(t *testing.T) { InterleavedIDs: &[2]int{4, 5}, } - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=1"), Header: base.Header{ @@ -1659,7 +1656,7 @@ func TestServerReadPartialTracks(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1671,7 +1668,7 @@ func TestServerReadPartialTracks(t *testing.T) { require.Equal(t, base.StatusOK, res.StatusCode) var f base.InterleavedFrame - err = f.Read(2048, br) + err = conn.ReadInterleavedFrame(&f) require.NoError(t, err) require.Equal(t, 4, f.Channel) require.Equal(t, testRTPPacketMarshaled, f.Payload) @@ -1679,10 +1676,10 @@ func TestServerReadPartialTracks(t *testing.T) { func TestServerReadAdditionalInfos(t *testing.T) { getInfos := func() (*headers.RTPInfo, []*uint32) { - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) ssrcs := make([]*uint32, 2) @@ -1699,7 +1696,7 @@ func TestServerReadAdditionalInfos(t *testing.T) { InterleavedIDs: &[2]int{0, 1}, } - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -1732,7 +1729,7 @@ func TestServerReadAdditionalInfos(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=1"), Header: base.Header{ @@ -1749,7 +1746,7 @@ func TestServerReadAdditionalInfos(t *testing.T) { require.NoError(t, err) ssrcs[1] = th.SSRC - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1913,10 +1910,10 @@ func TestServerReadErrorUDPSamePorts(t *testing.T) { defer s.Close() func() { - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) inTH := &headers.Transport{ Delivery: func() *headers.TransportDelivery { @@ -1931,7 +1928,7 @@ func TestServerReadErrorUDPSamePorts(t *testing.T) { ClientPorts: &[2]int{35466, 35467}, } - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -1946,7 +1943,7 @@ func TestServerReadErrorUDPSamePorts(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -1959,10 +1956,10 @@ func TestServerReadErrorUDPSamePorts(t *testing.T) { }() func() { - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) inTH := &headers.Transport{ Delivery: func() *headers.TransportDelivery { @@ -1977,7 +1974,7 @@ func TestServerReadErrorUDPSamePorts(t *testing.T) { ClientPorts: &[2]int{35466, 35467}, } - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ diff --git a/server_test.go b/server_test.go index 263ab54b..f51ebfaf 100644 --- a/server_test.go +++ b/server_test.go @@ -1,7 +1,6 @@ package gortsplib import ( - "bufio" "fmt" "net" "testing" @@ -10,6 +9,7 @@ import ( "github.com/aler9/gortsplib/pkg/auth" "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/conn" "github.com/aler9/gortsplib/pkg/headers" ) @@ -67,24 +67,23 @@ NkxNic7oHgsZpIkZ8HK+QjAAWA== -----END PRIVATE KEY----- `) -func writeReqReadRes(conn net.Conn, - br *bufio.Reader, +func writeReqReadRes( + conn *conn.Conn, req base.Request, ) (*base.Response, error) { - byts, _ := req.Marshal() - _, err := conn.Write(byts) + err := conn.WriteRequest(&req) if err != nil { return nil, err } var res base.Response - err = res.Read(br) + err = conn.ReadResponse(&res) return &res, err } -func readResIgnoreFrames(br *bufio.Reader) (*base.Response, error) { +func readResIgnoreFrames(conn *conn.Conn) (*base.Response, error) { var res base.Response - err := res.ReadIgnoreFrames(2048, br) + err := conn.ReadResponseIgnoreFrames(&res) return &res, err } @@ -232,7 +231,7 @@ func TestServerErrorInvalidUDPPorts(t *testing.T) { } func TestServerConnClose(t *testing.T) { - connClosed := make(chan struct{}) + nconnClosed := make(chan struct{}) s := &Server{ Handler: &testServerHandler{ @@ -241,7 +240,7 @@ func TestServerConnClose(t *testing.T) { ctx.Conn.Close() }, onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { - close(connClosed) + close(nconnClosed) }, }, RTSPAddress: "localhost:8554", @@ -251,11 +250,11 @@ func TestServerConnClose(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() + defer nconn.Close() - <-connClosed + <-nconnClosed } func TestServerCSeq(t *testing.T) { @@ -266,12 +265,12 @@ func TestServerCSeq(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Options, URL: mustParseURL("rtsp://localhost:8554/"), Header: base.Header{ @@ -285,13 +284,13 @@ func TestServerCSeq(t *testing.T) { } func TestServerErrorCSeqMissing(t *testing.T) { - connClosed := make(chan struct{}) + nconnClosed := make(chan struct{}) s := &Server{ Handler: &testServerHandler{ onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { require.EqualError(t, ctx.Error, "CSeq is missing") - close(connClosed) + close(nconnClosed) }, }, RTSPAddress: "localhost:8554", @@ -300,12 +299,12 @@ func TestServerErrorCSeqMissing(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Options, URL: mustParseURL("rtsp://localhost:8554/"), Header: base.Header{}, @@ -313,7 +312,7 @@ func TestServerErrorCSeqMissing(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusBadRequest, res.StatusCode) - <-connClosed + <-nconnClosed } type testServerErrMethodNotImplemented struct { @@ -349,15 +348,15 @@ func TestServerErrorMethodNotImplemented(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) var sx headers.Session if ca == "inside session" { - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -389,7 +388,7 @@ func TestServerErrorMethodNotImplemented(t *testing.T) { headers["Session"] = base.HeaderValue{sx.Session} } - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.SetParameter, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: headers, @@ -404,7 +403,7 @@ func TestServerErrorMethodNotImplemented(t *testing.T) { headers["Session"] = base.HeaderValue{sx.Session} } - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Options, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: headers, @@ -450,12 +449,12 @@ func TestServerErrorTCPTwoConnOneSession(t *testing.T) { require.NoError(t, err) defer s.Close() - conn1, err := net.Dial("tcp", "localhost:8554") + nconn1, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn1.Close() - br1 := bufio.NewReader(conn1) + defer nconn1.Close() + conn1 := conn.NewConn(nconn1) - res, err := writeReqReadRes(conn1, br1, base.Request{ + res, err := writeReqReadRes(conn1, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -481,7 +480,7 @@ func TestServerErrorTCPTwoConnOneSession(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn1, br1, base.Request{ + res, err = writeReqReadRes(conn1, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -492,12 +491,12 @@ func TestServerErrorTCPTwoConnOneSession(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - conn2, err := net.Dial("tcp", "localhost:8554") + nconn2, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn2.Close() - br2 := bufio.NewReader(conn2) + defer nconn2.Close() + conn2 := conn.NewConn(nconn2) - res, err = writeReqReadRes(conn2, br2, base.Request{ + res, err = writeReqReadRes(conn2, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -556,12 +555,12 @@ func TestServerErrorTCPOneConnTwoSessions(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -587,7 +586,7 @@ func TestServerErrorTCPOneConnTwoSessions(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Play, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -598,7 +597,7 @@ func TestServerErrorTCPOneConnTwoSessions(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -668,15 +667,15 @@ func TestServerGetSetParameter(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) var sx headers.Session if ca == "inside session" { - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -709,7 +708,7 @@ func TestServerGetSetParameter(t *testing.T) { headers["Session"] = base.HeaderValue{sx.Session} } - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.SetParameter, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: headers, @@ -725,7 +724,7 @@ func TestServerGetSetParameter(t *testing.T) { headers["Session"] = base.HeaderValue{sx.Session} } - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.GetParameter, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: headers, @@ -771,12 +770,12 @@ func TestServerErrorInvalidSession(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: method, URL: mustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ @@ -815,11 +814,12 @@ func TestServerSessionClose(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() + defer nconn.Close() + conn := conn.NewConn(nconn) - byts, _ := base.Request{ + err = conn.WriteRequest(&base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -837,8 +837,7 @@ func TestServerSessionClose(t *testing.T) { InterleavedIDs: &[2]int{0, 1}, }.Marshal(), }, - }.Marshal() - _, err = conn.Write(byts) + }) require.NoError(t, err) <-sessionClosed @@ -884,11 +883,11 @@ func TestServerSessionAutoClose(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - br := bufio.NewReader(conn) + conn := conn.NewConn(nconn) - _, err = writeReqReadRes(conn, br, base.Request{ + _, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -909,7 +908,7 @@ func TestServerSessionAutoClose(t *testing.T) { }) require.NoError(t, err) - conn.Close() + nconn.Close() <-sessionClosed }) @@ -919,7 +918,7 @@ func TestServerSessionAutoClose(t *testing.T) { func TestServerErrorInvalidPath(t *testing.T) { for _, ca := range []string{"inside session", "outside session"} { t.Run(ca, func(t *testing.T) { - connClosed := make(chan struct{}) + nconnClosed := make(chan struct{}) track := &TrackH264{ PayloadType: 96, @@ -934,7 +933,7 @@ func TestServerErrorInvalidPath(t *testing.T) { Handler: &testServerHandler{ onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { require.EqualError(t, ctx.Error, "invalid path") - close(connClosed) + close(nconnClosed) }, onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { return &base.Response{ @@ -949,13 +948,13 @@ func TestServerErrorInvalidPath(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) if ca == "inside session" { - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -981,7 +980,7 @@ func TestServerErrorInvalidPath(t *testing.T) { err = sx.Unmarshal(res.Header["Session"]) require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ + res, err = writeReqReadRes(conn, base.Request{ Method: base.SetParameter, URL: mustParseURL("rtsp://localhost:8554"), Header: base.Header{ @@ -992,7 +991,7 @@ func TestServerErrorInvalidPath(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusBadRequest, res.StatusCode) } else { - res, err := writeReqReadRes(conn, br, base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.SetParameter, URL: mustParseURL("rtsp://localhost:8554"), Header: base.Header{ @@ -1003,7 +1002,7 @@ func TestServerErrorInvalidPath(t *testing.T) { require.Equal(t, base.StatusBadRequest, res.StatusCode) } - <-connClosed + <-nconnClosed }) } } @@ -1036,10 +1035,10 @@ func TestServerAuth(t *testing.T) { require.NoError(t, err) defer s.Close() - conn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + defer nconn.Close() + conn := conn.NewConn(nconn) track := &TrackH264{ PayloadType: 96, @@ -1057,7 +1056,7 @@ func TestServerAuth(t *testing.T) { Body: Tracks{track}.Marshal(false), } - res, err := writeReqReadRes(conn, br, req) + res, err := writeReqReadRes(conn, req) require.NoError(t, err) require.Equal(t, base.StatusUnauthorized, res.StatusCode) @@ -1065,7 +1064,7 @@ func TestServerAuth(t *testing.T) { require.NoError(t, err) sender.AddAuthorization(&req) - res, err = writeReqReadRes(conn, br, req) + res, err = writeReqReadRes(conn, req) require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) } diff --git a/serverconn.go b/serverconn.go index 9e45089f..51e75228 100644 --- a/serverconn.go +++ b/serverconn.go @@ -1,7 +1,6 @@ package gortsplib import ( - "bufio" "context" "crypto/tls" "errors" @@ -14,6 +13,7 @@ import ( "github.com/pion/rtcp" "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/conn" "github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/url" ) @@ -32,13 +32,13 @@ type readReq struct { // ServerConn is a server-side RTSP connection. type ServerConn struct { - s *Server - conn net.Conn + s *Server + nconn net.Conn ctx context.Context ctxCancel func() remoteAddr *net.TCPAddr - br *bufio.Reader + conn *conn.Conn session *ServerSession readFunc func(readRequest chan readReq) error @@ -55,7 +55,7 @@ func newServerConn( ) *ServerConn { ctx, ctxCancel := context.WithCancel(s.ctx) - conn := func() net.Conn { + nconn = func() net.Conn { if s.TLSConfig != nil { return tls.Server(nconn, s.TLSConfig) } @@ -64,10 +64,10 @@ func newServerConn( sc := &ServerConn{ s: s, - conn: conn, + nconn: nconn, ctx: ctx, ctxCancel: ctxCancel, - remoteAddr: conn.RemoteAddr().(*net.TCPAddr), + remoteAddr: nconn.RemoteAddr().(*net.TCPAddr), sessionRemove: make(chan *ServerSession), done: make(chan struct{}), } @@ -88,7 +88,7 @@ func (sc *ServerConn) Close() error { // NetConn returns the underlying net.Conn. func (sc *ServerConn) NetConn() net.Conn { - return sc.conn + return sc.nconn } func (sc *ServerConn) ip() net.IP { @@ -109,7 +109,7 @@ func (sc *ServerConn) run() { }) } - sc.br = bufio.NewReaderSize(sc.conn, tcpReadBufferSize) + sc.conn = conn.NewConn(sc.nconn) readRequest := make(chan readReq) readErr := make(chan error) @@ -120,7 +120,7 @@ func (sc *ServerConn) run() { sc.ctxCancel() - sc.conn.Close() + sc.nconn.Close() <-readDone if sc.session != nil { @@ -185,12 +185,12 @@ func (sc *ServerConn) runReader(readRequest chan readReq, readErr chan error, re func (sc *ServerConn) readFuncStandard(readRequest chan readReq) error { // reset deadline - sc.conn.SetReadDeadline(time.Time{}) + sc.nconn.SetReadDeadline(time.Time{}) var req base.Request for { - err := req.Read(sc.br) + err := sc.conn.ReadRequest(&req) if err != nil { return err } @@ -211,7 +211,7 @@ func (sc *ServerConn) readFuncStandard(readRequest chan readReq) error { func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { // reset deadline - sc.conn.SetReadDeadline(time.Time{}) + sc.nconn.SetReadDeadline(time.Time{}) select { case sc.session.startWriter <- struct{}{}: @@ -299,10 +299,10 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { for { if sc.session.state == ServerSessionStateRecord { - sc.conn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout)) + sc.nconn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout)) } - what, err := base.ReadInterleavedFrameOrRequest(&frame, tcpMaxFramePayloadSize, &req, sc.br) + what, err := sc.conn.ReadInterleavedFrameOrRequest(&frame, &req) if err != nil { return err } @@ -532,10 +532,8 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error { h.OnResponse(sc, res) } - byts, _ := res.Marshal() - - sc.conn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout)) - sc.conn.Write(byts) + sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout)) + sc.conn.WriteResponse(res) return err } diff --git a/serversession.go b/serversession.go index 77396fc9..6d0b668e 100644 --- a/serversession.go +++ b/serversession.go @@ -1163,19 +1163,17 @@ func (ss *ServerSession) runWriter() { writeFunc = func(trackID int, isRTP bool, payload []byte) { if isRTP { - f := rtpFrames[trackID] - f.Payload = payload - n, _ := f.MarshalTo(buf) + fr := rtpFrames[trackID] + fr.Payload = payload - ss.tcpConn.conn.SetWriteDeadline(time.Now().Add(ss.s.WriteTimeout)) - ss.tcpConn.conn.Write(buf[:n]) + ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(ss.s.WriteTimeout)) + ss.tcpConn.conn.WriteInterleavedFrame(fr, buf) } else { - f := rtcpFrames[trackID] - f.Payload = payload - n, _ := f.MarshalTo(buf) + fr := rtcpFrames[trackID] + fr.Payload = payload - ss.tcpConn.conn.SetWriteDeadline(time.Now().Add(ss.s.WriteTimeout)) - ss.tcpConn.conn.Write(buf[:n]) + ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(ss.s.WriteTimeout)) + ss.tcpConn.conn.WriteInterleavedFrame(fr, buf) } } }