mirror of
https://github.com/aler9/gortsplib
synced 2025-10-04 06:46:42 +08:00
@@ -2,6 +2,7 @@ package gortsplib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
"github.com/bluenviron/gortsplib/v4/pkg/description"
|
||||
"github.com/bluenviron/gortsplib/v4/pkg/format"
|
||||
"github.com/bluenviron/gortsplib/v4/pkg/headers"
|
||||
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
|
||||
"github.com/bluenviron/gortsplib/v4/pkg/sdp"
|
||||
)
|
||||
|
||||
@@ -126,19 +128,37 @@ func readRequestIgnoreFrames(c *conn.Conn) (*base.Request, error) {
|
||||
}
|
||||
|
||||
func TestClientRecord(t *testing.T) {
|
||||
for _, transport := range []string{
|
||||
"udp",
|
||||
"tcp",
|
||||
"tls",
|
||||
for _, ca := range []struct {
|
||||
scheme string
|
||||
transport string
|
||||
secure string
|
||||
}{
|
||||
{
|
||||
"rtsp",
|
||||
"udp",
|
||||
"unsecure",
|
||||
},
|
||||
{
|
||||
"rtsp",
|
||||
"tcp",
|
||||
"unsecure",
|
||||
},
|
||||
{
|
||||
"rtsps",
|
||||
"udp",
|
||||
"secure",
|
||||
},
|
||||
{
|
||||
"rtsps",
|
||||
"tcp",
|
||||
"secure",
|
||||
},
|
||||
} {
|
||||
t.Run(transport, func(t *testing.T) {
|
||||
t.Run(ca.scheme+"_"+ca.transport+"_"+ca.secure, func(t *testing.T) {
|
||||
var l net.Listener
|
||||
var err error
|
||||
|
||||
var scheme string
|
||||
if transport == "tls" {
|
||||
scheme = "rtsps"
|
||||
|
||||
if ca.scheme == "rtsps" {
|
||||
var cert tls.Certificate
|
||||
cert, err = tls.X509KeyPair(serverCert, serverKey)
|
||||
require.NoError(t, err)
|
||||
@@ -147,8 +167,6 @@ func TestClientRecord(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
} else {
|
||||
scheme = "rtsp"
|
||||
|
||||
l, err = net.Listen("tcp", "localhost:8554")
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
@@ -156,6 +174,7 @@ func TestClientRecord(t *testing.T) {
|
||||
|
||||
serverDone := make(chan struct{})
|
||||
defer func() { <-serverDone }()
|
||||
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
|
||||
@@ -167,7 +186,7 @@ func TestClientRecord(t *testing.T) {
|
||||
req, err2 := conn.ReadRequest()
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, base.Options, req.Method)
|
||||
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
|
||||
require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL)
|
||||
|
||||
err2 = conn.WriteResponse(&base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
@@ -184,7 +203,7 @@ func TestClientRecord(t *testing.T) {
|
||||
req, err2 = conn.ReadRequest()
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, base.Announce, req.Method)
|
||||
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
|
||||
require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL)
|
||||
|
||||
var desc sdp.SessionDescription
|
||||
err = desc.Unmarshal(req.Body)
|
||||
@@ -194,6 +213,13 @@ func TestClientRecord(t *testing.T) {
|
||||
err = desc2.Unmarshal(&desc)
|
||||
require.NoError(t, err2)
|
||||
|
||||
if ca.secure == "secure" {
|
||||
require.True(t, desc2.Medias[0].Secure)
|
||||
|
||||
_, err = mikeyToContext(desc2.Medias[0].KeyMgmtMikey)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err2 = conn.WriteResponse(&base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
})
|
||||
@@ -203,7 +229,7 @@ func TestClientRecord(t *testing.T) {
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, base.Setup, req.Method)
|
||||
require.Equal(t, mustParseURL(
|
||||
scheme+"://localhost:8554/teststream/"+desc2.Medias[0].Control), req.URL)
|
||||
ca.scheme+"://localhost:8554/teststream/"+desc2.Medias[0].Control), req.URL)
|
||||
|
||||
var inTH headers.Transport
|
||||
err2 = inTH.Unmarshal(req.Header["Transport"])
|
||||
@@ -213,7 +239,7 @@ func TestClientRecord(t *testing.T) {
|
||||
|
||||
var l1 net.PacketConn
|
||||
var l2 net.PacketConn
|
||||
if transport == "udp" {
|
||||
if ca.transport == "udp" {
|
||||
l1, err2 = net.ListenPacket("udp", "localhost:34556")
|
||||
require.NoError(t, err2)
|
||||
defer l1.Close()
|
||||
@@ -223,11 +249,62 @@ func TestClientRecord(t *testing.T) {
|
||||
defer l2.Close()
|
||||
}
|
||||
|
||||
th := headers.Transport{
|
||||
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
|
||||
h := base.Header{
|
||||
"Session": headers.Session{
|
||||
Session: "ABCDE",
|
||||
Timeout: uintPtr(1),
|
||||
}.Marshal(),
|
||||
}
|
||||
|
||||
if transport == "udp" {
|
||||
th := headers.Transport{
|
||||
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
|
||||
Secure: inTH.Secure,
|
||||
}
|
||||
|
||||
var srtpInCtx *wrappedSRTPContext
|
||||
var srtpOutCtx *wrappedSRTPContext
|
||||
|
||||
if ca.secure == "secure" {
|
||||
th.Secure = true
|
||||
|
||||
require.True(t, th.Secure)
|
||||
|
||||
var keyMgmt headers.KeyMgmt
|
||||
err = keyMgmt.Unmarshal(req.Header["KeyMgmt"])
|
||||
require.NoError(t, err)
|
||||
|
||||
pl1, _ := mikeyGetPayload[*mikey.PayloadKEMAC](keyMgmt.MikeyMessage)
|
||||
pl2, _ := mikeyGetPayload[*mikey.PayloadKEMAC](desc2.Medias[0].KeyMgmtMikey)
|
||||
require.Equal(t, pl1, pl2)
|
||||
|
||||
srtpInCtx, err = mikeyToContext(keyMgmt.MikeyMessage)
|
||||
require.NoError(t, err)
|
||||
|
||||
outKey := make([]byte, srtpKeyLength)
|
||||
_, err = rand.Read(outKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
srtpOutCtx = &wrappedSRTPContext{
|
||||
key: outKey,
|
||||
ssrcs: []uint32{2345423},
|
||||
}
|
||||
err = srtpOutCtx.initialize()
|
||||
require.NoError(t, err)
|
||||
|
||||
var mikeyMsg *mikey.Message
|
||||
mikeyMsg, err = mikeyGenerate(srtpOutCtx)
|
||||
require.NoError(t, err)
|
||||
|
||||
var enc base.HeaderValue
|
||||
enc, err = headers.KeyMgmt{
|
||||
URL: req.URL.String(),
|
||||
MikeyMessage: mikeyMsg,
|
||||
}.Marshal()
|
||||
require.NoError(t, err)
|
||||
h["KeyMgmt"] = enc
|
||||
}
|
||||
|
||||
if ca.transport == "udp" {
|
||||
th.Protocol = headers.TransportProtocolUDP
|
||||
th.ServerPorts = &[2]int{34556, 34557}
|
||||
th.ClientPorts = inTH.ClientPorts
|
||||
@@ -236,54 +313,55 @@ func TestClientRecord(t *testing.T) {
|
||||
th.InterleavedIDs = inTH.InterleavedIDs
|
||||
}
|
||||
|
||||
h["Transport"] = th.Marshal()
|
||||
|
||||
err2 = conn.WriteResponse(&base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
Header: base.Header{
|
||||
"Transport": th.Marshal(),
|
||||
"Session": headers.Session{
|
||||
Session: "ABCDE",
|
||||
Timeout: uintPtr(1),
|
||||
}.Marshal(),
|
||||
},
|
||||
Header: h,
|
||||
})
|
||||
require.NoError(t, err2)
|
||||
|
||||
req, err2 = conn.ReadRequest()
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, base.Record, req.Method)
|
||||
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
|
||||
require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL)
|
||||
|
||||
err2 = conn.WriteResponse(&base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
})
|
||||
require.NoError(t, err2)
|
||||
|
||||
var pl []byte
|
||||
|
||||
// client -> server
|
||||
|
||||
if transport == "udp" {
|
||||
buf := make([]byte, 2048)
|
||||
var buf []byte
|
||||
|
||||
if ca.transport == "udp" {
|
||||
buf = make([]byte, 2048)
|
||||
var n int
|
||||
n, _, err2 = l1.ReadFrom(buf)
|
||||
require.NoError(t, err2)
|
||||
pl = buf[:n]
|
||||
buf = buf[:n]
|
||||
} else {
|
||||
var f *base.InterleavedFrame
|
||||
f, err2 = conn.ReadInterleavedFrame()
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, 0, f.Channel)
|
||||
pl = f.Payload
|
||||
buf = f.Payload
|
||||
}
|
||||
|
||||
if ca.secure == "secure" {
|
||||
buf, err2 = srtpInCtx.decryptRTP(buf, buf, nil)
|
||||
require.NoError(t, err2)
|
||||
}
|
||||
|
||||
var pkt rtp.Packet
|
||||
err2 = pkt.Unmarshal(pl)
|
||||
err2 = pkt.Unmarshal(buf)
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, testRTPPacket, pkt)
|
||||
|
||||
// client -> server keepalive
|
||||
|
||||
if transport == "udp" {
|
||||
if ca.transport == "udp" {
|
||||
recv := make(chan struct{})
|
||||
go func() {
|
||||
defer close(recv)
|
||||
@@ -301,8 +379,17 @@ func TestClientRecord(t *testing.T) {
|
||||
|
||||
// server -> client
|
||||
|
||||
if transport == "udp" {
|
||||
_, err2 = l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{
|
||||
buf = testRTCPPacketMarshaled
|
||||
|
||||
if ca.secure == "secure" {
|
||||
encr := make([]byte, 2000)
|
||||
encr, err2 = srtpOutCtx.encryptRTCP(encr, buf, nil)
|
||||
require.NoError(t, err2)
|
||||
buf = encr
|
||||
}
|
||||
|
||||
if ca.transport == "udp" {
|
||||
_, err2 = l2.WriteTo(buf, &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: th.ClientPorts[1],
|
||||
})
|
||||
@@ -310,7 +397,7 @@ func TestClientRecord(t *testing.T) {
|
||||
} else {
|
||||
err2 = conn.WriteInterleavedFrame(&base.InterleavedFrame{
|
||||
Channel: 1,
|
||||
Payload: testRTCPPacketMarshaled,
|
||||
Payload: buf,
|
||||
}, make([]byte, 1024))
|
||||
require.NoError(t, err2)
|
||||
}
|
||||
@@ -318,7 +405,7 @@ func TestClientRecord(t *testing.T) {
|
||||
req, err2 = conn.ReadRequest()
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, base.Teardown, req.Method)
|
||||
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
|
||||
require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL)
|
||||
|
||||
err2 = conn.WriteResponse(&base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
@@ -333,7 +420,7 @@ func TestClientRecord(t *testing.T) {
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
Transport: func() *Transport {
|
||||
if transport == "udp" {
|
||||
if ca.transport == "udp" {
|
||||
v := TransportUDP
|
||||
return &v
|
||||
}
|
||||
@@ -345,7 +432,7 @@ func TestClientRecord(t *testing.T) {
|
||||
medi := testH264Media
|
||||
medias := []*description.Media{medi}
|
||||
|
||||
err = record(&c, scheme+"://localhost:8554/teststream", medias,
|
||||
err = record(&c, ca.scheme+"://localhost:8554/teststream", medias,
|
||||
func(_ *description.Media, pkt rtcp.Packet) {
|
||||
require.Equal(t, &testRTCPPacket, pkt)
|
||||
close(recvDone)
|
||||
@@ -397,9 +484,9 @@ func TestClientRecord(t *testing.T) {
|
||||
}, s)
|
||||
|
||||
require.Greater(t, s.Session.BytesSent, uint64(15))
|
||||
require.Less(t, s.Session.BytesSent, uint64(17))
|
||||
require.Less(t, s.Session.BytesSent, uint64(30))
|
||||
require.Greater(t, s.Session.BytesReceived, uint64(19))
|
||||
require.Less(t, s.Session.BytesReceived, uint64(21))
|
||||
require.Less(t, s.Session.BytesReceived, uint64(40))
|
||||
|
||||
c.Close()
|
||||
<-done
|
||||
@@ -414,33 +501,18 @@ func TestClientRecordSocketError(t *testing.T) {
|
||||
for _, transport := range []string{
|
||||
"udp",
|
||||
"tcp",
|
||||
"tls",
|
||||
} {
|
||||
t.Run(transport, func(t *testing.T) {
|
||||
var l net.Listener
|
||||
var err error
|
||||
|
||||
var scheme string
|
||||
if transport == "tls" {
|
||||
scheme = "rtsps"
|
||||
|
||||
var cert tls.Certificate
|
||||
cert, err = tls.X509KeyPair(serverCert, serverKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
l, err = tls.Listen("tcp", "localhost:8554", &tls.Config{Certificates: []tls.Certificate{cert}})
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
} else {
|
||||
scheme = "rtsp"
|
||||
|
||||
l, err = net.Listen("tcp", "localhost:8554")
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
}
|
||||
l, err = net.Listen("tcp", "localhost:8554")
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
|
||||
serverDone := make(chan struct{})
|
||||
defer func() { <-serverDone }()
|
||||
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
|
||||
@@ -530,7 +602,7 @@ func TestClientRecordSocketError(t *testing.T) {
|
||||
medi := testH264Media
|
||||
medias := []*description.Media{medi}
|
||||
|
||||
err = record(&c, scheme+"://localhost:8554/teststream", medias, nil)
|
||||
err = record(&c, "rtsp://localhost:8554/teststream", medias, nil)
|
||||
require.NoError(t, err)
|
||||
defer c.Close()
|
||||
|
||||
@@ -559,6 +631,7 @@ func TestClientRecordPauseRecordSerial(t *testing.T) {
|
||||
|
||||
serverDone := make(chan struct{})
|
||||
defer func() { <-serverDone }()
|
||||
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
|
||||
@@ -707,6 +780,7 @@ func TestClientRecordPauseRecordParallel(t *testing.T) {
|
||||
|
||||
serverDone := make(chan struct{})
|
||||
defer func() { <-serverDone }()
|
||||
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
|
||||
@@ -885,6 +959,7 @@ func TestClientRecordAutomaticProtocol(t *testing.T) {
|
||||
|
||||
serverDone := make(chan struct{})
|
||||
defer func() { <-serverDone }()
|
||||
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
|
||||
@@ -1016,6 +1091,7 @@ func TestClientRecordDecodeErrors(t *testing.T) {
|
||||
|
||||
serverDone := make(chan struct{})
|
||||
defer func() { <-serverDone }()
|
||||
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
|
||||
@@ -1186,6 +1262,7 @@ func TestClientRecordRTCPReport(t *testing.T) {
|
||||
|
||||
serverDone := make(chan struct{})
|
||||
defer func() { <-serverDone }()
|
||||
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
|
||||
@@ -1371,6 +1448,7 @@ func TestClientRecordIgnoreTCPRTPPackets(t *testing.T) {
|
||||
|
||||
serverDone := make(chan struct{})
|
||||
defer func() { <-serverDone }()
|
||||
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
|
||||
|
Reference in New Issue
Block a user