From 616fa7ea8985f155c9cd643e2ab0fbb26694d5a4 Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Sat, 5 Jul 2025 12:48:13 +0200 Subject: [PATCH] support encrypted streams with SRTP and MIKEY (#520) (#809) --- README.md | 22 +- client.go | 473 +++++++++++++----- client_format.go | 56 ++- client_media.go | 148 +++++- client_play_test.go | 417 +++++++++++++-- client_record_test.go | 202 +++++--- client_test.go | 7 + client_udp_listener.go | 40 -- constants.go | 9 + .../{server-tls => server-secure}/main.go | 13 +- go.mod | 3 + go.sum | 6 + internal/teste2e/client_vs_server_test.go | 12 + internal/teste2e/sample_server_test.go | 19 +- internal/teste2e/server_vs_external_test.go | 40 ++ pkg/auth/verify.go | 16 +- pkg/base/header.go | 3 + pkg/base/header_test.go | 3 + pkg/description/media.go | 61 ++- pkg/description/session.go | 44 +- pkg/description/session_test.go | 189 +++++++ pkg/format/rtpac3/decoder.go | 1 - pkg/format/rtplpcm/decoder.go | 1 - pkg/format/rtpmpeg4audio/decoder.go | 2 - pkg/headers/key_mgmt.go | 86 ++++ pkg/headers/key_mgmt_test.go | 143 ++++++ pkg/headers/range.go | 6 +- pkg/headers/{rtpinfo.go => rtp_info.go} | 0 .../{rtpinfo_test.go => rtp_info_test.go} | 0 pkg/headers/session.go | 2 +- .../FuzzKeyMgmtUnmarshal/531ecc27fef0609a | 2 + .../FuzzKeyMgmtUnmarshal/771e938e4458e983 | 2 + .../FuzzKeyMgmtUnmarshal/90d404dbb91eead6 | 2 + .../FuzzKeyMgmtUnmarshal/a8857c4807d99b81 | 2 + .../FuzzKeyMgmtUnmarshal/d015a7c61a819cac | 2 + .../FuzzKeyMgmtUnmarshal/f98f6f990321cbbb | 2 + .../FuzzTransportsUnmarshal/249c6737cb7d7159 | 2 + .../FuzzTransportsUnmarshal/302fc5e96ed32a08 | 2 + .../FuzzTransportsUnmarshal/488626fc6b0fd159 | 2 + .../FuzzTransportsUnmarshal/a4fe0bdca2a17b9c | 2 + pkg/headers/transport.go | 120 +++-- pkg/headers/transport_test.go | 97 +--- pkg/headers/transports.go | 48 ++ pkg/headers/transports_test.go | 97 ++++ pkg/liberrors/client.go | 2 + pkg/liberrors/server.go | 19 +- pkg/mikey/header.go | 143 ++++++ pkg/mikey/message.go | 91 ++++ pkg/mikey/message_test.go | 324 ++++++++++++ pkg/mikey/payload.go | 20 + pkg/mikey/payload_kemac.go | 131 +++++ pkg/mikey/payload_rand.go | 46 ++ pkg/mikey/payload_sp.go | 129 +++++ pkg/mikey/payload_t.go | 56 +++ pkg/mikey/sub_payload_key_data.go | 66 +++ .../fuzz/FuzzUnmarshal/0e2aa46892dbf440 | 2 + .../fuzz/FuzzUnmarshal/0e400a18088f3ef1 | 2 + .../fuzz/FuzzUnmarshal/1308b4f12c633cfe | 2 + .../fuzz/FuzzUnmarshal/15a72f9c17aa83eb | 2 + .../fuzz/FuzzUnmarshal/21d92b615e38b74d | 2 + .../fuzz/FuzzUnmarshal/2968ff6b6c107cd2 | 2 + .../fuzz/FuzzUnmarshal/2c79d879e91381fc | 2 + .../fuzz/FuzzUnmarshal/3274006efd885c07 | 2 + .../fuzz/FuzzUnmarshal/34785eb47d444797 | 2 + .../fuzz/FuzzUnmarshal/34b60fae6a60ed23 | 2 + .../fuzz/FuzzUnmarshal/3cf3060ecf42827a | 2 + .../fuzz/FuzzUnmarshal/4305b65c1705da06 | 2 + .../fuzz/FuzzUnmarshal/450a97754aa91bb3 | 2 + .../fuzz/FuzzUnmarshal/49b410a3c47b1687 | 2 + .../fuzz/FuzzUnmarshal/4afa2df02fd9d00a | 2 + .../fuzz/FuzzUnmarshal/4cece13ffff9b317 | 2 + .../fuzz/FuzzUnmarshal/68b704daf492f697 | 2 + .../fuzz/FuzzUnmarshal/6e357cea2c7a331b | 2 + .../fuzz/FuzzUnmarshal/730d6c62c6f77fdc | 2 + .../fuzz/FuzzUnmarshal/7827f50473cee286 | 2 + .../fuzz/FuzzUnmarshal/7bf4d4e3d8104096 | 2 + .../fuzz/FuzzUnmarshal/7e509fbe194aa190 | 2 + .../fuzz/FuzzUnmarshal/87828f4d0c4c5b02 | 2 + .../fuzz/FuzzUnmarshal/8bfad049e124e765 | 2 + .../fuzz/FuzzUnmarshal/93e877be95b75ad0 | 2 + .../fuzz/FuzzUnmarshal/a10c26126d5754cc | 2 + .../fuzz/FuzzUnmarshal/a6078eb8043d4763 | 2 + .../fuzz/FuzzUnmarshal/af43d20f8944989e | 2 + .../fuzz/FuzzUnmarshal/c20cecf53e42d68f | 2 + .../fuzz/FuzzUnmarshal/cd41233d242004e0 | 2 + .../fuzz/FuzzUnmarshal/cde5035558ff62ef | 2 + .../fuzz/FuzzUnmarshal/d43da2b1fa37becf | 2 + .../fuzz/FuzzUnmarshal/dbd8b97d5c1e0809 | 2 + .../fuzz/FuzzUnmarshal/dd7554c158a4bf76 | 2 + .../fuzz/FuzzUnmarshal/f039aafb6c8fbf5c | 2 + .../fuzz/FuzzUnmarshal/f41ce99116022ac9 | 2 + .../fuzz/FuzzUnmarshal/fccb75db092e7fe6 | 2 + pkg/rtcpreceiver/rtcpreceiver.go | 4 +- server.go | 8 - server_conn.go | 134 ++++- server_play_test.go | 291 ++++++++--- server_record_test.go | 219 ++++++-- server_session.go | 343 ++++++++++--- server_session_format.go | 45 +- server_session_media.go | 150 +++++- server_stream.go | 2 +- server_stream_format.go | 88 +++- server_stream_media.go | 91 +++- wrapped_srtp_context.go | 63 +++ 104 files changed, 4179 insertions(+), 766 deletions(-) rename examples/{server-tls => server-secure}/main.go (94%) create mode 100644 pkg/headers/key_mgmt.go create mode 100644 pkg/headers/key_mgmt_test.go rename pkg/headers/{rtpinfo.go => rtp_info.go} (100%) rename pkg/headers/{rtpinfo_test.go => rtp_info_test.go} (100%) create mode 100644 pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/531ecc27fef0609a create mode 100644 pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/771e938e4458e983 create mode 100644 pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/90d404dbb91eead6 create mode 100644 pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/a8857c4807d99b81 create mode 100644 pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/d015a7c61a819cac create mode 100644 pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/f98f6f990321cbbb create mode 100644 pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/249c6737cb7d7159 create mode 100644 pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/302fc5e96ed32a08 create mode 100644 pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/488626fc6b0fd159 create mode 100644 pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/a4fe0bdca2a17b9c create mode 100644 pkg/headers/transports.go create mode 100644 pkg/headers/transports_test.go create mode 100644 pkg/mikey/header.go create mode 100644 pkg/mikey/message.go create mode 100644 pkg/mikey/message_test.go create mode 100644 pkg/mikey/payload.go create mode 100644 pkg/mikey/payload_kemac.go create mode 100644 pkg/mikey/payload_rand.go create mode 100644 pkg/mikey/payload_sp.go create mode 100644 pkg/mikey/payload_t.go create mode 100644 pkg/mikey/sub_payload_key_data.go create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/0e2aa46892dbf440 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/0e400a18088f3ef1 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/1308b4f12c633cfe create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/15a72f9c17aa83eb create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/21d92b615e38b74d create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/2968ff6b6c107cd2 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/2c79d879e91381fc create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/3274006efd885c07 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/34785eb47d444797 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/34b60fae6a60ed23 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/3cf3060ecf42827a create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/4305b65c1705da06 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/450a97754aa91bb3 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/49b410a3c47b1687 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/4afa2df02fd9d00a create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/4cece13ffff9b317 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/68b704daf492f697 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/6e357cea2c7a331b create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/730d6c62c6f77fdc create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/7827f50473cee286 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/7bf4d4e3d8104096 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/7e509fbe194aa190 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/87828f4d0c4c5b02 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/8bfad049e124e765 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/93e877be95b75ad0 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/a10c26126d5754cc create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/a6078eb8043d4763 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/af43d20f8944989e create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/c20cecf53e42d68f create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/cd41233d242004e0 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/cde5035558ff62ef create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/d43da2b1fa37becf create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/dbd8b97d5c1e0809 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/dd7554c158a4bf76 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/f039aafb6c8fbf5c create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/f41ce99116022ac9 create mode 100644 pkg/mikey/testdata/fuzz/FuzzUnmarshal/fccb75db092e7fe6 create mode 100644 wrapped_srtp_context.go diff --git a/README.md b/README.md index 41bf5bb6..f62c9d55 100644 --- a/README.md +++ b/README.md @@ -6,39 +6,37 @@ [![CodeCov](https://codecov.io/gh/bluenviron/gortsplib/branch/main/graph/badge.svg)](https://app.codecov.io/gh/bluenviron/gortsplib/tree/main) [![PkgGoDev](https://pkg.go.dev/badge/github.com/bluenviron/gortsplib/v4)](https://pkg.go.dev/github.com/bluenviron/gortsplib/v4#pkg-index) -RTSP 1.0 client and server library for the Go programming language, written for [MediaMTX](https://github.com/bluenviron/mediamtx). +RTSP client and server library for the Go programming language, written for [MediaMTX](https://github.com/bluenviron/mediamtx). Go ≥ 1.23 is required. Features: * Client + * Support secure protocol variants (RTSPS, TLS, SRTP, MIKEY) * Query servers about available media streams * Read media streams from a server ("play") * Read streams with the UDP, UDP-multicast or TCP transport protocol - * Read TLS-encrypted streams (TCP only) * Switch transport protocol automatically * Read selected media streams * Pause or seek without disconnecting from the server * Write to ONVIF back channels - * Get PTS (relative) timestamp of incoming packets - * Get NTP (absolute) timestamp of incoming packets + * Get PTS (presentation timestamp) of incoming packets + * Get NTP (absolute timestamp) of incoming packets * Write media streams to a server ("record") * Write streams with the UDP or TCP transport protocol - * Write TLS-encrypted streams (TCP only) * Switch transport protocol automatically * Pause without disconnecting from the server * Server + * Support secure protocol variants (RTSPS, TLS, SRTP, MIKEY) * Handle requests from clients * Validate client credentials * Read media streams from clients ("record") * Read streams with the UDP or TCP transport protocol - * Read TLS-encrypted streams (TCP only) - * Get PTS (relative) timestamp of incoming packets - * Get NTP (absolute) timestamp of incoming packets + * Get PTS (presentation timestamp) of incoming packets + * Get NTP (absolute timestamp) of incoming packets * Serve media streams to clients ("play") * Write streams with the UDP, UDP-multicast or TCP transport protocol - * Write TLS-encrypted streams (TCP only) * Compute and provide SSRC, RTP-Info to clients * Read ONVIF back channels * Utilities @@ -94,7 +92,7 @@ Features: * [client-record-format-vp8](examples/client-record-format-vp8/main.go) * [client-record-format-vp9](examples/client-record-format-vp9/main.go) * [server](examples/server/main.go) -* [server-tls](examples/server-tls/main.go) +* [server-secure](examples/server-secure/main.go) * [server-auth](examples/server-auth/main.go) * [server-record-format-h264-to-disk](examples/server-record-format-h264-to-disk/main.go) * [server-play-format-h264-from-disk](examples/server-play-format-h264-from-disk/main.go) @@ -150,7 +148,10 @@ In RTSP, media streams are transmitted by using RTP packets, which are encoded i |----|----| |[RFC2326, RTSP 1.0](https://datatracker.ietf.org/doc/html/rfc2326)|protocol| |[RFC7826, RTSP 2.0](https://datatracker.ietf.org/doc/html/rfc7826)|protocol| +|[ONVIF Streaming Specification 23.06](https://www.onvif.org/specs/stream/ONVIF-Streaming-Spec.pdf)|protocol| |[RFC8866, SDP: Session Description Protocol](https://datatracker.ietf.org/doc/html/rfc8866)|SDP| +|[RFC4567, Key Management Extensions for Session Description Protocol (SDP) and Real Time Streaming Protocol (RTSP)](https://datatracker.ietf.org/doc/html/rfc4567)|secure variants| +|[RFC3830, MIKEY: Multimedia Internet KEYing](https://datatracker.ietf.org/doc/html/rfc3830)|secure variants| |[RTP Payload Format For AV1 (v1.0)](https://aomediacodec.github.io/av1-rtp-spec/)|payload formats / AV1| |[RTP Payload Format for VP9 Video](https://datatracker.ietf.org/doc/html/draft-ietf-payload-vp9-16)|payload formats / VP9| |[RFC7741, RTP Payload Format for VP8 Video](https://datatracker.ietf.org/doc/html/rfc7741)|payload formats / VP8| @@ -178,3 +179,4 @@ In RTSP, media streams are transmitted by using RTP packets, which are encoded i * [pion/sdp (SDP library used internally)](https://github.com/pion/sdp) * [pion/rtp (RTP library used internally)](https://github.com/pion/rtp) * [pion/rtcp (RTCP library used internally)](https://github.com/pion/rtcp) +* [pion/srtp (SRTP library used internally)](https://github.com/pion/srtp) diff --git a/client.go b/client.go index 4e2ebae4..543c8501 100644 --- a/client.go +++ b/client.go @@ -1,5 +1,5 @@ /* -Package gortsplib is a RTSP 1.0 library for the Go programming language. +Package gortsplib is a RTSP library for the Go programming language. Examples are available at https://github.com/bluenviron/gortsplib/tree/main/examples */ @@ -7,10 +7,12 @@ package gortsplib import ( "context" + "crypto/rand" "crypto/tls" "fmt" "log" "net" + "slices" "strconv" "strings" "sync" @@ -28,6 +30,7 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/format" "github.com/bluenviron/gortsplib/v4/pkg/headers" "github.com/bluenviron/gortsplib/v4/pkg/liberrors" + "github.com/bluenviron/gortsplib/v4/pkg/mikey" "github.com/bluenviron/gortsplib/v4/pkg/rtcpreceiver" "github.com/bluenviron/gortsplib/v4/pkg/rtcpsender" "github.com/bluenviron/gortsplib/v4/pkg/rtptime" @@ -118,10 +121,121 @@ func findBaseURL(sd *sdp.SessionDescription, res *base.Response, u *base.URL) (* return u, nil } -func prepareForAnnounce(desc *description.Session) { - for i, media := range desc.Medias { - media.Control = "trackID=" + strconv.FormatInt(int64(i), 10) +type clientAnnounceDataFormat struct { + localSSRC uint32 +} + +type clientAnnounceDataMedia struct { + srtpOutKey []byte + formats map[uint8]*clientAnnounceDataFormat +} + +func announceDataPickLocalSSRC( + am *clientAnnounceDataMedia, + data map[*description.Media]*clientAnnounceDataMedia, +) (uint32, error) { + var takenSSRCs []uint32 //nolint:prealloc + + for _, am := range data { + for _, af := range am.formats { + takenSSRCs = append(takenSSRCs, af.localSSRC) + } } + + for _, af := range am.formats { + takenSSRCs = append(takenSSRCs, af.localSSRC) + } + + for { + ssrc, err := randUint32() + if err != nil { + return 0, err + } + + if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) { + return ssrc, nil + } + } +} + +func generateAnnounceData( + desc *description.Session, + secure bool, +) (map[*description.Media]*clientAnnounceDataMedia, error) { + data := make(map[*description.Media]*clientAnnounceDataMedia) + + for _, medi := range desc.Medias { + am := &clientAnnounceDataMedia{ + formats: make(map[uint8]*clientAnnounceDataFormat), + } + + for _, format := range medi.Formats { + dataFormat := &clientAnnounceDataFormat{} + + var err error + dataFormat.localSSRC, err = announceDataPickLocalSSRC(am, data) + if err != nil { + return nil, err + } + + am.formats[format.PayloadType()] = dataFormat + } + + if secure { + am.srtpOutKey = make([]byte, srtpKeyLength) + _, err := rand.Read(am.srtpOutKey) + if err != nil { + return nil, err + } + } + + data[medi] = am + } + + return data, nil +} + +func prepareForAnnounce( + desc *description.Session, + announceData map[*description.Media]*clientAnnounceDataMedia, + secure bool, +) error { + for i, m := range desc.Medias { + m.Control = "trackID=" + strconv.FormatInt(int64(i), 10) + m.Secure = secure + + if secure { + announceDataMedia := announceData[m] + + ssrcs := make([]uint32, len(m.Formats)) + n := 0 + for _, af := range announceDataMedia.formats { + ssrcs[n] = af.localSSRC + n++ + } + + // use a dummy Context. + // Context is needed to extract ROC, but since client has not started streaming, + // ROC is always zero, therefore a dummy Context can be used. + srtpCtx := &wrappedSRTPContext{ + key: announceDataMedia.srtpOutKey, + ssrcs: ssrcs, + } + err := srtpCtx.initialize() + if err != nil { + return err + } + + mikeyMsg, err := mikeyGenerate(srtpCtx) + if err != nil { + return err + } + + m.KeyMgmtMikey = mikeyMsg + } + } + + return nil } func supportsGetParameter(header base.Header) bool { @@ -330,7 +444,8 @@ type Client struct { receiverReportPeriod time.Duration checkTimeoutPeriod time.Duration - connURL *base.URL + scheme string + host string ctx context.Context ctxCancel func() state clientState @@ -342,8 +457,11 @@ type Client struct { optionsSent bool useGetParameter bool lastDescribeURL *base.URL + lastDescribeDesc *description.Session baseURL *base.URL + announceData map[*description.Media]*clientAnnounceDataMedia // record effectiveTransport *Transport + effectiveSecure bool backChannelSetupped bool stdChannelSetupped bool setuppedMedias map[*description.Media]*clientMedia @@ -474,10 +592,8 @@ func (c *Client) Start(scheme string, host string) error { ctx, ctxCancel := context.WithCancel(context.Background()) - c.connURL = &base.URL{ - Scheme: scheme, - Host: host, - } + c.scheme = scheme + c.host = host c.ctx = ctx c.ctxCancel = ctxCancel c.checkTimeoutTimer = emptyTimer() @@ -820,7 +936,6 @@ func (c *Client) checkState(allowed map[clientState]struct{}) error { func (c *Client) trySwitchingProtocol() error { c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP{}) - prevConnURL := c.connURL prevBaseURL := c.baseURL prevMedias := c.setuppedMedias @@ -828,7 +943,6 @@ func (c *Client) trySwitchingProtocol() error { v := TransportTCP c.effectiveTransport = &v - c.connURL = prevConnURL // some Hikvision cameras require a describe before a setup _, _, err := c.doDescribe(c.lastDescribeURL) @@ -856,26 +970,6 @@ func (c *Client) trySwitchingProtocol() error { return nil } -func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *base.URL) (*base.Response, error) { - c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP2{}) - - prevConnURL := c.connURL - - c.reset() - - v := TransportTCP - c.effectiveTransport = &v - c.connURL = prevConnURL - - // some Hikvision cameras require a describe before a setup - _, _, err := c.doDescribe(c.lastDescribeURL) - if err != nil { - return nil, err - } - - return c.doSetup(baseURL, medi, 0, 0) -} - func (c *Client) startTransportRoutines() { c.timeDecoder = &rtptime.GlobalDecoder2{} c.timeDecoder.Initialize() @@ -968,28 +1062,30 @@ func (c *Client) connOpen() error { return nil } - if c.connURL.Scheme != "rtsp" && c.connURL.Scheme != "rtsps" { - return liberrors.ErrClientUnsupportedScheme{Scheme: c.connURL.Scheme} - } - - if c.connURL.Scheme == "rtsps" && c.Transport != nil && *c.Transport != TransportTCP { - return liberrors.ErrClientRTSPSTCP{} + if c.scheme != "rtsp" && c.scheme != "rtsps" { + return liberrors.ErrClientUnsupportedScheme{Scheme: c.scheme} } dialCtx, dialCtxCancel := context.WithTimeout(c.ctx, c.ReadTimeout) defer dialCtxCancel() - nconn, err := c.DialContext(dialCtx, "tcp", canonicalAddr(c.connURL)) + nconn, err := c.DialContext(dialCtx, "tcp", canonicalAddr(&base.URL{ + Scheme: c.scheme, + Host: c.host, + })) if err != nil { return err } - if c.connURL.Scheme == "rtsps" { + if c.scheme == "rtsps" { tlsConfig := c.TLSConfig if tlsConfig == nil { tlsConfig = &tls.Config{} } - tlsConfig.ServerName = c.connURL.Hostname() + tlsConfig.ServerName = (&base.URL{ + Scheme: c.scheme, + Host: c.host, + }).Hostname() nconn = tls.Client(nconn, tlsConfig) } @@ -1256,7 +1352,7 @@ func (c *Client) doDescribe(u *base.URL) (*description.Session, *base.Response, return nil, nil, err } - if c.connURL.Scheme == "rtsps" && ru.Scheme != "rtsps" { + if c.scheme == "rtsps" && ru.Scheme != "rtsps" { return nil, nil, fmt.Errorf("connection cannot be downgraded from RTSPS to RTSP") } @@ -1264,10 +1360,8 @@ func (c *Client) doDescribe(u *base.URL) (*description.Session, *base.Response, ru.User = u.User } - c.connURL = &base.URL{ - Scheme: ru.Scheme, - Host: ru.Host, - } + c.scheme = ru.Scheme + c.host = ru.Host return c.doDescribe(ru) } @@ -1306,6 +1400,7 @@ func (c *Client) doDescribe(u *base.URL) (*description.Session, *base.Response, desc.BaseURL = baseURL c.lastDescribeURL = u + c.lastDescribeDesc = &desc return &desc, res, nil } @@ -1340,7 +1435,15 @@ func (c *Client) doAnnounce(u *base.URL, desc *description.Session) (*base.Respo return nil, err } - prepareForAnnounce(desc) + announceData, err := generateAnnounceData(desc, c.scheme == "rtsps") + if err != nil { + return nil, err + } + + err = prepareForAnnounce(desc, announceData, c.scheme == "rtsps") + if err != nil { + return nil, err + } byts, err := desc.Marshal(false) if err != nil { @@ -1367,6 +1470,7 @@ func (c *Client) doAnnounce(u *base.URL, desc *description.Session) (*base.Respo c.baseURL = u.Clone() c.state = clientStatePreRecord + c.announceData = announceData return res, nil } @@ -1408,72 +1512,91 @@ func (c *Client) doSetup( return nil, liberrors.ErrClientCannotSetupMediasDifferentURLs{} } - th := headers.Transport{ - Mode: func() *headers.TransportMode { - if c.state == clientStatePreRecord { - v := headers.TransportModeRecord - return &v - } - // when playing, omit mode, since it causes errors with some servers. - return nil - }(), + th := headers.Transport{} + + // when playing, omit mode, since it causes errors with some servers. + if c.state == clientStatePreRecord { + v := headers.TransportModeRecord + th.Mode = &v + } + + var transport Transport + + switch { + // use transport from previous SETUP calls + case c.effectiveTransport != nil: + transport = *c.effectiveTransport + th.Secure = c.effectiveSecure + + if th.Secure && !medi.Secure { + return nil, fmt.Errorf("previous media was setupped securely but current cannot") + } + + // use transport from config, secure flag from server + case c.Transport != nil: + transport = *c.Transport + th.Secure = medi.Secure && c.scheme == "rtsps" + + // try UDP if unencrypted or secure is supported by server, otherwise try TCP + default: + th.Secure = medi.Secure && c.scheme == "rtsps" + + if th.Secure || c.scheme == "rtsp" { + transport = TransportUDP + } else { + transport = TransportTCP + } } cm := &clientMedia{ - c: c, - media: medi, + c: c, + media: medi, + secure: th.Secure, } err = cm.initialize() if err != nil { return nil, err } - if c.effectiveTransport == nil { - if c.connURL.Scheme == "rtsps" { // always use TCP if encrypted - v := TransportTCP - c.effectiveTransport = &v - } else if c.Transport != nil { // take transport from config - c.effectiveTransport = c.Transport - } - } - - var desiredTransport Transport - if c.effectiveTransport != nil { - desiredTransport = *c.effectiveTransport - } else { - desiredTransport = TransportUDP - } - - switch desiredTransport { - case TransportUDP: - if (rtpPort == 0 && rtcpPort != 0) || - (rtpPort != 0 && rtcpPort == 0) { - return nil, liberrors.ErrClientUDPPortsZero{} + switch transport { + case TransportUDP, TransportUDPMulticast: + if c.scheme == "rtsps" && !medi.Secure { + cm.close() + return nil, fmt.Errorf("server does not support secure UDP") } - if rtpPort != 0 && rtcpPort != (rtpPort+1) { - return nil, liberrors.ErrClientUDPPortsNotConsecutive{} - } - - err = cm.createUDPListeners( - false, - nil, - net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)), - net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)), - ) - if err != nil { - return nil, err - } - - v1 := headers.TransportDeliveryUnicast - th.Delivery = &v1 th.Protocol = headers.TransportProtocolUDP - th.ClientPorts = &[2]int{cm.udpRTPListener.port(), cm.udpRTCPListener.port()} - case TransportUDPMulticast: - v1 := headers.TransportDeliveryMulticast - th.Delivery = &v1 - th.Protocol = headers.TransportProtocolUDP + if transport == TransportUDP { + if (rtpPort == 0 && rtcpPort != 0) || + (rtpPort != 0 && rtcpPort == 0) { + cm.close() + return nil, liberrors.ErrClientUDPPortsZero{} + } + + if rtpPort != 0 && rtcpPort != (rtpPort+1) { + cm.close() + return nil, liberrors.ErrClientUDPPortsNotConsecutive{} + } + + err = cm.createUDPListeners( + false, + nil, + net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)), + net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)), + ) + if err != nil { + cm.close() + return nil, err + } + + v1 := headers.TransportDeliveryUnicast + th.Delivery = &v1 + th.ClientPorts = &[2]int{cm.udpRTPListener.port(), cm.udpRTCPListener.port()} + } else { + v1 := headers.TransportDeliveryMulticast + th.Delivery = &v1 + } case TransportTCP: v1 := headers.TransportDeliveryUnicast @@ -1497,6 +1620,34 @@ func (c *Client) doSetup( header["Require"] = base.HeaderValue{"www.onvif.org/ver20/backchannel"} } + if th.Secure { + ssrcs := make([]uint32, len(cm.formats)) + n := 0 + for _, cf := range cm.formats { + ssrcs[n] = cf.localSSRC + n++ + } + + var mikeyMsg *mikey.Message + mikeyMsg, err = mikeyGenerate(cm.srtpOutCtx) + if err != nil { + cm.close() + return nil, err + } + + var enc base.HeaderValue + enc, err = headers.KeyMgmt{ + URL: mediaURL.String(), + MikeyMessage: mikeyMsg, + }.Marshal() + if err != nil { + cm.close() + return nil, err + } + + header["KeyMgmt"] = enc + } + res, err := c.do(&base.Request{ Method: base.Setup, URL: mediaURL, @@ -1512,10 +1663,12 @@ func (c *Client) doSetup( // switch transport automatically if res.StatusCode == base.StatusUnsupportedTransport && - c.effectiveTransport == nil { + c.effectiveTransport == nil && c.Transport == nil { c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP2{}) v := TransportTCP c.effectiveTransport = &v + c.effectiveSecure = th.Secure + return c.doSetup(baseURL, medi, 0, 0) } @@ -1529,23 +1682,37 @@ func (c *Client) doSetup( return nil, liberrors.ErrClientTransportHeaderInvalid{Err: err} } - switch desiredTransport { + switch transport { case TransportUDP, TransportUDPMulticast: if thRes.Protocol == headers.TransportProtocolTCP { cm.close() // switch transport automatically - if c.effectiveTransport == nil && - c.Transport == nil { + if c.effectiveTransport == nil && c.Transport == nil { + c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP2{}) + c.baseURL = baseURL - return c.trySwitchingProtocol2(medi, baseURL) + + c.reset() + + v := TransportTCP + c.effectiveTransport = &v + c.effectiveSecure = th.Secure + + // some Hikvision cameras require a describe before a setup + _, _, err = c.doDescribe(c.lastDescribeURL) + if err != nil { + return nil, err + } + + return c.doSetup(baseURL, medi, 0, 0) } return nil, liberrors.ErrClientServerRequestedTCP{} } } - switch desiredTransport { + switch transport { case TransportUDP: if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast { cm.close() @@ -1592,14 +1759,17 @@ func (c *Client) doSetup( case TransportUDPMulticast: if thRes.Delivery == nil || *thRes.Delivery != headers.TransportDeliveryMulticast { + cm.close() return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{} } if thRes.Ports == nil { + cm.close() return nil, liberrors.ErrClientTransportHeaderNoPorts{} } if thRes.Destination == nil { + cm.close() return nil, liberrors.ErrClientTransportHeaderNoDestination{} } @@ -1617,6 +1787,7 @@ func (c *Client) doSetup( net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[1]), 10)), ) if err != nil { + cm.close() return nil, err } @@ -1636,22 +1807,27 @@ func (c *Client) doSetup( case TransportTCP: if thRes.Protocol != headers.TransportProtocolTCP { + cm.close() return nil, liberrors.ErrClientServerRequestedUDP{} } if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast { + cm.close() return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{} } if thRes.InterleavedIDs == nil { + cm.close() return nil, liberrors.ErrClientTransportHeaderNoInterleavedIDs{} } if (thRes.InterleavedIDs[0] + 1) != thRes.InterleavedIDs[1] { + cm.close() return nil, liberrors.ErrClientTransportHeaderInvalidInterleavedIDs{} } if c.isChannelPairInUse(thRes.InterleavedIDs[0]) { + cm.close() return &base.Response{ StatusCode: base.StatusBadRequest, }, liberrors.ErrClientTransportHeaderInterleavedIDsInUse{} @@ -1660,6 +1836,48 @@ func (c *Client) doSetup( cm.tcpChannel = thRes.InterleavedIDs[0] } + if cm.secure { + if !thRes.Secure { + cm.close() + return nil, fmt.Errorf("transport was not setupped securely") + } + + var mikeyMsg *mikey.Message + + // extract key-mgmt from (in order of priority): + // - response + // - media SDP attributes + // - session SDP attributes + switch { + case res.Header["KeyMgmt"] != nil: + var keyMgmt headers.KeyMgmt + err = keyMgmt.Unmarshal(res.Header["KeyMgmt"]) + if err != nil { + cm.close() + return nil, err + } + mikeyMsg = keyMgmt.MikeyMessage + + case medi.KeyMgmtMikey != nil: + mikeyMsg = medi.KeyMgmtMikey + + case c.lastDescribeDesc.KeyMgmtMikey != nil: + mikeyMsg = c.lastDescribeDesc.KeyMgmtMikey + + default: + return nil, fmt.Errorf("server did not provide key-mgmt data in any supported way") + } + + cm.srtpInCtx, err = mikeyToContext(mikeyMsg) + if err != nil { + cm.close() + return nil, err + } + } else if thRes.Secure { + cm.close() + return nil, fmt.Errorf("received unexpected secure profile") + } + if c.setuppedMedias == nil { c.setuppedMedias = make(map[*description.Media]*clientMedia) } @@ -1667,7 +1885,8 @@ func (c *Client) doSetup( c.setuppedMedias[medi] = cm c.baseURL = baseURL - c.effectiveTransport = &desiredTransport + c.effectiveTransport = &transport + c.effectiveSecure = th.Secure if medi.IsBackChannel { c.backChannelSetupped = true @@ -1770,12 +1989,34 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) { // do this before sending the PLAY request. if *c.effectiveTransport == TransportUDP { for _, cm := range c.setuppedMedias { - if !cm.media.IsBackChannel { - byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal() - cm.udpRTPListener.write(byts) //nolint:errcheck + if !cm.media.IsBackChannel && cm.udpRTPListener.writeAddr != nil { + buf, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal() + if cm.srtpOutCtx != nil { + encr := make([]byte, cm.c.MaxPacketSize) + encr, err = cm.srtpOutCtx.encryptRTP(encr, buf, nil) + if err != nil { + return nil, err + } + buf = encr + } + err = cm.udpRTPListener.write(buf) + if err != nil { + return nil, err + } - byts, _ = (&rtcp.ReceiverReport{}).Marshal() - cm.udpRTCPListener.write(byts) //nolint:errcheck + buf, _ = (&rtcp.ReceiverReport{}).Marshal() + if cm.srtpOutCtx != nil { + encr := make([]byte, cm.c.MaxPacketSize) + encr, err = cm.srtpOutCtx.encryptRTCP(encr, buf, nil) + if err != nil { + return nil, err + } + buf = encr + } + err = cm.udpRTCPListener.write(buf) + if err != nil { + return nil, err + } } } } @@ -1981,7 +2222,7 @@ func (c *Client) WritePacketRTP(medi *description.Media, pkt *rtp.Packet) error } // WritePacketRTPWithNTP writes a RTP packet to the server. -// ntp is the absolute time of the packet, and is sent with periodic RTCP sender reports. +// ntp is the absolute timestamp of the packet, and is sent with periodic RTCP sender reports. func (c *Client) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.Packet, ntp time.Time) error { select { case <-c.done: @@ -2020,7 +2261,7 @@ func (c *Client) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error return cm.writePacketRTCP(pkt) } -// PacketPTS returns the PTS of an incoming RTP packet. +// PacketPTS returns the PTS (presentation timestamp) of an incoming RTP packet. // It is computed by decoding the packet timestamp and sychronizing it with other tracks. // // Deprecated: replaced by PacketPTS2. @@ -2036,7 +2277,7 @@ func (c *Client) PacketPTS(medi *description.Media, pkt *rtp.Packet) (time.Durat return multiplyAndDivide(time.Duration(v), time.Second, time.Duration(ct.format.ClockRate())), true } -// PacketPTS2 returns the PTS of an incoming RTP packet. +// PacketPTS2 returns the PTS (presentation timestamp) of an incoming RTP packet. // It is computed by decoding the packet timestamp and sychronizing it with other tracks. func (c *Client) PacketPTS2(medi *description.Media, pkt *rtp.Packet) (int64, bool) { cm := c.setuppedMedias[medi] @@ -2044,8 +2285,8 @@ func (c *Client) PacketPTS2(medi *description.Media, pkt *rtp.Packet) (int64, bo return c.timeDecoder.Decode(ct.format, pkt) } -// PacketNTP returns the NTP timestamp of an incoming RTP packet. -// The NTP timestamp is computed from RTCP sender reports. +// PacketNTP returns the NTP (absolute timestamp) of an incoming RTP packet. +// The NTP is computed from RTCP sender reports. func (c *Client) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time, bool) { cm := c.setuppedMedias[medi] ct := cm.formats[pkt.PayloadType] diff --git a/client_format.go b/client_format.go index a6e41087..978ed77d 100644 --- a/client_format.go +++ b/client_format.go @@ -1,6 +1,7 @@ package gortsplib import ( + "slices" "sync/atomic" "time" @@ -15,25 +16,26 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/rtpreorderer" ) -func isClientLocalSSRCTaken(ssrc uint32, c *Client, exclude *clientFormat) bool { - for _, cm := range c.setuppedMedias { +func clientPickLocalSSRC(cf *clientFormat) (uint32, error) { + var takenSSRCs []uint32 //nolint:prealloc + + for _, cm := range cf.cm.c.setuppedMedias { for _, cf := range cm.formats { - if cf != exclude && cf.localSSRC == ssrc { - return true - } + takenSSRCs = append(takenSSRCs, cf.localSSRC) } } - return false -} -func clientPickLocalSSRC(cf *clientFormat) (uint32, error) { + for _, cf := range cf.cm.formats { + takenSSRCs = append(takenSSRCs, cf.localSSRC) + } + for { ssrc, err := randUint32() if err != nil { return 0, err } - if ssrc != 0 && !isClientLocalSSRCTaken(ssrc, cf.cm.c, cf) { + if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) { return ssrc, nil } } @@ -56,10 +58,14 @@ type clientFormat struct { } func (cf *clientFormat) initialize() error { - var err error - cf.localSSRC, err = clientPickLocalSSRC(cf) - if err != nil { - return err + if cf.cm.c.state == clientStatePreRecord { + cf.localSSRC = cf.cm.c.announceData[cf.cm.media].formats[cf.format.PayloadType()].localSSRC + } else { + var err error + cf.localSSRC, err = clientPickLocalSSRC(cf) + if err != nil { + return err + } } cf.rtpPacketsReceived = new(uint64) @@ -181,17 +187,31 @@ func (cf *clientFormat) handlePacketsLost(lost uint64) { func (cf *clientFormat) writePacketRTP(pkt *rtp.Packet, ntp time.Time) error { pkt.SSRC = cf.localSSRC - byts := make([]byte, cf.cm.c.MaxPacketSize) - n, err := pkt.MarshalTo(byts) + cf.rtcpSender.ProcessPacket(pkt, ntp, cf.format.PTSEqualsDTS(pkt)) + + maxPlainPacketSize := cf.cm.c.MaxPacketSize + if cf.cm.srtpOutCtx != nil { + maxPlainPacketSize -= srtpOverhead + } + + buf := make([]byte, maxPlainPacketSize) + n, err := pkt.MarshalTo(buf) if err != nil { return err } - byts = byts[:n] + buf = buf[:n] - cf.rtcpSender.ProcessPacket(pkt, ntp, cf.format.PTSEqualsDTS(pkt)) + if cf.cm.srtpOutCtx != nil { + encr := make([]byte, cf.cm.c.MaxPacketSize) + encr, err = cf.cm.srtpOutCtx.encryptRTP(encr, buf, &pkt.Header) + if err != nil { + return err + } + buf = encr + } ok := cf.cm.c.writer.push(func() error { - return cf.writePacketRTPInQueue(byts) + return cf.writePacketRTPInQueue(buf) }) if !ok { return liberrors.ErrClientWriteQueueFull{} diff --git a/client_media.go b/client_media.go index ead508d7..0ff963ee 100644 --- a/client_media.go +++ b/client_media.go @@ -1,7 +1,10 @@ package gortsplib import ( + "crypto/rand" + "fmt" "net" + "strconv" "sync/atomic" "time" @@ -13,9 +16,12 @@ import ( ) type clientMedia struct { - c *Client - media *description.Media + c *Client + media *description.Media + secure bool + srtpOutCtx *wrappedSRTPContext + srtpInCtx *wrappedSRTPContext onPacketRTCP OnPacketRTCPFunc formats map[uint8]*clientFormat tcpChannel int @@ -55,6 +61,35 @@ func (cm *clientMedia) initialize() error { cm.formats[forma.PayloadType()] = f } + if cm.secure { + var srtpOutKey []byte + if cm.c.state == clientStatePreRecord { + srtpOutKey = cm.c.announceData[cm.media].srtpOutKey + } else { + srtpOutKey = make([]byte, srtpKeyLength) + _, err := rand.Read(srtpOutKey) + if err != nil { + return err + } + } + + ssrcs := make([]uint32, len(cm.formats)) + n := 0 + for _, cf := range cm.formats { + ssrcs[n] = cf.localSSRC + n++ + } + + cm.srtpOutCtx = &wrappedSRTPContext{ + key: srtpOutKey, + ssrcs: ssrcs, + } + err := cm.srtpOutCtx.initialize() + if err != nil { + return err + } + } + return nil } @@ -99,9 +134,45 @@ func (cm *clientMedia) createUDPListeners( return nil } - var err error - cm.udpRTPListener, cm.udpRTCPListener, err = createUDPListenerPair(cm.c) - return err + // pick two consecutive ports in range 65535-10000 + // RTP port must be even and RTCP port odd + for { + v, err := randInRange((65535 - 10000) / 2) + if err != nil { + return err + } + + rtpPort := v*2 + 10000 + rtcpPort := rtpPort + 1 + + cm.udpRTPListener = &clientUDPListener{ + c: cm.c, + multicastEnable: false, + multicastSourceIP: nil, + address: net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)), + } + err = cm.udpRTPListener.initialize() + if err != nil { + cm.udpRTPListener = nil + continue + } + + cm.udpRTCPListener = &clientUDPListener{ + c: cm.c, + multicastEnable: false, + multicastSourceIP: nil, + address: net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)), + } + err = cm.udpRTCPListener.initialize() + if err != nil { + cm.udpRTPListener.close() + cm.udpRTPListener = nil + cm.udpRTCPListener = nil + continue + } + + return nil + } } func (cm *clientMedia) start() { @@ -161,14 +232,44 @@ func (cm *clientMedia) findFormatByRemoteSSRC(ssrc uint32) *clientFormat { return nil } +func (cm *clientMedia) decodeRTP(payload []byte) (*rtp.Packet, error) { + if cm.srtpInCtx != nil { + var err error + payload, err = cm.srtpInCtx.decryptRTP(payload, payload, nil) + if err != nil { + return nil, err + } + } + + var pkt rtp.Packet + err := pkt.Unmarshal(payload) + return &pkt, err +} + +func (cm *clientMedia) decodeRTCP(payload []byte) ([]rtcp.Packet, error) { + if cm.srtpInCtx != nil { + var err error + payload, err = cm.srtpInCtx.decryptRTCP(payload, payload, nil) + if err != nil { + return nil, err + } + } + + pkts, err := rtcp.Unmarshal(payload) + if err != nil { + return nil, err + } + + return pkts, nil +} + func (cm *clientMedia) readPacketRTPTCPPlay(payload []byte) bool { atomic.AddUint64(cm.bytesReceived, uint64(len(payload))) now := cm.c.timeNow() atomic.StoreInt64(cm.c.tcpLastFrameTime, now.Unix()) - pkt := &rtp.Packet{} - err := pkt.Unmarshal(payload) + pkt, err := cm.decodeRTP(payload) if err != nil { cm.onPacketRTPDecodeError(err) return false @@ -196,7 +297,7 @@ func (cm *clientMedia) readPacketRTCPTCPPlay(payload []byte) bool { return false } - packets, err := rtcp.Unmarshal(payload) + packets, err := cm.decodeRTCP(payload) if err != nil { cm.onPacketRTCPDecodeError(err) return false @@ -230,7 +331,7 @@ func (cm *clientMedia) readPacketRTCPTCPRecord(payload []byte) bool { return false } - packets, err := rtcp.Unmarshal(payload) + packets, err := cm.decodeRTCP(payload) if err != nil { cm.onPacketRTCPDecodeError(err) return false @@ -253,8 +354,7 @@ func (cm *clientMedia) readPacketRTPUDPPlay(payload []byte) bool { return false } - pkt := &rtp.Packet{} - err := pkt.Unmarshal(payload) + pkt, err := cm.decodeRTP(payload) if err != nil { cm.onPacketRTPDecodeError(err) return false @@ -279,7 +379,7 @@ func (cm *clientMedia) readPacketRTCPUDPPlay(payload []byte) bool { return false } - packets, err := rtcp.Unmarshal(payload) + packets, err := cm.decodeRTCP(payload) if err != nil { cm.onPacketRTCPDecodeError(err) return false @@ -315,7 +415,7 @@ func (cm *clientMedia) readPacketRTCPUDPRecord(payload []byte) bool { return false } - packets, err := rtcp.Unmarshal(payload) + packets, err := cm.decodeRTCP(payload) if err != nil { cm.onPacketRTCPDecodeError(err) return false @@ -341,13 +441,31 @@ func (cm *clientMedia) onPacketRTCPDecodeError(err error) { } func (cm *clientMedia) writePacketRTCP(pkt rtcp.Packet) error { - byts, err := pkt.Marshal() + buf, err := pkt.Marshal() if err != nil { return err } + maxPlainPacketSize := cm.c.MaxPacketSize + if cm.srtpOutCtx != nil { + maxPlainPacketSize -= srtcpOverhead + } + + if len(buf) > maxPlainPacketSize { + return fmt.Errorf("packet is too big") + } + + if cm.srtpOutCtx != nil { + encr := make([]byte, cm.c.MaxPacketSize) + encr, err = cm.srtpOutCtx.encryptRTCP(encr, buf, nil) + if err != nil { + return err + } + buf = encr + } + ok := cm.c.writer.push(func() error { - return cm.writePacketRTCPInQueue(byts) + return cm.writePacketRTCPInQueue(buf) }) if !ok { return liberrors.ErrClientWriteQueueFull{} diff --git a/client_play_test.go b/client_play_test.go index b9668fdb..399a0d19 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -2,7 +2,9 @@ package gortsplib import ( "bytes" + "crypto/rand" "crypto/tls" + "encoding/base64" "net" "strconv" "strings" @@ -21,6 +23,7 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/description" "github.com/bluenviron/gortsplib/v4/pkg/format" "github.com/bluenviron/gortsplib/v4/pkg/headers" + "github.com/bluenviron/gortsplib/v4/pkg/mikey" "github.com/bluenviron/mediacommon/v2/pkg/codecs/mpeg4audio" ) @@ -45,7 +48,9 @@ func mediasToSDP(medias []*description.Media) []byte { Medias: medias, } - prepareForAnnounce(desc) + for i, m := range desc.Medias { + m.Control = "trackID=" + strconv.FormatInt(int64(i), 10) + } byts, err := desc.Marshal(false) if err != nil { @@ -146,6 +151,7 @@ func TestClientPlayFormats(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -242,23 +248,55 @@ func TestClientPlayFormats(t *testing.T) { } func TestClientPlay(t *testing.T) { - for _, transport := range []string{ - "udp", - "multicast", - "tcp", - "tls", + for _, ca := range []struct { + scheme string + transport string + secure string + }{ + { + "rtsp", + "udp", + "unsecure", + }, + { + "rtsp", + "multicast", + "unsecure", + }, + { + "rtsp", + "tcp", + "unsecure", + }, + { + "rtsps", + "tcp", + "unsecure", + }, + { + "rtsps", + "udp", + "secure", + }, + { + "rtsps", + "multicast", + "secure", + }, + { + "rtsps", + "tcp", + "secure", + }, } { - t.Run(transport, func(t *testing.T) { + t.Run(ca.scheme+"_"+ca.transport+"_"+ca.secure, func(t *testing.T) { packetRecv := make(chan struct{}) listenIP := multicastCapableIP(t) var l net.Listener var err error - var scheme string - if transport == "tls" { - scheme = "rtsps" - + if ca.scheme == "rtsps" { var cert tls.Certificate cert, err = tls.X509KeyPair(serverCert, serverKey) require.NoError(t, err) @@ -267,8 +305,6 @@ func TestClientPlay(t *testing.T) { require.NoError(t, err) defer l.Close() } else { - scheme = "rtsp" - l, err = net.Listen("tcp", listenIP+":8554") require.NoError(t, err) defer l.Close() @@ -276,6 +312,7 @@ func TestClientPlay(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -287,7 +324,7 @@ func TestClientPlay(t *testing.T) { req, err2 := conn.ReadRequest() require.NoError(t, err2) require.Equal(t, base.Options, req.Method) - require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value"), req.URL) + require.Equal(t, mustParseURL(ca.scheme+"://"+listenIP+":8554/test/stream?param=value"), req.URL) err2 = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, @@ -304,7 +341,7 @@ func TestClientPlay(t *testing.T) { req, err2 = conn.ReadRequest() require.NoError(t, err2) require.Equal(t, base.Describe, req.Method) - require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value"), req.URL) + require.Equal(t, mustParseURL(ca.scheme+"://"+listenIP+":8554/test/stream?param=value"), req.URL) forma := &format.Generic{ PayloadTyp: 96, @@ -317,10 +354,12 @@ func TestClientPlay(t *testing.T) { { Type: "application", Formats: []format.Format{forma}, + Secure: ca.secure == "secure", }, { Type: "application", Formats: []format.Format{forma}, + Secure: ca.secure == "secure", }, } @@ -328,7 +367,7 @@ func TestClientPlay(t *testing.T) { StatusCode: base.StatusOK, Header: base.Header{ "Content-Type": base.HeaderValue{"application/sdp"}, - "Content-Base": base.HeaderValue{scheme + "://" + listenIP + ":8554/test/stream?param=value/"}, + "Content-Base": base.HeaderValue{ca.scheme + "://" + listenIP + ":8554/test/stream?param=value/"}, }, Body: mediasToSDP(medias), }) @@ -337,13 +376,15 @@ func TestClientPlay(t *testing.T) { var l1s [2]net.PacketConn var l2s [2]net.PacketConn var clientPorts [2]*[2]int + var srtpInCtx [2]*wrappedSRTPContext + var srtpOutCtx [2]*wrappedSRTPContext for i := 0; i < 2; i++ { req, err2 = conn.ReadRequest() require.NoError(t, err2) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL( - scheme+"://"+listenIP+":8554/test/stream?param=value/"+medias[i].Control), req.URL) + ca.scheme+"://"+listenIP+":8554/test/stream?param=value/"+medias[i].Control), req.URL) var inTH headers.Transport err2 = inTH.Unmarshal(req.Header["Transport"]) @@ -351,9 +392,48 @@ func TestClientPlay(t *testing.T) { require.Equal(t, (*headers.TransportMode)(nil), inTH.Mode) - var th headers.Transport + h := base.Header{} - switch transport { + th := headers.Transport{ + Secure: inTH.Secure, + } + + if ca.secure == "secure" { + require.True(t, inTH.Secure) + + var keyMgmt headers.KeyMgmt + err2 = keyMgmt.Unmarshal(req.Header["KeyMgmt"]) + require.NoError(t, err2) + + srtpInCtx[i], err = mikeyToContext(keyMgmt.MikeyMessage) + require.NoError(t, err2) + + outKey := make([]byte, srtpKeyLength) + _, err2 = rand.Read(outKey) + require.NoError(t, err2) + + srtpOutCtx[i] = &wrappedSRTPContext{ + key: outKey, + ssrcs: []uint32{2345423}, + } + err2 = srtpOutCtx[i].initialize() + require.NoError(t, err2) + + var mikeyMsg *mikey.Message + mikeyMsg, err = mikeyGenerate(srtpOutCtx[i]) + require.NoError(t, err) + + var enc base.HeaderValue + enc, err = headers.KeyMgmt{ + URL: req.URL.String(), + MikeyMessage: mikeyMsg, + }.Marshal() + require.NoError(t, err) + + h["KeyMgmt"] = enc + } + + switch ca.transport { case "udp": v := headers.TransportDeliveryUnicast th.Delivery = &v @@ -409,18 +489,18 @@ func TestClientPlay(t *testing.T) { require.NoError(t, err2) } - case "tcp", "tls": + case "tcp": v := headers.TransportDeliveryUnicast th.Delivery = &v th.Protocol = headers.TransportProtocolTCP th.InterleavedIDs = &[2]int{0 + i*2, 1 + i*2} } + h["Transport"] = th.Marshal() + err2 = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - Header: base.Header{ - "Transport": th.Marshal(), - }, + Header: h, }) require.NoError(t, err2) } @@ -428,7 +508,7 @@ func TestClientPlay(t *testing.T) { req, err2 = conn.ReadRequest() require.NoError(t, err2) require.Equal(t, base.Play, req.Method) - require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value/"), req.URL) + require.Equal(t, mustParseURL(ca.scheme+"://"+listenIP+":8554/test/stream?param=value/"), req.URL) require.Equal(t, base.HeaderValue{"npt=0-"}, req.Header["Range"]) err2 = conn.WriteResponse(&base.Response{ @@ -439,25 +519,34 @@ func TestClientPlay(t *testing.T) { // server -> client for i := 0; i < 2; i++ { - switch transport { + buf := testRTPPacketMarshaled + + if ca.secure == "secure" { + encr := make([]byte, 2000) + encr, err2 = srtpOutCtx[i].encryptRTP(encr, buf, nil) + require.NoError(t, err2) + buf = encr + } + + switch ca.transport { case "udp": - _, err2 = l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ + _, err2 = l1s[i].WriteTo(buf, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: clientPorts[i][0], }) require.NoError(t, err2) case "multicast": - _, err2 = l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ + _, err2 = l1s[i].WriteTo(buf, &net.UDPAddr{ IP: net.ParseIP("224.1.0.1"), Port: 25000 + i*2, }) require.NoError(t, err2) - case "tcp", "tls": + case "tcp": err2 = conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 0 + i*2, - Payload: testRTPPacketMarshaled, + Payload: buf, }, make([]byte, 1024)) require.NoError(t, err2) } @@ -465,7 +554,7 @@ func TestClientPlay(t *testing.T) { // skip firewall opening - if transport == "udp" { + if ca.transport == "udp" { for i := 0; i < 2; i++ { buf := make([]byte, 2048) _, _, err2 = l2s[i].ReadFrom(buf) @@ -476,27 +565,30 @@ func TestClientPlay(t *testing.T) { // client -> server for i := 0; i < 2; i++ { - switch transport { + var buf []byte + + switch ca.transport { case "udp", "multicast": - buf := make([]byte, 2048) + buf = make([]byte, 2048) var n int n, _, err2 = l2s[i].ReadFrom(buf) require.NoError(t, err2) - var packets []rtcp.Packet - packets, err2 = rtcp.Unmarshal(buf[:n]) - require.NoError(t, err2) - require.Equal(t, &testRTCPPacket, packets[0]) + buf = buf[:n] - case "tcp", "tls": + case "tcp": var f *base.InterleavedFrame f, err2 = conn.ReadInterleavedFrame() require.NoError(t, err2) require.Equal(t, 1+i*2, f.Channel) - var packets []rtcp.Packet - packets, err2 = rtcp.Unmarshal(f.Payload) - require.NoError(t, err2) - require.Equal(t, &testRTCPPacket, packets[0]) + buf = f.Payload } + + if ca.secure == "secure" { + buf, err2 = srtpInCtx[i].decryptRTCP(buf, buf, nil) + require.NoError(t, err2) + } + + require.Equal(t, testRTCPPacketMarshaled, buf) } close(packetRecv) @@ -504,7 +596,7 @@ func TestClientPlay(t *testing.T) { req, err2 = conn.ReadRequest() require.NoError(t, err2) require.Equal(t, base.Teardown, req.Method) - require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value/"), req.URL) + require.Equal(t, mustParseURL(ca.scheme+"://"+listenIP+":8554/test/stream?param=value/"), req.URL) err2 = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, @@ -513,11 +605,9 @@ func TestClientPlay(t *testing.T) { }() c := Client{ - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, + TLSConfig: &tls.Config{InsecureSkipVerify: true}, Transport: func() *Transport { - switch transport { + switch ca.transport { case "udp": v := TransportUDP return &v @@ -526,14 +616,14 @@ func TestClientPlay(t *testing.T) { v := TransportUDPMulticast return &v - default: // tcp, tls + default: // tcp v := TransportTCP return &v } }(), } - u, err := base.ParseURL(scheme + "://" + listenIP + ":8554/test/stream?param=value") + u, err := base.ParseURL(ca.scheme + "://" + listenIP + ":8554/test/stream?param=value") require.NoError(t, err) err = c.Start(u.Scheme, u.Host) @@ -601,9 +691,221 @@ func TestClientPlay(t *testing.T) { }, s) require.Greater(t, s.Session.BytesSent, uint64(19)) - require.Less(t, s.Session.BytesSent, uint64(41)) + require.Less(t, s.Session.BytesSent, uint64(70)) require.Greater(t, s.Session.BytesReceived, uint64(31)) - require.Less(t, s.Session.BytesReceived, uint64(37)) + require.Less(t, s.Session.BytesReceived, uint64(80)) + }) + } +} + +func TestClientPlaySRTPVariants(t *testing.T) { + for _, ca := range []string{ + "key-mgmt in sdp session", + "key-mgmt in sdp media", + "key-mgmt in setup response", + } { + t.Run(ca, func(t *testing.T) { + cert, err := tls.X509KeyPair(serverCert, serverKey) + require.NoError(t, err) + + l, err := tls.Listen("tcp", "127.0.0.1:8554", &tls.Config{Certificates: []tls.Certificate{cert}}) + require.NoError(t, err) + defer l.Close() + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + + go func() { + defer close(serverDone) + + nconn, err2 := l.Accept() + require.NoError(t, err2) + defer nconn.Close() + conn := conn.NewConn(nconn) + + req, err2 := conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Options, req.Method) + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + string(base.Setup), + string(base.Play), + }, ", ")}, + }, + }) + require.NoError(t, err2) + + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Describe, req.Method) + + outKey := make([]byte, srtpKeyLength) + _, err2 = rand.Read(outKey) + require.NoError(t, err2) + + srtpOutCtx := &wrappedSRTPContext{ + key: outKey, + ssrcs: []uint32{845234432}, + } + err2 = srtpOutCtx.initialize() + require.NoError(t, err2) + + mikeyMsg, err2 := mikeyGenerate(srtpOutCtx) + require.NoError(t, err2) + + enc, err2 := mikeyMsg.Marshal() + require.NoError(t, err2) + + var sdp string + + switch ca { + case "key-mgmt in sdp session": + sdp = "v=0\n" + + "o=actionmovie 2891092738 2891092738 IN IP4 movie.example.com\n" + + "s=Action Movie\n" + + "t=0 0\n" + + "c=IN IP4 movie.example.com\n" + + "a=key-mgmt:mikey " + base64.StdEncoding.EncodeToString(enc) + "\n" + + "m=video 0 RTP/SAVP 96\n" + + "a=rtpmap:96 H264/90000\n" + + "a=control:trackID=0\n" + + case "key-mgmt in sdp media": + sdp = "v=0\n" + + "o=actionmovie 2891092738 2891092738 IN IP4 movie.example.com\n" + + "s=Action Movie\n" + + "t=0 0\n" + + "c=IN IP4 movie.example.com\n" + + "m=video 0 RTP/SAVP 96\n" + + "a=key-mgmt:mikey " + base64.StdEncoding.EncodeToString(enc) + "\n" + + "a=rtpmap:96 H264/90000\n" + + "a=control:trackID=0\n" + + case "key-mgmt in setup response": + sdp = "v=0\n" + + "o=actionmovie 2891092738 2891092738 IN IP4 movie.example.com\n" + + "s=Action Movie\n" + + "t=0 0\n" + + "c=IN IP4 movie.example.com\n" + + "m=video 0 RTP/SAVP 96\n" + + "a=rtpmap:96 H264/90000\n" + + "a=control:trackID=0\n" + } + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"application/sdp"}, + "Content-Base": base.HeaderValue{"rtsps://127.0.0.1:8554/stream/"}, + }, + Body: []byte(sdp), + }) + require.NoError(t, err2) + + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Setup, req.Method) + + var inTH headers.Transport + err2 = inTH.Unmarshal(req.Header["Transport"]) + require.NoError(t, err2) + + require.Equal(t, (*headers.TransportMode)(nil), inTH.Mode) + + th := headers.Transport{ + Secure: true, + } + + v := headers.TransportDeliveryUnicast + th.Delivery = &v + th.Protocol = headers.TransportProtocolUDP + th.ClientPorts = inTH.ClientPorts + th.ServerPorts = &[2]int{34556, 34557} + + h := base.Header{ + "Transport": th.Marshal(), + } + + if ca == "key-mgmt in setup response" { + var enc base.HeaderValue + enc, err2 = headers.KeyMgmt{ + URL: req.URL.String(), + MikeyMessage: mikeyMsg, + }.Marshal() + require.NoError(t, err2) + + h["KeyMgmt"] = enc + } + + l1, err2 := net.ListenPacket( + "udp", net.JoinHostPort("127.0.0.1", strconv.FormatInt(int64(th.ServerPorts[0]), 10))) + require.NoError(t, err2) + defer l1.Close() + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: h, + }) + require.NoError(t, err2) + + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Play, req.Method) + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err2) + + buf := testRTPPacketMarshaled + + encr := make([]byte, 2000) + encr, err2 = srtpOutCtx.encryptRTP(encr, buf, nil) + require.NoError(t, err2) + buf = encr + + _, err2 = l1.WriteTo(buf, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: th.ClientPorts[0], + }) + require.NoError(t, err2) + + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Teardown, req.Method) + }() + + c := Client{ + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + } + + u, err := base.ParseURL("rtsps://127.0.0.1:8554/stream") + require.NoError(t, err) + + err = c.Start(u.Scheme, u.Host) + require.NoError(t, err) + defer c.Close() + + sd, _, err := c.Describe(u) + require.NoError(t, err) + + err = c.SetupAll(sd.BaseURL, sd.Medias) + require.NoError(t, err) + + packetRecv := make(chan struct{}) + + c.OnPacketRTPAny(func(_ *description.Media, _ format.Format, _ *rtp.Packet) { + close(packetRecv) + }) + + _, err = c.Play(nil) + require.NoError(t, err) + + <-packetRecv }) } } @@ -616,6 +918,7 @@ func TestClientPlayPartial(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -768,6 +1071,7 @@ func TestClientPlayContentBase(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -1072,6 +1376,7 @@ func TestClientPlayAutomaticProtocol(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -1186,6 +1491,7 @@ func TestClientPlayAutomaticProtocol(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -1359,6 +1665,7 @@ func TestClientPlayAutomaticProtocol(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -1594,6 +1901,7 @@ func TestClientPlayDifferentInterleavedIDs(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -1713,6 +2021,7 @@ func TestClientPlayRedirect(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -2000,6 +2309,7 @@ func TestClientPlayPausePlay(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -2164,6 +2474,7 @@ func TestClientPlayRTCPReport(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -2334,6 +2645,7 @@ func TestClientPlayErrorTimeout(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -2476,6 +2788,7 @@ func TestClientPlayIgnoreTCPInvalidMedia(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -2594,6 +2907,7 @@ func TestClientPlayKeepAlive(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -2765,6 +3079,7 @@ func TestClientPlayDifferentSource(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -2909,6 +3224,7 @@ func TestClientPlayDecodeErrors(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -3169,6 +3485,7 @@ func TestClientPlayPacketNTP(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) diff --git a/client_record_test.go b/client_record_test.go index 5b0489ca..f01c28af 100644 --- a/client_record_test.go +++ b/client_record_test.go @@ -2,6 +2,7 @@ package gortsplib import ( "bytes" + "crypto/rand" "crypto/tls" "fmt" "net" @@ -19,6 +20,7 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/description" "github.com/bluenviron/gortsplib/v4/pkg/format" "github.com/bluenviron/gortsplib/v4/pkg/headers" + "github.com/bluenviron/gortsplib/v4/pkg/mikey" "github.com/bluenviron/gortsplib/v4/pkg/sdp" ) @@ -126,19 +128,37 @@ func readRequestIgnoreFrames(c *conn.Conn) (*base.Request, error) { } func TestClientRecord(t *testing.T) { - for _, transport := range []string{ - "udp", - "tcp", - "tls", + for _, ca := range []struct { + scheme string + transport string + secure string + }{ + { + "rtsp", + "udp", + "unsecure", + }, + { + "rtsp", + "tcp", + "unsecure", + }, + { + "rtsps", + "udp", + "secure", + }, + { + "rtsps", + "tcp", + "secure", + }, } { - t.Run(transport, func(t *testing.T) { + t.Run(ca.scheme+"_"+ca.transport+"_"+ca.secure, func(t *testing.T) { var l net.Listener var err error - var scheme string - if transport == "tls" { - scheme = "rtsps" - + if ca.scheme == "rtsps" { var cert tls.Certificate cert, err = tls.X509KeyPair(serverCert, serverKey) require.NoError(t, err) @@ -147,8 +167,6 @@ func TestClientRecord(t *testing.T) { require.NoError(t, err) defer l.Close() } else { - scheme = "rtsp" - l, err = net.Listen("tcp", "localhost:8554") require.NoError(t, err) defer l.Close() @@ -156,6 +174,7 @@ func TestClientRecord(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -167,7 +186,7 @@ func TestClientRecord(t *testing.T) { req, err2 := conn.ReadRequest() require.NoError(t, err2) require.Equal(t, base.Options, req.Method) - require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) + require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL) err2 = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, @@ -184,7 +203,7 @@ func TestClientRecord(t *testing.T) { req, err2 = conn.ReadRequest() require.NoError(t, err2) require.Equal(t, base.Announce, req.Method) - require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) + require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL) var desc sdp.SessionDescription err = desc.Unmarshal(req.Body) @@ -194,6 +213,13 @@ func TestClientRecord(t *testing.T) { err = desc2.Unmarshal(&desc) require.NoError(t, err2) + if ca.secure == "secure" { + require.True(t, desc2.Medias[0].Secure) + + _, err = mikeyToContext(desc2.Medias[0].KeyMgmtMikey) + require.NoError(t, err) + } + err2 = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, }) @@ -203,7 +229,7 @@ func TestClientRecord(t *testing.T) { require.NoError(t, err2) require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL( - scheme+"://localhost:8554/teststream/"+desc2.Medias[0].Control), req.URL) + ca.scheme+"://localhost:8554/teststream/"+desc2.Medias[0].Control), req.URL) var inTH headers.Transport err2 = inTH.Unmarshal(req.Header["Transport"]) @@ -213,7 +239,7 @@ func TestClientRecord(t *testing.T) { var l1 net.PacketConn var l2 net.PacketConn - if transport == "udp" { + if ca.transport == "udp" { l1, err2 = net.ListenPacket("udp", "localhost:34556") require.NoError(t, err2) defer l1.Close() @@ -223,11 +249,62 @@ func TestClientRecord(t *testing.T) { defer l2.Close() } - th := headers.Transport{ - Delivery: deliveryPtr(headers.TransportDeliveryUnicast), + h := base.Header{ + "Session": headers.Session{ + Session: "ABCDE", + Timeout: uintPtr(1), + }.Marshal(), } - if transport == "udp" { + th := headers.Transport{ + Delivery: deliveryPtr(headers.TransportDeliveryUnicast), + Secure: inTH.Secure, + } + + var srtpInCtx *wrappedSRTPContext + var srtpOutCtx *wrappedSRTPContext + + if ca.secure == "secure" { + th.Secure = true + + require.True(t, th.Secure) + + var keyMgmt headers.KeyMgmt + err = keyMgmt.Unmarshal(req.Header["KeyMgmt"]) + require.NoError(t, err) + + pl1, _ := mikeyGetPayload[*mikey.PayloadKEMAC](keyMgmt.MikeyMessage) + pl2, _ := mikeyGetPayload[*mikey.PayloadKEMAC](desc2.Medias[0].KeyMgmtMikey) + require.Equal(t, pl1, pl2) + + srtpInCtx, err = mikeyToContext(keyMgmt.MikeyMessage) + require.NoError(t, err) + + outKey := make([]byte, srtpKeyLength) + _, err = rand.Read(outKey) + require.NoError(t, err) + + srtpOutCtx = &wrappedSRTPContext{ + key: outKey, + ssrcs: []uint32{2345423}, + } + err = srtpOutCtx.initialize() + require.NoError(t, err) + + var mikeyMsg *mikey.Message + mikeyMsg, err = mikeyGenerate(srtpOutCtx) + require.NoError(t, err) + + var enc base.HeaderValue + enc, err = headers.KeyMgmt{ + URL: req.URL.String(), + MikeyMessage: mikeyMsg, + }.Marshal() + require.NoError(t, err) + h["KeyMgmt"] = enc + } + + if ca.transport == "udp" { th.Protocol = headers.TransportProtocolUDP th.ServerPorts = &[2]int{34556, 34557} th.ClientPorts = inTH.ClientPorts @@ -236,54 +313,55 @@ func TestClientRecord(t *testing.T) { th.InterleavedIDs = inTH.InterleavedIDs } + h["Transport"] = th.Marshal() + err2 = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, - Header: base.Header{ - "Transport": th.Marshal(), - "Session": headers.Session{ - Session: "ABCDE", - Timeout: uintPtr(1), - }.Marshal(), - }, + Header: h, }) require.NoError(t, err2) req, err2 = conn.ReadRequest() require.NoError(t, err2) require.Equal(t, base.Record, req.Method) - require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) + require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL) err2 = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, }) require.NoError(t, err2) - var pl []byte - // client -> server - if transport == "udp" { - buf := make([]byte, 2048) + var buf []byte + + if ca.transport == "udp" { + buf = make([]byte, 2048) var n int n, _, err2 = l1.ReadFrom(buf) require.NoError(t, err2) - pl = buf[:n] + buf = buf[:n] } else { var f *base.InterleavedFrame f, err2 = conn.ReadInterleavedFrame() require.NoError(t, err2) require.Equal(t, 0, f.Channel) - pl = f.Payload + buf = f.Payload + } + + if ca.secure == "secure" { + buf, err2 = srtpInCtx.decryptRTP(buf, buf, nil) + require.NoError(t, err2) } var pkt rtp.Packet - err2 = pkt.Unmarshal(pl) + err2 = pkt.Unmarshal(buf) require.NoError(t, err2) require.Equal(t, testRTPPacket, pkt) // client -> server keepalive - if transport == "udp" { + if ca.transport == "udp" { recv := make(chan struct{}) go func() { defer close(recv) @@ -301,8 +379,17 @@ func TestClientRecord(t *testing.T) { // server -> client - if transport == "udp" { - _, err2 = l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{ + buf = testRTCPPacketMarshaled + + if ca.secure == "secure" { + encr := make([]byte, 2000) + encr, err2 = srtpOutCtx.encryptRTCP(encr, buf, nil) + require.NoError(t, err2) + buf = encr + } + + if ca.transport == "udp" { + _, err2 = l2.WriteTo(buf, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ClientPorts[1], }) @@ -310,7 +397,7 @@ func TestClientRecord(t *testing.T) { } else { err2 = conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 1, - Payload: testRTCPPacketMarshaled, + Payload: buf, }, make([]byte, 1024)) require.NoError(t, err2) } @@ -318,7 +405,7 @@ func TestClientRecord(t *testing.T) { req, err2 = conn.ReadRequest() require.NoError(t, err2) require.Equal(t, base.Teardown, req.Method) - require.Equal(t, mustParseURL(scheme+"://localhost:8554/teststream"), req.URL) + require.Equal(t, mustParseURL(ca.scheme+"://localhost:8554/teststream"), req.URL) err2 = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, @@ -333,7 +420,7 @@ func TestClientRecord(t *testing.T) { InsecureSkipVerify: true, }, Transport: func() *Transport { - if transport == "udp" { + if ca.transport == "udp" { v := TransportUDP return &v } @@ -345,7 +432,7 @@ func TestClientRecord(t *testing.T) { medi := testH264Media medias := []*description.Media{medi} - err = record(&c, scheme+"://localhost:8554/teststream", medias, + err = record(&c, ca.scheme+"://localhost:8554/teststream", medias, func(_ *description.Media, pkt rtcp.Packet) { require.Equal(t, &testRTCPPacket, pkt) close(recvDone) @@ -397,9 +484,9 @@ func TestClientRecord(t *testing.T) { }, s) require.Greater(t, s.Session.BytesSent, uint64(15)) - require.Less(t, s.Session.BytesSent, uint64(17)) + require.Less(t, s.Session.BytesSent, uint64(30)) require.Greater(t, s.Session.BytesReceived, uint64(19)) - require.Less(t, s.Session.BytesReceived, uint64(21)) + require.Less(t, s.Session.BytesReceived, uint64(40)) c.Close() <-done @@ -414,33 +501,18 @@ func TestClientRecordSocketError(t *testing.T) { for _, transport := range []string{ "udp", "tcp", - "tls", } { t.Run(transport, func(t *testing.T) { var l net.Listener var err error - var scheme string - if transport == "tls" { - scheme = "rtsps" - - var cert tls.Certificate - cert, err = tls.X509KeyPair(serverCert, serverKey) - require.NoError(t, err) - - l, err = tls.Listen("tcp", "localhost:8554", &tls.Config{Certificates: []tls.Certificate{cert}}) - require.NoError(t, err) - defer l.Close() - } else { - scheme = "rtsp" - - l, err = net.Listen("tcp", "localhost:8554") - require.NoError(t, err) - defer l.Close() - } + l, err = net.Listen("tcp", "localhost:8554") + require.NoError(t, err) + defer l.Close() serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -530,7 +602,7 @@ func TestClientRecordSocketError(t *testing.T) { medi := testH264Media medias := []*description.Media{medi} - err = record(&c, scheme+"://localhost:8554/teststream", medias, nil) + err = record(&c, "rtsp://localhost:8554/teststream", medias, nil) require.NoError(t, err) defer c.Close() @@ -559,6 +631,7 @@ func TestClientRecordPauseRecordSerial(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -707,6 +780,7 @@ func TestClientRecordPauseRecordParallel(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -885,6 +959,7 @@ func TestClientRecordAutomaticProtocol(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -1016,6 +1091,7 @@ func TestClientRecordDecodeErrors(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -1186,6 +1262,7 @@ func TestClientRecordRTCPReport(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -1371,6 +1448,7 @@ func TestClientRecordIgnoreTCPRTPPackets(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) diff --git a/client_test.go b/client_test.go index e13da6e7..eb679eb1 100644 --- a/client_test.go +++ b/client_test.go @@ -58,6 +58,7 @@ func TestClientTLSSetServerName(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -141,6 +142,7 @@ func TestClientCloseDuringRequest(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -185,6 +187,7 @@ func TestClientSession(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -246,6 +249,7 @@ func TestClientAuth(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -327,6 +331,7 @@ func TestClientCSeq(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -399,6 +404,7 @@ func TestClientDescribeCharset(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) @@ -549,6 +555,7 @@ func TestClientRelativeContentBase(t *testing.T) { serverDone := make(chan struct{}) defer func() { <-serverDone }() + go func() { defer close(serverDone) diff --git a/client_udp_listener.go b/client_udp_listener.go index 0c7bb2de..b0243555 100644 --- a/client_udp_listener.go +++ b/client_udp_listener.go @@ -4,7 +4,6 @@ import ( "crypto/rand" "math/big" "net" - "strconv" "sync/atomic" "time" @@ -24,45 +23,6 @@ func randInRange(maxVal int) (int, error) { return int(n.Int64()), nil } -func createUDPListenerPair(c *Client) (*clientUDPListener, *clientUDPListener, error) { - // choose two consecutive ports in range 65535-10000 - // RTP port must be even and RTCP port odd - for { - v, err := randInRange((65535 - 10000) / 2) - if err != nil { - return nil, nil, err - } - - rtpPort := v*2 + 10000 - rtcpPort := rtpPort + 1 - - rtpListener := &clientUDPListener{ - c: c, - multicastEnable: false, - multicastSourceIP: nil, - address: net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)), - } - err = rtpListener.initialize() - if err != nil { - continue - } - - rtcpListener := &clientUDPListener{ - c: c, - multicastEnable: false, - multicastSourceIP: nil, - address: net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)), - } - err = rtcpListener.initialize() - if err != nil { - rtpListener.close() - continue - } - - return rtpListener, rtcpListener, nil - } -} - type packetConn interface { net.PacketConn SetReadBuffer(int) error diff --git a/constants.go b/constants.go index 9aa3da01..44165676 100644 --- a/constants.go +++ b/constants.go @@ -6,4 +6,13 @@ const ( // 1500 (UDP MTU) - 20 (IP header) - 8 (UDP header) udpMaxPayloadSize = 1472 + + // 16 master key + 14 master salt + srtpKeyLength = 30 + + // 10 (HMAC SHA1 authentication tag) + srtpOverhead = 10 + + // 10 (HMAC SHA1 authentication tag) + 4 (sequence number) + srtcpOverhead = 14 ) diff --git a/examples/server-tls/main.go b/examples/server-secure/main.go similarity index 94% rename from examples/server-tls/main.go rename to examples/server-secure/main.go index 36577244..bd86aa92 100644 --- a/examples/server-tls/main.go +++ b/examples/server-secure/main.go @@ -15,7 +15,7 @@ import ( ) // This example shows how to -// 1. create a RTSP server which uses secure protocols only (RTSPS, TLS). +// 1. create a RTSP server which uses secure protocols only (RTSPS, TLS, SRTP). // 2. allow a single client to publish a stream. // 3. allow several clients to read the stream. @@ -175,9 +175,14 @@ func main() { // when TLSConfig is set, only secure protocols are used. h := &serverHandler{} h.server = &gortsplib.Server{ - Handler: h, - TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert}}, - RTSPAddress: ":8322", + Handler: h, + TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert}}, + RTSPAddress: ":8322", + UDPRTPAddress: ":8004", + UDPRTCPAddress: ":8005", + MulticastIPRange: "224.1.0.0/16", + MulticastRTPPort: 8006, + MulticastRTCPPort: 8007, } // start server and wait until a fatal error diff --git a/go.mod b/go.mod index 9e98368d..2d747cee 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/pion/rtcp v1.2.15 github.com/pion/rtp v1.8.20 github.com/pion/sdp/v3 v3.0.14 + github.com/pion/srtp/v3 v3.0.6 github.com/stretchr/testify v1.10.0 golang.org/x/net v0.41.0 ) @@ -16,7 +17,9 @@ require ( require ( github.com/asticode/go-astikit v0.30.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pion/logging v0.2.3 // indirect github.com/pion/randutil v0.1.0 // indirect + github.com/pion/transport/v3 v3.0.7 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/sys v0.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 1711e5bc..8c60c8eb 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/pion/logging v0.2.3 h1:gHuf0zpoh1GW67Nr6Gj4cv5Z9ZscU7g/EaoC/Ke/igI= +github.com/pion/logging v0.2.3/go.mod h1:z8YfknkquMe1csOrxK5kc+5/ZPAzMxbKLX5aXpbpC90= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo= @@ -17,6 +19,10 @@ github.com/pion/rtp v1.8.20 h1:8zcyqohadZE8FCBeGdyEvHiclPIezcwRQH9zfapFyYI= github.com/pion/rtp v1.8.20/go.mod h1:bAu2UFKScgzyFqvUKmbvzSdPr+NGbZtv6UB2hesqXBk= github.com/pion/sdp/v3 v3.0.14 h1:1h7gBr9FhOWH5GjWWY5lcw/U85MtdcibTyt/o6RxRUI= github.com/pion/sdp/v3 v3.0.14/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E= +github.com/pion/srtp/v3 v3.0.6 h1:E2gyj1f5X10sB/qILUGIkL4C2CqK269Xq167PbGCc/4= +github.com/pion/srtp/v3 v3.0.6/go.mod h1:BxvziG3v/armJHAaJ87euvkhHqWe9I7iiOy50K2QkhY= +github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= +github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= github.com/pkg/profile v1.4.0/go.mod h1:NWz/XGvpEW1FyYQ7fCx4dqYBLlfTcE+A9FLAkNKqjFE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/internal/teste2e/client_vs_server_test.go b/internal/teste2e/client_vs_server_test.go index a212b019..0308ee83 100644 --- a/internal/teste2e/client_vs_server_test.go +++ b/internal/teste2e/client_vs_server_test.go @@ -85,6 +85,18 @@ func TestClientVsServer(t *testing.T) { readerScheme: "rtsps", readerProto: "tcp", }, + { + publisherScheme: "rtsps", + publisherProto: "udp", + readerScheme: "rtsps", + readerProto: "tcp", + }, + { + publisherScheme: "rtsps", + publisherProto: "udp", + readerScheme: "rtsps", + readerProto: "multicast", + }, } { t.Run(ca.publisherScheme+"_"+ca.publisherProto+"_"+ ca.readerScheme+"_"+ca.readerProto, func(t *testing.T) { diff --git a/internal/teste2e/sample_server_test.go b/internal/teste2e/sample_server_test.go index 3b16d9ae..500acbea 100644 --- a/internal/teste2e/sample_server_test.go +++ b/internal/teste2e/sample_server_test.go @@ -223,17 +223,14 @@ func (sh *sampleServer) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*base func (sh *sampleServer) initialize() error { sh.s = &gortsplib.Server{ - Handler: sh, - TLSConfig: sh.tlsConfig, - RTSPAddress: "0.0.0.0:8554", - } - - if sh.tlsConfig == nil { - sh.s.UDPRTPAddress = "0.0.0.0:8000" - sh.s.UDPRTCPAddress = "0.0.0.0:8001" - sh.s.MulticastIPRange = "224.1.0.0/16" - sh.s.MulticastRTPPort = 8002 - sh.s.MulticastRTCPPort = 8003 + Handler: sh, + TLSConfig: sh.tlsConfig, + RTSPAddress: "0.0.0.0:8554", + UDPRTPAddress: "0.0.0.0:8000", + UDPRTCPAddress: "0.0.0.0:8001", + MulticastIPRange: "224.1.0.0/16", + MulticastRTPPort: 8002, + MulticastRTCPPort: 8003, } err := sh.s.Start() diff --git a/internal/teste2e/server_vs_external_test.go b/internal/teste2e/server_vs_external_test.go index 3ccb9cad..0c0e11e3 100644 --- a/internal/teste2e/server_vs_external_test.go +++ b/internal/teste2e/server_vs_external_test.go @@ -248,6 +248,46 @@ func TestServerVsExternal(t *testing.T) { readerProto: "tcp", readerSecure: "unsecure", }, + { + publisherSoft: "gstreamer", + publisherScheme: "rtsps", + publisherProto: "tcp", + publisherSecure: "unsecure", + readerSoft: "gstreamer", + readerScheme: "rtsps", + readerProto: "tcp", + readerSecure: "unsecure", + }, + { + publisherSoft: "ffmpeg", + publisherScheme: "rtsps", + publisherProto: "tcp", + publisherSecure: "unsecure", + readerSoft: "gstreamer", + readerScheme: "rtsps", + readerProto: "udp", + readerSecure: "secure", + }, + { + publisherSoft: "gstreamer", + publisherScheme: "rtsps", + publisherProto: "udp", + publisherSecure: "secure", + readerSoft: "gstreamer", + readerScheme: "rtsps", + readerProto: "udp", + readerSecure: "secure", + }, + { + publisherSoft: "gstreamer", + publisherScheme: "rtsps", + publisherProto: "udp", + publisherSecure: "secure", + readerSoft: "gstreamer", + readerScheme: "rtsps", + readerProto: "multicast", + readerSecure: "secure", + }, } { t.Run(ca.publisherSoft+"_"+ca.publisherScheme+"_"+ca.publisherProto+"_"+ca.publisherSecure+"_"+ ca.readerSoft+"_"+ca.readerScheme+"_"+ca.readerProto+"_"+ca.readerSecure, func(t *testing.T) { diff --git a/pkg/auth/verify.go b/pkg/auth/verify.go index af298e6f..0a5d3621 100644 --- a/pkg/auth/verify.go +++ b/pkg/auth/verify.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "regexp" + "slices" "github.com/bluenviron/gortsplib/v4/pkg/base" "github.com/bluenviron/gortsplib/v4/pkg/headers" @@ -25,15 +26,6 @@ func sha256Hex(in string) string { return hex.EncodeToString(h.Sum(nil)) } -func contains(list []VerifyMethod, item VerifyMethod) bool { - for _, i := range list { - if i == item { - return true - } - } - return false -} - func urlMatches(expected string, received string, isSetup bool) bool { if received == expected { return true @@ -84,9 +76,9 @@ func Verify( switch { case auth.Method == headers.AuthMethodDigest && - (contains(methods, VerifyMethodDigestMD5) && + (slices.Contains(methods, VerifyMethodDigestMD5) && (auth.Algorithm == nil || *auth.Algorithm == headers.AuthAlgorithmMD5) || - contains(methods, VerifyMethodDigestSHA256) && + slices.Contains(methods, VerifyMethodDigestSHA256) && auth.Algorithm != nil && *auth.Algorithm == headers.AuthAlgorithmSHA256): if auth.Nonce != nonce { return fmt.Errorf("wrong nonce") @@ -118,7 +110,7 @@ func Verify( return fmt.Errorf("authentication failed") } - case auth.Method == headers.AuthMethodBasic && contains(methods, VerifyMethodBasic): + case auth.Method == headers.AuthMethodBasic && slices.Contains(methods, VerifyMethodBasic): if auth.Username != user { return fmt.Errorf("authentication failed") } diff --git a/pkg/base/header.go b/pkg/base/header.go index 4e925fad..67ad705a 100644 --- a/pkg/base/header.go +++ b/pkg/base/header.go @@ -24,6 +24,9 @@ func headerKeyNormalize(in string) string { case "cseq": return "CSeq" + + case "keymgmt": + return "KeyMgmt" } return http.CanonicalHeaderKey(in) } diff --git a/pkg/base/header_test.go b/pkg/base/header_test.go index 48660cfa..96205bc6 100644 --- a/pkg/base/header_test.go +++ b/pkg/base/header_test.go @@ -92,8 +92,10 @@ var cases = []struct { []byte("www-authenticate: value\r\n" + "cseq: value\r\n" + "rtp-info: value\r\n" + + "keymgmt: value\r\n" + "\r\n"), []byte("CSeq: value\r\n" + + "KeyMgmt: value\r\n" + "RTP-Info: value\r\n" + "WWW-Authenticate: value\r\n" + "\r\n"), @@ -101,6 +103,7 @@ var cases = []struct { "CSeq": HeaderValue{"value"}, "RTP-Info": HeaderValue{"value"}, "WWW-Authenticate": HeaderValue{"value"}, + "KeyMgmt": HeaderValue{"value"}, }, }, } diff --git a/pkg/description/media.go b/pkg/description/media.go index 2e2d9540..1979ef1c 100644 --- a/pkg/description/media.go +++ b/pkg/description/media.go @@ -2,8 +2,10 @@ package description import ( + "encoding/base64" "fmt" "reflect" + "slices" "sort" "strconv" "strings" @@ -13,6 +15,7 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/base" "github.com/bluenviron/gortsplib/v4/pkg/format" + "github.com/bluenviron/gortsplib/v4/pkg/mikey" ) func getAttribute(attributes []psdp.Attribute, key string) string { @@ -78,6 +81,12 @@ type Media struct { // Control attribute. Control string + // Whether the transport is secure. + Secure bool + + // key-mgmt attribute. + KeyMgmtMikey *mikey.Message + // Formats contained into the media. Formats []format.Format } @@ -93,6 +102,24 @@ func (m *Media) Unmarshal(md *psdp.MediaDescription) error { m.IsBackChannel = isBackChannel(md.Attributes) m.Control = getAttribute(md.Attributes, "control") + m.Secure = slices.Contains(md.MediaName.Protos, "SAVP") + + if enc := getAttribute(md.Attributes, "key-mgmt"); enc != "" { + if !strings.HasPrefix(enc, "mikey ") { + return fmt.Errorf("unsupported key-mgmt: %v", enc) + } + + enc2, err := base64.StdEncoding.DecodeString(enc[len("mikey "):]) + if err != nil { + return err + } + + m.KeyMgmtMikey = &mikey.Message{} + err = m.KeyMgmtMikey.Unmarshal(enc2) + if err != nil { + return err + } + } m.Formats = nil @@ -113,11 +140,29 @@ func (m *Media) Unmarshal(md *psdp.MediaDescription) error { } // Marshal encodes the media in SDP format. +// +// Deprecated: replaced by Marshal2. func (m Media) Marshal() *psdp.MediaDescription { + ret, err := m.Marshal2() + if err != nil { + panic(err) + } + return ret +} + +// Marshal2 encodes the media in SDP format. +func (m Media) Marshal2() (*psdp.MediaDescription, error) { + var protos []string + if !m.Secure { + protos = []string{"RTP", "AVP"} + } else { + protos = []string{"RTP", "SAVP"} + } + md := &psdp.MediaDescription{ MediaName: psdp.MediaName{ Media: string(m.Type), - Protos: []string{"RTP", "AVP"}, + Protos: protos, }, } @@ -134,6 +179,18 @@ func (m Media) Marshal() *psdp.MediaDescription { }) } + if m.KeyMgmtMikey != nil { + keyEnc, err := m.KeyMgmtMikey.Marshal() + if err != nil { + return nil, err + } + + md.Attributes = append(md.Attributes, psdp.Attribute{ + Key: "key-mgmt", + Value: "mikey " + base64.StdEncoding.EncodeToString(keyEnc), + }) + } + md.Attributes = append(md.Attributes, psdp.Attribute{ Key: "control", Value: m.Control, @@ -165,7 +222,7 @@ func (m Media) Marshal() *psdp.MediaDescription { } } - return md + return md, nil } // URL returns the absolute URL of the media. diff --git a/pkg/description/session.go b/pkg/description/session.go index 6fed6f22..cc022bc1 100644 --- a/pkg/description/session.go +++ b/pkg/description/session.go @@ -1,12 +1,14 @@ package description import ( + "encoding/base64" "fmt" "strings" psdp "github.com/pion/sdp/v3" "github.com/bluenviron/gortsplib/v4/pkg/base" + "github.com/bluenviron/gortsplib/v4/pkg/mikey" "github.com/bluenviron/gortsplib/v4/pkg/sdp" ) @@ -51,6 +53,9 @@ type Session struct { // Whether to use multicast. Multicast bool + // key-mgmt attribute. + KeyMgmtMikey *mikey.Message + // FEC groups (RFC5109). FECGroups []SessionFECGroup @@ -77,6 +82,23 @@ func (d *Session) Unmarshal(ssd *sdp.SessionDescription) error { d.Title = "" } + if enc := getAttribute(ssd.Attributes, "key-mgmt"); enc != "" { + if !strings.HasPrefix(enc, "mikey ") { + return fmt.Errorf("unsupported key-mgmt: %v", enc) + } + + enc2, err := base64.StdEncoding.DecodeString(enc[len("mikey "):]) + if err != nil { + return err + } + + d.KeyMgmtMikey = &mikey.Message{} + err = d.KeyMgmtMikey.Unmarshal(enc2) + if err != nil { + return err + } + } + if len(ssd.MediaDescriptions) == 0 { return fmt.Errorf("no media streams are present in SDP") } @@ -163,11 +185,29 @@ func (d Session) Marshal(_ bool) ([]byte, error) { }) } + if d.KeyMgmtMikey != nil { + keyEnc, err := d.KeyMgmtMikey.Marshal() + if err != nil { + return nil, err + } + + sout.Attributes = append(sout.Attributes, psdp.Attribute{ + Key: "key-mgmt", + Value: "mikey " + base64.StdEncoding.EncodeToString(keyEnc), + }) + } + sout.MediaDescriptions = make([]*psdp.MediaDescription, len(d.Medias)) for i, media := range d.Medias { - sout.MediaDescriptions[i] = media.Marshal() + med, err := media.Marshal2() + if err != nil { + return nil, err + } + sout.MediaDescriptions[i] = med } - return sout.Marshal() + out, _ := sout.Marshal() + + return out, nil } diff --git a/pkg/description/session_test.go b/pkg/description/session_test.go index 5db0bcde..713260f6 100644 --- a/pkg/description/session_test.go +++ b/pkg/description/session_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/require" "github.com/bluenviron/gortsplib/v4/pkg/format" + "github.com/bluenviron/gortsplib/v4/pkg/mikey" "github.com/bluenviron/gortsplib/v4/pkg/sdp" ) @@ -638,6 +639,194 @@ var casesSession = []struct { }, }, }, + { + "key-mgmt in session", + "v=0\n" + + "o=actionmovie 2891092738 2891092738 IN IP4 movie.example.com\n" + + "s=Action Movie\n" + + "t=0 0\n" + + "c=IN IP4 movie.example.com\n" + + "a=key-mgmt:mikey AQAFAHojKV4BAACVjCMnAAAAAAsA6/mdTLBeokwKEGwcAuPrxj6/enyb+" + + "A2+rNcBAAAAFQABAQEBEAIBAQMBCgcBAQgBAQoBAQAAACIAIAAeX8XvOCzIMh0JTOWivWLxEflTUSp1fjj2i8xG7D9DAA==\r\n" + + "m=video 0 RTP/SAVP 96\n" + + "a=rtpmap:96 H264/90000\n" + + "a=control:trackID=0\n", + "v=0\r\n" + + "o=- 0 0 IN IP4 127.0.0.1\r\n" + + "s=Action Movie\r\n" + + "c=IN IP4 0.0.0.0\r\n" + + "t=0 0\r\n" + + "a=key-mgmt:mikey AQAFAHojKV4BAACVjCMnAAAAAAsA6/mdTLBeokwKEGwcAuPrxj6/enyb+" + + "A2+rNcBAAAAFQABAQEBEAIBAQMBCgcBAQgBAQoBAQAAACIAIAAeX8XvOCzIMh0JTOWivWLxEflTUSp1fjj2i8xG7D9DAA==\r\n" + + "m=video 0 RTP/SAVP 96\r\n" + + "a=control:trackID=0\r\n" + + "a=rtpmap:96 H264/90000\r\n", + Session{ + Title: "Action Movie", + KeyMgmtMikey: &mikey.Message{ //nolint:dupl + Header: mikey.Header{ + Version: 1, + CSBID: 2049124702, + CSIDMapInfo: []mikey.SRTPIDEntry{{ + SSRC: 2508989223, + }}, + }, + Payloads: []mikey.Payload{ + &mikey.PayloadT{ + TSValue: 17003794820816085580, + }, + &mikey.PayloadRAND{ + Data: []byte{ + 0x6c, 0x1c, 0x02, 0xe3, 0xeb, 0xc6, 0x3e, 0xbf, + 0x7a, 0x7c, 0x9b, 0xf8, 0x0d, 0xbe, 0xac, 0xd7, + }, + }, + &mikey.PayloadSP{ + PolicyParams: []mikey.PayloadSPPolicyParam{ + { + Type: 0, Value: []byte{1}, + }, + { + Type: 1, Value: []byte{0x10}, + }, + { + Type: 2, Value: []byte{1}, + }, + { + Type: 3, Value: []byte{0x0a}, + }, + { + Type: 7, Value: []byte{1}, + }, + { + Type: 8, Value: []byte{1}, + }, + { + Type: 10, Value: []byte{1}, + }, + }, + }, + &mikey.PayloadKEMAC{ + SubPayloads: []*mikey.SubPayloadKeyData{ + { + Type: 2, + KeyData: []byte{ + 0x5f, 0xc5, 0xef, 0x38, 0x2c, 0xc8, 0x32, 0x1d, + 0x09, 0x4c, 0xe5, 0xa2, 0xbd, 0x62, 0xf1, 0x11, + 0xf9, 0x53, 0x51, 0x2a, 0x75, 0x7e, 0x38, 0xf6, + 0x8b, 0xcc, 0x46, 0xec, 0x3f, 0x43, + }, + }, + }, + }, + }, + }, + Medias: []*Media{ + { + Type: "video", + Control: "trackID=0", + Secure: true, + Formats: []format.Format{&format.H264{ + PayloadTyp: 96, + }}, + }, + }, + }, + }, + { + "key-mgmt in media", + "v=0\n" + + "o=actionmovie 2891092738 2891092738 IN IP4 movie.example.com\n" + + "s=Action Movie\n" + + "t=0 0\n" + + "c=IN IP4 movie.example.com\n" + + "m=video 0 RTP/SAVP 96\n" + + "a=key-mgmt:mikey AQAFAHojKV4BAACVjCMnAAAAAAsA6/mdTLBeokwKEGwcAuPrxj6/enyb+" + + "A2+rNcBAAAAFQABAQEBEAIBAQMBCgcBAQgBAQoBAQAAACIAIAAeX8XvOCzIMh0JTOWivWLxEflTUSp1fjj2i8xG7D9DAA==\r\n" + + "a=rtpmap:96 H264/90000\n" + + "a=control:trackID=0\n", + "v=0\r\n" + + "o=- 0 0 IN IP4 127.0.0.1\r\n" + + "s=Action Movie\r\n" + + "c=IN IP4 0.0.0.0\r\n" + + "t=0 0\r\n" + + "m=video 0 RTP/SAVP 96\r\n" + + "a=key-mgmt:mikey AQAFAHojKV4BAACVjCMnAAAAAAsA6/mdTLBeokwKEGwcAuPrxj6/enyb+" + + "A2+rNcBAAAAFQABAQEBEAIBAQMBCgcBAQgBAQoBAQAAACIAIAAeX8XvOCzIMh0JTOWivWLxEflTUSp1fjj2i8xG7D9DAA==\r\n" + + "a=control:trackID=0\r\n" + + "a=rtpmap:96 H264/90000\r\n", + Session{ + Title: "Action Movie", + Medias: []*Media{ + { + Type: "video", + Control: "trackID=0", + Secure: true, + KeyMgmtMikey: &mikey.Message{ //nolint:dupl + Header: mikey.Header{ + Version: 1, + CSBID: 2049124702, + CSIDMapInfo: []mikey.SRTPIDEntry{{ + SSRC: 2508989223, + }}, + }, + Payloads: []mikey.Payload{ + &mikey.PayloadT{ + TSValue: 17003794820816085580, + }, + &mikey.PayloadRAND{ + Data: []byte{ + 0x6c, 0x1c, 0x02, 0xe3, 0xeb, 0xc6, 0x3e, 0xbf, + 0x7a, 0x7c, 0x9b, 0xf8, 0x0d, 0xbe, 0xac, 0xd7, + }, + }, + &mikey.PayloadSP{ + PolicyParams: []mikey.PayloadSPPolicyParam{ + { + Type: 0, Value: []byte{1}, + }, + { + Type: 1, Value: []byte{0x10}, + }, + { + Type: 2, Value: []byte{1}, + }, + { + Type: 3, Value: []byte{0x0a}, + }, + { + Type: 7, Value: []byte{1}, + }, + { + Type: 8, Value: []byte{1}, + }, + { + Type: 10, Value: []byte{1}, + }, + }, + }, + &mikey.PayloadKEMAC{ + SubPayloads: []*mikey.SubPayloadKeyData{ + { + Type: 2, + KeyData: []byte{ + 0x5f, 0xc5, 0xef, 0x38, 0x2c, 0xc8, 0x32, 0x1d, + 0x09, 0x4c, 0xe5, 0xa2, 0xbd, 0x62, 0xf1, 0x11, + 0xf9, 0x53, 0x51, 0x2a, 0x75, 0x7e, 0x38, 0xf6, + 0x8b, 0xcc, 0x46, 0xec, 0x3f, 0x43, + }, + }, + }, + }, + }, + }, + Formats: []format.Format{&format.H264{ + PayloadTyp: 96, + }}, + }, + }, + }, + }, } func TestSessionUnmarshal(t *testing.T) { diff --git a/pkg/format/rtpac3/decoder.go b/pkg/format/rtpac3/decoder.go index 60bce298..931ff162 100644 --- a/pkg/format/rtpac3/decoder.go +++ b/pkg/format/rtpac3/decoder.go @@ -49,7 +49,6 @@ func (d *Decoder) resetFragments() { } // Decode decodes frames from a RTP packet. -// It returns the frames and the PTS of the first frame. func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, error) { if len(pkt.Payload) < 2 { d.resetFragments() diff --git a/pkg/format/rtplpcm/decoder.go b/pkg/format/rtplpcm/decoder.go index 535a472f..11a2aeed 100644 --- a/pkg/format/rtplpcm/decoder.go +++ b/pkg/format/rtplpcm/decoder.go @@ -22,7 +22,6 @@ func (d *Decoder) Init() error { } // Decode decodes audio samples from a RTP packet. -// It returns audio samples and PTS of the first sample. func (d *Decoder) Decode(pkt *rtp.Packet) ([]byte, error) { plen := len(pkt.Payload) if (plen % d.sampleSize) != 0 { diff --git a/pkg/format/rtpmpeg4audio/decoder.go b/pkg/format/rtpmpeg4audio/decoder.go index 4d980e8e..a87429ec 100644 --- a/pkg/format/rtpmpeg4audio/decoder.go +++ b/pkg/format/rtpmpeg4audio/decoder.go @@ -52,8 +52,6 @@ func (d *Decoder) resetFragments() { } // Decode decodes AUs from a RTP packet. -// It returns the AUs and the PTS of the first AU. -// The PTS of subsequent AUs can be calculated by adding time.Second*mpeg4audio.SamplesPerAccessUnit/clockRate. func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, error) { if !d.LATM { return d.decodeGeneric(pkt) diff --git a/pkg/headers/key_mgmt.go b/pkg/headers/key_mgmt.go new file mode 100644 index 00000000..34cb78be --- /dev/null +++ b/pkg/headers/key_mgmt.go @@ -0,0 +1,86 @@ +package headers + +import ( + "encoding/base64" + "fmt" + + "github.com/bluenviron/gortsplib/v4/pkg/base" + "github.com/bluenviron/gortsplib/v4/pkg/mikey" +) + +// KeyMgmt is a KeyMgmt header. +type KeyMgmt struct { + URL string + MikeyMessage *mikey.Message +} + +// Unmarshal decodes a KeyMgmt header. +func (h *KeyMgmt) Unmarshal(v base.HeaderValue) error { + if len(v) == 0 { + return fmt.Errorf("value not provided") + } + + if len(v) > 1 { + return fmt.Errorf("value provided multiple times (%v)", v) + } + + kvs, err := keyValParse(v[0], ';') + if err != nil { + return err + } + + protocolProvided := false + uriProvided := false + + for k, v := range kvs { + switch k { + case "prot": + if v != "mikey" { + return fmt.Errorf("unsupported protocol: %v", v) + } + protocolProvided = true + + case "uri": + h.URL = v + uriProvided = true + + case "data": + byts, err := base64.StdEncoding.DecodeString(v) + if err != nil { + return fmt.Errorf("invalid data: %w", err) + } + + h.MikeyMessage = &mikey.Message{} + err = h.MikeyMessage.Unmarshal(byts) + if err != nil { + return fmt.Errorf("invalid data: %w", err) + } + } + } + + if !protocolProvided { + return fmt.Errorf("protocol not provided") + } + + if !uriProvided { + return fmt.Errorf("URI not provided") + } + + if h.MikeyMessage == nil { + return fmt.Errorf("mikey message not provided") + } + + return nil +} + +// Marshal encodes a KeyMgmt header. +func (h KeyMgmt) Marshal() (base.HeaderValue, error) { + buf, err := h.MikeyMessage.Marshal() + if err != nil { + return nil, err + } + + encData := base64.StdEncoding.EncodeToString(buf) + + return base.HeaderValue{`prot=mikey;uri="` + h.URL + `";data="` + encData + `"`}, nil +} diff --git a/pkg/headers/key_mgmt_test.go b/pkg/headers/key_mgmt_test.go new file mode 100644 index 00000000..7a685e64 --- /dev/null +++ b/pkg/headers/key_mgmt_test.go @@ -0,0 +1,143 @@ +package headers + +import ( + "testing" + + "github.com/bluenviron/gortsplib/v4/pkg/base" + "github.com/bluenviron/gortsplib/v4/pkg/mikey" + "github.com/stretchr/testify/require" +) + +var casesKeyMgmt = []struct { + name string + vin base.HeaderValue + vout base.HeaderValue + h KeyMgmt +}{ + { + "standard", + base.HeaderValue{`prot=mikey;` + + `uri="rtsps://127.0.0.1:8322/stream/trackID=0";` + + `data="AQAFAHojKV4BAACVjCMnAAAAAAsA6/mdTLBeokwKEGwcAuPrxj6/enyb+` + + `A2+rNcBAAAAFQABAQEBEAIBAQMBCgcBAQgBAQoBAQAAACIAIAAeX8XvOCzIMh0JTOWivWLxEflTUSp1fjj2i8xG7D9DAA=="`}, + base.HeaderValue{`prot=mikey;` + + `uri="rtsps://127.0.0.1:8322/stream/trackID=0";` + + `data="AQAFAHojKV4BAACVjCMnAAAAAAsA6/mdTLBeokwKEGwcAuPrxj6/enyb+` + + `A2+rNcBAAAAFQABAQEBEAIBAQMBCgcBAQgBAQoBAQAAACIAIAAeX8XvOCzIMh0JTOWivWLxEflTUSp1fjj2i8xG7D9DAA=="`}, + KeyMgmt{ + URL: "rtsps://127.0.0.1:8322/stream/trackID=0", + MikeyMessage: &mikey.Message{ + Header: mikey.Header{ + Version: 1, + CSBID: 2049124702, + CSIDMapInfo: []mikey.SRTPIDEntry{ + { + SSRC: 2508989223, + }, + }, + }, + Payloads: []mikey.Payload{ + &mikey.PayloadT{ + TSValue: 17003794820816085580, + }, + &mikey.PayloadRAND{ + Data: []byte{ + 0x6c, 0x1c, 0x02, 0xe3, 0xeb, 0xc6, 0x3e, 0xbf, + 0x7a, 0x7c, 0x9b, 0xf8, 0x0d, 0xbe, 0xac, 0xd7, + }, + }, + &mikey.PayloadSP{ + PolicyParams: []mikey.PayloadSPPolicyParam{ + { + Type: 0, Value: []byte{1}, + }, + { + Type: 1, Value: []byte{0x10}, + }, + { + Type: 2, Value: []byte{1}, + }, + { + Type: 3, Value: []byte{0x0a}, + }, + { + Type: 7, Value: []byte{1}, + }, + { + Type: 8, Value: []byte{1}, + }, + { + Type: 10, Value: []byte{1}, + }, + }, + }, + &mikey.PayloadKEMAC{ + SubPayloads: []*mikey.SubPayloadKeyData{ + { + Type: 2, + KeyData: []byte{ + 0x5f, 0xc5, 0xef, 0x38, 0x2c, 0xc8, 0x32, 0x1d, + 0x09, 0x4c, 0xe5, 0xa2, 0xbd, 0x62, 0xf1, 0x11, + 0xf9, 0x53, 0x51, 0x2a, 0x75, 0x7e, 0x38, 0xf6, + 0x8b, 0xcc, 0x46, 0xec, 0x3f, 0x43, + }, + }, + }, + }, + }, + }, + }, + }, +} + +func TestKeyMgmtUnmarshal(t *testing.T) { + for _, ca := range casesKeyMgmt { + t.Run(ca.name, func(t *testing.T) { + var h KeyMgmt + err := h.Unmarshal(ca.vin) + require.NoError(t, err) + require.Equal(t, ca.h, h) + }) + } +} + +func TestKeyMgmtMarshal(t *testing.T) { + for _, ca := range casesKeyMgmt { + t.Run(ca.name, func(t *testing.T) { + req, err := ca.h.Marshal() + require.NoError(t, err) + require.Equal(t, ca.vout, req) + }) + } +} + +func FuzzKeyMgmtUnmarshal(f *testing.F) { + for _, ca := range casesKeyMgmt { + f.Add(ca.vin[0]) + } + + f.Fuzz(func(t *testing.T, b string) { + var h KeyMgmt + err := h.Unmarshal(base.HeaderValue{b}) + if err != nil { + return + } + + _, err = h.Marshal() + require.NoError(t, err) + }) +} + +func TestKeyMgmtAdditionalErrors(t *testing.T) { + func() { + var h KeyMgmt + err := h.Unmarshal(base.HeaderValue{}) + require.Error(t, err) + }() + + func() { + var h KeyMgmt + err := h.Unmarshal(base.HeaderValue{"a", "b"}) + require.Error(t, err) + }() +} diff --git a/pkg/headers/range.go b/pkg/headers/range.go index a6665ece..815d434b 100644 --- a/pkg/headers/range.go +++ b/pkg/headers/range.go @@ -268,7 +268,7 @@ func rangeValueUnmarshal(s RangeValue, v string) error { // Range is a Range header. type Range struct { - // range expressed in a certain unit. + // range expressed in some measurement units. Value RangeValue // time at which the operation is to be made effective. @@ -285,9 +285,7 @@ func (h *Range) Unmarshal(v base.HeaderValue) error { return fmt.Errorf("value provided multiple times (%v)", v) } - v0 := v[0] - - kvs, err := keyValParse(v0, ';') + kvs, err := keyValParse(v[0], ';') if err != nil { return err } diff --git a/pkg/headers/rtpinfo.go b/pkg/headers/rtp_info.go similarity index 100% rename from pkg/headers/rtpinfo.go rename to pkg/headers/rtp_info.go diff --git a/pkg/headers/rtpinfo_test.go b/pkg/headers/rtp_info_test.go similarity index 100% rename from pkg/headers/rtpinfo_test.go rename to pkg/headers/rtp_info_test.go diff --git a/pkg/headers/session.go b/pkg/headers/session.go index 683b3c19..9ff8f4d1 100644 --- a/pkg/headers/session.go +++ b/pkg/headers/session.go @@ -13,7 +13,7 @@ type Session struct { // session id Session string - // (optional) a timeout + // (optional) timeout Timeout *uint } diff --git a/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/531ecc27fef0609a b/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/531ecc27fef0609a new file mode 100644 index 00000000..d2396a14 --- /dev/null +++ b/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/531ecc27fef0609a @@ -0,0 +1,2 @@ +go test fuzz v1 +string("data") diff --git a/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/771e938e4458e983 b/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/771e938e4458e983 new file mode 100644 index 00000000..ee3f3399 --- /dev/null +++ b/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/771e938e4458e983 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("0") diff --git a/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/90d404dbb91eead6 b/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/90d404dbb91eead6 new file mode 100644 index 00000000..73b21d06 --- /dev/null +++ b/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/90d404dbb91eead6 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("prot") diff --git a/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/a8857c4807d99b81 b/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/a8857c4807d99b81 new file mode 100644 index 00000000..a0d05139 --- /dev/null +++ b/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/a8857c4807d99b81 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("prot=\"0000") diff --git a/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/d015a7c61a819cac b/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/d015a7c61a819cac new file mode 100644 index 00000000..3974011c --- /dev/null +++ b/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/d015a7c61a819cac @@ -0,0 +1,2 @@ +go test fuzz v1 +string("data=00") diff --git a/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/f98f6f990321cbbb b/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/f98f6f990321cbbb new file mode 100644 index 00000000..0ffaa20d --- /dev/null +++ b/pkg/headers/testdata/fuzz/FuzzKeyMgmtUnmarshal/f98f6f990321cbbb @@ -0,0 +1,2 @@ +go test fuzz v1 +string("prot=mikey") diff --git a/pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/249c6737cb7d7159 b/pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/249c6737cb7d7159 new file mode 100644 index 00000000..596785b4 --- /dev/null +++ b/pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/249c6737cb7d7159 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("port=0-") diff --git a/pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/302fc5e96ed32a08 b/pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/302fc5e96ed32a08 new file mode 100644 index 00000000..f55f21fb --- /dev/null +++ b/pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/302fc5e96ed32a08 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("port=--") diff --git a/pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/488626fc6b0fd159 b/pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/488626fc6b0fd159 new file mode 100644 index 00000000..cc73fa07 --- /dev/null +++ b/pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/488626fc6b0fd159 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("port=-") diff --git a/pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/a4fe0bdca2a17b9c b/pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/a4fe0bdca2a17b9c new file mode 100644 index 00000000..b83ef465 --- /dev/null +++ b/pkg/headers/testdata/fuzz/FuzzTransportsUnmarshal/a4fe0bdca2a17b9c @@ -0,0 +1,2 @@ +go test fuzz v1 +string("source=0.0.0.0.A.0") diff --git a/pkg/headers/transport.go b/pkg/headers/transport.go index e53d2fe5..3721ff96 100644 --- a/pkg/headers/transport.go +++ b/pkg/headers/transport.go @@ -48,6 +48,8 @@ const ( ) // String implements fmt.Stringer. +// +// Deprecated: not used anymore. func (p TransportProtocol) String() string { if p == TransportProtocolUDP { return "RTP/AVP" @@ -65,6 +67,8 @@ const ( ) // String implements fmt.Stringer. +// +// Deprecated: not used anymore. func (d TransportDelivery) String() string { if d == TransportDeliveryUnicast { return "unicast" @@ -112,37 +116,40 @@ func (m TransportMode) String() string { // Transport is a Transport header. type Transport struct { - // protocol of the stream + // protocol of the stream. Protocol TransportProtocol - // (optional) delivery method of the stream + // Whether the secure variant is active. + Secure bool + + // (optional) delivery method of the stream. Delivery *TransportDelivery - // (optional) Source IP + // (optional) Source IP. Source *net.IP - // (optional) destination IP + // (optional) destination IP. Destination *net.IP - // (optional) interleaved frame ids + // (optional) interleaved frame IDs. InterleavedIDs *[2]int - // (optional) TTL + // (optional) TTL. TTL *uint - // (optional) ports + // (optional) ports. Ports *[2]int - // (optional) client ports + // (optional) client ports. ClientPorts *[2]int - // (optional) server ports + // (optional) server ports. ServerPorts *[2]int - // (optional) SSRC of the packets of the stream + // (optional) SSRC of the packets of the stream. SSRC *uint32 - // (optional) mode + // (optional) mode. Mode *TransportMode } @@ -156,14 +163,12 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error { return fmt.Errorf("value provided multiple times (%v)", v) } - v0 := v[0] - - kvs, err := keyValParse(v0, ';') + kvs, err := keyValParse(v[0], ';') if err != nil { return err } - protocolFound := false + profileFound := false for k, rv := range kvs { v := rv @@ -171,11 +176,21 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error { switch k { case "RTP/AVP", "RTP/AVP/UDP": h.Protocol = TransportProtocolUDP - protocolFound = true + profileFound = true case "RTP/AVP/TCP": h.Protocol = TransportProtocolTCP - protocolFound = true + profileFound = true + + case "RTP/SAVP", "RTP/SAVP/UDP": + h.Protocol = TransportProtocolUDP + h.Secure = true + profileFound = true + + case "RTP/SAVP/TCP": + h.Protocol = TransportProtocolTCP + h.Secure = true + profileFound = true case "unicast": v := TransportDeliveryUnicast @@ -273,8 +288,8 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error { } } - if !protocolFound { - return fmt.Errorf("protocol not found (%v)", v[0]) + if !profileFound { + return fmt.Errorf("profile is missing: %v", v[0]) } return nil @@ -284,10 +299,33 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error { func (h Transport) Marshal() base.HeaderValue { var rets []string - rets = append(rets, h.Protocol.String()) + var profile string + + switch { + case h.Protocol == TransportProtocolUDP && !h.Secure: + profile = "RTP/AVP" + case h.Protocol == TransportProtocolTCP && !h.Secure: + profile = "RTP/AVP/TCP" + case h.Protocol == TransportProtocolUDP && h.Secure: + profile = "RTP/SAVP" + case h.Protocol == TransportProtocolTCP && h.Secure: + profile = "RTP/SAVP/TCP" + } + + rets = append(rets, profile) if h.Delivery != nil { - rets = append(rets, h.Delivery.String()) + var delivery string + + switch *h.Delivery { + case TransportDeliveryUnicast: + delivery = "unicast" + + case TransportDeliveryMulticast: + delivery = "multicast" + } + + rets = append(rets, delivery) } if h.Source != nil { @@ -337,43 +375,3 @@ func (h Transport) Marshal() base.HeaderValue { return base.HeaderValue{strings.Join(rets, ";")} } - -// Transports is a Transport header with multiple transports. -type Transports []Transport - -// Unmarshal decodes a Transport header. -func (ts *Transports) Unmarshal(v base.HeaderValue) error { - if len(v) == 0 { - return fmt.Errorf("value not provided") - } - - if len(v) > 1 { - return fmt.Errorf("value provided multiple times (%v)", v) - } - - v0 := v[0] - transports := strings.Split(v0, ",") // , separated per RFC2326 section 12.39 - *ts = make([]Transport, len(transports)) - - for i, transport := range transports { - var tr Transport - err := tr.Unmarshal(base.HeaderValue{strings.TrimLeft(transport, " ")}) - if err != nil { - return err - } - (*ts)[i] = tr - } - - return nil -} - -// Marshal encodes a Transport header. -func (ts Transports) Marshal() base.HeaderValue { - vals := make([]string, len(ts)) - - for i, th := range ts { - vals[i] = th.Marshal()[0] - } - - return base.HeaderValue{strings.Join(vals, ",")} -} diff --git a/pkg/headers/transport_test.go b/pkg/headers/transport_test.go index dcf7b0d3..d537ba03 100644 --- a/pkg/headers/transport_test.go +++ b/pkg/headers/transport_test.go @@ -168,6 +168,28 @@ var casesTransport = []struct { ServerPorts: &[2]int{56002, 56003}, }, }, + { + "secure udp unicast play request", + base.HeaderValue{`RTP/SAVP;unicast;client_port=3456-3457;mode="PLAY"`}, + base.HeaderValue{`RTP/SAVP;unicast;client_port=3456-3457;mode=play`}, + Transport{ + Protocol: TransportProtocolUDP, + Secure: true, + Delivery: deliveryPtr(TransportDeliveryUnicast), + ClientPorts: &[2]int{3456, 3457}, + Mode: transportModePtr(TransportModePlay), + }, + }, + { + "secure tcp play request / response", + base.HeaderValue{`RTP/SAVP/TCP;interleaved=0-1`}, + base.HeaderValue{`RTP/SAVP/TCP;interleaved=0-1`}, + Transport{ + Protocol: TransportProtocolTCP, + Secure: true, + InterleavedIDs: &[2]int{0, 1}, + }, + }, } func TestTransportUnmarshal(t *testing.T) { @@ -190,81 +212,6 @@ func TestTransportMarshal(t *testing.T) { } } -var casesTransports = []struct { - name string - vin base.HeaderValue - vout base.HeaderValue - h Transports -}{ - { - "a", - base.HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode="PLAY", RTP/AVP/TCP;unicast;interleaved=0-1`}, - base.HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode=play,RTP/AVP/TCP;unicast;interleaved=0-1`}, - Transports{ - { - Protocol: TransportProtocolUDP, - Delivery: deliveryPtr(TransportDeliveryUnicast), - ClientPorts: &[2]int{3456, 3457}, - Mode: transportModePtr(TransportModePlay), - }, - Transport{ - Protocol: TransportProtocolTCP, - Delivery: deliveryPtr(TransportDeliveryUnicast), - InterleavedIDs: &[2]int{0, 1}, - }, - }, - }, -} - -func TestTransportsUnmarshal(t *testing.T) { - for _, ca := range casesTransports { - t.Run(ca.name, func(t *testing.T) { - var h Transports - err := h.Unmarshal(ca.vin) - require.NoError(t, err) - require.Equal(t, ca.h, h) - }) - } -} - -func TestTransportsMarshal(t *testing.T) { - for _, ca := range casesTransports { - t.Run(ca.name, func(t *testing.T) { - req := ca.h.Marshal() - require.Equal(t, ca.vout, req) - }) - } -} - -func FuzzTransportsUnmarshal(f *testing.F) { - for _, ca := range casesTransports { - f.Add(ca.vin[0]) - } - - for _, ca := range casesTransport { - f.Add(ca.vin[0]) - } - - f.Add("source=aa-14187") - f.Add("destination=aa") - f.Add("interleaved=") - f.Add("ttl=") - f.Add("port=") - f.Add("client_port=") - f.Add("server_port=") - f.Add("mode=") - - f.Fuzz(func(_ *testing.T, b string) { - var h Transports - err := h.Unmarshal(base.HeaderValue{b}) - if err != nil { - return - } - - h.Marshal() - }) -} - func TestTransportAdditionalErrors(t *testing.T) { func() { var h Transport diff --git a/pkg/headers/transports.go b/pkg/headers/transports.go new file mode 100644 index 00000000..d6a39a19 --- /dev/null +++ b/pkg/headers/transports.go @@ -0,0 +1,48 @@ +package headers + +import ( + "fmt" + "strings" + + "github.com/bluenviron/gortsplib/v4/pkg/base" +) + +// Transports is a Transport header with multiple transports. +type Transports []Transport + +// Unmarshal decodes a Transport header. +func (ts *Transports) Unmarshal(v base.HeaderValue) error { + if len(v) == 0 { + return fmt.Errorf("value not provided") + } + + if len(v) > 1 { + return fmt.Errorf("value provided multiple times (%v)", v) + } + + v0 := v[0] + transports := strings.Split(v0, ",") // , separated per RFC2326 section 12.39 + *ts = make([]Transport, len(transports)) + + for i, transport := range transports { + var tr Transport + err := tr.Unmarshal(base.HeaderValue{strings.TrimLeft(transport, " ")}) + if err != nil { + return err + } + (*ts)[i] = tr + } + + return nil +} + +// Marshal encodes a Transport header. +func (ts Transports) Marshal() base.HeaderValue { + vals := make([]string, len(ts)) + + for i, th := range ts { + vals[i] = th.Marshal()[0] + } + + return base.HeaderValue{strings.Join(vals, ",")} +} diff --git a/pkg/headers/transports_test.go b/pkg/headers/transports_test.go new file mode 100644 index 00000000..9bc9f7e3 --- /dev/null +++ b/pkg/headers/transports_test.go @@ -0,0 +1,97 @@ +package headers + +import ( + "testing" + + "github.com/bluenviron/gortsplib/v4/pkg/base" + "github.com/stretchr/testify/require" +) + +var casesTransports = []struct { + name string + vin base.HeaderValue + vout base.HeaderValue + h Transports +}{ + { + "a", + base.HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode="PLAY", RTP/AVP/TCP;unicast;interleaved=0-1`}, + base.HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode=play,RTP/AVP/TCP;unicast;interleaved=0-1`}, + Transports{ + { + Protocol: TransportProtocolUDP, + Delivery: deliveryPtr(TransportDeliveryUnicast), + ClientPorts: &[2]int{3456, 3457}, + Mode: transportModePtr(TransportModePlay), + }, + Transport{ + Protocol: TransportProtocolTCP, + Delivery: deliveryPtr(TransportDeliveryUnicast), + InterleavedIDs: &[2]int{0, 1}, + }, + }, + }, +} + +func TestTransportsUnmarshal(t *testing.T) { + for _, ca := range casesTransports { + t.Run(ca.name, func(t *testing.T) { + var h Transports + err := h.Unmarshal(ca.vin) + require.NoError(t, err) + require.Equal(t, ca.h, h) + }) + } +} + +func TestTransportsMarshal(t *testing.T) { + for _, ca := range casesTransports { + t.Run(ca.name, func(t *testing.T) { + req := ca.h.Marshal() + require.Equal(t, ca.vout, req) + }) + } +} + +func FuzzTransportsUnmarshal(f *testing.F) { + for _, ca := range casesTransports { + f.Add(ca.vin[0]) + } + + for _, ca := range casesTransport { + f.Add(ca.vin[0]) + } + + f.Add("source=aa-14187") + f.Add("destination=aa") + f.Add("interleaved=") + f.Add("ttl=") + f.Add("port=") + f.Add("client_port=") + f.Add("server_port=") + f.Add("mode=") + + f.Fuzz(func(_ *testing.T, b string) { + var h Transports + err := h.Unmarshal(base.HeaderValue{b}) + if err != nil { + return + } + + h.Marshal() + }) +} + +func TestTransportsAdditionalErrors(t *testing.T) { + func() { + var h Transports + err := h.Unmarshal(base.HeaderValue{}) + require.Error(t, err) + }() + + func() { + var h Transports + err := h.Unmarshal(base.HeaderValue{"a", "b"}) + require.Error(t, err) + }() +} diff --git a/pkg/liberrors/client.go b/pkg/liberrors/client.go index 41805722..5b7b1c8b 100644 --- a/pkg/liberrors/client.go +++ b/pkg/liberrors/client.go @@ -224,6 +224,8 @@ func (e ErrClientUnsupportedScheme) Error() string { } // ErrClientRTSPSTCP is an error that can be returned by a client. +// +// Deprecated: not used anymore. type ErrClientRTSPSTCP struct{} // Error implements the error interface. diff --git a/pkg/liberrors/server.go b/pkg/liberrors/server.go index 932f58d3..987886a3 100644 --- a/pkg/liberrors/server.go +++ b/pkg/liberrors/server.go @@ -128,12 +128,27 @@ func (e ErrServerMediasDifferentPaths) Error() string { return "can't setup medias with different paths" } +// ErrServerInvalidKeyMgmtHeader is an error that can be returned by a server. +type ErrServerInvalidKeyMgmtHeader struct { + Wrapped error +} + +// Error implements the error interface. +func (e ErrServerInvalidKeyMgmtHeader) Error() string { + return fmt.Sprintf("invalid KeyMgmt header: %s", e.Wrapped.Error()) +} + // ErrServerMediasDifferentProtocols is an error that can be returned by a server. -type ErrServerMediasDifferentProtocols struct{} +// +// Deprecated: replaced by ErrServerMediasDifferentTransports. +type ErrServerMediasDifferentProtocols = ErrServerMediasDifferentTransports + +// ErrServerMediasDifferentTransports is an error that can be returned by a server. +type ErrServerMediasDifferentTransports struct{} // Error implements the error interface. func (e ErrServerMediasDifferentProtocols) Error() string { - return "can't setup medias with different protocols" + return "can't setup medias with different transports" } // ErrServerNoMediasSetup is an error that can be returned by a server. diff --git a/pkg/mikey/header.go b/pkg/mikey/header.go new file mode 100644 index 00000000..59d6f2ce --- /dev/null +++ b/pkg/mikey/header.go @@ -0,0 +1,143 @@ +package mikey + +import "fmt" + +func boolToUint8(v bool) uint8 { + if v { + return 1 + } + return 0 +} + +// DataType is a message data type. +type DataType uint8 + +// RFC3830, Table 6.1.a +const ( + DataTypeInitiatorPSK DataType = 0 +) + +// CSIDMapType is a CS ID map type. +type CSIDMapType uint8 + +// RFC3830, Table 6.1.d +const ( + CSIDMapTypeSRTPID CSIDMapType = 0 +) + +// SRTPIDEntry is an entry of a SRTP-ID map. +type SRTPIDEntry struct { + PolicyNo uint8 + SSRC uint32 + ROC uint32 +} + +// Header is a MIKEY header. +type Header struct { + Version uint8 + DataType DataType + V bool + PRFFunc uint8 + CSBID uint32 + CSIDMapType CSIDMapType + CSIDMapInfo []SRTPIDEntry +} + +func (h *Header) unmarshal(buf []byte) (int, payloadType, error) { + if len(buf) < 10 { + return 0, 0, fmt.Errorf("header too short") + } + + n := 0 + h.Version = buf[n] + n++ + + if h.Version != 1 { + return 0, 0, fmt.Errorf("unsupported version: %v", h.Version) + } + + h.DataType = DataType(buf[n]) + n++ + + if h.DataType != DataTypeInitiatorPSK { + return 0, 0, fmt.Errorf("unsupported data type: %v", h.DataType) + } + + nextPayload := payloadType(buf[n]) + n++ + + h.V = (buf[n] >> 7) != 0 + h.PRFFunc = buf[n] & 0b01111111 + n++ + + if h.V { + return 0, 0, fmt.Errorf("unsupported V: %v", h.V) + } + + if h.PRFFunc != 0 { + return 0, 0, fmt.Errorf("unsupported PRFFunc: %v", h.PRFFunc) + } + + h.CSBID = uint32(buf[n])<<24 | uint32(buf[n+1])<<16 | uint32(buf[n+2])<<8 | uint32(buf[n+3]) + n += 4 + + numCS := buf[n] + n++ + + h.CSIDMapType = CSIDMapType(buf[n]) + n++ + + if h.CSIDMapType != CSIDMapTypeSRTPID { + return 0, 0, fmt.Errorf("unsupported map type: %d", h.CSIDMapType) + } + + if len(buf[n:]) < (int(numCS) * 9) { + return 0, 0, fmt.Errorf("header too short") + } + + h.CSIDMapInfo = make([]SRTPIDEntry, numCS) + + for i := range numCS { + h.CSIDMapInfo[i].PolicyNo = buf[n] + n++ + h.CSIDMapInfo[i].SSRC = uint32(buf[n])<<24 | uint32(buf[n+1])<<16 | uint32(buf[n+2])<<8 | uint32(buf[n+3]) + n += 4 + h.CSIDMapInfo[i].ROC = uint32(buf[n])<<24 | uint32(buf[n+1])<<16 | uint32(buf[n+2])<<8 | uint32(buf[n+3]) + n += 4 + } + + return n, nextPayload, nil +} + +func (h *Header) marshalSize() int { + return 10 + len(h.CSIDMapInfo)*9 +} + +func (h *Header) marshalTo(buf []byte, nextPayload payloadType) (int, error) { + buf[0] = h.Version + buf[1] = byte(h.DataType) + buf[2] = byte(nextPayload) + buf[3] = boolToUint8(h.V)<<7 | h.PRFFunc + buf[4] = byte(h.CSBID >> 24) + buf[5] = byte(h.CSBID >> 16) + buf[6] = byte(h.CSBID >> 8) + buf[7] = byte(h.CSBID) + buf[8] = byte(len(h.CSIDMapInfo)) + buf[9] = byte(h.CSIDMapType) + n := 10 + + for _, mi := range h.CSIDMapInfo { + buf[n] = mi.PolicyNo + buf[n+1] = byte(mi.SSRC >> 24) + buf[n+2] = byte(mi.SSRC >> 16) + buf[n+3] = byte(mi.SSRC >> 8) + buf[n+4] = byte(mi.SSRC) + buf[n+5] = byte(mi.ROC >> 24) + buf[n+6] = byte(mi.ROC >> 16) + buf[n+7] = byte(mi.ROC >> 8) + buf[n+8] = byte(mi.ROC) + n += 9 + } + + return n, nil +} diff --git a/pkg/mikey/message.go b/pkg/mikey/message.go new file mode 100644 index 00000000..56d3e326 --- /dev/null +++ b/pkg/mikey/message.go @@ -0,0 +1,91 @@ +// Package mikey contains functions to decode and encode MIKEY messages. +package mikey + +import "fmt" + +// Message is a MIKEY message. +type Message struct { + Header Header + Payloads []Payload +} + +// Unmarshal decodes a Message. +func (m *Message) Unmarshal(buf []byte) error { + n, nextPayloadType, err := m.Header.unmarshal(buf) + if err != nil { + return err + } + + for nextPayloadType != 0 { + var payload Payload + + switch nextPayloadType { + case payloadTypeKEMAC: + payload = &PayloadKEMAC{} + case payloadTypeT: + payload = &PayloadT{} + case payloadTypeSP: + payload = &PayloadSP{} + case payloadTypeRAND: + payload = &PayloadRAND{} + default: + return fmt.Errorf("unsupported payload type: %d", nextPayloadType) + } + + payloadLen, err := payload.unmarshal(buf[n:]) + if err != nil { + return fmt.Errorf("unable to parse payload %d: %w", nextPayloadType, err) + } + + nextPayloadType = payloadType(buf[n]) + n += payloadLen + m.Payloads = append(m.Payloads, payload) + } + + if n < len(buf) { + return fmt.Errorf("detected %d unparsed bytes", len(buf)-n) + } + + return nil +} + +func (m *Message) marshalSize() int { + n := m.Header.marshalSize() + for _, pl := range m.Payloads { + n += pl.marshalSize() + } + return n +} + +// Marshal encodes a Message. +func (m *Message) Marshal() ([]byte, error) { + buf := make([]byte, m.marshalSize()) + + var nextPayloadType payloadType + if len(m.Payloads) != 0 { + nextPayloadType = m.Payloads[0].typ() + } + + n, err := m.Header.marshalTo(buf, nextPayloadType) + if err != nil { + return nil, err + } + + for i, pl := range m.Payloads { + if i != len(m.Payloads)-1 { + nextPayloadType = m.Payloads[i+1].typ() + } else { + nextPayloadType = 0 + } + + buf[n] = byte(nextPayloadType) + + n2, err := pl.marshalTo(buf[n:]) + if err != nil { + return nil, err + } + n += n2 + } + + return buf, nil +} diff --git a/pkg/mikey/message_test.go b/pkg/mikey/message_test.go new file mode 100644 index 00000000..8dcbf38c --- /dev/null +++ b/pkg/mikey/message_test.go @@ -0,0 +1,324 @@ +package mikey + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +var cases = []struct { + name string + enc []byte + dec Message +}{ + { + "a", + []byte{ + 0x01, 0x00, 0x05, 0x00, 0xe6, 0x9d, 0x51, 0xf8, + 0x01, 0x00, 0x00, 0x30, 0x68, 0x57, 0x60, 0x00, + 0x00, 0x00, 0x00, 0x0b, 0x00, 0xeb, 0xfe, 0x6f, + 0x2d, 0xb1, 0xc1, 0x3f, 0xd0, 0x0a, 0x10, 0xc2, + 0xdd, 0xe4, 0x43, 0xa8, 0x49, 0x30, 0xa5, 0x75, + 0x7a, 0x7e, 0xd9, 0xc3, 0xa4, 0x17, 0xfb, 0x01, + 0x00, 0x00, 0x00, 0x15, 0x00, 0x01, 0x01, 0x01, + 0x01, 0x10, 0x02, 0x01, 0x01, 0x03, 0x01, 0x0a, + 0x07, 0x01, 0x01, 0x08, 0x01, 0x01, 0x0a, 0x01, + 0x01, 0x00, 0x00, 0x00, 0x22, 0x00, 0x20, 0x00, + 0x1e, 0x90, 0x91, 0x78, 0x3d, 0xfc, 0xe8, 0xdd, + 0xcd, 0x44, 0x3a, 0x53, 0x50, 0x8b, 0x64, 0x50, + 0x9f, 0x35, 0xbd, 0x8a, 0x86, 0xbc, 0x4d, 0x8b, + 0x76, 0x37, 0xa5, 0x02, 0x49, 0x3d, 0xaf, 0x00, + }, + Message{ + Header: Header{ + Version: 1, + CSBID: 3869069816, + CSIDMapInfo: []SRTPIDEntry{ + { + PolicyNo: 0, + SSRC: 812144480, + ROC: 0, + }, + }, + }, + Payloads: []Payload{ + &PayloadT{ + TSType: 0, + TSValue: 17005151485044015056, + }, + &PayloadRAND{ + Data: []byte{ + 0xc2, 0xdd, 0xe4, 0x43, 0xa8, 0x49, 0x30, 0xa5, + 0x75, 0x7a, 0x7e, 0xd9, 0xc3, 0xa4, 0x17, 0xfb, + }, + }, + &PayloadSP{ + PolicyParams: []PayloadSPPolicyParam{ + { + Type: PayloadSPPolicyParamTypeEncrAlg, + Value: []byte{1}, + }, + { + Type: PayloadSPPolicyParamTypeSessionEncrKeyLen, + Value: []byte{0x10}, + }, + { + Type: PayloadSPPolicyParamTypeAuthAlg, + Value: []byte{1}, + }, + { + Type: PayloadSPPolicyParamTypeSessionAuthKeyLen, + Value: []byte{0x0a}, + }, + { + Type: PayloadSPPolicyParamTypeSRTPEncrOffOn, + Value: []byte{1}, + }, + { + Type: PayloadSPPolicyParamTypeSRTCPEncrOffOn, + Value: []byte{1}, + }, + { + Type: PayloadSPPolicyParamTypeSRTPAuthOffOn, + Value: []byte{1}, + }, + }, + }, + &PayloadKEMAC{ + SubPayloads: []*SubPayloadKeyData{ + { + Type: 2, + KeyData: []byte{ + 0x90, 0x91, 0x78, 0x3d, 0xfc, 0xe8, 0xdd, 0xcd, + 0x44, 0x3a, 0x53, 0x50, 0x8b, 0x64, 0x50, 0x9f, + 0x35, 0xbd, 0x8a, 0x86, 0xbc, 0x4d, 0x8b, 0x76, + 0x37, 0xa5, 0x02, 0x49, 0x3d, 0xaf, + }, + }, + }, + }, + }, + }, + }, + { + "b", + []byte{ + 0x01, 0x00, 0x05, 0x00, 0xfe, 0xaf, 0x97, 0x52, + 0x01, 0x00, 0x00, 0xcc, 0x83, 0x62, 0x37, 0x00, + 0x00, 0x00, 0x00, 0x0b, 0x00, 0xeb, 0xfe, 0xf6, + 0x6b, 0xa2, 0x8c, 0x9b, 0x84, 0x0a, 0x10, 0x27, + 0x6e, 0x94, 0x18, 0x0e, 0x88, 0x75, 0xc2, 0xea, + 0xad, 0x31, 0xd8, 0x2f, 0x86, 0x46, 0x20, 0x01, + 0x00, 0x00, 0x00, 0x15, 0x00, 0x01, 0x01, 0x01, + 0x01, 0x10, 0x02, 0x01, 0x01, 0x03, 0x01, 0x0a, + 0x07, 0x01, 0x01, 0x08, 0x01, 0x01, 0x0a, 0x01, + 0x01, 0x00, 0x00, 0x00, 0x22, 0x00, 0x20, 0x00, + 0x1e, 0x99, 0x1b, 0x0f, 0x14, 0x8f, 0x09, 0x4b, + 0x4e, 0x5b, 0x8b, 0x30, 0x53, 0xcd, 0x62, 0x76, + 0x87, 0x7f, 0xcc, 0xed, 0x18, 0x66, 0xf1, 0x41, + 0x77, 0x2a, 0xdd, 0xdd, 0xe7, 0x06, 0x4b, 0x00, + }, + Message{ + Header: Header{ + Version: 1, + CSBID: 4272920402, + CSIDMapInfo: []SRTPIDEntry{ + { + PolicyNo: 0, + SSRC: 3431162423, + ROC: 0, + }, + }, + }, + Payloads: []Payload{ //nolint:dupl + &PayloadT{ + TSValue: 17005300185146628996, + }, + &PayloadRAND{ + Data: []byte{ + 0x27, 0x6e, 0x94, 0x18, 0x0e, 0x88, 0x75, 0xc2, + 0xea, 0xad, 0x31, 0xd8, 0x2f, 0x86, 0x46, 0x20, + }, + }, + &PayloadSP{ + PolicyParams: []PayloadSPPolicyParam{ + { + Type: PayloadSPPolicyParamTypeEncrAlg, + Value: []byte{1}, + }, + { + Type: PayloadSPPolicyParamTypeSessionEncrKeyLen, + Value: []byte{0x10}, + }, + { + Type: PayloadSPPolicyParamTypeAuthAlg, + Value: []byte{1}, + }, + { + Type: PayloadSPPolicyParamTypeSessionAuthKeyLen, + Value: []byte{0x0a}, + }, + { + Type: PayloadSPPolicyParamTypeSRTPEncrOffOn, + Value: []byte{1}, + }, + { + Type: PayloadSPPolicyParamTypeSRTCPEncrOffOn, + Value: []byte{1}, + }, + { + Type: PayloadSPPolicyParamTypeSRTPAuthOffOn, + Value: []byte{1}, + }, + }, + }, + &PayloadKEMAC{ + SubPayloads: []*SubPayloadKeyData{ + { + Type: 2, + KeyData: []byte{ + 0x99, 0x1b, 0x0f, 0x14, 0x8f, 0x09, 0x4b, 0x4e, + 0x5b, 0x8b, 0x30, 0x53, 0xcd, 0x62, 0x76, 0x87, + 0x7f, 0xcc, 0xed, 0x18, 0x66, 0xf1, 0x41, 0x77, + 0x2a, 0xdd, 0xdd, 0xe7, 0x06, 0x4b, + }, + }, + }, + }, + }, + }, + }, + { + "c", + []byte{ + 0x01, 0x00, 0x05, 0x00, 0x7d, 0xe1, 0x27, 0xa6, + 0x02, 0x00, 0x00, 0xcc, 0x83, 0x62, 0x37, 0x00, + 0x00, 0x00, 0x00, 0x00, 0xb5, 0xcc, 0x3b, 0xf2, + 0x00, 0x00, 0x00, 0x00, 0x0b, 0x00, 0xeb, 0xfe, + 0xf6, 0x6b, 0xa2, 0xb1, 0xf6, 0x87, 0x0a, 0x10, + 0x61, 0xbb, 0x19, 0x94, 0x32, 0x53, 0x03, 0x56, + 0xa2, 0xd1, 0x88, 0x07, 0x15, 0x23, 0x75, 0x95, + 0x01, 0x00, 0x00, 0x00, 0x15, 0x00, 0x01, 0x01, + 0x01, 0x01, 0x10, 0x02, 0x01, 0x01, 0x03, 0x01, + 0x0a, 0x07, 0x01, 0x01, 0x08, 0x01, 0x01, 0x0a, + 0x01, 0x01, 0x00, 0x00, 0x00, 0x22, 0x00, 0x20, + 0x00, 0x1e, 0x99, 0x1b, 0x0f, 0x14, 0x8f, 0x09, + 0x4b, 0x4e, 0x5b, 0x8b, 0x30, 0x53, 0xcd, 0x62, + 0x76, 0x87, 0x7f, 0xcc, 0xed, 0x18, 0x66, 0xf1, + 0x41, 0x77, 0x2a, 0xdd, 0xdd, 0xe7, 0x06, 0x4b, + 0x00, + }, + Message{ + Header: Header{ + Version: 1, + CSBID: 2111907750, + CSIDMapInfo: []SRTPIDEntry{ + { + PolicyNo: 0, + SSRC: 3431162423, + ROC: 0, + }, + { + PolicyNo: 0, + SSRC: 3050060786, + ROC: 0, + }, + }, + }, + Payloads: []Payload{ //nolint:dupl + &PayloadT{ + TSValue: 17005300185149077127, + }, + &PayloadRAND{ + Data: []byte{ + 0x61, 0xbb, 0x19, 0x94, 0x32, 0x53, 0x03, 0x56, + 0xa2, 0xd1, 0x88, 0x07, 0x15, 0x23, 0x75, 0x95, + }, + }, + &PayloadSP{ + PolicyParams: []PayloadSPPolicyParam{ + { + Type: PayloadSPPolicyParamTypeEncrAlg, + Value: []byte{1}, + }, + { + Type: PayloadSPPolicyParamTypeSessionEncrKeyLen, + Value: []byte{0x10}, + }, + { + Type: PayloadSPPolicyParamTypeAuthAlg, + Value: []byte{1}, + }, + { + Type: PayloadSPPolicyParamTypeSessionAuthKeyLen, + Value: []byte{0x0a}, + }, + { + Type: PayloadSPPolicyParamTypeSRTPEncrOffOn, + Value: []byte{1}, + }, + { + Type: PayloadSPPolicyParamTypeSRTCPEncrOffOn, + Value: []byte{1}, + }, + { + Type: PayloadSPPolicyParamTypeSRTPAuthOffOn, + Value: []byte{1}, + }, + }, + }, + &PayloadKEMAC{ + SubPayloads: []*SubPayloadKeyData{ + { + Type: 2, + KeyData: []byte{ + 0x99, 0x1b, 0x0f, 0x14, 0x8f, 0x09, 0x4b, 0x4e, + 0x5b, 0x8b, 0x30, 0x53, 0xcd, 0x62, 0x76, 0x87, + 0x7f, 0xcc, 0xed, 0x18, 0x66, 0xf1, 0x41, 0x77, + 0x2a, 0xdd, 0xdd, 0xe7, 0x06, 0x4b, + }, + }, + }, + }, + }, + }, + }, +} + +func TestUnmarshal(t *testing.T) { + for _, ca := range cases { + t.Run(ca.name, func(t *testing.T) { + var dec Message + err := dec.Unmarshal(ca.enc) + require.NoError(t, err) + require.Equal(t, ca.dec, dec) + }) + } +} + +func TestMarshal(t *testing.T) { + for _, ca := range cases { + t.Run(ca.name, func(t *testing.T) { + enc, err := ca.dec.Marshal() + require.NoError(t, err) + require.Equal(t, ca.enc, enc) + }) + } +} + +func FuzzUnmarshal(f *testing.F) { + for _, ca := range cases { + f.Add(ca.enc) + } + + f.Fuzz(func(t *testing.T, b []byte) { + var msg Message + err := msg.Unmarshal(b) + if err != nil { + return + } + + _, err = msg.Marshal() + require.NoError(t, err) + }) +} diff --git a/pkg/mikey/payload.go b/pkg/mikey/payload.go new file mode 100644 index 00000000..efd95f55 --- /dev/null +++ b/pkg/mikey/payload.go @@ -0,0 +1,20 @@ +package mikey + +type payloadType uint8 + +// RFC3830, table 6.1.b +const ( + payloadTypeKEMAC payloadType = 1 + payloadTypeT payloadType = 5 + payloadTypeSP payloadType = 10 + payloadTypeRAND payloadType = 11 + payloadTypeKeyData payloadType = 20 +) + +// Payload is a MIKEY payload. +type Payload interface { + unmarshal(buf []byte) (int, error) + typ() payloadType + marshalSize() int + marshalTo(buf []byte) (int, error) +} diff --git a/pkg/mikey/payload_kemac.go b/pkg/mikey/payload_kemac.go new file mode 100644 index 00000000..12f16d7a --- /dev/null +++ b/pkg/mikey/payload_kemac.go @@ -0,0 +1,131 @@ +package mikey + +import "fmt" + +// PayloadKEMACEncrAlg is a encryption algorithm. +type PayloadKEMACEncrAlg uint8 + +// RFC3830, Table 6.2.a +const ( + PayloadKEMACEncrAlgNULL PayloadKEMACEncrAlg = 0 +) + +// PayloadKEMACMacAlg is a authentication algorithm. +type PayloadKEMACMacAlg uint8 + +// RFC3830, Table 6.2.b +const ( + PayloadKEMACMacAlgNULL PayloadKEMACMacAlg = 0 +) + +// PayloadKEMAC is a Key data transport payload. +type PayloadKEMAC struct { + EncrAlg PayloadKEMACEncrAlg + SubPayloads []*SubPayloadKeyData + MacAlg PayloadKEMACMacAlg +} + +func (p *PayloadKEMAC) unmarshal(buf []byte) (int, error) { + if len(buf) < 4 { + return 0, fmt.Errorf("buffer too short") + } + + n := 1 + p.EncrAlg = PayloadKEMACEncrAlg(buf[n]) + n++ + + if p.EncrAlg != PayloadKEMACEncrAlgNULL { + return 0, fmt.Errorf("unsupported encr alg: %v", p.EncrAlg) + } + + encrDataLen := int(uint16(buf[n])<<8 | uint16(buf[n+1])) + n += 2 + + if len(buf[n:]) < (encrDataLen + 1) { + return 0, fmt.Errorf("buffer too short") + } + + encrData := buf[n : n+encrDataLen] + n += encrDataLen + + sn := 0 + + for { + sp := &SubPayloadKeyData{} + spLen, err := sp.unmarshal(encrData[sn:]) + if err != nil { + return 0, err + } + + nextPayloadType := payloadType(encrData[sn]) + sn += spLen + p.SubPayloads = append(p.SubPayloads, sp) + + if nextPayloadType == 0 { + break + } + if nextPayloadType != payloadTypeKeyData { + return 0, fmt.Errorf("unsupported payload type: %v", nextPayloadType) + } + } + + if sn != len(encrData) { + return 0, fmt.Errorf("detected unread bytes") + } + + p.MacAlg = PayloadKEMACMacAlg(buf[n]) + n++ + + if p.MacAlg != PayloadKEMACMacAlgNULL { + return 0, fmt.Errorf("unsupported mac alg: %v", p.MacAlg) + } + + return n, nil +} + +func (*PayloadKEMAC) typ() payloadType { + return payloadTypeKEMAC +} + +func (p *PayloadKEMAC) marshalSize() int { + n := 5 + for _, sp := range p.SubPayloads { + n += sp.marshalSize() + } + return n +} + +func (p *PayloadKEMAC) marshalTo(buf []byte) (int, error) { + buf[1] = byte(p.EncrAlg) + + encrDataLen := 0 + for _, sp := range p.SubPayloads { + encrDataLen += sp.marshalSize() + } + + buf[2] = byte(encrDataLen >> 8) + buf[3] = byte(encrDataLen) + n := 4 + + for i, sp := range p.SubPayloads { + var nextPayloadType payloadType + if i != len(p.SubPayloads)-1 { + nextPayloadType = payloadTypeKeyData + } else { + nextPayloadType = 0 + } + + buf[n] = byte(nextPayloadType) + + n2, err := sp.marshalTo(buf[n:]) + if err != nil { + return 0, err + } + n += n2 + } + + buf[n] = byte(p.MacAlg) + n++ + + return n, nil +} diff --git a/pkg/mikey/payload_rand.go b/pkg/mikey/payload_rand.go new file mode 100644 index 00000000..5c400ba0 --- /dev/null +++ b/pkg/mikey/payload_rand.go @@ -0,0 +1,46 @@ +package mikey + +import "fmt" + +// PayloadRAND is a payload with random data. +type PayloadRAND struct { + Data []byte +} + +func (p *PayloadRAND) unmarshal(buf []byte) (int, error) { + if len(buf) < 2 { + return 0, fmt.Errorf("buffer too short") + } + + n := 1 + dataLen := int(buf[n]) + n++ + + if dataLen < 16 { + return 0, fmt.Errorf("invalid data len: %v", dataLen) + } + + if len(buf[n:]) < dataLen { + return 0, fmt.Errorf("buffer too short") + } + + p.Data = buf[n : n+dataLen] + n += dataLen + + return n, nil +} + +func (*PayloadRAND) typ() payloadType { + return payloadTypeRAND +} + +func (p *PayloadRAND) marshalSize() int { + return 2 + len(p.Data) +} + +func (p *PayloadRAND) marshalTo(buf []byte) (int, error) { + buf[1] = uint8(len(p.Data)) + n := 2 + n += copy(buf[2:], p.Data) + return n, nil +} diff --git a/pkg/mikey/payload_sp.go b/pkg/mikey/payload_sp.go new file mode 100644 index 00000000..81a58cbe --- /dev/null +++ b/pkg/mikey/payload_sp.go @@ -0,0 +1,129 @@ +package mikey + +import "fmt" + +// PayloadSPProtType is a security protocol. +type PayloadSPProtType uint8 + +// RFC3830, Table 6.2.a +const ( + PayloadSPProtTypeSRTP PayloadSPProtType = 0 +) + +// PayloadSPPolicyParamType is a policy param type. +type PayloadSPPolicyParamType uint8 + +// RFC3830, Table 6.10.1.a +const ( + PayloadSPPolicyParamTypeEncrAlg PayloadSPPolicyParamType = 0 + PayloadSPPolicyParamTypeSessionEncrKeyLen PayloadSPPolicyParamType = 1 + PayloadSPPolicyParamTypeAuthAlg PayloadSPPolicyParamType = 2 + PayloadSPPolicyParamTypeSessionAuthKeyLen PayloadSPPolicyParamType = 3 + PayloadSPPolicyParamTypeSessionSaltKeyLen PayloadSPPolicyParamType = 4 + PayloadSPPolicyParamTypeSRTPPseudoRandFun PayloadSPPolicyParamType = 5 + PayloadSPPolicyParamTypeKeyDerRate PayloadSPPolicyParamType = 6 + PayloadSPPolicyParamTypeSRTPEncrOffOn PayloadSPPolicyParamType = 7 + PayloadSPPolicyParamTypeSRTCPEncrOffOn PayloadSPPolicyParamType = 8 + PayloadSPPolicyParamTypeSenderFECOrder PayloadSPPolicyParamType = 9 + PayloadSPPolicyParamTypeSRTPAuthOffOn PayloadSPPolicyParamType = 10 + PayloadSPPolicyParamTypeAuthTagLen PayloadSPPolicyParamType = 11 + PayloadSPPolicyParamTypeSRTPPrefixLen PayloadSPPolicyParamType = 12 +) + +// PayloadSPPolicyParam is a policy param. +type PayloadSPPolicyParam struct { + Type PayloadSPPolicyParamType + Value []byte +} + +// PayloadSP is a security policy payload. +type PayloadSP struct { + PolicyNo uint8 + ProtType PayloadSPProtType + PolicyParams []PayloadSPPolicyParam +} + +func (p *PayloadSP) unmarshal(buf []byte) (int, error) { + if len(buf) < 5 { + return 0, fmt.Errorf("buffer too short") + } + + n := 1 + p.PolicyNo = buf[n] + n++ + p.ProtType = PayloadSPProtType(buf[n]) + n++ + + if p.ProtType != 0 { + return 0, fmt.Errorf("unsupported prot type: %v", p.ProtType) + } + + policyParamLength := uint16(buf[n])<<8 | uint16(buf[n+1]) + n += 2 + end := n + int(policyParamLength) + + for { + if n > end { + return 0, fmt.Errorf("policy param overflowed") + } + if n == end { + break + } + if len(buf[n:]) < 2 { + return 0, fmt.Errorf("buffer too short") + } + + typ := PayloadSPPolicyParamType(buf[n]) + n++ + valueLen := int(buf[n]) + n++ + + if len(buf[n:]) < valueLen { + return 0, fmt.Errorf("buffer too short") + } + + value := buf[n : n+valueLen] + n += valueLen + + p.PolicyParams = append(p.PolicyParams, PayloadSPPolicyParam{ + Type: typ, + Value: value, + }) + } + + return n, nil +} + +func (*PayloadSP) typ() payloadType { + return payloadTypeSP +} + +func (p *PayloadSP) marshalSize() int { + n := 5 + 2*len(p.PolicyParams) + for _, pp := range p.PolicyParams { + n += len(pp.Value) + } + return n +} + +func (p *PayloadSP) marshalTo(buf []byte) (int, error) { + buf[1] = p.PolicyNo + buf[2] = byte(p.ProtType) + + policyParamLength := 0 + for _, pp := range p.PolicyParams { + policyParamLength += 2 + len(pp.Value) + } + buf[3] = byte(policyParamLength >> 8) + buf[4] = byte(policyParamLength) + n := 5 + + for _, pp := range p.PolicyParams { + buf[n] = byte(pp.Type) + buf[n+1] = uint8(len(pp.Value)) + n += 2 + n += copy(buf[n:], pp.Value) + } + + return n, nil +} diff --git a/pkg/mikey/payload_t.go b/pkg/mikey/payload_t.go new file mode 100644 index 00000000..df2ca47a --- /dev/null +++ b/pkg/mikey/payload_t.go @@ -0,0 +1,56 @@ +package mikey + +import "fmt" + +// PayloadT is a timestamp payload. +type PayloadT struct { + TSType uint8 + TSValue uint64 +} + +func (p *PayloadT) unmarshal(buf []byte) (int, error) { + if len(buf) < 10 { + return 0, fmt.Errorf("buffer too short") + } + + n := 1 + p.TSType = buf[n] + n++ + + if p.TSType != 0 { + return 0, fmt.Errorf("unsupported TSType: %v", p.TSType) + } + + p.TSValue = uint64(buf[n])<<56 | + uint64(buf[n+1])<<48 | + uint64(buf[n+2])<<40 | + uint64(buf[n+3])<<32 | + uint64(buf[n+4])<<24 | + uint64(buf[n+5])<<16 | + uint64(buf[n+6])<<8 | + uint64(buf[n+7]) + n += 8 + + return n, nil +} + +func (*PayloadT) typ() payloadType { + return payloadTypeT +} + +func (p *PayloadT) marshalSize() int { + return 10 +} + +func (p *PayloadT) marshalTo(buf []byte) (int, error) { + buf[1] = p.TSType + buf[2] = byte(p.TSValue >> 56) + buf[3] = byte(p.TSValue >> 48) + buf[4] = byte(p.TSValue >> 40) + buf[5] = byte(p.TSValue >> 32) + buf[6] = byte(p.TSValue >> 24) + buf[7] = byte(p.TSValue >> 16) + buf[8] = byte(p.TSValue >> 8) + buf[9] = byte(p.TSValue) + return 10, nil +} diff --git a/pkg/mikey/sub_payload_key_data.go b/pkg/mikey/sub_payload_key_data.go new file mode 100644 index 00000000..2356a843 --- /dev/null +++ b/pkg/mikey/sub_payload_key_data.go @@ -0,0 +1,66 @@ +package mikey + +import "fmt" + +// SubPayloadKeyDataKeyType is a data key type. +type SubPayloadKeyDataKeyType uint8 + +// RFC3830, table 6.13.a +const ( + SubPayloadKeyDataKeyTypeTEK SubPayloadKeyDataKeyType = 2 +) + +// SubPayloadKeyData is a key data sub-payload. +type SubPayloadKeyData struct { + Type SubPayloadKeyDataKeyType + KV uint8 + KeyData []byte +} + +func (p *SubPayloadKeyData) unmarshal(buf []byte) (int, error) { + if len(buf) < 4 { + return 0, fmt.Errorf("buffer too short") + } + + n := 1 + p.Type = SubPayloadKeyDataKeyType(buf[n] >> 4) + p.KV = buf[n] & 0b00001111 + n++ + + if p.Type != SubPayloadKeyDataKeyTypeTEK { + return 0, fmt.Errorf("unsupported key type: %v", p.Type) + } + + if p.KV != 0 { + return 0, fmt.Errorf("unsupported KV: %v", p.KV) + } + + keyDataLen := int(uint16(buf[n])<<8 | uint16(buf[n+1])) + n += 2 + + if len(buf[n:]) < keyDataLen { + return 0, fmt.Errorf("buffer too short") + } + + p.KeyData = buf[n : n+keyDataLen] + n += keyDataLen + + return n, nil +} + +func (p *SubPayloadKeyData) marshalSize() int { + return 4 + len(p.KeyData) +} + +func (p *SubPayloadKeyData) marshalTo(buf []byte) (int, error) { + buf[1] = byte(p.Type)<<4 | p.KV + + keyDataLen := len(p.KeyData) + buf[2] = byte(keyDataLen >> 8) + buf[3] = byte(keyDataLen) + n := 4 + + n += copy(buf[n:], p.KeyData) + + return n, nil +} diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/0e2aa46892dbf440 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/0e2aa46892dbf440 new file mode 100644 index 00000000..c22c100e --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/0e2aa46892dbf440 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x00\xe6\x9dQ\xf8\x01\x00\x000hW`\x00\x00\x00\x00\v\x00\xeb\xfeo-\xb1\xc1?\xd0\n\x10\xc2\xdd\xe4C\xa8I0\xa5uz~\xd9ä\x17\xfb\x01\x00\x00\x00\x15\x00\x01\x01\x01\x01\x10\x02\x8a\x86\xbc\x01\n\a\x01\x01\b\x01\x01\n\x01\x01\x00\x00\x00\"\x00 \x00\x1e\x90\x91x=\xfc\xe8\xdd\xcdD:SP\x8bdP\x9f5\xbd\x8a\x86\xbcM\x8bv7\xa5\x02I=\xaf\x00") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/0e400a18088f3ef1 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/0e400a18088f3ef1 new file mode 100644 index 00000000..a50beed0 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/0e400a18088f3ef1 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x000\x000000\x01\x00") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/1308b4f12c633cfe b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/1308b4f12c633cfe new file mode 100644 index 00000000..02ddc037 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/1308b4f12c633cfe @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n0") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/15a72f9c17aa83eb b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/15a72f9c17aa83eb new file mode 100644 index 00000000..e57edaaa --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/15a72f9c17aa83eb @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\x01\x1000000000000000000\x00\x00\"\x00 \x00\x00000000000000000000000000000000\x00") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/21d92b615e38b74d b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/21d92b615e38b74d new file mode 100644 index 00000000..6bb42c3a --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/21d92b615e38b74d @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x000\x000000\x01") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/2968ff6b6c107cd2 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/2968ff6b6c107cd2 new file mode 100644 index 00000000..1cb2313c --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/2968ff6b6c107cd2 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\x00\x00000000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/2c79d879e91381fc b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/2c79d879e91381fc new file mode 100644 index 00000000..b047e523 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/2c79d879e91381fc @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x10000000000000000000\x0000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/3274006efd885c07 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/3274006efd885c07 new file mode 100644 index 00000000..b8192378 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/3274006efd885c07 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00\x00 00000000000000000000000000000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/34785eb47d444797 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/34785eb47d444797 new file mode 100644 index 00000000..4615438f --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/34785eb47d444797 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x10000000000000000000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/34b60fae6a60ed23 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/34b60fae6a60ed23 new file mode 100644 index 00000000..1f922415 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/34b60fae6a60ed23 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00\x00 000000000000000000000000000000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/3cf3060ecf42827a b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/3cf3060ecf42827a new file mode 100644 index 00000000..ddf5c16c --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/3cf3060ecf42827a @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x00000000000\x000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/4305b65c1705da06 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/4305b65c1705da06 new file mode 100644 index 00000000..8e650567 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/4305b65c1705da06 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x000000000000\x00000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/450a97754aa91bb3 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/450a97754aa91bb3 new file mode 100644 index 00000000..ca3b9321 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/450a97754aa91bb3 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("0000000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/49b410a3c47b1687 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/49b410a3c47b1687 new file mode 100644 index 00000000..d59151a0 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/49b410a3c47b1687 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x000\x00000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/4afa2df02fd9d00a b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/4afa2df02fd9d00a new file mode 100644 index 00000000..6f973f0f --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/4afa2df02fd9d00a @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01000000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/4cece13ffff9b317 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/4cece13ffff9b317 new file mode 100644 index 00000000..d4a5c636 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/4cece13ffff9b317 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x0000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/68b704daf492f697 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/68b704daf492f697 new file mode 100644 index 00000000..9ca03f8d --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/68b704daf492f697 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/6e357cea2c7a331b b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/6e357cea2c7a331b new file mode 100644 index 00000000..939c2469 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/6e357cea2c7a331b @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\x0100000000000000000000000000000000000000000000000000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/730d6c62c6f77fdc b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/730d6c62c6f77fdc new file mode 100644 index 00000000..940c817c --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/730d6c62c6f77fdc @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/7827f50473cee286 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/7827f50473cee286 new file mode 100644 index 00000000..dfb70136 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/7827f50473cee286 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00\x00\"0 00000000000000000000000000000000\x00") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/7bf4d4e3d8104096 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/7bf4d4e3d8104096 new file mode 100644 index 00000000..98dea120 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/7bf4d4e3d8104096 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00\x00\x020 \x00") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/7e509fbe194aa190 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/7e509fbe194aa190 new file mode 100644 index 00000000..729a064b --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/7e509fbe194aa190 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00\x00\"0 \x00\x00000000000000000000000000000000\x00") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/87828f4d0c4c5b02 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/87828f4d0c4c5b02 new file mode 100644 index 00000000..3c06b4db --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/87828f4d0c4c5b02 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00\x00 00000000000000000000000000000000\x00") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/8bfad049e124e765 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/8bfad049e124e765 new file mode 100644 index 00000000..c72b32f4 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/8bfad049e124e765 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x000\x000000\x01\x000000000000000000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/93e877be95b75ad0 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/93e877be95b75ad0 new file mode 100644 index 00000000..e6ac4515 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/93e877be95b75ad0 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x0100\x00\x00 0!000000000000000000000000000000\x00") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/a10c26126d5754cc b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/a10c26126d5754cc new file mode 100644 index 00000000..35b7eb5b --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/a10c26126d5754cc @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x000000000000\x000000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/a6078eb8043d4763 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/a6078eb8043d4763 new file mode 100644 index 00000000..1e7b73a3 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/a6078eb8043d4763 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x000000000000000000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/af43d20f8944989e b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/af43d20f8944989e new file mode 100644 index 00000000..687a0efd --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/af43d20f8944989e @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x000\x9d000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/c20cecf53e42d68f b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/c20cecf53e42d68f new file mode 100644 index 00000000..9d0b508d --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/c20cecf53e42d68f @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/cd41233d242004e0 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/cd41233d242004e0 new file mode 100644 index 00000000..c73f16e5 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/cd41233d242004e0 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/cde5035558ff62ef b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/cde5035558ff62ef new file mode 100644 index 00000000..7f627854 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/cde5035558ff62ef @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x100000000000000000\x010\x00\x00\x150\x0100\x0100\x0100\x0100\x0100\x0100\x010") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/d43da2b1fa37becf b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/d43da2b1fa37becf new file mode 100644 index 00000000..9ea847c5 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/d43da2b1fa37becf @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x10000000000000000000\x00") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/dbd8b97d5c1e0809 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/dbd8b97d5c1e0809 new file mode 100644 index 00000000..23511e59 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/dbd8b97d5c1e0809 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x000\x000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/dd7554c158a4bf76 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/dd7554c158a4bf76 new file mode 100644 index 00000000..49c633c1 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/dd7554c158a4bf76 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x00\x05\x000000\x01\x00000000000\v\x0000000000\n\x10000000000000000000\x000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/f039aafb6c8fbf5c b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/f039aafb6c8fbf5c new file mode 100644 index 00000000..e9361693 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/f039aafb6c8fbf5c @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x000\x000000\x010") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/f41ce99116022ac9 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/f41ce99116022ac9 new file mode 100644 index 00000000..37f22795 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/f41ce99116022ac9 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x0000000000") diff --git a/pkg/mikey/testdata/fuzz/FuzzUnmarshal/fccb75db092e7fe6 b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/fccb75db092e7fe6 new file mode 100644 index 00000000..ddd30743 --- /dev/null +++ b/pkg/mikey/testdata/fuzz/FuzzUnmarshal/fccb75db092e7fe6 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x01\x000\x00000") diff --git a/pkg/rtcpreceiver/rtcpreceiver.go b/pkg/rtcpreceiver/rtcpreceiver.go index 15893975..ddaf4e9c 100644 --- a/pkg/rtcpreceiver/rtcpreceiver.go +++ b/pkg/rtcpreceiver/rtcpreceiver.go @@ -161,7 +161,7 @@ func (rr *RTCPReceiver) report() rtcp.Packet { } if rr.firstSenderReportReceived { - // middle 32 bits out of 64 in the NTP timestamp of last sender report + // middle 32 bits out of 64 in the NTP of last sender report report.Reports[0].LastSenderReport = uint32(rr.lastSenderReportTimeNTP >> 16) // delay, expressed in units of 1/65536 seconds, between @@ -267,7 +267,7 @@ func (rr *RTCPReceiver) packetNTPUnsafe(ts uint32) (time.Time, bool) { return ntpTimeRTCPToGo(rr.lastSenderReportTimeNTP).Add(timeDiffGo), true } -// PacketNTP returns the NTP timestamp of the packet. +// PacketNTP returns the NTP (absolute timestamp) of the packet. func (rr *RTCPReceiver) PacketNTP(ts uint32) (time.Time, bool) { rr.mutex.Lock() defer rr.mutex.Unlock() diff --git a/server.go b/server.go index 5096c646..19ecc91b 100644 --- a/server.go +++ b/server.go @@ -196,14 +196,6 @@ func (s *Server) Start() error { s.checkStreamPeriod = 1 * time.Second } - if s.TLSConfig != nil && s.UDPRTPAddress != "" { - return fmt.Errorf("TLS can't be used with UDP") - } - - if s.TLSConfig != nil && s.MulticastIPRange != "" { - return fmt.Errorf("TLS can't be used with UDP-multicast") - } - if s.RTSPAddress == "" { return fmt.Errorf("RTSPAddress not provided") } diff --git a/server_conn.go b/server_conn.go index 90ea1adb..a200f1d2 100644 --- a/server_conn.go +++ b/server_conn.go @@ -2,6 +2,7 @@ package gortsplib import ( "context" + "crypto/rand" "crypto/tls" "errors" "net" @@ -17,6 +18,7 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/description" "github.com/bluenviron/gortsplib/v4/pkg/headers" "github.com/bluenviron/gortsplib/v4/pkg/liberrors" + "github.com/bluenviron/gortsplib/v4/pkg/mikey" ) func getSessionID(header base.Header) string { @@ -51,7 +53,97 @@ func checkBackChannelsEnabled(header base.Header) bool { return false } -func prepareForDescribe(d *description.Session, multicast bool, backChannels bool) *description.Session { +func mikeyGenerate(ctx *wrappedSRTPContext) (*mikey.Message, error) { + csbID, err := randUint32() + if err != nil { + return nil, err + } + + msg := &mikey.Message{ + Header: mikey.Header{ + Version: 1, + CSBID: csbID, + }, + } + + msg.Header.CSIDMapInfo = make([]mikey.SRTPIDEntry, len(ctx.ssrcs)) + + n := 0 + for _, ssrc := range ctx.ssrcs { + msg.Header.CSIDMapInfo[n] = mikey.SRTPIDEntry{ + PolicyNo: 0, + SSRC: ssrc, + ROC: ctx.roc(ssrc), + } + n++ + } + + randData := make([]byte, 16) + _, err = rand.Read(randData) + if err != nil { + return nil, err + } + + msg.Payloads = []mikey.Payload{ + &mikey.PayloadT{ + TSType: 0, + TSValue: mikeyEncodeTime(time.Now()), + }, + &mikey.PayloadRAND{ + Data: randData, + }, + &mikey.PayloadSP{ + PolicyParams: []mikey.PayloadSPPolicyParam{ + { + Type: mikey.PayloadSPPolicyParamTypeEncrAlg, + Value: []byte{1}, + }, + { + Type: mikey.PayloadSPPolicyParamTypeSessionEncrKeyLen, + Value: []byte{0x10}, + }, + { + Type: mikey.PayloadSPPolicyParamTypeAuthAlg, + Value: []byte{1}, + }, + { + Type: mikey.PayloadSPPolicyParamTypeSessionAuthKeyLen, + Value: []byte{0x0a}, + }, + { + Type: mikey.PayloadSPPolicyParamTypeSRTPEncrOffOn, + Value: []byte{1}, + }, + { + Type: mikey.PayloadSPPolicyParamTypeSRTCPEncrOffOn, + Value: []byte{1}, + }, + { + Type: mikey.PayloadSPPolicyParamTypeSRTPAuthOffOn, + Value: []byte{1}, + }, + }, + }, + &mikey.PayloadKEMAC{ + SubPayloads: []*mikey.SubPayloadKeyData{ + { + Type: mikey.SubPayloadKeyDataKeyTypeTEK, + KeyData: ctx.key, + }, + }, + }, + } + + return msg, nil +} + +func prepareForDescribe( + d *description.Session, + multicast bool, + backChannels bool, + secure bool, + medias map[*description.Media]*serverStreamMedia, +) (*description.Session, error) { out := &description.Session{ Title: d.Title, Multicast: multicast, @@ -60,19 +152,32 @@ func prepareForDescribe(d *description.Session, multicast bool, backChannels boo for i, medi := range d.Medias { if !medi.IsBackChannel || backChannels { + var keyMgmtMikey *mikey.Message + if secure { + sm := medias[medi] + + var err error + keyMgmtMikey, err = mikeyGenerate(sm.srtpOutCtx) + if err != nil { + return nil, err + } + } + out.Medias = append(out.Medias, &description.Media{ Type: medi.Type, ID: medi.ID, IsBackChannel: medi.IsBackChannel, // we have to use trackID=number in order to support clients // like the Grandstream GXV3500. - Control: "trackID=" + strconv.FormatInt(int64(i), 10), - Formats: medi.Formats, + Control: "trackID=" + strconv.FormatInt(int64(i), 10), + Secure: secure, + KeyMgmtMikey: keyMgmtMikey, + Formats: medi.Formats, }) } } - return out + return out, nil } func credentialsProvided(req *base.Request) bool { @@ -160,7 +265,7 @@ func (sc *ServerConn) UserData() interface{} { return sc.userData } -// Session returns associated session. +// Session returns the associated session. func (sc *ServerConn) Session() *ServerSession { return sc.session } @@ -370,13 +475,28 @@ func (sc *ServerConn) handleRequestInner(req *base.Request) (*base.Response, err return res, err } - desc := prepareForDescribe( + var desc *description.Session + desc, err = prepareForDescribe( stream.Desc, checkMulticastEnabled(sc.s.MulticastIPRange, query), checkBackChannelsEnabled(req.Header), + sc.s.TLSConfig != nil, + stream.medias, ) + if err != nil { + return &base.Response{ + StatusCode: base.StatusInternalServerError, + }, err + } + + var byts []byte + byts, err = desc.Marshal(false) + if err != nil { + return &base.Response{ + StatusCode: base.StatusInternalServerError, + }, err + } - byts, _ := desc.Marshal(false) res.Body = byts } diff --git a/server_play_test.go b/server_play_test.go index 3bc77a89..df48f240 100644 --- a/server_play_test.go +++ b/server_play_test.go @@ -2,6 +2,7 @@ package gortsplib import ( "bytes" + "crypto/rand" "crypto/tls" "net" "strconv" @@ -21,6 +22,7 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/description" "github.com/bluenviron/gortsplib/v4/pkg/format" "github.com/bluenviron/gortsplib/v4/pkg/headers" + "github.com/bluenviron/gortsplib/v4/pkg/mikey" "github.com/bluenviron/gortsplib/v4/pkg/sdp" ) @@ -333,7 +335,7 @@ func TestServerPlaySetupErrors(t *testing.T) { require.EqualError(t, ctx.Error, "media has already been setup") case "different protocols": - require.EqualError(t, ctx.Error, "can't setup medias with different protocols") + require.EqualError(t, ctx.Error, "can't setup medias with different transports") } close(nconnClosed) }, @@ -574,13 +576,48 @@ func TestServerPlaySetupErrorSameUDPPortsAndIP(t *testing.T) { } func TestServerPlay(t *testing.T) { - for _, transport := range []string{ - "udp", - "multicast", - "tcp", - "tls", + for _, ca := range []struct { + scheme string + transport string + secure string + }{ + { + "rtsp", + "udp", + "unsecure", + }, + { + "rtsp", + "multicast", + "unsecure", + }, + { + "rtsp", + "tcp", + "unsecure", + }, + { + "rtsps", + "tcp", + "unsecure", + }, + { + "rtsps", + "udp", + "secure", + }, + { + "rtsps", + "multicast", + "secure", + }, + { + "rtsps", + "tcp", + "secure", + }, } { - t.Run(transport, func(t *testing.T) { + t.Run(ca.scheme+"_"+ca.transport+"_"+ca.secure, func(t *testing.T) { var stream *ServerStream nconnOpened := make(chan struct{}) nconnClosed := make(chan struct{}) @@ -598,10 +635,10 @@ func TestServerPlay(t *testing.T) { }, onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { s := ctx.Conn.Stats() - require.Greater(t, s.BytesSent, uint64(810)) - require.Less(t, s.BytesSent, uint64(1150)) - require.Greater(t, s.BytesReceived, uint64(440)) - require.Less(t, s.BytesReceived, uint64(660)) + require.Greater(t, s.BytesSent, uint64(800)) + require.Less(t, s.BytesSent, uint64(1600)) + require.Greater(t, s.BytesReceived, uint64(400)) + require.Less(t, s.BytesReceived, uint64(950)) close(nconnClosed) }, @@ -609,12 +646,12 @@ func TestServerPlay(t *testing.T) { close(sessionOpened) }, onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { - if transport != "multicast" { + if ca.transport != "multicast" { s := ctx.Session.Stats() require.Greater(t, s.BytesSent, uint64(50)) - require.Less(t, s.BytesSent, uint64(60)) + require.Less(t, s.BytesSent, uint64(130)) require.Greater(t, s.BytesReceived, uint64(15)) - require.Less(t, s.BytesReceived, uint64(25)) + require.Less(t, s.BytesReceived, uint64(35)) } close(sessionClosed) @@ -632,12 +669,12 @@ func TestServerPlay(t *testing.T) { onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { require.NotNil(t, ctx.Conn.Session()) - switch transport { + switch ca.transport { case "udp": v := TransportUDP require.Equal(t, &v, ctx.Session.SetuppedTransport()) - case "tcp", "tls": + case "tcp": v := TransportTCP require.Equal(t, &v, ctx.Session.SetuppedTransport()) @@ -651,14 +688,14 @@ func TestServerPlay(t *testing.T) { // send RTCP packets directly to the session. // these are sent after the response, only if onPlay returns StatusOK. - if transport != "multicast" { + if ca.transport != "multicast" { err := ctx.Session.WritePacketRTCP(stream.Description().Medias[0], &testRTCPPacket) require.NoError(t, err) } ctx.Session.OnPacketRTCPAny(func(medi *description.Media, pkt rtcp.Packet) { // ignore multicast loopback - if transport == "multicast" && atomic.AddUint64(&counter, 1) <= 1 { + if ca.secure == "unsecure" && ca.transport == "multicast" && atomic.AddUint64(&counter, 1) <= 1 { return } @@ -691,7 +728,7 @@ func TestServerPlay(t *testing.T) { RTSPAddress: listenIP + ":8554", } - switch transport { + switch ca.transport { case "udp": s.UDPRTPAddress = "127.0.0.1:8000" s.UDPRTCPAddress = "127.0.0.1:8001" @@ -700,8 +737,9 @@ func TestServerPlay(t *testing.T) { s.MulticastIPRange = "224.1.0.0/16" s.MulticastRTPPort = 8000 s.MulticastRTCPPort = 8001 + } - case "tls": + if ca.scheme == "rtsps" { cert, err := tls.X509KeyPair(serverCert, serverKey) require.NoError(t, err) s.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}} @@ -723,7 +761,7 @@ func TestServerPlay(t *testing.T) { require.NoError(t, err) nconn = func() net.Conn { - if transport == "tls" { + if ca.scheme == "rtsps" { return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true}) } return nconn @@ -734,11 +772,16 @@ func TestServerPlay(t *testing.T) { desc := doDescribe(t, conn, false) + if ca.secure == "secure" { + require.True(t, desc.Medias[0].Secure) + require.NotEmpty(t, desc.Medias[0].KeyMgmtMikey) + } + inTH := &headers.Transport{ Mode: transportModePtr(headers.TransportModePlay), } - switch transport { + switch ca.transport { case "udp": v := headers.TransportDeliveryUnicast inTH.Delivery = &v @@ -750,19 +793,81 @@ func TestServerPlay(t *testing.T) { inTH.Delivery = &v inTH.Protocol = headers.TransportProtocolUDP - default: + case "tcp": v := headers.TransportDeliveryUnicast inTH.Delivery = &v inTH.Protocol = headers.TransportProtocolTCP inTH.InterleavedIDs = &[2]int{5, 6} // odd value } - res, th := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "") + h := base.Header{ + "CSeq": base.HeaderValue{"1"}, + } + + var srtpOutCtx *wrappedSRTPContext + + if ca.secure == "secure" { + inTH.Secure = true + + key := make([]byte, srtpKeyLength) + _, err = rand.Read(key) + require.NoError(t, err) + + srtpOutCtx = &wrappedSRTPContext{ + key: key, + ssrcs: []uint32{2345423}, + } + err = srtpOutCtx.initialize() + require.NoError(t, err) + + var mikeyMsg *mikey.Message + mikeyMsg, err = mikeyGenerate(srtpOutCtx) + require.NoError(t, err) + + var enc base.HeaderValue + enc, err = headers.KeyMgmt{ + URL: mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), + MikeyMessage: mikeyMsg, + }.Marshal() + require.NoError(t, err) + h["KeyMgmt"] = enc + } + + h["Transport"] = inTH.Marshal() + + res, err := writeReqReadRes(conn, base.Request{ + Method: base.Setup, + URL: mediaURL(t, desc.BaseURL, desc.Medias[0]), + Header: h, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + var th headers.Transport + err = th.Unmarshal(res.Header["Transport"]) + require.NoError(t, err) + + var srtpInCtx *wrappedSRTPContext + + if ca.secure == "secure" { + require.True(t, th.Secure) + + var keyMgmt headers.KeyMgmt + err = keyMgmt.Unmarshal(res.Header["KeyMgmt"]) + require.NoError(t, err) + + pl1, _ := mikeyGetPayload[*mikey.PayloadKEMAC](keyMgmt.MikeyMessage) + pl2, _ := mikeyGetPayload[*mikey.PayloadKEMAC](desc.Medias[0].KeyMgmtMikey) + require.Equal(t, pl1, pl2) + + srtpInCtx, err = mikeyToContext(keyMgmt.MikeyMessage) + require.NoError(t, err) + } var l1 net.PacketConn var l2 net.PacketConn - switch transport { //nolint:dupl + switch ca.transport { case "udp": require.Equal(t, headers.TransportProtocolUDP, th.Protocol) require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery) @@ -775,7 +880,7 @@ func TestServerPlay(t *testing.T) { require.NoError(t, err) defer l2.Close() - case "multicast": + case "multicast": //nolint:dupl require.Equal(t, headers.TransportProtocolUDP, th.Protocol) require.Equal(t, headers.TransportDeliveryMulticast, *th.Delivery) @@ -808,7 +913,7 @@ func TestServerPlay(t *testing.T) { require.NoError(t, err) } - default: + case "tcp": require.Equal(t, headers.TransportProtocolTCP, th.Protocol) require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery) } @@ -821,86 +926,122 @@ func TestServerPlay(t *testing.T) { // server -> client (direct) - switch transport { - case "udp": - buf := make([]byte, 2048) - var n int - n, _, err = l2.ReadFrom(buf) - require.NoError(t, err) - require.Equal(t, testRTCPPacketMarshaled, buf[:n]) + if ca.transport != "multicast" { + var buf []byte - case "tcp", "tls": - var f *base.InterleavedFrame - f, err = conn.ReadInterleavedFrame() - require.NoError(t, err) - require.Equal(t, 6, f.Channel) - require.Equal(t, testRTCPPacketMarshaled, f.Payload) + switch ca.transport { + case "udp": + buf = make([]byte, 2048) + var n int + n, _, err = l2.ReadFrom(buf) + require.NoError(t, err) + buf = buf[:n] + + case "tcp": + var f *base.InterleavedFrame + f, err = conn.ReadInterleavedFrame() + require.NoError(t, err) + require.Equal(t, 6, f.Channel) + buf = f.Payload + } + + if ca.secure == "secure" { + buf, err = srtpInCtx.decryptRTCP(buf, buf, nil) + require.NoError(t, err) + } + + require.Equal(t, testRTCPPacketMarshaled, buf) } // server -> client (through stream) - if transport == "udp" || transport == "multicast" { - buf := make([]byte, 2048) + var buf1 []byte + var buf2 []byte + + switch ca.transport { + case "udp", "multicast": + buf1 = make([]byte, 2048) var n int - n, _, err = l1.ReadFrom(buf) + n, _, err = l1.ReadFrom(buf1) require.NoError(t, err) + buf1 = buf1[:n] - var pkt rtp.Packet - err = pkt.Unmarshal(buf[:n]) + buf2 = make([]byte, 2048) + n, _, err = l2.ReadFrom(buf2) require.NoError(t, err) - pkt.SSRC = testRTPPacket.SSRC - require.Equal(t, testRTPPacket, pkt) + buf2 = buf2[:n] - buf = make([]byte, 2048) - n, _, err = l2.ReadFrom(buf) - require.NoError(t, err) - require.Equal(t, testRTCPPacketMarshaled, buf[:n]) - } else { + case "tcp": var f *base.InterleavedFrame f, err = conn.ReadInterleavedFrame() require.NoError(t, err) require.Equal(t, 6, f.Channel) - require.Equal(t, testRTCPPacketMarshaled, f.Payload) + buf2 = f.Payload f, err = conn.ReadInterleavedFrame() require.NoError(t, err) require.Equal(t, 5, f.Channel) - var pkt rtp.Packet - err = pkt.Unmarshal(f.Payload) - require.NoError(t, err) - pkt.SSRC = testRTPPacket.SSRC - require.Equal(t, testRTPPacket, pkt) + buf1 = f.Payload } + if ca.secure == "secure" { + buf1, err = srtpInCtx.decryptRTP(buf1, buf1, nil) + require.NoError(t, err) + } + + var pkt rtp.Packet + err = pkt.Unmarshal(buf1) + require.NoError(t, err) + pkt.SSRC = testRTPPacket.SSRC + require.Equal(t, testRTPPacket, pkt) + + if ca.secure == "secure" { + buf2, err = srtpInCtx.decryptRTCP(buf2, buf2, nil) + require.NoError(t, err) + } + + require.Equal(t, testRTCPPacketMarshaled, buf2) + // client -> server - switch transport { + buf := testRTCPPacketMarshaled + + if ca.secure == "secure" { + encr := make([]byte, 2000) + encr, err = srtpOutCtx.encryptRTCP(encr, buf, nil) + require.NoError(t, err) + buf = encr + } + + switch ca.transport { case "udp": - _, err = l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{ + _, err = l2.WriteTo(buf, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ServerPorts[1], }) require.NoError(t, err) - <-framesReceived case "multicast": - _, err = l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{ + _, err = l2.WriteTo(buf, &net.UDPAddr{ IP: *th.Destination, Port: th.Ports[1], }) require.NoError(t, err) - <-framesReceived - default: + case "tcp": err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 6, - Payload: testRTCPPacketMarshaled, + Payload: buf, }, make([]byte, 1024)) require.NoError(t, err) - <-framesReceived } - if transport == "udp" || transport == "multicast" { + <-framesReceived + + // ping + + switch ca.transport { + case "udp", "multicast": // ping with OPTIONS res, err = writeReqReadRes(conn, base.Request{ Method: base.Options, @@ -941,7 +1082,6 @@ func TestServerPlaySocketError(t *testing.T) { "udp", "multicast", "tcp", - "tls", } { t.Run(transport, func(t *testing.T) { var stream *ServerStream @@ -996,11 +1136,6 @@ func TestServerPlaySocketError(t *testing.T) { s.MulticastIPRange = "224.1.0.0/16" s.MulticastRTPPort = 8000 s.MulticastRTCPPort = 8001 - - case "tls": - cert, err := tls.X509KeyPair(serverCert, serverKey) - require.NoError(t, err) - s.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}} } err := s.Start() @@ -1019,12 +1154,6 @@ func TestServerPlaySocketError(t *testing.T) { require.NoError(t, err) defer nconn.Close() - nconn = func() net.Conn { - if transport == "tls" { - return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true}) - } - return nconn - }() conn := conn.NewConn(nconn) desc := doDescribe(t, conn, false) @@ -1057,7 +1186,7 @@ func TestServerPlaySocketError(t *testing.T) { var l1 net.PacketConn var l2 net.PacketConn - switch transport { //nolint:dupl + switch transport { case "udp": require.Equal(t, headers.TransportProtocolUDP, th.Protocol) require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery) @@ -1070,7 +1199,7 @@ func TestServerPlaySocketError(t *testing.T) { require.NoError(t, err) defer l2.Close() - case "multicast": + case "multicast": //nolint:dupl require.Equal(t, headers.TransportProtocolUDP, th.Protocol) require.Equal(t, headers.TransportDeliveryMulticast, *th.Delivery) diff --git a/server_record_test.go b/server_record_test.go index c0b1a6be..a31180fe 100644 --- a/server_record_test.go +++ b/server_record_test.go @@ -2,6 +2,7 @@ package gortsplib import ( "bytes" + "crypto/rand" "crypto/tls" "net" "strconv" @@ -18,6 +19,7 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/description" "github.com/bluenviron/gortsplib/v4/pkg/format" "github.com/bluenviron/gortsplib/v4/pkg/headers" + "github.com/bluenviron/gortsplib/v4/pkg/mikey" "github.com/bluenviron/gortsplib/v4/pkg/sdp" ) @@ -337,6 +339,9 @@ func TestServerRecordPath(t *testing.T) { media := testH264Media media.Control = ca.control + enc, err := media.Marshal2() + require.NoError(t, err) + sout := &sdp.SessionDescription{ SessionName: psdp.SessionName("Stream"), Origin: psdp.Origin{ @@ -348,7 +353,7 @@ func TestServerRecordPath(t *testing.T) { TimeDescriptions: []psdp.TimeDescription{ {Timing: psdp.Timing{}}, }, - MediaDescriptions: []*psdp.MediaDescription{media.Marshal()}, + MediaDescriptions: []*psdp.MediaDescription{enc}, } byts, _ := sout.Marshal() @@ -533,12 +538,38 @@ func TestServerRecordErrorRecordPartialMedias(t *testing.T) { } func TestServerRecord(t *testing.T) { - for _, transport := range []string{ - "udp", - "tcp", - "tls", + for _, ca := range []struct { + scheme string + transport string + secure string + }{ + { + "rtsp", + "udp", + "unsecure", + }, + { + "rtsp", + "tcp", + "unsecure", + }, + { + "rtsps", + "tcp", + "unsecure", + }, + { + "rtsps", + "udp", + "secure", + }, + { + "rtsps", + "tcp", + "secure", + }, } { - t.Run(transport, func(t *testing.T) { + t.Run(ca.scheme+"_"+ca.transport+"_"+ca.secure, func(t *testing.T) { nconnOpened := make(chan struct{}) nconnClosed := make(chan struct{}) sessionOpened := make(chan struct{}) @@ -552,9 +583,9 @@ func TestServerRecord(t *testing.T) { onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { s := ctx.Conn.Stats() require.Greater(t, s.BytesSent, uint64(510)) - require.Less(t, s.BytesSent, uint64(560)) + require.Less(t, s.BytesSent, uint64(1100)) require.Greater(t, s.BytesReceived, uint64(1000)) - require.Less(t, s.BytesReceived, uint64(1200)) + require.Less(t, s.BytesReceived, uint64(1800)) close(nconnClosed) }, @@ -564,9 +595,9 @@ func TestServerRecord(t *testing.T) { onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { s := ctx.Session.Stats() require.Greater(t, s.BytesSent, uint64(75)) - require.Less(t, s.BytesSent, uint64(130)) + require.Less(t, s.BytesSent, uint64(140)) require.Greater(t, s.BytesReceived, uint64(70)) - require.Less(t, s.BytesReceived, uint64(80)) + require.Less(t, s.BytesReceived, uint64(130)) close(sessionClosed) }, @@ -581,12 +612,12 @@ func TestServerRecord(t *testing.T) { }, nil, nil }, onRecord: func(ctx *ServerHandlerOnRecordCtx) (*base.Response, error) { - switch transport { + switch ca.transport { case "udp": v := TransportUDP require.Equal(t, &v, ctx.Session.SetuppedTransport()) - case "tcp", "tls": + case "tcp": v := TransportTCP require.Equal(t, &v, ctx.Session.SetuppedTransport()) } @@ -628,12 +659,12 @@ func TestServerRecord(t *testing.T) { RTSPAddress: "localhost:8554", } - switch transport { - case "udp": + if ca.transport == "udp" { s.UDPRTPAddress = "127.0.0.1:8000" s.UDPRTCPAddress = "127.0.0.1:8001" + } - case "tls": + if ca.scheme == "rtsps" { cert, err := tls.X509KeyPair(serverCert, serverKey) require.NoError(t, err) s.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}} @@ -648,7 +679,7 @@ func TestServerRecord(t *testing.T) { defer nconn.Close() nconn = func() net.Conn { - if transport == "tls" { + if ca.scheme == "rtsps" { return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true}) } return nconn @@ -686,6 +717,8 @@ func TestServerRecord(t *testing.T) { var l2s [2]net.PacketConn var session string var serverPorts [2]*[2]int + var srtpOutCtx [2]*wrappedSRTPContext + var srtpInCtx [2]*wrappedSRTPContext for i := 0; i < 2; i++ { inTH := &headers.Transport{ @@ -693,7 +726,7 @@ func TestServerRecord(t *testing.T) { Mode: transportModePtr(headers.TransportModeRecord), } - if transport == "udp" { + if ca.transport == "udp" { inTH.Protocol = headers.TransportProtocolUDP inTH.ClientPorts = &[2]int{35466 + i*2, 35467 + i*2} @@ -709,84 +742,186 @@ func TestServerRecord(t *testing.T) { inTH.InterleavedIDs = &[2]int{2 + i*2, 3 + i*2} } - res, th := doSetup(t, conn, "rtsp://localhost:8554/teststream?param=value/"+medias[i].Control, inTH, "") + h := base.Header{ + "CSeq": base.HeaderValue{"1"}, + } + + if session != "" { + h["Session"] = base.HeaderValue{session} + } + + if ca.secure == "secure" { + inTH.Secure = true + + key := make([]byte, srtpKeyLength) + _, err = rand.Read(key) + require.NoError(t, err) + + srtpOutCtx[i] = &wrappedSRTPContext{ + key: key, + ssrcs: []uint32{2345423}, + } + err = srtpOutCtx[i].initialize() + require.NoError(t, err) + + var mikeyMsg *mikey.Message + mikeyMsg, err = mikeyGenerate(srtpOutCtx[i]) + require.NoError(t, err) + + var enc base.HeaderValue + enc, err = headers.KeyMgmt{ + URL: "rtsp://localhost:8554/teststream?param=value/" + medias[i].Control, + MikeyMessage: mikeyMsg, + }.Marshal() + require.NoError(t, err) + h["KeyMgmt"] = enc + } + + h["Transport"] = inTH.Marshal() + + var res *base.Response + res, err = writeReqReadRes(conn, base.Request{ + Method: base.Setup, + URL: mustParseURL("rtsp://localhost:8554/teststream?param=value/" + medias[i].Control), + Header: h, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + var th headers.Transport + err = th.Unmarshal(res.Header["Transport"]) + require.NoError(t, err) session = readSession(t, res) - if transport == "udp" { + if ca.transport == "udp" { serverPorts[i] = th.ServerPorts } + + if ca.secure == "secure" { + require.True(t, th.Secure) + + var keyMgmt headers.KeyMgmt + err = keyMgmt.Unmarshal(res.Header["KeyMgmt"]) + require.NoError(t, err) + + srtpInCtx[i], err = mikeyToContext(keyMgmt.MikeyMessage) + require.NoError(t, err) + } } doRecord(t, conn, "rtsp://localhost:8554/teststream", session) - for i := 0; i < 2; i++ { - // skip firewall opening - if transport == "udp" { + // skip firewall opening + + if ca.transport == "udp" { + for i := 0; i < 2; i++ { buf := make([]byte, 2048) _, _, err = l2s[i].ReadFrom(buf) require.NoError(t, err) } + } - // server -> client + // server -> client - if transport == "udp" { - buf := make([]byte, 2048) + for i := 0; i < 2; i++ { + var buf []byte + + if ca.transport == "udp" { + buf = make([]byte, 2048) var n int n, _, err = l2s[i].ReadFrom(buf) require.NoError(t, err) - require.Equal(t, testRTCPPacketMarshaled, buf[:n]) + buf = buf[:n] } else { var f *base.InterleavedFrame f, err = conn.ReadInterleavedFrame() require.NoError(t, err) require.Equal(t, 3+i*2, f.Channel) - require.Equal(t, testRTCPPacketMarshaled, f.Payload) + buf = f.Payload } - // client -> server + if ca.secure == "secure" { + buf, err = srtpInCtx[i].decryptRTCP(buf, buf, nil) + require.NoError(t, err) + } - if transport == "udp" { - _, err = l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ + require.Equal(t, testRTCPPacketMarshaled, buf) + } + + // client -> server + + for i := 0; i < 2; i++ { + buf1 := testRTPPacketMarshaled + + if ca.secure == "secure" { + encr := make([]byte, 2000) + encr, err = srtpOutCtx[i].encryptRTP(encr, buf1, nil) + require.NoError(t, err) + buf1 = encr + } + + buf2 := testRTCPPacketMarshaled + + if ca.secure == "secure" { + encr := make([]byte, 2000) + encr, err = srtpOutCtx[i].encryptRTCP(encr, buf2, nil) + require.NoError(t, err) + buf2 = encr + } + + if ca.transport == "udp" { + _, err = l1s[i].WriteTo(buf1, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: serverPorts[i][0], }) require.NoError(t, err) - _, err = l2s[i].WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{ + _, err = l2s[i].WriteTo(buf2, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: serverPorts[i][1], }) require.NoError(t, err) } else { - err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 2 + i*2, - Payload: testRTPPacketMarshaled, + Payload: buf1, }, make([]byte, 1024)) require.NoError(t, err) err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 3 + i*2, - Payload: testRTCPPacketMarshaled, + Payload: buf2, }, make([]byte, 1024)) require.NoError(t, err) } } - for i := 0; i < 2; i++ { - // server -> client + // server -> client - if transport == "udp" { - buf := make([]byte, 2048) - n, _, err := l2s[i].ReadFrom(buf) + for i := 0; i < 2; i++ { + var buf []byte + + if ca.transport == "udp" { + buf = make([]byte, 2048) + var n int + n, _, err = l2s[i].ReadFrom(buf) require.NoError(t, err) - require.Equal(t, testRTCPPacketMarshaled, buf[:n]) + buf = buf[:n] } else { - f, err := conn.ReadInterleavedFrame() + var f *base.InterleavedFrame + f, err = conn.ReadInterleavedFrame() require.NoError(t, err) require.Equal(t, 3+i*2, f.Channel) - require.Equal(t, testRTCPPacketMarshaled, f.Payload) + buf = f.Payload } + + if ca.secure == "secure" { + buf, err = srtpInCtx[i].decryptRTCP(buf, buf, nil) + require.NoError(t, err) + } + + require.Equal(t, testRTCPPacketMarshaled, buf) } doTeardown(t, conn, "rtsp://localhost:8554/teststream", session) diff --git a/server_session.go b/server_session.go index c4d7a4dd..ee59bfba 100644 --- a/server_session.go +++ b/server_session.go @@ -1,6 +1,7 @@ package gortsplib import ( + "bytes" "context" "fmt" "log" @@ -20,6 +21,7 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/format" "github.com/bluenviron/gortsplib/v4/pkg/headers" "github.com/bluenviron/gortsplib/v4/pkg/liberrors" + "github.com/bluenviron/gortsplib/v4/pkg/mikey" "github.com/bluenviron/gortsplib/v4/pkg/rtcpreceiver" "github.com/bluenviron/gortsplib/v4/pkg/rtcpsender" "github.com/bluenviron/gortsplib/v4/pkg/rtptime" @@ -165,21 +167,160 @@ func findMediaByTrackID(medias []*description.Media, trackID string) *descriptio return medias[id] } -func findFirstSupportedTransportHeader(s *Server, tsh headers.Transports) *headers.Transport { - // Per RFC2326 section 12.39, client specifies transports in order of preference. - // Filter out the ones we don't support and then pick first supported transport. - for _, tr := range tsh { +func isTransportSupported(s *Server, tr *headers.Transport) bool { + // prevent using UDP/UDP-multicast when listeners are disabled + if tr.Protocol == headers.TransportProtocolUDP { isMulticast := tr.Delivery != nil && *tr.Delivery == headers.TransportDeliveryMulticast - if tr.Protocol == headers.TransportProtocolUDP && - ((!isMulticast && s.udpRTPListener == nil) || - (isMulticast && s.MulticastIPRange == "")) { - continue + if !isMulticast && s.udpRTPListener == nil { + return false + } + if isMulticast && s.MulticastIPRange == "" { + return false + } + } + + // prevent using unsecure UDP with RTSPS + if tr.Protocol == headers.TransportProtocolUDP && !tr.Secure && s.TLSConfig != nil { + return false + } + + // prevent using secure profiles with plain RTSP, since keys are in plain + if tr.Secure && s.TLSConfig == nil { + return false + } + + return true +} + +func pickFirstSupportedTransport(s *Server, tsh headers.Transports) *headers.Transport { + for _, tr := range tsh { + if isTransportSupported(s, &tr) { + return &tr } - return &tr } return nil } +func mikeyDecodeTime(t uint64) time.Time { + sec := t >> 32 + dec := t & 0xFFFFFFFF + sec -= 2208988800 + return time.Unix(int64(sec), int64(dec)) +} + +func mikeyEncodeTime(n time.Time) uint64 { + nano := uint64(n.UnixNano()) + sec := nano / 1000000000 + dec := nano % 1000000000 + sec += 2208988800 + return sec<<32 | dec +} + +func mikeyGetPayload[T mikey.Payload](mikeyMsg *mikey.Message) (T, bool) { + var zero T + for _, wrapped := range mikeyMsg.Payloads { + if val, ok := wrapped.(T); ok { + return val, true + } + } + return zero, false +} + +func mikeyGetSPPolicy(spPayload *mikey.PayloadSP, typ mikey.PayloadSPPolicyParamType) ([]byte, bool) { + for _, pl := range spPayload.PolicyParams { + if pl.Type == typ { + return pl.Value, true + } + } + return nil, false +} + +func mikeyToContext(mikeyMsg *mikey.Message) (*wrappedSRTPContext, error) { + timePayload, ok := mikeyGetPayload[*mikey.PayloadT](mikeyMsg) + if !ok { + return nil, fmt.Errorf("time payload not present") + } + + ts := mikeyDecodeTime(timePayload.TSValue) + diff := time.Since(ts) + if diff < -time.Hour || diff > time.Hour { + return nil, fmt.Errorf("NTP difference is too high") + } + + spPayload, ok := mikeyGetPayload[*mikey.PayloadSP](mikeyMsg) + if !ok { + return nil, fmt.Errorf("SP payload not present") + } + + v, ok := mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeEncrAlg) + if !ok || !bytes.Equal(v, []byte{1}) { + return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeEncrAlg") + } + + v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeSessionEncrKeyLen) + if !ok || !bytes.Equal(v, []byte{0x10}) { + return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeSessionEncrKeyLen") + } + + v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeAuthAlg) + if !ok || !bytes.Equal(v, []byte{1}) { + return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeAuthAlg") + } + + v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeSessionAuthKeyLen) + if !ok || !bytes.Equal(v, []byte{0x0a}) { + return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeSessionAuthKeyLen") + } + + v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeSRTPEncrOffOn) + if !ok || !bytes.Equal(v, []byte{1}) { + return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeSRTPEncrOffOn") + } + + v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeSRTCPEncrOffOn) + if !ok || !bytes.Equal(v, []byte{1}) { + return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeSRTCPEncrOffOn") + } + + v, ok = mikeyGetSPPolicy(spPayload, mikey.PayloadSPPolicyParamTypeSRTPAuthOffOn) + if !ok || !bytes.Equal(v, []byte{1}) { + return nil, fmt.Errorf("missing or unsupported policy: PayloadSPPolicyParamTypeSRTPAuthOffOn") + } + + kemacPayload, ok := mikeyGetPayload[*mikey.PayloadKEMAC](mikeyMsg) + if !ok { + return nil, fmt.Errorf("KEMAC payload not present") + } + + if len(kemacPayload.SubPayloads) != 1 { + return nil, fmt.Errorf("multiple keys are present") + } + + if len(kemacPayload.SubPayloads[0].KeyData) != srtpKeyLength { + return nil, fmt.Errorf("unexpected key size: %d", len(kemacPayload.SubPayloads[0].KeyData)) + } + + ssrcs := make([]uint32, len(mikeyMsg.Header.CSIDMapInfo)) + startROCs := make([]uint32, len(mikeyMsg.Header.CSIDMapInfo)) + + for i, entry := range mikeyMsg.Header.CSIDMapInfo { + ssrcs[i] = entry.SSRC + startROCs[i] = entry.ROC + } + + srtpCtx := &wrappedSRTPContext{ + key: kemacPayload.SubPayloads[0].KeyData, + ssrcs: ssrcs, + startROCs: startROCs, + } + err := srtpCtx.initialize() + if err != nil { + return nil, err + } + + return srtpCtx, nil +} + func generateRTPInfoEntry(ssm *serverStreamMedia, now time.Time) *headers.RTPInfoEntry { // do not generate a RTP-Info entry when // there are multiple formats inside a single media stream, @@ -293,6 +434,7 @@ type ServerSession struct { setuppedMediasOrdered []*serverSessionMedia tcpCallbackByChannel map[int]readFunc setuppedTransport *Transport + setuppedSecure bool setuppedStream *ServerStream // play setuppedPath string setuppedQuery string @@ -371,6 +513,13 @@ func (ss *ServerSession) SetuppedTransport() *Transport { return ss.setuppedTransport } +// SetuppedSecure returns whether a secure profile is in use. +// If this is false, it does not mean that the stream is not secure, since +// there are some combinations that are secure nonetheless, like RTSPS+TCP+unsecure. +func (ss *ServerSession) SetuppedSecure() bool { + return ss.setuppedSecure +} + // SetuppedStream returns the stream associated with the session. func (ss *ServerSession) SetuppedStream() *ServerStream { return ss.setuppedStream @@ -947,7 +1096,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( }, liberrors.ErrServerTransportHeaderInvalid{Err: err} } - inTH := findFirstSupportedTransportHeader(ss.s, transportHeaders) + // Per RFC2326 section 12.39, client specifies transports in order of preference. + // pick the first supported one. + inTH := pickFirstSupportedTransport(ss.s, transportHeaders) if inTH == nil { return &base.Response{ StatusCode: base.StatusUnsupportedTransport, @@ -978,20 +1129,41 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( var transport Transport - if inTH.Protocol == headers.TransportProtocolUDP { + switch inTH.Protocol { + case headers.TransportProtocolUDP: if inTH.Delivery != nil && *inTH.Delivery == headers.TransportDeliveryMulticast { transport = TransportUDPMulticast } else { transport = TransportUDP } - } else { + + case headers.TransportProtocolTCP: transport = TransportTCP } - if ss.setuppedTransport != nil && *ss.setuppedTransport != transport { + var srtpInCtx *wrappedSRTPContext + + if inTH.Secure { + var keyMgmt headers.KeyMgmt + err = keyMgmt.Unmarshal(req.Header["KeyMgmt"]) + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, liberrors.ErrServerInvalidKeyMgmtHeader{Wrapped: err} + } + + srtpInCtx, err = mikeyToContext(keyMgmt.MikeyMessage) + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, liberrors.ErrServerInvalidKeyMgmtHeader{Wrapped: err} + } + } + + if ss.setuppedTransport != nil && (*ss.setuppedTransport != transport || ss.setuppedSecure != inTH.Secure) { return &base.Response{ StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerMediasDifferentProtocols{} + }, liberrors.ErrServerMediasDifferentTransports{} } switch transport { @@ -1052,7 +1224,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( // workaround to prevent a bug in rtspclientsink // that makes impossible for the client to receive the response // and send frames. - // this was causing problems during unit tests. + // this was causing problems during E2E tests. if ua, ok := req.Header["User-Agent"]; ok && len(ua) == 1 && strings.HasPrefix(ua[0], "GStreamer") { select { @@ -1092,6 +1264,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( } ss.setuppedTransport = &transport + ss.setuppedSecure = inTH.Secure if ss.state == ServerSessionStateInitial { err = stream.readerAdd(ss, @@ -1109,7 +1282,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( ss.setuppedStream = stream } - th := headers.Transport{} + th := headers.Transport{ + Secure: inTH.Secure, + } if ss.state == ServerSessionStatePrePlay { if stream != ss.setuppedStream { @@ -1131,6 +1306,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( sm := &serverSessionMedia{ ss: ss, media: medi, + srtpInCtx: srtpInCtx, onPacketRTCP: func(_ rtcp.Packet) {}, } err = sm.initialize() @@ -1141,46 +1317,48 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( } switch transport { - case TransportUDP: - sm.udpRTPReadPort = inTH.ClientPorts[0] - sm.udpRTCPReadPort = inTH.ClientPorts[1] - - sm.udpRTPWriteAddr = &net.UDPAddr{ - IP: ss.author.ip(), - Zone: ss.author.zone(), - Port: sm.udpRTPReadPort, - } - - sm.udpRTCPWriteAddr = &net.UDPAddr{ - IP: ss.author.ip(), - Zone: ss.author.zone(), - Port: sm.udpRTCPReadPort, - } - + case TransportUDP, TransportUDPMulticast: th.Protocol = headers.TransportProtocolUDP - de := headers.TransportDeliveryUnicast - th.Delivery = &de - th.ClientPorts = inTH.ClientPorts - th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()} - case TransportUDPMulticast: - th.Protocol = headers.TransportProtocolUDP - de := headers.TransportDeliveryMulticast - th.Delivery = &de - v := uint(127) - th.TTL = &v - d := stream.medias[medi].multicastWriter.ip() - th.Destination = &d - th.Ports = &[2]int{ss.s.MulticastRTPPort, ss.s.MulticastRTCPPort} + if transport == TransportUDP { + sm.udpRTPReadPort = inTH.ClientPorts[0] + sm.udpRTCPReadPort = inTH.ClientPorts[1] + + sm.udpRTPWriteAddr = &net.UDPAddr{ + IP: ss.author.ip(), + Zone: ss.author.zone(), + Port: sm.udpRTPReadPort, + } + + sm.udpRTCPWriteAddr = &net.UDPAddr{ + IP: ss.author.ip(), + Zone: ss.author.zone(), + Port: sm.udpRTCPReadPort, + } + + de := headers.TransportDeliveryUnicast + th.Delivery = &de + th.ClientPorts = inTH.ClientPorts + th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()} + } else { + de := headers.TransportDeliveryMulticast + th.Delivery = &de + v := uint(127) + th.TTL = &v + d := stream.medias[medi].multicastWriter.ip() + th.Destination = &d + th.Ports = &[2]int{ss.s.MulticastRTPPort, ss.s.MulticastRTCPPort} + } default: // TCP + th.Protocol = headers.TransportProtocolTCP + if inTH.InterleavedIDs != nil { sm.tcpChannel = inTH.InterleavedIDs[0] } else { sm.tcpChannel = ss.findFreeChannelPair() } - th.Protocol = headers.TransportProtocolTCP de := headers.TransportDeliveryUnicast th.Delivery = &de th.InterleavedIDs = &[2]int{sm.tcpChannel, sm.tcpChannel + 1} @@ -1193,6 +1371,38 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( ss.setuppedMediasOrdered = append(ss.setuppedMediasOrdered, sm) res.Header["Transport"] = th.Marshal() + + if inTH.Secure { + ssrcs := make([]uint32, len(sm.formats)) + n := 0 + for _, sf := range sm.formats { + ssrcs[n] = sf.localSSRC + n++ + } + + var mk *mikey.Message + mk, err = mikeyGenerate(sm.srtpOutCtx) + if err != nil { + return &base.Response{ + StatusCode: base.StatusInternalServerError, + }, err + } + + var enc base.HeaderValue + enc, err = headers.KeyMgmt{ + URL: req.URL.String(), + MikeyMessage: mk, + }.Marshal() + if err != nil { + return &base.Response{ + StatusCode: base.StatusInternalServerError, + }, err + } + + // always return KeyMgmt even if redundant when playing + // (since it's already present in the SDP) + res.Header["KeyMgmt"] = enc + } } return res, err @@ -1239,7 +1449,12 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( ss.timeDecoder.Initialize() for _, sm := range ss.setuppedMedias { - sm.start() + err = sm.start() + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err + } } if *ss.setuppedTransport == TransportTCP { @@ -1329,7 +1544,12 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( ss.timeDecoder.Initialize() for _, sm := range ss.setuppedMedias { - sm.start() + err = sm.start() + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err + } } if *ss.setuppedTransport == TransportTCP { @@ -1523,12 +1743,6 @@ func (ss *ServerSession) OnPacketRTCP(medi *description.Media, cb OnPacketRTCPFu sm.onPacketRTCP = cb } -func (ss *ServerSession) writePacketRTPEncoded(medi *description.Media, payloadType uint8, byts []byte) error { - sm := ss.setuppedMedias[medi] - sf := sm.formats[payloadType] - return sf.writePacketRTPEncoded(byts) -} - // WritePacketRTP writes a RTP packet to the session. func (ss *ServerSession) WritePacketRTP(medi *description.Media, pkt *rtp.Packet) error { sm := ss.setuppedMedias[medi] @@ -1536,22 +1750,13 @@ func (ss *ServerSession) WritePacketRTP(medi *description.Media, pkt *rtp.Packet return sf.writePacketRTP(pkt) } -func (ss *ServerSession) writePacketRTCPEncoded(medi *description.Media, byts []byte) error { - sm := ss.setuppedMedias[medi] - return sm.writePacketRTCPEncoded(byts) -} - // WritePacketRTCP writes a RTCP packet to the session. func (ss *ServerSession) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error { - byts, err := pkt.Marshal() - if err != nil { - return err - } - - return ss.writePacketRTCPEncoded(medi, byts) + sm := ss.setuppedMedias[medi] + return sm.writePacketRTCP(pkt) } -// PacketPTS returns the PTS of an incoming RTP packet. +// PacketPTS returns the PTS (presentation timestamp) of an incoming RTP packet. // It is computed by decoding the packet timestamp and sychronizing it with other tracks. // // Deprecated: replaced by PacketPTS2. @@ -1567,7 +1772,7 @@ func (ss *ServerSession) PacketPTS(medi *description.Media, pkt *rtp.Packet) (ti return multiplyAndDivide(time.Duration(v), time.Second, time.Duration(sf.format.ClockRate())), true } -// PacketPTS2 returns the PTS of an incoming RTP packet. +// PacketPTS2 returns the PTS (presentation timestamp) of an incoming RTP packet. // It is computed by decoding the packet timestamp and sychronizing it with other tracks. func (ss *ServerSession) PacketPTS2(medi *description.Media, pkt *rtp.Packet) (int64, bool) { sm := ss.setuppedMedias[medi] @@ -1575,8 +1780,8 @@ func (ss *ServerSession) PacketPTS2(medi *description.Media, pkt *rtp.Packet) (i return ss.timeDecoder.Decode(sf.format, pkt) } -// PacketNTP returns the NTP timestamp of an incoming RTP packet. -// The NTP timestamp is computed from RTCP sender reports. +// PacketNTP returns the NTP (absolute timestamp) of an incoming RTP packet. +// The NTP is computed from RTCP sender reports. func (ss *ServerSession) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time, bool) { sm := ss.setuppedMedias[medi] sf := sm.formats[pkt.PayloadType] diff --git a/server_session_format.go b/server_session_format.go index 34046a10..29bc8269 100644 --- a/server_session_format.go +++ b/server_session_format.go @@ -2,6 +2,7 @@ package gortsplib import ( "log" + "slices" "sync/atomic" "time" @@ -15,25 +16,26 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/rtpreorderer" ) -func isServerSessionLocalSSRCTaken(ssrc uint32, ss *ServerSession, exclude *serverSessionFormat) bool { - for _, sm := range ss.setuppedMedias { +func serverSessionPickLocalSSRC(sf *serverSessionFormat) (uint32, error) { + var takenSSRCs []uint32 //nolint:prealloc + + for _, sm := range sf.sm.ss.setuppedMedias { for _, sf := range sm.formats { - if sf != exclude && sf.localSSRC == ssrc { - return true - } + takenSSRCs = append(takenSSRCs, sf.localSSRC) } } - return false -} -func serverSessionPickLocalSSRC(sf *serverSessionFormat) (uint32, error) { + for _, sf := range sf.sm.formats { + takenSSRCs = append(takenSSRCs, sf.localSSRC) + } + for { ssrc, err := randUint32() if err != nil { return 0, err } - if ssrc != 0 && !isServerSessionLocalSSRCTaken(ssrc, sf.sm.ss, sf) { + if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) { return ssrc, nil } } @@ -188,14 +190,31 @@ func (sf *serverSessionFormat) onPacketRTPLost(lost uint64) { func (sf *serverSessionFormat) writePacketRTP(pkt *rtp.Packet) error { pkt.SSRC = sf.localSSRC - byts := make([]byte, sf.sm.ss.s.MaxPacketSize) - n, err := pkt.MarshalTo(byts) + maxPlainPacketSize := sf.sm.ss.s.MaxPacketSize + if sf.sm.ss.setuppedSecure { + maxPlainPacketSize -= srtpOverhead + } + + plain := make([]byte, maxPlainPacketSize) + n, err := pkt.MarshalTo(plain) if err != nil { return err } - byts = byts[:n] + plain = plain[:n] - return sf.writePacketRTPEncoded(byts) + var encr []byte + if sf.sm.ss.setuppedSecure { + encr = make([]byte, sf.sm.ss.s.MaxPacketSize) + encr, err = sf.sm.srtpOutCtx.encryptRTP(encr, plain, &pkt.Header) + if err != nil { + return err + } + } + + if sf.sm.ss.setuppedSecure { + return sf.writePacketRTPEncoded(encr) + } + return sf.writePacketRTPEncoded(plain) } func (sf *serverSessionFormat) writePacketRTPEncoded(payload []byte) error { diff --git a/server_session_media.go b/server_session_media.go index a7e829bf..a696f676 100644 --- a/server_session_media.go +++ b/server_session_media.go @@ -1,6 +1,8 @@ package gortsplib import ( + "crypto/rand" + "fmt" "log" "net" "sync/atomic" @@ -16,8 +18,10 @@ import ( type serverSessionMedia struct { ss *ServerSession media *description.Media + srtpInCtx *wrappedSRTPContext onPacketRTCP OnPacketRTCPFunc + srtpOutCtx *wrappedSRTPContext tcpChannel int udpRTPReadPort int udpRTPWriteAddr *net.UDPAddr @@ -56,12 +60,41 @@ func (sm *serverSessionMedia) initialize() error { sm.formats[forma.PayloadType()] = f } + if sm.ss.s.TLSConfig != nil { + if sm.ss.state == ServerSessionStatePreRecord || sm.media.IsBackChannel { + srtpOutKey := make([]byte, srtpKeyLength) + _, err := rand.Read(srtpOutKey) + if err != nil { + return err + } + + ssrcs := make([]uint32, len(sm.formats)) + n := 0 + for _, cf := range sm.formats { + ssrcs[n] = cf.localSSRC + n++ + } + + sm.srtpOutCtx = &wrappedSRTPContext{ + key: srtpOutKey, + ssrcs: ssrcs, + } + err = sm.srtpOutCtx.initialize() + if err != nil { + return err + } + } else { + streamMedia := sm.ss.setuppedStream.medias[sm.media] + sm.srtpOutCtx = streamMedia.srtpOutCtx + } + } + return nil } -func (sm *serverSessionMedia) start() { +func (sm *serverSessionMedia) start() error { // allocate udpRTCPReceiver before udpRTCPListener - // otherwise udpRTCPReceiver.LastSSRC() can't be called. + // otherwise udpRTCPReceiver.LastSSRC() cannot be called. for _, sf := range sm.formats { sf.start() } @@ -78,11 +111,33 @@ func (sm *serverSessionMedia) start() { sm.ss.s.udpRTCPListener.addClient(sm.ss.author.ip(), sm.udpRTCPReadPort, sm.readPacketRTCPUDPPlay) } else { // open the firewall by sending empty packets to the remote part. - byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal() - sm.ss.s.udpRTPListener.write(byts, sm.udpRTPWriteAddr) //nolint:errcheck + buf, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal() + if sm.srtpOutCtx != nil { + encr := make([]byte, sm.ss.s.MaxPacketSize) + encr, err := sm.srtpOutCtx.encryptRTP(encr, buf, nil) + if err != nil { + return err + } + buf = encr + } + err := sm.ss.s.udpRTPListener.write(buf, sm.udpRTPWriteAddr) + if err != nil { + return err + } - byts, _ = (&rtcp.ReceiverReport{}).Marshal() - sm.ss.s.udpRTCPListener.write(byts, sm.udpRTCPWriteAddr) //nolint:errcheck + buf, _ = (&rtcp.ReceiverReport{}).Marshal() + if sm.srtpOutCtx != nil { + encr := make([]byte, sm.ss.s.MaxPacketSize) + encr, err = sm.srtpOutCtx.encryptRTCP(encr, buf, nil) + if err != nil { + return err + } + buf = encr + } + err = sm.ss.s.udpRTCPListener.write(buf, sm.udpRTCPWriteAddr) + if err != nil { + return err + } sm.ss.s.udpRTPListener.addClient(sm.ss.author.ip(), sm.udpRTPReadPort, sm.readPacketRTPUDPRecord) sm.ss.s.udpRTCPListener.addClient(sm.ss.author.ip(), sm.udpRTCPReadPort, sm.readPacketRTCPUDPRecord) @@ -104,6 +159,8 @@ func (sm *serverSessionMedia) start() { sm.ss.tcpCallbackByChannel[sm.tcpChannel+1] = sm.readPacketRTCPTCPRecord } } + + return nil } func (sm *serverSessionMedia) stop() { @@ -127,6 +184,37 @@ func (sm *serverSessionMedia) findFormatByRemoteSSRC(ssrc uint32) *serverSession return nil } +func (sm *serverSessionMedia) decodeRTP(payload []byte) (*rtp.Packet, error) { + if sm.srtpInCtx != nil { + var err error + payload, err = sm.srtpInCtx.decryptRTP(payload, payload, nil) + if err != nil { + return nil, err + } + } + + var pkt rtp.Packet + err := pkt.Unmarshal(payload) + return &pkt, err +} + +func (sm *serverSessionMedia) decodeRTCP(payload []byte) ([]rtcp.Packet, error) { + if sm.srtpInCtx != nil { + var err error + payload, err = sm.srtpInCtx.decryptRTCP(payload, payload, nil) + if err != nil { + return nil, err + } + } + + pkts, err := rtcp.Unmarshal(payload) + if err != nil { + return nil, err + } + + return pkts, nil +} + func (sm *serverSessionMedia) readPacketRTPUDPPlay(payload []byte) bool { atomic.AddUint64(sm.bytesReceived, uint64(len(payload))) @@ -135,8 +223,7 @@ func (sm *serverSessionMedia) readPacketRTPUDPPlay(payload []byte) bool { return false } - pkt := &rtp.Packet{} - err := pkt.Unmarshal(payload) + pkt, err := sm.decodeRTP(payload) if err != nil { sm.onPacketRTPDecodeError(err) return false @@ -163,7 +250,7 @@ func (sm *serverSessionMedia) readPacketRTCPUDPPlay(payload []byte) bool { return false } - packets, err := rtcp.Unmarshal(payload) + packets, err := sm.decodeRTCP(payload) if err != nil { sm.onPacketRTCPDecodeError(err) return false @@ -189,8 +276,7 @@ func (sm *serverSessionMedia) readPacketRTPUDPRecord(payload []byte) bool { return false } - pkt := &rtp.Packet{} - err := pkt.Unmarshal(payload) + pkt, err := sm.decodeRTP(payload) if err != nil { sm.onPacketRTPDecodeError(err) return false @@ -218,7 +304,7 @@ func (sm *serverSessionMedia) readPacketRTCPUDPRecord(payload []byte) bool { return false } - packets, err := rtcp.Unmarshal(payload) + packets, err := sm.decodeRTCP(payload) if err != nil { sm.onPacketRTCPDecodeError(err) return false @@ -250,8 +336,7 @@ func (sm *serverSessionMedia) readPacketRTPTCPPlay(payload []byte) bool { atomic.AddUint64(sm.bytesReceived, uint64(len(payload))) - pkt := &rtp.Packet{} - err := pkt.Unmarshal(payload) + pkt, err := sm.decodeRTP(payload) if err != nil { sm.onPacketRTPDecodeError(err) return false @@ -276,7 +361,7 @@ func (sm *serverSessionMedia) readPacketRTCPTCPPlay(payload []byte) bool { return false } - packets, err := rtcp.Unmarshal(payload) + packets, err := sm.decodeRTCP(payload) if err != nil { sm.onPacketRTCPDecodeError(err) return false @@ -294,8 +379,7 @@ func (sm *serverSessionMedia) readPacketRTCPTCPPlay(payload []byte) bool { func (sm *serverSessionMedia) readPacketRTPTCPRecord(payload []byte) bool { atomic.AddUint64(sm.bytesReceived, uint64(len(payload))) - pkt := &rtp.Packet{} - err := pkt.Unmarshal(payload) + pkt, err := sm.decodeRTP(payload) if err != nil { sm.onPacketRTPDecodeError(err) return false @@ -320,7 +404,7 @@ func (sm *serverSessionMedia) readPacketRTCPTCPRecord(payload []byte) bool { return false } - packets, err := rtcp.Unmarshal(payload) + packets, err := sm.decodeRTCP(payload) if err != nil { sm.onPacketRTCPDecodeError(err) return false @@ -370,6 +454,36 @@ func (sm *serverSessionMedia) onPacketRTCPDecodeError(err error) { } } +func (sm *serverSessionMedia) writePacketRTCP(pkt rtcp.Packet) error { + plain, err := pkt.Marshal() + if err != nil { + return err + } + + maxPlainPacketSize := sm.ss.s.MaxPacketSize + if sm.ss.setuppedSecure { + maxPlainPacketSize -= srtcpOverhead + } + + if len(plain) > maxPlainPacketSize { + return fmt.Errorf("packet is too big") + } + + var encr []byte + if sm.ss.setuppedSecure { + encr = make([]byte, sm.ss.s.MaxPacketSize) + encr, err = sm.srtpOutCtx.encryptRTCP(encr, plain, nil) + if err != nil { + return err + } + } + + if sm.ss.setuppedSecure { + return sm.writePacketRTCPEncoded(encr) + } + return sm.writePacketRTCPEncoded(plain) +} + func (sm *serverSessionMedia) writePacketRTCPEncoded(payload []byte) error { sm.ss.writerMutex.RLock() defer sm.ss.writerMutex.RUnlock() diff --git a/server_stream.go b/server_stream.go index 97cea81a..4278b4eb 100644 --- a/server_stream.go +++ b/server_stream.go @@ -273,7 +273,7 @@ func (st *ServerStream) WritePacketRTP(medi *description.Media, pkt *rtp.Packet) } // WritePacketRTPWithNTP writes a RTP packet to all the readers of the stream. -// ntp is the absolute time of the packet, and is sent with periodic RTCP sender reports. +// ntp is the absolute timestamp of the packet, and is sent with periodic RTCP sender reports. func (st *ServerStream) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.Packet, ntp time.Time) error { st.mutex.RLock() defer st.mutex.RUnlock() diff --git a/server_stream_format.go b/server_stream_format.go index 30d26048..428cef43 100644 --- a/server_stream_format.go +++ b/server_stream_format.go @@ -2,6 +2,7 @@ package gortsplib import ( "crypto/rand" + "slices" "sync/atomic" "time" @@ -21,25 +22,26 @@ func randUint32() (uint32, error) { return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]), nil } -func isServerStreamLocalSSRCTaken(ssrc uint32, stream *ServerStream, exclude *serverStreamFormat) bool { - for _, sm := range stream.medias { +func serverStreamPickLocalSSRC(sf *serverStreamFormat) (uint32, error) { + var takenSSRCs []uint32 //nolint:prealloc + + for _, sm := range sf.sm.st.medias { for _, sf := range sm.formats { - if sf != exclude && sf.localSSRC == ssrc { - return true - } + takenSSRCs = append(takenSSRCs, sf.localSSRC) } } - return false -} -func serverStreamPickLocalSSRC(sf *serverStreamFormat) (uint32, error) { + for _, sf := range sf.sm.formats { + takenSSRCs = append(takenSSRCs, sf.localSSRC) + } + for { ssrc, err := randUint32() if err != nil { return 0, err } - if ssrc != 0 && !isServerStreamLocalSSRCTaken(ssrc, sf.sm.st, sf) { + if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) { return ssrc, nil } } @@ -87,39 +89,77 @@ func (sf *serverStreamFormat) close() { func (sf *serverStreamFormat) writePacketRTP(pkt *rtp.Packet, ntp time.Time) error { pkt.SSRC = sf.localSSRC - byts := make([]byte, sf.sm.st.Server.MaxPacketSize) - n, err := pkt.MarshalTo(byts) + sf.rtcpSender.ProcessPacket(pkt, ntp, sf.format.PTSEqualsDTS(pkt)) + + maxPlainPacketSize := sf.sm.st.Server.MaxPacketSize + if sf.sm.srtpOutCtx != nil { + maxPlainPacketSize -= srtpOverhead + } + + plain := make([]byte, maxPlainPacketSize) + n, err := pkt.MarshalTo(plain) if err != nil { return err } - byts = byts[:n] + plain = plain[:n] - sf.rtcpSender.ProcessPacket(pkt, ntp, sf.format.PTSEqualsDTS(pkt)) + var encr []byte + if sf.sm.srtpOutCtx != nil { + encr = make([]byte, sf.sm.st.Server.MaxPacketSize) + encr, err = sf.sm.srtpOutCtx.encryptRTP(encr, plain, &pkt.Header) + if err != nil { + return err + } + } - le := uint64(len(byts)) + encrLen := uint64(len(encr)) + plainLen := uint64(len(plain)) // send unicast for r := range sf.sm.st.activeUnicastReaders { - if _, ok := r.setuppedMedias[sf.sm.media]; ok { - err := r.writePacketRTPEncoded(sf.sm.media, pkt.PayloadType, byts) - if err != nil { - r.onStreamWriteError(err) - continue + if rsm, ok := r.setuppedMedias[sf.sm.media]; ok { + rsf := rsm.formats[pkt.PayloadType] + + if r.setuppedSecure { + err := rsf.writePacketRTPEncoded(encr) + if err != nil { + r.onStreamWriteError(err) + continue + } + + atomic.AddUint64(sf.sm.bytesSent, encrLen) + } else { + err := rsf.writePacketRTPEncoded(plain) + if err != nil { + r.onStreamWriteError(err) + continue + } + + atomic.AddUint64(sf.sm.bytesSent, plainLen) } - atomic.AddUint64(sf.sm.bytesSent, le) atomic.AddUint64(sf.rtpPacketsSent, 1) } } // send multicast if sf.sm.multicastWriter != nil { - err := sf.sm.multicastWriter.writePacketRTP(byts) - if err != nil { - return err + if sf.sm.srtpOutCtx != nil { + err := sf.sm.multicastWriter.writePacketRTP(encr) + if err != nil { + return err + } + + atomic.AddUint64(sf.sm.bytesSent, encrLen) + } else { + err := sf.sm.multicastWriter.writePacketRTP(plain) + if err != nil { + return err + } + + atomic.AddUint64(sf.sm.bytesSent, plainLen) } - atomic.AddUint64(sf.sm.bytesSent, le) atomic.AddUint64(sf.rtpPacketsSent, 1) } diff --git a/server_stream_media.go b/server_stream_media.go index 5a6ed050..f5692147 100644 --- a/server_stream_media.go +++ b/server_stream_media.go @@ -1,6 +1,8 @@ package gortsplib import ( + "crypto/rand" + "fmt" "sync/atomic" "github.com/bluenviron/gortsplib/v4/pkg/description" @@ -12,6 +14,7 @@ type serverStreamMedia struct { media *description.Media trackID int + srtpOutCtx *wrappedSRTPContext formats map[uint8]*serverStreamFormat multicastWriter *serverMulticastWriter bytesSent *uint64 @@ -40,6 +43,30 @@ func (sm *serverStreamMedia) initialize() error { sm.formats[forma.PayloadType()] = sf } + if sm.st.Server.TLSConfig != nil { + srtpOutKey := make([]byte, srtpKeyLength) + _, err := rand.Read(srtpOutKey) + if err != nil { + return err + } + + ssrcs := make([]uint32, len(sm.formats)) + n := 0 + for _, cf := range sm.formats { + ssrcs[n] = cf.localSSRC + n++ + } + + sm.srtpOutCtx = &wrappedSRTPContext{ + key: srtpOutKey, + ssrcs: ssrcs, + } + err = sm.srtpOutCtx.initialize() + if err != nil { + return err + } + } + return nil } @@ -54,35 +81,75 @@ func (sm *serverStreamMedia) close() { } func (sm *serverStreamMedia) writePacketRTCP(pkt rtcp.Packet) error { - byts, err := pkt.Marshal() + plain, err := pkt.Marshal() if err != nil { return err } - le := len(byts) + maxPlainPacketSize := sm.st.Server.MaxPacketSize + if sm.srtpOutCtx != nil { + maxPlainPacketSize -= srtcpOverhead + } + + if len(plain) > maxPlainPacketSize { + return fmt.Errorf("packet is too big") + } + + var encr []byte + if sm.srtpOutCtx != nil { + encr = make([]byte, sm.st.Server.MaxPacketSize) + encr, err = sm.srtpOutCtx.encryptRTCP(encr, plain, nil) + if err != nil { + return err + } + } + + encrLen := uint64(len(encr)) + plainLen := uint64(len(plain)) // send unicast for r := range sm.st.activeUnicastReaders { - if _, ok := r.setuppedMedias[sm.media]; ok { - err := r.writePacketRTCPEncoded(sm.media, byts) - if err != nil { - r.onStreamWriteError(err) - continue + if sm, ok := r.setuppedMedias[sm.media]; ok { + if r.setuppedSecure { + err := sm.writePacketRTCPEncoded(encr) + if err != nil { + r.onStreamWriteError(err) + continue + } + + atomic.AddUint64(sm.bytesSent, encrLen) + } else { + err := sm.writePacketRTCPEncoded(plain) + if err != nil { + r.onStreamWriteError(err) + continue + } + + atomic.AddUint64(sm.bytesSent, plainLen) } - atomic.AddUint64(sm.bytesSent, uint64(le)) atomic.AddUint64(sm.rtcpPacketsSent, 1) } } // send multicast if sm.multicastWriter != nil { - err := sm.multicastWriter.writePacketRTCP(byts) - if err != nil { - return err + if sm.srtpOutCtx != nil { + err := sm.multicastWriter.writePacketRTCP(encr) + if err != nil { + return err + } + + atomic.AddUint64(sm.bytesSent, encrLen) + } else { + err := sm.multicastWriter.writePacketRTCP(plain) + if err != nil { + return err + } + + atomic.AddUint64(sm.bytesSent, plainLen) } - atomic.AddUint64(sm.bytesSent, uint64(le)) atomic.AddUint64(sm.rtcpPacketsSent, 1) } diff --git a/wrapped_srtp_context.go b/wrapped_srtp_context.go new file mode 100644 index 00000000..2a2a5238 --- /dev/null +++ b/wrapped_srtp_context.go @@ -0,0 +1,63 @@ +package gortsplib + +import ( + "sync" + + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/pion/srtp/v3" +) + +// srtp.Context with +// - accessible key +// - accessible SSRCs +// - mutex around Encrypt*, ROC* +type wrappedSRTPContext struct { + key []byte + ssrcs []uint32 + startROCs []uint32 + + w *srtp.Context + mutex sync.RWMutex +} + +func (ctx *wrappedSRTPContext) initialize() error { + var err error + ctx.w, err = srtp.CreateContext(ctx.key[:16], ctx.key[16:], srtp.ProtectionProfileAes128CmHmacSha1_80) + if err != nil { + return err + } + + for i, roc := range ctx.startROCs { + ctx.w.SetROC(ctx.ssrcs[i], roc) + } + + return nil +} + +func (ctx *wrappedSRTPContext) decryptRTP(dst []byte, encrypted []byte, header *rtp.Header) ([]byte, error) { + return ctx.w.DecryptRTP(dst, encrypted, header) +} + +func (ctx *wrappedSRTPContext) decryptRTCP(dst []byte, encrypted []byte, header *rtcp.Header) ([]byte, error) { + return ctx.w.DecryptRTCP(dst, encrypted, header) +} + +func (ctx *wrappedSRTPContext) encryptRTP(dst []byte, plaintext []byte, header *rtp.Header) ([]byte, error) { + ctx.mutex.Lock() + defer ctx.mutex.Unlock() + return ctx.w.EncryptRTP(dst, plaintext, header) +} + +func (ctx *wrappedSRTPContext) encryptRTCP(dst []byte, decrypted []byte, header *rtcp.Header) ([]byte, error) { + ctx.mutex.Lock() + defer ctx.mutex.Unlock() + return ctx.w.EncryptRTCP(dst, decrypted, header) +} + +func (ctx *wrappedSRTPContext) roc(ssrc uint32) uint32 { + ctx.mutex.RLock() + defer ctx.mutex.RUnlock() + v, _ := ctx.w.ROC(ssrc) + return v +}