reuse structs when reading Requests, Responses and Headers

This commit is contained in:
aler9
2020-10-06 10:07:57 +02:00
parent cbf56d59d9
commit eba2fb39d1
11 changed files with 135 additions and 111 deletions

45
base/content.go Normal file
View File

@@ -0,0 +1,45 @@
package base
import (
"bufio"
"fmt"
"io"
"strconv"
)
func contentRead(rb *bufio.Reader, header Header) ([]byte, error) {
cls, ok := header["Content-Length"]
if !ok || len(cls) != 1 {
return nil, nil
}
cl, err := strconv.ParseInt(cls[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid Content-Length")
}
if cl > rtspMaxContentLength {
return nil, fmt.Errorf("Content-Length exceeds %d", rtspMaxContentLength)
}
ret := make([]byte, cl)
n, err := io.ReadFull(rb, ret)
if err != nil && n != len(ret) {
return nil, err
}
return ret, nil
}
func contentWrite(bw *bufio.Writer, content []byte) error {
if len(content) == 0 {
return nil
}
_, err := bw.Write(content)
if err != nil {
return err
}
return nil
}

View File

@@ -34,32 +34,30 @@ type HeaderValue []string
// Header is a RTSP reader, present in both Requests and Responses. // Header is a RTSP reader, present in both Requests and Responses.
type Header map[string]HeaderValue type Header map[string]HeaderValue
func headerRead(rb *bufio.Reader) (Header, error) { func (h Header) read(rb *bufio.Reader) error {
h := make(Header)
for { for {
byt, err := rb.ReadByte() byt, err := rb.ReadByte()
if err != nil { if err != nil {
return nil, err return err
} }
if byt == '\r' { if byt == '\r' {
err := readByteEqual(rb, '\n') err := readByteEqual(rb, '\n')
if err != nil { if err != nil {
return nil, err return err
} }
break break
} }
if len(h) >= headerMaxEntryCount { if len(h) >= headerMaxEntryCount {
return nil, fmt.Errorf("headers count exceeds %d", headerMaxEntryCount) return fmt.Errorf("headers count exceeds %d", headerMaxEntryCount)
} }
key := string([]byte{byt}) key := string([]byte{byt})
byts, err := readBytesLimited(rb, ':', headerMaxKeyLength-1) byts, err := readBytesLimited(rb, ':', headerMaxKeyLength-1)
if err != nil { if err != nil {
return nil, err return err
} }
key += string(byts[:len(byts)-1]) key += string(byts[:len(byts)-1])
key = headerKeyNormalize(key) key = headerKeyNormalize(key)
@@ -69,7 +67,7 @@ func headerRead(rb *bufio.Reader) (Header, error) {
for { for {
byt, err := rb.ReadByte() byt, err := rb.ReadByte()
if err != nil { if err != nil {
return nil, err return err
} }
if byt != ' ' { if byt != ' ' {
@@ -80,23 +78,23 @@ func headerRead(rb *bufio.Reader) (Header, error) {
byts, err = readBytesLimited(rb, '\r', headerMaxValueLength) byts, err = readBytesLimited(rb, '\r', headerMaxValueLength)
if err != nil { if err != nil {
return nil, err return err
} }
val := string(byts[:len(byts)-1]) val := string(byts[:len(byts)-1])
if len(val) == 0 { if len(val) == 0 {
return nil, fmt.Errorf("empty header value") return fmt.Errorf("empty header value")
} }
err = readByteEqual(rb, '\n') err = readByteEqual(rb, '\n')
if err != nil { if err != nil {
return nil, err return err
} }
h[key] = append(h[key], val) h[key] = append(h[key], val)
} }
return h, nil return nil
} }
func (h Header) write(wb *bufio.Writer) error { func (h Header) write(wb *bufio.Writer) error {

View File

@@ -93,9 +93,10 @@ var casesHeader = []struct {
func TestHeaderRead(t *testing.T) { func TestHeaderRead(t *testing.T) {
for _, c := range casesHeader { for _, c := range casesHeader {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
req, err := headerRead(bufio.NewReader(bytes.NewBuffer(c.dec))) h := make(Header)
err := h.read(bufio.NewReader(bytes.NewBuffer(c.dec)))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c.header, req) require.Equal(t, c.header, h)
}) })
} }
} }

View File

@@ -11,27 +11,8 @@ const (
interleavedFrameMagicByte = 0x24 interleavedFrameMagicByte = 0x24
) )
// ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response.
func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, br *bufio.Reader) (interface{}, error) {
b, err := br.ReadByte()
if err != nil {
return nil, err
}
br.UnreadByte()
if b == interleavedFrameMagicByte {
err := frame.Read(br)
if err != nil {
return nil, err
}
return frame, err
}
return ReadResponse(br)
}
// ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Response. // ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Response.
func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, br *bufio.Reader) (interface{}, error) { func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, req *Request, br *bufio.Reader) (interface{}, error) {
b, err := br.ReadByte() b, err := br.ReadByte()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -46,7 +27,34 @@ func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, br *bufio.Reader) (i
return frame, err return frame, err
} }
return ReadRequest(br) err = req.Read(br)
if err != nil {
return nil, err
}
return req, nil
}
// ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response.
func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, res *Response, br *bufio.Reader) (interface{}, error) {
b, err := br.ReadByte()
if err != nil {
return nil, err
}
br.UnreadByte()
if b == interleavedFrameMagicByte {
err := frame.Read(br)
if err != nil {
return nil, err
}
return frame, err
}
err = res.Read(br)
if err != nil {
return nil, err
}
return res, nil
} }
// InterleavedFrame is an interleaved frame, and allows to transfer binary data // InterleavedFrame is an interleaved frame, and allows to transfer binary data

View File

@@ -53,66 +53,65 @@ type Request struct {
SkipResponse bool SkipResponse bool
} }
// ReadRequest reads a request. // Read reads a request.
func ReadRequest(rb *bufio.Reader) (*Request, error) { func (req *Request) Read(rb *bufio.Reader) error {
req := &Request{}
byts, err := readBytesLimited(rb, ' ', requestMaxLethodLength) byts, err := readBytesLimited(rb, ' ', requestMaxLethodLength)
if err != nil { if err != nil {
return nil, err return err
} }
req.Method = Method(byts[:len(byts)-1]) req.Method = Method(byts[:len(byts)-1])
if req.Method == "" { if req.Method == "" {
return nil, fmt.Errorf("empty method") return fmt.Errorf("empty method")
} }
byts, err = readBytesLimited(rb, ' ', requestMaxPathLength) byts, err = readBytesLimited(rb, ' ', requestMaxPathLength)
if err != nil { if err != nil {
return nil, err return err
} }
rawUrl := string(byts[:len(byts)-1]) rawUrl := string(byts[:len(byts)-1])
if rawUrl == "" { if rawUrl == "" {
return nil, fmt.Errorf("empty url") return fmt.Errorf("empty url")
} }
ur, err := url.Parse(rawUrl) ur, err := url.Parse(rawUrl)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to parse url (%v)", rawUrl) return fmt.Errorf("unable to parse url (%v)", rawUrl)
} }
req.Url = ur req.Url = ur
if req.Url.Scheme != "rtsp" { if req.Url.Scheme != "rtsp" {
return nil, fmt.Errorf("invalid url scheme (%v)", rawUrl) return fmt.Errorf("invalid url scheme (%v)", rawUrl)
} }
byts, err = readBytesLimited(rb, '\r', requestMaxProtocolLength) byts, err = readBytesLimited(rb, '\r', requestMaxProtocolLength)
if err != nil { if err != nil {
return nil, err return err
} }
proto := string(byts[:len(byts)-1]) proto := string(byts[:len(byts)-1])
if proto != rtspProtocol10 { if proto != rtspProtocol10 {
return nil, fmt.Errorf("expected '%s', got '%s'", rtspProtocol10, proto) return fmt.Errorf("expected '%s', got '%s'", rtspProtocol10, proto)
} }
err = readByteEqual(rb, '\n') err = readByteEqual(rb, '\n')
if err != nil { if err != nil {
return nil, err return err
} }
req.Header, err = headerRead(rb) req.Header = make(Header)
err = req.Header.read(rb)
if err != nil { if err != nil {
return nil, err return err
} }
req.Content, err = readContent(rb, req.Header) req.Content, err = contentRead(rb, req.Header)
if err != nil { if err != nil {
return nil, err return err
} }
return req, nil return nil
} }
// Write writes a request. // Write writes a request.
@@ -139,7 +138,7 @@ func (req Request) Write(bw *bufio.Writer) error {
return err return err
} }
err = writeContent(bw, req.Content) err = contentWrite(bw, req.Content)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -12,7 +12,7 @@ import (
var casesRequest = []struct { var casesRequest = []struct {
name string name string
byts []byte byts []byte
req *Request req Request
}{ }{
{ {
"options", "options",
@@ -21,7 +21,7 @@ var casesRequest = []struct {
"Proxy-Require: gzipped-messages\r\n" + "Proxy-Require: gzipped-messages\r\n" +
"Require: implicit-play\r\n" + "Require: implicit-play\r\n" +
"\r\n"), "\r\n"),
&Request{ Request{
Method: "OPTIONS", Method: "OPTIONS",
Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"}, Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"},
Header: Header{ Header: Header{
@@ -36,7 +36,7 @@ var casesRequest = []struct {
[]byte("DESCRIBE rtsp://example.com/media.mp4 RTSP/1.0\r\n" + []byte("DESCRIBE rtsp://example.com/media.mp4 RTSP/1.0\r\n" +
"CSeq: 2\r\n" + "CSeq: 2\r\n" +
"\r\n"), "\r\n"),
&Request{ Request{
Method: "DESCRIBE", Method: "DESCRIBE",
Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"}, Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"},
Header: Header{ Header: Header{
@@ -64,7 +64,7 @@ var casesRequest = []struct {
"a=recvonly\n" + "a=recvonly\n" +
"m=audio 3456 RTP/AVP 0\n" + "m=audio 3456 RTP/AVP 0\n" +
"m=video 2232 RTP/AVP 31\n"), "m=video 2232 RTP/AVP 31\n"),
&Request{ Request{
Method: "ANNOUNCE", Method: "ANNOUNCE",
Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"}, Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"},
Header: Header{ Header: Header{
@@ -98,7 +98,7 @@ var casesRequest = []struct {
"\r\n" + "\r\n" +
"packets_received\n" + "packets_received\n" +
"jitter\n"), "jitter\n"),
&Request{ Request{
Method: "GET_PARAMETER", Method: "GET_PARAMETER",
Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"}, Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"},
Header: Header{ Header: Header{
@@ -115,9 +115,10 @@ var casesRequest = []struct {
} }
func TestRequestRead(t *testing.T) { func TestRequestRead(t *testing.T) {
var req Request
for _, c := range casesRequest { for _, c := range casesRequest {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
req, err := ReadRequest(bufio.NewReader(bytes.NewBuffer(c.byts))) err := req.Read(bufio.NewReader(bytes.NewBuffer(c.byts)))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c.req, req) require.Equal(t, c.req, req)
}) })

View File

@@ -132,58 +132,57 @@ type Response struct {
Content []byte Content []byte
} }
// ReadResponse reads a response. // Read reads a response.
func ReadResponse(rb *bufio.Reader) (*Response, error) { func (res *Response) Read(rb *bufio.Reader) error {
res := &Response{}
byts, err := readBytesLimited(rb, ' ', 255) byts, err := readBytesLimited(rb, ' ', 255)
if err != nil { if err != nil {
return nil, err return err
} }
proto := string(byts[:len(byts)-1]) proto := string(byts[:len(byts)-1])
if proto != rtspProtocol10 { if proto != rtspProtocol10 {
return nil, fmt.Errorf("expected '%s', got '%s'", rtspProtocol10, proto) return fmt.Errorf("expected '%s', got '%s'", rtspProtocol10, proto)
} }
byts, err = readBytesLimited(rb, ' ', 4) byts, err = readBytesLimited(rb, ' ', 4)
if err != nil { if err != nil {
return nil, err return err
} }
statusCodeStr := string(byts[:len(byts)-1]) statusCodeStr := string(byts[:len(byts)-1])
statusCode64, err := strconv.ParseInt(statusCodeStr, 10, 32) statusCode64, err := strconv.ParseInt(statusCodeStr, 10, 32)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to parse status code") return fmt.Errorf("unable to parse status code")
} }
res.StatusCode = StatusCode(statusCode64) res.StatusCode = StatusCode(statusCode64)
byts, err = readBytesLimited(rb, '\r', 255) byts, err = readBytesLimited(rb, '\r', 255)
if err != nil { if err != nil {
return nil, err return err
} }
res.StatusMessage = string(byts[:len(byts)-1]) res.StatusMessage = string(byts[:len(byts)-1])
if len(res.StatusMessage) == 0 { if len(res.StatusMessage) == 0 {
return nil, fmt.Errorf("empty status") return fmt.Errorf("empty status")
} }
err = readByteEqual(rb, '\n') err = readByteEqual(rb, '\n')
if err != nil { if err != nil {
return nil, err return err
} }
res.Header, err = headerRead(rb) res.Header = make(Header)
err = res.Header.read(rb)
if err != nil { if err != nil {
return nil, err return err
} }
res.Content, err = readContent(rb, res.Header) res.Content, err = contentRead(rb, res.Header)
if err != nil { if err != nil {
return nil, err return err
} }
return res, nil return nil
} }
// Write writes a Response. // Write writes a Response.
@@ -208,7 +207,7 @@ func (res Response) Write(bw *bufio.Writer) error {
return err return err
} }
err = writeContent(bw, res.Content) err = contentWrite(bw, res.Content)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -11,7 +11,7 @@ import (
var casesResponse = []struct { var casesResponse = []struct {
name string name string
byts []byte byts []byte
res *Response res Response
}{ }{
{ {
"ok with single header", "ok with single header",
@@ -20,7 +20,7 @@ var casesResponse = []struct {
"Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" + "Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" +
"\r\n", "\r\n",
), ),
&Response{ Response{
StatusCode: StatusOK, StatusCode: StatusOK,
StatusMessage: "OK", StatusMessage: "OK",
Header: Header{ Header: Header{
@@ -39,7 +39,7 @@ var casesResponse = []struct {
"WWW-Authenticate: Basic realm=\"4419b63f5e51\"\r\n" + "WWW-Authenticate: Basic realm=\"4419b63f5e51\"\r\n" +
"\r\n", "\r\n",
), ),
&Response{ Response{
StatusCode: StatusOK, StatusCode: StatusOK,
StatusMessage: "OK", StatusMessage: "OK",
Header: Header{ Header: Header{
@@ -78,7 +78,7 @@ var casesResponse = []struct {
"a=AvgBitRate:integer;65790\n" + "a=AvgBitRate:integer;65790\n" +
"a=StreamName:string;\"hinted audio track\"\n", "a=StreamName:string;\"hinted audio track\"\n",
), ),
&Response{ Response{
StatusCode: 200, StatusCode: 200,
StatusMessage: "OK", StatusMessage: "OK",
Header: Header{ Header: Header{
@@ -109,9 +109,10 @@ var casesResponse = []struct {
} }
func TestResponseRead(t *testing.T) { func TestResponseRead(t *testing.T) {
var res Response
for _, c := range casesResponse { for _, c := range casesResponse {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
res, err := ReadResponse(bufio.NewReader(bytes.NewBuffer(c.byts))) err := res.Read(bufio.NewReader(bytes.NewBuffer(c.byts)))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c.res, res) require.Equal(t, c.res, res)
}) })

View File

@@ -3,8 +3,6 @@ package base
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"io"
"strconv"
) )
const ( const (
@@ -38,40 +36,3 @@ func readBytesLimited(rb *bufio.Reader, delim byte, n int) ([]byte, error) {
} }
return nil, fmt.Errorf("buffer length exceeds %d", n) return nil, fmt.Errorf("buffer length exceeds %d", n)
} }
func readContent(rb *bufio.Reader, header Header) ([]byte, error) {
cls, ok := header["Content-Length"]
if !ok || len(cls) != 1 {
return nil, nil
}
cl, err := strconv.ParseInt(cls[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid Content-Length")
}
if cl > rtspMaxContentLength {
return nil, fmt.Errorf("Content-Length exceeds %d", rtspMaxContentLength)
}
ret := make([]byte, cl)
n, err := io.ReadFull(rb, ret)
if err != nil && n != len(ret) {
return nil, err
}
return ret, nil
}
func writeContent(bw *bufio.Writer, content []byte) error {
if len(content) == 0 {
return nil
}
_, err := bw.Write(content)
if err != nil {
return err
}
return nil
}

View File

@@ -184,9 +184,10 @@ func (c *ConnClient) NetConn() net.Conn {
func (c *ConnClient) readFrameTCPOrResponse() (interface{}, error) { func (c *ConnClient) readFrameTCPOrResponse() (interface{}, error) {
frame := c.tcpFrames.next() frame := c.tcpFrames.next()
res := base.Response{}
c.conf.Conn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout)) c.conf.Conn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout))
return base.ReadInterleavedFrameOrResponse(frame, c.br) return base.ReadInterleavedFrameOrResponse(frame, &res, c.br)
} }
// ReadFrameTCP reads an InterleavedFrame. // ReadFrameTCP reads an InterleavedFrame.
@@ -735,7 +736,8 @@ func (c *ConnClient) LoopUDP() error {
go func() { go func() {
for { for {
c.conf.Conn.SetReadDeadline(time.Now().Add(clientUDPKeepalivePeriod + c.conf.ReadTimeout)) c.conf.Conn.SetReadDeadline(time.Now().Add(clientUDPKeepalivePeriod + c.conf.ReadTimeout))
_, err := base.ReadResponse(c.br) var res base.Response
err := res.Read(c.br)
if err != nil { if err != nil {
readDone <- err readDone <- err
return return

View File

@@ -73,18 +73,27 @@ func (s *ConnServer) NetConn() net.Conn {
// ReadRequest reads a Request. // ReadRequest reads a Request.
func (s *ConnServer) ReadRequest() (*base.Request, error) { func (s *ConnServer) ReadRequest() (*base.Request, error) {
req := &base.Request{}
s.conf.Conn.SetReadDeadline(time.Time{}) // disable deadline s.conf.Conn.SetReadDeadline(time.Time{}) // disable deadline
return base.ReadRequest(s.br) err := req.Read(s.br)
if err != nil {
return nil, err
}
return req, nil
} }
// ReadFrameTCPOrRequest reads an InterleavedFrame or a Request. // ReadFrameTCPOrRequest reads an InterleavedFrame or a Request.
func (s *ConnServer) ReadFrameTCPOrRequest(timeout bool) (interface{}, error) { func (s *ConnServer) ReadFrameTCPOrRequest(timeout bool) (interface{}, error) {
frame := s.tcpFrames.next()
req := base.Request{}
if timeout { if timeout {
s.conf.Conn.SetReadDeadline(time.Now().Add(s.conf.ReadTimeout)) s.conf.Conn.SetReadDeadline(time.Now().Add(s.conf.ReadTimeout))
} }
frame := s.tcpFrames.next() return base.ReadInterleavedFrameOrRequest(frame, &req, s.br)
return base.ReadInterleavedFrameOrRequest(frame, s.br)
} }
// WriteResponse writes a Response. // WriteResponse writes a Response.