From e9518993d461e4ca6447646bcf3cf00ddaecb0c1 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sat, 24 Apr 2021 11:46:48 +0200 Subject: [PATCH] base: improve negative tests --- pkg/base/header.go | 13 ++++----- pkg/base/header_test.go | 57 +++++++++++++++++++++++++++++++++++++++ pkg/base/request.go | 2 +- pkg/base/request_test.go | 29 ++++++++++++++++++-- pkg/base/response.go | 2 +- pkg/base/response_test.go | 22 ++++++++++++--- 6 files changed, 109 insertions(+), 16 deletions(-) diff --git a/pkg/base/header.go b/pkg/base/header.go index 0add5a3f..d223a52d 100644 --- a/pkg/base/header.go +++ b/pkg/base/header.go @@ -36,6 +36,7 @@ type Header map[string]HeaderValue func (h *Header) read(rb *bufio.Reader) error { *h = make(Header) + count := 0 for { byt, err := rb.ReadByte() @@ -52,15 +53,14 @@ func (h *Header) read(rb *bufio.Reader) error { break } - if len(*h) >= headerMaxEntryCount { - return fmt.Errorf("headers count exceeds %d (it's %d)", - headerMaxEntryCount, len(*h)) + if count >= headerMaxEntryCount { + return fmt.Errorf("headers count exceeds %d", headerMaxEntryCount) } key := string([]byte{byt}) byts, err := readBytesLimited(rb, ':', headerMaxKeyLength-1) if err != nil { - return err + return fmt.Errorf("value is missing") } key += string(byts[:len(byts)-1]) key = headerKeyNormalize(key) @@ -85,16 +85,13 @@ func (h *Header) read(rb *bufio.Reader) error { } val := string(byts[:len(byts)-1]) - if len(val) == 0 { - return fmt.Errorf("empty header value") - } - err = readByteEqual(rb, '\n') if err != nil { return err } (*h)[key] = append((*h)[key], val) + count++ } return nil diff --git a/pkg/base/header_test.go b/pkg/base/header_test.go index ac3152d3..df8edf72 100644 --- a/pkg/base/header_test.go +++ b/pkg/base/header_test.go @@ -42,6 +42,16 @@ var casesHeader = []struct { }, }, }, + { + "empty", + []byte("Testing:\r\n" + + "\r\n"), + []byte("Testing: \r\n" + + "\r\n"), + Header{ + "Testing": HeaderValue{""}, + }, + }, { "without space", []byte("CSeq:2\r\n" + @@ -116,3 +126,50 @@ func TestHeaderWrite(t *testing.T) { }) } } + +func TestHeaderReadErrors(t *testing.T) { + for _, ca := range []struct { + name string + dec []byte + err string + }{ + { + "empty", + []byte{}, + "EOF", + }, + { + "r without n", + []byte("Testing: val\rTesting: val\r\n"), + "expected '\n', got 'T'", + }, + { + "final r without n", + []byte("Testing: val\r\nTesting: val\r\n\r"), + "EOF", + }, + { + "missing value", + []byte("Testing\r\n"), + "value is missing", + }, + { + "too many entries", + func() []byte { + var ret []byte + for i := 0; i < headerMaxEntryCount+2; i++ { + ret = append(ret, []byte("Testing: val\r\n")...) + } + ret = append(ret, []byte("\r\n")...) + return ret + }(), + "headers count exceeds 255", + }, + } { + t.Run(ca.name, func(t *testing.T) { + h := make(Header) + err := h.read(bufio.NewReader(bytes.NewBuffer(ca.dec))) + require.Equal(t, ca.err, err.Error()) + }) + } +} diff --git a/pkg/base/request.go b/pkg/base/request.go index 91870b90..561ae3e7 100644 --- a/pkg/base/request.go +++ b/pkg/base/request.go @@ -71,7 +71,7 @@ func (req *Request) Read(rb *bufio.Reader) error { ur, err := ParseURL(rawURL) if err != nil { - return fmt.Errorf("unable to parse url (%v)", rawURL) + return fmt.Errorf("invalid URL (%v)", rawURL) } req.URL = ur diff --git a/pkg/base/request_test.go b/pkg/base/request_test.go index ba1d2885..4a007426 100644 --- a/pkg/base/request_test.go +++ b/pkg/base/request_test.go @@ -160,48 +160,73 @@ func TestRequestReadErrors(t *testing.T) { for _, ca := range []struct { name string byts []byte + err string }{ { "empty", []byte{}, + "EOF", }, { "missing url, protocol, eol", []byte("GET"), + "EOF", }, { "missing protocol, eol", []byte("GET rtsp://testing123/test"), + "EOF", }, { "missing eol", []byte("GET rtsp://testing123/test RTSP/1.0"), + "EOF", }, { "empty method", []byte(" rtsp://testing123 RTSP/1.0\r\n"), + "empty method", }, { "empty URL", []byte("GET RTSP/1.0\r\n"), + "invalid URL ()", }, { "empty protocol", - []byte("GET http://testing123 \r\n"), + []byte("GET rtsp://testing123 \r\n"), + "expected 'RTSP/1.0', got ''", }, { "invalid URL", []byte("GET http://testing123 RTSP/1.0\r\n"), + "invalid URL (http://testing123)", }, { "invalid protocol", []byte("GET rtsp://testing123 RTSP/2.0\r\n"), + "expected 'RTSP/1.0', got 'RTSP/2.0'", }, } { t.Run(ca.name, func(t *testing.T) { var req Request err := req.Read(bufio.NewReader(bytes.NewBuffer(ca.byts))) - require.Error(t, err) + require.Equal(t, ca.err, err.Error()) }) } } + +func TestRequestReadIgnoreFrames(t *testing.T) { + byts := []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4} + byts = append(byts, []byte("OPTIONS rtsp://example.com/media.mp4 RTSP/1.0\r\n"+ + "CSeq: 1\r\n"+ + "Proxy-Require: gzipped-messages\r\n"+ + "Require: implicit-play\r\n"+ + "\r\n")...) + + rb := bufio.NewReader(bytes.NewBuffer(byts)) + buf := make([]byte, 10) + var req Request + err := req.ReadIgnoreFrames(rb, buf) + require.NoError(t, err) +} diff --git a/pkg/base/response.go b/pkg/base/response.go index e3a3da04..38b8886b 100644 --- a/pkg/base/response.go +++ b/pkg/base/response.go @@ -164,7 +164,7 @@ func (res *Response) Read(rb *bufio.Reader) error { res.StatusMessage = string(byts[:len(byts)-1]) if len(res.StatusMessage) == 0 { - return fmt.Errorf("empty status") + return fmt.Errorf("empty status message") } err = readByteEqual(rb, '\n') diff --git a/pkg/base/response_test.go b/pkg/base/response_test.go index 0b859fbb..873946fd 100644 --- a/pkg/base/response_test.go +++ b/pkg/base/response_test.go @@ -122,44 +122,58 @@ func TestResponseReadErrors(t *testing.T) { for _, ca := range []struct { name string byts []byte + err string }{ { "empty", []byte{}, + "EOF", }, { "missing code, message, eol", []byte("RTSP/1.0"), + "EOF", }, { "missing message, eol", []byte("RTSP/1.0 200"), + "EOF", }, { "missing eol", []byte("RTSP/1.0 200 OK"), + "EOF", }, { "missing eol 2", []byte("RTSP/1.0 200 OK\r"), + "EOF", }, { "invalid protocol", []byte("RTSP/2.0 200 OK\r\n"), + "expected 'RTSP/1.0', got 'RTSP/2.0'", + }, + { + "code too long", + []byte("RTSP/1.0 1234 OK\r\n"), + "buffer length exceeds 4", }, { "invalid code", - []byte("RTSP/2.0 string OK\r\n"), + []byte("RTSP/1.0 str OK\r\n"), + "unable to parse status code", }, { "empty message", - []byte("RTSP/2.0 string \r\n"), + []byte("RTSP/1.0 200 \r\n"), + "empty status message", }, } { t.Run(ca.name, func(t *testing.T) { var res Response err := res.Read(bufio.NewReader(bytes.NewBuffer(ca.byts))) - require.Error(t, err) + require.Equal(t, ca.err, err.Error()) }) } } @@ -193,7 +207,7 @@ func TestResponseWriteAutoFillStatus(t *testing.T) { require.Equal(t, byts, buf.Bytes()) } -func TestReadIgnoreFrames(t *testing.T) { +func TestResponseReadIgnoreFrames(t *testing.T) { byts := []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4} byts = append(byts, []byte("RTSP/1.0 200 OK\r\n"+ "CSeq: 1\r\n"+