support encrypted streams with SRTP and MIKEY (#520) (#809)

This commit is contained in:
Alessandro Ros
2025-07-05 12:48:13 +02:00
committed by GitHub
parent a5ff92f130
commit 616fa7ea89
104 changed files with 4179 additions and 766 deletions

View File

@@ -6,39 +6,37 @@
[![CodeCov](https://codecov.io/gh/bluenviron/gortsplib/branch/main/graph/badge.svg)](https://app.codecov.io/gh/bluenviron/gortsplib/tree/main)
[![PkgGoDev](https://pkg.go.dev/badge/github.com/bluenviron/gortsplib/v4)](https://pkg.go.dev/github.com/bluenviron/gortsplib/v4#pkg-index)
RTSP 1.0 client and server library for the Go programming language, written for [MediaMTX](https://github.com/bluenviron/mediamtx).
RTSP client and server library for the Go programming language, written for [MediaMTX](https://github.com/bluenviron/mediamtx).
Go ≥ 1.23 is required.
Features:
* Client
* Support secure protocol variants (RTSPS, TLS, SRTP, MIKEY)
* Query servers about available media streams
* Read media streams from a server ("play")
* Read streams with the UDP, UDP-multicast or TCP transport protocol
* Read TLS-encrypted streams (TCP only)
* Switch transport protocol automatically
* Read selected media streams
* Pause or seek without disconnecting from the server
* Write to ONVIF back channels
* Get PTS (relative) timestamp of incoming packets
* Get NTP (absolute) timestamp of incoming packets
* Get PTS (presentation timestamp) of incoming packets
* Get NTP (absolute timestamp) of incoming packets
* Write media streams to a server ("record")
* Write streams with the UDP or TCP transport protocol
* Write TLS-encrypted streams (TCP only)
* Switch transport protocol automatically
* Pause without disconnecting from the server
* Server
* Support secure protocol variants (RTSPS, TLS, SRTP, MIKEY)
* Handle requests from clients
* Validate client credentials
* Read media streams from clients ("record")
* Read streams with the UDP or TCP transport protocol
* Read TLS-encrypted streams (TCP only)
* Get PTS (relative) timestamp of incoming packets
* Get NTP (absolute) timestamp of incoming packets
* Get PTS (presentation timestamp) of incoming packets
* Get NTP (absolute timestamp) of incoming packets
* Serve media streams to clients ("play")
* Write streams with the UDP, UDP-multicast or TCP transport protocol
* Write TLS-encrypted streams (TCP only)
* Compute and provide SSRC, RTP-Info to clients
* Read ONVIF back channels
* Utilities
@@ -94,7 +92,7 @@ Features:
* [client-record-format-vp8](examples/client-record-format-vp8/main.go)
* [client-record-format-vp9](examples/client-record-format-vp9/main.go)
* [server](examples/server/main.go)
* [server-tls](examples/server-tls/main.go)
* [server-secure](examples/server-secure/main.go)
* [server-auth](examples/server-auth/main.go)
* [server-record-format-h264-to-disk](examples/server-record-format-h264-to-disk/main.go)
* [server-play-format-h264-from-disk](examples/server-play-format-h264-from-disk/main.go)
@@ -150,7 +148,10 @@ In RTSP, media streams are transmitted by using RTP packets, which are encoded i
|----|----|
|[RFC2326, RTSP 1.0](https://datatracker.ietf.org/doc/html/rfc2326)|protocol|
|[RFC7826, RTSP 2.0](https://datatracker.ietf.org/doc/html/rfc7826)|protocol|
|[ONVIF Streaming Specification 23.06](https://www.onvif.org/specs/stream/ONVIF-Streaming-Spec.pdf)|protocol|
|[RFC8866, SDP: Session Description Protocol](https://datatracker.ietf.org/doc/html/rfc8866)|SDP|
|[RFC4567, Key Management Extensions for Session Description Protocol (SDP) and Real Time Streaming Protocol (RTSP)](https://datatracker.ietf.org/doc/html/rfc4567)|secure variants|
|[RFC3830, MIKEY: Multimedia Internet KEYing](https://datatracker.ietf.org/doc/html/rfc3830)|secure variants|
|[RTP Payload Format For AV1 (v1.0)](https://aomediacodec.github.io/av1-rtp-spec/)|payload formats / AV1|
|[RTP Payload Format for VP9 Video](https://datatracker.ietf.org/doc/html/draft-ietf-payload-vp9-16)|payload formats / VP9|
|[RFC7741, RTP Payload Format for VP8 Video](https://datatracker.ietf.org/doc/html/rfc7741)|payload formats / VP8|
@@ -178,3 +179,4 @@ In RTSP, media streams are transmitted by using RTP packets, which are encoded i
* [pion/sdp (SDP library used internally)](https://github.com/pion/sdp)
* [pion/rtp (RTP library used internally)](https://github.com/pion/rtp)
* [pion/rtcp (RTCP library used internally)](https://github.com/pion/rtcp)
* [pion/srtp (SRTP library used internally)](https://github.com/pion/srtp)

473
client.go
View File

@@ -1,5 +1,5 @@
/*
Package gortsplib is a RTSP 1.0 library for the Go programming language.
Package gortsplib is a RTSP library for the Go programming language.
Examples are available at https://github.com/bluenviron/gortsplib/tree/main/examples
*/
@@ -7,10 +7,12 @@ package gortsplib
import (
"context"
"crypto/rand"
"crypto/tls"
"fmt"
"log"
"net"
"slices"
"strconv"
"strings"
"sync"
@@ -28,6 +30,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/gortsplib/v4/pkg/headers"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
"github.com/bluenviron/gortsplib/v4/pkg/rtcpreceiver"
"github.com/bluenviron/gortsplib/v4/pkg/rtcpsender"
"github.com/bluenviron/gortsplib/v4/pkg/rtptime"
@@ -118,10 +121,121 @@ func findBaseURL(sd *sdp.SessionDescription, res *base.Response, u *base.URL) (*
return u, nil
}
func prepareForAnnounce(desc *description.Session) {
for i, media := range desc.Medias {
media.Control = "trackID=" + strconv.FormatInt(int64(i), 10)
type clientAnnounceDataFormat struct {
localSSRC uint32
}
type clientAnnounceDataMedia struct {
srtpOutKey []byte
formats map[uint8]*clientAnnounceDataFormat
}
func announceDataPickLocalSSRC(
am *clientAnnounceDataMedia,
data map[*description.Media]*clientAnnounceDataMedia,
) (uint32, error) {
var takenSSRCs []uint32 //nolint:prealloc
for _, am := range data {
for _, af := range am.formats {
takenSSRCs = append(takenSSRCs, af.localSSRC)
}
}
for _, af := range am.formats {
takenSSRCs = append(takenSSRCs, af.localSSRC)
}
for {
ssrc, err := randUint32()
if err != nil {
return 0, err
}
if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) {
return ssrc, nil
}
}
}
func generateAnnounceData(
desc *description.Session,
secure bool,
) (map[*description.Media]*clientAnnounceDataMedia, error) {
data := make(map[*description.Media]*clientAnnounceDataMedia)
for _, medi := range desc.Medias {
am := &clientAnnounceDataMedia{
formats: make(map[uint8]*clientAnnounceDataFormat),
}
for _, format := range medi.Formats {
dataFormat := &clientAnnounceDataFormat{}
var err error
dataFormat.localSSRC, err = announceDataPickLocalSSRC(am, data)
if err != nil {
return nil, err
}
am.formats[format.PayloadType()] = dataFormat
}
if secure {
am.srtpOutKey = make([]byte, srtpKeyLength)
_, err := rand.Read(am.srtpOutKey)
if err != nil {
return nil, err
}
}
data[medi] = am
}
return data, nil
}
func prepareForAnnounce(
desc *description.Session,
announceData map[*description.Media]*clientAnnounceDataMedia,
secure bool,
) error {
for i, m := range desc.Medias {
m.Control = "trackID=" + strconv.FormatInt(int64(i), 10)
m.Secure = secure
if secure {
announceDataMedia := announceData[m]
ssrcs := make([]uint32, len(m.Formats))
n := 0
for _, af := range announceDataMedia.formats {
ssrcs[n] = af.localSSRC
n++
}
// use a dummy Context.
// Context is needed to extract ROC, but since client has not started streaming,
// ROC is always zero, therefore a dummy Context can be used.
srtpCtx := &wrappedSRTPContext{
key: announceDataMedia.srtpOutKey,
ssrcs: ssrcs,
}
err := srtpCtx.initialize()
if err != nil {
return err
}
mikeyMsg, err := mikeyGenerate(srtpCtx)
if err != nil {
return err
}
m.KeyMgmtMikey = mikeyMsg
}
}
return nil
}
func supportsGetParameter(header base.Header) bool {
@@ -330,7 +444,8 @@ type Client struct {
receiverReportPeriod time.Duration
checkTimeoutPeriod time.Duration
connURL *base.URL
scheme string
host string
ctx context.Context
ctxCancel func()
state clientState
@@ -342,8 +457,11 @@ type Client struct {
optionsSent bool
useGetParameter bool
lastDescribeURL *base.URL
lastDescribeDesc *description.Session
baseURL *base.URL
announceData map[*description.Media]*clientAnnounceDataMedia // record
effectiveTransport *Transport
effectiveSecure bool
backChannelSetupped bool
stdChannelSetupped bool
setuppedMedias map[*description.Media]*clientMedia
@@ -474,10 +592,8 @@ func (c *Client) Start(scheme string, host string) error {
ctx, ctxCancel := context.WithCancel(context.Background())
c.connURL = &base.URL{
Scheme: scheme,
Host: host,
}
c.scheme = scheme
c.host = host
c.ctx = ctx
c.ctxCancel = ctxCancel
c.checkTimeoutTimer = emptyTimer()
@@ -820,7 +936,6 @@ func (c *Client) checkState(allowed map[clientState]struct{}) error {
func (c *Client) trySwitchingProtocol() error {
c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP{})
prevConnURL := c.connURL
prevBaseURL := c.baseURL
prevMedias := c.setuppedMedias
@@ -828,7 +943,6 @@ func (c *Client) trySwitchingProtocol() error {
v := TransportTCP
c.effectiveTransport = &v
c.connURL = prevConnURL
// some Hikvision cameras require a describe before a setup
_, _, err := c.doDescribe(c.lastDescribeURL)
@@ -856,26 +970,6 @@ func (c *Client) trySwitchingProtocol() error {
return nil
}
func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *base.URL) (*base.Response, error) {
c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP2{})
prevConnURL := c.connURL
c.reset()
v := TransportTCP
c.effectiveTransport = &v
c.connURL = prevConnURL
// some Hikvision cameras require a describe before a setup
_, _, err := c.doDescribe(c.lastDescribeURL)
if err != nil {
return nil, err
}
return c.doSetup(baseURL, medi, 0, 0)
}
func (c *Client) startTransportRoutines() {
c.timeDecoder = &rtptime.GlobalDecoder2{}
c.timeDecoder.Initialize()
@@ -968,28 +1062,30 @@ func (c *Client) connOpen() error {
return nil
}
if c.connURL.Scheme != "rtsp" && c.connURL.Scheme != "rtsps" {
return liberrors.ErrClientUnsupportedScheme{Scheme: c.connURL.Scheme}
}
if c.connURL.Scheme == "rtsps" && c.Transport != nil && *c.Transport != TransportTCP {
return liberrors.ErrClientRTSPSTCP{}
if c.scheme != "rtsp" && c.scheme != "rtsps" {
return liberrors.ErrClientUnsupportedScheme{Scheme: c.scheme}
}
dialCtx, dialCtxCancel := context.WithTimeout(c.ctx, c.ReadTimeout)
defer dialCtxCancel()
nconn, err := c.DialContext(dialCtx, "tcp", canonicalAddr(c.connURL))
nconn, err := c.DialContext(dialCtx, "tcp", canonicalAddr(&base.URL{
Scheme: c.scheme,
Host: c.host,
}))
if err != nil {
return err
}
if c.connURL.Scheme == "rtsps" {
if c.scheme == "rtsps" {
tlsConfig := c.TLSConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{}
}
tlsConfig.ServerName = c.connURL.Hostname()
tlsConfig.ServerName = (&base.URL{
Scheme: c.scheme,
Host: c.host,
}).Hostname()
nconn = tls.Client(nconn, tlsConfig)
}
@@ -1256,7 +1352,7 @@ func (c *Client) doDescribe(u *base.URL) (*description.Session, *base.Response,
return nil, nil, err
}
if c.connURL.Scheme == "rtsps" && ru.Scheme != "rtsps" {
if c.scheme == "rtsps" && ru.Scheme != "rtsps" {
return nil, nil, fmt.Errorf("connection cannot be downgraded from RTSPS to RTSP")
}
@@ -1264,10 +1360,8 @@ func (c *Client) doDescribe(u *base.URL) (*description.Session, *base.Response,
ru.User = u.User
}
c.connURL = &base.URL{
Scheme: ru.Scheme,
Host: ru.Host,
}
c.scheme = ru.Scheme
c.host = ru.Host
return c.doDescribe(ru)
}
@@ -1306,6 +1400,7 @@ func (c *Client) doDescribe(u *base.URL) (*description.Session, *base.Response,
desc.BaseURL = baseURL
c.lastDescribeURL = u
c.lastDescribeDesc = &desc
return &desc, res, nil
}
@@ -1340,7 +1435,15 @@ func (c *Client) doAnnounce(u *base.URL, desc *description.Session) (*base.Respo
return nil, err
}
prepareForAnnounce(desc)
announceData, err := generateAnnounceData(desc, c.scheme == "rtsps")
if err != nil {
return nil, err
}
err = prepareForAnnounce(desc, announceData, c.scheme == "rtsps")
if err != nil {
return nil, err
}
byts, err := desc.Marshal(false)
if err != nil {
@@ -1367,6 +1470,7 @@ func (c *Client) doAnnounce(u *base.URL, desc *description.Session) (*base.Respo
c.baseURL = u.Clone()
c.state = clientStatePreRecord
c.announceData = announceData
return res, nil
}
@@ -1408,72 +1512,91 @@ func (c *Client) doSetup(
return nil, liberrors.ErrClientCannotSetupMediasDifferentURLs{}
}
th := headers.Transport{
Mode: func() *headers.TransportMode {
if c.state == clientStatePreRecord {
v := headers.TransportModeRecord
return &v
}
// when playing, omit mode, since it causes errors with some servers.
return nil
}(),
th := headers.Transport{}
// when playing, omit mode, since it causes errors with some servers.
if c.state == clientStatePreRecord {
v := headers.TransportModeRecord
th.Mode = &v
}
var transport Transport
switch {
// use transport from previous SETUP calls
case c.effectiveTransport != nil:
transport = *c.effectiveTransport
th.Secure = c.effectiveSecure
if th.Secure && !medi.Secure {
return nil, fmt.Errorf("previous media was setupped securely but current cannot")
}
// use transport from config, secure flag from server
case c.Transport != nil:
transport = *c.Transport
th.Secure = medi.Secure && c.scheme == "rtsps"
// try UDP if unencrypted or secure is supported by server, otherwise try TCP
default:
th.Secure = medi.Secure && c.scheme == "rtsps"
if th.Secure || c.scheme == "rtsp" {
transport = TransportUDP
} else {
transport = TransportTCP
}
}
cm := &clientMedia{
c: c,
media: medi,
c: c,
media: medi,
secure: th.Secure,
}
err = cm.initialize()
if err != nil {
return nil, err
}
if c.effectiveTransport == nil {
if c.connURL.Scheme == "rtsps" { // always use TCP if encrypted
v := TransportTCP
c.effectiveTransport = &v
} else if c.Transport != nil { // take transport from config
c.effectiveTransport = c.Transport
}
}
var desiredTransport Transport
if c.effectiveTransport != nil {
desiredTransport = *c.effectiveTransport
} else {
desiredTransport = TransportUDP
}
switch desiredTransport {
case TransportUDP:
if (rtpPort == 0 && rtcpPort != 0) ||
(rtpPort != 0 && rtcpPort == 0) {
return nil, liberrors.ErrClientUDPPortsZero{}
switch transport {
case TransportUDP, TransportUDPMulticast:
if c.scheme == "rtsps" && !medi.Secure {
cm.close()
return nil, fmt.Errorf("server does not support secure UDP")
}
if rtpPort != 0 && rtcpPort != (rtpPort+1) {
return nil, liberrors.ErrClientUDPPortsNotConsecutive{}
}
err = cm.createUDPListeners(
false,
nil,
net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)),
net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)),
)
if err != nil {
return nil, err
}
v1 := headers.TransportDeliveryUnicast
th.Delivery = &v1
th.Protocol = headers.TransportProtocolUDP
th.ClientPorts = &[2]int{cm.udpRTPListener.port(), cm.udpRTCPListener.port()}
case TransportUDPMulticast:
v1 := headers.TransportDeliveryMulticast
th.Delivery = &v1
th.Protocol = headers.TransportProtocolUDP
if transport == TransportUDP {
if (rtpPort == 0 && rtcpPort != 0) ||
(rtpPort != 0 && rtcpPort == 0) {
cm.close()
return nil, liberrors.ErrClientUDPPortsZero{}
}
if rtpPort != 0 && rtcpPort != (rtpPort+1) {
cm.close()
return nil, liberrors.ErrClientUDPPortsNotConsecutive{}
}
err = cm.createUDPListeners(
false,
nil,
net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)),
net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)),
)
if err != nil {
cm.close()
return nil, err
}
v1 := headers.TransportDeliveryUnicast
th.Delivery = &v1
th.ClientPorts = &[2]int{cm.udpRTPListener.port(), cm.udpRTCPListener.port()}
} else {
v1 := headers.TransportDeliveryMulticast
th.Delivery = &v1
}
case TransportTCP:
v1 := headers.TransportDeliveryUnicast
@@ -1497,6 +1620,34 @@ func (c *Client) doSetup(
header["Require"] = base.HeaderValue{"www.onvif.org/ver20/backchannel"}
}
if th.Secure {
ssrcs := make([]uint32, len(cm.formats))
n := 0
for _, cf := range cm.formats {
ssrcs[n] = cf.localSSRC
n++
}
var mikeyMsg *mikey.Message
mikeyMsg, err = mikeyGenerate(cm.srtpOutCtx)
if err != nil {
cm.close()
return nil, err
}
var enc base.HeaderValue
enc, err = headers.KeyMgmt{
URL: mediaURL.String(),
MikeyMessage: mikeyMsg,
}.Marshal()
if err != nil {
cm.close()
return nil, err
}
header["KeyMgmt"] = enc
}
res, err := c.do(&base.Request{
Method: base.Setup,
URL: mediaURL,
@@ -1512,10 +1663,12 @@ func (c *Client) doSetup(
// switch transport automatically
if res.StatusCode == base.StatusUnsupportedTransport &&
c.effectiveTransport == nil {
c.effectiveTransport == nil && c.Transport == nil {
c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP2{})
v := TransportTCP
c.effectiveTransport = &v
c.effectiveSecure = th.Secure
return c.doSetup(baseURL, medi, 0, 0)
}
@@ -1529,23 +1682,37 @@ func (c *Client) doSetup(
return nil, liberrors.ErrClientTransportHeaderInvalid{Err: err}
}
switch desiredTransport {
switch transport {
case TransportUDP, TransportUDPMulticast:
if thRes.Protocol == headers.TransportProtocolTCP {
cm.close()
// switch transport automatically
if c.effectiveTransport == nil &&
c.Transport == nil {
if c.effectiveTransport == nil && c.Transport == nil {
c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP2{})
c.baseURL = baseURL
return c.trySwitchingProtocol2(medi, baseURL)
c.reset()
v := TransportTCP
c.effectiveTransport = &v
c.effectiveSecure = th.Secure
// some Hikvision cameras require a describe before a setup
_, _, err = c.doDescribe(c.lastDescribeURL)
if err != nil {
return nil, err
}
return c.doSetup(baseURL, medi, 0, 0)
}
return nil, liberrors.ErrClientServerRequestedTCP{}
}
}
switch desiredTransport {
switch transport {
case TransportUDP:
if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast {
cm.close()
@@ -1592,14 +1759,17 @@ func (c *Client) doSetup(
case TransportUDPMulticast:
if thRes.Delivery == nil || *thRes.Delivery != headers.TransportDeliveryMulticast {
cm.close()
return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{}
}
if thRes.Ports == nil {
cm.close()
return nil, liberrors.ErrClientTransportHeaderNoPorts{}
}
if thRes.Destination == nil {
cm.close()
return nil, liberrors.ErrClientTransportHeaderNoDestination{}
}
@@ -1617,6 +1787,7 @@ func (c *Client) doSetup(
net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[1]), 10)),
)
if err != nil {
cm.close()
return nil, err
}
@@ -1636,22 +1807,27 @@ func (c *Client) doSetup(
case TransportTCP:
if thRes.Protocol != headers.TransportProtocolTCP {
cm.close()
return nil, liberrors.ErrClientServerRequestedUDP{}
}
if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast {
cm.close()
return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{}
}
if thRes.InterleavedIDs == nil {
cm.close()
return nil, liberrors.ErrClientTransportHeaderNoInterleavedIDs{}
}
if (thRes.InterleavedIDs[0] + 1) != thRes.InterleavedIDs[1] {
cm.close()
return nil, liberrors.ErrClientTransportHeaderInvalidInterleavedIDs{}
}
if c.isChannelPairInUse(thRes.InterleavedIDs[0]) {
cm.close()
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrClientTransportHeaderInterleavedIDsInUse{}
@@ -1660,6 +1836,48 @@ func (c *Client) doSetup(
cm.tcpChannel = thRes.InterleavedIDs[0]
}
if cm.secure {
if !thRes.Secure {
cm.close()
return nil, fmt.Errorf("transport was not setupped securely")
}
var mikeyMsg *mikey.Message
// extract key-mgmt from (in order of priority):
// - response
// - media SDP attributes
// - session SDP attributes
switch {
case res.Header["KeyMgmt"] != nil:
var keyMgmt headers.KeyMgmt
err = keyMgmt.Unmarshal(res.Header["KeyMgmt"])
if err != nil {
cm.close()
return nil, err
}
mikeyMsg = keyMgmt.MikeyMessage
case medi.KeyMgmtMikey != nil:
mikeyMsg = medi.KeyMgmtMikey
case c.lastDescribeDesc.KeyMgmtMikey != nil:
mikeyMsg = c.lastDescribeDesc.KeyMgmtMikey
default:
return nil, fmt.Errorf("server did not provide key-mgmt data in any supported way")
}
cm.srtpInCtx, err = mikeyToContext(mikeyMsg)
if err != nil {
cm.close()
return nil, err
}
} else if thRes.Secure {
cm.close()
return nil, fmt.Errorf("received unexpected secure profile")
}
if c.setuppedMedias == nil {
c.setuppedMedias = make(map[*description.Media]*clientMedia)
}
@@ -1667,7 +1885,8 @@ func (c *Client) doSetup(
c.setuppedMedias[medi] = cm
c.baseURL = baseURL
c.effectiveTransport = &desiredTransport
c.effectiveTransport = &transport
c.effectiveSecure = th.Secure
if medi.IsBackChannel {
c.backChannelSetupped = true
@@ -1770,12 +1989,34 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
// do this before sending the PLAY request.
if *c.effectiveTransport == TransportUDP {
for _, cm := range c.setuppedMedias {
if !cm.media.IsBackChannel {
byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal()
cm.udpRTPListener.write(byts) //nolint:errcheck
if !cm.media.IsBackChannel && cm.udpRTPListener.writeAddr != nil {
buf, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal()
if cm.srtpOutCtx != nil {
encr := make([]byte, cm.c.MaxPacketSize)
encr, err = cm.srtpOutCtx.encryptRTP(encr, buf, nil)
if err != nil {
return nil, err
}
buf = encr
}
err = cm.udpRTPListener.write(buf)
if err != nil {
return nil, err
}
byts, _ = (&rtcp.ReceiverReport{}).Marshal()
cm.udpRTCPListener.write(byts) //nolint:errcheck
buf, _ = (&rtcp.ReceiverReport{}).Marshal()
if cm.srtpOutCtx != nil {
encr := make([]byte, cm.c.MaxPacketSize)
encr, err = cm.srtpOutCtx.encryptRTCP(encr, buf, nil)
if err != nil {
return nil, err
}
buf = encr
}
err = cm.udpRTCPListener.write(buf)
if err != nil {
return nil, err
}
}
}
}
@@ -1981,7 +2222,7 @@ func (c *Client) WritePacketRTP(medi *description.Media, pkt *rtp.Packet) error
}
// WritePacketRTPWithNTP writes a RTP packet to the server.
// ntp is the absolute time of the packet, and is sent with periodic RTCP sender reports.
// ntp is the absolute timestamp of the packet, and is sent with periodic RTCP sender reports.
func (c *Client) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.Packet, ntp time.Time) error {
select {
case <-c.done:
@@ -2020,7 +2261,7 @@ func (c *Client) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error
return cm.writePacketRTCP(pkt)
}
// PacketPTS returns the PTS of an incoming RTP packet.
// PacketPTS returns the PTS (presentation timestamp) of an incoming RTP packet.
// It is computed by decoding the packet timestamp and sychronizing it with other tracks.
//
// Deprecated: replaced by PacketPTS2.
@@ -2036,7 +2277,7 @@ func (c *Client) PacketPTS(medi *description.Media, pkt *rtp.Packet) (time.Durat
return multiplyAndDivide(time.Duration(v), time.Second, time.Duration(ct.format.ClockRate())), true
}
// PacketPTS2 returns the PTS of an incoming RTP packet.
// PacketPTS2 returns the PTS (presentation timestamp) of an incoming RTP packet.
// It is computed by decoding the packet timestamp and sychronizing it with other tracks.
func (c *Client) PacketPTS2(medi *description.Media, pkt *rtp.Packet) (int64, bool) {
cm := c.setuppedMedias[medi]
@@ -2044,8 +2285,8 @@ func (c *Client) PacketPTS2(medi *description.Media, pkt *rtp.Packet) (int64, bo
return c.timeDecoder.Decode(ct.format, pkt)
}
// PacketNTP returns the NTP timestamp of an incoming RTP packet.
// The NTP timestamp is computed from RTCP sender reports.
// PacketNTP returns the NTP (absolute timestamp) of an incoming RTP packet.
// The NTP is computed from RTCP sender reports.
func (c *Client) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time, bool) {
cm := c.setuppedMedias[medi]
ct := cm.formats[pkt.PayloadType]

View File

@@ -1,6 +1,7 @@
package gortsplib
import (
"slices"
"sync/atomic"
"time"
@@ -15,25 +16,26 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/rtpreorderer"
)
func isClientLocalSSRCTaken(ssrc uint32, c *Client, exclude *clientFormat) bool {
for _, cm := range c.setuppedMedias {
func clientPickLocalSSRC(cf *clientFormat) (uint32, error) {
var takenSSRCs []uint32 //nolint:prealloc
for _, cm := range cf.cm.c.setuppedMedias {
for _, cf := range cm.formats {
if cf != exclude && cf.localSSRC == ssrc {
return true
}
takenSSRCs = append(takenSSRCs, cf.localSSRC)
}
}
return false
}
func clientPickLocalSSRC(cf *clientFormat) (uint32, error) {
for _, cf := range cf.cm.formats {
takenSSRCs = append(takenSSRCs, cf.localSSRC)
}
for {
ssrc, err := randUint32()
if err != nil {
return 0, err
}
if ssrc != 0 && !isClientLocalSSRCTaken(ssrc, cf.cm.c, cf) {
if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) {
return ssrc, nil
}
}
@@ -56,10 +58,14 @@ type clientFormat struct {
}
func (cf *clientFormat) initialize() error {
var err error
cf.localSSRC, err = clientPickLocalSSRC(cf)
if err != nil {
return err
if cf.cm.c.state == clientStatePreRecord {
cf.localSSRC = cf.cm.c.announceData[cf.cm.media].formats[cf.format.PayloadType()].localSSRC
} else {
var err error
cf.localSSRC, err = clientPickLocalSSRC(cf)
if err != nil {
return err
}
}
cf.rtpPacketsReceived = new(uint64)
@@ -181,17 +187,31 @@ func (cf *clientFormat) handlePacketsLost(lost uint64) {
func (cf *clientFormat) writePacketRTP(pkt *rtp.Packet, ntp time.Time) error {
pkt.SSRC = cf.localSSRC
byts := make([]byte, cf.cm.c.MaxPacketSize)
n, err := pkt.MarshalTo(byts)
cf.rtcpSender.ProcessPacket(pkt, ntp, cf.format.PTSEqualsDTS(pkt))
maxPlainPacketSize := cf.cm.c.MaxPacketSize
if cf.cm.srtpOutCtx != nil {
maxPlainPacketSize -= srtpOverhead
}
buf := make([]byte, maxPlainPacketSize)
n, err := pkt.MarshalTo(buf)
if err != nil {
return err
}
byts = byts[:n]
buf = buf[:n]
cf.rtcpSender.ProcessPacket(pkt, ntp, cf.format.PTSEqualsDTS(pkt))
if cf.cm.srtpOutCtx != nil {
encr := make([]byte, cf.cm.c.MaxPacketSize)
encr, err = cf.cm.srtpOutCtx.encryptRTP(encr, buf, &pkt.Header)
if err != nil {
return err
}
buf = encr
}
ok := cf.cm.c.writer.push(func() error {
return cf.writePacketRTPInQueue(byts)
return cf.writePacketRTPInQueue(buf)
})
if !ok {
return liberrors.ErrClientWriteQueueFull{}

View File

@@ -1,7 +1,10 @@
package gortsplib
import (
"crypto/rand"
"fmt"
"net"
"strconv"
"sync/atomic"
"time"
@@ -13,9 +16,12 @@ import (
)
type clientMedia struct {
c *Client
media *description.Media
c *Client
media *description.Media
secure bool
srtpOutCtx *wrappedSRTPContext
srtpInCtx *wrappedSRTPContext
onPacketRTCP OnPacketRTCPFunc
formats map[uint8]*clientFormat
tcpChannel int
@@ -55,6 +61,35 @@ func (cm *clientMedia) initialize() error {
cm.formats[forma.PayloadType()] = f
}
if cm.secure {
var srtpOutKey []byte
if cm.c.state == clientStatePreRecord {
srtpOutKey = cm.c.announceData[cm.media].srtpOutKey
} else {
srtpOutKey = make([]byte, srtpKeyLength)
_, err := rand.Read(srtpOutKey)
if err != nil {
return err
}
}
ssrcs := make([]uint32, len(cm.formats))
n := 0
for _, cf := range cm.formats {
ssrcs[n] = cf.localSSRC
n++
}
cm.srtpOutCtx = &wrappedSRTPContext{
key: srtpOutKey,
ssrcs: ssrcs,
}
err := cm.srtpOutCtx.initialize()
if err != nil {
return err
}
}
return nil
}
@@ -99,9 +134,45 @@ func (cm *clientMedia) createUDPListeners(
return nil
}
var err error
cm.udpRTPListener, cm.udpRTCPListener, err = createUDPListenerPair(cm.c)
return err
// pick two consecutive ports in range 65535-10000
// RTP port must be even and RTCP port odd
for {
v, err := randInRange((65535 - 10000) / 2)
if err != nil {
return err
}
rtpPort := v*2 + 10000
rtcpPort := rtpPort + 1
cm.udpRTPListener = &clientUDPListener{
c: cm.c,
multicastEnable: false,
multicastSourceIP: nil,
address: net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)),
}
err = cm.udpRTPListener.initialize()
if err != nil {
cm.udpRTPListener = nil
continue
}
cm.udpRTCPListener = &clientUDPListener{
c: cm.c,
multicastEnable: false,
multicastSourceIP: nil,
address: net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)),
}
err = cm.udpRTCPListener.initialize()
if err != nil {
cm.udpRTPListener.close()
cm.udpRTPListener = nil
cm.udpRTCPListener = nil
continue
}
return nil
}
}
func (cm *clientMedia) start() {
@@ -161,14 +232,44 @@ func (cm *clientMedia) findFormatByRemoteSSRC(ssrc uint32) *clientFormat {
return nil
}
func (cm *clientMedia) decodeRTP(payload []byte) (*rtp.Packet, error) {
if cm.srtpInCtx != nil {
var err error
payload, err = cm.srtpInCtx.decryptRTP(payload, payload, nil)
if err != nil {
return nil, err
}
}
var pkt rtp.Packet
err := pkt.Unmarshal(payload)
return &pkt, err
}
func (cm *clientMedia) decodeRTCP(payload []byte) ([]rtcp.Packet, error) {
if cm.srtpInCtx != nil {
var err error
payload, err = cm.srtpInCtx.decryptRTCP(payload, payload, nil)
if err != nil {
return nil, err
}
}
pkts, err := rtcp.Unmarshal(payload)
if err != nil {
return nil, err
}
return pkts, nil
}
func (cm *clientMedia) readPacketRTPTCPPlay(payload []byte) bool {
atomic.AddUint64(cm.bytesReceived, uint64(len(payload)))
now := cm.c.timeNow()
atomic.StoreInt64(cm.c.tcpLastFrameTime, now.Unix())
pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload)
pkt, err := cm.decodeRTP(payload)
if err != nil {
cm.onPacketRTPDecodeError(err)
return false
@@ -196,7 +297,7 @@ func (cm *clientMedia) readPacketRTCPTCPPlay(payload []byte) bool {
return false
}
packets, err := rtcp.Unmarshal(payload)
packets, err := cm.decodeRTCP(payload)
if err != nil {
cm.onPacketRTCPDecodeError(err)
return false
@@ -230,7 +331,7 @@ func (cm *clientMedia) readPacketRTCPTCPRecord(payload []byte) bool {
return false
}
packets, err := rtcp.Unmarshal(payload)
packets, err := cm.decodeRTCP(payload)
if err != nil {
cm.onPacketRTCPDecodeError(err)
return false
@@ -253,8 +354,7 @@ func (cm *clientMedia) readPacketRTPUDPPlay(payload []byte) bool {
return false
}
pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload)
pkt, err := cm.decodeRTP(payload)
if err != nil {
cm.onPacketRTPDecodeError(err)
return false
@@ -279,7 +379,7 @@ func (cm *clientMedia) readPacketRTCPUDPPlay(payload []byte) bool {
return false
}
packets, err := rtcp.Unmarshal(payload)
packets, err := cm.decodeRTCP(payload)
if err != nil {
cm.onPacketRTCPDecodeError(err)
return false
@@ -315,7 +415,7 @@ func (cm *clientMedia) readPacketRTCPUDPRecord(payload []byte) bool {
return false
}
packets, err := rtcp.Unmarshal(payload)
packets, err := cm.decodeRTCP(payload)
if err != nil {
cm.onPacketRTCPDecodeError(err)
return false
@@ -341,13 +441,31 @@ func (cm *clientMedia) onPacketRTCPDecodeError(err error) {
}
func (cm *clientMedia) writePacketRTCP(pkt rtcp.Packet) error {
byts, err := pkt.Marshal()
buf, err := pkt.Marshal()
if err != nil {
return err
}
maxPlainPacketSize := cm.c.MaxPacketSize
if cm.srtpOutCtx != nil {
maxPlainPacketSize -= srtcpOverhead
}
if len(buf) > maxPlainPacketSize {
return fmt.Errorf("packet is too big")
}
if cm.srtpOutCtx != nil {
encr := make([]byte, cm.c.MaxPacketSize)
encr, err = cm.srtpOutCtx.encryptRTCP(encr, buf, nil)
if err != nil {
return err
}
buf = encr
}
ok := cm.c.writer.push(func() error {
return cm.writePacketRTCPInQueue(byts)
return cm.writePacketRTCPInQueue(buf)
})
if !ok {
return liberrors.ErrClientWriteQueueFull{}

View File

@@ -2,7 +2,9 @@ package gortsplib
import (
"bytes"
"crypto/rand"
"crypto/tls"
"encoding/base64"
"net"
"strconv"
"strings"
@@ -21,6 +23,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/gortsplib/v4/pkg/headers"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
"github.com/bluenviron/mediacommon/v2/pkg/codecs/mpeg4audio"
)
@@ -45,7 +48,9 @@ func mediasToSDP(medias []*description.Media) []byte {
Medias: medias,
}
prepareForAnnounce(desc)
for i, m := range desc.Medias {
m.Control = "trackID=" + strconv.FormatInt(int64(i), 10)
}
byts, err := desc.Marshal(false)
if err != nil {
@@ -146,6 +151,7 @@ func TestClientPlayFormats(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -242,23 +248,55 @@ func TestClientPlayFormats(t *testing.T) {
}
func TestClientPlay(t *testing.T) {
for _, transport := range []string{
"udp",
"multicast",
"tcp",
"tls",
for _, ca := range []struct {
scheme string
transport string
secure string
}{
{
"rtsp",
"udp",
"unsecure",
},
{
"rtsp",
"multicast",
"unsecure",
},
{
"rtsp",
"tcp",
"unsecure",
},
{
"rtsps",
"tcp",
"unsecure",
},
{
"rtsps",
"udp",
"secure",
},
{
"rtsps",
"multicast",
"secure",
},
{
"rtsps",
"tcp",
"secure",
},
} {
t.Run(transport, func(t *testing.T) {
t.Run(ca.scheme+"_"+ca.transport+"_"+ca.secure, func(t *testing.T) {
packetRecv := make(chan struct{})
listenIP := multicastCapableIP(t)
var l net.Listener
var err error
var scheme string
if transport == "tls" {
scheme = "rtsps"
if ca.scheme == "rtsps" {
var cert tls.Certificate
cert, err = tls.X509KeyPair(serverCert, serverKey)
require.NoError(t, err)
@@ -267,8 +305,6 @@ func TestClientPlay(t *testing.T) {
require.NoError(t, err)
defer l.Close()
} else {
scheme = "rtsp"
l, err = net.Listen("tcp", listenIP+":8554")
require.NoError(t, err)
defer l.Close()
@@ -276,6 +312,7 @@ func TestClientPlay(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -287,7 +324,7 @@ func TestClientPlay(t *testing.T) {
req, err2 := conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Options, req.Method)
require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value"), req.URL)
require.Equal(t, mustParseURL(ca.scheme+"://"+listenIP+":8554/test/stream?param=value"), req.URL)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
@@ -304,7 +341,7 @@ func TestClientPlay(t *testing.T) {
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Describe, req.Method)
require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value"), req.URL)
require.Equal(t, mustParseURL(ca.scheme+"://"+listenIP+":8554/test/stream?param=value"), req.URL)
forma := &format.Generic{
PayloadTyp: 96,
@@ -317,10 +354,12 @@ func TestClientPlay(t *testing.T) {
{
Type: "application",
Formats: []format.Format{forma},
Secure: ca.secure == "secure",
},
{
Type: "application",
Formats: []format.Format{forma},
Secure: ca.secure == "secure",
},
}
@@ -328,7 +367,7 @@ func TestClientPlay(t *testing.T) {
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Type": base.HeaderValue{"application/sdp"},
"Content-Base": base.HeaderValue{scheme + "://" + listenIP + ":8554/test/stream?param=value/"},
"Content-Base": base.HeaderValue{ca.scheme + "://" + listenIP + ":8554/test/stream?param=value/"},
},
Body: mediasToSDP(medias),
})
@@ -337,13 +376,15 @@ func TestClientPlay(t *testing.T) {
var l1s [2]net.PacketConn
var l2s [2]net.PacketConn
var clientPorts [2]*[2]int
var srtpInCtx [2]*wrappedSRTPContext
var srtpOutCtx [2]*wrappedSRTPContext
for i := 0; i < 2; i++ {
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Setup, req.Method)
require.Equal(t, mustParseURL(
scheme+"://"+listenIP+":8554/test/stream?param=value/"+medias[i].Control), req.URL)
ca.scheme+"://"+listenIP+":8554/test/stream?param=value/"+medias[i].Control), req.URL)
var inTH headers.Transport
err2 = inTH.Unmarshal(req.Header["Transport"])
@@ -351,9 +392,48 @@ func TestClientPlay(t *testing.T) {
require.Equal(t, (*headers.TransportMode)(nil), inTH.Mode)
var th headers.Transport
h := base.Header{}
switch transport {
th := headers.Transport{
Secure: inTH.Secure,
}
if ca.secure == "secure" {
require.True(t, inTH.Secure)
var keyMgmt headers.KeyMgmt
err2 = keyMgmt.Unmarshal(req.Header["KeyMgmt"])
require.NoError(t, err2)
srtpInCtx[i], err = mikeyToContext(keyMgmt.MikeyMessage)
require.NoError(t, err2)
outKey := make([]byte, srtpKeyLength)
_, err2 = rand.Read(outKey)
require.NoError(t, err2)
srtpOutCtx[i] = &wrappedSRTPContext{
key: outKey,
ssrcs: []uint32{2345423},
}
err2 = srtpOutCtx[i].initialize()
require.NoError(t, err2)
var mikeyMsg *mikey.Message
mikeyMsg, err = mikeyGenerate(srtpOutCtx[i])
require.NoError(t, err)
var enc base.HeaderValue
enc, err = headers.KeyMgmt{
URL: req.URL.String(),
MikeyMessage: mikeyMsg,
}.Marshal()
require.NoError(t, err)
h["KeyMgmt"] = enc
}
switch ca.transport {
case "udp":
v := headers.TransportDeliveryUnicast
th.Delivery = &v
@@ -409,18 +489,18 @@ func TestClientPlay(t *testing.T) {
require.NoError(t, err2)
}
case "tcp", "tls":
case "tcp":
v := headers.TransportDeliveryUnicast
th.Delivery = &v
th.Protocol = headers.TransportProtocolTCP
th.InterleavedIDs = &[2]int{0 + i*2, 1 + i*2}
}
h["Transport"] = th.Marshal()
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
Header: h,
})
require.NoError(t, err2)
}
@@ -428,7 +508,7 @@ func TestClientPlay(t *testing.T) {
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Play, req.Method)
require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value/"), req.URL)
require.Equal(t, mustParseURL(ca.scheme+"://"+listenIP+":8554/test/stream?param=value/"), req.URL)
require.Equal(t, base.HeaderValue{"npt=0-"}, req.Header["Range"])
err2 = conn.WriteResponse(&base.Response{
@@ -439,25 +519,34 @@ func TestClientPlay(t *testing.T) {
// server -> client
for i := 0; i < 2; i++ {
switch transport {
buf := testRTPPacketMarshaled
if ca.secure == "secure" {
encr := make([]byte, 2000)
encr, err2 = srtpOutCtx[i].encryptRTP(encr, buf, nil)
require.NoError(t, err2)
buf = encr
}
switch ca.transport {
case "udp":
_, err2 = l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{
_, err2 = l1s[i].WriteTo(buf, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: clientPorts[i][0],
})
require.NoError(t, err2)
case "multicast":
_, err2 = l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{
_, err2 = l1s[i].WriteTo(buf, &net.UDPAddr{
IP: net.ParseIP("224.1.0.1"),
Port: 25000 + i*2,
})
require.NoError(t, err2)
case "tcp", "tls":
case "tcp":
err2 = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0 + i*2,
Payload: testRTPPacketMarshaled,
Payload: buf,
}, make([]byte, 1024))
require.NoError(t, err2)
}
@@ -465,7 +554,7 @@ func TestClientPlay(t *testing.T) {
// skip firewall opening
if transport == "udp" {
if ca.transport == "udp" {
for i := 0; i < 2; i++ {
buf := make([]byte, 2048)
_, _, err2 = l2s[i].ReadFrom(buf)
@@ -476,27 +565,30 @@ func TestClientPlay(t *testing.T) {
// client -> server
for i := 0; i < 2; i++ {
switch transport {
var buf []byte
switch ca.transport {
case "udp", "multicast":
buf := make([]byte, 2048)
buf = make([]byte, 2048)
var n int
n, _, err2 = l2s[i].ReadFrom(buf)
require.NoError(t, err2)
var packets []rtcp.Packet
packets, err2 = rtcp.Unmarshal(buf[:n])
require.NoError(t, err2)
require.Equal(t, &testRTCPPacket, packets[0])
buf = buf[:n]
case "tcp", "tls":
case "tcp":
var f *base.InterleavedFrame
f, err2 = conn.ReadInterleavedFrame()
require.NoError(t, err2)
require.Equal(t, 1+i*2, f.Channel)
var packets []rtcp.Packet
packets, err2 = rtcp.Unmarshal(f.Payload)
require.NoError(t, err2)
require.Equal(t, &testRTCPPacket, packets[0])
buf = f.Payload
}
if ca.secure == "secure" {
buf, err2 = srtpInCtx[i].decryptRTCP(buf, buf, nil)
require.NoError(t, err2)
}
require.Equal(t, testRTCPPacketMarshaled, buf)
}
close(packetRecv)
@@ -504,7 +596,7 @@ func TestClientPlay(t *testing.T) {
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Teardown, req.Method)
require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value/"), req.URL)
require.Equal(t, mustParseURL(ca.scheme+"://"+listenIP+":8554/test/stream?param=value/"), req.URL)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
@@ -513,11 +605,9 @@ func TestClientPlay(t *testing.T) {
}()
c := Client{
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
TLSConfig: &tls.Config{InsecureSkipVerify: true},
Transport: func() *Transport {
switch transport {
switch ca.transport {
case "udp":
v := TransportUDP
return &v
@@ -526,14 +616,14 @@ func TestClientPlay(t *testing.T) {
v := TransportUDPMulticast
return &v
default: // tcp, tls
default: // tcp
v := TransportTCP
return &v
}
}(),
}
u, err := base.ParseURL(scheme + "://" + listenIP + ":8554/test/stream?param=value")
u, err := base.ParseURL(ca.scheme + "://" + listenIP + ":8554/test/stream?param=value")
require.NoError(t, err)
err = c.Start(u.Scheme, u.Host)
@@ -601,9 +691,221 @@ func TestClientPlay(t *testing.T) {
}, s)
require.Greater(t, s.Session.BytesSent, uint64(19))
require.Less(t, s.Session.BytesSent, uint64(41))
require.Less(t, s.Session.BytesSent, uint64(70))
require.Greater(t, s.Session.BytesReceived, uint64(31))
require.Less(t, s.Session.BytesReceived, uint64(37))
require.Less(t, s.Session.BytesReceived, uint64(80))
})
}
}
func TestClientPlaySRTPVariants(t *testing.T) {
for _, ca := range []string{
"key-mgmt in sdp session",
"key-mgmt in sdp media",
"key-mgmt in setup response",
} {
t.Run(ca, func(t *testing.T) {
cert, err := tls.X509KeyPair(serverCert, serverKey)
require.NoError(t, err)
l, err := tls.Listen("tcp", "127.0.0.1:8554", &tls.Config{Certificates: []tls.Certificate{cert}})
require.NoError(t, err)
defer l.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
nconn, err2 := l.Accept()
require.NoError(t, err2)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err2 := conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Options, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Describe),
string(base.Setup),
string(base.Play),
}, ", ")},
},
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Describe, req.Method)
outKey := make([]byte, srtpKeyLength)
_, err2 = rand.Read(outKey)
require.NoError(t, err2)
srtpOutCtx := &wrappedSRTPContext{
key: outKey,
ssrcs: []uint32{845234432},
}
err2 = srtpOutCtx.initialize()
require.NoError(t, err2)
mikeyMsg, err2 := mikeyGenerate(srtpOutCtx)
require.NoError(t, err2)
enc, err2 := mikeyMsg.Marshal()
require.NoError(t, err2)
var sdp string
switch ca {
case "key-mgmt in sdp session":
sdp = "v=0\n" +
"o=actionmovie 2891092738 2891092738 IN IP4 movie.example.com\n" +
"s=Action Movie\n" +
"t=0 0\n" +
"c=IN IP4 movie.example.com\n" +
"a=key-mgmt:mikey " + base64.StdEncoding.EncodeToString(enc) + "\n" +
"m=video 0 RTP/SAVP 96\n" +
"a=rtpmap:96 H264/90000\n" +
"a=control:trackID=0\n"
case "key-mgmt in sdp media":
sdp = "v=0\n" +
"o=actionmovie 2891092738 2891092738 IN IP4 movie.example.com\n" +
"s=Action Movie\n" +
"t=0 0\n" +
"c=IN IP4 movie.example.com\n" +
"m=video 0 RTP/SAVP 96\n" +
"a=key-mgmt:mikey " + base64.StdEncoding.EncodeToString(enc) + "\n" +
"a=rtpmap:96 H264/90000\n" +
"a=control:trackID=0\n"
case "key-mgmt in setup response":
sdp = "v=0\n" +
"o=actionmovie 2891092738 2891092738 IN IP4 movie.example.com\n" +
"s=Action Movie\n" +
"t=0 0\n" +
"c=IN IP4 movie.example.com\n" +
"m=video 0 RTP/SAVP 96\n" +
"a=rtpmap:96 H264/90000\n" +
"a=control:trackID=0\n"
}
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Type": base.HeaderValue{"application/sdp"},
"Content-Base": base.HeaderValue{"rtsps://127.0.0.1:8554/stream/"},
},
Body: []byte(sdp),
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Setup, req.Method)
var inTH headers.Transport
err2 = inTH.Unmarshal(req.Header["Transport"])
require.NoError(t, err2)
require.Equal(t, (*headers.TransportMode)(nil), inTH.Mode)
th := headers.Transport{
Secure: true,
}
v := headers.TransportDeliveryUnicast
th.Delivery = &v
th.Protocol = headers.TransportProtocolUDP
th.ClientPorts = inTH.ClientPorts
th.ServerPorts = &[2]int{34556, 34557}
h := base.Header{
"Transport": th.Marshal(),
}
if ca == "key-mgmt in setup response" {
var enc base.HeaderValue
enc, err2 = headers.KeyMgmt{
URL: req.URL.String(),
MikeyMessage: mikeyMsg,
}.Marshal()
require.NoError(t, err2)
h["KeyMgmt"] = enc
}
l1, err2 := net.ListenPacket(
"udp", net.JoinHostPort("127.0.0.1", strconv.FormatInt(int64(th.ServerPorts[0]), 10)))
require.NoError(t, err2)
defer l1.Close()
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: h,
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Play, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
buf := testRTPPacketMarshaled
encr := make([]byte, 2000)
encr, err2 = srtpOutCtx.encryptRTP(encr, buf, nil)
require.NoError(t, err2)
buf = encr
_, err2 = l1.WriteTo(buf, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[0],
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Teardown, req.Method)
}()
c := Client{
TLSConfig: &tls.Config{InsecureSkipVerify: true},
}
u, err := base.ParseURL("rtsps://127.0.0.1:8554/stream")
require.NoError(t, err)
err = c.Start(u.Scheme, u.Host)
require.NoError(t, err)
defer c.Close()
sd, _, err := c.Describe(u)
require.NoError(t, err)
err = c.SetupAll(sd.BaseURL, sd.Medias)
require.NoError(t, err)
packetRecv := make(chan struct{})
c.OnPacketRTPAny(func(_ *description.Media, _ format.Format, _ *rtp.Packet) {
close(packetRecv)
})
_, err = c.Play(nil)
require.NoError(t, err)
<-packetRecv
})
}
}
@@ -616,6 +918,7 @@ func TestClientPlayPartial(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -768,6 +1071,7 @@ func TestClientPlayContentBase(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -1072,6 +1376,7 @@ func TestClientPlayAutomaticProtocol(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -1186,6 +1491,7 @@ func TestClientPlayAutomaticProtocol(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -1359,6 +1665,7 @@ func TestClientPlayAutomaticProtocol(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -1594,6 +1901,7 @@ func TestClientPlayDifferentInterleavedIDs(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -1713,6 +2021,7 @@ func TestClientPlayRedirect(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -2000,6 +2309,7 @@ func TestClientPlayPausePlay(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -2164,6 +2474,7 @@ func TestClientPlayRTCPReport(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -2334,6 +2645,7 @@ func TestClientPlayErrorTimeout(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -2476,6 +2788,7 @@ func TestClientPlayIgnoreTCPInvalidMedia(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -2594,6 +2907,7 @@ func TestClientPlayKeepAlive(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -2765,6 +3079,7 @@ func TestClientPlayDifferentSource(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -2909,6 +3224,7 @@ func TestClientPlayDecodeErrors(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -3169,6 +3485,7 @@ func TestClientPlayPacketNTP(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)

View File

@@ -2,6 +2,7 @@ package gortsplib
import (
"bytes"
"crypto/rand"
"crypto/tls"
"fmt"
"net"
@@ -19,6 +20,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/gortsplib/v4/pkg/headers"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
"github.com/bluenviron/gortsplib/v4/pkg/sdp"
)
@@ -126,19 +128,37 @@ func readRequestIgnoreFrames(c *conn.Conn) (*base.Request, error) {
}
func TestClientRecord(t *testing.T) {
for _, transport := range []string{
"udp",
"tcp",
"tls",
for _, ca := range []struct {
scheme string
transport string
secure string
}{
{
"rtsp",
"udp",
"unsecure",
},
{
"rtsp",
"tcp",
"unsecure",
},
{
"rtsps",
"udp",
"secure",
},
{
"rtsps",
"tcp",
"secure",
},
} {
t.Run(transport, func(t *testing.T) {
t.Run(ca.scheme+"_"+ca.transport+"_"+ca.secure, func(t *testing.T) {
var l net.Listener
var err error
var scheme string
if transport == "tls" {
scheme = "rtsps"
if ca.scheme == "rtsps" {
var cert tls.Certificate
cert, err = tls.X509KeyPair(serverCert, serverKey)
require.NoError(t, err)
@@ -147,8 +167,6 @@ func TestClientRecord(t *testing.T) {
require.NoError(t, err)
defer l.Close()
} else {
scheme = "rtsp"
l, err = net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
@@ -156,6 +174,7 @@ func TestClientRecord(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -167,7 +186,7 @@ func TestClientRecord(t *testing.T) {
req, err2 := conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Options, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
@@ -184,7 +203,7 @@ func TestClientRecord(t *testing.T) {
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Announce, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL)
var desc sdp.SessionDescription
err = desc.Unmarshal(req.Body)
@@ -194,6 +213,13 @@ func TestClientRecord(t *testing.T) {
err = desc2.Unmarshal(&desc)
require.NoError(t, err2)
if ca.secure == "secure" {
require.True(t, desc2.Medias[0].Secure)
_, err = mikeyToContext(desc2.Medias[0].KeyMgmtMikey)
require.NoError(t, err)
}
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
@@ -203,7 +229,7 @@ func TestClientRecord(t *testing.T) {
require.NoError(t, err2)
require.Equal(t, base.Setup, req.Method)
require.Equal(t, mustParseURL(
scheme+"://localhost:8554/teststream/"+desc2.Medias[0].Control), req.URL)
ca.scheme+"://localhost:8554/teststream/"+desc2.Medias[0].Control), req.URL)
var inTH headers.Transport
err2 = inTH.Unmarshal(req.Header["Transport"])
@@ -213,7 +239,7 @@ func TestClientRecord(t *testing.T) {
var l1 net.PacketConn
var l2 net.PacketConn
if transport == "udp" {
if ca.transport == "udp" {
l1, err2 = net.ListenPacket("udp", "localhost:34556")
require.NoError(t, err2)
defer l1.Close()
@@ -223,11 +249,62 @@ func TestClientRecord(t *testing.T) {
defer l2.Close()
}
th := headers.Transport{
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
h := base.Header{
"Session": headers.Session{
Session: "ABCDE",
Timeout: uintPtr(1),
}.Marshal(),
}
if transport == "udp" {
th := headers.Transport{
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
Secure: inTH.Secure,
}
var srtpInCtx *wrappedSRTPContext
var srtpOutCtx *wrappedSRTPContext
if ca.secure == "secure" {
th.Secure = true
require.True(t, th.Secure)
var keyMgmt headers.KeyMgmt
err = keyMgmt.Unmarshal(req.Header["KeyMgmt"])
require.NoError(t, err)
pl1, _ := mikeyGetPayload[*mikey.PayloadKEMAC](keyMgmt.MikeyMessage)
pl2, _ := mikeyGetPayload[*mikey.PayloadKEMAC](desc2.Medias[0].KeyMgmtMikey)
require.Equal(t, pl1, pl2)
srtpInCtx, err = mikeyToContext(keyMgmt.MikeyMessage)
require.NoError(t, err)
outKey := make([]byte, srtpKeyLength)
_, err = rand.Read(outKey)
require.NoError(t, err)
srtpOutCtx = &wrappedSRTPContext{
key: outKey,
ssrcs: []uint32{2345423},
}
err = srtpOutCtx.initialize()
require.NoError(t, err)
var mikeyMsg *mikey.Message
mikeyMsg, err = mikeyGenerate(srtpOutCtx)
require.NoError(t, err)
var enc base.HeaderValue
enc, err = headers.KeyMgmt{
URL: req.URL.String(),
MikeyMessage: mikeyMsg,
}.Marshal()
require.NoError(t, err)
h["KeyMgmt"] = enc
}
if ca.transport == "udp" {
th.Protocol = headers.TransportProtocolUDP
th.ServerPorts = &[2]int{34556, 34557}
th.ClientPorts = inTH.ClientPorts
@@ -236,54 +313,55 @@ func TestClientRecord(t *testing.T) {
th.InterleavedIDs = inTH.InterleavedIDs
}
h["Transport"] = th.Marshal()
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
"Session": headers.Session{
Session: "ABCDE",
Timeout: uintPtr(1),
}.Marshal(),
},
Header: h,
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Record, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
var pl []byte
// client -> server
if transport == "udp" {
buf := make([]byte, 2048)
var buf []byte
if ca.transport == "udp" {
buf = make([]byte, 2048)
var n int
n, _, err2 = l1.ReadFrom(buf)
require.NoError(t, err2)
pl = buf[:n]
buf = buf[:n]
} else {
var f *base.InterleavedFrame
f, err2 = conn.ReadInterleavedFrame()
require.NoError(t, err2)
require.Equal(t, 0, f.Channel)
pl = f.Payload
buf = f.Payload
}
if ca.secure == "secure" {
buf, err2 = srtpInCtx.decryptRTP(buf, buf, nil)
require.NoError(t, err2)
}
var pkt rtp.Packet
err2 = pkt.Unmarshal(pl)
err2 = pkt.Unmarshal(buf)
require.NoError(t, err2)
require.Equal(t, testRTPPacket, pkt)
// client -> server keepalive
if transport == "udp" {
if ca.transport == "udp" {
recv := make(chan struct{})
go func() {
defer close(recv)
@@ -301,8 +379,17 @@ func TestClientRecord(t *testing.T) {
// server -> client
if transport == "udp" {
_, err2 = l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{
buf = testRTCPPacketMarshaled
if ca.secure == "secure" {
encr := make([]byte, 2000)
encr, err2 = srtpOutCtx.encryptRTCP(encr, buf, nil)
require.NoError(t, err2)
buf = encr
}
if ca.transport == "udp" {
_, err2 = l2.WriteTo(buf, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[1],
})
@@ -310,7 +397,7 @@ func TestClientRecord(t *testing.T) {
} else {
err2 = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 1,
Payload: testRTCPPacketMarshaled,
Payload: buf,
}, make([]byte, 1024))
require.NoError(t, err2)
}
@@ -318,7 +405,7 @@ func TestClientRecord(t *testing.T) {
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Teardown, req.Method)
require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL)
require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
@@ -333,7 +420,7 @@ func TestClientRecord(t *testing.T) {
InsecureSkipVerify: true,
},
Transport: func() *Transport {
if transport == "udp" {
if ca.transport == "udp" {
v := TransportUDP
return &v
}
@@ -345,7 +432,7 @@ func TestClientRecord(t *testing.T) {
medi := testH264Media
medias := []*description.Media{medi}
err = record(&c, scheme+"://localhost:8554/teststream", medias,
err = record(&c, ca.scheme+"://localhost:8554/teststream", medias,
func(_ *description.Media, pkt rtcp.Packet) {
require.Equal(t, &testRTCPPacket, pkt)
close(recvDone)
@@ -397,9 +484,9 @@ func TestClientRecord(t *testing.T) {
}, s)
require.Greater(t, s.Session.BytesSent, uint64(15))
require.Less(t, s.Session.BytesSent, uint64(17))
require.Less(t, s.Session.BytesSent, uint64(30))
require.Greater(t, s.Session.BytesReceived, uint64(19))
require.Less(t, s.Session.BytesReceived, uint64(21))
require.Less(t, s.Session.BytesReceived, uint64(40))
c.Close()
<-done
@@ -414,33 +501,18 @@ func TestClientRecordSocketError(t *testing.T) {
for _, transport := range []string{
"udp",
"tcp",
"tls",
} {
t.Run(transport, func(t *testing.T) {
var l net.Listener
var err error
var scheme string
if transport == "tls" {
scheme = "rtsps"
var cert tls.Certificate
cert, err = tls.X509KeyPair(serverCert, serverKey)
require.NoError(t, err)
l, err = tls.Listen("tcp", "localhost:8554", &tls.Config{Certificates: []tls.Certificate{cert}})
require.NoError(t, err)
defer l.Close()
} else {
scheme = "rtsp"
l, err = net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
}
l, err = net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -530,7 +602,7 @@ func TestClientRecordSocketError(t *testing.T) {
medi := testH264Media
medias := []*description.Media{medi}
err = record(&c, scheme+"://localhost:8554/teststream", medias, nil)
err = record(&c, "rtsp://localhost:8554/teststream", medias, nil)
require.NoError(t, err)
defer c.Close()
@@ -559,6 +631,7 @@ func TestClientRecordPauseRecordSerial(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -707,6 +780,7 @@ func TestClientRecordPauseRecordParallel(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -885,6 +959,7 @@ func TestClientRecordAutomaticProtocol(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -1016,6 +1091,7 @@ func TestClientRecordDecodeErrors(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -1186,6 +1262,7 @@ func TestClientRecordRTCPReport(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -1371,6 +1448,7 @@ func TestClientRecordIgnoreTCPRTPPackets(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)

View File

@@ -58,6 +58,7 @@ func TestClientTLSSetServerName(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -141,6 +142,7 @@ func TestClientCloseDuringRequest(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -185,6 +187,7 @@ func TestClientSession(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -246,6 +249,7 @@ func TestClientAuth(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -327,6 +331,7 @@ func TestClientCSeq(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -399,6 +404,7 @@ func TestClientDescribeCharset(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
@@ -549,6 +555,7 @@ func TestClientRelativeContentBase(t *testing.T) {
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)

View File

@@ -4,7 +4,6 @@ import (
"crypto/rand"
"math/big"
"net"
"strconv"
"sync/atomic"
"time"
@@ -24,45 +23,6 @@ func randInRange(maxVal int) (int, error) {
return int(n.Int64()), nil
}
func createUDPListenerPair(c *Client) (*clientUDPListener, *clientUDPListener, error) {
// choose two consecutive ports in range 65535-10000
// RTP port must be even and RTCP port odd
for {
v, err := randInRange((65535 - 10000) / 2)
if err != nil {
return nil, nil, err
}
rtpPort := v*2 + 10000
rtcpPort := rtpPort + 1
rtpListener := &clientUDPListener{
c: c,
multicastEnable: false,
multicastSourceIP: nil,
address: net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)),
}
err = rtpListener.initialize()
if err != nil {
continue
}
rtcpListener := &clientUDPListener{
c: c,
multicastEnable: false,
multicastSourceIP: nil,
address: net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)),
}
err = rtcpListener.initialize()
if err != nil {
rtpListener.close()
continue
}
return rtpListener, rtcpListener, nil
}
}
type packetConn interface {
net.PacketConn
SetReadBuffer(int) error

View File

@@ -6,4 +6,13 @@ const (
// 1500 (UDP MTU) - 20 (IP header) - 8 (UDP header)
udpMaxPayloadSize = 1472
// 16 master key + 14 master salt
srtpKeyLength = 30
// 10 (HMAC SHA1 authentication tag)
srtpOverhead = 10
// 10 (HMAC SHA1 authentication tag) + 4 (sequence number)
srtcpOverhead = 14
)

View File

@@ -15,7 +15,7 @@ import (
)
// This example shows how to
// 1. create a RTSP server which uses secure protocols only (RTSPS, TLS).
// 1. create a RTSP server which uses secure protocols only (RTSPS, TLS, SRTP).
// 2. allow a single client to publish a stream.
// 3. allow several clients to read the stream.
@@ -175,9 +175,14 @@ func main() {
// when TLSConfig is set, only secure protocols are used.
h := &serverHandler{}
h.server = &gortsplib.Server{
Handler: h,
TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert}},
RTSPAddress: ":8322",
Handler: h,
TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert}},
RTSPAddress: ":8322",
UDPRTPAddress: ":8004",
UDPRTCPAddress: ":8005",
MulticastIPRange: "224.1.0.0/16",
MulticastRTPPort: 8006,
MulticastRTCPPort: 8007,
}
// start server and wait until a fatal error

3
go.mod
View File

@@ -9,6 +9,7 @@ require (
github.com/pion/rtcp v1.2.15
github.com/pion/rtp v1.8.20
github.com/pion/sdp/v3 v3.0.14
github.com/pion/srtp/v3 v3.0.6
github.com/stretchr/testify v1.10.0
golang.org/x/net v0.41.0
)
@@ -16,7 +17,9 @@ require (
require (
github.com/asticode/go-astikit v0.30.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pion/logging v0.2.3 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/transport/v3 v3.0.7 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/sys v0.33.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

6
go.sum
View File

@@ -9,6 +9,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/pion/logging v0.2.3 h1:gHuf0zpoh1GW67Nr6Gj4cv5Z9ZscU7g/EaoC/Ke/igI=
github.com/pion/logging v0.2.3/go.mod h1:z8YfknkquMe1csOrxK5kc+5/ZPAzMxbKLX5aXpbpC90=
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo=
@@ -17,6 +19,10 @@ github.com/pion/rtp v1.8.20 h1:8zcyqohadZE8FCBeGdyEvHiclPIezcwRQH9zfapFyYI=
github.com/pion/rtp v1.8.20/go.mod h1:bAu2UFKScgzyFqvUKmbvzSdPr+NGbZtv6UB2hesqXBk=
github.com/pion/sdp/v3 v3.0.14 h1:1h7gBr9FhOWH5GjWWY5lcw/U85MtdcibTyt/o6RxRUI=
github.com/pion/sdp/v3 v3.0.14/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E=
github.com/pion/srtp/v3 v3.0.6 h1:E2gyj1f5X10sB/qILUGIkL4C2CqK269Xq167PbGCc/4=
github.com/pion/srtp/v3 v3.0.6/go.mod h1:BxvziG3v/armJHAaJ87euvkhHqWe9I7iiOy50K2QkhY=
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
github.com/pkg/profile v1.4.0/go.mod h1:NWz/XGvpEW1FyYQ7fCx4dqYBLlfTcE+A9FLAkNKqjFE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=

View File

@@ -85,6 +85,18 @@ func TestClientVsServer(t *testing.T) {
readerScheme: "rtsps",
readerProto: "tcp",
},
{
publisherScheme: "rtsps",
publisherProto: "udp",
readerScheme: "rtsps",
readerProto: "tcp",
},
{
publisherScheme: "rtsps",
publisherProto: "udp",
readerScheme: "rtsps",
readerProto: "multicast",
},
} {
t.Run(ca.publisherScheme+"_"+ca.publisherProto+"_"+
ca.readerScheme+"_"+ca.readerProto, func(t *testing.T) {

View File

@@ -223,17 +223,14 @@ func (sh *sampleServer) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*base
func (sh *sampleServer) initialize() error {
sh.s = &gortsplib.Server{
Handler: sh,
TLSConfig: sh.tlsConfig,
RTSPAddress: "0.0.0.0:8554",
}
if sh.tlsConfig == nil {
sh.s.UDPRTPAddress = "0.0.0.0:8000"
sh.s.UDPRTCPAddress = "0.0.0.0:8001"
sh.s.MulticastIPRange = "224.1.0.0/16"
sh.s.MulticastRTPPort = 8002
sh.s.MulticastRTCPPort = 8003
Handler: sh,
TLSConfig: sh.tlsConfig,
RTSPAddress: "0.0.0.0:8554",
UDPRTPAddress: "0.0.0.0:8000",
UDPRTCPAddress: "0.0.0.0:8001",
MulticastIPRange: "224.1.0.0/16",
MulticastRTPPort: 8002,
MulticastRTCPPort: 8003,
}
err := sh.s.Start()

View File

@@ -248,6 +248,46 @@ func TestServerVsExternal(t *testing.T) {
readerProto: "tcp",
readerSecure: "unsecure",
},
{
publisherSoft: "gstreamer",
publisherScheme: "rtsps",
publisherProto: "tcp",
publisherSecure: "unsecure",
readerSoft: "gstreamer",
readerScheme: "rtsps",
readerProto: "tcp",
readerSecure: "unsecure",
},
{
publisherSoft: "ffmpeg",
publisherScheme: "rtsps",
publisherProto: "tcp",
publisherSecure: "unsecure",
readerSoft: "gstreamer",
readerScheme: "rtsps",
readerProto: "udp",
readerSecure: "secure",
},
{
publisherSoft: "gstreamer",
publisherScheme: "rtsps",
publisherProto: "udp",
publisherSecure: "secure",
readerSoft: "gstreamer",
readerScheme: "rtsps",
readerProto: "udp",
readerSecure: "secure",
},
{
publisherSoft: "gstreamer",
publisherScheme: "rtsps",
publisherProto: "udp",
publisherSecure: "secure",
readerSoft: "gstreamer",
readerScheme: "rtsps",
readerProto: "multicast",
readerSecure: "secure",
},
} {
t.Run(ca.publisherSoft+"_"+ca.publisherScheme+"_"+ca.publisherProto+"_"+ca.publisherSecure+"_"+
ca.readerSoft+"_"+ca.readerScheme+"_"+ca.readerProto+"_"+ca.readerSecure, func(t *testing.T) {

View File

@@ -6,6 +6,7 @@ import (
"encoding/hex"
"fmt"
"regexp"
"slices"
"github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/headers"
@@ -25,15 +26,6 @@ func sha256Hex(in string) string {
return hex.EncodeToString(h.Sum(nil))
}
func contains(list []VerifyMethod, item VerifyMethod) bool {
for _, i := range list {
if i == item {
return true
}
}
return false
}
func urlMatches(expected string, received string, isSetup bool) bool {
if received == expected {
return true
@@ -84,9 +76,9 @@ func Verify(
switch {
case auth.Method == headers.AuthMethodDigest &&
(contains(methods, VerifyMethodDigestMD5) &&
(slices.Contains(methods, VerifyMethodDigestMD5) &&
(auth.Algorithm == nil || *auth.Algorithm == headers.AuthAlgorithmMD5) ||
contains(methods, VerifyMethodDigestSHA256) &&
slices.Contains(methods, VerifyMethodDigestSHA256) &&
auth.Algorithm != nil && *auth.Algorithm == headers.AuthAlgorithmSHA256):
if auth.Nonce != nonce {
return fmt.Errorf("wrong nonce")
@@ -118,7 +110,7 @@ func Verify(
return fmt.Errorf("authentication failed")
}
case auth.Method == headers.AuthMethodBasic && contains(methods, VerifyMethodBasic):
case auth.Method == headers.AuthMethodBasic && slices.Contains(methods, VerifyMethodBasic):
if auth.Username != user {
return fmt.Errorf("authentication failed")
}

View File

@@ -24,6 +24,9 @@ func headerKeyNormalize(in string) string {
case "cseq":
return "CSeq"
case "keymgmt":
return "KeyMgmt"
}
return http.CanonicalHeaderKey(in)
}

View File

@@ -92,8 +92,10 @@ var cases = []struct {
[]byte("www-authenticate: value\r\n" +
"cseq: value\r\n" +
"rtp-info: value\r\n" +
"keymgmt: value\r\n" +
"\r\n"),
[]byte("CSeq: value\r\n" +
"KeyMgmt: value\r\n" +
"RTP-Info: value\r\n" +
"WWW-Authenticate: value\r\n" +
"\r\n"),
@@ -101,6 +103,7 @@ var cases = []struct {
"CSeq": HeaderValue{"value"},
"RTP-Info": HeaderValue{"value"},
"WWW-Authenticate": HeaderValue{"value"},
"KeyMgmt": HeaderValue{"value"},
},
},
}

View File

@@ -2,8 +2,10 @@
package description
import (
"encoding/base64"
"fmt"
"reflect"
"slices"
"sort"
"strconv"
"strings"
@@ -13,6 +15,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
)
func getAttribute(attributes []psdp.Attribute, key string) string {
@@ -78,6 +81,12 @@ type Media struct {
// Control attribute.
Control string
// Whether the transport is secure.
Secure bool
// key-mgmt attribute.
KeyMgmtMikey *mikey.Message
// Formats contained into the media.
Formats []format.Format
}
@@ -93,6 +102,24 @@ func (m *Media) Unmarshal(md *psdp.MediaDescription) error {
m.IsBackChannel = isBackChannel(md.Attributes)
m.Control = getAttribute(md.Attributes, "control")
m.Secure = slices.Contains(md.MediaName.Protos, "SAVP")
if enc := getAttribute(md.Attributes, "key-mgmt"); enc != "" {
if !strings.HasPrefix(enc, "mikey ") {
return fmt.Errorf("unsupported key-mgmt: %v", enc)
}
enc2, err := base64.StdEncoding.DecodeString(enc[len("mikey "):])
if err != nil {
return err
}
m.KeyMgmtMikey = &mikey.Message{}
err = m.KeyMgmtMikey.Unmarshal(enc2)
if err != nil {
return err
}
}
m.Formats = nil
@@ -113,11 +140,29 @@ func (m *Media) Unmarshal(md *psdp.MediaDescription) error {
}
// Marshal encodes the media in SDP format.
//
// Deprecated: replaced by Marshal2.
func (m Media) Marshal() *psdp.MediaDescription {
ret, err := m.Marshal2()
if err != nil {
panic(err)
}
return ret
}
// Marshal2 encodes the media in SDP format.
func (m Media) Marshal2() (*psdp.MediaDescription, error) {
var protos []string
if !m.Secure {
protos = []string{"RTP", "AVP"}
} else {
protos = []string{"RTP", "SAVP"}
}
md := &psdp.MediaDescription{
MediaName: psdp.MediaName{
Media: string(m.Type),
Protos: []string{"RTP", "AVP"},
Protos: protos,
},
}
@@ -134,6 +179,18 @@ func (m Media) Marshal() *psdp.MediaDescription {
})
}
if m.KeyMgmtMikey != nil {
keyEnc, err := m.KeyMgmtMikey.Marshal()
if err != nil {
return nil, err
}
md.Attributes = append(md.Attributes, psdp.Attribute{
Key: "key-mgmt",
Value: "mikey " + base64.StdEncoding.EncodeToString(keyEnc),
})
}
md.Attributes = append(md.Attributes, psdp.Attribute{
Key: "control",
Value: m.Control,
@@ -165,7 +222,7 @@ func (m Media) Marshal() *psdp.MediaDescription {
}
}
return md
return md, nil
}
// URL returns the absolute URL of the media.

View File

@@ -1,12 +1,14 @@
package description
import (
"encoding/base64"
"fmt"
"strings"
psdp "github.com/pion/sdp/v3"
"github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
"github.com/bluenviron/gortsplib/v4/pkg/sdp"
)
@@ -51,6 +53,9 @@ type Session struct {
// Whether to use multicast.
Multicast bool
// key-mgmt attribute.
KeyMgmtMikey *mikey.Message
// FEC groups (RFC5109).
FECGroups []SessionFECGroup
@@ -77,6 +82,23 @@ func (d *Session) Unmarshal(ssd *sdp.SessionDescription) error {
d.Title = ""
}
if enc := getAttribute(ssd.Attributes, "key-mgmt"); enc != "" {
if !strings.HasPrefix(enc, "mikey ") {
return fmt.Errorf("unsupported key-mgmt: %v", enc)
}
enc2, err := base64.StdEncoding.DecodeString(enc[len("mikey "):])
if err != nil {
return err
}
d.KeyMgmtMikey = &mikey.Message{}
err = d.KeyMgmtMikey.Unmarshal(enc2)
if err != nil {
return err
}
}
if len(ssd.MediaDescriptions) == 0 {
return fmt.Errorf("no media streams are present in SDP")
}
@@ -163,11 +185,29 @@ func (d Session) Marshal(_ bool) ([]byte, error) {
})
}
if d.KeyMgmtMikey != nil {
keyEnc, err := d.KeyMgmtMikey.Marshal()
if err != nil {
return nil, err
}
sout.Attributes = append(sout.Attributes, psdp.Attribute{
Key: "key-mgmt",
Value: "mikey " + base64.StdEncoding.EncodeToString(keyEnc),
})
}
sout.MediaDescriptions = make([]*psdp.MediaDescription, len(d.Medias))
for i, media := range d.Medias {
sout.MediaDescriptions[i] = media.Marshal()
med, err := media.Marshal2()
if err != nil {
return nil, err
}
sout.MediaDescriptions[i] = med
}
return sout.Marshal()
out, _ := sout.Marshal()
return out, nil
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
"github.com/bluenviron/gortsplib/v4/pkg/sdp"
)
@@ -638,6 +639,194 @@ var casesSession = []struct {
},
},
},
{
"key-mgmt in session",
"v=0\n" +
"o=actionmovie 2891092738 2891092738 IN IP4 movie.example.com\n" +
"s=Action Movie\n" +
"t=0 0\n" +
"c=IN IP4 movie.example.com\n" +
"a=key-mgmt:mikey AQAFAHojKV4BAACVjCMnAAAAAAsA6/mdTLBeokwKEGwcAuPrxj6/enyb+" +
"A2+rNcBAAAAFQABAQEBEAIBAQMBCgcBAQgBAQoBAQAAACIAIAAeX8XvOCzIMh0JTOWivWLxEflTUSp1fjj2i8xG7D9DAA==\r\n" +
"m=video 0 RTP/SAVP 96\n" +
"a=rtpmap:96 H264/90000\n" +
"a=control:trackID=0\n",
"v=0\r\n" +
"o=- 0 0 IN IP4 127.0.0.1\r\n" +
"s=Action Movie\r\n" +
"c=IN IP4 0.0.0.0\r\n" +
"t=0 0\r\n" +
"a=key-mgmt:mikey AQAFAHojKV4BAACVjCMnAAAAAAsA6/mdTLBeokwKEGwcAuPrxj6/enyb+" +
"A2+rNcBAAAAFQABAQEBEAIBAQMBCgcBAQgBAQoBAQAAACIAIAAeX8XvOCzIMh0JTOWivWLxEflTUSp1fjj2i8xG7D9DAA==\r\n" +
"m=video 0 RTP/SAVP 96\r\n" +
"a=control:trackID=0\r\n" +
"a=rtpmap:96 H264/90000\r\n",
Session{
Title: "Action Movie",
KeyMgmtMikey: &mikey.Message{ //nolint:dupl
Header: mikey.Header{
Version: 1,
CSBID: 2049124702,
CSIDMapInfo: []mikey.SRTPIDEntry{{
SSRC: 2508989223,
}},
},
Payloads: []mikey.Payload{
&mikey.PayloadT{
TSValue: 17003794820816085580,
},
&mikey.PayloadRAND{
Data: []byte{
0x6c, 0x1c, 0x02, 0xe3, 0xeb, 0xc6, 0x3e, 0xbf,
0x7a, 0x7c, 0x9b, 0xf8, 0x0d, 0xbe, 0xac, 0xd7,
},
},
&mikey.PayloadSP{
PolicyParams: []mikey.PayloadSPPolicyParam{
{
Type: 0, Value: []byte{1},
},
{
Type: 1, Value: []byte{0x10},
},
{
Type: 2, Value: []byte{1},
},
{
Type: 3, Value: []byte{0x0a},
},
{
Type: 7, Value: []byte{1},
},
{
Type: 8, Value: []byte{1},
},
{
Type: 10, Value: []byte{1},
},
},
},
&mikey.PayloadKEMAC{
SubPayloads: []*mikey.SubPayloadKeyData{
{
Type: 2,
KeyData: []byte{
0x5f, 0xc5, 0xef, 0x38, 0x2c, 0xc8, 0x32, 0x1d,
0x09, 0x4c, 0xe5, 0xa2, 0xbd, 0x62, 0xf1, 0x11,
0xf9, 0x53, 0x51, 0x2a, 0x75, 0x7e, 0x38, 0xf6,
0x8b, 0xcc, 0x46, 0xec, 0x3f, 0x43,
},
},
},
},
},
},
Medias: []*Media{
{
Type: "video",
Control: "trackID=0",
Secure: true,
Formats: []format.Format{&format.H264{
PayloadTyp: 96,
}},
},
},
},
},
{
"key-mgmt in media",
"v=0\n" +
"o=actionmovie 2891092738 2891092738 IN IP4 movie.example.com\n" +
"s=Action Movie\n" +
"t=0 0\n" +
"c=IN IP4 movie.example.com\n" +
"m=video 0 RTP/SAVP 96\n" +
"a=key-mgmt:mikey AQAFAHojKV4BAACVjCMnAAAAAAsA6/mdTLBeokwKEGwcAuPrxj6/enyb+" +
"A2+rNcBAAAAFQABAQEBEAIBAQMBCgcBAQgBAQoBAQAAACIAIAAeX8XvOCzIMh0JTOWivWLxEflTUSp1fjj2i8xG7D9DAA==\r\n" +
"a=rtpmap:96 H264/90000\n" +
"a=control:trackID=0\n",
"v=0\r\n" +
"o=- 0 0 IN IP4 127.0.0.1\r\n" +
"s=Action Movie\r\n" +
"c=IN IP4 0.0.0.0\r\n" +
"t=0 0\r\n" +
"m=video 0 RTP/SAVP 96\r\n" +
"a=key-mgmt:mikey AQAFAHojKV4BAACVjCMnAAAAAAsA6/mdTLBeokwKEGwcAuPrxj6/enyb+" +
"A2+rNcBAAAAFQABAQEBEAIBAQMBCgcBAQgBAQoBAQAAACIAIAAeX8XvOCzIMh0JTOWivWLxEflTUSp1fjj2i8xG7D9DAA==\r\n" +
"a=control:trackID=0\r\n" +
"a=rtpmap:96 H264/90000\r\n",
Session{
Title: "Action Movie",
Medias: []*Media{
{
Type: "video",
Control: "trackID=0",
Secure: true,
KeyMgmtMikey: &mikey.Message{ //nolint:dupl
Header: mikey.Header{
Version: 1,
CSBID: 2049124702,
CSIDMapInfo: []mikey.SRTPIDEntry{{
SSRC: 2508989223,
}},
},
Payloads: []mikey.Payload{
&mikey.PayloadT{
TSValue: 17003794820816085580,
},
&mikey.PayloadRAND{
Data: []byte{
0x6c, 0x1c, 0x02, 0xe3, 0xeb, 0xc6, 0x3e, 0xbf,
0x7a, 0x7c, 0x9b, 0xf8, 0x0d, 0xbe, 0xac, 0xd7,
},
},
&mikey.PayloadSP{
PolicyParams: []mikey.PayloadSPPolicyParam{
{
Type: 0, Value: []byte{1},
},
{
Type: 1, Value: []byte{0x10},
},
{
Type: 2, Value: []byte{1},
},
{
Type: 3, Value: []byte{0x0a},
},
{
Type: 7, Value: []byte{1},
},
{
Type: 8, Value: []byte{1},
},
{
Type: 10, Value: []byte{1},
},
},
},
&mikey.PayloadKEMAC{
SubPayloads: []*mikey.SubPayloadKeyData{
{
Type: 2,
KeyData: []byte{
0x5f, 0xc5, 0xef, 0x38, 0x2c, 0xc8, 0x32, 0x1d,
0x09, 0x4c, 0xe5, 0xa2, 0xbd, 0x62, 0xf1, 0x11,
0xf9, 0x53, 0x51, 0x2a, 0x75, 0x7e, 0x38, 0xf6,
0x8b, 0xcc, 0x46, 0xec, 0x3f, 0x43,
},
},
},
},
},
},
Formats: []format.Format{&format.H264{
PayloadTyp: 96,
}},
},
},
},
},
}
func TestSessionUnmarshal(t *testing.T) {

View File

@@ -49,7 +49,6 @@ func (d *Decoder) resetFragments() {
}
// Decode decodes frames from a RTP packet.
// It returns the frames and the PTS of the first frame.
func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, error) {
if len(pkt.Payload) < 2 {
d.resetFragments()

View File

@@ -22,7 +22,6 @@ func (d *Decoder) Init() error {
}
// Decode decodes audio samples from a RTP packet.
// It returns audio samples and PTS of the first sample.
func (d *Decoder) Decode(pkt *rtp.Packet) ([]byte, error) {
plen := len(pkt.Payload)
if (plen % d.sampleSize) != 0 {

View File

@@ -52,8 +52,6 @@ func (d *Decoder) resetFragments() {
}
// Decode decodes AUs from a RTP packet.
// It returns the AUs and the PTS of the first AU.
// The PTS of subsequent AUs can be calculated by adding time.Second*mpeg4audio.SamplesPerAccessUnit/clockRate.
func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, error) {
if !d.LATM {
return d.decodeGeneric(pkt)

86
pkg/headers/key_mgmt.go Normal file
View File

@@ -0,0 +1,86 @@
package headers
import (
"encoding/base64"
"fmt"
"github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
)
// KeyMgmt is a KeyMgmt header.
type KeyMgmt struct {
URL string
MikeyMessage *mikey.Message
}
// Unmarshal decodes a KeyMgmt header.
func (h *KeyMgmt) Unmarshal(v base.HeaderValue) error {
if len(v) == 0 {
return fmt.Errorf("value not provided")
}
if len(v) > 1 {
return fmt.Errorf("value provided multiple times (%v)", v)
}
kvs, err := keyValParse(v[0], ';')
if err != nil {
return err
}
protocolProvided := false
uriProvided := false
for k, v := range kvs {
switch k {
case "prot":
if v != "mikey" {
return fmt.Errorf("unsupported protocol: %v", v)
}
protocolProvided = true
case "uri":
h.URL = v
uriProvided = true
case "data":
byts, err := base64.StdEncoding.DecodeString(v)
if err != nil {
return fmt.Errorf("invalid data: %w", err)
}
h.MikeyMessage = &mikey.Message{}
err = h.MikeyMessage.Unmarshal(byts)
if err != nil {
return fmt.Errorf("invalid data: %w", err)
}
}
}
if !protocolProvided {
return fmt.Errorf("protocol not provided")
}
if !uriProvided {
return fmt.Errorf("URI not provided")
}
if h.MikeyMessage == nil {
return fmt.Errorf("mikey message not provided")
}
return nil
}
// Marshal encodes a KeyMgmt header.
func (h KeyMgmt) Marshal() (base.HeaderValue, error) {
buf, err := h.MikeyMessage.Marshal()
if err != nil {
return nil, err
}
encData := base64.StdEncoding.EncodeToString(buf)
return base.HeaderValue{`prot=mikey;uri="` + h.URL + `";data="` + encData + `"`}, nil
}

View File

@@ -0,0 +1,143 @@
package headers
import (
"testing"
"github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
"github.com/stretchr/testify/require"
)
var casesKeyMgmt = []struct {
name string
vin base.HeaderValue
vout base.HeaderValue
h KeyMgmt
}{
{
"standard",
base.HeaderValue{`prot=mikey;` +
`uri="rtsps://127.0.0.1:8322/stream/trackID=0";` +
`data="AQAFAHojKV4BAACVjCMnAAAAAAsA6/mdTLBeokwKEGwcAuPrxj6/enyb+` +
`A2+rNcBAAAAFQABAQEBEAIBAQMBCgcBAQgBAQoBAQAAACIAIAAeX8XvOCzIMh0JTOWivWLxEflTUSp1fjj2i8xG7D9DAA=="`},
base.HeaderValue{`prot=mikey;` +
`uri="rtsps://127.0.0.1:8322/stream/trackID=0";` +
`data="AQAFAHojKV4BAACVjCMnAAAAAAsA6/mdTLBeokwKEGwcAuPrxj6/enyb+` +
`A2+rNcBAAAAFQABAQEBEAIBAQMBCgcBAQgBAQoBAQAAACIAIAAeX8XvOCzIMh0JTOWivWLxEflTUSp1fjj2i8xG7D9DAA=="`},
KeyMgmt{
URL: "rtsps://127.0.0.1:8322/stream/trackID=0",
MikeyMessage: &mikey.Message{
Header: mikey.Header{
Version: 1,
CSBID: 2049124702,
CSIDMapInfo: []mikey.SRTPIDEntry{
{
SSRC: 2508989223,
},
},
},
Payloads: []mikey.Payload{
&mikey.PayloadT{
TSValue: 17003794820816085580,
},
&mikey.PayloadRAND{
Data: []byte{
0x6c, 0x1c, 0x02, 0xe3, 0xeb, 0xc6, 0x3e, 0xbf,
0x7a, 0x7c, 0x9b, 0xf8, 0x0d, 0xbe, 0xac, 0xd7,
},
},
&mikey.PayloadSP{
PolicyParams: []mikey.PayloadSPPolicyParam{
{
Type: 0, Value: []byte{1},
},
{
Type: 1, Value: []byte{0x10},
},
{
Type: 2, Value: []byte{1},
},
{
Type: 3, Value: []byte{0x0a},
},
{
Type: 7, Value: []byte{1},
},
{
Type: 8, Value: []byte{1},
},
{
Type: 10, Value: []byte{1},
},
},
},
&mikey.PayloadKEMAC{
SubPayloads: []*mikey.SubPayloadKeyData{
{
Type: 2,
KeyData: []byte{
0x5f, 0xc5, 0xef, 0x38, 0x2c, 0xc8, 0x32, 0x1d,
0x09, 0x4c, 0xe5, 0xa2, 0xbd, 0x62, 0xf1, 0x11,
0xf9, 0x53, 0x51, 0x2a, 0x75, 0x7e, 0x38, 0xf6,
0x8b, 0xcc, 0x46, 0xec, 0x3f, 0x43,
},
},
},
},
},
},
},
},
}
func TestKeyMgmtUnmarshal(t *testing.T) {
for _, ca := range casesKeyMgmt {
t.Run(ca.name, func(t *testing.T) {
var h KeyMgmt
err := h.Unmarshal(ca.vin)
require.NoError(t, err)
require.Equal(t, ca.h, h)
})
}
}
func TestKeyMgmtMarshal(t *testing.T) {
for _, ca := range casesKeyMgmt {
t.Run(ca.name, func(t *testing.T) {
req, err := ca.h.Marshal()
require.NoError(t, err)
require.Equal(t, ca.vout, req)
})
}
}
func FuzzKeyMgmtUnmarshal(f *testing.F) {
for _, ca := range casesKeyMgmt {
f.Add(ca.vin[0])
}
f.Fuzz(func(t *testing.T, b string) {
var h KeyMgmt
err := h.Unmarshal(base.HeaderValue{b})
if err != nil {
return
}
_, err = h.Marshal()
require.NoError(t, err)
})
}
func TestKeyMgmtAdditionalErrors(t *testing.T) {
func() {
var h KeyMgmt
err := h.Unmarshal(base.HeaderValue{})
require.Error(t, err)
}()
func() {
var h KeyMgmt
err := h.Unmarshal(base.HeaderValue{"a", "b"})
require.Error(t, err)
}()
}

View File

@@ -268,7 +268,7 @@ func rangeValueUnmarshal(s RangeValue, v string) error {
// Range is a Range header.
type Range struct {
// range expressed in a certain unit.
// range expressed in some measurement units.
Value RangeValue
// time at which the operation is to be made effective.
@@ -285,9 +285,7 @@ func (h *Range) Unmarshal(v base.HeaderValue) error {
return fmt.Errorf("value provided multiple times (%v)", v)
}
v0 := v[0]
kvs, err := keyValParse(v0, ';')
kvs, err := keyValParse(v[0], ';')
if err != nil {
return err
}

View File

@@ -13,7 +13,7 @@ type Session struct {
// session id
Session string
// (optional) a timeout
// (optional) timeout
Timeout *uint
}

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("data")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("0")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("prot")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("prot=\"0000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("data=00")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("prot=mikey")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("port=0-")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("port=--")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("port=-")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
string("source=0.0.0.0.A.0")

View File

@@ -48,6 +48,8 @@ const (
)
// String implements fmt.Stringer.
//
// Deprecated: not used anymore.
func (p TransportProtocol) String() string {
if p == TransportProtocolUDP {
return "RTP/AVP"
@@ -65,6 +67,8 @@ const (
)
// String implements fmt.Stringer.
//
// Deprecated: not used anymore.
func (d TransportDelivery) String() string {
if d == TransportDeliveryUnicast {
return "unicast"
@@ -112,37 +116,40 @@ func (m TransportMode) String() string {
// Transport is a Transport header.
type Transport struct {
// protocol of the stream
// protocol of the stream.
Protocol TransportProtocol
// (optional) delivery method of the stream
// Whether the secure variant is active.
Secure bool
// (optional) delivery method of the stream.
Delivery *TransportDelivery
// (optional) Source IP
// (optional) Source IP.
Source *net.IP
// (optional) destination IP
// (optional) destination IP.
Destination *net.IP
// (optional) interleaved frame ids
// (optional) interleaved frame IDs.
InterleavedIDs *[2]int
// (optional) TTL
// (optional) TTL.
TTL *uint
// (optional) ports
// (optional) ports.
Ports *[2]int
// (optional) client ports
// (optional) client ports.
ClientPorts *[2]int
// (optional) server ports
// (optional) server ports.
ServerPorts *[2]int
// (optional) SSRC of the packets of the stream
// (optional) SSRC of the packets of the stream.
SSRC *uint32
// (optional) mode
// (optional) mode.
Mode *TransportMode
}
@@ -156,14 +163,12 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error {
return fmt.Errorf("value provided multiple times (%v)", v)
}
v0 := v[0]
kvs, err := keyValParse(v0, ';')
kvs, err := keyValParse(v[0], ';')
if err != nil {
return err
}
protocolFound := false
profileFound := false
for k, rv := range kvs {
v := rv
@@ -171,11 +176,21 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error {
switch k {
case "RTP/AVP", "RTP/AVP/UDP":
h.Protocol = TransportProtocolUDP
protocolFound = true
profileFound = true
case "RTP/AVP/TCP":
h.Protocol = TransportProtocolTCP
protocolFound = true
profileFound = true
case "RTP/SAVP", "RTP/SAVP/UDP":
h.Protocol = TransportProtocolUDP
h.Secure = true
profileFound = true
case "RTP/SAVP/TCP":
h.Protocol = TransportProtocolTCP
h.Secure = true
profileFound = true
case "unicast":
v := TransportDeliveryUnicast
@@ -273,8 +288,8 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error {
}
}
if !protocolFound {
return fmt.Errorf("protocol not found (%v)", v[0])
if !profileFound {
return fmt.Errorf("profile is missing: %v", v[0])
}
return nil
@@ -284,10 +299,33 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error {
func (h Transport) Marshal() base.HeaderValue {
var rets []string
rets = append(rets, h.Protocol.String())
var profile string
switch {
case h.Protocol == TransportProtocolUDP && !h.Secure:
profile = "RTP/AVP"
case h.Protocol == TransportProtocolTCP && !h.Secure:
profile = "RTP/AVP/TCP"
case h.Protocol == TransportProtocolUDP && h.Secure:
profile = "RTP/SAVP"
case h.Protocol == TransportProtocolTCP && h.Secure:
profile = "RTP/SAVP/TCP"
}
rets = append(rets, profile)
if h.Delivery != nil {
rets = append(rets, h.Delivery.String())
var delivery string
switch *h.Delivery {
case TransportDeliveryUnicast:
delivery = "unicast"
case TransportDeliveryMulticast:
delivery = "multicast"
}
rets = append(rets, delivery)
}
if h.Source != nil {
@@ -337,43 +375,3 @@ func (h Transport) Marshal() base.HeaderValue {
return base.HeaderValue{strings.Join(rets, ";")}
}
// Transports is a Transport header with multiple transports.
type Transports []Transport
// Unmarshal decodes a Transport header.
func (ts *Transports) Unmarshal(v base.HeaderValue) error {
if len(v) == 0 {
return fmt.Errorf("value not provided")
}
if len(v) > 1 {
return fmt.Errorf("value provided multiple times (%v)", v)
}
v0 := v[0]
transports := strings.Split(v0, ",") // , separated per RFC2326 section 12.39
*ts = make([]Transport, len(transports))
for i, transport := range transports {
var tr Transport
err := tr.Unmarshal(base.HeaderValue{strings.TrimLeft(transport, " ")})
if err != nil {
return err
}
(*ts)[i] = tr
}
return nil
}
// Marshal encodes a Transport header.
func (ts Transports) Marshal() base.HeaderValue {
vals := make([]string, len(ts))
for i, th := range ts {
vals[i] = th.Marshal()[0]
}
return base.HeaderValue{strings.Join(vals, ",")}
}

View File

@@ -168,6 +168,28 @@ var casesTransport = []struct {
ServerPorts: &[2]int{56002, 56003},
},
},
{
"secure udp unicast play request",
base.HeaderValue{`RTP/SAVP;unicast;client_port=3456-3457;mode="PLAY"`},
base.HeaderValue{`RTP/SAVP;unicast;client_port=3456-3457;mode=play`},
Transport{
Protocol: TransportProtocolUDP,
Secure: true,
Delivery: deliveryPtr(TransportDeliveryUnicast),
ClientPorts: &[2]int{3456, 3457},
Mode: transportModePtr(TransportModePlay),
},
},
{
"secure tcp play request / response",
base.HeaderValue{`RTP/SAVP/TCP;interleaved=0-1`},
base.HeaderValue{`RTP/SAVP/TCP;interleaved=0-1`},
Transport{
Protocol: TransportProtocolTCP,
Secure: true,
InterleavedIDs: &[2]int{0, 1},
},
},
}
func TestTransportUnmarshal(t *testing.T) {
@@ -190,81 +212,6 @@ func TestTransportMarshal(t *testing.T) {
}
}
var casesTransports = []struct {
name string
vin base.HeaderValue
vout base.HeaderValue
h Transports
}{
{
"a",
base.HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode="PLAY", RTP/AVP/TCP;unicast;interleaved=0-1`},
base.HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode=play,RTP/AVP/TCP;unicast;interleaved=0-1`},
Transports{
{
Protocol: TransportProtocolUDP,
Delivery: deliveryPtr(TransportDeliveryUnicast),
ClientPorts: &[2]int{3456, 3457},
Mode: transportModePtr(TransportModePlay),
},
Transport{
Protocol: TransportProtocolTCP,
Delivery: deliveryPtr(TransportDeliveryUnicast),
InterleavedIDs: &[2]int{0, 1},
},
},
},
}
func TestTransportsUnmarshal(t *testing.T) {
for _, ca := range casesTransports {
t.Run(ca.name, func(t *testing.T) {
var h Transports
err := h.Unmarshal(ca.vin)
require.NoError(t, err)
require.Equal(t, ca.h, h)
})
}
}
func TestTransportsMarshal(t *testing.T) {
for _, ca := range casesTransports {
t.Run(ca.name, func(t *testing.T) {
req := ca.h.Marshal()
require.Equal(t, ca.vout, req)
})
}
}
func FuzzTransportsUnmarshal(f *testing.F) {
for _, ca := range casesTransports {
f.Add(ca.vin[0])
}
for _, ca := range casesTransport {
f.Add(ca.vin[0])
}
f.Add("source=aa-14187")
f.Add("destination=aa")
f.Add("interleaved=")
f.Add("ttl=")
f.Add("port=")
f.Add("client_port=")
f.Add("server_port=")
f.Add("mode=")
f.Fuzz(func(_ *testing.T, b string) {
var h Transports
err := h.Unmarshal(base.HeaderValue{b})
if err != nil {
return
}
h.Marshal()
})
}
func TestTransportAdditionalErrors(t *testing.T) {
func() {
var h Transport

48
pkg/headers/transports.go Normal file
View File

@@ -0,0 +1,48 @@
package headers
import (
"fmt"
"strings"
"github.com/bluenviron/gortsplib/v4/pkg/base"
)
// Transports is a Transport header with multiple transports.
type Transports []Transport
// Unmarshal decodes a Transport header.
func (ts *Transports) Unmarshal(v base.HeaderValue) error {
if len(v) == 0 {
return fmt.Errorf("value not provided")
}
if len(v) > 1 {
return fmt.Errorf("value provided multiple times (%v)", v)
}
v0 := v[0]
transports := strings.Split(v0, ",") // , separated per RFC2326 section 12.39
*ts = make([]Transport, len(transports))
for i, transport := range transports {
var tr Transport
err := tr.Unmarshal(base.HeaderValue{strings.TrimLeft(transport, " ")})
if err != nil {
return err
}
(*ts)[i] = tr
}
return nil
}
// Marshal encodes a Transport header.
func (ts Transports) Marshal() base.HeaderValue {
vals := make([]string, len(ts))
for i, th := range ts {
vals[i] = th.Marshal()[0]
}
return base.HeaderValue{strings.Join(vals, ",")}
}

View File

@@ -0,0 +1,97 @@
package headers
import (
"testing"
"github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/stretchr/testify/require"
)
var casesTransports = []struct {
name string
vin base.HeaderValue
vout base.HeaderValue
h Transports
}{
{
"a",
base.HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode="PLAY", RTP/AVP/TCP;unicast;interleaved=0-1`},
base.HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode=play,RTP/AVP/TCP;unicast;interleaved=0-1`},
Transports{
{
Protocol: TransportProtocolUDP,
Delivery: deliveryPtr(TransportDeliveryUnicast),
ClientPorts: &[2]int{3456, 3457},
Mode: transportModePtr(TransportModePlay),
},
Transport{
Protocol: TransportProtocolTCP,
Delivery: deliveryPtr(TransportDeliveryUnicast),
InterleavedIDs: &[2]int{0, 1},
},
},
},
}
func TestTransportsUnmarshal(t *testing.T) {
for _, ca := range casesTransports {
t.Run(ca.name, func(t *testing.T) {
var h Transports
err := h.Unmarshal(ca.vin)
require.NoError(t, err)
require.Equal(t, ca.h, h)
})
}
}
func TestTransportsMarshal(t *testing.T) {
for _, ca := range casesTransports {
t.Run(ca.name, func(t *testing.T) {
req := ca.h.Marshal()
require.Equal(t, ca.vout, req)
})
}
}
func FuzzTransportsUnmarshal(f *testing.F) {
for _, ca := range casesTransports {
f.Add(ca.vin[0])
}
for _, ca := range casesTransport {
f.Add(ca.vin[0])
}
f.Add("source=aa-14187")
f.Add("destination=aa")
f.Add("interleaved=")
f.Add("ttl=")
f.Add("port=")
f.Add("client_port=")
f.Add("server_port=")
f.Add("mode=")
f.Fuzz(func(_ *testing.T, b string) {
var h Transports
err := h.Unmarshal(base.HeaderValue{b})
if err != nil {
return
}
h.Marshal()
})
}
func TestTransportsAdditionalErrors(t *testing.T) {
func() {
var h Transports
err := h.Unmarshal(base.HeaderValue{})
require.Error(t, err)
}()
func() {
var h Transports
err := h.Unmarshal(base.HeaderValue{"a", "b"})
require.Error(t, err)
}()
}

View File

@@ -224,6 +224,8 @@ func (e ErrClientUnsupportedScheme) Error() string {
}
// ErrClientRTSPSTCP is an error that can be returned by a client.
//
// Deprecated: not used anymore.
type ErrClientRTSPSTCP struct{}
// Error implements the error interface.

View File

@@ -128,12 +128,27 @@ func (e ErrServerMediasDifferentPaths) Error() string {
return "can't setup medias with different paths"
}
// ErrServerInvalidKeyMgmtHeader is an error that can be returned by a server.
type ErrServerInvalidKeyMgmtHeader struct {
Wrapped error
}
// Error implements the error interface.
func (e ErrServerInvalidKeyMgmtHeader) Error() string {
return fmt.Sprintf("invalid KeyMgmt header: %s", e.Wrapped.Error())
}
// ErrServerMediasDifferentProtocols is an error that can be returned by a server.
type ErrServerMediasDifferentProtocols struct{}
//
// Deprecated: replaced by ErrServerMediasDifferentTransports.
type ErrServerMediasDifferentProtocols = ErrServerMediasDifferentTransports
// ErrServerMediasDifferentTransports is an error that can be returned by a server.
type ErrServerMediasDifferentTransports struct{}
// Error implements the error interface.
func (e ErrServerMediasDifferentProtocols) Error() string {
return "can't setup medias with different protocols"
return "can't setup medias with different transports"
}
// ErrServerNoMediasSetup is an error that can be returned by a server.

143
pkg/mikey/header.go Normal file
View File

@@ -0,0 +1,143 @@
package mikey
import "fmt"
func boolToUint8(v bool) uint8 {
if v {
return 1
}
return 0
}
// DataType is a message data type.
type DataType uint8
// RFC3830, Table 6.1.a
const (
DataTypeInitiatorPSK DataType = 0
)
// CSIDMapType is a CS ID map type.
type CSIDMapType uint8
// RFC3830, Table 6.1.d
const (
CSIDMapTypeSRTPID CSIDMapType = 0
)
// SRTPIDEntry is an entry of a SRTP-ID map.
type SRTPIDEntry struct {
PolicyNo uint8
SSRC uint32
ROC uint32
}
// Header is a MIKEY header.
type Header struct {
Version uint8
DataType DataType
V bool
PRFFunc uint8
CSBID uint32
CSIDMapType CSIDMapType
CSIDMapInfo []SRTPIDEntry
}
func (h *Header) unmarshal(buf []byte) (int, payloadType, error) {
if len(buf) < 10 {
return 0, 0, fmt.Errorf("header too short")
}
n := 0
h.Version = buf[n]
n++
if h.Version != 1 {
return 0, 0, fmt.Errorf("unsupported version: %v", h.Version)
}
h.DataType = DataType(buf[n])
n++
if h.DataType != DataTypeInitiatorPSK {
return 0, 0, fmt.Errorf("unsupported data type: %v", h.DataType)
}
nextPayload := payloadType(buf[n])
n++
h.V = (buf[n] >> 7) != 0
h.PRFFunc = buf[n] & 0b01111111
n++
if h.V {
return 0, 0, fmt.Errorf("unsupported V: %v", h.V)
}
if h.PRFFunc != 0 {
return 0, 0, fmt.Errorf("unsupported PRFFunc: %v", h.PRFFunc)
}
h.CSBID = uint32(buf[n])<<24 | uint32(buf[n+1])<<16 | uint32(buf[n+2])<<8 | uint32(buf[n+3])
n += 4
numCS := buf[n]
n++
h.CSIDMapType = CSIDMapType(buf[n])
n++
if h.CSIDMapType != CSIDMapTypeSRTPID {
return 0, 0, fmt.Errorf("unsupported map type: %d", h.CSIDMapType)
}
if len(buf[n:]) < (int(numCS) * 9) {
return 0, 0, fmt.Errorf("header too short")
}
h.CSIDMapInfo = make([]SRTPIDEntry, numCS)
for i := range numCS {
h.CSIDMapInfo[i].PolicyNo = buf[n]
n++
h.CSIDMapInfo[i].SSRC = uint32(buf[n])<<24 | uint32(buf[n+1])<<16 | uint32(buf[n+2])<<8 | uint32(buf[n+3])
n += 4
h.CSIDMapInfo[i].ROC = uint32(buf[n])<<24 | uint32(buf[n+1])<<16 | uint32(buf[n+2])<<8 | uint32(buf[n+3])
n += 4
}
return n, nextPayload, nil
}
func (h *Header) marshalSize() int {
return 10 + len(h.CSIDMapInfo)*9
}
func (h *Header) marshalTo(buf []byte, nextPayload payloadType) (int, error) {
buf[0] = h.Version
buf[1] = byte(h.DataType)
buf[2] = byte(nextPayload)
buf[3] = boolToUint8(h.V)<<7 | h.PRFFunc
buf[4] = byte(h.CSBID >> 24)
buf[5] = byte(h.CSBID >> 16)
buf[6] = byte(h.CSBID >> 8)
buf[7] = byte(h.CSBID)
buf[8] = byte(len(h.CSIDMapInfo))
buf[9] = byte(h.CSIDMapType)
n := 10
for _, mi := range h.CSIDMapInfo {
buf[n] = mi.PolicyNo
buf[n+1] = byte(mi.SSRC >> 24)
buf[n+2] = byte(mi.SSRC >> 16)
buf[n+3] = byte(mi.SSRC >> 8)
buf[n+4] = byte(mi.SSRC)
buf[n+5] = byte(mi.ROC >> 24)
buf[n+6] = byte(mi.ROC >> 16)
buf[n+7] = byte(mi.ROC >> 8)
buf[n+8] = byte(mi.ROC)
n += 9
}
return n, nil
}

91
pkg/mikey/message.go Normal file
View File

@@ -0,0 +1,91 @@
// Package mikey contains functions to decode and encode MIKEY messages.
package mikey
import "fmt"
// Message is a MIKEY message.
type Message struct {
Header Header
Payloads []Payload
}
// Unmarshal decodes a Message.
func (m *Message) Unmarshal(buf []byte) error {
n, nextPayloadType, err := m.Header.unmarshal(buf)
if err != nil {
return err
}
for nextPayloadType != 0 {
var payload Payload
switch nextPayloadType {
case payloadTypeKEMAC:
payload = &PayloadKEMAC{}
case payloadTypeT:
payload = &PayloadT{}
case payloadTypeSP:
payload = &PayloadSP{}
case payloadTypeRAND:
payload = &PayloadRAND{}
default:
return fmt.Errorf("unsupported payload type: %d", nextPayloadType)
}
payloadLen, err := payload.unmarshal(buf[n:])
if err != nil {
return fmt.Errorf("unable to parse payload %d: %w", nextPayloadType, err)
}
nextPayloadType = payloadType(buf[n])
n += payloadLen
m.Payloads = append(m.Payloads, payload)
}
if n < len(buf) {
return fmt.Errorf("detected %d unparsed bytes", len(buf)-n)
}
return nil
}
func (m *Message) marshalSize() int {
n := m.Header.marshalSize()
for _, pl := range m.Payloads {
n += pl.marshalSize()
}
return n
}
// Marshal encodes a Message.
func (m *Message) Marshal() ([]byte, error) {
buf := make([]byte, m.marshalSize())
var nextPayloadType payloadType
if len(m.Payloads) != 0 {
nextPayloadType = m.Payloads[0].typ()
}
n, err := m.Header.marshalTo(buf, nextPayloadType)
if err != nil {
return nil, err
}
for i, pl := range m.Payloads {
if i != len(m.Payloads)-1 {
nextPayloadType = m.Payloads[i+1].typ()
} else {
nextPayloadType = 0
}
buf[n] = byte(nextPayloadType)
n2, err := pl.marshalTo(buf[n:])
if err != nil {
return nil, err
}
n += n2
}
return buf, nil
}

324
pkg/mikey/message_test.go Normal file
View File

@@ -0,0 +1,324 @@
package mikey
import (
"testing"
"github.com/stretchr/testify/require"
)
var cases = []struct {
name string
enc []byte
dec Message
}{
{
"a",
[]byte{
0x01, 0x00, 0x05, 0x00, 0xe6, 0x9d, 0x51, 0xf8,
0x01, 0x00, 0x00, 0x30, 0x68, 0x57, 0x60, 0x00,
0x00, 0x00, 0x00, 0x0b, 0x00, 0xeb, 0xfe, 0x6f,
0x2d, 0xb1, 0xc1, 0x3f, 0xd0, 0x0a, 0x10, 0xc2,
0xdd, 0xe4, 0x43, 0xa8, 0x49, 0x30, 0xa5, 0x75,
0x7a, 0x7e, 0xd9, 0xc3, 0xa4, 0x17, 0xfb, 0x01,
0x00, 0x00, 0x00, 0x15, 0x00, 0x01, 0x01, 0x01,
0x01, 0x10, 0x02, 0x01, 0x01, 0x03, 0x01, 0x0a,
0x07, 0x01, 0x01, 0x08, 0x01, 0x01, 0x0a, 0x01,
0x01, 0x00, 0x00, 0x00, 0x22, 0x00, 0x20, 0x00,
0x1e, 0x90, 0x91, 0x78, 0x3d, 0xfc, 0xe8, 0xdd,
0xcd, 0x44, 0x3a, 0x53, 0x50, 0x8b, 0x64, 0x50,
0x9f, 0x35, 0xbd, 0x8a, 0x86, 0xbc, 0x4d, 0x8b,
0x76, 0x37, 0xa5, 0x02, 0x49, 0x3d, 0xaf, 0x00,
},
Message{
Header: Header{
Version: 1,
CSBID: 3869069816,
CSIDMapInfo: []SRTPIDEntry{
{
PolicyNo: 0,
SSRC: 812144480,
ROC: 0,
},
},
},
Payloads: []Payload{
&PayloadT{
TSType: 0,
TSValue: 17005151485044015056,
},
&PayloadRAND{
Data: []byte{
0xc2, 0xdd, 0xe4, 0x43, 0xa8, 0x49, 0x30, 0xa5,
0x75, 0x7a, 0x7e, 0xd9, 0xc3, 0xa4, 0x17, 0xfb,
},
},
&PayloadSP{
PolicyParams: []PayloadSPPolicyParam{
{
Type: PayloadSPPolicyParamTypeEncrAlg,
Value: []byte{1},
},
{
Type: PayloadSPPolicyParamTypeSessionEncrKeyLen,
Value: []byte{0x10},
},
{
Type: PayloadSPPolicyParamTypeAuthAlg,
Value: []byte{1},
},
{
Type: PayloadSPPolicyParamTypeSessionAuthKeyLen,
Value: []byte{0x0a},
},
{
Type: PayloadSPPolicyParamTypeSRTPEncrOffOn,
Value: []byte{1},
},
{
Type: PayloadSPPolicyParamTypeSRTCPEncrOffOn,
Value: []byte{1},
},
{
Type: PayloadSPPolicyParamTypeSRTPAuthOffOn,
Value: []byte{1},
},
},
},
&PayloadKEMAC{
SubPayloads: []*SubPayloadKeyData{
{
Type: 2,
KeyData: []byte{
0x90, 0x91, 0x78, 0x3d, 0xfc, 0xe8, 0xdd, 0xcd,
0x44, 0x3a, 0x53, 0x50, 0x8b, 0x64, 0x50, 0x9f,
0x35, 0xbd, 0x8a, 0x86, 0xbc, 0x4d, 0x8b, 0x76,
0x37, 0xa5, 0x02, 0x49, 0x3d, 0xaf,
},
},
},
},
},
},
},
{
"b",
[]byte{
0x01, 0x00, 0x05, 0x00, 0xfe, 0xaf, 0x97, 0x52,
0x01, 0x00, 0x00, 0xcc, 0x83, 0x62, 0x37, 0x00,
0x00, 0x00, 0x00, 0x0b, 0x00, 0xeb, 0xfe, 0xf6,
0x6b, 0xa2, 0x8c, 0x9b, 0x84, 0x0a, 0x10, 0x27,
0x6e, 0x94, 0x18, 0x0e, 0x88, 0x75, 0xc2, 0xea,
0xad, 0x31, 0xd8, 0x2f, 0x86, 0x46, 0x20, 0x01,
0x00, 0x00, 0x00, 0x15, 0x00, 0x01, 0x01, 0x01,
0x01, 0x10, 0x02, 0x01, 0x01, 0x03, 0x01, 0x0a,
0x07, 0x01, 0x01, 0x08, 0x01, 0x01, 0x0a, 0x01,
0x01, 0x00, 0x00, 0x00, 0x22, 0x00, 0x20, 0x00,
0x1e, 0x99, 0x1b, 0x0f, 0x14, 0x8f, 0x09, 0x4b,
0x4e, 0x5b, 0x8b, 0x30, 0x53, 0xcd, 0x62, 0x76,
0x87, 0x7f, 0xcc, 0xed, 0x18, 0x66, 0xf1, 0x41,
0x77, 0x2a, 0xdd, 0xdd, 0xe7, 0x06, 0x4b, 0x00,
},
Message{
Header: Header{
Version: 1,
CSBID: 4272920402,
CSIDMapInfo: []SRTPIDEntry{
{
PolicyNo: 0,
SSRC: 3431162423,
ROC: 0,
},
},
},
Payloads: []Payload{ //nolint:dupl
&PayloadT{
TSValue: 17005300185146628996,
},
&PayloadRAND{
Data: []byte{
0x27, 0x6e, 0x94, 0x18, 0x0e, 0x88, 0x75, 0xc2,
0xea, 0xad, 0x31, 0xd8, 0x2f, 0x86, 0x46, 0x20,
},
},
&PayloadSP{
PolicyParams: []PayloadSPPolicyParam{
{
Type: PayloadSPPolicyParamTypeEncrAlg,
Value: []byte{1},
},
{
Type: PayloadSPPolicyParamTypeSessionEncrKeyLen,
Value: []byte{0x10},
},
{
Type: PayloadSPPolicyParamTypeAuthAlg,
Value: []byte{1},
},
{
Type: PayloadSPPolicyParamTypeSessionAuthKeyLen,
Value: []byte{0x0a},
},
{
Type: PayloadSPPolicyParamTypeSRTPEncrOffOn,
Value: []byte{1},
},
{
Type: PayloadSPPolicyParamTypeSRTCPEncrOffOn,
Value: []byte{1},
},
{
Type: PayloadSPPolicyParamTypeSRTPAuthOffOn,
Value: []byte{1},
},
},
},
&PayloadKEMAC{
SubPayloads: []*SubPayloadKeyData{
{
Type: 2,
KeyData: []byte{
0x99, 0x1b, 0x0f, 0x14, 0x8f, 0x09, 0x4b, 0x4e,
0x5b, 0x8b, 0x30, 0x53, 0xcd, 0x62, 0x76, 0x87,
0x7f, 0xcc, 0xed, 0x18, 0x66, 0xf1, 0x41, 0x77,
0x2a, 0xdd, 0xdd, 0xe7, 0x06, 0x4b,
},
},
},
},
},
},
},
{
"c",
[]byte{
0x01, 0x00, 0x05, 0x00, 0x7d, 0xe1, 0x27, 0xa6,
0x02, 0x00, 0x00, 0xcc, 0x83, 0x62, 0x37, 0x00,
0x00, 0x00, 0x00, 0x00, 0xb5, 0xcc, 0x3b, 0xf2,
0x00, 0x00, 0x00, 0x00, 0x0b, 0x00, 0xeb, 0xfe,
0xf6, 0x6b, 0xa2, 0xb1, 0xf6, 0x87, 0x0a, 0x10,
0x61, 0xbb, 0x19, 0x94, 0x32, 0x53, 0x03, 0x56,
0xa2, 0xd1, 0x88, 0x07, 0x15, 0x23, 0x75, 0x95,
0x01, 0x00, 0x00, 0x00, 0x15, 0x00, 0x01, 0x01,
0x01, 0x01, 0x10, 0x02, 0x01, 0x01, 0x03, 0x01,
0x0a, 0x07, 0x01, 0x01, 0x08, 0x01, 0x01, 0x0a,
0x01, 0x01, 0x00, 0x00, 0x00, 0x22, 0x00, 0x20,
0x00, 0x1e, 0x99, 0x1b, 0x0f, 0x14, 0x8f, 0x09,
0x4b, 0x4e, 0x5b, 0x8b, 0x30, 0x53, 0xcd, 0x62,
0x76, 0x87, 0x7f, 0xcc, 0xed, 0x18, 0x66, 0xf1,
0x41, 0x77, 0x2a, 0xdd, 0xdd, 0xe7, 0x06, 0x4b,
0x00,
},
Message{
Header: Header{
Version: 1,
CSBID: 2111907750,
CSIDMapInfo: []SRTPIDEntry{
{
PolicyNo: 0,
SSRC: 3431162423,
ROC: 0,
},
{
PolicyNo: 0,
SSRC: 3050060786,
ROC: 0,
},
},
},
Payloads: []Payload{ //nolint:dupl
&PayloadT{
TSValue: 17005300185149077127,
},
&PayloadRAND{
Data: []byte{
0x61, 0xbb, 0x19, 0x94, 0x32, 0x53, 0x03, 0x56,
0xa2, 0xd1, 0x88, 0x07, 0x15, 0x23, 0x75, 0x95,
},
},
&PayloadSP{
PolicyParams: []PayloadSPPolicyParam{
{
Type: PayloadSPPolicyParamTypeEncrAlg,
Value: []byte{1},
},
{
Type: PayloadSPPolicyParamTypeSessionEncrKeyLen,
Value: []byte{0x10},
},
{
Type: PayloadSPPolicyParamTypeAuthAlg,
Value: []byte{1},
},
{
Type: PayloadSPPolicyParamTypeSessionAuthKeyLen,
Value: []byte{0x0a},
},
{
Type: PayloadSPPolicyParamTypeSRTPEncrOffOn,
Value: []byte{1},
},
{
Type: PayloadSPPolicyParamTypeSRTCPEncrOffOn,
Value: []byte{1},
},
{
Type: PayloadSPPolicyParamTypeSRTPAuthOffOn,
Value: []byte{1},
},
},
},
&PayloadKEMAC{
SubPayloads: []*SubPayloadKeyData{
{
Type: 2,
KeyData: []byte{
0x99, 0x1b, 0x0f, 0x14, 0x8f, 0x09, 0x4b, 0x4e,
0x5b, 0x8b, 0x30, 0x53, 0xcd, 0x62, 0x76, 0x87,
0x7f, 0xcc, 0xed, 0x18, 0x66, 0xf1, 0x41, 0x77,
0x2a, 0xdd, 0xdd, 0xe7, 0x06, 0x4b,
},
},
},
},
},
},
},
}
func TestUnmarshal(t *testing.T) {
for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) {
var dec Message
err := dec.Unmarshal(ca.enc)
require.NoError(t, err)
require.Equal(t, ca.dec, dec)
})
}
}
func TestMarshal(t *testing.T) {
for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) {
enc, err := ca.dec.Marshal()
require.NoError(t, err)
require.Equal(t, ca.enc, enc)
})
}
}
func FuzzUnmarshal(f *testing.F) {
for _, ca := range cases {
f.Add(ca.enc)
}
f.Fuzz(func(t *testing.T, b []byte) {
var msg Message
err := msg.Unmarshal(b)
if err != nil {
return
}
_, err = msg.Marshal()
require.NoError(t, err)
})
}

20
pkg/mikey/payload.go Normal file
View File

@@ -0,0 +1,20 @@
package mikey
type payloadType uint8
// RFC3830, table 6.1.b
const (
payloadTypeKEMAC payloadType = 1
payloadTypeT payloadType = 5
payloadTypeSP payloadType = 10
payloadTypeRAND payloadType = 11
payloadTypeKeyData payloadType = 20
)
// Payload is a MIKEY payload.
type Payload interface {
unmarshal(buf []byte) (int, error)
typ() payloadType
marshalSize() int
marshalTo(buf []byte) (int, error)
}

131
pkg/mikey/payload_kemac.go Normal file
View File

@@ -0,0 +1,131 @@
package mikey
import "fmt"
// PayloadKEMACEncrAlg is a encryption algorithm.
type PayloadKEMACEncrAlg uint8
// RFC3830, Table 6.2.a
const (
PayloadKEMACEncrAlgNULL PayloadKEMACEncrAlg = 0
)
// PayloadKEMACMacAlg is a authentication algorithm.
type PayloadKEMACMacAlg uint8
// RFC3830, Table 6.2.b
const (
PayloadKEMACMacAlgNULL PayloadKEMACMacAlg = 0
)
// PayloadKEMAC is a Key data transport payload.
type PayloadKEMAC struct {
EncrAlg PayloadKEMACEncrAlg
SubPayloads []*SubPayloadKeyData
MacAlg PayloadKEMACMacAlg
}
func (p *PayloadKEMAC) unmarshal(buf []byte) (int, error) {
if len(buf) < 4 {
return 0, fmt.Errorf("buffer too short")
}
n := 1
p.EncrAlg = PayloadKEMACEncrAlg(buf[n])
n++
if p.EncrAlg != PayloadKEMACEncrAlgNULL {
return 0, fmt.Errorf("unsupported encr alg: %v", p.EncrAlg)
}
encrDataLen := int(uint16(buf[n])<<8 | uint16(buf[n+1]))
n += 2
if len(buf[n:]) < (encrDataLen + 1) {
return 0, fmt.Errorf("buffer too short")
}
encrData := buf[n : n+encrDataLen]
n += encrDataLen
sn := 0
for {
sp := &SubPayloadKeyData{}
spLen, err := sp.unmarshal(encrData[sn:])
if err != nil {
return 0, err
}
nextPayloadType := payloadType(encrData[sn])
sn += spLen
p.SubPayloads = append(p.SubPayloads, sp)
if nextPayloadType == 0 {
break
}
if nextPayloadType != payloadTypeKeyData {
return 0, fmt.Errorf("unsupported payload type: %v", nextPayloadType)
}
}
if sn != len(encrData) {
return 0, fmt.Errorf("detected unread bytes")
}
p.MacAlg = PayloadKEMACMacAlg(buf[n])
n++
if p.MacAlg != PayloadKEMACMacAlgNULL {
return 0, fmt.Errorf("unsupported mac alg: %v", p.MacAlg)
}
return n, nil
}
func (*PayloadKEMAC) typ() payloadType {
return payloadTypeKEMAC
}
func (p *PayloadKEMAC) marshalSize() int {
n := 5
for _, sp := range p.SubPayloads {
n += sp.marshalSize()
}
return n
}
func (p *PayloadKEMAC) marshalTo(buf []byte) (int, error) {
buf[1] = byte(p.EncrAlg)
encrDataLen := 0
for _, sp := range p.SubPayloads {
encrDataLen += sp.marshalSize()
}
buf[2] = byte(encrDataLen >> 8)
buf[3] = byte(encrDataLen)
n := 4
for i, sp := range p.SubPayloads {
var nextPayloadType payloadType
if i != len(p.SubPayloads)-1 {
nextPayloadType = payloadTypeKeyData
} else {
nextPayloadType = 0
}
buf[n] = byte(nextPayloadType)
n2, err := sp.marshalTo(buf[n:])
if err != nil {
return 0, err
}
n += n2
}
buf[n] = byte(p.MacAlg)
n++
return n, nil
}

46
pkg/mikey/payload_rand.go Normal file
View File

@@ -0,0 +1,46 @@
package mikey
import "fmt"
// PayloadRAND is a payload with random data.
type PayloadRAND struct {
Data []byte
}
func (p *PayloadRAND) unmarshal(buf []byte) (int, error) {
if len(buf) < 2 {
return 0, fmt.Errorf("buffer too short")
}
n := 1
dataLen := int(buf[n])
n++
if dataLen < 16 {
return 0, fmt.Errorf("invalid data len: %v", dataLen)
}
if len(buf[n:]) < dataLen {
return 0, fmt.Errorf("buffer too short")
}
p.Data = buf[n : n+dataLen]
n += dataLen
return n, nil
}
func (*PayloadRAND) typ() payloadType {
return payloadTypeRAND
}
func (p *PayloadRAND) marshalSize() int {
return 2 + len(p.Data)
}
func (p *PayloadRAND) marshalTo(buf []byte) (int, error) {
buf[1] = uint8(len(p.Data))
n := 2
n += copy(buf[2:], p.Data)
return n, nil
}

129
pkg/mikey/payload_sp.go Normal file
View File

@@ -0,0 +1,129 @@
package mikey
import "fmt"
// PayloadSPProtType is a security protocol.
type PayloadSPProtType uint8
// RFC3830, Table 6.2.a
const (
PayloadSPProtTypeSRTP PayloadSPProtType = 0
)
// PayloadSPPolicyParamType is a policy param type.
type PayloadSPPolicyParamType uint8
// RFC3830, Table 6.10.1.a
const (
PayloadSPPolicyParamTypeEncrAlg PayloadSPPolicyParamType = 0
PayloadSPPolicyParamTypeSessionEncrKeyLen PayloadSPPolicyParamType = 1
PayloadSPPolicyParamTypeAuthAlg PayloadSPPolicyParamType = 2
PayloadSPPolicyParamTypeSessionAuthKeyLen PayloadSPPolicyParamType = 3
PayloadSPPolicyParamTypeSessionSaltKeyLen PayloadSPPolicyParamType = 4
PayloadSPPolicyParamTypeSRTPPseudoRandFun PayloadSPPolicyParamType = 5
PayloadSPPolicyParamTypeKeyDerRate PayloadSPPolicyParamType = 6
PayloadSPPolicyParamTypeSRTPEncrOffOn PayloadSPPolicyParamType = 7
PayloadSPPolicyParamTypeSRTCPEncrOffOn PayloadSPPolicyParamType = 8
PayloadSPPolicyParamTypeSenderFECOrder PayloadSPPolicyParamType = 9
PayloadSPPolicyParamTypeSRTPAuthOffOn PayloadSPPolicyParamType = 10
PayloadSPPolicyParamTypeAuthTagLen PayloadSPPolicyParamType = 11
PayloadSPPolicyParamTypeSRTPPrefixLen PayloadSPPolicyParamType = 12
)
// PayloadSPPolicyParam is a policy param.
type PayloadSPPolicyParam struct {
Type PayloadSPPolicyParamType
Value []byte
}
// PayloadSP is a security policy payload.
type PayloadSP struct {
PolicyNo uint8
ProtType PayloadSPProtType
PolicyParams []PayloadSPPolicyParam
}
func (p *PayloadSP) unmarshal(buf []byte) (int, error) {
if len(buf) < 5 {
return 0, fmt.Errorf("buffer too short")
}
n := 1
p.PolicyNo = buf[n]
n++
p.ProtType = PayloadSPProtType(buf[n])
n++
if p.ProtType != 0 {
return 0, fmt.Errorf("unsupported prot type: %v", p.ProtType)
}
policyParamLength := uint16(buf[n])<<8 | uint16(buf[n+1])
n += 2
end := n + int(policyParamLength)
for {
if n > end {
return 0, fmt.Errorf("policy param overflowed")
}
if n == end {
break
}
if len(buf[n:]) < 2 {
return 0, fmt.Errorf("buffer too short")
}
typ := PayloadSPPolicyParamType(buf[n])
n++
valueLen := int(buf[n])
n++
if len(buf[n:]) < valueLen {
return 0, fmt.Errorf("buffer too short")
}
value := buf[n : n+valueLen]
n += valueLen
p.PolicyParams = append(p.PolicyParams, PayloadSPPolicyParam{
Type: typ,
Value: value,
})
}
return n, nil
}
func (*PayloadSP) typ() payloadType {
return payloadTypeSP
}
func (p *PayloadSP) marshalSize() int {
n := 5 + 2*len(p.PolicyParams)
for _, pp := range p.PolicyParams {
n += len(pp.Value)
}
return n
}
func (p *PayloadSP) marshalTo(buf []byte) (int, error) {
buf[1] = p.PolicyNo
buf[2] = byte(p.ProtType)
policyParamLength := 0
for _, pp := range p.PolicyParams {
policyParamLength += 2 + len(pp.Value)
}
buf[3] = byte(policyParamLength >> 8)
buf[4] = byte(policyParamLength)
n := 5
for _, pp := range p.PolicyParams {
buf[n] = byte(pp.Type)
buf[n+1] = uint8(len(pp.Value))
n += 2
n += copy(buf[n:], pp.Value)
}
return n, nil
}

56
pkg/mikey/payload_t.go Normal file
View File

@@ -0,0 +1,56 @@
package mikey
import "fmt"
// PayloadT is a timestamp payload.
type PayloadT struct {
TSType uint8
TSValue uint64
}
func (p *PayloadT) unmarshal(buf []byte) (int, error) {
if len(buf) < 10 {
return 0, fmt.Errorf("buffer too short")
}
n := 1
p.TSType = buf[n]
n++
if p.TSType != 0 {
return 0, fmt.Errorf("unsupported TSType: %v", p.TSType)
}
p.TSValue = uint64(buf[n])<<56 |
uint64(buf[n+1])<<48 |
uint64(buf[n+2])<<40 |
uint64(buf[n+3])<<32 |
uint64(buf[n+4])<<24 |
uint64(buf[n+5])<<16 |
uint64(buf[n+6])<<8 |
uint64(buf[n+7])
n += 8
return n, nil
}
func (*PayloadT) typ() payloadType {
return payloadTypeT
}
func (p *PayloadT) marshalSize() int {
return 10
}
func (p *PayloadT) marshalTo(buf []byte) (int, error) {
buf[1] = p.TSType
buf[2] = byte(p.TSValue >> 56)
buf[3] = byte(p.TSValue >> 48)
buf[4] = byte(p.TSValue >> 40)
buf[5] = byte(p.TSValue >> 32)
buf[6] = byte(p.TSValue >> 24)
buf[7] = byte(p.TSValue >> 16)
buf[8] = byte(p.TSValue >> 8)
buf[9] = byte(p.TSValue)
return 10, nil
}

View File

@@ -0,0 +1,66 @@
package mikey
import "fmt"
// SubPayloadKeyDataKeyType is a data key type.
type SubPayloadKeyDataKeyType uint8
// RFC3830, table 6.13.a
const (
SubPayloadKeyDataKeyTypeTEK SubPayloadKeyDataKeyType = 2
)
// SubPayloadKeyData is a key data sub-payload.
type SubPayloadKeyData struct {
Type SubPayloadKeyDataKeyType
KV uint8
KeyData []byte
}
func (p *SubPayloadKeyData) unmarshal(buf []byte) (int, error) {
if len(buf) < 4 {
return 0, fmt.Errorf("buffer too short")
}
n := 1
p.Type = SubPayloadKeyDataKeyType(buf[n] >> 4)
p.KV = buf[n] & 0b00001111
n++
if p.Type != SubPayloadKeyDataKeyTypeTEK {
return 0, fmt.Errorf("unsupported key type: %v", p.Type)
}
if p.KV != 0 {
return 0, fmt.Errorf("unsupported KV: %v", p.KV)
}
keyDataLen := int(uint16(buf[n])<<8 | uint16(buf[n+1]))
n += 2
if len(buf[n:]) < keyDataLen {
return 0, fmt.Errorf("buffer too short")
}
p.KeyData = buf[n : n+keyDataLen]
n += keyDataLen
return n, nil
}
func (p *SubPayloadKeyData) marshalSize() int {
return 4 + len(p.KeyData)
}
func (p *SubPayloadKeyData) marshalTo(buf []byte) (int, error) {
buf[1] = byte(p.Type)<<4 | p.KV
keyDataLen := len(p.KeyData)
buf[2] = byte(keyDataLen >> 8)
buf[3] = byte(keyDataLen)
n := 4
n += copy(buf[n:], p.KeyData)
return n, nil
}

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x00\xe6\x9dQ\xf8\x01\x00\x000hW`\x00\x00\x00\x00\v\x00\xeb\xfeo-\xb1\xc1?\xd0\n\x10\xc2\xdd\xe4C\xa8I0\xa5uz~\xd9ä\x17\xfb\x01\x00\x00\x00\x15\x00\x01\x01\x01\x01\x10\x02\x8a\x86\xbc\x01\n\a\x01\x01\b\x01\x01\n\x01\x01\x00\x00\x00\"\x00 \x00\x1e\x90\x91x=\xfc\xe8\xdd\xcdD:SP\x8bdP\x9f5\xbd\x8a\x86\xbcM\x8bv7\xa5\x02I=\xaf\x00")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x000\x000000\x01\x00")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n0")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\x01\x1000000000000000000\x00\x00\"\x00 \x00\x00000000000000000000000000000000\x00")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x000\x000000\x01")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\x00\x00000000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x10000000000000000000\x0000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00\x00 00000000000000000000000000000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x10000000000000000000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00\x00 000000000000000000000000000000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x00000000000\x000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x000000000000\x00000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("0000000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x000\x00000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01000000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x0000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\x0100000000000000000000000000000000000000000000000000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00\x00\"0 00000000000000000000000000000000\x00")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00\x00\x020 \x00")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00\x00\"0 \x00\x00000000000000000000000000000000\x00")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00\x00 00000000000000000000000000000000\x00")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x000\x000000\x01\x000000000000000000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00\x00 0!000000000000000000000000000000\x00")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x000000000000\x000000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x000000000000000000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x000\x9d000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x010")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x10000000000000000000\x00")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x000\x000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x10000000000000000000\x000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x000\x000000\x010")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x0000000000")

View File

@@ -0,0 +1,2 @@
go test fuzz v1
[]byte("\x01\x000\x00000")

View File

@@ -161,7 +161,7 @@ func (rr *RTCPReceiver) report() rtcp.Packet {
}
if rr.firstSenderReportReceived {
// middle 32 bits out of 64 in the NTP timestamp of last sender report
// middle 32 bits out of 64 in the NTP of last sender report
report.Reports[0].LastSenderReport = uint32(rr.lastSenderReportTimeNTP >> 16)
// delay, expressed in units of 1/65536 seconds, between
@@ -267,7 +267,7 @@ func (rr *RTCPReceiver) packetNTPUnsafe(ts uint32) (time.Time, bool) {
return ntpTimeRTCPToGo(rr.lastSenderReportTimeNTP).Add(timeDiffGo), true
}
// PacketNTP returns the NTP timestamp of the packet.
// PacketNTP returns the NTP (absolute timestamp) of the packet.
func (rr *RTCPReceiver) PacketNTP(ts uint32) (time.Time, bool) {
rr.mutex.Lock()
defer rr.mutex.Unlock()

View File

@@ -196,14 +196,6 @@ func (s *Server) Start() error {
s.checkStreamPeriod = 1 * time.Second
}
if s.TLSConfig != nil && s.UDPRTPAddress != "" {
return fmt.Errorf("TLS can't be used with UDP")
}
if s.TLSConfig != nil && s.MulticastIPRange != "" {
return fmt.Errorf("TLS can't be used with UDP-multicast")
}
if s.RTSPAddress == "" {
return fmt.Errorf("RTSPAddress not provided")
}

View File

@@ -2,6 +2,7 @@ package gortsplib
import (
"context"
"crypto/rand"
"crypto/tls"
"errors"
"net"
@@ -17,6 +18,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/headers"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
)
func getSessionID(header base.Header) string {
@@ -51,7 +53,97 @@ func checkBackChannelsEnabled(header base.Header) bool {
return false
}
func prepareForDescribe(d *description.Session, multicast bool, backChannels bool) *description.Session {
func mikeyGenerate(ctx *wrappedSRTPContext) (*mikey.Message, error) {
csbID, err := randUint32()
if err != nil {
return nil, err
}
msg := &mikey.Message{
Header: mikey.Header{
Version: 1,
CSBID: csbID,
},
}
msg.Header.CSIDMapInfo = make([]mikey.SRTPIDEntry, len(ctx.ssrcs))
n := 0
for _, ssrc := range ctx.ssrcs {
msg.Header.CSIDMapInfo[n] = mikey.SRTPIDEntry{
PolicyNo: 0,
SSRC: ssrc,
ROC: ctx.roc(ssrc),
}
n++
}
randData := make([]byte, 16)
_, err = rand.Read(randData)
if err != nil {
return nil, err
}
msg.Payloads = []mikey.Payload{
&mikey.PayloadT{
TSType: 0,
TSValue: mikeyEncodeTime(time.Now()),
},
&mikey.PayloadRAND{
Data: randData,
},
&mikey.PayloadSP{
PolicyParams: []mikey.PayloadSPPolicyParam{
{
Type: mikey.PayloadSPPolicyParamTypeEncrAlg,
Value: []byte{1},
},
{
Type: mikey.PayloadSPPolicyParamTypeSessionEncrKeyLen,
Value: []byte{0x10},
},
{
Type: mikey.PayloadSPPolicyParamTypeAuthAlg,
Value: []byte{1},
},
{
Type: mikey.PayloadSPPolicyParamTypeSessionAuthKeyLen,
Value: []byte{0x0a},
},
{
Type: mikey.PayloadSPPolicyParamTypeSRTPEncrOffOn,
Value: []byte{1},
},
{
Type: mikey.PayloadSPPolicyParamTypeSRTCPEncrOffOn,
Value: []byte{1},
},
{
Type: mikey.PayloadSPPolicyParamTypeSRTPAuthOffOn,
Value: []byte{1},
},
},
},
&mikey.PayloadKEMAC{
SubPayloads: []*mikey.SubPayloadKeyData{
{
Type: mikey.SubPayloadKeyDataKeyTypeTEK,
KeyData: ctx.key,
},
},
},
}
return msg, nil
}
func prepareForDescribe(
d *description.Session,
multicast bool,
backChannels bool,
secure bool,
medias map[*description.Media]*serverStreamMedia,
) (*description.Session, error) {
out := &description.Session{
Title: d.Title,
Multicast: multicast,
@@ -60,19 +152,32 @@ func prepareForDescribe(d *description.Session, multicast bool, backChannels boo
for i, medi := range d.Medias {
if !medi.IsBackChannel || backChannels {
var keyMgmtMikey *mikey.Message
if secure {
sm := medias[medi]
var err error
keyMgmtMikey, err = mikeyGenerate(sm.srtpOutCtx)
if err != nil {
return nil, err
}
}
out.Medias = append(out.Medias, &description.Media{
Type: medi.Type,
ID: medi.ID,
IsBackChannel: medi.IsBackChannel,
// we have to use trackID=number in order to support clients
// like the Grandstream GXV3500.
Control: "trackID=" + strconv.FormatInt(int64(i), 10),
Formats: medi.Formats,
Control: "trackID=" + strconv.FormatInt(int64(i), 10),
Secure: secure,
KeyMgmtMikey: keyMgmtMikey,
Formats: medi.Formats,
})
}
}
return out
return out, nil
}
func credentialsProvided(req *base.Request) bool {
@@ -160,7 +265,7 @@ func (sc *ServerConn) UserData() interface{} {
return sc.userData
}
// Session returns associated session.
// Session returns the associated session.
func (sc *ServerConn) Session() *ServerSession {
return sc.session
}
@@ -370,13 +475,28 @@ func (sc *ServerConn) handleRequestInner(req *base.Request) (*base.Response, err
return res, err
}
desc := prepareForDescribe(
var desc *description.Session
desc, err = prepareForDescribe(
stream.Desc,
checkMulticastEnabled(sc.s.MulticastIPRange, query),
checkBackChannelsEnabled(req.Header),
sc.s.TLSConfig != nil,
stream.medias,
)
if err != nil {
return &base.Response{
StatusCode: base.StatusInternalServerError,
}, err
}
var byts []byte
byts, err = desc.Marshal(false)
if err != nil {
return &base.Response{
StatusCode: base.StatusInternalServerError,
}, err
}
byts, _ := desc.Marshal(false)
res.Body = byts
}

View File

@@ -2,6 +2,7 @@ package gortsplib
import (
"bytes"
"crypto/rand"
"crypto/tls"
"net"
"strconv"
@@ -21,6 +22,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/gortsplib/v4/pkg/headers"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
"github.com/bluenviron/gortsplib/v4/pkg/sdp"
)
@@ -333,7 +335,7 @@ func TestServerPlaySetupErrors(t *testing.T) {
require.EqualError(t, ctx.Error, "media has already been setup")
case "different protocols":
require.EqualError(t, ctx.Error, "can't setup medias with different protocols")
require.EqualError(t, ctx.Error, "can't setup medias with different transports")
}
close(nconnClosed)
},
@@ -574,13 +576,48 @@ func TestServerPlaySetupErrorSameUDPPortsAndIP(t *testing.T) {
}
func TestServerPlay(t *testing.T) {
for _, transport := range []string{
"udp",
"multicast",
"tcp",
"tls",
for _, ca := range []struct {
scheme string
transport string
secure string
}{
{
"rtsp",
"udp",
"unsecure",
},
{
"rtsp",
"multicast",
"unsecure",
},
{
"rtsp",
"tcp",
"unsecure",
},
{
"rtsps",
"tcp",
"unsecure",
},
{
"rtsps",
"udp",
"secure",
},
{
"rtsps",
"multicast",
"secure",
},
{
"rtsps",
"tcp",
"secure",
},
} {
t.Run(transport, func(t *testing.T) {
t.Run(ca.scheme+"_"+ca.transport+"_"+ca.secure, func(t *testing.T) {
var stream *ServerStream
nconnOpened := make(chan struct{})
nconnClosed := make(chan struct{})
@@ -598,10 +635,10 @@ func TestServerPlay(t *testing.T) {
},
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
s := ctx.Conn.Stats()
require.Greater(t, s.BytesSent, uint64(810))
require.Less(t, s.BytesSent, uint64(1150))
require.Greater(t, s.BytesReceived, uint64(440))
require.Less(t, s.BytesReceived, uint64(660))
require.Greater(t, s.BytesSent, uint64(800))
require.Less(t, s.BytesSent, uint64(1600))
require.Greater(t, s.BytesReceived, uint64(400))
require.Less(t, s.BytesReceived, uint64(950))
close(nconnClosed)
},
@@ -609,12 +646,12 @@ func TestServerPlay(t *testing.T) {
close(sessionOpened)
},
onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) {
if transport != "multicast" {
if ca.transport != "multicast" {
s := ctx.Session.Stats()
require.Greater(t, s.BytesSent, uint64(50))
require.Less(t, s.BytesSent, uint64(60))
require.Less(t, s.BytesSent, uint64(130))
require.Greater(t, s.BytesReceived, uint64(15))
require.Less(t, s.BytesReceived, uint64(25))
require.Less(t, s.BytesReceived, uint64(35))
}
close(sessionClosed)
@@ -632,12 +669,12 @@ func TestServerPlay(t *testing.T) {
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
require.NotNil(t, ctx.Conn.Session())
switch transport {
switch ca.transport {
case "udp":
v := TransportUDP
require.Equal(t, &v, ctx.Session.SetuppedTransport())
case "tcp", "tls":
case "tcp":
v := TransportTCP
require.Equal(t, &v, ctx.Session.SetuppedTransport())
@@ -651,14 +688,14 @@ func TestServerPlay(t *testing.T) {
// send RTCP packets directly to the session.
// these are sent after the response, only if onPlay returns StatusOK.
if transport != "multicast" {
if ca.transport != "multicast" {
err := ctx.Session.WritePacketRTCP(stream.Description().Medias[0], &testRTCPPacket)
require.NoError(t, err)
}
ctx.Session.OnPacketRTCPAny(func(medi *description.Media, pkt rtcp.Packet) {
// ignore multicast loopback
if transport == "multicast" && atomic.AddUint64(&counter, 1) <= 1 {
if ca.secure == "unsecure" && ca.transport == "multicast" && atomic.AddUint64(&counter, 1) <= 1 {
return
}
@@ -691,7 +728,7 @@ func TestServerPlay(t *testing.T) {
RTSPAddress: listenIP + ":8554",
}
switch transport {
switch ca.transport {
case "udp":
s.UDPRTPAddress = "127.0.0.1:8000"
s.UDPRTCPAddress = "127.0.0.1:8001"
@@ -700,8 +737,9 @@ func TestServerPlay(t *testing.T) {
s.MulticastIPRange = "224.1.0.0/16"
s.MulticastRTPPort = 8000
s.MulticastRTCPPort = 8001
}
case "tls":
if ca.scheme == "rtsps" {
cert, err := tls.X509KeyPair(serverCert, serverKey)
require.NoError(t, err)
s.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
@@ -723,7 +761,7 @@ func TestServerPlay(t *testing.T) {
require.NoError(t, err)
nconn = func() net.Conn {
if transport == "tls" {
if ca.scheme == "rtsps" {
return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true})
}
return nconn
@@ -734,11 +772,16 @@ func TestServerPlay(t *testing.T) {
desc := doDescribe(t, conn, false)
if ca.secure == "secure" {
require.True(t, desc.Medias[0].Secure)
require.NotEmpty(t, desc.Medias[0].KeyMgmtMikey)
}
inTH := &headers.Transport{
Mode: transportModePtr(headers.TransportModePlay),
}
switch transport {
switch ca.transport {
case "udp":
v := headers.TransportDeliveryUnicast
inTH.Delivery = &v
@@ -750,19 +793,81 @@ func TestServerPlay(t *testing.T) {
inTH.Delivery = &v
inTH.Protocol = headers.TransportProtocolUDP
default:
case "tcp":
v := headers.TransportDeliveryUnicast
inTH.Delivery = &v
inTH.Protocol = headers.TransportProtocolTCP
inTH.InterleavedIDs = &[2]int{5, 6} // odd value
}
res, th := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "")
h := base.Header{
"CSeq": base.HeaderValue{"1"},
}
var srtpOutCtx *wrappedSRTPContext
if ca.secure == "secure" {
inTH.Secure = true
key := make([]byte, srtpKeyLength)
_, err = rand.Read(key)
require.NoError(t, err)
srtpOutCtx = &wrappedSRTPContext{
key: key,
ssrcs: []uint32{2345423},
}
err = srtpOutCtx.initialize()
require.NoError(t, err)
var mikeyMsg *mikey.Message
mikeyMsg, err = mikeyGenerate(srtpOutCtx)
require.NoError(t, err)
var enc base.HeaderValue
enc, err = headers.KeyMgmt{
URL: mediaURL(t, desc.BaseURL, desc.Medias[0]).String(),
MikeyMessage: mikeyMsg,
}.Marshal()
require.NoError(t, err)
h["KeyMgmt"] = enc
}
h["Transport"] = inTH.Marshal()
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mediaURL(t, desc.BaseURL, desc.Medias[0]),
Header: h,
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
var th headers.Transport
err = th.Unmarshal(res.Header["Transport"])
require.NoError(t, err)
var srtpInCtx *wrappedSRTPContext
if ca.secure == "secure" {
require.True(t, th.Secure)
var keyMgmt headers.KeyMgmt
err = keyMgmt.Unmarshal(res.Header["KeyMgmt"])
require.NoError(t, err)
pl1, _ := mikeyGetPayload[*mikey.PayloadKEMAC](keyMgmt.MikeyMessage)
pl2, _ := mikeyGetPayload[*mikey.PayloadKEMAC](desc.Medias[0].KeyMgmtMikey)
require.Equal(t, pl1, pl2)
srtpInCtx, err = mikeyToContext(keyMgmt.MikeyMessage)
require.NoError(t, err)
}
var l1 net.PacketConn
var l2 net.PacketConn
switch transport { //nolint:dupl
switch ca.transport {
case "udp":
require.Equal(t, headers.TransportProtocolUDP, th.Protocol)
require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery)
@@ -775,7 +880,7 @@ func TestServerPlay(t *testing.T) {
require.NoError(t, err)
defer l2.Close()
case "multicast":
case "multicast": //nolint:dupl
require.Equal(t, headers.TransportProtocolUDP, th.Protocol)
require.Equal(t, headers.TransportDeliveryMulticast, *th.Delivery)
@@ -808,7 +913,7 @@ func TestServerPlay(t *testing.T) {
require.NoError(t, err)
}
default:
case "tcp":
require.Equal(t, headers.TransportProtocolTCP, th.Protocol)
require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery)
}
@@ -821,86 +926,122 @@ func TestServerPlay(t *testing.T) {
// server -> client (direct)
switch transport {
case "udp":
buf := make([]byte, 2048)
var n int
n, _, err = l2.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
if ca.transport != "multicast" {
var buf []byte
case "tcp", "tls":
var f *base.InterleavedFrame
f, err = conn.ReadInterleavedFrame()
require.NoError(t, err)
require.Equal(t, 6, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
switch ca.transport {
case "udp":
buf = make([]byte, 2048)
var n int
n, _, err = l2.ReadFrom(buf)
require.NoError(t, err)
buf = buf[:n]
case "tcp":
var f *base.InterleavedFrame
f, err = conn.ReadInterleavedFrame()
require.NoError(t, err)
require.Equal(t, 6, f.Channel)
buf = f.Payload
}
if ca.secure == "secure" {
buf, err = srtpInCtx.decryptRTCP(buf, buf, nil)
require.NoError(t, err)
}
require.Equal(t, testRTCPPacketMarshaled, buf)
}
// server -> client (through stream)
if transport == "udp" || transport == "multicast" {
buf := make([]byte, 2048)
var buf1 []byte
var buf2 []byte
switch ca.transport {
case "udp", "multicast":
buf1 = make([]byte, 2048)
var n int
n, _, err = l1.ReadFrom(buf)
n, _, err = l1.ReadFrom(buf1)
require.NoError(t, err)
buf1 = buf1[:n]
var pkt rtp.Packet
err = pkt.Unmarshal(buf[:n])
buf2 = make([]byte, 2048)
n, _, err = l2.ReadFrom(buf2)
require.NoError(t, err)
pkt.SSRC = testRTPPacket.SSRC
require.Equal(t, testRTPPacket, pkt)
buf2 = buf2[:n]
buf = make([]byte, 2048)
n, _, err = l2.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
} else {
case "tcp":
var f *base.InterleavedFrame
f, err = conn.ReadInterleavedFrame()
require.NoError(t, err)
require.Equal(t, 6, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
buf2 = f.Payload
f, err = conn.ReadInterleavedFrame()
require.NoError(t, err)
require.Equal(t, 5, f.Channel)
var pkt rtp.Packet
err = pkt.Unmarshal(f.Payload)
require.NoError(t, err)
pkt.SSRC = testRTPPacket.SSRC
require.Equal(t, testRTPPacket, pkt)
buf1 = f.Payload
}
if ca.secure == "secure" {
buf1, err = srtpInCtx.decryptRTP(buf1, buf1, nil)
require.NoError(t, err)
}
var pkt rtp.Packet
err = pkt.Unmarshal(buf1)
require.NoError(t, err)
pkt.SSRC = testRTPPacket.SSRC
require.Equal(t, testRTPPacket, pkt)
if ca.secure == "secure" {
buf2, err = srtpInCtx.decryptRTCP(buf2, buf2, nil)
require.NoError(t, err)
}
require.Equal(t, testRTCPPacketMarshaled, buf2)
// client -> server
switch transport {
buf := testRTCPPacketMarshaled
if ca.secure == "secure" {
encr := make([]byte, 2000)
encr, err = srtpOutCtx.encryptRTCP(encr, buf, nil)
require.NoError(t, err)
buf = encr
}
switch ca.transport {
case "udp":
_, err = l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{
_, err = l2.WriteTo(buf, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: th.ServerPorts[1],
})
require.NoError(t, err)
<-framesReceived
case "multicast":
_, err = l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{
_, err = l2.WriteTo(buf, &net.UDPAddr{
IP: *th.Destination,
Port: th.Ports[1],
})
require.NoError(t, err)
<-framesReceived
default:
case "tcp":
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 6,
Payload: testRTCPPacketMarshaled,
Payload: buf,
}, make([]byte, 1024))
require.NoError(t, err)
<-framesReceived
}
if transport == "udp" || transport == "multicast" {
<-framesReceived
// ping
switch ca.transport {
case "udp", "multicast":
// ping with OPTIONS
res, err = writeReqReadRes(conn, base.Request{
Method: base.Options,
@@ -941,7 +1082,6 @@ func TestServerPlaySocketError(t *testing.T) {
"udp",
"multicast",
"tcp",
"tls",
} {
t.Run(transport, func(t *testing.T) {
var stream *ServerStream
@@ -996,11 +1136,6 @@ func TestServerPlaySocketError(t *testing.T) {
s.MulticastIPRange = "224.1.0.0/16"
s.MulticastRTPPort = 8000
s.MulticastRTCPPort = 8001
case "tls":
cert, err := tls.X509KeyPair(serverCert, serverKey)
require.NoError(t, err)
s.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
}
err := s.Start()
@@ -1019,12 +1154,6 @@ func TestServerPlaySocketError(t *testing.T) {
require.NoError(t, err)
defer nconn.Close()
nconn = func() net.Conn {
if transport == "tls" {
return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true})
}
return nconn
}()
conn := conn.NewConn(nconn)
desc := doDescribe(t, conn, false)
@@ -1057,7 +1186,7 @@ func TestServerPlaySocketError(t *testing.T) {
var l1 net.PacketConn
var l2 net.PacketConn
switch transport { //nolint:dupl
switch transport {
case "udp":
require.Equal(t, headers.TransportProtocolUDP, th.Protocol)
require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery)
@@ -1070,7 +1199,7 @@ func TestServerPlaySocketError(t *testing.T) {
require.NoError(t, err)
defer l2.Close()
case "multicast":
case "multicast": //nolint:dupl
require.Equal(t, headers.TransportProtocolUDP, th.Protocol)
require.Equal(t, headers.TransportDeliveryMulticast, *th.Delivery)

View File

@@ -2,6 +2,7 @@ package gortsplib
import (
"bytes"
"crypto/rand"
"crypto/tls"
"net"
"strconv"
@@ -18,6 +19,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/gortsplib/v4/pkg/headers"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
"github.com/bluenviron/gortsplib/v4/pkg/sdp"
)
@@ -337,6 +339,9 @@ func TestServerRecordPath(t *testing.T) {
media := testH264Media
media.Control = ca.control
enc, err := media.Marshal2()
require.NoError(t, err)
sout := &sdp.SessionDescription{
SessionName: psdp.SessionName("Stream"),
Origin: psdp.Origin{
@@ -348,7 +353,7 @@ func TestServerRecordPath(t *testing.T) {
TimeDescriptions: []psdp.TimeDescription{
{Timing: psdp.Timing{}},
},
MediaDescriptions: []*psdp.MediaDescription{media.Marshal()},
MediaDescriptions: []*psdp.MediaDescription{enc},
}
byts, _ := sout.Marshal()
@@ -533,12 +538,38 @@ func TestServerRecordErrorRecordPartialMedias(t *testing.T) {
}
func TestServerRecord(t *testing.T) {
for _, transport := range []string{
"udp",
"tcp",
"tls",
for _, ca := range []struct {
scheme string
transport string
secure string
}{
{
"rtsp",
"udp",
"unsecure",
},
{
"rtsp",
"tcp",
"unsecure",
},
{
"rtsps",
"tcp",
"unsecure",
},
{
"rtsps",
"udp",
"secure",
},
{
"rtsps",
"tcp",
"secure",
},
} {
t.Run(transport, func(t *testing.T) {
t.Run(ca.scheme+"_"+ca.transport+"_"+ca.secure, func(t *testing.T) {
nconnOpened := make(chan struct{})
nconnClosed := make(chan struct{})
sessionOpened := make(chan struct{})
@@ -552,9 +583,9 @@ func TestServerRecord(t *testing.T) {
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
s := ctx.Conn.Stats()
require.Greater(t, s.BytesSent, uint64(510))
require.Less(t, s.BytesSent, uint64(560))
require.Less(t, s.BytesSent, uint64(1100))
require.Greater(t, s.BytesReceived, uint64(1000))
require.Less(t, s.BytesReceived, uint64(1200))
require.Less(t, s.BytesReceived, uint64(1800))
close(nconnClosed)
},
@@ -564,9 +595,9 @@ func TestServerRecord(t *testing.T) {
onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) {
s := ctx.Session.Stats()
require.Greater(t, s.BytesSent, uint64(75))
require.Less(t, s.BytesSent, uint64(130))
require.Less(t, s.BytesSent, uint64(140))
require.Greater(t, s.BytesReceived, uint64(70))
require.Less(t, s.BytesReceived, uint64(80))
require.Less(t, s.BytesReceived, uint64(130))
close(sessionClosed)
},
@@ -581,12 +612,12 @@ func TestServerRecord(t *testing.T) {
}, nil, nil
},
onRecord: func(ctx *ServerHandlerOnRecordCtx) (*base.Response, error) {
switch transport {
switch ca.transport {
case "udp":
v := TransportUDP
require.Equal(t, &v, ctx.Session.SetuppedTransport())
case "tcp", "tls":
case "tcp":
v := TransportTCP
require.Equal(t, &v, ctx.Session.SetuppedTransport())
}
@@ -628,12 +659,12 @@ func TestServerRecord(t *testing.T) {
RTSPAddress: "localhost:8554",
}
switch transport {
case "udp":
if ca.transport == "udp" {
s.UDPRTPAddress = "127.0.0.1:8000"
s.UDPRTCPAddress = "127.0.0.1:8001"
}
case "tls":
if ca.scheme == "rtsps" {
cert, err := tls.X509KeyPair(serverCert, serverKey)
require.NoError(t, err)
s.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
@@ -648,7 +679,7 @@ func TestServerRecord(t *testing.T) {
defer nconn.Close()
nconn = func() net.Conn {
if transport == "tls" {
if ca.scheme == "rtsps" {
return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true})
}
return nconn
@@ -686,6 +717,8 @@ func TestServerRecord(t *testing.T) {
var l2s [2]net.PacketConn
var session string
var serverPorts [2]*[2]int
var srtpOutCtx [2]*wrappedSRTPContext
var srtpInCtx [2]*wrappedSRTPContext
for i := 0; i < 2; i++ {
inTH := &headers.Transport{
@@ -693,7 +726,7 @@ func TestServerRecord(t *testing.T) {
Mode: transportModePtr(headers.TransportModeRecord),
}
if transport == "udp" {
if ca.transport == "udp" {
inTH.Protocol = headers.TransportProtocolUDP
inTH.ClientPorts = &[2]int{35466 + i*2, 35467 + i*2}
@@ -709,84 +742,186 @@ func TestServerRecord(t *testing.T) {
inTH.InterleavedIDs = &[2]int{2 + i*2, 3 + i*2}
}
res, th := doSetup(t, conn, "rtsp://localhost:8554/teststream?param=value/"+medias[i].Control, inTH, "")
h := base.Header{
"CSeq": base.HeaderValue{"1"},
}
if session != "" {
h["Session"] = base.HeaderValue{session}
}
if ca.secure == "secure" {
inTH.Secure = true
key := make([]byte, srtpKeyLength)
_, err = rand.Read(key)
require.NoError(t, err)
srtpOutCtx[i] = &wrappedSRTPContext{
key: key,
ssrcs: []uint32{2345423},
}
err = srtpOutCtx[i].initialize()
require.NoError(t, err)
var mikeyMsg *mikey.Message
mikeyMsg, err = mikeyGenerate(srtpOutCtx[i])
require.NoError(t, err)
var enc base.HeaderValue
enc, err = headers.KeyMgmt{
URL: "rtsp://localhost:8554/teststream?param=value/" + medias[i].Control,
MikeyMessage: mikeyMsg,
}.Marshal()
require.NoError(t, err)
h["KeyMgmt"] = enc
}
h["Transport"] = inTH.Marshal()
var res *base.Response
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream?param=value/" + medias[i].Control),
Header: h,
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
var th headers.Transport
err = th.Unmarshal(res.Header["Transport"])
require.NoError(t, err)
session = readSession(t, res)
if transport == "udp" {
if ca.transport == "udp" {
serverPorts[i] = th.ServerPorts
}
if ca.secure == "secure" {
require.True(t, th.Secure)
var keyMgmt headers.KeyMgmt
err = keyMgmt.Unmarshal(res.Header["KeyMgmt"])
require.NoError(t, err)
srtpInCtx[i], err = mikeyToContext(keyMgmt.MikeyMessage)
require.NoError(t, err)
}
}
doRecord(t, conn, "rtsp://localhost:8554/teststream", session)
for i := 0; i < 2; i++ {
// skip firewall opening
if transport == "udp" {
// skip firewall opening
if ca.transport == "udp" {
for i := 0; i < 2; i++ {
buf := make([]byte, 2048)
_, _, err = l2s[i].ReadFrom(buf)
require.NoError(t, err)
}
}
// server -> client
// server -> client
if transport == "udp" {
buf := make([]byte, 2048)
for i := 0; i < 2; i++ {
var buf []byte
if ca.transport == "udp" {
buf = make([]byte, 2048)
var n int
n, _, err = l2s[i].ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
buf = buf[:n]
} else {
var f *base.InterleavedFrame
f, err = conn.ReadInterleavedFrame()
require.NoError(t, err)
require.Equal(t, 3+i*2, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
buf = f.Payload
}
// client -> server
if ca.secure == "secure" {
buf, err = srtpInCtx[i].decryptRTCP(buf, buf, nil)
require.NoError(t, err)
}
if transport == "udp" {
_, err = l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{
require.Equal(t, testRTCPPacketMarshaled, buf)
}
// client -> server
for i := 0; i < 2; i++ {
buf1 := testRTPPacketMarshaled
if ca.secure == "secure" {
encr := make([]byte, 2000)
encr, err = srtpOutCtx[i].encryptRTP(encr, buf1, nil)
require.NoError(t, err)
buf1 = encr
}
buf2 := testRTCPPacketMarshaled
if ca.secure == "secure" {
encr := make([]byte, 2000)
encr, err = srtpOutCtx[i].encryptRTCP(encr, buf2, nil)
require.NoError(t, err)
buf2 = encr
}
if ca.transport == "udp" {
_, err = l1s[i].WriteTo(buf1, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: serverPorts[i][0],
})
require.NoError(t, err)
_, err = l2s[i].WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{
_, err = l2s[i].WriteTo(buf2, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: serverPorts[i][1],
})
require.NoError(t, err)
} else {
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 2 + i*2,
Payload: testRTPPacketMarshaled,
Payload: buf1,
}, make([]byte, 1024))
require.NoError(t, err)
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 3 + i*2,
Payload: testRTCPPacketMarshaled,
Payload: buf2,
}, make([]byte, 1024))
require.NoError(t, err)
}
}
for i := 0; i < 2; i++ {
// server -> client
// server -> client
if transport == "udp" {
buf := make([]byte, 2048)
n, _, err := l2s[i].ReadFrom(buf)
for i := 0; i < 2; i++ {
var buf []byte
if ca.transport == "udp" {
buf = make([]byte, 2048)
var n int
n, _, err = l2s[i].ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
buf = buf[:n]
} else {
f, err := conn.ReadInterleavedFrame()
var f *base.InterleavedFrame
f, err = conn.ReadInterleavedFrame()
require.NoError(t, err)
require.Equal(t, 3+i*2, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
buf = f.Payload
}
if ca.secure == "secure" {
buf, err = srtpInCtx[i].decryptRTCP(buf, buf, nil)
require.NoError(t, err)
}
require.Equal(t, testRTCPPacketMarshaled, buf)
}
doTeardown(t, conn, "rtsp://localhost:8554/teststream", session)

View File

@@ -1,6 +1,7 @@
package gortsplib
import (
"bytes"
"context"
"fmt"
"log"
@@ -20,6 +21,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/gortsplib/v4/pkg/headers"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
"github.com/bluenviron/gortsplib/v4/pkg/mikey"
"github.com/bluenviron/gortsplib/v4/pkg/rtcpreceiver"
"github.com/bluenviron/gortsplib/v4/pkg/rtcpsender"
"github.com/bluenviron/gortsplib/v4/pkg/rtptime"
@@ -165,21 +167,160 @@ func findMediaByTrackID(medias []*description.Media, trackID string) *descriptio
return medias[id]
}
func findFirstSupportedTransportHeader(s *Server, tsh headers.Transports) *headers.Transport {
// Per RFC2326 section 12.39, client specifies transports in order of preference.
// Filter out the ones we don't support and then pick first supported transport.
for _, tr := range tsh {
func isTransportSupported(s *Server, tr *headers.Transport) bool {
// prevent using UDP/UDP-multicast when listeners are disabled
if tr.Protocol == headers.TransportProtocolUDP {
isMulticast := tr.Delivery != nil && *tr.Delivery == headers.TransportDeliveryMulticast
if tr.Protocol == headers.TransportProtocolUDP &&
((!isMulticast && s.udpRTPListener == nil) ||
(isMulticast && s.MulticastIPRange == "")) {
continue
if !isMulticast && s.udpRTPListener == nil {
return false
}
if isMulticast && s.MulticastIPRange == "" {
return false
}
}
// prevent using unsecure UDP with RTSPS
if tr.Protocol == headers.TransportProtocolUDP && !tr.Secure && s.TLSConfig != nil {
return false
}
// prevent using secure profiles with plain RTSP, since keys are in plain
if tr.Secure && s.TLSConfig == nil {
return false
}
return true
}
func pickFirstSupportedTransport(s *Server, tsh headers.Transports) *headers.Transport {
for _, tr := range tsh {
if isTransportSupported(s, &tr) {
return &tr
}
return &tr
}
return nil
}
func mikeyDecodeTime(t uint64) time.Time {
sec := t >> 32
dec := t & 0xFFFFFFFF
sec -= 2208988800
return time.Unix(int64(sec), int64(dec))
}
func mikeyEncodeTime(n time.Time) uint64 {
nano := uint64(n.UnixNano())
sec := nano / 1000000000
dec := nano % 1000000000
sec += 2208988800
return sec<<32 | dec
}
func mikeyGetPayload[T mikey.Payload](mikeyMsg *mikey.Message) (T, bool) {
var zero T
for _, wrapped := range mikeyMsg.Payloads {
if val, ok := wrapped.(T); ok {
return val, true
}
}
return zero, false
}
func mikeyGetSPPolicy(spPayload *mikey.PayloadSP, typ mikey.PayloadSPPolicyParamType) ([]byte, bool) {
for _, pl := range spPayload.PolicyParams {
if pl.Type == typ {
return pl.Value, true
}
}
return nil, false
}
func mikeyToContext(mikeyMsg *mikey.Message) (*wrappedSRTPContext, error) {
timePayload, ok := mikeyGetPayload[*mikey.PayloadT](mikeyMsg)
if !ok {
return nil, fmt.Errorf("time payload not present")
}
ts := mikeyDecodeTime(timePayload.TSValue)
diff := time.Since(ts)
if diff < -time.Hour || diff > time.Hour {
return nil, fmt.Errorf("NTP difference is too high")
}
spPayload, ok := mikeyGetPayload[*mikey.PayloadSP](mikeyMsg)
if !ok {
return nil, fmt.Errorf("SP payload not present")
}
v, ok := mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeEncrAlg)
if !ok || !bytes.Equal(v, []byte{1}) {
return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeEncrAlg")
}
v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeSessionEncrKeyLen)
if !ok || !bytes.Equal(v, []byte{0x10}) {
return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeSessionEncrKeyLen")
}
v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeAuthAlg)
if !ok || !bytes.Equal(v, []byte{1}) {
return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeAuthAlg")
}
v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeSessionAuthKeyLen)
if !ok || !bytes.Equal(v, []byte{0x0a}) {
return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeSessionAuthKeyLen")
}
v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeSRTPEncrOffOn)
if !ok || !bytes.Equal(v, []byte{1}) {
return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeSRTPEncrOffOn")
}
v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeSRTCPEncrOffOn)
if !ok || !bytes.Equal(v, []byte{1}) {
return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeSRTCPEncrOffOn")
}
v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeSRTPAuthOffOn)
if !ok || !bytes.Equal(v, []byte{1}) {
return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeSRTPAuthOffOn")
}
kemacPayload, ok := mikeyGetPayload[*mikey.PayloadKEMAC](mikeyMsg)
if !ok {
return nil, fmt.Errorf("KEMAC payload not present")
}
if len(kemacPayload.SubPayloads) != 1 {
return nil, fmt.Errorf("multiple keys are present")
}
if len(kemacPayload.SubPayloads[0].KeyData) != srtpKeyLength {
return nil, fmt.Errorf("unexpected key size: %d", len(kemacPayload.SubPayloads[0].KeyData))
}
ssrcs := make([]uint32, len(mikeyMsg.Header.CSIDMapInfo))
startROCs := make([]uint32, len(mikeyMsg.Header.CSIDMapInfo))
for i, entry := range mikeyMsg.Header.CSIDMapInfo {
ssrcs[i] = entry.SSRC
startROCs[i] = entry.ROC
}
srtpCtx := &wrappedSRTPContext{
key: kemacPayload.SubPayloads[0].KeyData,
ssrcs: ssrcs,
startROCs: startROCs,
}
err := srtpCtx.initialize()
if err != nil {
return nil, err
}
return srtpCtx, nil
}
func generateRTPInfoEntry(ssm *serverStreamMedia, now time.Time) *headers.RTPInfoEntry {
// do not generate a RTP-Info entry when
// there are multiple formats inside a single media stream,
@@ -293,6 +434,7 @@ type ServerSession struct {
setuppedMediasOrdered []*serverSessionMedia
tcpCallbackByChannel map[int]readFunc
setuppedTransport *Transport
setuppedSecure bool
setuppedStream *ServerStream // play
setuppedPath string
setuppedQuery string
@@ -371,6 +513,13 @@ func (ss *ServerSession) SetuppedTransport() *Transport {
return ss.setuppedTransport
}
// SetuppedSecure returns whether a secure profile is in use.
// If this is false, it does not mean that the stream is not secure, since
// there are some combinations that are secure nonetheless, like RTSPS+TCP+unsecure.
func (ss *ServerSession) SetuppedSecure() bool {
return ss.setuppedSecure
}
// SetuppedStream returns the stream associated with the session.
func (ss *ServerSession) SetuppedStream() *ServerStream {
return ss.setuppedStream
@@ -947,7 +1096,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}, liberrors.ErrServerTransportHeaderInvalid{Err: err}
}
inTH := findFirstSupportedTransportHeader(ss.s, transportHeaders)
// Per RFC2326 section 12.39, client specifies transports in order of preference.
// pick the first supported one.
inTH := pickFirstSupportedTransport(ss.s, transportHeaders)
if inTH == nil {
return &base.Response{
StatusCode: base.StatusUnsupportedTransport,
@@ -978,20 +1129,41 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
var transport Transport
if inTH.Protocol == headers.TransportProtocolUDP {
switch inTH.Protocol {
case headers.TransportProtocolUDP:
if inTH.Delivery != nil && *inTH.Delivery == headers.TransportDeliveryMulticast {
transport = TransportUDPMulticast
} else {
transport = TransportUDP
}
} else {
case headers.TransportProtocolTCP:
transport = TransportTCP
}
if ss.setuppedTransport != nil && *ss.setuppedTransport != transport {
var srtpInCtx *wrappedSRTPContext
if inTH.Secure {
var keyMgmt headers.KeyMgmt
err = keyMgmt.Unmarshal(req.Header["KeyMgmt"])
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidKeyMgmtHeader{Wrapped: err}
}
srtpInCtx, err = mikeyToContext(keyMgmt.MikeyMessage)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerInvalidKeyMgmtHeader{Wrapped: err}
}
}
if ss.setuppedTransport != nil && (*ss.setuppedTransport != transport || ss.setuppedSecure != inTH.Secure) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerMediasDifferentProtocols{}
}, liberrors.ErrServerMediasDifferentTransports{}
}
switch transport {
@@ -1052,7 +1224,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
// workaround to prevent a bug in rtspclientsink
// that makes impossible for the client to receive the response
// and send frames.
// this was causing problems during unit tests.
// this was causing problems during E2E tests.
if ua, ok := req.Header["User-Agent"]; ok && len(ua) == 1 &&
strings.HasPrefix(ua[0], "GStreamer") {
select {
@@ -1092,6 +1264,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}
ss.setuppedTransport = &transport
ss.setuppedSecure = inTH.Secure
if ss.state == ServerSessionStateInitial {
err = stream.readerAdd(ss,
@@ -1109,7 +1282,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
ss.setuppedStream = stream
}
th := headers.Transport{}
th := headers.Transport{
Secure: inTH.Secure,
}
if ss.state == ServerSessionStatePrePlay {
if stream != ss.setuppedStream {
@@ -1131,6 +1306,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
sm := &serverSessionMedia{
ss: ss,
media: medi,
srtpInCtx: srtpInCtx,
onPacketRTCP: func(_ rtcp.Packet) {},
}
err = sm.initialize()
@@ -1141,46 +1317,48 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}
switch transport {
case TransportUDP:
sm.udpRTPReadPort = inTH.ClientPorts[0]
sm.udpRTCPReadPort = inTH.ClientPorts[1]
sm.udpRTPWriteAddr = &net.UDPAddr{
IP: ss.author.ip(),
Zone: ss.author.zone(),
Port: sm.udpRTPReadPort,
}
sm.udpRTCPWriteAddr = &net.UDPAddr{
IP: ss.author.ip(),
Zone: ss.author.zone(),
Port: sm.udpRTCPReadPort,
}
case TransportUDP, TransportUDPMulticast:
th.Protocol = headers.TransportProtocolUDP
de := headers.TransportDeliveryUnicast
th.Delivery = &de
th.ClientPorts = inTH.ClientPorts
th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()}
case TransportUDPMulticast:
th.Protocol = headers.TransportProtocolUDP
de := headers.TransportDeliveryMulticast
th.Delivery = &de
v := uint(127)
th.TTL = &v
d := stream.medias[medi].multicastWriter.ip()
th.Destination = &d
th.Ports = &[2]int{ss.s.MulticastRTPPort, ss.s.MulticastRTCPPort}
if transport == TransportUDP {
sm.udpRTPReadPort = inTH.ClientPorts[0]
sm.udpRTCPReadPort = inTH.ClientPorts[1]
sm.udpRTPWriteAddr = &net.UDPAddr{
IP: ss.author.ip(),
Zone: ss.author.zone(),
Port: sm.udpRTPReadPort,
}
sm.udpRTCPWriteAddr = &net.UDPAddr{
IP: ss.author.ip(),
Zone: ss.author.zone(),
Port: sm.udpRTCPReadPort,
}
de := headers.TransportDeliveryUnicast
th.Delivery = &de
th.ClientPorts = inTH.ClientPorts
th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()}
} else {
de := headers.TransportDeliveryMulticast
th.Delivery = &de
v := uint(127)
th.TTL = &v
d := stream.medias[medi].multicastWriter.ip()
th.Destination = &d
th.Ports = &[2]int{ss.s.MulticastRTPPort, ss.s.MulticastRTCPPort}
}
default: // TCP
th.Protocol = headers.TransportProtocolTCP
if inTH.InterleavedIDs != nil {
sm.tcpChannel = inTH.InterleavedIDs[0]
} else {
sm.tcpChannel = ss.findFreeChannelPair()
}
th.Protocol = headers.TransportProtocolTCP
de := headers.TransportDeliveryUnicast
th.Delivery = &de
th.InterleavedIDs = &[2]int{sm.tcpChannel, sm.tcpChannel + 1}
@@ -1193,6 +1371,38 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
ss.setuppedMediasOrdered = append(ss.setuppedMediasOrdered, sm)
res.Header["Transport"] = th.Marshal()
if inTH.Secure {
ssrcs := make([]uint32, len(sm.formats))
n := 0
for _, sf := range sm.formats {
ssrcs[n] = sf.localSSRC
n++
}
var mk *mikey.Message
mk, err = mikeyGenerate(sm.srtpOutCtx)
if err != nil {
return &base.Response{
StatusCode: base.StatusInternalServerError,
}, err
}
var enc base.HeaderValue
enc, err = headers.KeyMgmt{
URL: req.URL.String(),
MikeyMessage: mk,
}.Marshal()
if err != nil {
return &base.Response{
StatusCode: base.StatusInternalServerError,
}, err
}
// always return KeyMgmt even if redundant when playing
// (since it's already present in the SDP)
res.Header["KeyMgmt"] = enc
}
}
return res, err
@@ -1239,7 +1449,12 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
ss.timeDecoder.Initialize()
for _, sm := range ss.setuppedMedias {
sm.start()
err = sm.start()
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
}
if *ss.setuppedTransport == TransportTCP {
@@ -1329,7 +1544,12 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
ss.timeDecoder.Initialize()
for _, sm := range ss.setuppedMedias {
sm.start()
err = sm.start()
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
}
if *ss.setuppedTransport == TransportTCP {
@@ -1523,12 +1743,6 @@ func (ss *ServerSession) OnPacketRTCP(medi *description.Media, cb OnPacketRTCPFu
sm.onPacketRTCP = cb
}
func (ss *ServerSession) writePacketRTPEncoded(medi *description.Media, payloadType uint8, byts []byte) error {
sm := ss.setuppedMedias[medi]
sf := sm.formats[payloadType]
return sf.writePacketRTPEncoded(byts)
}
// WritePacketRTP writes a RTP packet to the session.
func (ss *ServerSession) WritePacketRTP(medi *description.Media, pkt *rtp.Packet) error {
sm := ss.setuppedMedias[medi]
@@ -1536,22 +1750,13 @@ func (ss *ServerSession) WritePacketRTP(medi *description.Media, pkt *rtp.Packet
return sf.writePacketRTP(pkt)
}
func (ss *ServerSession) writePacketRTCPEncoded(medi *description.Media, byts []byte) error {
sm := ss.setuppedMedias[medi]
return sm.writePacketRTCPEncoded(byts)
}
// WritePacketRTCP writes a RTCP packet to the session.
func (ss *ServerSession) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error {
byts, err := pkt.Marshal()
if err != nil {
return err
}
return ss.writePacketRTCPEncoded(medi, byts)
sm := ss.setuppedMedias[medi]
return sm.writePacketRTCP(pkt)
}
// PacketPTS returns the PTS of an incoming RTP packet.
// PacketPTS returns the PTS (presentation timestamp) of an incoming RTP packet.
// It is computed by decoding the packet timestamp and sychronizing it with other tracks.
//
// Deprecated: replaced by PacketPTS2.
@@ -1567,7 +1772,7 @@ func (ss *ServerSession) PacketPTS(medi *description.Media, pkt *rtp.Packet) (ti
return multiplyAndDivide(time.Duration(v), time.Second, time.Duration(sf.format.ClockRate())), true
}
// PacketPTS2 returns the PTS of an incoming RTP packet.
// PacketPTS2 returns the PTS (presentation timestamp) of an incoming RTP packet.
// It is computed by decoding the packet timestamp and sychronizing it with other tracks.
func (ss *ServerSession) PacketPTS2(medi *description.Media, pkt *rtp.Packet) (int64, bool) {
sm := ss.setuppedMedias[medi]
@@ -1575,8 +1780,8 @@ func (ss *ServerSession) PacketPTS2(medi *description.Media, pkt *rtp.Packet) (i
return ss.timeDecoder.Decode(sf.format, pkt)
}
// PacketNTP returns the NTP timestamp of an incoming RTP packet.
// The NTP timestamp is computed from RTCP sender reports.
// PacketNTP returns the NTP (absolute timestamp) of an incoming RTP packet.
// The NTP is computed from RTCP sender reports.
func (ss *ServerSession) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time, bool) {
sm := ss.setuppedMedias[medi]
sf := sm.formats[pkt.PayloadType]

View File

@@ -2,6 +2,7 @@ package gortsplib
import (
"log"
"slices"
"sync/atomic"
"time"
@@ -15,25 +16,26 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/rtpreorderer"
)
func isServerSessionLocalSSRCTaken(ssrc uint32, ss *ServerSession, exclude *serverSessionFormat) bool {
for _, sm := range ss.setuppedMedias {
func serverSessionPickLocalSSRC(sf *serverSessionFormat) (uint32, error) {
var takenSSRCs []uint32 //nolint:prealloc
for _, sm := range sf.sm.ss.setuppedMedias {
for _, sf := range sm.formats {
if sf != exclude && sf.localSSRC == ssrc {
return true
}
takenSSRCs = append(takenSSRCs, sf.localSSRC)
}
}
return false
}
func serverSessionPickLocalSSRC(sf *serverSessionFormat) (uint32, error) {
for _, sf := range sf.sm.formats {
takenSSRCs = append(takenSSRCs, sf.localSSRC)
}
for {
ssrc, err := randUint32()
if err != nil {
return 0, err
}
if ssrc != 0 && !isServerSessionLocalSSRCTaken(ssrc, sf.sm.ss, sf) {
if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) {
return ssrc, nil
}
}
@@ -188,14 +190,31 @@ func (sf *serverSessionFormat) onPacketRTPLost(lost uint64) {
func (sf *serverSessionFormat) writePacketRTP(pkt *rtp.Packet) error {
pkt.SSRC = sf.localSSRC
byts := make([]byte, sf.sm.ss.s.MaxPacketSize)
n, err := pkt.MarshalTo(byts)
maxPlainPacketSize := sf.sm.ss.s.MaxPacketSize
if sf.sm.ss.setuppedSecure {
maxPlainPacketSize -= srtpOverhead
}
plain := make([]byte, maxPlainPacketSize)
n, err := pkt.MarshalTo(plain)
if err != nil {
return err
}
byts = byts[:n]
plain = plain[:n]
return sf.writePacketRTPEncoded(byts)
var encr []byte
if sf.sm.ss.setuppedSecure {
encr = make([]byte, sf.sm.ss.s.MaxPacketSize)
encr, err = sf.sm.srtpOutCtx.encryptRTP(encr, plain, &pkt.Header)
if err != nil {
return err
}
}
if sf.sm.ss.setuppedSecure {
return sf.writePacketRTPEncoded(encr)
}
return sf.writePacketRTPEncoded(plain)
}
func (sf *serverSessionFormat) writePacketRTPEncoded(payload []byte) error {

View File

@@ -1,6 +1,8 @@
package gortsplib
import (
"crypto/rand"
"fmt"
"log"
"net"
"sync/atomic"
@@ -16,8 +18,10 @@ import (
type serverSessionMedia struct {
ss *ServerSession
media *description.Media
srtpInCtx *wrappedSRTPContext
onPacketRTCP OnPacketRTCPFunc
srtpOutCtx *wrappedSRTPContext
tcpChannel int
udpRTPReadPort int
udpRTPWriteAddr *net.UDPAddr
@@ -56,12 +60,41 @@ func (sm *serverSessionMedia) initialize() error {
sm.formats[forma.PayloadType()] = f
}
if sm.ss.s.TLSConfig != nil {
if sm.ss.state == ServerSessionStatePreRecord || sm.media.IsBackChannel {
srtpOutKey := make([]byte, srtpKeyLength)
_, err := rand.Read(srtpOutKey)
if err != nil {
return err
}
ssrcs := make([]uint32, len(sm.formats))
n := 0
for _, cf := range sm.formats {
ssrcs[n] = cf.localSSRC
n++
}
sm.srtpOutCtx = &wrappedSRTPContext{
key: srtpOutKey,
ssrcs: ssrcs,
}
err = sm.srtpOutCtx.initialize()
if err != nil {
return err
}
} else {
streamMedia := sm.ss.setuppedStream.medias[sm.media]
sm.srtpOutCtx = streamMedia.srtpOutCtx
}
}
return nil
}
func (sm *serverSessionMedia) start() {
func (sm *serverSessionMedia) start() error {
// allocate udpRTCPReceiver before udpRTCPListener
// otherwise udpRTCPReceiver.LastSSRC() can't be called.
// otherwise udpRTCPReceiver.LastSSRC() cannot be called.
for _, sf := range sm.formats {
sf.start()
}
@@ -78,11 +111,33 @@ func (sm *serverSessionMedia) start() {
sm.ss.s.udpRTCPListener.addClient(sm.ss.author.ip(), sm.udpRTCPReadPort, sm.readPacketRTCPUDPPlay)
} else {
// open the firewall by sending empty packets to the remote part.
byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal()
sm.ss.s.udpRTPListener.write(byts, sm.udpRTPWriteAddr) //nolint:errcheck
buf, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal()
if sm.srtpOutCtx != nil {
encr := make([]byte, sm.ss.s.MaxPacketSize)
encr, err := sm.srtpOutCtx.encryptRTP(encr, buf, nil)
if err != nil {
return err
}
buf = encr
}
err := sm.ss.s.udpRTPListener.write(buf, sm.udpRTPWriteAddr)
if err != nil {
return err
}
byts, _ = (&rtcp.ReceiverReport{}).Marshal()
sm.ss.s.udpRTCPListener.write(byts, sm.udpRTCPWriteAddr) //nolint:errcheck
buf, _ = (&rtcp.ReceiverReport{}).Marshal()
if sm.srtpOutCtx != nil {
encr := make([]byte, sm.ss.s.MaxPacketSize)
encr, err = sm.srtpOutCtx.encryptRTCP(encr, buf, nil)
if err != nil {
return err
}
buf = encr
}
err = sm.ss.s.udpRTCPListener.write(buf, sm.udpRTCPWriteAddr)
if err != nil {
return err
}
sm.ss.s.udpRTPListener.addClient(sm.ss.author.ip(), sm.udpRTPReadPort, sm.readPacketRTPUDPRecord)
sm.ss.s.udpRTCPListener.addClient(sm.ss.author.ip(), sm.udpRTCPReadPort, sm.readPacketRTCPUDPRecord)
@@ -104,6 +159,8 @@ func (sm *serverSessionMedia) start() {
sm.ss.tcpCallbackByChannel[sm.tcpChannel+1] = sm.readPacketRTCPTCPRecord
}
}
return nil
}
func (sm *serverSessionMedia) stop() {
@@ -127,6 +184,37 @@ func (sm *serverSessionMedia) findFormatByRemoteSSRC(ssrc uint32) *serverSession
return nil
}
func (sm *serverSessionMedia) decodeRTP(payload []byte) (*rtp.Packet, error) {
if sm.srtpInCtx != nil {
var err error
payload, err = sm.srtpInCtx.decryptRTP(payload, payload, nil)
if err != nil {
return nil, err
}
}
var pkt rtp.Packet
err := pkt.Unmarshal(payload)
return &pkt, err
}
func (sm *serverSessionMedia) decodeRTCP(payload []byte) ([]rtcp.Packet, error) {
if sm.srtpInCtx != nil {
var err error
payload, err = sm.srtpInCtx.decryptRTCP(payload, payload, nil)
if err != nil {
return nil, err
}
}
pkts, err := rtcp.Unmarshal(payload)
if err != nil {
return nil, err
}
return pkts, nil
}
func (sm *serverSessionMedia) readPacketRTPUDPPlay(payload []byte) bool {
atomic.AddUint64(sm.bytesReceived, uint64(len(payload)))
@@ -135,8 +223,7 @@ func (sm *serverSessionMedia) readPacketRTPUDPPlay(payload []byte) bool {
return false
}
pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload)
pkt, err := sm.decodeRTP(payload)
if err != nil {
sm.onPacketRTPDecodeError(err)
return false
@@ -163,7 +250,7 @@ func (sm *serverSessionMedia) readPacketRTCPUDPPlay(payload []byte) bool {
return false
}
packets, err := rtcp.Unmarshal(payload)
packets, err := sm.decodeRTCP(payload)
if err != nil {
sm.onPacketRTCPDecodeError(err)
return false
@@ -189,8 +276,7 @@ func (sm *serverSessionMedia) readPacketRTPUDPRecord(payload []byte) bool {
return false
}
pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload)
pkt, err := sm.decodeRTP(payload)
if err != nil {
sm.onPacketRTPDecodeError(err)
return false
@@ -218,7 +304,7 @@ func (sm *serverSessionMedia) readPacketRTCPUDPRecord(payload []byte) bool {
return false
}
packets, err := rtcp.Unmarshal(payload)
packets, err := sm.decodeRTCP(payload)
if err != nil {
sm.onPacketRTCPDecodeError(err)
return false
@@ -250,8 +336,7 @@ func (sm *serverSessionMedia) readPacketRTPTCPPlay(payload []byte) bool {
atomic.AddUint64(sm.bytesReceived, uint64(len(payload)))
pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload)
pkt, err := sm.decodeRTP(payload)
if err != nil {
sm.onPacketRTPDecodeError(err)
return false
@@ -276,7 +361,7 @@ func (sm *serverSessionMedia) readPacketRTCPTCPPlay(payload []byte) bool {
return false
}
packets, err := rtcp.Unmarshal(payload)
packets, err := sm.decodeRTCP(payload)
if err != nil {
sm.onPacketRTCPDecodeError(err)
return false
@@ -294,8 +379,7 @@ func (sm *serverSessionMedia) readPacketRTCPTCPPlay(payload []byte) bool {
func (sm *serverSessionMedia) readPacketRTPTCPRecord(payload []byte) bool {
atomic.AddUint64(sm.bytesReceived, uint64(len(payload)))
pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload)
pkt, err := sm.decodeRTP(payload)
if err != nil {
sm.onPacketRTPDecodeError(err)
return false
@@ -320,7 +404,7 @@ func (sm *serverSessionMedia) readPacketRTCPTCPRecord(payload []byte) bool {
return false
}
packets, err := rtcp.Unmarshal(payload)
packets, err := sm.decodeRTCP(payload)
if err != nil {
sm.onPacketRTCPDecodeError(err)
return false
@@ -370,6 +454,36 @@ func (sm *serverSessionMedia) onPacketRTCPDecodeError(err error) {
}
}
func (sm *serverSessionMedia) writePacketRTCP(pkt rtcp.Packet) error {
plain, err := pkt.Marshal()
if err != nil {
return err
}
maxPlainPacketSize := sm.ss.s.MaxPacketSize
if sm.ss.setuppedSecure {
maxPlainPacketSize -= srtcpOverhead
}
if len(plain) > maxPlainPacketSize {
return fmt.Errorf("packet is too big")
}
var encr []byte
if sm.ss.setuppedSecure {
encr = make([]byte, sm.ss.s.MaxPacketSize)
encr, err = sm.srtpOutCtx.encryptRTCP(encr, plain, nil)
if err != nil {
return err
}
}
if sm.ss.setuppedSecure {
return sm.writePacketRTCPEncoded(encr)
}
return sm.writePacketRTCPEncoded(plain)
}
func (sm *serverSessionMedia) writePacketRTCPEncoded(payload []byte) error {
sm.ss.writerMutex.RLock()
defer sm.ss.writerMutex.RUnlock()

Some files were not shown because too many files have changed in this diff Show More