support encrypted streams with SRTP and MIKEY (#520) (#809)

This commit is contained in:
Alessandro Ros
2025-07-05 12:48:13 +02:00
committed by GitHub
parent a5ff92f130
commit 616fa7ea89
104 changed files with 4179 additions and 766 deletions

View File

@@ -2,6 +2,7 @@ package gortsplib
import (
"bytes"
"crypto/rand"
"crypto/tls"
"fmt"
"net"
@@ -19,6 +20,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/gortsplib/v4/pkg/headers"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
"github.com/bluenviron/gortsplib/v4/pkg/sdp"
)
@@ -126,19 +128,37 @@ func readRequestIgnoreFrames(c *conn.Conn) (*base.Request, error) {
}
func TestClientRecord(t *testing.T) {
for _, transport := range []string{
"udp",
"tcp",
"tls",
for _, ca := range []struct {
scheme string
transport string
secure string
}{
{
"rtsp",
"udp",
"unsecure",
},
{
"rtsp",
"tcp",
"unsecure",
},
{
"rtsps",
"udp",
"secure",
},
{
"rtsps",
"tcp",
"secure",
},
} {
t.Run(transport, func(t *testing.T) {
t.Run(ca.scheme+"_"+ca.transport+"_"+ca.secure, func(t *testing.T) {
var l net.Listener
var err error
var scheme string
if transport == "tls" {
scheme = "rtsps"
if ca.scheme == "rtsps" {
var cert tls.Certificate
cert, err = tls.X509KeyPair(serverCert, serverKey)
require.NoError(t, err)
@@ -147,8 +167,6 @@ func TestClientRecord(t *testing.T) {
require.NoError(t, err)
defer l.Close()
} else {
scheme = "rtsp"
l, err = net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
@@ -156,6 +174,7 @@ func TestClientRecord(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -167,7 +186,7 @@ func TestClientRecord(t *testing.T) {
req, err2 := conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Options, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
@@ -184,7 +203,7 @@ func TestClientRecord(t *testing.T) {
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Announce, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL)
var desc sdp.SessionDescription
err = desc.Unmarshal(req.Body)
@@ -194,6 +213,13 @@ func TestClientRecord(t *testing.T) {
err = desc2.Unmarshal(&desc)
require.NoError(t, err2)
if ca.secure == "secure" {
require.True(t, desc2.Medias[0].Secure)
_, err = mikeyToContext(desc2.Medias[0].KeyMgmtMikey)
require.NoError(t, err)
}
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
@@ -203,7 +229,7 @@ func TestClientRecord(t *testing.T) {
require.NoError(t, err2)
require.Equal(t, base.Setup, req.Method)
require.Equal(t, mustParseURL(
scheme+"://localhost:8554/teststream/"+desc2.Medias[0].Control), req.URL)
ca.scheme+"://localhost:8554/teststream/"+desc2.Medias[0].Control), req.URL)
var inTH headers.Transport
err2 = inTH.Unmarshal(req.Header["Transport"])
@@ -213,7 +239,7 @@ func TestClientRecord(t *testing.T) {
var l1 net.PacketConn
var l2 net.PacketConn
if transport == "udp" {
if ca.transport == "udp" {
l1, err2 = net.ListenPacket("udp", "localhost:34556")
require.NoError(t, err2)
defer l1.Close()
@@ -223,11 +249,62 @@ func TestClientRecord(t *testing.T) {
defer l2.Close()
}
th := headers.Transport{
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
h := base.Header{
"Session": headers.Session{
Session: "ABCDE",
Timeout: uintPtr(1),
}.Marshal(),
}
if transport == "udp" {
th := headers.Transport{
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
Secure: inTH.Secure,
}
var srtpInCtx *wrappedSRTPContext
var srtpOutCtx *wrappedSRTPContext
if ca.secure == "secure" {
th.Secure = true
require.True(t, th.Secure)
var keyMgmt headers.KeyMgmt
err = keyMgmt.Unmarshal(req.Header["KeyMgmt"])
require.NoError(t, err)
pl1, _ := mikeyGetPayload[*mikey.PayloadKEMAC](keyMgmt.MikeyMessage)
pl2, _ := mikeyGetPayload[*mikey.PayloadKEMAC](desc2.Medias[0].KeyMgmtMikey)
require.Equal(t, pl1, pl2)
srtpInCtx, err = mikeyToContext(keyMgmt.MikeyMessage)
require.NoError(t, err)
outKey := make([]byte, srtpKeyLength)
_, err = rand.Read(outKey)
require.NoError(t, err)
srtpOutCtx = &wrappedSRTPContext{
key: outKey,
ssrcs: []uint32{2345423},
}
err = srtpOutCtx.initialize()
require.NoError(t, err)
var mikeyMsg *mikey.Message
mikeyMsg, err = mikeyGenerate(srtpOutCtx)
require.NoError(t, err)
var enc base.HeaderValue
enc, err = headers.KeyMgmt{
URL: req.URL.String(),
MikeyMessage: mikeyMsg,
}.Marshal()
require.NoError(t, err)
h["KeyMgmt"] = enc
}
if ca.transport == "udp" {
th.Protocol = headers.TransportProtocolUDP
th.ServerPorts = &[2]int{34556, 34557}
th.ClientPorts = inTH.ClientPorts
@@ -236,54 +313,55 @@ func TestClientRecord(t *testing.T) {
th.InterleavedIDs = inTH.InterleavedIDs
}
h["Transport"] = th.Marshal()
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
"Session": headers.Session{
Session: "ABCDE",
Timeout: uintPtr(1),
}.Marshal(),
},
Header: h,
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Record, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
var pl []byte
// client -> server
if transport == "udp" {
buf := make([]byte, 2048)
var buf []byte
if ca.transport == "udp" {
buf = make([]byte, 2048)
var n int
n, _, err2 = l1.ReadFrom(buf)
require.NoError(t, err2)
pl = buf[:n]
buf = buf[:n]
} else {
var f *base.InterleavedFrame
f, err2 = conn.ReadInterleavedFrame()
require.NoError(t, err2)
require.Equal(t, 0, f.Channel)
pl = f.Payload
buf = f.Payload
}
if ca.secure == "secure" {
buf, err2 = srtpInCtx.decryptRTP(buf, buf, nil)
require.NoError(t, err2)
}
var pkt rtp.Packet
err2 = pkt.Unmarshal(pl)
err2 = pkt.Unmarshal(buf)
require.NoError(t, err2)
require.Equal(t, testRTPPacket, pkt)
// client -> server keepalive
if transport == "udp" {
if ca.transport == "udp" {
recv := make(chan struct{})
go func() {
defer close(recv)
@@ -301,8 +379,17 @@ func TestClientRecord(t *testing.T) {
// server -> client
if transport == "udp" {
_, err2 = l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{
buf = testRTCPPacketMarshaled
if ca.secure == "secure" {
encr := make([]byte, 2000)
encr, err2 = srtpOutCtx.encryptRTCP(encr, buf, nil)
require.NoError(t, err2)
buf = encr
}
if ca.transport == "udp" {
_, err2 = l2.WriteTo(buf, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[1],
})
@@ -310,7 +397,7 @@ func TestClientRecord(t *testing.T) {
} else {
err2 = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 1,
Payload: testRTCPPacketMarshaled,
Payload: buf,
}, make([]byte, 1024))
require.NoError(t, err2)
}
@@ -318,7 +405,7 @@ func TestClientRecord(t *testing.T) {
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Teardown, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
@@ -333,7 +420,7 @@ func TestClientRecord(t *testing.T) {
InsecureSkipVerify: true,
},
Transport: func() *Transport {
if transport == "udp" {
if ca.transport == "udp" {
v := TransportUDP
return &v
}
@@ -345,7 +432,7 @@ func TestClientRecord(t *testing.T) {
medi := testH264Media
medias := []*description.Media{medi}
err = record(&c, scheme+"://localhost:8554/teststream", medias,
err = record(&c, ca.scheme+"://localhost:8554/teststream", medias,
func(_ *description.Media, pkt rtcp.Packet) {
require.Equal(t, &testRTCPPacket, pkt)
close(recvDone)
@@ -397,9 +484,9 @@ func TestClientRecord(t *testing.T) {
}, s)
require.Greater(t, s.Session.BytesSent, uint64(15))
require.Less(t, s.Session.BytesSent, uint64(17))
require.Less(t, s.Session.BytesSent, uint64(30))
require.Greater(t, s.Session.BytesReceived, uint64(19))
require.Less(t, s.Session.BytesReceived, uint64(21))
require.Less(t, s.Session.BytesReceived, uint64(40))
c.Close()
<-done
@@ -414,33 +501,18 @@ func TestClientRecordSocketError(t *testing.T) {
for _, transport := range []string{
"udp",
"tcp",
"tls",
} {
t.Run(transport, func(t *testing.T) {
var l net.Listener
var err error
var scheme string
if transport == "tls" {
scheme = "rtsps"
var cert tls.Certificate
cert, err = tls.X509KeyPair(serverCert, serverKey)
require.NoError(t, err)
l, err = tls.Listen("tcp", "localhost:8554", &tls.Config{Certificates: []tls.Certificate{cert}})
require.NoError(t, err)
defer l.Close()
} else {
scheme = "rtsp"
l, err = net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
}
l, err = net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -530,7 +602,7 @@ func TestClientRecordSocketError(t *testing.T) {
medi := testH264Media
medias := []*description.Media{medi}
err = record(&c, scheme+"://localhost:8554/teststream", medias, nil)
err = record(&c, "rtsp://localhost:8554/teststream", medias, nil)
require.NoError(t, err)
defer c.Close()
@@ -559,6 +631,7 @@ func TestClientRecordPauseRecordSerial(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -707,6 +780,7 @@ func TestClientRecordPauseRecordParallel(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -885,6 +959,7 @@ func TestClientRecordAutomaticProtocol(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -1016,6 +1091,7 @@ func TestClientRecordDecodeErrors(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -1186,6 +1262,7 @@ func TestClientRecordRTCPReport(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -1371,6 +1448,7 @@ func TestClientRecordIgnoreTCPRTPPackets(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)