improve tests

This commit is contained in:
aler9
2021-04-04 20:20:29 +02:00
parent c2d5ced43b
commit 14ce8dbc45
8 changed files with 161 additions and 48 deletions

View File

@@ -25,9 +25,11 @@ var casesInterleavedFrame = []struct {
} }
func TestInterleavedFrameRead(t *testing.T) { func TestInterleavedFrameRead(t *testing.T) {
// keep f global to make sure that all its fields are overridden.
var f InterleavedFrame
for _, ca := range casesInterleavedFrame { for _, ca := range casesInterleavedFrame {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
var f InterleavedFrame
f.Payload = make([]byte, 1024) f.Payload = make([]byte, 1024)
err := f.Read(bufio.NewReader(bytes.NewBuffer(ca.enc))) err := f.Read(bufio.NewReader(bytes.NewBuffer(ca.enc)))
require.NoError(t, err) require.NoError(t, err)
@@ -48,3 +50,34 @@ func TestInterleavedFrameWrite(t *testing.T) {
}) })
} }
} }
func TestInterleavedFrameReadErrors(t *testing.T) {
for _, ca := range []struct {
name string
byts []byte
}{
{
"empty",
[]byte{},
},
{
"invalid magic byte",
[]byte{0x55, 0x00, 0x00, 0x00},
},
{
"length too big",
[]byte{0x24, 0x00, 0x00, 0x08},
},
{
"invalid payload",
[]byte{0x24, 0x00, 0x00, 0x08, 0x01, 0x02},
},
} {
t.Run(ca.name, func(t *testing.T) {
var f InterleavedFrame
f.Payload = make([]byte, 5)
err := f.Read(bufio.NewReader(bytes.NewBuffer(ca.byts)))
require.Error(t, err)
})
}
}

View File

@@ -11,7 +11,7 @@ import (
const ( const (
rtspProtocol10 = "RTSP/1.0" rtspProtocol10 = "RTSP/1.0"
requestMaxMethodLength = 64 requestMaxMethodLength = 64
requestMaxPathLength = 2048 requestMaxURLLength = 2048
requestMaxProtocolLength = 64 requestMaxProtocolLength = 64
) )
@@ -63,7 +63,7 @@ func (req *Request) Read(rb *bufio.Reader) error {
return fmt.Errorf("empty method") return fmt.Errorf("empty method")
} }
byts, err = readBytesLimited(rb, ' ', requestMaxPathLength) byts, err = readBytesLimited(rb, ' ', requestMaxURLLength)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -131,25 +131,65 @@ var casesRequest = []struct {
} }
func TestRequestRead(t *testing.T) { func TestRequestRead(t *testing.T) {
// keep req global to make sure that all its fields are overridden.
var req Request var req Request
for _, c := range casesRequest {
t.Run(c.name, func(t *testing.T) { for _, ca := range casesRequest {
err := req.Read(bufio.NewReader(bytes.NewBuffer(c.byts))) t.Run(ca.name, func(t *testing.T) {
err := req.Read(bufio.NewReader(bytes.NewBuffer(ca.byts)))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c.req, req) require.Equal(t, ca.req, req)
}) })
} }
} }
func TestRequestWrite(t *testing.T) { func TestRequestWrite(t *testing.T) {
for _, c := range casesRequest { for _, ca := range casesRequest {
t.Run(c.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
bw := bufio.NewWriter(&buf) bw := bufio.NewWriter(&buf)
err := c.req.Write(bw) err := ca.req.Write(bw)
require.NoError(t, err) require.NoError(t, err)
// do NOT call flush(), write() must have already done it // do NOT call flush(), write() must have already done it
require.Equal(t, c.byts, buf.Bytes()) require.Equal(t, ca.byts, buf.Bytes())
})
}
}
func TestRequestReadErrors(t *testing.T) {
for _, ca := range []struct {
name string
byts []byte
}{
{
"empty",
[]byte{},
},
{
"missing url, protocol, eol",
[]byte("GET"),
},
{
"missing protocol, eol",
[]byte("GET rtsp://testing123/test"),
},
{
"missing eol",
[]byte("GET rtsp://testing123/test RTSP/1.0"),
},
{
"invalid URL",
[]byte("GET http://testing123 RTSP/1.0\r\n"),
},
{
"invalid protocol",
[]byte("GET rtsp://testing123 RTSP/2.0\r\n"),
},
} {
t.Run(ca.name, func(t *testing.T) {
var req Request
err := req.Read(bufio.NewReader(bytes.NewBuffer(ca.byts)))
require.Error(t, err)
}) })
} }
} }

View File

@@ -109,7 +109,9 @@ var casesResponse = []struct {
} }
func TestResponseRead(t *testing.T) { func TestResponseRead(t *testing.T) {
// keep res global to make sure that all its fields are overridden.
var res Response var res Response
for _, c := range casesResponse { for _, c := range casesResponse {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
err := res.Read(bufio.NewReader(bytes.NewBuffer(c.byts))) err := res.Read(bufio.NewReader(bytes.NewBuffer(c.byts)))
@@ -132,7 +134,45 @@ func TestResponseWrite(t *testing.T) {
} }
} }
func TestResponseWriteStatusAutofill(t *testing.T) { func TestResponseReadErrors(t *testing.T) {
for _, ca := range []struct {
name string
byts []byte
}{
{
"empty",
[]byte{},
},
{
"missing code, message, eol",
[]byte("RTSP/1.0"),
},
{
"missing message, eol",
[]byte("RTSP/1.0 200"),
},
{
"missing eol",
[]byte("RTSP/1.0 200 OK"),
},
{
"invalid protocol",
[]byte("RTSP/2.0 200 OK\r\n"),
},
{
"invalid code",
[]byte("RTSP/2.0 string OK\r\n"),
},
} {
t.Run(ca.name, func(t *testing.T) {
var res Response
err := res.Read(bufio.NewReader(bytes.NewBuffer(ca.byts)))
require.Error(t, err)
})
}
}
func TestResponseWriteAutoFillStatus(t *testing.T) {
res := &Response{ res := &Response{
StatusCode: StatusMethodNotAllowed, StatusCode: StatusMethodNotAllowed,
Header: Header{ Header: Header{

View File

@@ -183,6 +183,15 @@ func TestAuthRead(t *testing.T) {
} }
} }
func TestAuthWrite(t *testing.T) {
for _, ca := range casesAuth {
t.Run(ca.name, func(t *testing.T) {
vout := ca.h.Write()
require.Equal(t, ca.vout, vout)
})
}
}
func TestAuthReadError(t *testing.T) { func TestAuthReadError(t *testing.T) {
for _, ca := range []struct { for _, ca := range []struct {
name string name string
@@ -212,12 +221,3 @@ func TestAuthReadError(t *testing.T) {
}) })
} }
} }
func TestAuthWrite(t *testing.T) {
for _, ca := range casesAuth {
t.Run(ca.name, func(t *testing.T) {
vout := ca.h.Write()
require.Equal(t, ca.vout, vout)
})
}
}

View File

@@ -178,6 +178,15 @@ func TestRTPInfoRead(t *testing.T) {
} }
} }
func TestRTPInfoWrite(t *testing.T) {
for _, ca := range casesRTPInfo {
t.Run(ca.name, func(t *testing.T) {
req := ca.h.Write()
require.Equal(t, ca.vout, req)
})
}
}
func TestRTPInfoReadError(t *testing.T) { func TestRTPInfoReadError(t *testing.T) {
for _, ca := range []struct { for _, ca := range []struct {
name string name string
@@ -199,12 +208,3 @@ func TestRTPInfoReadError(t *testing.T) {
}) })
} }
} }
func TestRTPInfoWrite(t *testing.T) {
for _, ca := range casesRTPInfo {
t.Run(ca.name, func(t *testing.T) {
req := ca.h.Write()
require.Equal(t, ca.vout, req)
})
}
}

View File

@@ -59,6 +59,15 @@ func TestSessionRead(t *testing.T) {
} }
} }
func TestSessionWrite(t *testing.T) {
for _, ca := range casesSession {
t.Run(ca.name, func(t *testing.T) {
req := ca.h.Write()
require.Equal(t, ca.vout, req)
})
}
}
func TestSessionReadError(t *testing.T) { func TestSessionReadError(t *testing.T) {
for _, ca := range []struct { for _, ca := range []struct {
name string name string
@@ -80,12 +89,3 @@ func TestSessionReadError(t *testing.T) {
}) })
} }
} }
func TestSessionWrite(t *testing.T) {
for _, ca := range casesSession {
t.Run(ca.name, func(t *testing.T) {
req := ca.h.Write()
require.Equal(t, ca.vout, req)
})
}
}

View File

@@ -124,6 +124,15 @@ func TestTransportRead(t *testing.T) {
} }
} }
func TestTransportWrite(t *testing.T) {
for _, ca := range casesTransport {
t.Run(ca.name, func(t *testing.T) {
req := ca.h.Write()
require.Equal(t, ca.vout, req)
})
}
}
func TestTransportReadError(t *testing.T) { func TestTransportReadError(t *testing.T) {
for _, ca := range []struct { for _, ca := range []struct {
name string name string
@@ -145,12 +154,3 @@ func TestTransportReadError(t *testing.T) {
}) })
} }
} }
func TestTransportWrite(t *testing.T) {
for _, ca := range casesTransport {
t.Run(ca.name, func(t *testing.T) {
req := ca.h.Write()
require.Equal(t, ca.vout, req)
})
}
}