From eba2fb39d192d7cd7a1052ff389f2246d922ce67 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Tue, 6 Oct 2020 10:07:57 +0200 Subject: [PATCH] reuse structs when reading Requests, Responses and Headers --- base/content.go | 45 ++++++++++++++++++++++++++++++++++++++++ base/header.go | 22 +++++++++----------- base/header_test.go | 5 +++-- base/interleavedframe.go | 20 ++++++++++++------ base/request.go | 37 ++++++++++++++++----------------- base/request_test.go | 13 ++++++------ base/response.go | 33 ++++++++++++++--------------- base/response_test.go | 11 +++++----- base/utils.go | 39 ---------------------------------- connclient.go | 6 ++++-- connserver.go | 15 +++++++++++--- 11 files changed, 135 insertions(+), 111 deletions(-) create mode 100644 base/content.go diff --git a/base/content.go b/base/content.go new file mode 100644 index 00000000..2a7a2dfa --- /dev/null +++ b/base/content.go @@ -0,0 +1,45 @@ +package base + +import ( + "bufio" + "fmt" + "io" + "strconv" +) + +func contentRead(rb *bufio.Reader, header Header) ([]byte, error) { + cls, ok := header["Content-Length"] + if !ok || len(cls) != 1 { + return nil, nil + } + + cl, err := strconv.ParseInt(cls[0], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid Content-Length") + } + + if cl > rtspMaxContentLength { + return nil, fmt.Errorf("Content-Length exceeds %d", rtspMaxContentLength) + } + + ret := make([]byte, cl) + n, err := io.ReadFull(rb, ret) + if err != nil && n != len(ret) { + return nil, err + } + + return ret, nil +} + +func contentWrite(bw *bufio.Writer, content []byte) error { + if len(content) == 0 { + return nil + } + + _, err := bw.Write(content) + if err != nil { + return err + } + + return nil +} diff --git a/base/header.go b/base/header.go index 2fd13ca5..1fb28e31 100644 --- a/base/header.go +++ b/base/header.go @@ -34,32 +34,30 @@ type HeaderValue []string // Header is a RTSP reader, present in both Requests and Responses. type Header map[string]HeaderValue -func headerRead(rb *bufio.Reader) (Header, error) { - h := make(Header) - +func (h Header) read(rb *bufio.Reader) error { for { byt, err := rb.ReadByte() if err != nil { - return nil, err + return err } if byt == '\r' { err := readByteEqual(rb, '\n') if err != nil { - return nil, err + return err } break } if len(h) >= headerMaxEntryCount { - return nil, fmt.Errorf("headers count exceeds %d", headerMaxEntryCount) + return fmt.Errorf("headers count exceeds %d", headerMaxEntryCount) } key := string([]byte{byt}) byts, err := readBytesLimited(rb, ':', headerMaxKeyLength-1) if err != nil { - return nil, err + return err } key += string(byts[:len(byts)-1]) key = headerKeyNormalize(key) @@ -69,7 +67,7 @@ func headerRead(rb *bufio.Reader) (Header, error) { for { byt, err := rb.ReadByte() if err != nil { - return nil, err + return err } if byt != ' ' { @@ -80,23 +78,23 @@ func headerRead(rb *bufio.Reader) (Header, error) { byts, err = readBytesLimited(rb, '\r', headerMaxValueLength) if err != nil { - return nil, err + return err } val := string(byts[:len(byts)-1]) if len(val) == 0 { - return nil, fmt.Errorf("empty header value") + return fmt.Errorf("empty header value") } err = readByteEqual(rb, '\n') if err != nil { - return nil, err + return err } h[key] = append(h[key], val) } - return h, nil + return nil } func (h Header) write(wb *bufio.Writer) error { diff --git a/base/header_test.go b/base/header_test.go index 5c02cdb1..14c691f7 100644 --- a/base/header_test.go +++ b/base/header_test.go @@ -93,9 +93,10 @@ var casesHeader = []struct { func TestHeaderRead(t *testing.T) { for _, c := range casesHeader { t.Run(c.name, func(t *testing.T) { - req, err := headerRead(bufio.NewReader(bytes.NewBuffer(c.dec))) + h := make(Header) + err := h.read(bufio.NewReader(bytes.NewBuffer(c.dec))) require.NoError(t, err) - require.Equal(t, c.header, req) + require.Equal(t, c.header, h) }) } } diff --git a/base/interleavedframe.go b/base/interleavedframe.go index 9f554cb2..0ed069b5 100644 --- a/base/interleavedframe.go +++ b/base/interleavedframe.go @@ -11,8 +11,8 @@ const ( interleavedFrameMagicByte = 0x24 ) -// ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response. -func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, br *bufio.Reader) (interface{}, error) { +// ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Response. +func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, req *Request, br *bufio.Reader) (interface{}, error) { b, err := br.ReadByte() if err != nil { return nil, err @@ -27,11 +27,15 @@ func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, br *bufio.Reader) ( return frame, err } - return ReadResponse(br) + err = req.Read(br) + if err != nil { + return nil, err + } + return req, nil } -// ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Response. -func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, br *bufio.Reader) (interface{}, error) { +// ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response. +func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, res *Response, br *bufio.Reader) (interface{}, error) { b, err := br.ReadByte() if err != nil { return nil, err @@ -46,7 +50,11 @@ func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, br *bufio.Reader) (i return frame, err } - return ReadRequest(br) + err = res.Read(br) + if err != nil { + return nil, err + } + return res, nil } // InterleavedFrame is an interleaved frame, and allows to transfer binary data diff --git a/base/request.go b/base/request.go index 72e4028d..017dc752 100644 --- a/base/request.go +++ b/base/request.go @@ -53,66 +53,65 @@ type Request struct { SkipResponse bool } -// ReadRequest reads a request. -func ReadRequest(rb *bufio.Reader) (*Request, error) { - req := &Request{} - +// Read reads a request. +func (req *Request) Read(rb *bufio.Reader) error { byts, err := readBytesLimited(rb, ' ', requestMaxLethodLength) if err != nil { - return nil, err + return err } req.Method = Method(byts[:len(byts)-1]) if req.Method == "" { - return nil, fmt.Errorf("empty method") + return fmt.Errorf("empty method") } byts, err = readBytesLimited(rb, ' ', requestMaxPathLength) if err != nil { - return nil, err + return err } rawUrl := string(byts[:len(byts)-1]) if rawUrl == "" { - return nil, fmt.Errorf("empty url") + return fmt.Errorf("empty url") } ur, err := url.Parse(rawUrl) if err != nil { - return nil, fmt.Errorf("unable to parse url (%v)", rawUrl) + return fmt.Errorf("unable to parse url (%v)", rawUrl) } req.Url = ur if req.Url.Scheme != "rtsp" { - return nil, fmt.Errorf("invalid url scheme (%v)", rawUrl) + return fmt.Errorf("invalid url scheme (%v)", rawUrl) } byts, err = readBytesLimited(rb, '\r', requestMaxProtocolLength) if err != nil { - return nil, err + return err } proto := string(byts[:len(byts)-1]) if proto != rtspProtocol10 { - return nil, fmt.Errorf("expected '%s', got '%s'", rtspProtocol10, proto) + return fmt.Errorf("expected '%s', got '%s'", rtspProtocol10, proto) } err = readByteEqual(rb, '\n') if err != nil { - return nil, err + return err } - req.Header, err = headerRead(rb) + req.Header = make(Header) + err = req.Header.read(rb) if err != nil { - return nil, err + return err } - req.Content, err = readContent(rb, req.Header) + req.Content, err = contentRead(rb, req.Header) if err != nil { - return nil, err + return err } - return req, nil + return nil } // Write writes a request. @@ -139,7 +138,7 @@ func (req Request) Write(bw *bufio.Writer) error { return err } - err = writeContent(bw, req.Content) + err = contentWrite(bw, req.Content) if err != nil { return err } diff --git a/base/request_test.go b/base/request_test.go index b8225e9c..79b6d424 100644 --- a/base/request_test.go +++ b/base/request_test.go @@ -12,7 +12,7 @@ import ( var casesRequest = []struct { name string byts []byte - req *Request + req Request }{ { "options", @@ -21,7 +21,7 @@ var casesRequest = []struct { "Proxy-Require: gzipped-messages\r\n" + "Require: implicit-play\r\n" + "\r\n"), - &Request{ + Request{ Method: "OPTIONS", Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"}, Header: Header{ @@ -36,7 +36,7 @@ var casesRequest = []struct { []byte("DESCRIBE rtsp://example.com/media.mp4 RTSP/1.0\r\n" + "CSeq: 2\r\n" + "\r\n"), - &Request{ + Request{ Method: "DESCRIBE", Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"}, Header: Header{ @@ -64,7 +64,7 @@ var casesRequest = []struct { "a=recvonly\n" + "m=audio 3456 RTP/AVP 0\n" + "m=video 2232 RTP/AVP 31\n"), - &Request{ + Request{ Method: "ANNOUNCE", Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"}, Header: Header{ @@ -98,7 +98,7 @@ var casesRequest = []struct { "\r\n" + "packets_received\n" + "jitter\n"), - &Request{ + Request{ Method: "GET_PARAMETER", Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"}, Header: Header{ @@ -115,9 +115,10 @@ var casesRequest = []struct { } func TestRequestRead(t *testing.T) { + var req Request for _, c := range casesRequest { t.Run(c.name, func(t *testing.T) { - req, err := ReadRequest(bufio.NewReader(bytes.NewBuffer(c.byts))) + err := req.Read(bufio.NewReader(bytes.NewBuffer(c.byts))) require.NoError(t, err) require.Equal(t, c.req, req) }) diff --git a/base/response.go b/base/response.go index 8b877631..341fbed0 100644 --- a/base/response.go +++ b/base/response.go @@ -132,58 +132,57 @@ type Response struct { Content []byte } -// ReadResponse reads a response. -func ReadResponse(rb *bufio.Reader) (*Response, error) { - res := &Response{} - +// Read reads a response. +func (res *Response) Read(rb *bufio.Reader) error { byts, err := readBytesLimited(rb, ' ', 255) if err != nil { - return nil, err + return err } proto := string(byts[:len(byts)-1]) if proto != rtspProtocol10 { - return nil, fmt.Errorf("expected '%s', got '%s'", rtspProtocol10, proto) + return fmt.Errorf("expected '%s', got '%s'", rtspProtocol10, proto) } byts, err = readBytesLimited(rb, ' ', 4) if err != nil { - return nil, err + return err } statusCodeStr := string(byts[:len(byts)-1]) statusCode64, err := strconv.ParseInt(statusCodeStr, 10, 32) if err != nil { - return nil, fmt.Errorf("unable to parse status code") + return fmt.Errorf("unable to parse status code") } res.StatusCode = StatusCode(statusCode64) byts, err = readBytesLimited(rb, '\r', 255) if err != nil { - return nil, err + return err } res.StatusMessage = string(byts[:len(byts)-1]) if len(res.StatusMessage) == 0 { - return nil, fmt.Errorf("empty status") + return fmt.Errorf("empty status") } err = readByteEqual(rb, '\n') if err != nil { - return nil, err + return err } - res.Header, err = headerRead(rb) + res.Header = make(Header) + err = res.Header.read(rb) if err != nil { - return nil, err + return err } - res.Content, err = readContent(rb, res.Header) + res.Content, err = contentRead(rb, res.Header) if err != nil { - return nil, err + return err } - return res, nil + return nil } // Write writes a Response. @@ -208,7 +207,7 @@ func (res Response) Write(bw *bufio.Writer) error { return err } - err = writeContent(bw, res.Content) + err = contentWrite(bw, res.Content) if err != nil { return err } diff --git a/base/response_test.go b/base/response_test.go index 295f7003..c165af2b 100644 --- a/base/response_test.go +++ b/base/response_test.go @@ -11,7 +11,7 @@ import ( var casesResponse = []struct { name string byts []byte - res *Response + res Response }{ { "ok with single header", @@ -20,7 +20,7 @@ var casesResponse = []struct { "Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" + "\r\n", ), - &Response{ + Response{ StatusCode: StatusOK, StatusMessage: "OK", Header: Header{ @@ -39,7 +39,7 @@ var casesResponse = []struct { "WWW-Authenticate: Basic realm=\"4419b63f5e51\"\r\n" + "\r\n", ), - &Response{ + Response{ StatusCode: StatusOK, StatusMessage: "OK", Header: Header{ @@ -78,7 +78,7 @@ var casesResponse = []struct { "a=AvgBitRate:integer;65790\n" + "a=StreamName:string;\"hinted audio track\"\n", ), - &Response{ + Response{ StatusCode: 200, StatusMessage: "OK", Header: Header{ @@ -109,9 +109,10 @@ var casesResponse = []struct { } func TestResponseRead(t *testing.T) { + var res Response for _, c := range casesResponse { t.Run(c.name, func(t *testing.T) { - res, err := ReadResponse(bufio.NewReader(bytes.NewBuffer(c.byts))) + err := res.Read(bufio.NewReader(bytes.NewBuffer(c.byts))) require.NoError(t, err) require.Equal(t, c.res, res) }) diff --git a/base/utils.go b/base/utils.go index eae7e4b4..80d8ddab 100644 --- a/base/utils.go +++ b/base/utils.go @@ -3,8 +3,6 @@ package base import ( "bufio" "fmt" - "io" - "strconv" ) const ( @@ -38,40 +36,3 @@ func readBytesLimited(rb *bufio.Reader, delim byte, n int) ([]byte, error) { } return nil, fmt.Errorf("buffer length exceeds %d", n) } - -func readContent(rb *bufio.Reader, header Header) ([]byte, error) { - cls, ok := header["Content-Length"] - if !ok || len(cls) != 1 { - return nil, nil - } - - cl, err := strconv.ParseInt(cls[0], 10, 64) - if err != nil { - return nil, fmt.Errorf("invalid Content-Length") - } - - if cl > rtspMaxContentLength { - return nil, fmt.Errorf("Content-Length exceeds %d", rtspMaxContentLength) - } - - ret := make([]byte, cl) - n, err := io.ReadFull(rb, ret) - if err != nil && n != len(ret) { - return nil, err - } - - return ret, nil -} - -func writeContent(bw *bufio.Writer, content []byte) error { - if len(content) == 0 { - return nil - } - - _, err := bw.Write(content) - if err != nil { - return err - } - - return nil -} diff --git a/connclient.go b/connclient.go index f31efc40..9864a8f2 100644 --- a/connclient.go +++ b/connclient.go @@ -184,9 +184,10 @@ func (c *ConnClient) NetConn() net.Conn { func (c *ConnClient) readFrameTCPOrResponse() (interface{}, error) { frame := c.tcpFrames.next() + res := base.Response{} c.conf.Conn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout)) - return base.ReadInterleavedFrameOrResponse(frame, c.br) + return base.ReadInterleavedFrameOrResponse(frame, &res, c.br) } // ReadFrameTCP reads an InterleavedFrame. @@ -735,7 +736,8 @@ func (c *ConnClient) LoopUDP() error { go func() { for { c.conf.Conn.SetReadDeadline(time.Now().Add(clientUDPKeepalivePeriod + c.conf.ReadTimeout)) - _, err := base.ReadResponse(c.br) + var res base.Response + err := res.Read(c.br) if err != nil { readDone <- err return diff --git a/connserver.go b/connserver.go index cb844d86..ebcd2b07 100644 --- a/connserver.go +++ b/connserver.go @@ -73,18 +73,27 @@ func (s *ConnServer) NetConn() net.Conn { // ReadRequest reads a Request. func (s *ConnServer) ReadRequest() (*base.Request, error) { + req := &base.Request{} + s.conf.Conn.SetReadDeadline(time.Time{}) // disable deadline - return base.ReadRequest(s.br) + err := req.Read(s.br) + if err != nil { + return nil, err + } + + return req, nil } // ReadFrameTCPOrRequest reads an InterleavedFrame or a Request. func (s *ConnServer) ReadFrameTCPOrRequest(timeout bool) (interface{}, error) { + frame := s.tcpFrames.next() + req := base.Request{} + if timeout { s.conf.Conn.SetReadDeadline(time.Now().Add(s.conf.ReadTimeout)) } - frame := s.tcpFrames.next() - return base.ReadInterleavedFrameOrRequest(frame, s.br) + return base.ReadInterleavedFrameOrRequest(frame, &req, s.br) } // WriteResponse writes a Response.