mirror of
https://github.com/aler9/gortsplib
synced 2025-10-04 23:02:45 +08:00
support multiple headers with same key
This commit is contained in:
26
conn.go
26
conn.go
@@ -74,38 +74,38 @@ func (c *Conn) SetCredentials(user string, pass string, realm string, nonce stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) ReadRequest() (*Request, error) {
|
func (c *Conn) ReadRequest() (*Request, error) {
|
||||||
return requestDecode(c.nconn)
|
return readRequest(c.nconn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) WriteRequest(req *Request) error {
|
func (c *Conn) WriteRequest(req *Request) error {
|
||||||
if c.cseqEnabled {
|
if c.cseqEnabled {
|
||||||
if req.Headers == nil {
|
if req.Header == nil {
|
||||||
req.Headers = make(map[string]string)
|
req.Header = make(Header)
|
||||||
}
|
}
|
||||||
c.cseq += 1
|
c.cseq += 1
|
||||||
req.Headers["CSeq"] = strconv.FormatInt(int64(c.cseq), 10)
|
req.Header["CSeq"] = []string{strconv.FormatInt(int64(c.cseq), 10)}
|
||||||
}
|
}
|
||||||
if c.session != "" {
|
if c.session != "" {
|
||||||
if req.Headers == nil {
|
if req.Header == nil {
|
||||||
req.Headers = make(map[string]string)
|
req.Header = make(Header)
|
||||||
}
|
}
|
||||||
req.Headers["Session"] = c.session
|
req.Header["Session"] = []string{c.session}
|
||||||
}
|
}
|
||||||
if c.authProv != nil {
|
if c.authProv != nil {
|
||||||
if req.Headers == nil {
|
if req.Header == nil {
|
||||||
req.Headers = make(map[string]string)
|
req.Header = make(Header)
|
||||||
}
|
}
|
||||||
req.Headers["Authorization"] = c.authProv.generateHeader(req.Method, req.Url)
|
req.Header["Authorization"] = []string{c.authProv.generateHeader(req.Method, req.Url)}
|
||||||
}
|
}
|
||||||
return requestEncode(c.nconn, req)
|
return req.write(c.nconn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) ReadResponse() (*Response, error) {
|
func (c *Conn) ReadResponse() (*Response, error) {
|
||||||
return responseDecode(c.nconn)
|
return readResponse(c.nconn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) WriteResponse(res *Response) error {
|
func (c *Conn) WriteResponse(res *Response) error {
|
||||||
return responseEncode(c.nconn, res)
|
return res.write(c.nconn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) ReadInterleavedFrame(buf []byte) (int, int, error) {
|
func (c *Conn) ReadInterleavedFrame(buf []byte) (int, int, error) {
|
||||||
|
96
header.go
Normal file
96
header.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
package gortsplib
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
_MAX_HEADER_COUNT = 255
|
||||||
|
_MAX_HEADER_KEY_LENGTH = 255
|
||||||
|
_MAX_HEADER_VALUE_LENGTH = 255
|
||||||
|
)
|
||||||
|
|
||||||
|
type Header map[string][]string
|
||||||
|
|
||||||
|
func readHeader(rb *bufio.Reader) (Header, error) {
|
||||||
|
h := make(Header)
|
||||||
|
|
||||||
|
for {
|
||||||
|
byt, err := rb.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if byt == '\r' {
|
||||||
|
err := readByteEqual(rb, '\n')
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(h) >= _MAX_HEADER_COUNT {
|
||||||
|
return nil, fmt.Errorf("headers count exceeds %d", _MAX_HEADER_COUNT)
|
||||||
|
}
|
||||||
|
|
||||||
|
key := string([]byte{byt})
|
||||||
|
byts, err := readBytesLimited(rb, ':', _MAX_HEADER_KEY_LENGTH-1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
key += string(byts[:len(byts)-1])
|
||||||
|
|
||||||
|
err = readByteEqual(rb, ' ')
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
byts, err = readBytesLimited(rb, '\r', _MAX_HEADER_VALUE_LENGTH)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
val := string(byts[:len(byts)-1])
|
||||||
|
|
||||||
|
if len(val) == 0 {
|
||||||
|
return nil, fmt.Errorf("empty header value")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = readByteEqual(rb, '\n')
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
h[key] = append(h[key], val)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h Header) write(wb *bufio.Writer) error {
|
||||||
|
// sort headers by key
|
||||||
|
// in order to obtain deterministic results
|
||||||
|
var keys []string
|
||||||
|
for key := range h {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
for _, key := range keys {
|
||||||
|
for _, val := range h[key] {
|
||||||
|
_, err := wb.Write([]byte(key + ": " + val + "\r\n"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := wb.Write([]byte("\r\n"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
61
header_test.go
Normal file
61
header_test.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package gortsplib
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
var casesHeader = []struct {
|
||||||
|
name string
|
||||||
|
byts []byte
|
||||||
|
header Header
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"single",
|
||||||
|
[]byte("Proxy-Require: gzipped-messages\r\n" +
|
||||||
|
"Require: implicit-play\r\n" +
|
||||||
|
"\r\n"),
|
||||||
|
Header{
|
||||||
|
"Require": []string{"implicit-play"},
|
||||||
|
"Proxy-Require": []string{"gzipped-messages"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"multiple",
|
||||||
|
[]byte("WWW-Authenticate: Digest realm=\"4419b63f5e51\", nonce=\"8b84a3b789283a8bea8da7fa7d41f08b\", stale=\"FALSE\"\r\n" +
|
||||||
|
"WWW-Authenticate: Basic realm=\"4419b63f5e51\"\r\n" +
|
||||||
|
"\r\n"),
|
||||||
|
Header{
|
||||||
|
"WWW-Authenticate": []string{
|
||||||
|
`Digest realm="4419b63f5e51", nonce="8b84a3b789283a8bea8da7fa7d41f08b", stale="FALSE"`,
|
||||||
|
`Basic realm="4419b63f5e51"`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeaderRead(t *testing.T) {
|
||||||
|
for _, c := range casesHeader {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
req, err := readHeader(bufio.NewReader(bytes.NewBuffer(c.byts)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, c.header, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeaderWrite(t *testing.T) {
|
||||||
|
for _, c := range casesHeader {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
bw := bufio.NewWriter(&buf)
|
||||||
|
err := c.header.write(bw)
|
||||||
|
bw.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, c.byts, buf.Bytes())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
12
request.go
12
request.go
@@ -9,11 +9,11 @@ import (
|
|||||||
type Request struct {
|
type Request struct {
|
||||||
Method string
|
Method string
|
||||||
Url string
|
Url string
|
||||||
Headers map[string]string
|
Header Header
|
||||||
Content []byte
|
Content []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestDecode(r io.Reader) (*Request, error) {
|
func readRequest(r io.Reader) (*Request, error) {
|
||||||
rb := bufio.NewReader(r)
|
rb := bufio.NewReader(r)
|
||||||
|
|
||||||
req := &Request{}
|
req := &Request{}
|
||||||
@@ -53,12 +53,12 @@ func requestDecode(r io.Reader) (*Request, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Headers, err = readHeaders(rb)
|
req.Header, err = readHeader(rb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Content, err = readContent(rb, req.Headers)
|
req.Content, err = readContent(rb, req.Header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -66,7 +66,7 @@ func requestDecode(r io.Reader) (*Request, error) {
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestEncode(w io.Writer, req *Request) error {
|
func (req *Request) write(w io.Writer) error {
|
||||||
wb := bufio.NewWriter(w)
|
wb := bufio.NewWriter(w)
|
||||||
|
|
||||||
_, err := wb.Write([]byte(req.Method + " " + req.Url + " " + _RTSP_PROTO + "\r\n"))
|
_, err := wb.Write([]byte(req.Method + " " + req.Url + " " + _RTSP_PROTO + "\r\n"))
|
||||||
@@ -74,7 +74,7 @@ func requestEncode(w io.Writer, req *Request) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = writeHeaders(wb, req.Headers)
|
err = req.Header.write(wb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@@ -22,10 +22,10 @@ var casesRequest = []struct {
|
|||||||
&Request{
|
&Request{
|
||||||
Method: "OPTIONS",
|
Method: "OPTIONS",
|
||||||
Url: "rtsp://example.com/media.mp4",
|
Url: "rtsp://example.com/media.mp4",
|
||||||
Headers: map[string]string{
|
Header: Header{
|
||||||
"CSeq": "1",
|
"CSeq": []string{"1"},
|
||||||
"Require": "implicit-play",
|
"Require": []string{"implicit-play"},
|
||||||
"Proxy-Require": "gzipped-messages",
|
"Proxy-Require": []string{"gzipped-messages"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -37,8 +37,8 @@ var casesRequest = []struct {
|
|||||||
&Request{
|
&Request{
|
||||||
Method: "DESCRIBE",
|
Method: "DESCRIBE",
|
||||||
Url: "rtsp://example.com/media.mp4",
|
Url: "rtsp://example.com/media.mp4",
|
||||||
Headers: map[string]string{
|
Header: Header{
|
||||||
"CSeq": "2",
|
"CSeq": []string{"2"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -65,12 +65,12 @@ var casesRequest = []struct {
|
|||||||
&Request{
|
&Request{
|
||||||
Method: "ANNOUNCE",
|
Method: "ANNOUNCE",
|
||||||
Url: "rtsp://example.com/media.mp4",
|
Url: "rtsp://example.com/media.mp4",
|
||||||
Headers: map[string]string{
|
Header: Header{
|
||||||
"CSeq": "7",
|
"CSeq": []string{"7"},
|
||||||
"Date": "23 Jan 1997 15:35:06 GMT",
|
"Date": []string{"23 Jan 1997 15:35:06 GMT"},
|
||||||
"Session": "12345678",
|
"Session": []string{"12345678"},
|
||||||
"Content-Type": "application/sdp",
|
"Content-Type": []string{"application/sdp"},
|
||||||
"Content-Length": "306",
|
"Content-Length": []string{"306"},
|
||||||
},
|
},
|
||||||
Content: []byte("v=0\n" +
|
Content: []byte("v=0\n" +
|
||||||
"o=mhandley 2890844526 2890845468 IN IP4 126.16.64.4\n" +
|
"o=mhandley 2890844526 2890845468 IN IP4 126.16.64.4\n" +
|
||||||
@@ -99,11 +99,11 @@ var casesRequest = []struct {
|
|||||||
&Request{
|
&Request{
|
||||||
Method: "GET_PARAMETER",
|
Method: "GET_PARAMETER",
|
||||||
Url: "rtsp://example.com/media.mp4",
|
Url: "rtsp://example.com/media.mp4",
|
||||||
Headers: map[string]string{
|
Header: Header{
|
||||||
"CSeq": "9",
|
"CSeq": []string{"9"},
|
||||||
"Content-Type": "text/parameters",
|
"Content-Type": []string{"text/parameters"},
|
||||||
"Session": "12345678",
|
"Session": []string{"12345678"},
|
||||||
"Content-Length": "24",
|
"Content-Length": []string{"24"},
|
||||||
},
|
},
|
||||||
Content: []byte("packets_received\n" +
|
Content: []byte("packets_received\n" +
|
||||||
"jitter\n",
|
"jitter\n",
|
||||||
@@ -112,21 +112,21 @@ var casesRequest = []struct {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRequestDecode(t *testing.T) {
|
func TestRequestRead(t *testing.T) {
|
||||||
for _, c := range casesRequest {
|
for _, c := range casesRequest {
|
||||||
t.Run(c.name, func(t *testing.T) {
|
t.Run(c.name, func(t *testing.T) {
|
||||||
req, err := requestDecode(bytes.NewBuffer(c.byts))
|
req, err := readRequest(bytes.NewBuffer(c.byts))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, c.req, req)
|
require.Equal(t, c.req, req)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRequestEncode(t *testing.T) {
|
func TestRequestWrite(t *testing.T) {
|
||||||
for _, c := range casesRequest {
|
for _, c := range casesRequest {
|
||||||
t.Run(c.name, func(t *testing.T) {
|
t.Run(c.name, func(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
err := requestEncode(&buf, c.req)
|
err := c.req.write(&buf)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, c.byts, buf.Bytes())
|
require.Equal(t, c.byts, buf.Bytes())
|
||||||
})
|
})
|
||||||
|
14
response.go
14
response.go
@@ -10,11 +10,11 @@ import (
|
|||||||
type Response struct {
|
type Response struct {
|
||||||
StatusCode int
|
StatusCode int
|
||||||
Status string
|
Status string
|
||||||
Headers map[string]string
|
Header Header
|
||||||
Content []byte
|
Content []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseDecode(r io.Reader) (*Response, error) {
|
func readResponse(r io.Reader) (*Response, error) {
|
||||||
rb := bufio.NewReader(r)
|
rb := bufio.NewReader(r)
|
||||||
|
|
||||||
res := &Response{}
|
res := &Response{}
|
||||||
@@ -56,12 +56,12 @@ func responseDecode(r io.Reader) (*Response, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
res.Headers, err = readHeaders(rb)
|
res.Header, err = readHeader(rb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
res.Content, err = readContent(rb, res.Headers)
|
res.Content, err = readContent(rb, res.Header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -69,7 +69,7 @@ func responseDecode(r io.Reader) (*Response, error) {
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseEncode(w io.Writer, res *Response) error {
|
func (res *Response) write(w io.Writer) error {
|
||||||
wb := bufio.NewWriter(w)
|
wb := bufio.NewWriter(w)
|
||||||
|
|
||||||
_, err := wb.Write([]byte(_RTSP_PROTO + " " + strconv.FormatInt(int64(res.StatusCode), 10) + " " + res.Status + "\r\n"))
|
_, err := wb.Write([]byte(_RTSP_PROTO + " " + strconv.FormatInt(int64(res.StatusCode), 10) + " " + res.Status + "\r\n"))
|
||||||
@@ -78,10 +78,10 @@ func responseEncode(w io.Writer, res *Response) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(res.Content) != 0 {
|
if len(res.Content) != 0 {
|
||||||
res.Headers["Content-Length"] = strconv.FormatInt(int64(len(res.Content)), 10)
|
res.Header["Content-Length"] = []string{strconv.FormatInt(int64(len(res.Content)), 10)}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = writeHeaders(wb, res.Headers)
|
err = res.Header.write(wb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@@ -13,7 +13,7 @@ var casesResponse = []struct {
|
|||||||
res *Response
|
res *Response
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"ok",
|
"ok with single header",
|
||||||
[]byte("RTSP/1.0 200 OK\r\n" +
|
[]byte("RTSP/1.0 200 OK\r\n" +
|
||||||
"CSeq: 1\r\n" +
|
"CSeq: 1\r\n" +
|
||||||
"Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" +
|
"Public: DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE\r\n" +
|
||||||
@@ -22,9 +22,33 @@ var casesResponse = []struct {
|
|||||||
&Response{
|
&Response{
|
||||||
StatusCode: 200,
|
StatusCode: 200,
|
||||||
Status: "OK",
|
Status: "OK",
|
||||||
Headers: map[string]string{
|
Header: Header{
|
||||||
"CSeq": "1",
|
"CSeq": []string{"1"},
|
||||||
"Public": "DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE",
|
"Public": []string{"DESCRIBE, SETUP, TEARDOWN, PLAY, PAUSE"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ok with multiple headers",
|
||||||
|
[]byte("RTSP/1.0 200 OK\r\n" +
|
||||||
|
"CSeq: 2\r\n" +
|
||||||
|
"Date: Sat, Aug 16 2014 02:22:28 GMT\r\n" +
|
||||||
|
"Session: 645252166\r\n" +
|
||||||
|
"WWW-Authenticate: Digest realm=\"4419b63f5e51\", nonce=\"8b84a3b789283a8bea8da7fa7d41f08b\", stale=\"FALSE\"\r\n" +
|
||||||
|
"WWW-Authenticate: Basic realm=\"4419b63f5e51\"\r\n" +
|
||||||
|
"\r\n",
|
||||||
|
),
|
||||||
|
&Response{
|
||||||
|
StatusCode: 200,
|
||||||
|
Status: "OK",
|
||||||
|
Header: Header{
|
||||||
|
"CSeq": []string{"2"},
|
||||||
|
"Session": []string{"645252166"},
|
||||||
|
"WWW-Authenticate": []string{
|
||||||
|
"Digest realm=\"4419b63f5e51\", nonce=\"8b84a3b789283a8bea8da7fa7d41f08b\", stale=\"FALSE\"",
|
||||||
|
"Basic realm=\"4419b63f5e51\"",
|
||||||
|
},
|
||||||
|
"Date": []string{"Sat, Aug 16 2014 02:22:28 GMT"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -56,11 +80,11 @@ var casesResponse = []struct {
|
|||||||
&Response{
|
&Response{
|
||||||
StatusCode: 200,
|
StatusCode: 200,
|
||||||
Status: "OK",
|
Status: "OK",
|
||||||
Headers: map[string]string{
|
Header: Header{
|
||||||
"Content-Base": "rtsp://example.com/media.mp4",
|
"Content-Base": []string{"rtsp://example.com/media.mp4"},
|
||||||
"Content-Length": "444",
|
"Content-Length": []string{"444"},
|
||||||
"Content-Type": "application/sdp",
|
"Content-Type": []string{"application/sdp"},
|
||||||
"CSeq": "2",
|
"CSeq": []string{"2"},
|
||||||
},
|
},
|
||||||
Content: []byte("m=video 0 RTP/AVP 96\n" +
|
Content: []byte("m=video 0 RTP/AVP 96\n" +
|
||||||
"a=control:streamid=0\n" +
|
"a=control:streamid=0\n" +
|
||||||
@@ -83,21 +107,21 @@ var casesResponse = []struct {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResponseDecode(t *testing.T) {
|
func TestResponseRead(t *testing.T) {
|
||||||
for _, c := range casesResponse {
|
for _, c := range casesResponse {
|
||||||
t.Run(c.name, func(t *testing.T) {
|
t.Run(c.name, func(t *testing.T) {
|
||||||
res, err := responseDecode(bytes.NewBuffer(c.byts))
|
res, err := readResponse(bytes.NewBuffer(c.byts))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, c.res, res)
|
require.Equal(t, c.res, res)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResponseEncode(t *testing.T) {
|
func TestResponseWrite(t *testing.T) {
|
||||||
for _, c := range casesResponse {
|
for _, c := range casesResponse {
|
||||||
t.Run(c.name, func(t *testing.T) {
|
t.Run(c.name, func(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
err := responseEncode(&buf, c.res)
|
err := c.res.write(&buf)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, c.byts, buf.Bytes())
|
require.Equal(t, c.byts, buf.Bytes())
|
||||||
})
|
})
|
||||||
|
98
utils.go
98
utils.go
@@ -4,16 +4,12 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"sort"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
_RTSP_PROTO = "RTSP/1.0"
|
_RTSP_PROTO = "RTSP/1.0"
|
||||||
_MAX_HEADER_COUNT = 255
|
_MAX_CONTENT_LENGTH = 4096
|
||||||
_MAX_HEADER_KEY_LENGTH = 255
|
|
||||||
_MAX_HEADER_VALUE_LENGTH = 255
|
|
||||||
_MAX_CONTENT_LENGTH = 4096
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func readBytesLimited(rb *bufio.Reader, delim byte, n int) ([]byte, error) {
|
func readBytesLimited(rb *bufio.Reader, delim byte, n int) ([]byte, error) {
|
||||||
@@ -44,95 +40,13 @@ func readByteEqual(rb *bufio.Reader, cmp byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readHeaders(rb *bufio.Reader) (map[string]string, error) {
|
func readContent(rb *bufio.Reader, header Header) ([]byte, error) {
|
||||||
ret := make(map[string]string)
|
cls, ok := header["Content-Length"]
|
||||||
|
if !ok || len(cls) != 1 {
|
||||||
for {
|
|
||||||
byt, err := rb.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if byt == '\r' {
|
|
||||||
err := readByteEqual(rb, '\n')
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(ret) >= _MAX_HEADER_COUNT {
|
|
||||||
return nil, fmt.Errorf("headers count exceeds %d", _MAX_HEADER_COUNT)
|
|
||||||
}
|
|
||||||
|
|
||||||
key := string([]byte{byt})
|
|
||||||
byts, err := readBytesLimited(rb, ':', _MAX_HEADER_KEY_LENGTH-1)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
key += string(byts[:len(byts)-1])
|
|
||||||
|
|
||||||
err = readByteEqual(rb, ' ')
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
byts, err = readBytesLimited(rb, '\r', _MAX_HEADER_VALUE_LENGTH)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
val := string(byts[:len(byts)-1])
|
|
||||||
|
|
||||||
if len(val) == 0 {
|
|
||||||
return nil, fmt.Errorf("empty header value")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = readByteEqual(rb, '\n')
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// set only if not set previously
|
|
||||||
if _, ok := ret[key]; !ok {
|
|
||||||
ret[key] = val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeHeaders(wb *bufio.Writer, headers map[string]string) error {
|
|
||||||
// sort headers by key
|
|
||||||
// in order to obtain deterministic results
|
|
||||||
var keys []string
|
|
||||||
for key := range headers {
|
|
||||||
keys = append(keys, key)
|
|
||||||
}
|
|
||||||
sort.Strings(keys)
|
|
||||||
|
|
||||||
for _, key := range keys {
|
|
||||||
_, err := wb.Write([]byte(key + ": " + headers[key] + "\r\n"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := wb.Write([]byte("\r\n"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func readContent(rb *bufio.Reader, headers map[string]string) ([]byte, error) {
|
|
||||||
cls, ok := headers["Content-Length"]
|
|
||||||
if !ok {
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cl, err := strconv.ParseInt(cls, 10, 64)
|
cl, err := strconv.ParseInt(cls[0], 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid Content-Length")
|
return nil, fmt.Errorf("invalid Content-Length")
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user