reuse structs when reading Requests, Responses and Headers

This commit is contained in:
aler9
2020-10-06 10:07:57 +02:00
parent cbf56d59d9
commit eba2fb39d1
11 changed files with 135 additions and 111 deletions

45
base/content.go Normal file
View File

@@ -0,0 +1,45 @@
package base
import (
"bufio"
"fmt"
"io"
"strconv"
)
func contentRead(rb *bufio.Reader, header Header) ([]byte, error) {
cls, ok := header["Content-Length"]
if !ok || len(cls) != 1 {
return nil, nil
}
cl, err := strconv.ParseInt(cls[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid Content-Length")
}
if cl > rtspMaxContentLength {
return nil, fmt.Errorf("Content-Length exceeds %d", rtspMaxContentLength)
}
ret := make([]byte, cl)
n, err := io.ReadFull(rb, ret)
if err != nil && n != len(ret) {
return nil, err
}
return ret, nil
}
func contentWrite(bw *bufio.Writer, content []byte) error {
if len(content) == 0 {
return nil
}
_, err := bw.Write(content)
if err != nil {
return err
}
return nil
}

View File

@@ -34,32 +34,30 @@ type HeaderValue []string
// Header is a RTSP reader, present in both Requests and Responses.
type Header map[string]HeaderValue
func headerRead(rb *bufio.Reader) (Header, error) {
h := make(Header)
func (h Header) read(rb *bufio.Reader) error {
for {
byt, err := rb.ReadByte()
if err != nil {
return nil, err
return err
}
if byt == '\r' {
err := readByteEqual(rb, '\n')
if err != nil {
return nil, err
return err
}
break
}
if len(h) >= headerMaxEntryCount {
return nil, fmt.Errorf("headers count exceeds %d", headerMaxEntryCount)
return fmt.Errorf("headers count exceeds %d", headerMaxEntryCount)
}
key := string([]byte{byt})
byts, err := readBytesLimited(rb, ':', headerMaxKeyLength-1)
if err != nil {
return nil, err
return err
}
key += string(byts[:len(byts)-1])
key = headerKeyNormalize(key)
@@ -69,7 +67,7 @@ func headerRead(rb *bufio.Reader) (Header, error) {
for {
byt, err := rb.ReadByte()
if err != nil {
return nil, err
return err
}
if byt != ' ' {
@@ -80,23 +78,23 @@ func headerRead(rb *bufio.Reader) (Header, error) {
byts, err = readBytesLimited(rb, '\r', headerMaxValueLength)
if err != nil {
return nil, err
return err
}
val := string(byts[:len(byts)-1])
if len(val) == 0 {
return nil, fmt.Errorf("empty header value")
return fmt.Errorf("empty header value")
}
err = readByteEqual(rb, '\n')
if err != nil {
return nil, err
return err
}
h[key] = append(h[key], val)
}
return h, nil
return nil
}
func (h Header) write(wb *bufio.Writer) error {

View File

@@ -93,9 +93,10 @@ var casesHeader = []struct {
func TestHeaderRead(t *testing.T) {
for _, c := range casesHeader {
t.Run(c.name, func(t *testing.T) {
req, err := headerRead(bufio.NewReader(bytes.NewBuffer(c.dec)))
h := make(Header)
err := h.read(bufio.NewReader(bytes.NewBuffer(c.dec)))
require.NoError(t, err)
require.Equal(t, c.header, req)
require.Equal(t, c.header, h)
})
}
}

View File

@@ -11,8 +11,8 @@ const (
interleavedFrameMagicByte = 0x24
)
// ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response.
func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, br *bufio.Reader) (interface{}, error) {
// ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Response.
func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, req *Request, br *bufio.Reader) (interface{}, error) {
b, err := br.ReadByte()
if err != nil {
return nil, err
@@ -27,11 +27,15 @@ func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, br *bufio.Reader) (
return frame, err
}
return ReadResponse(br)
err = req.Read(br)
if err != nil {
return nil, err
}
return req, nil
}
// ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Response.
func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, br *bufio.Reader) (interface{}, error) {
// ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response.
func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, res *Response, br *bufio.Reader) (interface{}, error) {
b, err := br.ReadByte()
if err != nil {
return nil, err
@@ -46,7 +50,11 @@ func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, br *bufio.Reader) (i
return frame, err
}
return ReadRequest(br)
err = res.Read(br)
if err != nil {
return nil, err
}
return res, nil
}
// InterleavedFrame is an interleaved frame, and allows to transfer binary data

View File

@@ -53,66 +53,65 @@ type Request struct {
SkipResponse bool
}
// ReadRequest reads a request.
func ReadRequest(rb *bufio.Reader) (*Request, error) {
req := &Request{}
// Read reads a request.
func (req *Request) Read(rb *bufio.Reader) error {
byts, err := readBytesLimited(rb, ' ', requestMaxLethodLength)
if err != nil {
return nil, err
return err
}
req.Method = Method(byts[:len(byts)-1])
if req.Method == "" {
return nil, fmt.Errorf("empty method")
return fmt.Errorf("empty method")
}
byts, err = readBytesLimited(rb, ' ', requestMaxPathLength)
if err != nil {
return nil, err
return err
}
rawUrl := string(byts[:len(byts)-1])
if rawUrl == "" {
return nil, fmt.Errorf("empty url")
return fmt.Errorf("empty url")
}
ur, err := url.Parse(rawUrl)
if err != nil {
return nil, fmt.Errorf("unable to parse url (%v)", rawUrl)
return fmt.Errorf("unable to parse url (%v)", rawUrl)
}
req.Url = ur
if req.Url.Scheme != "rtsp" {
return nil, fmt.Errorf("invalid url scheme (%v)", rawUrl)
return fmt.Errorf("invalid url scheme (%v)", rawUrl)
}
byts, err = readBytesLimited(rb, '\r', requestMaxProtocolLength)
if err != nil {
return nil, err
return err
}
proto := string(byts[:len(byts)-1])
if proto != rtspProtocol10 {
return nil, fmt.Errorf("expected '%s', got '%s'", rtspProtocol10, proto)
return fmt.Errorf("expected '%s', got '%s'", rtspProtocol10, proto)
}
err = readByteEqual(rb, '\n')
if err != nil {
return nil, err
return err
}
req.Header, err = headerRead(rb)
req.Header = make(Header)
err = req.Header.read(rb)
if err != nil {
return nil, err
return err
}
req.Content, err = readContent(rb, req.Header)
req.Content, err = contentRead(rb, req.Header)
if err != nil {
return nil, err
return err
}
return req, nil
return nil
}
// Write writes a request.
@@ -139,7 +138,7 @@ func (req Request) Write(bw *bufio.Writer) error {
return err
}
err = writeContent(bw, req.Content)
err = contentWrite(bw, req.Content)
if err != nil {
return err
}

View File

@@ -12,7 +12,7 @@ import (
var casesRequest = []struct {
name string
byts []byte
req *Request
req Request
}{
{
"options",
@@ -21,7 +21,7 @@ var casesRequest = []struct {
"Proxy-Require: gzipped-messages\r\n" +
"Require: implicit-play\r\n" +
"\r\n"),
&Request{
Request{
Method: "OPTIONS",
Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"},
Header: Header{
@@ -36,7 +36,7 @@ var casesRequest = []struct {
[]byte("DESCRIBE rtsp://example.com/media.mp4 RTSP/1.0\r\n" +
"CSeq: 2\r\n" +
"\r\n"),
&Request{
Request{
Method: "DESCRIBE",
Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"},
Header: Header{
@@ -64,7 +64,7 @@ var casesRequest = []struct {
"a=recvonly\n" +
"m=audio 3456 RTP/AVP 0\n" +
"m=video 2232 RTP/AVP 31\n"),
&Request{
Request{
Method: "ANNOUNCE",
Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"},
Header: Header{
@@ -98,7 +98,7 @@ var casesRequest = []struct {
"\r\n" +
"packets_received\n" +
"jitter\n"),
&Request{
Request{
Method: "GET_PARAMETER",
Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"},
Header: Header{
@@ -115,9 +115,10 @@ var casesRequest = []struct {
}
func TestRequestRead(t *testing.T) {
var req Request
for _, c := range casesRequest {
t.Run(c.name, func(t *testing.T) {
req, err := ReadRequest(bufio.NewReader(bytes.NewBuffer(c.byts)))
err := req.Read(bufio.NewReader(bytes.NewBuffer(c.byts)))
require.NoError(t, err)
require.Equal(t, c.req, req)
})

View File

@@ -132,58 +132,57 @@ type Response struct {
Content []byte
}
// ReadResponse reads a response.
func ReadResponse(rb *bufio.Reader) (*Response, error) {
res := &Response{}
// Read reads a response.
func (res *Response) Read(rb *bufio.Reader) error {
byts, err := readBytesLimited(rb, ' ', 255)
if err != nil {
return nil, err
return err
}
proto := string(byts[:len(byts)-1])
if proto != rtspProtocol10 {
return nil, fmt.Errorf("expected '%s', got '%s'", rtspProtocol10, proto)
return fmt.Errorf("expected '%s', got '%s'", rtspProtocol10, proto)
}
byts, err = readBytesLimited(rb, ' ', 4)
if err != nil {
return nil, err
return err
}
statusCodeStr := string(byts[:len(byts)-1])
statusCode64, err := strconv.ParseInt(statusCodeStr, 10, 32)
if err != nil {
return nil, fmt.Errorf("unable to parse status code")
return fmt.Errorf("unable to parse status code")
}
res.StatusCode = StatusCode(statusCode64)
byts, err = readBytesLimited(rb, '\r', 255)
if err != nil {
return nil, err
return err
}
res.StatusMessage = string(byts[:len(byts)-1])
if len(res.StatusMessage) == 0 {
return nil, fmt.Errorf("empty status")
return fmt.Errorf("empty status")
}
err = readByteEqual(rb, '\n')
if err != nil {
return nil, err
return err
}
res.Header, err = headerRead(rb)
res.Header = make(Header)
err = res.Header.read(rb)
if err != nil {
return nil, err
return err
}
res.Content, err = readContent(rb, res.Header)
res.Content, err = contentRead(rb, res.Header)
if err != nil {
return nil, err
return err
}
return res, nil
return nil
}
// Write writes a Response.
@@ -208,7 +207,7 @@ func (res Response) Write(bw *bufio.Writer) error {
return err
}
err = writeContent(bw, res.Content)
err = contentWrite(bw, res.Content)
if err != nil {
return err
}

View File

@@ -11,7 +11,7 @@ import (
var casesResponse = []struct {
name string
byts []byte
res *Response
res Response
}{
{
"ok with single header",
@@ -20,7 +20,7 @@ var casesResponse = []struct {
"Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" +
"\r\n",
),
&Response{
Response{
StatusCode: StatusOK,
StatusMessage: "OK",
Header: Header{
@@ -39,7 +39,7 @@ var casesResponse = []struct {
"WWW-Authenticate: Basic realm=\"4419b63f5e51\"\r\n" +
"\r\n",
),
&Response{
Response{
StatusCode: StatusOK,
StatusMessage: "OK",
Header: Header{
@@ -78,7 +78,7 @@ var casesResponse = []struct {
"a=AvgBitRate:integer;65790\n" +
"a=StreamName:string;\"hinted audio track\"\n",
),
&Response{
Response{
StatusCode: 200,
StatusMessage: "OK",
Header: Header{
@@ -109,9 +109,10 @@ var casesResponse = []struct {
}
func TestResponseRead(t *testing.T) {
var res Response
for _, c := range casesResponse {
t.Run(c.name, func(t *testing.T) {
res, err := ReadResponse(bufio.NewReader(bytes.NewBuffer(c.byts)))
err := res.Read(bufio.NewReader(bytes.NewBuffer(c.byts)))
require.NoError(t, err)
require.Equal(t, c.res, res)
})

View File

@@ -3,8 +3,6 @@ package base
import (
"bufio"
"fmt"
"io"
"strconv"
)
const (
@@ -38,40 +36,3 @@ func readBytesLimited(rb *bufio.Reader, delim byte, n int) ([]byte, error) {
}
return nil, fmt.Errorf("buffer length exceeds %d", n)
}
func readContent(rb *bufio.Reader, header Header) ([]byte, error) {
cls, ok := header["Content-Length"]
if !ok || len(cls) != 1 {
return nil, nil
}
cl, err := strconv.ParseInt(cls[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid Content-Length")
}
if cl > rtspMaxContentLength {
return nil, fmt.Errorf("Content-Length exceeds %d", rtspMaxContentLength)
}
ret := make([]byte, cl)
n, err := io.ReadFull(rb, ret)
if err != nil && n != len(ret) {
return nil, err
}
return ret, nil
}
func writeContent(bw *bufio.Writer, content []byte) error {
if len(content) == 0 {
return nil
}
_, err := bw.Write(content)
if err != nil {
return err
}
return nil
}

View File

@@ -184,9 +184,10 @@ func (c *ConnClient) NetConn() net.Conn {
func (c *ConnClient) readFrameTCPOrResponse() (interface{}, error) {
frame := c.tcpFrames.next()
res := base.Response{}
c.conf.Conn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout))
return base.ReadInterleavedFrameOrResponse(frame, c.br)
return base.ReadInterleavedFrameOrResponse(frame, &res, c.br)
}
// ReadFrameTCP reads an InterleavedFrame.
@@ -735,7 +736,8 @@ func (c *ConnClient) LoopUDP() error {
go func() {
for {
c.conf.Conn.SetReadDeadline(time.Now().Add(clientUDPKeepalivePeriod + c.conf.ReadTimeout))
_, err := base.ReadResponse(c.br)
var res base.Response
err := res.Read(c.br)
if err != nil {
readDone <- err
return

View File

@@ -73,18 +73,27 @@ func (s *ConnServer) NetConn() net.Conn {
// ReadRequest reads a Request.
func (s *ConnServer) ReadRequest() (*base.Request, error) {
req := &base.Request{}
s.conf.Conn.SetReadDeadline(time.Time{}) // disable deadline
return base.ReadRequest(s.br)
err := req.Read(s.br)
if err != nil {
return nil, err
}
return req, nil
}
// ReadFrameTCPOrRequest reads an InterleavedFrame or a Request.
func (s *ConnServer) ReadFrameTCPOrRequest(timeout bool) (interface{}, error) {
frame := s.tcpFrames.next()
req := base.Request{}
if timeout {
s.conf.Conn.SetReadDeadline(time.Now().Add(s.conf.ReadTimeout))
}
frame := s.tcpFrames.next()
return base.ReadInterleavedFrameOrRequest(frame, s.br)
return base.ReadInterleavedFrameOrRequest(frame, &req, s.br)
}
// WriteResponse writes a Response.