diff --git a/conn.go b/conn.go index f1d760d8..c07b6a9a 100644 --- a/conn.go +++ b/conn.go @@ -74,38 +74,38 @@ func (c *Conn) SetCredentials(user string, pass string, realm string, nonce stri } func (c *Conn) ReadRequest() (*Request, error) { - return requestDecode(c.nconn) + return readRequest(c.nconn) } func (c *Conn) WriteRequest(req *Request) error { if c.cseqEnabled { - if req.Headers == nil { - req.Headers = make(map[string]string) + if req.Header == nil { + req.Header = make(Header) } c.cseq += 1 - req.Headers["CSeq"] = strconv.FormatInt(int64(c.cseq), 10) + req.Header["CSeq"] = []string{strconv.FormatInt(int64(c.cseq), 10)} } if c.session != "" { - if req.Headers == nil { - req.Headers = make(map[string]string) + if req.Header == nil { + req.Header = make(Header) } - req.Headers["Session"] = c.session + req.Header["Session"] = []string{c.session} } if c.authProv != nil { - if req.Headers == nil { - req.Headers = make(map[string]string) + if req.Header == nil { + req.Header = make(Header) } - req.Headers["Authorization"] = c.authProv.generateHeader(req.Method, req.Url) + req.Header["Authorization"] = []string{c.authProv.generateHeader(req.Method, req.Url)} } - return requestEncode(c.nconn, req) + return req.write(c.nconn) } func (c *Conn) ReadResponse() (*Response, error) { - return responseDecode(c.nconn) + return readResponse(c.nconn) } func (c *Conn) WriteResponse(res *Response) error { - return responseEncode(c.nconn, res) + return res.write(c.nconn) } func (c *Conn) ReadInterleavedFrame(buf []byte) (int, int, error) { diff --git a/header.go b/header.go new file mode 100644 index 00000000..c06574da --- /dev/null +++ b/header.go @@ -0,0 +1,96 @@ +package gortsplib + +import ( + "bufio" + "fmt" + "sort" +) + +const ( + _MAX_HEADER_COUNT = 255 + _MAX_HEADER_KEY_LENGTH = 255 + _MAX_HEADER_VALUE_LENGTH = 255 +) + +type Header map[string][]string + +func readHeader(rb *bufio.Reader) (Header, error) { + h := make(Header) + + for { + byt, err := rb.ReadByte() + if err != nil { + return nil, err + } + + if byt == '\r' { + err := readByteEqual(rb, '\n') + if err != nil { + return nil, err + } + + break + } + + if len(h) >= _MAX_HEADER_COUNT { + return nil, fmt.Errorf("headers count exceeds %d", _MAX_HEADER_COUNT) + } + + key := string([]byte{byt}) + byts, err := readBytesLimited(rb, ':', _MAX_HEADER_KEY_LENGTH-1) + if err != nil { + return nil, err + } + key += string(byts[:len(byts)-1]) + + err = readByteEqual(rb, ' ') + if err != nil { + return nil, err + } + + byts, err = readBytesLimited(rb, '\r', _MAX_HEADER_VALUE_LENGTH) + if err != nil { + return nil, err + } + val := string(byts[:len(byts)-1]) + + if len(val) == 0 { + return nil, fmt.Errorf("empty header value") + } + + err = readByteEqual(rb, '\n') + if err != nil { + return nil, err + } + + h[key] = append(h[key], val) + } + + return h, nil +} + +func (h Header) write(wb *bufio.Writer) error { + // sort headers by key + // in order to obtain deterministic results + var keys []string + for key := range h { + keys = append(keys, key) + } + sort.Strings(keys) + + for _, key := range keys { + for _, val := range h[key] { + _, err := wb.Write([]byte(key + ": " + val + "\r\n")) + if err != nil { + return err + } + } + } + + _, err := wb.Write([]byte("\r\n")) + if err != nil { + return err + } + + return nil +} diff --git a/header_test.go b/header_test.go new file mode 100644 index 00000000..5550e00d --- /dev/null +++ b/header_test.go @@ -0,0 +1,61 @@ +package gortsplib + +import ( + "bufio" + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +var casesHeader = []struct { + name string + byts []byte + header Header +}{ + { + "single", + []byte("Proxy-Require: gzipped-messages\r\n" + + "Require: implicit-play\r\n" + + "\r\n"), + Header{ + "Require": []string{"implicit-play"}, + "Proxy-Require": []string{"gzipped-messages"}, + }, + }, + { + "multiple", + []byte("WWW-Authenticate: Digest realm=\"4419b63f5e51\", nonce=\"8b84a3b789283a8bea8da7fa7d41f08b\", stale=\"FALSE\"\r\n" + + "WWW-Authenticate: Basic realm=\"4419b63f5e51\"\r\n" + + "\r\n"), + Header{ + "WWW-Authenticate": []string{ + `Digest realm="4419b63f5e51", nonce="8b84a3b789283a8bea8da7fa7d41f08b", stale="FALSE"`, + `Basic realm="4419b63f5e51"`, + }, + }, + }, +} + +func TestHeaderRead(t *testing.T) { + for _, c := range casesHeader { + t.Run(c.name, func(t *testing.T) { + req, err := readHeader(bufio.NewReader(bytes.NewBuffer(c.byts))) + require.NoError(t, err) + require.Equal(t, c.header, req) + }) + } +} + +func TestHeaderWrite(t *testing.T) { + for _, c := range casesHeader { + t.Run(c.name, func(t *testing.T) { + var buf bytes.Buffer + bw := bufio.NewWriter(&buf) + err := c.header.write(bw) + bw.Flush() + require.NoError(t, err) + require.Equal(t, c.byts, buf.Bytes()) + }) + } +} diff --git a/request.go b/request.go index 0a18d8cf..51637d7f 100644 --- a/request.go +++ b/request.go @@ -9,11 +9,11 @@ import ( type Request struct { Method string Url string - Headers map[string]string + Header Header Content []byte } -func requestDecode(r io.Reader) (*Request, error) { +func readRequest(r io.Reader) (*Request, error) { rb := bufio.NewReader(r) req := &Request{} @@ -53,12 +53,12 @@ func requestDecode(r io.Reader) (*Request, error) { return nil, err } - req.Headers, err = readHeaders(rb) + req.Header, err = readHeader(rb) if err != nil { return nil, err } - req.Content, err = readContent(rb, req.Headers) + req.Content, err = readContent(rb, req.Header) if err != nil { return nil, err } @@ -66,7 +66,7 @@ func requestDecode(r io.Reader) (*Request, error) { return req, nil } -func requestEncode(w io.Writer, req *Request) error { +func (req *Request) write(w io.Writer) error { wb := bufio.NewWriter(w) _, err := wb.Write([]byte(req.Method + " " + req.Url + " " + _RTSP_PROTO + "\r\n")) @@ -74,7 +74,7 @@ func requestEncode(w io.Writer, req *Request) error { return err } - err = writeHeaders(wb, req.Headers) + err = req.Header.write(wb) if err != nil { return err } diff --git a/request_test.go b/request_test.go index e6e5ba1a..5b6f3752 100644 --- a/request_test.go +++ b/request_test.go @@ -22,10 +22,10 @@ var casesRequest = []struct { &Request{ Method: "OPTIONS", Url: "rtsp://example.com/media.mp4", - Headers: map[string]string{ - "CSeq": "1", - "Require": "implicit-play", - "Proxy-Require": "gzipped-messages", + Header: Header{ + "CSeq": []string{"1"}, + "Require": []string{"implicit-play"}, + "Proxy-Require": []string{"gzipped-messages"}, }, }, }, @@ -37,8 +37,8 @@ var casesRequest = []struct { &Request{ Method: "DESCRIBE", Url: "rtsp://example.com/media.mp4", - Headers: map[string]string{ - "CSeq": "2", + Header: Header{ + "CSeq": []string{"2"}, }, }, }, @@ -65,12 +65,12 @@ var casesRequest = []struct { &Request{ Method: "ANNOUNCE", Url: "rtsp://example.com/media.mp4", - Headers: map[string]string{ - "CSeq": "7", - "Date": "23 Jan 1997 15:35:06 GMT", - "Session": "12345678", - "Content-Type": "application/sdp", - "Content-Length": "306", + Header: Header{ + "CSeq": []string{"7"}, + "Date": []string{"23 Jan 1997 15:35:06 GMT"}, + "Session": []string{"12345678"}, + "Content-Type": []string{"application/sdp"}, + "Content-Length": []string{"306"}, }, Content: []byte("v=0\n" + "o=mhandley 2890844526 2890845468 IN IP4 126.16.64.4\n" + @@ -99,11 +99,11 @@ var casesRequest = []struct { &Request{ Method: "GET_PARAMETER", Url: "rtsp://example.com/media.mp4", - Headers: map[string]string{ - "CSeq": "9", - "Content-Type": "text/parameters", - "Session": "12345678", - "Content-Length": "24", + Header: Header{ + "CSeq": []string{"9"}, + "Content-Type": []string{"text/parameters"}, + "Session": []string{"12345678"}, + "Content-Length": []string{"24"}, }, Content: []byte("packets_received\n" + "jitter\n", @@ -112,21 +112,21 @@ var casesRequest = []struct { }, } -func TestRequestDecode(t *testing.T) { +func TestRequestRead(t *testing.T) { for _, c := range casesRequest { t.Run(c.name, func(t *testing.T) { - req, err := requestDecode(bytes.NewBuffer(c.byts)) + req, err := readRequest(bytes.NewBuffer(c.byts)) require.NoError(t, err) require.Equal(t, c.req, req) }) } } -func TestRequestEncode(t *testing.T) { +func TestRequestWrite(t *testing.T) { for _, c := range casesRequest { t.Run(c.name, func(t *testing.T) { var buf bytes.Buffer - err := requestEncode(&buf, c.req) + err := c.req.write(&buf) require.NoError(t, err) require.Equal(t, c.byts, buf.Bytes()) }) diff --git a/response.go b/response.go index 0080fb35..cbbed6d9 100644 --- a/response.go +++ b/response.go @@ -10,11 +10,11 @@ import ( type Response struct { StatusCode int Status string - Headers map[string]string + Header Header Content []byte } -func responseDecode(r io.Reader) (*Response, error) { +func readResponse(r io.Reader) (*Response, error) { rb := bufio.NewReader(r) res := &Response{} @@ -56,12 +56,12 @@ func responseDecode(r io.Reader) (*Response, error) { return nil, err } - res.Headers, err = readHeaders(rb) + res.Header, err = readHeader(rb) if err != nil { return nil, err } - res.Content, err = readContent(rb, res.Headers) + res.Content, err = readContent(rb, res.Header) if err != nil { return nil, err } @@ -69,7 +69,7 @@ func responseDecode(r io.Reader) (*Response, error) { return res, nil } -func responseEncode(w io.Writer, res *Response) error { +func (res *Response) write(w io.Writer) error { wb := bufio.NewWriter(w) _, err := wb.Write([]byte(_RTSP_PROTO + " " + strconv.FormatInt(int64(res.StatusCode), 10) + " " + res.Status + "\r\n")) @@ -78,10 +78,10 @@ func responseEncode(w io.Writer, res *Response) error { } if len(res.Content) != 0 { - res.Headers["Content-Length"] = strconv.FormatInt(int64(len(res.Content)), 10) + res.Header["Content-Length"] = []string{strconv.FormatInt(int64(len(res.Content)), 10)} } - err = writeHeaders(wb, res.Headers) + err = res.Header.write(wb) if err != nil { return err } diff --git a/response_test.go b/response_test.go index cc5d953f..1ea95ebc 100644 --- a/response_test.go +++ b/response_test.go @@ -13,7 +13,7 @@ var casesResponse = []struct { res *Response }{ { - "ok", + "ok with single header", []byte("RTSP/1.0 200 OK\r\n" + "CSeq: 1\r\n" + "Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" + @@ -22,9 +22,33 @@ var casesResponse = []struct { &Response{ StatusCode: 200, Status: "OK", - Headers: map[string]string{ - "CSeq": "1", - "Public": "DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE", + Header: Header{ + "CSeq": []string{"1"}, + "Public": []string{"DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE"}, + }, + }, + }, + { + "ok with multiple headers", + []byte("RTSP/1.0 200 OK\r\n" + + "CSeq: 2\r\n" + + "Date: Sat, Aug 16 2014 02:22:28 GMT\r\n" + + "Session: 645252166\r\n" + + "WWW-Authenticate: Digest realm=\"4419b63f5e51\", nonce=\"8b84a3b789283a8bea8da7fa7d41f08b\", stale=\"FALSE\"\r\n" + + "WWW-Authenticate: Basic realm=\"4419b63f5e51\"\r\n" + + "\r\n", + ), + &Response{ + StatusCode: 200, + Status: "OK", + Header: Header{ + "CSeq": []string{"2"}, + "Session": []string{"645252166"}, + "WWW-Authenticate": []string{ + "Digest realm=\"4419b63f5e51\", nonce=\"8b84a3b789283a8bea8da7fa7d41f08b\", stale=\"FALSE\"", + "Basic realm=\"4419b63f5e51\"", + }, + "Date": []string{"Sat, Aug 16 2014 02:22:28 GMT"}, }, }, }, @@ -56,11 +80,11 @@ var casesResponse = []struct { &Response{ StatusCode: 200, Status: "OK", - Headers: map[string]string{ - "Content-Base": "rtsp://example.com/media.mp4", - "Content-Length": "444", - "Content-Type": "application/sdp", - "CSeq": "2", + Header: Header{ + "Content-Base": []string{"rtsp://example.com/media.mp4"}, + "Content-Length": []string{"444"}, + "Content-Type": []string{"application/sdp"}, + "CSeq": []string{"2"}, }, Content: []byte("m=video 0 RTP/AVP 96\n" + "a=control:streamid=0\n" + @@ -83,21 +107,21 @@ var casesResponse = []struct { }, } -func TestResponseDecode(t *testing.T) { +func TestResponseRead(t *testing.T) { for _, c := range casesResponse { t.Run(c.name, func(t *testing.T) { - res, err := responseDecode(bytes.NewBuffer(c.byts)) + res, err := readResponse(bytes.NewBuffer(c.byts)) require.NoError(t, err) require.Equal(t, c.res, res) }) } } -func TestResponseEncode(t *testing.T) { +func TestResponseWrite(t *testing.T) { for _, c := range casesResponse { t.Run(c.name, func(t *testing.T) { var buf bytes.Buffer - err := responseEncode(&buf, c.res) + err := c.res.write(&buf) require.NoError(t, err) require.Equal(t, c.byts, buf.Bytes()) }) diff --git a/utils.go b/utils.go index a3a754ad..83a163da 100644 --- a/utils.go +++ b/utils.go @@ -4,16 +4,12 @@ import ( "bufio" "fmt" "io" - "sort" "strconv" ) const ( - _RTSP_PROTO = "RTSP/1.0" - _MAX_HEADER_COUNT = 255 - _MAX_HEADER_KEY_LENGTH = 255 - _MAX_HEADER_VALUE_LENGTH = 255 - _MAX_CONTENT_LENGTH = 4096 + _RTSP_PROTO = "RTSP/1.0" + _MAX_CONTENT_LENGTH = 4096 ) func readBytesLimited(rb *bufio.Reader, delim byte, n int) ([]byte, error) { @@ -44,95 +40,13 @@ func readByteEqual(rb *bufio.Reader, cmp byte) error { return nil } -func readHeaders(rb *bufio.Reader) (map[string]string, error) { - ret := make(map[string]string) - - for { - byt, err := rb.ReadByte() - if err != nil { - return nil, err - } - - if byt == '\r' { - err := readByteEqual(rb, '\n') - if err != nil { - return nil, err - } - - break - } - - if len(ret) >= _MAX_HEADER_COUNT { - return nil, fmt.Errorf("headers count exceeds %d", _MAX_HEADER_COUNT) - } - - key := string([]byte{byt}) - byts, err := readBytesLimited(rb, ':', _MAX_HEADER_KEY_LENGTH-1) - if err != nil { - return nil, err - } - key += string(byts[:len(byts)-1]) - - err = readByteEqual(rb, ' ') - if err != nil { - return nil, err - } - - byts, err = readBytesLimited(rb, '\r', _MAX_HEADER_VALUE_LENGTH) - if err != nil { - return nil, err - } - val := string(byts[:len(byts)-1]) - - if len(val) == 0 { - return nil, fmt.Errorf("empty header value") - } - - err = readByteEqual(rb, '\n') - if err != nil { - return nil, err - } - - // set only if not set previously - if _, ok := ret[key]; !ok { - ret[key] = val - } - } - - return ret, nil -} - -func writeHeaders(wb *bufio.Writer, headers map[string]string) error { - // sort headers by key - // in order to obtain deterministic results - var keys []string - for key := range headers { - keys = append(keys, key) - } - sort.Strings(keys) - - for _, key := range keys { - _, err := wb.Write([]byte(key + ": " + headers[key] + "\r\n")) - if err != nil { - return err - } - } - - _, err := wb.Write([]byte("\r\n")) - if err != nil { - return err - } - - return nil -} - -func readContent(rb *bufio.Reader, headers map[string]string) ([]byte, error) { - cls, ok := headers["Content-Length"] - if !ok { +func readContent(rb *bufio.Reader, header Header) ([]byte, error) { + cls, ok := header["Content-Length"] + if !ok || len(cls) != 1 { return nil, nil } - cl, err := strconv.ParseInt(cls, 10, 64) + cl, err := strconv.ParseInt(cls[0], 10, 64) if err != nil { return nil, fmt.Errorf("invalid Content-Length") }