improve coverage

This commit is contained in:
aler9
2022-12-11 23:36:43 +01:00
parent 0c13440721
commit 46cbb885b7
22 changed files with 126 additions and 44 deletions

View File

@@ -1074,7 +1074,7 @@ func TestClientPlayAutomaticProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Describe, req.Method) require.Equal(t, base.Describe, req.Method)
err = v.ValidateRequest(req) err = v.ValidateRequest(req, nil)
require.NoError(t, err) require.NoError(t, err)
err = conn.WriteResponse(&base.Response{ err = conn.WriteResponse(&base.Response{
@@ -1188,7 +1188,7 @@ func TestClientPlayAutomaticProtocol(t *testing.T) {
require.Equal(t, base.Setup, req.Method) require.Equal(t, base.Setup, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/mediaID=0"), req.URL) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/mediaID=0"), req.URL)
err = v.ValidateRequest(req) err = v.ValidateRequest(req, nil)
require.NoError(t, err) require.NoError(t, err)
var inTH headers.Transport var inTH headers.Transport

View File

@@ -180,7 +180,7 @@ func TestClientAuth(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Describe, req.Method) require.Equal(t, base.Describe, req.Method)
err = v.ValidateRequest(req) err = v.ValidateRequest(req, nil)
require.NoError(t, err) require.NoError(t, err)
medias := media.Medias{testH264Media.Clone()} medias := media.Medias{testH264Media.Clone()}

View File

@@ -78,7 +78,7 @@ func TestAuth(t *testing.T) {
req.URL = mustParseURL("rtsp://myhost/mypath") req.URL = mustParseURL("rtsp://myhost/mypath")
err = va.ValidateRequest(req) err = va.ValidateRequest(req, nil)
if conf != "nofail" { if conf != "nofail" {
require.Error(t, err) require.Error(t, err)
@@ -93,7 +93,7 @@ func TestAuth(t *testing.T) {
func TestAuthVLC(t *testing.T) { func TestAuthVLC(t *testing.T) {
for _, ca := range []struct { for _, ca := range []struct {
clientURL string clientURL string
serverURL string mediaURL string
}{ }{
{ {
"rtsp://myhost/mypath/", "rtsp://myhost/mypath/",
@@ -115,11 +115,13 @@ func TestAuthVLC(t *testing.T) {
URL: mustParseURL(ca.clientURL), URL: mustParseURL(ca.clientURL),
} }
se.AddAuthorization(req) se.AddAuthorization(req)
req.URL = mustParseURL(ca.mediaURL)
req.URL = mustParseURL(ca.serverURL) err = va.ValidateRequest(req, mustParseURL(ca.clientURL))
err = va.ValidateRequest(req)
require.NoError(t, err) require.NoError(t, err)
err = va.ValidateRequest(req, mustParseURL("rtsp://invalid"))
require.Error(t, err)
} }
} }
@@ -155,7 +157,7 @@ func TestAuthHashed(t *testing.T) {
} }
va.AddAuthorization(req) va.AddAuthorization(req)
err = se.ValidateRequest(req) err = se.ValidateRequest(req, nil)
if conf != "nofail" { if conf != "nofail" {
require.Error(t, err) require.Error(t, err)

View File

@@ -11,34 +11,6 @@ import (
"github.com/aler9/gortsplib/v2/pkg/url" "github.com/aler9/gortsplib/v2/pkg/url"
) )
func stringsReverseIndex(s, substr string) int {
for i := len(s) - 1 - len(substr); i >= 0; i-- {
if s[i:i+len(substr)] == substr {
return i
}
}
return -1
}
func generateAltURL(req *base.Request) (*url.URL, bool) {
if req.Method != base.Setup {
return nil, false
}
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
if !ok {
return nil, false
}
i := stringsReverseIndex(pathAndQuery, "/trackID=")
if i < 0 {
return nil, false
}
ur, _ := url.Parse(req.URL.Scheme + "://" + req.URL.Host + "/" + pathAndQuery[:i+1])
return ur, true
}
// Validator allows to validate credentials generated by a Sender. // Validator allows to validate credentials generated by a Sender.
type Validator struct { type Validator struct {
user string user string
@@ -113,7 +85,7 @@ func (va *Validator) Header() base.HeaderValue {
} }
// ValidateRequest validates a request sent by a client. // ValidateRequest validates a request sent by a client.
func (va *Validator) ValidateRequest(req *base.Request) error { func (va *Validator) ValidateRequest(req *base.Request, baseURL *url.URL) error {
var auth headers.Authorization var auth headers.Authorization
err := auth.Unmarshal(req.Header["Authorization"]) err := auth.Unmarshal(req.Header["Authorization"])
if err != nil { if err != nil {
@@ -179,10 +151,9 @@ func (va *Validator) ValidateRequest(req *base.Request) error {
if *auth.DigestValues.URI != ur.String() { if *auth.DigestValues.URI != ur.String() {
// in SETUP requests, VLC strips the control attribute. // in SETUP requests, VLC strips the control attribute.
// try again with an alternative URL without the control attribute. // try again with the base URL.
if altURL, ok := generateAltURL(req); ok { if baseURL != nil {
ur = altURL ur = baseURL
if *auth.DigestValues.URI != ur.String() { if *auth.DigestValues.URI != ur.String() {
return fmt.Errorf("wrong URL") return fmt.Errorf("wrong URL")
} }

View File

@@ -64,7 +64,7 @@ func TestValidatorErrors(t *testing.T) {
Header: base.Header{ Header: base.Header{
"Authorization": ca.hv, "Authorization": ca.hv,
}, },
}) }, nil)
require.EqualError(t, err, ca.err) require.EqualError(t, err, ca.err)
}) })
} }

View File

@@ -3,6 +3,7 @@ package format
import ( import (
"testing" "testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -11,6 +12,7 @@ func TestG711Attributes(t *testing.T) {
require.Equal(t, "G711", format.String()) require.Equal(t, "G711", format.String())
require.Equal(t, 8000, format.ClockRate()) require.Equal(t, 8000, format.ClockRate())
require.Equal(t, uint8(8), format.PayloadType()) require.Equal(t, uint8(8), format.PayloadType())
require.Equal(t, true, format.PTSEqualsDTS(&rtp.Packet{}))
format = &G711{ format = &G711{
MULaw: true, MULaw: true,

View File

@@ -3,6 +3,7 @@ package format
import ( import (
"testing" "testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -11,6 +12,7 @@ func TestG722Attributes(t *testing.T) {
require.Equal(t, "G722", format.String()) require.Equal(t, "G722", format.String())
require.Equal(t, 8000, format.ClockRate()) require.Equal(t, 8000, format.ClockRate())
require.Equal(t, uint8(9), format.PayloadType()) require.Equal(t, uint8(9), format.PayloadType())
require.Equal(t, true, format.PTSEqualsDTS(&rtp.Packet{}))
} }
func TestG722Clone(t *testing.T) { func TestG722Clone(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package format
import ( import (
"testing" "testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -19,6 +20,7 @@ func TestGenericAttributes(t *testing.T) {
require.Equal(t, "Generic", format.String()) require.Equal(t, "Generic", format.String())
require.Equal(t, 90000, format.ClockRate()) require.Equal(t, 90000, format.ClockRate())
require.Equal(t, uint8(98), format.PayloadType()) require.Equal(t, uint8(98), format.PayloadType())
require.Equal(t, true, format.PTSEqualsDTS(&rtp.Packet{}))
} }
func TestGenericClone(t *testing.T) { func TestGenericClone(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package format
import ( import (
"testing" "testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -25,6 +26,22 @@ func TestH264Attributes(t *testing.T) {
require.Equal(t, []byte{0x09, 0x0A}, format.SafePPS()) require.Equal(t, []byte{0x09, 0x0A}, format.SafePPS())
} }
func TestH264PTSEqualsDTS(t *testing.T) {
format := &H264{
PayloadTyp: 96,
SPS: []byte{0x01, 0x02},
PPS: []byte{0x03, 0x04},
PacketizationMode: 1,
}
require.Equal(t, true, format.PTSEqualsDTS(&rtp.Packet{
Payload: []byte{0x05},
}))
require.Equal(t, false, format.PTSEqualsDTS(&rtp.Packet{
Payload: []byte{0x01},
}))
}
func TestH264Clone(t *testing.T) { func TestH264Clone(t *testing.T) {
format := &H264{ format := &H264{
PayloadTyp: 96, PayloadTyp: 96,

View File

@@ -3,6 +3,7 @@ package format
import ( import (
"testing" "testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -16,6 +17,7 @@ func TestH265Attributes(t *testing.T) {
require.Equal(t, "H265", format.String()) require.Equal(t, "H265", format.String())
require.Equal(t, 90000, format.ClockRate()) require.Equal(t, 90000, format.ClockRate())
require.Equal(t, uint8(96), format.PayloadType()) require.Equal(t, uint8(96), format.PayloadType())
require.Equal(t, true, format.PTSEqualsDTS(&rtp.Packet{}))
require.Equal(t, []byte{0x01, 0x02}, format.SafeVPS()) require.Equal(t, []byte{0x01, 0x02}, format.SafeVPS())
require.Equal(t, []byte{0x03, 0x04}, format.SafeSPS()) require.Equal(t, []byte{0x03, 0x04}, format.SafeSPS())
require.Equal(t, []byte{0x05, 0x06}, format.SafePPS()) require.Equal(t, []byte{0x05, 0x06}, format.SafePPS())

View File

@@ -3,6 +3,7 @@ package format
import ( import (
"testing" "testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -11,6 +12,7 @@ func TestJPEGAttributes(t *testing.T) {
require.Equal(t, "JPEG", format.String()) require.Equal(t, "JPEG", format.String())
require.Equal(t, 90000, format.ClockRate()) require.Equal(t, 90000, format.ClockRate())
require.Equal(t, uint8(26), format.PayloadType()) require.Equal(t, uint8(26), format.PayloadType())
require.Equal(t, true, format.PTSEqualsDTS(&rtp.Packet{}))
} }
func TestJPEGClone(t *testing.T) { func TestJPEGClone(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package format
import ( import (
"testing" "testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -16,6 +17,7 @@ func TestLPCMAttributes(t *testing.T) {
require.Equal(t, "LPCM", format.String()) require.Equal(t, "LPCM", format.String())
require.Equal(t, 44100, format.ClockRate()) require.Equal(t, 44100, format.ClockRate())
require.Equal(t, uint8(96), format.PayloadType()) require.Equal(t, uint8(96), format.PayloadType())
require.Equal(t, true, format.PTSEqualsDTS(&rtp.Packet{}))
} }
func TestTracLPCMClone(t *testing.T) { func TestTracLPCMClone(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package format
import ( import (
"testing" "testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -11,6 +12,7 @@ func TestMPEG2AudioAttributes(t *testing.T) {
require.Equal(t, "MPEG2-audio", format.String()) require.Equal(t, "MPEG2-audio", format.String())
require.Equal(t, 90000, format.ClockRate()) require.Equal(t, 90000, format.ClockRate())
require.Equal(t, uint8(14), format.PayloadType()) require.Equal(t, uint8(14), format.PayloadType())
require.Equal(t, true, format.PTSEqualsDTS(&rtp.Packet{}))
} }
func TestMPEG2AudioClone(t *testing.T) { func TestMPEG2AudioClone(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package format
import ( import (
"testing" "testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -11,6 +12,7 @@ func TestMPEG2VideoAttributes(t *testing.T) {
require.Equal(t, "MPEG2-video", format.String()) require.Equal(t, "MPEG2-video", format.String())
require.Equal(t, 90000, format.ClockRate()) require.Equal(t, 90000, format.ClockRate())
require.Equal(t, uint8(32), format.PayloadType()) require.Equal(t, uint8(32), format.PayloadType())
require.Equal(t, true, format.PTSEqualsDTS(&rtp.Packet{}))
} }
func TestMPEG2VideoClone(t *testing.T) { func TestMPEG2VideoClone(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package format
import ( import (
"testing" "testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/aler9/gortsplib/v2/pkg/mpeg4audio" "github.com/aler9/gortsplib/v2/pkg/mpeg4audio"
@@ -23,6 +24,7 @@ func TestMPEG4AudioAttributes(t *testing.T) {
require.Equal(t, "MPEG4-audio", format.String()) require.Equal(t, "MPEG4-audio", format.String())
require.Equal(t, 48000, format.ClockRate()) require.Equal(t, 48000, format.ClockRate())
require.Equal(t, uint8(96), format.PayloadType()) require.Equal(t, uint8(96), format.PayloadType())
require.Equal(t, true, format.PTSEqualsDTS(&rtp.Packet{}))
} }
func TestMPEG4AudioClone(t *testing.T) { func TestMPEG4AudioClone(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package format
import ( import (
"testing" "testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -15,6 +16,7 @@ func TestOpusAttributes(t *testing.T) {
require.Equal(t, "Opus", format.String()) require.Equal(t, "Opus", format.String())
require.Equal(t, 48000, format.ClockRate()) require.Equal(t, 48000, format.ClockRate())
require.Equal(t, uint8(96), format.PayloadType()) require.Equal(t, uint8(96), format.PayloadType())
require.Equal(t, true, format.PTSEqualsDTS(&rtp.Packet{}))
} }
func TestTracOpusClone(t *testing.T) { func TestTracOpusClone(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package format
import ( import (
"testing" "testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -16,6 +17,7 @@ func TestVorbisAttributes(t *testing.T) {
require.Equal(t, "Vorbis", format.String()) require.Equal(t, "Vorbis", format.String())
require.Equal(t, 48000, format.ClockRate()) require.Equal(t, 48000, format.ClockRate())
require.Equal(t, uint8(96), format.PayloadType()) require.Equal(t, uint8(96), format.PayloadType())
require.Equal(t, true, format.PTSEqualsDTS(&rtp.Packet{}))
} }
func TestTracVorbisClone(t *testing.T) { func TestTracVorbisClone(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package format
import ( import (
"testing" "testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -13,6 +14,7 @@ func TestVP8ttributes(t *testing.T) {
require.Equal(t, "VP8", format.String()) require.Equal(t, "VP8", format.String())
require.Equal(t, 90000, format.ClockRate()) require.Equal(t, 90000, format.ClockRate())
require.Equal(t, uint8(99), format.PayloadType()) require.Equal(t, uint8(99), format.PayloadType())
require.Equal(t, true, format.PTSEqualsDTS(&rtp.Packet{}))
} }
func TestVP8Clone(t *testing.T) { func TestVP8Clone(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package format
import ( import (
"testing" "testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -13,6 +14,7 @@ func TestVP9Attributes(t *testing.T) {
require.Equal(t, "VP9", format.String()) require.Equal(t, "VP9", format.String())
require.Equal(t, 90000, format.ClockRate()) require.Equal(t, 90000, format.ClockRate())
require.Equal(t, uint8(100), format.PayloadType()) require.Equal(t, uint8(100), format.PayloadType())
require.Equal(t, true, format.PTSEqualsDTS(&rtp.Packet{}))
} }
func TestVP9Clone(t *testing.T) { func TestVP9Clone(t *testing.T) {

View File

@@ -0,0 +1,17 @@
package h264
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIDRPresent(t *testing.T) {
require.Equal(t, true, IDRPresent([][]byte{
{0x05},
{0x07},
}))
require.Equal(t, false, IDRPresent([][]byte{
{0x01},
}))
}

View File

@@ -67,6 +67,15 @@ var configCases = []struct {
CoreCoderDelay: 385, CoreCoderDelay: 385,
}, },
}, },
{
"aac-lc 44.1khz 8 chans",
[]byte{0x12, 0x38},
Config{
Type: ObjectTypeAACLC,
SampleRate: 44100,
ChannelCount: 8,
},
},
{ {
"sbr (he-aac v1) 44.1khz stereo", "sbr (he-aac v1) 44.1khz stereo",
[]byte{0x2b, 0x8a, 0x00}, []byte{0x2b, 0x8a, 0x00},
@@ -90,6 +99,46 @@ func TestConfigUnmarshal(t *testing.T) {
} }
} }
func TestConfigUnmarshalErrors(t *testing.T) {
for _, ca := range []struct {
name string
enc []byte
err string
}{
{
"empty",
[]byte{},
"not enough bits",
},
{
"unsupported object type",
[]byte{0xF1},
"unsupported object type: 30",
},
{
"no sample rate index",
[]byte{0b00010000},
"not enough bits",
},
{
"invalid sample rate index",
[]byte{0b00010110, 0b10000000},
"invalid sample rate index (13)",
},
{
"channel config 0",
[]byte{0b00010100, 0b00000000},
"not yet supported",
},
} {
t.Run(ca.name, func(t *testing.T) {
var dec Config
err := dec.Unmarshal(ca.enc)
require.EqualError(t, err, ca.err)
})
}
}
func TestConfigMarshal(t *testing.T) { func TestConfigMarshal(t *testing.T) {
for _, ca := range configCases { for _, ca := range configCases {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {

View File

@@ -1136,7 +1136,7 @@ func TestServerAuth(t *testing.T) {
s := &Server{ s := &Server{
Handler: &testServerHandler{ Handler: &testServerHandler{
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
err := authValidator.ValidateRequest(ctx.Request) err := authValidator.ValidateRequest(ctx.Request, nil)
if err != nil { if err != nil {
return &base.Response{ return &base.Response{
StatusCode: base.StatusUnauthorized, StatusCode: base.StatusUnauthorized,