support encrypted streams with SRTP and MIKEY (#520) (#809)

This commit is contained in:
Alessandro Ros
2025-07-05 12:48:13 +02:00
committed by GitHub
parent a5ff92f130
commit 616fa7ea89
104 changed files with 4179 additions and 766 deletions

View File

@@ -1,6 +1,7 @@
package gortsplib
import (
"bytes"
"context"
"fmt"
"log"
@@ -20,6 +21,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/gortsplib/v4/pkg/headers"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
"github.com/bluenviron/gortsplib/v4/pkg/rtcpreceiver"
"github.com/bluenviron/gortsplib/v4/pkg/rtcpsender"
"github.com/bluenviron/gortsplib/v4/pkg/rtptime"
@@ -165,21 +167,160 @@ func findMediaByTrackID(medias []*description.Media, trackID string) *descriptio
return medias[id]
}
func findFirstSupportedTransportHeader(s *Server, tsh headers.Transports) *headers.Transport {
// Per RFC2326 section 12.39, client specifies transports in order of preference.
// Filter out the ones we don't support and then pick first supported transport.
for _, tr := range tsh {
func isTransportSupported(s *Server, tr *headers.Transport) bool {
// prevent using UDP/UDP-multicast when listeners are disabled
if tr.Protocol == headers.TransportProtocolUDP {
isMulticast := tr.Delivery != nil && *tr.Delivery == headers.TransportDeliveryMulticast
if tr.Protocol == headers.TransportProtocolUDP &&
((!isMulticast && s.udpRTPListener == nil) ||
(isMulticast && s.MulticastIPRange == "")) {
continue
if !isMulticast && s.udpRTPListener == nil {
return false
}
if isMulticast && s.MulticastIPRange == "" {
return false
}
}
// prevent using unsecure UDP with RTSPS
if tr.Protocol == headers.TransportProtocolUDP && !tr.Secure && s.TLSConfig != nil {
return false
}
// prevent using secure profiles with plain RTSP, since keys are in plain
if tr.Secure && s.TLSConfig == nil {
return false
}
return true
}
func pickFirstSupportedTransport(s *Server, tsh headers.Transports) *headers.Transport {
for _, tr := range tsh {
if isTransportSupported(s, &tr) {
return &tr
}
return &tr
}
return nil
}
func mikeyDecodeTime(t uint64) time.Time {
sec := t >> 32
dec := t & 0xFFFFFFFF
sec -= 2208988800
return time.Unix(int64(sec), int64(dec))
}
func mikeyEncodeTime(n time.Time) uint64 {
nano := uint64(n.UnixNano())
sec := nano / 1000000000
dec := nano % 1000000000
sec += 2208988800
return sec<<32 | dec
}
func mikeyGetPayload[T mikey.Payload](mikeyMsg *mikey.Message) (T, bool) {
var zero T
for _, wrapped := range mikeyMsg.Payloads {
if val, ok := wrapped.(T); ok {
return val, true
}
}
return zero, false
}
func mikeyGetSPPolicy(spPayload *mikey.PayloadSP, typ mikey.PayloadSPPolicyParamType) ([]byte, bool) {
for _, pl := range spPayload.PolicyParams {
if pl.Type == typ {
return pl.Value, true
}
}
return nil, false
}
func mikeyToContext(mikeyMsg *mikey.Message) (*wrappedSRTPContext, error) {
timePayload, ok := mikeyGetPayload[*mikey.PayloadT](mikeyMsg)
if !ok {
return nil, fmt.Errorf("time payload not present")
}
ts := mikeyDecodeTime(timePayload.TSValue)
diff := time.Since(ts)
if diff < -time.Hour || diff > time.Hour {
return nil, fmt.Errorf("NTP difference is too high")
}
spPayload, ok := mikeyGetPayload[*mikey.PayloadSP](mikeyMsg)
if !ok {
return nil, fmt.Errorf("SP payload not present")
}
v, ok := mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeEncrAlg)
if !ok || !bytes.Equal(v, []byte{1}) {
return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeEncrAlg")
}
v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeSessionEncrKeyLen)
if !ok || !bytes.Equal(v, []byte{0x10}) {
return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeSessionEncrKeyLen")
}
v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeAuthAlg)
if !ok || !bytes.Equal(v, []byte{1}) {
return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeAuthAlg")
}
v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeSessionAuthKeyLen)
if !ok || !bytes.Equal(v, []byte{0x0a}) {
return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeSessionAuthKeyLen")
}
v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeSRTPEncrOffOn)
if !ok || !bytes.Equal(v, []byte{1}) {
return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeSRTPEncrOffOn")
}
v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeSRTCPEncrOffOn)
if !ok || !bytes.Equal(v, []byte{1}) {
return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeSRTCPEncrOffOn")
}
v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeSRTPAuthOffOn)
if !ok || !bytes.Equal(v, []byte{1}) {
return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeSRTPAuthOffOn")
}
kemacPayload, ok := mikeyGetPayload[*mikey.PayloadKEMAC](mikeyMsg)
if !ok {
return nil, fmt.Errorf("KEMAC payload not present")
}
if len(kemacPayload.SubPayloads) != 1 {
return nil, fmt.Errorf("multiple keys are present")
}
if len(kemacPayload.SubPayloads[0].KeyData) != srtpKeyLength {
return nil, fmt.Errorf("unexpected key size: %d", len(kemacPayload.SubPayloads[0].KeyData))
}
ssrcs := make([]uint32, len(mikeyMsg.Header.CSIDMapInfo))
startROCs := make([]uint32, len(mikeyMsg.Header.CSIDMapInfo))
for i, entry := range mikeyMsg.Header.CSIDMapInfo {
ssrcs[i] = entry.SSRC
startROCs[i] = entry.ROC
}
srtpCtx := &wrappedSRTPContext{
key: kemacPayload.SubPayloads[0].KeyData,
ssrcs: ssrcs,
startROCs: startROCs,
}
err := srtpCtx.initialize()
if err != nil {
return nil, err
}
return srtpCtx, nil
}
func generateRTPInfoEntry(ssm *serverStreamMedia, now time.Time) *headers.RTPInfoEntry {
// do not generate a RTP-Info entry when
// there are multiple formats inside a single media stream,
@@ -293,6 +434,7 @@ type ServerSession struct {
setuppedMediasOrdered []*serverSessionMedia
tcpCallbackByChannel map[int]readFunc
setuppedTransport *Transport
setuppedSecure bool
setuppedStream *ServerStream // play
setuppedPath string
setuppedQuery string
@@ -371,6 +513,13 @@ func (ss *ServerSession) SetuppedTransport() *Transport {
return ss.setuppedTransport
}
// SetuppedSecure returns whether a secure profile is in use.
// If this is false, it does not mean that the stream is not secure, since
// there are some combinations that are secure nonetheless, like RTSPS+TCP+unsecure.
func (ss *ServerSession) SetuppedSecure() bool {
return ss.setuppedSecure
}
// SetuppedStream returns the stream associated with the session.
func (ss *ServerSession) SetuppedStream() *ServerStream {
return ss.setuppedStream
@@ -947,7 +1096,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}, liberrors.ErrServerTransportHeaderInvalid{Err: err}
}
inTH := findFirstSupportedTransportHeader(ss.s, transportHeaders)
// Per RFC2326 section 12.39, client specifies transports in order of preference.
// pick the first supported one.
inTH := pickFirstSupportedTransport(ss.s, transportHeaders)
if inTH == nil {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
@@ -978,20 +1129,41 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
var transport Transport
if inTH.Protocol == headers.TransportProtocolUDP {
switch inTH.Protocol {
case headers.TransportProtocolUDP:
if inTH.Delivery != nil && *inTH.Delivery == headers.TransportDeliveryMulticast {
transport = TransportUDPMulticast
} else {
transport = TransportUDP
}
} else {
case headers.TransportProtocolTCP:
transport = TransportTCP
}
if ss.setuppedTransport != nil && *ss.setuppedTransport != transport {
var srtpInCtx *wrappedSRTPContext
if inTH.Secure {
var keyMgmt headers.KeyMgmt
err = keyMgmt.Unmarshal(req.Header["KeyMgmt"])
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidKeyMgmtHeader{Wrapped: err}
}
srtpInCtx, err = mikeyToContext(keyMgmt.MikeyMessage)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidKeyMgmtHeader{Wrapped: err}
}
}
if ss.setuppedTransport != nil && (*ss.setuppedTransport != transport || ss.setuppedSecure != inTH.Secure) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerMediasDifferentProtocols{}
}, liberrors.ErrServerMediasDifferentTransports{}
}
switch transport {
@@ -1052,7 +1224,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
// workaround to prevent a bug in rtspclientsink
// that makes impossible for the client to receive the response
// and send frames.
// this was causing problems during unit tests.
// this was causing problems during E2E tests.
if ua, ok := req.Header["User-Agent"]; ok && len(ua) == 1 &&
strings.HasPrefix(ua[0], "GStreamer") {
select {
@@ -1092,6 +1264,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}
ss.setuppedTransport = &transport
ss.setuppedSecure = inTH.Secure
if ss.state == ServerSessionStateInitial {
err = stream.readerAdd(ss,
@@ -1109,7 +1282,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
ss.setuppedStream = stream
}
th := headers.Transport{}
th := headers.Transport{
Secure: inTH.Secure,
}
if ss.state == ServerSessionStatePrePlay {
if stream != ss.setuppedStream {
@@ -1131,6 +1306,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
sm := &serverSessionMedia{
ss: ss,
media: medi,
srtpInCtx: srtpInCtx,
onPacketRTCP: func(_ rtcp.Packet) {},
}
err = sm.initialize()
@@ -1141,46 +1317,48 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}
switch transport {
case TransportUDP:
sm.udpRTPReadPort = inTH.ClientPorts[0]
sm.udpRTCPReadPort = inTH.ClientPorts[1]
sm.udpRTPWriteAddr = &net.UDPAddr{
IP: ss.author.ip(),
Zone: ss.author.zone(),
Port: sm.udpRTPReadPort,
}
sm.udpRTCPWriteAddr = &net.UDPAddr{
IP: ss.author.ip(),
Zone: ss.author.zone(),
Port: sm.udpRTCPReadPort,
}
case TransportUDP, TransportUDPMulticast:
th.Protocol = headers.TransportProtocolUDP
de := headers.TransportDeliveryUnicast
th.Delivery = &de
th.ClientPorts = inTH.ClientPorts
th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()}
case TransportUDPMulticast:
th.Protocol = headers.TransportProtocolUDP
de := headers.TransportDeliveryMulticast
th.Delivery = &de
v := uint(127)
th.TTL = &v
d := stream.medias[medi].multicastWriter.ip()
th.Destination = &d
th.Ports = &[2]int{ss.s.MulticastRTPPort, ss.s.MulticastRTCPPort}
if transport == TransportUDP {
sm.udpRTPReadPort = inTH.ClientPorts[0]
sm.udpRTCPReadPort = inTH.ClientPorts[1]
sm.udpRTPWriteAddr = &net.UDPAddr{
IP: ss.author.ip(),
Zone: ss.author.zone(),
Port: sm.udpRTPReadPort,
}
sm.udpRTCPWriteAddr = &net.UDPAddr{
IP: ss.author.ip(),
Zone: ss.author.zone(),
Port: sm.udpRTCPReadPort,
}
de := headers.TransportDeliveryUnicast
th.Delivery = &de
th.ClientPorts = inTH.ClientPorts
th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()}
} else {
de := headers.TransportDeliveryMulticast
th.Delivery = &de
v := uint(127)
th.TTL = &v
d := stream.medias[medi].multicastWriter.ip()
th.Destination = &d
th.Ports = &[2]int{ss.s.MulticastRTPPort, ss.s.MulticastRTCPPort}
}
default: // TCP
th.Protocol = headers.TransportProtocolTCP
if inTH.InterleavedIDs != nil {
sm.tcpChannel = inTH.InterleavedIDs[0]
} else {
sm.tcpChannel = ss.findFreeChannelPair()
}
th.Protocol = headers.TransportProtocolTCP
de := headers.TransportDeliveryUnicast
th.Delivery = &de
th.InterleavedIDs = &[2]int{sm.tcpChannel, sm.tcpChannel + 1}
@@ -1193,6 +1371,38 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
ss.setuppedMediasOrdered = append(ss.setuppedMediasOrdered, sm)
res.Header["Transport"] = th.Marshal()
if inTH.Secure {
ssrcs := make([]uint32, len(sm.formats))
n := 0
for _, sf := range sm.formats {
ssrcs[n] = sf.localSSRC
n++
}
var mk *mikey.Message
mk, err = mikeyGenerate(sm.srtpOutCtx)
if err != nil {
return &base.Response{
StatusCode: base.StatusInternalServerError,
}, err
}
var enc base.HeaderValue
enc, err = headers.KeyMgmt{
URL: req.URL.String(),
MikeyMessage: mk,
}.Marshal()
if err != nil {
return &base.Response{
StatusCode: base.StatusInternalServerError,
}, err
}
// always return KeyMgmt even if redundant when playing
// (since it's already present in the SDP)
res.Header["KeyMgmt"] = enc
}
}
return res, err
@@ -1239,7 +1449,12 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
ss.timeDecoder.Initialize()
for _, sm := range ss.setuppedMedias {
sm.start()
err = sm.start()
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
}
if *ss.setuppedTransport == TransportTCP {
@@ -1329,7 +1544,12 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
ss.timeDecoder.Initialize()
for _, sm := range ss.setuppedMedias {
sm.start()
err = sm.start()
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
}
if *ss.setuppedTransport == TransportTCP {
@@ -1523,12 +1743,6 @@ func (ss *ServerSession) OnPacketRTCP(medi *description.Media, cb OnPacketRTCPFu
sm.onPacketRTCP = cb
}
func (ss *ServerSession) writePacketRTPEncoded(medi *description.Media, payloadType uint8, byts []byte) error {
sm := ss.setuppedMedias[medi]
sf := sm.formats[payloadType]
return sf.writePacketRTPEncoded(byts)
}
// WritePacketRTP writes a RTP packet to the session.
func (ss *ServerSession) WritePacketRTP(medi *description.Media, pkt *rtp.Packet) error {
sm := ss.setuppedMedias[medi]
@@ -1536,22 +1750,13 @@ func (ss *ServerSession) WritePacketRTP(medi *description.Media, pkt *rtp.Packet
return sf.writePacketRTP(pkt)
}
func (ss *ServerSession) writePacketRTCPEncoded(medi *description.Media, byts []byte) error {
sm := ss.setuppedMedias[medi]
return sm.writePacketRTCPEncoded(byts)
}
// WritePacketRTCP writes a RTCP packet to the session.
func (ss *ServerSession) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error {
byts, err := pkt.Marshal()
if err != nil {
return err
}
return ss.writePacketRTCPEncoded(medi, byts)
sm := ss.setuppedMedias[medi]
return sm.writePacketRTCP(pkt)
}
// PacketPTS returns the PTS of an incoming RTP packet.
// PacketPTS returns the PTS (presentation timestamp) of an incoming RTP packet.
// It is computed by decoding the packet timestamp and sychronizing it with other tracks.
//
// Deprecated: replaced by PacketPTS2.
@@ -1567,7 +1772,7 @@ func (ss *ServerSession) PacketPTS(medi *description.Media, pkt *rtp.Packet) (ti
return multiplyAndDivide(time.Duration(v), time.Second, time.Duration(sf.format.ClockRate())), true
}
// PacketPTS2 returns the PTS of an incoming RTP packet.
// PacketPTS2 returns the PTS (presentation timestamp) of an incoming RTP packet.
// It is computed by decoding the packet timestamp and sychronizing it with other tracks.
func (ss *ServerSession) PacketPTS2(medi *description.Media, pkt *rtp.Packet) (int64, bool) {
sm := ss.setuppedMedias[medi]
@@ -1575,8 +1780,8 @@ func (ss *ServerSession) PacketPTS2(medi *description.Media, pkt *rtp.Packet) (i
return ss.timeDecoder.Decode(sf.format, pkt)
}
// PacketNTP returns the NTP timestamp of an incoming RTP packet.
// The NTP timestamp is computed from RTCP sender reports.
// PacketNTP returns the NTP (absolute timestamp) of an incoming RTP packet.
// The NTP is computed from RTCP sender reports.
func (ss *ServerSession) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time, bool) {
sm := ss.setuppedMedias[medi]
sf := sm.formats[pkt.PayloadType]