support multiple headers with same key

This commit is contained in:
aler9
2020-01-26 10:21:05 +01:00
parent 2624982a5a
commit c746cc64d3
8 changed files with 247 additions and 152 deletions

26
conn.go
View File

@@ -74,38 +74,38 @@ func (c *Conn) SetCredentials(user string, pass string, realm string, nonce stri
} }
func (c *Conn) ReadRequest() (*Request, error) { func (c *Conn) ReadRequest() (*Request, error) {
return requestDecode(c.nconn) return readRequest(c.nconn)
} }
func (c *Conn) WriteRequest(req *Request) error { func (c *Conn) WriteRequest(req *Request) error {
if c.cseqEnabled { if c.cseqEnabled {
if req.Headers == nil { if req.Header == nil {
req.Headers = make(map[string]string) req.Header = make(Header)
} }
c.cseq += 1 c.cseq += 1
req.Headers["CSeq"] = strconv.FormatInt(int64(c.cseq), 10) req.Header["CSeq"] = []string{strconv.FormatInt(int64(c.cseq), 10)}
} }
if c.session != "" { if c.session != "" {
if req.Headers == nil { if req.Header == nil {
req.Headers = make(map[string]string) req.Header = make(Header)
} }
req.Headers["Session"] = c.session req.Header["Session"] = []string{c.session}
} }
if c.authProv != nil { if c.authProv != nil {
if req.Headers == nil { if req.Header == nil {
req.Headers = make(map[string]string) req.Header = make(Header)
} }
req.Headers["Authorization"] = c.authProv.generateHeader(req.Method, req.Url) req.Header["Authorization"] = []string{c.authProv.generateHeader(req.Method, req.Url)}
} }
return requestEncode(c.nconn, req) return req.write(c.nconn)
} }
func (c *Conn) ReadResponse() (*Response, error) { func (c *Conn) ReadResponse() (*Response, error) {
return responseDecode(c.nconn) return readResponse(c.nconn)
} }
func (c *Conn) WriteResponse(res *Response) error { func (c *Conn) WriteResponse(res *Response) error {
return responseEncode(c.nconn, res) return res.write(c.nconn)
} }
func (c *Conn) ReadInterleavedFrame(buf []byte) (int, int, error) { func (c *Conn) ReadInterleavedFrame(buf []byte) (int, int, error) {

96
header.go Normal file
View File

@@ -0,0 +1,96 @@
package gortsplib
import (
"bufio"
"fmt"
"sort"
)
const (
_MAX_HEADER_COUNT = 255
_MAX_HEADER_KEY_LENGTH = 255
_MAX_HEADER_VALUE_LENGTH = 255
)
type Header map[string][]string
func readHeader(rb *bufio.Reader) (Header, error) {
h := make(Header)
for {
byt, err := rb.ReadByte()
if err != nil {
return nil, err
}
if byt == '\r' {
err := readByteEqual(rb, '\n')
if err != nil {
return nil, err
}
break
}
if len(h) >= _MAX_HEADER_COUNT {
return nil, fmt.Errorf("headers count exceeds %d", _MAX_HEADER_COUNT)
}
key := string([]byte{byt})
byts, err := readBytesLimited(rb, ':', _MAX_HEADER_KEY_LENGTH-1)
if err != nil {
return nil, err
}
key += string(byts[:len(byts)-1])
err = readByteEqual(rb, ' ')
if err != nil {
return nil, err
}
byts, err = readBytesLimited(rb, '\r', _MAX_HEADER_VALUE_LENGTH)
if err != nil {
return nil, err
}
val := string(byts[:len(byts)-1])
if len(val) == 0 {
return nil, fmt.Errorf("empty header value")
}
err = readByteEqual(rb, '\n')
if err != nil {
return nil, err
}
h[key] = append(h[key], val)
}
return h, nil
}
func (h Header) write(wb *bufio.Writer) error {
// sort headers by key
// in order to obtain deterministic results
var keys []string
for key := range h {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
for _, val := range h[key] {
_, err := wb.Write([]byte(key + ": " + val + "\r\n"))
if err != nil {
return err
}
}
}
_, err := wb.Write([]byte("\r\n"))
if err != nil {
return err
}
return nil
}

61
header_test.go Normal file
View File

@@ -0,0 +1,61 @@
package gortsplib
import (
"bufio"
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
var casesHeader = []struct {
name string
byts []byte
header Header
}{
{
"single",
[]byte("Proxy-Require: gzipped-messages\r\n" +
"Require: implicit-play\r\n" +
"\r\n"),
Header{
"Require": []string{"implicit-play"},
"Proxy-Require": []string{"gzipped-messages"},
},
},
{
"multiple",
[]byte("WWW-Authenticate: Digest realm=\"4419b63f5e51\", nonce=\"8b84a3b789283a8bea8da7fa7d41f08b\", stale=\"FALSE\"\r\n" +
"WWW-Authenticate: Basic realm=\"4419b63f5e51\"\r\n" +
"\r\n"),
Header{
"WWW-Authenticate": []string{
`Digest realm="4419b63f5e51", nonce="8b84a3b789283a8bea8da7fa7d41f08b", stale="FALSE"`,
`Basic realm="4419b63f5e51"`,
},
},
},
}
func TestHeaderRead(t *testing.T) {
for _, c := range casesHeader {
t.Run(c.name, func(t *testing.T) {
req, err := readHeader(bufio.NewReader(bytes.NewBuffer(c.byts)))
require.NoError(t, err)
require.Equal(t, c.header, req)
})
}
}
func TestHeaderWrite(t *testing.T) {
for _, c := range casesHeader {
t.Run(c.name, func(t *testing.T) {
var buf bytes.Buffer
bw := bufio.NewWriter(&buf)
err := c.header.write(bw)
bw.Flush()
require.NoError(t, err)
require.Equal(t, c.byts, buf.Bytes())
})
}
}

View File

@@ -9,11 +9,11 @@ import (
type Request struct { type Request struct {
Method string Method string
Url string Url string
Headers map[string]string Header Header
Content []byte Content []byte
} }
func requestDecode(r io.Reader) (*Request, error) { func readRequest(r io.Reader) (*Request, error) {
rb := bufio.NewReader(r) rb := bufio.NewReader(r)
req := &Request{} req := &Request{}
@@ -53,12 +53,12 @@ func requestDecode(r io.Reader) (*Request, error) {
return nil, err return nil, err
} }
req.Headers, err = readHeaders(rb) req.Header, err = readHeader(rb)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Content, err = readContent(rb, req.Headers) req.Content, err = readContent(rb, req.Header)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -66,7 +66,7 @@ func requestDecode(r io.Reader) (*Request, error) {
return req, nil return req, nil
} }
func requestEncode(w io.Writer, req *Request) error { func (req *Request) write(w io.Writer) error {
wb := bufio.NewWriter(w) wb := bufio.NewWriter(w)
_, err := wb.Write([]byte(req.Method + " " + req.Url + " " + _RTSP_PROTO + "\r\n")) _, err := wb.Write([]byte(req.Method + " " + req.Url + " " + _RTSP_PROTO + "\r\n"))
@@ -74,7 +74,7 @@ func requestEncode(w io.Writer, req *Request) error {
return err return err
} }
err = writeHeaders(wb, req.Headers) err = req.Header.write(wb)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -22,10 +22,10 @@ var casesRequest = []struct {
&Request{ &Request{
Method: "OPTIONS", Method: "OPTIONS",
Url: "rtsp://example.com/media.mp4", Url: "rtsp://example.com/media.mp4",
Headers: map[string]string{ Header: Header{
"CSeq": "1", "CSeq": []string{"1"},
"Require": "implicit-play", "Require": []string{"implicit-play"},
"Proxy-Require": "gzipped-messages", "Proxy-Require": []string{"gzipped-messages"},
}, },
}, },
}, },
@@ -37,8 +37,8 @@ var casesRequest = []struct {
&Request{ &Request{
Method: "DESCRIBE", Method: "DESCRIBE",
Url: "rtsp://example.com/media.mp4", Url: "rtsp://example.com/media.mp4",
Headers: map[string]string{ Header: Header{
"CSeq": "2", "CSeq": []string{"2"},
}, },
}, },
}, },
@@ -65,12 +65,12 @@ var casesRequest = []struct {
&Request{ &Request{
Method: "ANNOUNCE", Method: "ANNOUNCE",
Url: "rtsp://example.com/media.mp4", Url: "rtsp://example.com/media.mp4",
Headers: map[string]string{ Header: Header{
"CSeq": "7", "CSeq": []string{"7"},
"Date": "23 Jan 1997 15:35:06 GMT", "Date": []string{"23 Jan 1997 15:35:06 GMT"},
"Session": "12345678", "Session": []string{"12345678"},
"Content-Type": "application/sdp", "Content-Type": []string{"application/sdp"},
"Content-Length": "306", "Content-Length": []string{"306"},
}, },
Content: []byte("v=0\n" + Content: []byte("v=0\n" +
"o=mhandley 2890844526 2890845468 IN IP4 126.16.64.4\n" + "o=mhandley 2890844526 2890845468 IN IP4 126.16.64.4\n" +
@@ -99,11 +99,11 @@ var casesRequest = []struct {
&Request{ &Request{
Method: "GET_PARAMETER", Method: "GET_PARAMETER",
Url: "rtsp://example.com/media.mp4", Url: "rtsp://example.com/media.mp4",
Headers: map[string]string{ Header: Header{
"CSeq": "9", "CSeq": []string{"9"},
"Content-Type": "text/parameters", "Content-Type": []string{"text/parameters"},
"Session": "12345678", "Session": []string{"12345678"},
"Content-Length": "24", "Content-Length": []string{"24"},
}, },
Content: []byte("packets_received\n" + Content: []byte("packets_received\n" +
"jitter\n", "jitter\n",
@@ -112,21 +112,21 @@ var casesRequest = []struct {
}, },
} }
func TestRequestDecode(t *testing.T) { func TestRequestRead(t *testing.T) {
for _, c := range casesRequest { for _, c := range casesRequest {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
req, err := requestDecode(bytes.NewBuffer(c.byts)) req, err := readRequest(bytes.NewBuffer(c.byts))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c.req, req) require.Equal(t, c.req, req)
}) })
} }
} }
func TestRequestEncode(t *testing.T) { func TestRequestWrite(t *testing.T) {
for _, c := range casesRequest { for _, c := range casesRequest {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
err := requestEncode(&buf, c.req) err := c.req.write(&buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c.byts, buf.Bytes()) require.Equal(t, c.byts, buf.Bytes())
}) })

View File

@@ -10,11 +10,11 @@ import (
type Response struct { type Response struct {
StatusCode int StatusCode int
Status string Status string
Headers map[string]string Header Header
Content []byte Content []byte
} }
func responseDecode(r io.Reader) (*Response, error) { func readResponse(r io.Reader) (*Response, error) {
rb := bufio.NewReader(r) rb := bufio.NewReader(r)
res := &Response{} res := &Response{}
@@ -56,12 +56,12 @@ func responseDecode(r io.Reader) (*Response, error) {
return nil, err return nil, err
} }
res.Headers, err = readHeaders(rb) res.Header, err = readHeader(rb)
if err != nil { if err != nil {
return nil, err return nil, err
} }
res.Content, err = readContent(rb, res.Headers) res.Content, err = readContent(rb, res.Header)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -69,7 +69,7 @@ func responseDecode(r io.Reader) (*Response, error) {
return res, nil return res, nil
} }
func responseEncode(w io.Writer, res *Response) error { func (res *Response) write(w io.Writer) error {
wb := bufio.NewWriter(w) wb := bufio.NewWriter(w)
_, err := wb.Write([]byte(_RTSP_PROTO + " " + strconv.FormatInt(int64(res.StatusCode), 10) + " " + res.Status + "\r\n")) _, err := wb.Write([]byte(_RTSP_PROTO + " " + strconv.FormatInt(int64(res.StatusCode), 10) + " " + res.Status + "\r\n"))
@@ -78,10 +78,10 @@ func responseEncode(w io.Writer, res *Response) error {
} }
if len(res.Content) != 0 { if len(res.Content) != 0 {
res.Headers["Content-Length"] = strconv.FormatInt(int64(len(res.Content)), 10) res.Header["Content-Length"] = []string{strconv.FormatInt(int64(len(res.Content)), 10)}
} }
err = writeHeaders(wb, res.Headers) err = res.Header.write(wb)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -13,7 +13,7 @@ var casesResponse = []struct {
res *Response res *Response
}{ }{
{ {
"ok", "ok with single header",
[]byte("RTSP/1.0 200 OK\r\n" + []byte("RTSP/1.0 200 OK\r\n" +
"CSeq: 1\r\n" + "CSeq: 1\r\n" +
"Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" + "Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" +
@@ -22,9 +22,33 @@ var casesResponse = []struct {
&Response{ &Response{
StatusCode: 200, StatusCode: 200,
Status: "OK", Status: "OK",
Headers: map[string]string{ Header: Header{
"CSeq": "1", "CSeq": []string{"1"},
"Public": "DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE", "Public": []string{"DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE"},
},
},
},
{
"ok with multiple headers",
[]byte("RTSP/1.0 200 OK\r\n" +
"CSeq: 2\r\n" +
"Date: Sat, Aug 16 2014 02:22:28 GMT\r\n" +
"Session: 645252166\r\n" +
"WWW-Authenticate: Digest realm=\"4419b63f5e51\", nonce=\"8b84a3b789283a8bea8da7fa7d41f08b\", stale=\"FALSE\"\r\n" +
"WWW-Authenticate: Basic realm=\"4419b63f5e51\"\r\n" +
"\r\n",
),
&Response{
StatusCode: 200,
Status: "OK",
Header: Header{
"CSeq": []string{"2"},
"Session": []string{"645252166"},
"WWW-Authenticate": []string{
"Digest realm=\"4419b63f5e51\", nonce=\"8b84a3b789283a8bea8da7fa7d41f08b\", stale=\"FALSE\"",
"Basic realm=\"4419b63f5e51\"",
},
"Date": []string{"Sat, Aug 16 2014 02:22:28 GMT"},
}, },
}, },
}, },
@@ -56,11 +80,11 @@ var casesResponse = []struct {
&Response{ &Response{
StatusCode: 200, StatusCode: 200,
Status: "OK", Status: "OK",
Headers: map[string]string{ Header: Header{
"Content-Base": "rtsp://example.com/media.mp4", "Content-Base": []string{"rtsp://example.com/media.mp4"},
"Content-Length": "444", "Content-Length": []string{"444"},
"Content-Type": "application/sdp", "Content-Type": []string{"application/sdp"},
"CSeq": "2", "CSeq": []string{"2"},
}, },
Content: []byte("m=video 0 RTP/AVP 96\n" + Content: []byte("m=video 0 RTP/AVP 96\n" +
"a=control:streamid=0\n" + "a=control:streamid=0\n" +
@@ -83,21 +107,21 @@ var casesResponse = []struct {
}, },
} }
func TestResponseDecode(t *testing.T) { func TestResponseRead(t *testing.T) {
for _, c := range casesResponse { for _, c := range casesResponse {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
res, err := responseDecode(bytes.NewBuffer(c.byts)) res, err := readResponse(bytes.NewBuffer(c.byts))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c.res, res) require.Equal(t, c.res, res)
}) })
} }
} }
func TestResponseEncode(t *testing.T) { func TestResponseWrite(t *testing.T) {
for _, c := range casesResponse { for _, c := range casesResponse {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
err := responseEncode(&buf, c.res) err := c.res.write(&buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c.byts, buf.Bytes()) require.Equal(t, c.byts, buf.Bytes())
}) })

View File

@@ -4,15 +4,11 @@ import (
"bufio" "bufio"
"fmt" "fmt"
"io" "io"
"sort"
"strconv" "strconv"
) )
const ( const (
_RTSP_PROTO = "RTSP/1.0" _RTSP_PROTO = "RTSP/1.0"
_MAX_HEADER_COUNT = 255
_MAX_HEADER_KEY_LENGTH = 255
_MAX_HEADER_VALUE_LENGTH = 255
_MAX_CONTENT_LENGTH = 4096 _MAX_CONTENT_LENGTH = 4096
) )
@@ -44,95 +40,13 @@ func readByteEqual(rb *bufio.Reader, cmp byte) error {
return nil return nil
} }
func readHeaders(rb *bufio.Reader) (map[string]string, error) { func readContent(rb *bufio.Reader, header Header) ([]byte, error) {
ret := make(map[string]string) cls, ok := header["Content-Length"]
if !ok || len(cls) != 1 {
for {
byt, err := rb.ReadByte()
if err != nil {
return nil, err
}
if byt == '\r' {
err := readByteEqual(rb, '\n')
if err != nil {
return nil, err
}
break
}
if len(ret) >= _MAX_HEADER_COUNT {
return nil, fmt.Errorf("headers count exceeds %d", _MAX_HEADER_COUNT)
}
key := string([]byte{byt})
byts, err := readBytesLimited(rb, ':', _MAX_HEADER_KEY_LENGTH-1)
if err != nil {
return nil, err
}
key += string(byts[:len(byts)-1])
err = readByteEqual(rb, ' ')
if err != nil {
return nil, err
}
byts, err = readBytesLimited(rb, '\r', _MAX_HEADER_VALUE_LENGTH)
if err != nil {
return nil, err
}
val := string(byts[:len(byts)-1])
if len(val) == 0 {
return nil, fmt.Errorf("empty header value")
}
err = readByteEqual(rb, '\n')
if err != nil {
return nil, err
}
// set only if not set previously
if _, ok := ret[key]; !ok {
ret[key] = val
}
}
return ret, nil
}
func writeHeaders(wb *bufio.Writer, headers map[string]string) error {
// sort headers by key
// in order to obtain deterministic results
var keys []string
for key := range headers {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
_, err := wb.Write([]byte(key + ": " + headers[key] + "\r\n"))
if err != nil {
return err
}
}
_, err := wb.Write([]byte("\r\n"))
if err != nil {
return err
}
return nil
}
func readContent(rb *bufio.Reader, headers map[string]string) ([]byte, error) {
cls, ok := headers["Content-Length"]
if !ok {
return nil, nil return nil, nil
} }
cl, err := strconv.ParseInt(cls, 10, 64) cl, err := strconv.ParseInt(cls[0], 10, 64)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid Content-Length") return nil, fmt.Errorf("invalid Content-Length")
} }