support encrypted streams with SRTP and MIKEY (#520) (#809)

This commit is contained in:
Alessandro Ros
2025-07-05 12:48:13 +02:00
committed by GitHub
parent a5ff92f130
commit 616fa7ea89
104 changed files with 4179 additions and 766 deletions

View File

@@ -2,6 +2,7 @@ package gortsplib
import (
"bytes"
"crypto/rand"
"crypto/tls"
"net"
"strconv"
@@ -18,6 +19,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/gortsplib/v4/pkg/headers"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
"github.com/bluenviron/gortsplib/v4/pkg/sdp"
)
@@ -337,6 +339,9 @@ func TestServerRecordPath(t *testing.T) {
media := testH264Media
media.Control = ca.control
enc, err := media.Marshal2()
require.NoError(t, err)
sout := &sdp.SessionDescription{
SessionName: psdp.SessionName("Stream"),
Origin: psdp.Origin{
@@ -348,7 +353,7 @@ func TestServerRecordPath(t *testing.T) {
TimeDescriptions: []psdp.TimeDescription{
{Timing: psdp.Timing{}},
},
MediaDescriptions: []*psdp.MediaDescription{media.Marshal()},
MediaDescriptions: []*psdp.MediaDescription{enc},
}
byts, _ := sout.Marshal()
@@ -533,12 +538,38 @@ func TestServerRecordErrorRecordPartialMedias(t *testing.T) {
}
func TestServerRecord(t *testing.T) {
for _, transport := range []string{
"udp",
"tcp",
"tls",
for _, ca := range []struct {
scheme string
transport string
secure string
}{
{
"rtsp",
"udp",
"unsecure",
},
{
"rtsp",
"tcp",
"unsecure",
},
{
"rtsps",
"tcp",
"unsecure",
},
{
"rtsps",
"udp",
"secure",
},
{
"rtsps",
"tcp",
"secure",
},
} {
t.Run(transport, func(t *testing.T) {
t.Run(ca.scheme+"_"+ca.transport+"_"+ca.secure, func(t *testing.T) {
nconnOpened := make(chan struct{})
nconnClosed := make(chan struct{})
sessionOpened := make(chan struct{})
@@ -552,9 +583,9 @@ func TestServerRecord(t *testing.T) {
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
s := ctx.Conn.Stats()
require.Greater(t, s.BytesSent, uint64(510))
require.Less(t, s.BytesSent, uint64(560))
require.Less(t, s.BytesSent, uint64(1100))
require.Greater(t, s.BytesReceived, uint64(1000))
require.Less(t, s.BytesReceived, uint64(1200))
require.Less(t, s.BytesReceived, uint64(1800))
close(nconnClosed)
},
@@ -564,9 +595,9 @@ func TestServerRecord(t *testing.T) {
onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) {
s := ctx.Session.Stats()
require.Greater(t, s.BytesSent, uint64(75))
require.Less(t, s.BytesSent, uint64(130))
require.Less(t, s.BytesSent, uint64(140))
require.Greater(t, s.BytesReceived, uint64(70))
require.Less(t, s.BytesReceived, uint64(80))
require.Less(t, s.BytesReceived, uint64(130))
close(sessionClosed)
},
@@ -581,12 +612,12 @@ func TestServerRecord(t *testing.T) {
}, nil, nil
},
onRecord: func(ctx *ServerHandlerOnRecordCtx) (*base.Response, error) {
switch transport {
switch ca.transport {
case "udp":
v := TransportUDP
require.Equal(t, &v, ctx.Session.SetuppedTransport())
case "tcp", "tls":
case "tcp":
v := TransportTCP
require.Equal(t, &v, ctx.Session.SetuppedTransport())
}
@@ -628,12 +659,12 @@ func TestServerRecord(t *testing.T) {
RTSPAddress: "localhost:8554",
}
switch transport {
case "udp":
if ca.transport == "udp" {
s.UDPRTPAddress = "127.0.0.1:8000"
s.UDPRTCPAddress = "127.0.0.1:8001"
}
case "tls":
if ca.scheme == "rtsps" {
cert, err := tls.X509KeyPair(serverCert, serverKey)
require.NoError(t, err)
s.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
@@ -648,7 +679,7 @@ func TestServerRecord(t *testing.T) {
defer nconn.Close()
nconn = func() net.Conn {
if transport == "tls" {
if ca.scheme == "rtsps" {
return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true})
}
return nconn
@@ -686,6 +717,8 @@ func TestServerRecord(t *testing.T) {
var l2s [2]net.PacketConn
var session string
var serverPorts [2]*[2]int
var srtpOutCtx [2]*wrappedSRTPContext
var srtpInCtx [2]*wrappedSRTPContext
for i := 0; i < 2; i++ {
inTH := &headers.Transport{
@@ -693,7 +726,7 @@ func TestServerRecord(t *testing.T) {
Mode: transportModePtr(headers.TransportModeRecord),
}
if transport == "udp" {
if ca.transport == "udp" {
inTH.Protocol = headers.TransportProtocolUDP
inTH.ClientPorts = &[2]int{35466 + i*2, 35467 + i*2}
@@ -709,84 +742,186 @@ func TestServerRecord(t *testing.T) {
inTH.InterleavedIDs = &[2]int{2 + i*2, 3 + i*2}
}
res, th := doSetup(t, conn, "rtsp://localhost:8554/teststream?param=value/"+medias[i].Control, inTH, "")
h := base.Header{
"CSeq": base.HeaderValue{"1"},
}
if session != "" {
h["Session"] = base.HeaderValue{session}
}
if ca.secure == "secure" {
inTH.Secure = true
key := make([]byte, srtpKeyLength)
_, err = rand.Read(key)
require.NoError(t, err)
srtpOutCtx[i] = &wrappedSRTPContext{
key: key,
ssrcs: []uint32{2345423},
}
err = srtpOutCtx[i].initialize()
require.NoError(t, err)
var mikeyMsg *mikey.Message
mikeyMsg, err = mikeyGenerate(srtpOutCtx[i])
require.NoError(t, err)
var enc base.HeaderValue
enc, err = headers.KeyMgmt{
URL: "rtsp://localhost:8554/teststream?param=value/" + medias[i].Control,
MikeyMessage: mikeyMsg,
}.Marshal()
require.NoError(t, err)
h["KeyMgmt"] = enc
}
h["Transport"] = inTH.Marshal()
var res *base.Response
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream?param=value/" + medias[i].Control),
Header: h,
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
var th headers.Transport
err = th.Unmarshal(res.Header["Transport"])
require.NoError(t, err)
session = readSession(t, res)
if transport == "udp" {
if ca.transport == "udp" {
serverPorts[i] = th.ServerPorts
}
if ca.secure == "secure" {
require.True(t, th.Secure)
var keyMgmt headers.KeyMgmt
err = keyMgmt.Unmarshal(res.Header["KeyMgmt"])
require.NoError(t, err)
srtpInCtx[i], err = mikeyToContext(keyMgmt.MikeyMessage)
require.NoError(t, err)
}
}
doRecord(t, conn, "rtsp://localhost:8554/teststream", session)
for i := 0; i < 2; i++ {
// skip firewall opening
if transport == "udp" {
// skip firewall opening
if ca.transport == "udp" {
for i := 0; i < 2; i++ {
buf := make([]byte, 2048)
_, _, err = l2s[i].ReadFrom(buf)
require.NoError(t, err)
}
}
// server -> client
// server -> client
if transport == "udp" {
buf := make([]byte, 2048)
for i := 0; i < 2; i++ {
var buf []byte
if ca.transport == "udp" {
buf = make([]byte, 2048)
var n int
n, _, err = l2s[i].ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
buf = buf[:n]
} else {
var f *base.InterleavedFrame
f, err = conn.ReadInterleavedFrame()
require.NoError(t, err)
require.Equal(t, 3+i*2, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
buf = f.Payload
}
// client -> server
if ca.secure == "secure" {
buf, err = srtpInCtx[i].decryptRTCP(buf, buf, nil)
require.NoError(t, err)
}
if transport == "udp" {
_, err = l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{
require.Equal(t, testRTCPPacketMarshaled, buf)
}
// client -> server
for i := 0; i < 2; i++ {
buf1 := testRTPPacketMarshaled
if ca.secure == "secure" {
encr := make([]byte, 2000)
encr, err = srtpOutCtx[i].encryptRTP(encr, buf1, nil)
require.NoError(t, err)
buf1 = encr
}
buf2 := testRTCPPacketMarshaled
if ca.secure == "secure" {
encr := make([]byte, 2000)
encr, err = srtpOutCtx[i].encryptRTCP(encr, buf2, nil)
require.NoError(t, err)
buf2 = encr
}
if ca.transport == "udp" {
_, err = l1s[i].WriteTo(buf1, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: serverPorts[i][0],
})
require.NoError(t, err)
_, err = l2s[i].WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{
_, err = l2s[i].WriteTo(buf2, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: serverPorts[i][1],
})
require.NoError(t, err)
} else {
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 2 + i*2,
Payload: testRTPPacketMarshaled,
Payload: buf1,
}, make([]byte, 1024))
require.NoError(t, err)
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 3 + i*2,
Payload: testRTCPPacketMarshaled,
Payload: buf2,
}, make([]byte, 1024))
require.NoError(t, err)
}
}
for i := 0; i < 2; i++ {
// server -> client
// server -> client
if transport == "udp" {
buf := make([]byte, 2048)
n, _, err := l2s[i].ReadFrom(buf)
for i := 0; i < 2; i++ {
var buf []byte
if ca.transport == "udp" {
buf = make([]byte, 2048)
var n int
n, _, err = l2s[i].ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
buf = buf[:n]
} else {
f, err := conn.ReadInterleavedFrame()
var f *base.InterleavedFrame
f, err = conn.ReadInterleavedFrame()
require.NoError(t, err)
require.Equal(t, 3+i*2, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
buf = f.Payload
}
if ca.secure == "secure" {
buf, err = srtpInCtx[i].decryptRTCP(buf, buf, nil)
require.NoError(t, err)
}
require.Equal(t, testRTCPPacketMarshaled, buf)
}
doTeardown(t, conn, "rtsp://localhost:8554/teststream", session)