diff --git a/connclient.go b/connclient.go index e358c1f5..21a50064 100644 --- a/connclient.go +++ b/connclient.go @@ -8,6 +8,7 @@ import ( type ConnClient struct { nconn net.Conn + br *bufio.Reader bw *bufio.Writer session string cseqEnabled bool @@ -18,7 +19,8 @@ type ConnClient struct { func NewConnClient(nconn net.Conn) *ConnClient { return &ConnClient{ nconn: nconn, - bw: bufio.NewWriterSize(nconn, _INTERLEAVED_FRAME_MAX_SIZE), + br: bufio.NewReaderSize(nconn, 4096), + bw: bufio.NewWriterSize(nconn, 4096), } } @@ -58,15 +60,15 @@ func (c *ConnClient) WriteRequest(req *Request) error { } req.Header["Authorization"] = []string{c.authProv.generateHeader(req.Method, req.Url)} } - return req.write(c.nconn) + return req.write(c.bw) } func (c *ConnClient) ReadResponse() (*Response, error) { - return readResponse(c.nconn) + return readResponse(c.br) } func (c *ConnClient) ReadInterleavedFrame() (*InterleavedFrame, error) { - return readInterleavedFrame(c.nconn) + return readInterleavedFrame(c.br) } func (c *ConnClient) WriteInterleavedFrame(frame *InterleavedFrame) error { diff --git a/connserver.go b/connserver.go index 1bf8bc41..07738231 100644 --- a/connserver.go +++ b/connserver.go @@ -7,13 +7,15 @@ import ( type ConnServer struct { nconn net.Conn + br *bufio.Reader bw *bufio.Writer } func NewConnServer(nconn net.Conn) *ConnServer { return &ConnServer{ nconn: nconn, - bw: bufio.NewWriterSize(nconn, _INTERLEAVED_FRAME_MAX_SIZE), + br: bufio.NewReaderSize(nconn, 4096), + bw: bufio.NewWriterSize(nconn, 4096), } } @@ -22,15 +24,15 @@ func (s *ConnServer) NetConn() net.Conn { } func (s *ConnServer) ReadRequest() (*Request, error) { - return readRequest(s.nconn) + return readRequest(s.br) } func (s *ConnServer) WriteResponse(res *Response) error { - return res.write(s.nconn) + return res.write(s.bw) } func (s *ConnServer) ReadInterleavedFrame() (*InterleavedFrame, error) { - return readInterleavedFrame(s.nconn) + return readInterleavedFrame(s.br) } func (s *ConnServer) WriteInterleavedFrame(frame *InterleavedFrame) error { diff --git a/header_test.go b/header_test.go index 5550e00d..6299ff11 100644 --- a/header_test.go +++ b/header_test.go @@ -53,8 +53,8 @@ func TestHeaderWrite(t *testing.T) { var buf bytes.Buffer bw := bufio.NewWriter(&buf) err := c.header.write(bw) - bw.Flush() require.NoError(t, err) + bw.Flush() require.Equal(t, c.byts, buf.Bytes()) }) } diff --git a/interleavedframe.go b/interleavedframe.go index 08a03e4e..07e6aa72 100644 --- a/interleavedframe.go +++ b/interleavedframe.go @@ -24,11 +24,6 @@ func readInterleavedFrame(r io.Reader) (*InterleavedFrame, error) { return nil, err } - // connection terminated - if header[0] == 0x54 { - return nil, io.EOF - } - if header[0] != 0x24 { return nil, fmt.Errorf("wrong magic byte (0x%.2x)", header[0]) } diff --git a/request.go b/request.go index 51637d7f..e84d9aa9 100644 --- a/request.go +++ b/request.go @@ -3,7 +3,6 @@ package gortsplib import ( "bufio" "fmt" - "io" ) type Request struct { @@ -13,12 +12,10 @@ type Request struct { Content []byte } -func readRequest(r io.Reader) (*Request, error) { - rb := bufio.NewReader(r) - +func readRequest(br *bufio.Reader) (*Request, error) { req := &Request{} - byts, err := readBytesLimited(rb, ' ', 255) + byts, err := readBytesLimited(br, ' ', 255) if err != nil { return nil, err } @@ -28,7 +25,7 @@ func readRequest(r io.Reader) (*Request, error) { return nil, fmt.Errorf("empty method") } - byts, err = readBytesLimited(rb, ' ', 255) + byts, err = readBytesLimited(br, ' ', 255) if err != nil { return nil, err } @@ -38,7 +35,7 @@ func readRequest(r io.Reader) (*Request, error) { return nil, fmt.Errorf("empty path") } - byts, err = readBytesLimited(rb, '\r', 255) + byts, err = readBytesLimited(br, '\r', 255) if err != nil { return nil, err } @@ -48,17 +45,17 @@ func readRequest(r io.Reader) (*Request, error) { return nil, fmt.Errorf("expected '%s', got '%s'", _RTSP_PROTO, proto) } - err = readByteEqual(rb, '\n') + err = readByteEqual(br, '\n') if err != nil { return nil, err } - req.Header, err = readHeader(rb) + req.Header, err = readHeader(br) if err != nil { return nil, err } - req.Content, err = readContent(rb, req.Header) + req.Content, err = readContent(br, req.Header) if err != nil { return nil, err } @@ -66,23 +63,21 @@ func readRequest(r io.Reader) (*Request, error) { return req, nil } -func (req *Request) write(w io.Writer) error { - wb := bufio.NewWriter(w) - - _, err := wb.Write([]byte(req.Method + " " + req.Url + " " + _RTSP_PROTO + "\r\n")) +func (req *Request) write(bw *bufio.Writer) error { + _, err := bw.Write([]byte(req.Method + " " + req.Url + " " + _RTSP_PROTO + "\r\n")) if err != nil { return err } - err = req.Header.write(wb) + err = req.Header.write(bw) if err != nil { return err } - err = writeContent(wb, req.Content) + err = writeContent(bw, req.Content) if err != nil { return err } - return wb.Flush() + return bw.Flush() } diff --git a/request_test.go b/request_test.go index 5b6f3752..a5de4b99 100644 --- a/request_test.go +++ b/request_test.go @@ -1,6 +1,7 @@ package gortsplib import ( + "bufio" "bytes" "testing" @@ -115,7 +116,7 @@ var casesRequest = []struct { func TestRequestRead(t *testing.T) { for _, c := range casesRequest { t.Run(c.name, func(t *testing.T) { - req, err := readRequest(bytes.NewBuffer(c.byts)) + req, err := readRequest(bufio.NewReader(bytes.NewBuffer(c.byts))) require.NoError(t, err) require.Equal(t, c.req, req) }) @@ -126,8 +127,10 @@ func TestRequestWrite(t *testing.T) { for _, c := range casesRequest { t.Run(c.name, func(t *testing.T) { var buf bytes.Buffer - err := c.req.write(&buf) + bw := bufio.NewWriter(&buf) + err := c.req.write(bw) require.NoError(t, err) + bw.Flush() require.Equal(t, c.byts, buf.Bytes()) }) } diff --git a/response.go b/response.go index cbbed6d9..8fc1e5f6 100644 --- a/response.go +++ b/response.go @@ -3,7 +3,6 @@ package gortsplib import ( "bufio" "fmt" - "io" "strconv" ) @@ -14,12 +13,10 @@ type Response struct { Content []byte } -func readResponse(r io.Reader) (*Response, error) { - rb := bufio.NewReader(r) - +func readResponse(br *bufio.Reader) (*Response, error) { res := &Response{} - byts, err := readBytesLimited(rb, ' ', 255) + byts, err := readBytesLimited(br, ' ', 255) if err != nil { return nil, err } @@ -29,7 +26,7 @@ func readResponse(r io.Reader) (*Response, error) { return nil, fmt.Errorf("expected '%s', got '%s'", _RTSP_PROTO, proto) } - byts, err = readBytesLimited(rb, ' ', 4) + byts, err = readBytesLimited(br, ' ', 4) if err != nil { return nil, err } @@ -41,7 +38,7 @@ func readResponse(r io.Reader) (*Response, error) { return nil, fmt.Errorf("unable to parse status code") } - byts, err = readBytesLimited(rb, '\r', 255) + byts, err = readBytesLimited(br, '\r', 255) if err != nil { return nil, err } @@ -51,17 +48,17 @@ func readResponse(r io.Reader) (*Response, error) { return nil, fmt.Errorf("empty status") } - err = readByteEqual(rb, '\n') + err = readByteEqual(br, '\n') if err != nil { return nil, err } - res.Header, err = readHeader(rb) + res.Header, err = readHeader(br) if err != nil { return nil, err } - res.Content, err = readContent(rb, res.Header) + res.Content, err = readContent(br, res.Header) if err != nil { return nil, err } @@ -69,10 +66,8 @@ func readResponse(r io.Reader) (*Response, error) { return res, nil } -func (res *Response) write(w io.Writer) error { - wb := bufio.NewWriter(w) - - _, err := wb.Write([]byte(_RTSP_PROTO + " " + strconv.FormatInt(int64(res.StatusCode), 10) + " " + res.Status + "\r\n")) +func (res *Response) write(bw *bufio.Writer) error { + _, err := bw.Write([]byte(_RTSP_PROTO + " " + strconv.FormatInt(int64(res.StatusCode), 10) + " " + res.Status + "\r\n")) if err != nil { return err } @@ -81,15 +76,15 @@ func (res *Response) write(w io.Writer) error { res.Header["Content-Length"] = []string{strconv.FormatInt(int64(len(res.Content)), 10)} } - err = res.Header.write(wb) + err = res.Header.write(bw) if err != nil { return err } - err = writeContent(wb, res.Content) + err = writeContent(bw, res.Content) if err != nil { return err } - return wb.Flush() + return bw.Flush() } diff --git a/response_test.go b/response_test.go index 1ea95ebc..02549f31 100644 --- a/response_test.go +++ b/response_test.go @@ -1,6 +1,7 @@ package gortsplib import ( + "bufio" "bytes" "testing" @@ -110,7 +111,7 @@ var casesResponse = []struct { func TestResponseRead(t *testing.T) { for _, c := range casesResponse { t.Run(c.name, func(t *testing.T) { - res, err := readResponse(bytes.NewBuffer(c.byts)) + res, err := readResponse(bufio.NewReader(bytes.NewBuffer(c.byts))) require.NoError(t, err) require.Equal(t, c.res, res) }) @@ -121,8 +122,10 @@ func TestResponseWrite(t *testing.T) { for _, c := range casesResponse { t.Run(c.name, func(t *testing.T) { var buf bytes.Buffer - err := c.res.write(&buf) + bw := bufio.NewWriter(&buf) + err := c.res.write(bw) require.NoError(t, err) + bw.Flush() require.Equal(t, c.byts, buf.Bytes()) }) } diff --git a/utils.go b/utils.go index 83a163da..fb196482 100644 --- a/utils.go +++ b/utils.go @@ -12,23 +12,23 @@ const ( _MAX_CONTENT_LENGTH = 4096 ) -func readBytesLimited(rb *bufio.Reader, delim byte, n int) ([]byte, error) { +func readBytesLimited(br *bufio.Reader, delim byte, n int) ([]byte, error) { for i := 1; i <= n; i++ { - byts, err := rb.Peek(i) + byts, err := br.Peek(i) if err != nil { return nil, err } if byts[len(byts)-1] == delim { - rb.Discard(len(byts)) + br.Discard(len(byts)) return byts, nil } } return nil, fmt.Errorf("buffer length exceeds %d", n) } -func readByteEqual(rb *bufio.Reader, cmp byte) error { - byt, err := rb.ReadByte() +func readByteEqual(br *bufio.Reader, cmp byte) error { + byt, err := br.ReadByte() if err != nil { return err } @@ -40,7 +40,7 @@ func readByteEqual(rb *bufio.Reader, cmp byte) error { return nil } -func readContent(rb *bufio.Reader, header Header) ([]byte, error) { +func readContent(br *bufio.Reader, header Header) ([]byte, error) { cls, ok := header["Content-Length"] if !ok || len(cls) != 1 { return nil, nil @@ -56,7 +56,7 @@ func readContent(rb *bufio.Reader, header Header) ([]byte, error) { } ret := make([]byte, cl) - n, err := io.ReadFull(rb, ret) + n, err := io.ReadFull(br, ret) if err != nil && n != len(ret) { return nil, err } @@ -64,12 +64,12 @@ func readContent(rb *bufio.Reader, header Header) ([]byte, error) { return ret, nil } -func writeContent(wb *bufio.Writer, content []byte) error { +func writeContent(bw *bufio.Writer, content []byte) error { if len(content) == 0 { return nil } - _, err := wb.Write(content) + _, err := bw.Write(content) if err != nil { return err }