mirror of
https://github.com/pion/webrtc.git
synced 2025-09-26 19:21:12 +08:00

handleIncomingSSRC will call streamsForSSRC which opens rtp/rtcp streams that if unhandled can be leaked resources. Now we will proactively open them before calling handleIncomingSSRC and close then later. In the future it would be better to do this inside handleIncomingSSRC to protect other callers.
523 lines
14 KiB
Go
523 lines
14 KiB
Go
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
//go:build !js
|
|
// +build !js
|
|
|
|
package webrtc
|
|
|
|
import (
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/pion/dtls/v2"
|
|
"github.com/pion/dtls/v2/pkg/crypto/fingerprint"
|
|
"github.com/pion/interceptor"
|
|
"github.com/pion/logging"
|
|
"github.com/pion/rtcp"
|
|
"github.com/pion/srtp/v2"
|
|
"github.com/pion/webrtc/v3/internal/mux"
|
|
"github.com/pion/webrtc/v3/internal/util"
|
|
"github.com/pion/webrtc/v3/pkg/rtcerr"
|
|
)
|
|
|
|
// DTLSTransport allows an application access to information about the DTLS
|
|
// transport over which RTP and RTCP packets are sent and received by
|
|
// RTPSender and RTPReceiver, as well other data such as SCTP packets sent
|
|
// and received by data channels.
|
|
type DTLSTransport struct {
|
|
lock sync.RWMutex
|
|
|
|
iceTransport *ICETransport
|
|
certificates []Certificate
|
|
remoteParameters DTLSParameters
|
|
remoteCertificate []byte
|
|
state DTLSTransportState
|
|
srtpProtectionProfile srtp.ProtectionProfile
|
|
|
|
onStateChangeHandler func(DTLSTransportState)
|
|
|
|
conn *dtls.Conn
|
|
|
|
srtpSession, srtcpSession atomic.Value
|
|
srtpEndpoint, srtcpEndpoint *mux.Endpoint
|
|
simulcastStreams []simulcastStreamPair
|
|
srtpReady chan struct{}
|
|
|
|
dtlsMatcher mux.MatchFunc
|
|
|
|
api *API
|
|
log logging.LeveledLogger
|
|
}
|
|
|
|
type simulcastStreamPair struct {
|
|
srtp *srtp.ReadStreamSRTP
|
|
srtcp *srtp.ReadStreamSRTCP
|
|
}
|
|
|
|
// NewDTLSTransport creates a new DTLSTransport.
|
|
// This constructor is part of the ORTC API. It is not
|
|
// meant to be used together with the basic WebRTC API.
|
|
func (api *API) NewDTLSTransport(transport *ICETransport, certificates []Certificate) (*DTLSTransport, error) {
|
|
t := &DTLSTransport{
|
|
iceTransport: transport,
|
|
api: api,
|
|
state: DTLSTransportStateNew,
|
|
dtlsMatcher: mux.MatchDTLS,
|
|
srtpReady: make(chan struct{}),
|
|
log: api.settingEngine.LoggerFactory.NewLogger("DTLSTransport"),
|
|
}
|
|
|
|
if len(certificates) > 0 {
|
|
now := time.Now()
|
|
for _, x509Cert := range certificates {
|
|
if !x509Cert.Expires().IsZero() && now.After(x509Cert.Expires()) {
|
|
return nil, &rtcerr.InvalidAccessError{Err: ErrCertificateExpired}
|
|
}
|
|
t.certificates = append(t.certificates, x509Cert)
|
|
}
|
|
} else {
|
|
sk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
return nil, &rtcerr.UnknownError{Err: err}
|
|
}
|
|
certificate, err := GenerateCertificate(sk)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
t.certificates = []Certificate{*certificate}
|
|
}
|
|
|
|
return t, nil
|
|
}
|
|
|
|
// ICETransport returns the currently-configured *ICETransport or nil
|
|
// if one has not been configured
|
|
func (t *DTLSTransport) ICETransport() *ICETransport {
|
|
t.lock.RLock()
|
|
defer t.lock.RUnlock()
|
|
return t.iceTransport
|
|
}
|
|
|
|
// onStateChange requires the caller holds the lock
|
|
func (t *DTLSTransport) onStateChange(state DTLSTransportState) {
|
|
t.state = state
|
|
handler := t.onStateChangeHandler
|
|
if handler != nil {
|
|
handler(state)
|
|
}
|
|
}
|
|
|
|
// OnStateChange sets a handler that is fired when the DTLS
|
|
// connection state changes.
|
|
func (t *DTLSTransport) OnStateChange(f func(DTLSTransportState)) {
|
|
t.lock.Lock()
|
|
defer t.lock.Unlock()
|
|
t.onStateChangeHandler = f
|
|
}
|
|
|
|
// State returns the current dtls transport state.
|
|
func (t *DTLSTransport) State() DTLSTransportState {
|
|
t.lock.RLock()
|
|
defer t.lock.RUnlock()
|
|
return t.state
|
|
}
|
|
|
|
// WriteRTCP sends a user provided RTCP packet to the connected peer. If no peer is connected the
|
|
// packet is discarded.
|
|
func (t *DTLSTransport) WriteRTCP(pkts []rtcp.Packet) (int, error) {
|
|
raw, err := rtcp.Marshal(pkts)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
srtcpSession, err := t.getSRTCPSession()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
writeStream, err := srtcpSession.OpenWriteStream()
|
|
if err != nil {
|
|
// nolint
|
|
return 0, fmt.Errorf("%w: %v", errPeerConnWriteRTCPOpenWriteStream, err)
|
|
}
|
|
|
|
return writeStream.Write(raw)
|
|
}
|
|
|
|
// GetLocalParameters returns the DTLS parameters of the local DTLSTransport upon construction.
|
|
func (t *DTLSTransport) GetLocalParameters() (DTLSParameters, error) {
|
|
fingerprints := []DTLSFingerprint{}
|
|
|
|
for _, c := range t.certificates {
|
|
prints, err := c.GetFingerprints()
|
|
if err != nil {
|
|
return DTLSParameters{}, err
|
|
}
|
|
|
|
fingerprints = append(fingerprints, prints...)
|
|
}
|
|
|
|
return DTLSParameters{
|
|
Role: DTLSRoleAuto, // always returns the default role
|
|
Fingerprints: fingerprints,
|
|
}, nil
|
|
}
|
|
|
|
// GetRemoteCertificate returns the certificate chain in use by the remote side
|
|
// returns an empty list prior to selection of the remote certificate
|
|
func (t *DTLSTransport) GetRemoteCertificate() []byte {
|
|
t.lock.RLock()
|
|
defer t.lock.RUnlock()
|
|
return t.remoteCertificate
|
|
}
|
|
|
|
func (t *DTLSTransport) startSRTP() error {
|
|
srtpConfig := &srtp.Config{
|
|
Profile: t.srtpProtectionProfile,
|
|
BufferFactory: t.api.settingEngine.BufferFactory,
|
|
LoggerFactory: t.api.settingEngine.LoggerFactory,
|
|
}
|
|
if t.api.settingEngine.replayProtection.SRTP != nil {
|
|
srtpConfig.RemoteOptions = append(
|
|
srtpConfig.RemoteOptions,
|
|
srtp.SRTPReplayProtection(*t.api.settingEngine.replayProtection.SRTP),
|
|
)
|
|
}
|
|
|
|
if t.api.settingEngine.disableSRTPReplayProtection {
|
|
srtpConfig.RemoteOptions = append(
|
|
srtpConfig.RemoteOptions,
|
|
srtp.SRTPNoReplayProtection(),
|
|
)
|
|
}
|
|
|
|
if t.api.settingEngine.replayProtection.SRTCP != nil {
|
|
srtpConfig.RemoteOptions = append(
|
|
srtpConfig.RemoteOptions,
|
|
srtp.SRTCPReplayProtection(*t.api.settingEngine.replayProtection.SRTCP),
|
|
)
|
|
}
|
|
|
|
if t.api.settingEngine.disableSRTCPReplayProtection {
|
|
srtpConfig.RemoteOptions = append(
|
|
srtpConfig.RemoteOptions,
|
|
srtp.SRTCPNoReplayProtection(),
|
|
)
|
|
}
|
|
|
|
connState := t.conn.ConnectionState()
|
|
err := srtpConfig.ExtractSessionKeysFromDTLS(&connState, t.role() == DTLSRoleClient)
|
|
if err != nil {
|
|
// nolint
|
|
return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
|
|
}
|
|
|
|
srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig)
|
|
if err != nil {
|
|
// nolint
|
|
return fmt.Errorf("%w: %v", errFailedToStartSRTP, err)
|
|
}
|
|
|
|
srtcpSession, err := srtp.NewSessionSRTCP(t.srtcpEndpoint, srtpConfig)
|
|
if err != nil {
|
|
// nolint
|
|
return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err)
|
|
}
|
|
|
|
t.srtpSession.Store(srtpSession)
|
|
t.srtcpSession.Store(srtcpSession)
|
|
close(t.srtpReady)
|
|
return nil
|
|
}
|
|
|
|
func (t *DTLSTransport) getSRTPSession() (*srtp.SessionSRTP, error) {
|
|
if value, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok {
|
|
return value, nil
|
|
}
|
|
|
|
return nil, errDtlsTransportNotStarted
|
|
}
|
|
|
|
func (t *DTLSTransport) getSRTCPSession() (*srtp.SessionSRTCP, error) {
|
|
if value, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok {
|
|
return value, nil
|
|
}
|
|
|
|
return nil, errDtlsTransportNotStarted
|
|
}
|
|
|
|
func (t *DTLSTransport) role() DTLSRole {
|
|
// If remote has an explicit role use the inverse
|
|
switch t.remoteParameters.Role {
|
|
case DTLSRoleClient:
|
|
return DTLSRoleServer
|
|
case DTLSRoleServer:
|
|
return DTLSRoleClient
|
|
default:
|
|
}
|
|
|
|
// If SettingEngine has an explicit role
|
|
switch t.api.settingEngine.answeringDTLSRole {
|
|
case DTLSRoleServer:
|
|
return DTLSRoleServer
|
|
case DTLSRoleClient:
|
|
return DTLSRoleClient
|
|
default:
|
|
}
|
|
|
|
// Remote was auto and no explicit role was configured via SettingEngine
|
|
if t.iceTransport.Role() == ICERoleControlling {
|
|
return DTLSRoleServer
|
|
}
|
|
return defaultDtlsRoleAnswer
|
|
}
|
|
|
|
// Start DTLS transport negotiation with the parameters of the remote DTLS transport
|
|
func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
|
|
// Take lock and prepare connection, we must not hold the lock
|
|
// when connecting
|
|
prepareTransport := func() (DTLSRole, *dtls.Config, error) {
|
|
t.lock.Lock()
|
|
defer t.lock.Unlock()
|
|
|
|
if err := t.ensureICEConn(); err != nil {
|
|
return DTLSRole(0), nil, err
|
|
}
|
|
|
|
if t.state != DTLSTransportStateNew {
|
|
return DTLSRole(0), nil, &rtcerr.InvalidStateError{Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, t.state)}
|
|
}
|
|
|
|
t.srtpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTP)
|
|
t.srtcpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTCP)
|
|
t.remoteParameters = remoteParameters
|
|
|
|
cert := t.certificates[0]
|
|
t.onStateChange(DTLSTransportStateConnecting)
|
|
|
|
return t.role(), &dtls.Config{
|
|
Certificates: []tls.Certificate{
|
|
{
|
|
Certificate: [][]byte{cert.x509Cert.Raw},
|
|
PrivateKey: cert.privateKey,
|
|
},
|
|
},
|
|
SRTPProtectionProfiles: func() []dtls.SRTPProtectionProfile {
|
|
if len(t.api.settingEngine.srtpProtectionProfiles) > 0 {
|
|
return t.api.settingEngine.srtpProtectionProfiles
|
|
}
|
|
|
|
return defaultSrtpProtectionProfiles()
|
|
}(),
|
|
ClientAuth: dtls.RequireAnyClientCert,
|
|
LoggerFactory: t.api.settingEngine.LoggerFactory,
|
|
InsecureSkipVerify: !t.api.settingEngine.dtls.disableInsecureSkipVerify,
|
|
}, nil
|
|
}
|
|
|
|
var dtlsConn *dtls.Conn
|
|
dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS)
|
|
role, dtlsConfig, err := prepareTransport()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if t.api.settingEngine.replayProtection.DTLS != nil {
|
|
dtlsConfig.ReplayProtectionWindow = int(*t.api.settingEngine.replayProtection.DTLS)
|
|
}
|
|
|
|
if t.api.settingEngine.dtls.clientAuth != nil {
|
|
dtlsConfig.ClientAuth = *t.api.settingEngine.dtls.clientAuth
|
|
}
|
|
|
|
dtlsConfig.FlightInterval = t.api.settingEngine.dtls.retransmissionInterval
|
|
dtlsConfig.InsecureSkipVerifyHello = t.api.settingEngine.dtls.insecureSkipHelloVerify
|
|
dtlsConfig.EllipticCurves = t.api.settingEngine.dtls.ellipticCurves
|
|
dtlsConfig.ConnectContextMaker = t.api.settingEngine.dtls.connectContextMaker
|
|
dtlsConfig.ExtendedMasterSecret = t.api.settingEngine.dtls.extendedMasterSecret
|
|
dtlsConfig.ClientCAs = t.api.settingEngine.dtls.clientCAs
|
|
dtlsConfig.RootCAs = t.api.settingEngine.dtls.rootCAs
|
|
dtlsConfig.KeyLogWriter = t.api.settingEngine.dtls.keyLogWriter
|
|
|
|
// Connect as DTLS Client/Server, function is blocking and we
|
|
// must not hold the DTLSTransport lock
|
|
if role == DTLSRoleClient {
|
|
dtlsConn, err = dtls.Client(dtlsEndpoint, dtlsConfig)
|
|
} else {
|
|
dtlsConn, err = dtls.Server(dtlsEndpoint, dtlsConfig)
|
|
}
|
|
|
|
// Re-take the lock, nothing beyond here is blocking
|
|
t.lock.Lock()
|
|
defer t.lock.Unlock()
|
|
|
|
if err != nil {
|
|
t.onStateChange(DTLSTransportStateFailed)
|
|
return err
|
|
}
|
|
|
|
srtpProfile, ok := dtlsConn.SelectedSRTPProtectionProfile()
|
|
if !ok {
|
|
t.onStateChange(DTLSTransportStateFailed)
|
|
return ErrNoSRTPProtectionProfile
|
|
}
|
|
|
|
switch srtpProfile {
|
|
case dtls.SRTP_AEAD_AES_128_GCM:
|
|
t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes128Gcm
|
|
case dtls.SRTP_AEAD_AES_256_GCM:
|
|
t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes256Gcm
|
|
case dtls.SRTP_AES128_CM_HMAC_SHA1_80:
|
|
t.srtpProtectionProfile = srtp.ProtectionProfileAes128CmHmacSha1_80
|
|
case dtls.SRTP_NULL_HMAC_SHA1_80:
|
|
t.srtpProtectionProfile = srtp.ProtectionProfileNullHmacSha1_80
|
|
default:
|
|
t.onStateChange(DTLSTransportStateFailed)
|
|
return ErrNoSRTPProtectionProfile
|
|
}
|
|
|
|
// Check the fingerprint if a certificate was exchanged
|
|
remoteCerts := dtlsConn.ConnectionState().PeerCertificates
|
|
if len(remoteCerts) == 0 {
|
|
t.onStateChange(DTLSTransportStateFailed)
|
|
return errNoRemoteCertificate
|
|
}
|
|
t.remoteCertificate = remoteCerts[0]
|
|
|
|
if !t.api.settingEngine.disableCertificateFingerprintVerification {
|
|
parsedRemoteCert, err := x509.ParseCertificate(t.remoteCertificate)
|
|
if err != nil {
|
|
if closeErr := dtlsConn.Close(); closeErr != nil {
|
|
t.log.Error(err.Error())
|
|
}
|
|
|
|
t.onStateChange(DTLSTransportStateFailed)
|
|
return err
|
|
}
|
|
|
|
if err = t.validateFingerPrint(parsedRemoteCert); err != nil {
|
|
if closeErr := dtlsConn.Close(); closeErr != nil {
|
|
t.log.Error(err.Error())
|
|
}
|
|
|
|
t.onStateChange(DTLSTransportStateFailed)
|
|
return err
|
|
}
|
|
}
|
|
|
|
t.conn = dtlsConn
|
|
t.onStateChange(DTLSTransportStateConnected)
|
|
|
|
return t.startSRTP()
|
|
}
|
|
|
|
// Stop stops and closes the DTLSTransport object.
|
|
func (t *DTLSTransport) Stop() error {
|
|
t.lock.Lock()
|
|
defer t.lock.Unlock()
|
|
|
|
// Try closing everything and collect the errors
|
|
var closeErrs []error
|
|
|
|
if srtpSession, err := t.getSRTPSession(); err == nil && srtpSession != nil {
|
|
closeErrs = append(closeErrs, srtpSession.Close())
|
|
}
|
|
|
|
if srtcpSession, err := t.getSRTCPSession(); err == nil && srtcpSession != nil {
|
|
closeErrs = append(closeErrs, srtcpSession.Close())
|
|
}
|
|
|
|
for i := range t.simulcastStreams {
|
|
closeErrs = append(closeErrs, t.simulcastStreams[i].srtp.Close())
|
|
closeErrs = append(closeErrs, t.simulcastStreams[i].srtcp.Close())
|
|
}
|
|
|
|
if t.conn != nil {
|
|
// dtls connection may be closed on sctp close.
|
|
if err := t.conn.Close(); err != nil && !errors.Is(err, dtls.ErrConnClosed) {
|
|
closeErrs = append(closeErrs, err)
|
|
}
|
|
}
|
|
t.onStateChange(DTLSTransportStateClosed)
|
|
return util.FlattenErrs(closeErrs)
|
|
}
|
|
|
|
func (t *DTLSTransport) validateFingerPrint(remoteCert *x509.Certificate) error {
|
|
for _, fp := range t.remoteParameters.Fingerprints {
|
|
hashAlgo, err := fingerprint.HashFromString(fp.Algorithm)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
remoteValue, err := fingerprint.Fingerprint(remoteCert, hashAlgo)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if strings.EqualFold(remoteValue, fp.Value) {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return errNoMatchingCertificateFingerprint
|
|
}
|
|
|
|
func (t *DTLSTransport) ensureICEConn() error {
|
|
if t.iceTransport == nil {
|
|
return errICEConnectionNotStarted
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *DTLSTransport) storeSimulcastStream(srtpReadStream *srtp.ReadStreamSRTP, srtcpReadStream *srtp.ReadStreamSRTCP) {
|
|
t.lock.Lock()
|
|
defer t.lock.Unlock()
|
|
|
|
t.simulcastStreams = append(t.simulcastStreams, simulcastStreamPair{srtpReadStream, srtcpReadStream})
|
|
}
|
|
|
|
func (t *DTLSTransport) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) {
|
|
srtpSession, err := t.getSRTPSession()
|
|
if err != nil {
|
|
return nil, nil, nil, nil, err
|
|
}
|
|
|
|
rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc))
|
|
if err != nil {
|
|
return nil, nil, nil, nil, err
|
|
}
|
|
|
|
rtpInterceptor := t.api.interceptor.BindRemoteStream(&streamInfo, interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
|
|
n, err = rtpReadStream.Read(in)
|
|
return n, a, err
|
|
}))
|
|
|
|
srtcpSession, err := t.getSRTCPSession()
|
|
if err != nil {
|
|
return nil, nil, nil, nil, err
|
|
}
|
|
|
|
rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc))
|
|
if err != nil {
|
|
return nil, nil, nil, nil, err
|
|
}
|
|
|
|
rtcpInterceptor := t.api.interceptor.BindRTCPReader(interceptor.RTCPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
|
|
n, err = rtcpReadStream.Read(in)
|
|
return n, a, err
|
|
}))
|
|
|
|
return rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, nil
|
|
}
|