support encrypted streams with SRTP and MIKEY (#520) (#809)

This commit is contained in:
Alessandro Ros
2025-07-05 12:48:13 +02:00
committed by GitHub
parent a5ff92f130
commit 616fa7ea89
104 changed files with 4179 additions and 766 deletions

473
client.go
View File

@@ -1,5 +1,5 @@
/*
Package gortsplib is a RTSP 1.0 library for the Go programming language.
Package gortsplib is a RTSP library for the Go programming language.
Examples are available at https://github.com/bluenviron/gortsplib/tree/main/examples
*/
@@ -7,10 +7,12 @@ package gortsplib
import (
"context"
"crypto/rand"
"crypto/tls"
"fmt"
"log"
"net"
"slices"
"strconv"
"strings"
"sync"
@@ -28,6 +30,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/gortsplib/v4/pkg/headers"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
"github.com/bluenviron/gortsplib/v4/pkg/rtcpreceiver"
"github.com/bluenviron/gortsplib/v4/pkg/rtcpsender"
"github.com/bluenviron/gortsplib/v4/pkg/rtptime"
@@ -118,10 +121,121 @@ func findBaseURL(sd *sdp.SessionDescription, res *base.Response, u *base.URL) (*
return u, nil
}
func prepareForAnnounce(desc *description.Session) {
for i, media := range desc.Medias {
media.Control = "trackID=" + strconv.FormatInt(int64(i), 10)
type clientAnnounceDataFormat struct {
localSSRC uint32
}
type clientAnnounceDataMedia struct {
srtpOutKey []byte
formats map[uint8]*clientAnnounceDataFormat
}
func announceDataPickLocalSSRC(
am *clientAnnounceDataMedia,
data map[*description.Media]*clientAnnounceDataMedia,
) (uint32, error) {
var takenSSRCs []uint32 //nolint:prealloc
for _, am := range data {
for _, af := range am.formats {
takenSSRCs = append(takenSSRCs, af.localSSRC)
}
}
for _, af := range am.formats {
takenSSRCs = append(takenSSRCs, af.localSSRC)
}
for {
ssrc, err := randUint32()
if err != nil {
return 0, err
}
if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) {
return ssrc, nil
}
}
}
func generateAnnounceData(
desc *description.Session,
secure bool,
) (map[*description.Media]*clientAnnounceDataMedia, error) {
data := make(map[*description.Media]*clientAnnounceDataMedia)
for _, medi := range desc.Medias {
am := &clientAnnounceDataMedia{
formats: make(map[uint8]*clientAnnounceDataFormat),
}
for _, format := range medi.Formats {
dataFormat := &clientAnnounceDataFormat{}
var err error
dataFormat.localSSRC, err = announceDataPickLocalSSRC(am, data)
if err != nil {
return nil, err
}
am.formats[format.PayloadType()] = dataFormat
}
if secure {
am.srtpOutKey = make([]byte, srtpKeyLength)
_, err := rand.Read(am.srtpOutKey)
if err != nil {
return nil, err
}
}
data[medi] = am
}
return data, nil
}
func prepareForAnnounce(
desc *description.Session,
announceData map[*description.Media]*clientAnnounceDataMedia,
secure bool,
) error {
for i, m := range desc.Medias {
m.Control = "trackID=" + strconv.FormatInt(int64(i), 10)
m.Secure = secure
if secure {
announceDataMedia := announceData[m]
ssrcs := make([]uint32, len(m.Formats))
n := 0
for _, af := range announceDataMedia.formats {
ssrcs[n] = af.localSSRC
n++
}
// use a dummy Context.
// Context is needed to extract ROC, but since client has not started streaming,
// ROC is always zero, therefore a dummy Context can be used.
srtpCtx := &wrappedSRTPContext{
key: announceDataMedia.srtpOutKey,
ssrcs: ssrcs,
}
err := srtpCtx.initialize()
if err != nil {
return err
}
mikeyMsg, err := mikeyGenerate(srtpCtx)
if err != nil {
return err
}
m.KeyMgmtMikey = mikeyMsg
}
}
return nil
}
func supportsGetParameter(header base.Header) bool {
@@ -330,7 +444,8 @@ type Client struct {
receiverReportPeriod time.Duration
checkTimeoutPeriod time.Duration
connURL *base.URL
scheme string
host string
ctx context.Context
ctxCancel func()
state clientState
@@ -342,8 +457,11 @@ type Client struct {
optionsSent bool
useGetParameter bool
lastDescribeURL *base.URL
lastDescribeDesc *description.Session
baseURL *base.URL
announceData map[*description.Media]*clientAnnounceDataMedia // record
effectiveTransport *Transport
effectiveSecure bool
backChannelSetupped bool
stdChannelSetupped bool
setuppedMedias map[*description.Media]*clientMedia
@@ -474,10 +592,8 @@ func (c *Client) Start(scheme string, host string) error {
ctx, ctxCancel := context.WithCancel(context.Background())
c.connURL = &base.URL{
Scheme: scheme,
Host: host,
}
c.scheme = scheme
c.host = host
c.ctx = ctx
c.ctxCancel = ctxCancel
c.checkTimeoutTimer = emptyTimer()
@@ -820,7 +936,6 @@ func (c *Client) checkState(allowed map[clientState]struct{}) error {
func (c *Client) trySwitchingProtocol() error {
c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP{})
prevConnURL := c.connURL
prevBaseURL := c.baseURL
prevMedias := c.setuppedMedias
@@ -828,7 +943,6 @@ func (c *Client) trySwitchingProtocol() error {
v := TransportTCP
c.effectiveTransport = &v
c.connURL = prevConnURL
// some Hikvision cameras require a describe before a setup
_, _, err := c.doDescribe(c.lastDescribeURL)
@@ -856,26 +970,6 @@ func (c *Client) trySwitchingProtocol() error {
return nil
}
func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *base.URL) (*base.Response, error) {
c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP2{})
prevConnURL := c.connURL
c.reset()
v := TransportTCP
c.effectiveTransport = &v
c.connURL = prevConnURL
// some Hikvision cameras require a describe before a setup
_, _, err := c.doDescribe(c.lastDescribeURL)
if err != nil {
return nil, err
}
return c.doSetup(baseURL, medi, 0, 0)
}
func (c *Client) startTransportRoutines() {
c.timeDecoder = &rtptime.GlobalDecoder2{}
c.timeDecoder.Initialize()
@@ -968,28 +1062,30 @@ func (c *Client) connOpen() error {
return nil
}
if c.connURL.Scheme != "rtsp" && c.connURL.Scheme != "rtsps" {
return liberrors.ErrClientUnsupportedScheme{Scheme: c.connURL.Scheme}
}
if c.connURL.Scheme == "rtsps" && c.Transport != nil && *c.Transport != TransportTCP {
return liberrors.ErrClientRTSPSTCP{}
if c.scheme != "rtsp" && c.scheme != "rtsps" {
return liberrors.ErrClientUnsupportedScheme{Scheme: c.scheme}
}
dialCtx, dialCtxCancel := context.WithTimeout(c.ctx, c.ReadTimeout)
defer dialCtxCancel()
nconn, err := c.DialContext(dialCtx, "tcp", canonicalAddr(c.connURL))
nconn, err := c.DialContext(dialCtx, "tcp", canonicalAddr(&base.URL{
Scheme: c.scheme,
Host: c.host,
}))
if err != nil {
return err
}
if c.connURL.Scheme == "rtsps" {
if c.scheme == "rtsps" {
tlsConfig := c.TLSConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{}
}
tlsConfig.ServerName = c.connURL.Hostname()
tlsConfig.ServerName = (&base.URL{
Scheme: c.scheme,
Host: c.host,
}).Hostname()
nconn = tls.Client(nconn, tlsConfig)
}
@@ -1256,7 +1352,7 @@ func (c *Client) doDescribe(u *base.URL) (*description.Session, *base.Response,
return nil, nil, err
}
if c.connURL.Scheme == "rtsps" && ru.Scheme != "rtsps" {
if c.scheme == "rtsps" && ru.Scheme != "rtsps" {
return nil, nil, fmt.Errorf("connection cannot be downgraded from RTSPS to RTSP")
}
@@ -1264,10 +1360,8 @@ func (c *Client) doDescribe(u *base.URL) (*description.Session, *base.Response,
ru.User = u.User
}
c.connURL = &base.URL{
Scheme: ru.Scheme,
Host: ru.Host,
}
c.scheme = ru.Scheme
c.host = ru.Host
return c.doDescribe(ru)
}
@@ -1306,6 +1400,7 @@ func (c *Client) doDescribe(u *base.URL) (*description.Session, *base.Response,
desc.BaseURL = baseURL
c.lastDescribeURL = u
c.lastDescribeDesc = &desc
return &desc, res, nil
}
@@ -1340,7 +1435,15 @@ func (c *Client) doAnnounce(u *base.URL, desc *description.Session) (*base.Respo
return nil, err
}
prepareForAnnounce(desc)
announceData, err := generateAnnounceData(desc, c.scheme == "rtsps")
if err != nil {
return nil, err
}
err = prepareForAnnounce(desc, announceData, c.scheme == "rtsps")
if err != nil {
return nil, err
}
byts, err := desc.Marshal(false)
if err != nil {
@@ -1367,6 +1470,7 @@ func (c *Client) doAnnounce(u *base.URL, desc *description.Session) (*base.Respo
c.baseURL = u.Clone()
c.state = clientStatePreRecord
c.announceData = announceData
return res, nil
}
@@ -1408,72 +1512,91 @@ func (c *Client) doSetup(
return nil, liberrors.ErrClientCannotSetupMediasDifferentURLs{}
}
th := headers.Transport{
Mode: func() *headers.TransportMode {
if c.state == clientStatePreRecord {
v := headers.TransportModeRecord
return &v
}
// when playing, omit mode, since it causes errors with some servers.
return nil
}(),
th := headers.Transport{}
// when playing, omit mode, since it causes errors with some servers.
if c.state == clientStatePreRecord {
v := headers.TransportModeRecord
th.Mode = &v
}
var transport Transport
switch {
// use transport from previous SETUP calls
case c.effectiveTransport != nil:
transport = *c.effectiveTransport
th.Secure = c.effectiveSecure
if th.Secure && !medi.Secure {
return nil, fmt.Errorf("previous media was setupped securely but current cannot")
}
// use transport from config, secure flag from server
case c.Transport != nil:
transport = *c.Transport
th.Secure = medi.Secure && c.scheme == "rtsps"
// try UDP if unencrypted or secure is supported by server, otherwise try TCP
default:
th.Secure = medi.Secure && c.scheme == "rtsps"
if th.Secure || c.scheme == "rtsp" {
transport = TransportUDP
} else {
transport = TransportTCP
}
}
cm := &clientMedia{
c: c,
media: medi,
c: c,
media: medi,
secure: th.Secure,
}
err = cm.initialize()
if err != nil {
return nil, err
}
if c.effectiveTransport == nil {
if c.connURL.Scheme == "rtsps" { // always use TCP if encrypted
v := TransportTCP
c.effectiveTransport = &v
} else if c.Transport != nil { // take transport from config
c.effectiveTransport = c.Transport
}
}
var desiredTransport Transport
if c.effectiveTransport != nil {
desiredTransport = *c.effectiveTransport
} else {
desiredTransport = TransportUDP
}
switch desiredTransport {
case TransportUDP:
if (rtpPort == 0 && rtcpPort != 0) ||
(rtpPort != 0 && rtcpPort == 0) {
return nil, liberrors.ErrClientUDPPortsZero{}
switch transport {
case TransportUDP, TransportUDPMulticast:
if c.scheme == "rtsps" && !medi.Secure {
cm.close()
return nil, fmt.Errorf("server does not support secure UDP")
}
if rtpPort != 0 && rtcpPort != (rtpPort+1) {
return nil, liberrors.ErrClientUDPPortsNotConsecutive{}
}
err = cm.createUDPListeners(
false,
nil,
net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)),
net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)),
)
if err != nil {
return nil, err
}
v1 := headers.TransportDeliveryUnicast
th.Delivery = &v1
th.Protocol = headers.TransportProtocolUDP
th.ClientPorts = &[2]int{cm.udpRTPListener.port(), cm.udpRTCPListener.port()}
case TransportUDPMulticast:
v1 := headers.TransportDeliveryMulticast
th.Delivery = &v1
th.Protocol = headers.TransportProtocolUDP
if transport == TransportUDP {
if (rtpPort == 0 && rtcpPort != 0) ||
(rtpPort != 0 && rtcpPort == 0) {
cm.close()
return nil, liberrors.ErrClientUDPPortsZero{}
}
if rtpPort != 0 && rtcpPort != (rtpPort+1) {
cm.close()
return nil, liberrors.ErrClientUDPPortsNotConsecutive{}
}
err = cm.createUDPListeners(
false,
nil,
net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)),
net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)),
)
if err != nil {
cm.close()
return nil, err
}
v1 := headers.TransportDeliveryUnicast
th.Delivery = &v1
th.ClientPorts = &[2]int{cm.udpRTPListener.port(), cm.udpRTCPListener.port()}
} else {
v1 := headers.TransportDeliveryMulticast
th.Delivery = &v1
}
case TransportTCP:
v1 := headers.TransportDeliveryUnicast
@@ -1497,6 +1620,34 @@ func (c *Client) doSetup(
header["Require"] = base.HeaderValue{"www.onvif.org/ver20/backchannel"}
}
if th.Secure {
ssrcs := make([]uint32, len(cm.formats))
n := 0
for _, cf := range cm.formats {
ssrcs[n] = cf.localSSRC
n++
}
var mikeyMsg *mikey.Message
mikeyMsg, err = mikeyGenerate(cm.srtpOutCtx)
if err != nil {
cm.close()
return nil, err
}
var enc base.HeaderValue
enc, err = headers.KeyMgmt{
URL: mediaURL.String(),
MikeyMessage: mikeyMsg,
}.Marshal()
if err != nil {
cm.close()
return nil, err
}
header["KeyMgmt"] = enc
}
res, err := c.do(&base.Request{
Method: base.Setup,
URL: mediaURL,
@@ -1512,10 +1663,12 @@ func (c *Client) doSetup(
// switch transport automatically
if res.StatusCode == base.StatusUnsupportedTransport &&
c.effectiveTransport == nil {
c.effectiveTransport == nil && c.Transport == nil {
c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP2{})
v := TransportTCP
c.effectiveTransport = &v
c.effectiveSecure = th.Secure
return c.doSetup(baseURL, medi, 0, 0)
}
@@ -1529,23 +1682,37 @@ func (c *Client) doSetup(
return nil, liberrors.ErrClientTransportHeaderInvalid{Err: err}
}
switch desiredTransport {
switch transport {
case TransportUDP, TransportUDPMulticast:
if thRes.Protocol == headers.TransportProtocolTCP {
cm.close()
// switch transport automatically
if c.effectiveTransport == nil &&
c.Transport == nil {
if c.effectiveTransport == nil && c.Transport == nil {
c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP2{})
c.baseURL = baseURL
return c.trySwitchingProtocol2(medi, baseURL)
c.reset()
v := TransportTCP
c.effectiveTransport = &v
c.effectiveSecure = th.Secure
// some Hikvision cameras require a describe before a setup
_, _, err = c.doDescribe(c.lastDescribeURL)
if err != nil {
return nil, err
}
return c.doSetup(baseURL, medi, 0, 0)
}
return nil, liberrors.ErrClientServerRequestedTCP{}
}
}
switch desiredTransport {
switch transport {
case TransportUDP:
if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast {
cm.close()
@@ -1592,14 +1759,17 @@ func (c *Client) doSetup(
case TransportUDPMulticast:
if thRes.Delivery == nil || *thRes.Delivery != headers.TransportDeliveryMulticast {
cm.close()
return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{}
}
if thRes.Ports == nil {
cm.close()
return nil, liberrors.ErrClientTransportHeaderNoPorts{}
}
if thRes.Destination == nil {
cm.close()
return nil, liberrors.ErrClientTransportHeaderNoDestination{}
}
@@ -1617,6 +1787,7 @@ func (c *Client) doSetup(
net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[1]), 10)),
)
if err != nil {
cm.close()
return nil, err
}
@@ -1636,22 +1807,27 @@ func (c *Client) doSetup(
case TransportTCP:
if thRes.Protocol != headers.TransportProtocolTCP {
cm.close()
return nil, liberrors.ErrClientServerRequestedUDP{}
}
if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast {
cm.close()
return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{}
}
if thRes.InterleavedIDs == nil {
cm.close()
return nil, liberrors.ErrClientTransportHeaderNoInterleavedIDs{}
}
if (thRes.InterleavedIDs[0] + 1) != thRes.InterleavedIDs[1] {
cm.close()
return nil, liberrors.ErrClientTransportHeaderInvalidInterleavedIDs{}
}
if c.isChannelPairInUse(thRes.InterleavedIDs[0]) {
cm.close()
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrClientTransportHeaderInterleavedIDsInUse{}
@@ -1660,6 +1836,48 @@ func (c *Client) doSetup(
cm.tcpChannel = thRes.InterleavedIDs[0]
}
if cm.secure {
if !thRes.Secure {
cm.close()
return nil, fmt.Errorf("transport was not setupped securely")
}
var mikeyMsg *mikey.Message
// extract key-mgmt from (in order of priority):
// - response
// - media SDP attributes
// - session SDP attributes
switch {
case res.Header["KeyMgmt"] != nil:
var keyMgmt headers.KeyMgmt
err = keyMgmt.Unmarshal(res.Header["KeyMgmt"])
if err != nil {
cm.close()
return nil, err
}
mikeyMsg = keyMgmt.MikeyMessage
case medi.KeyMgmtMikey != nil:
mikeyMsg = medi.KeyMgmtMikey
case c.lastDescribeDesc.KeyMgmtMikey != nil:
mikeyMsg = c.lastDescribeDesc.KeyMgmtMikey
default:
return nil, fmt.Errorf("server did not provide key-mgmt data in any supported way")
}
cm.srtpInCtx, err = mikeyToContext(mikeyMsg)
if err != nil {
cm.close()
return nil, err
}
} else if thRes.Secure {
cm.close()
return nil, fmt.Errorf("received unexpected secure profile")
}
if c.setuppedMedias == nil {
c.setuppedMedias = make(map[*description.Media]*clientMedia)
}
@@ -1667,7 +1885,8 @@ func (c *Client) doSetup(
c.setuppedMedias[medi] = cm
c.baseURL = baseURL
c.effectiveTransport = &desiredTransport
c.effectiveTransport = &transport
c.effectiveSecure = th.Secure
if medi.IsBackChannel {
c.backChannelSetupped = true
@@ -1770,12 +1989,34 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
// do this before sending the PLAY request.
if *c.effectiveTransport == TransportUDP {
for _, cm := range c.setuppedMedias {
if !cm.media.IsBackChannel {
byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal()
cm.udpRTPListener.write(byts) //nolint:errcheck
if !cm.media.IsBackChannel && cm.udpRTPListener.writeAddr != nil {
buf, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal()
if cm.srtpOutCtx != nil {
encr := make([]byte, cm.c.MaxPacketSize)
encr, err = cm.srtpOutCtx.encryptRTP(encr, buf, nil)
if err != nil {
return nil, err
}
buf = encr
}
err = cm.udpRTPListener.write(buf)
if err != nil {
return nil, err
}
byts, _ = (&rtcp.ReceiverReport{}).Marshal()
cm.udpRTCPListener.write(byts) //nolint:errcheck
buf, _ = (&rtcp.ReceiverReport{}).Marshal()
if cm.srtpOutCtx != nil {
encr := make([]byte, cm.c.MaxPacketSize)
encr, err = cm.srtpOutCtx.encryptRTCP(encr, buf, nil)
if err != nil {
return nil, err
}
buf = encr
}
err = cm.udpRTCPListener.write(buf)
if err != nil {
return nil, err
}
}
}
}
@@ -1981,7 +2222,7 @@ func (c *Client) WritePacketRTP(medi *description.Media, pkt *rtp.Packet) error
}
// WritePacketRTPWithNTP writes a RTP packet to the server.
// ntp is the absolute time of the packet, and is sent with periodic RTCP sender reports.
// ntp is the absolute timestamp of the packet, and is sent with periodic RTCP sender reports.
func (c *Client) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.Packet, ntp time.Time) error {
select {
case <-c.done:
@@ -2020,7 +2261,7 @@ func (c *Client) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error
return cm.writePacketRTCP(pkt)
}
// PacketPTS returns the PTS of an incoming RTP packet.
// PacketPTS returns the PTS (presentation timestamp) of an incoming RTP packet.
// It is computed by decoding the packet timestamp and sychronizing it with other tracks.
//
// Deprecated: replaced by PacketPTS2.
@@ -2036,7 +2277,7 @@ func (c *Client) PacketPTS(medi *description.Media, pkt *rtp.Packet) (time.Durat
return multiplyAndDivide(time.Duration(v), time.Second, time.Duration(ct.format.ClockRate())), true
}
// PacketPTS2 returns the PTS of an incoming RTP packet.
// PacketPTS2 returns the PTS (presentation timestamp) of an incoming RTP packet.
// It is computed by decoding the packet timestamp and sychronizing it with other tracks.
func (c *Client) PacketPTS2(medi *description.Media, pkt *rtp.Packet) (int64, bool) {
cm := c.setuppedMedias[medi]
@@ -2044,8 +2285,8 @@ func (c *Client) PacketPTS2(medi *description.Media, pkt *rtp.Packet) (int64, bo
return c.timeDecoder.Decode(ct.format, pkt)
}
// PacketNTP returns the NTP timestamp of an incoming RTP packet.
// The NTP timestamp is computed from RTCP sender reports.
// PacketNTP returns the NTP (absolute timestamp) of an incoming RTP packet.
// The NTP is computed from RTCP sender reports.
func (c *Client) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time, bool) {
cm := c.setuppedMedias[medi]
ct := cm.formats[pkt.PayloadType]