Handle Simulcast RepairStream

Read + Discard packets from the Simulcast repair stream. When a
Simulcast stream is enabled the remote will send packets via the repair
stream for probing. We can't ignore these packets anymore because it
will cause gaps in the feedback reports

Resolves #1957
This commit is contained in:
Sean DuBois
2021-09-14 22:31:10 -04:00
parent f8fa792477
commit 11b8873da2
9 changed files with 161 additions and 72 deletions

View File

@@ -31,6 +31,8 @@ const (
incomingUnhandledRTPSsrc = "Incoming unhandled RTP ssrc(%d), OnTrack will not be fired. %v"
generatedCertificateOrigin = "WebRTC"
sdesRepairRTPStreamIDURI = "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id"
)
func defaultSrtpProtectionProfiles() []dtls.SRTPProtectionProfile {

View File

@@ -17,6 +17,7 @@ import (
"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/pkg/crypto/fingerprint"
"github.com/pion/interceptor"
"github.com/pion/logging"
"github.com/pion/rtcp"
"github.com/pion/srtp/v2"
@@ -459,3 +460,37 @@ func (t *DTLSTransport) storeSimulcastStream(s *srtp.ReadStreamSRTP) {
t.simulcastStreams = append(t.simulcastStreams, s)
}
func (t *DTLSTransport) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) {
srtpSession, err := t.getSRTPSession()
if err != nil {
return nil, nil, nil, nil, err
}
rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc))
if err != nil {
return nil, nil, nil, nil, err
}
rtpInterceptor := t.api.interceptor.BindRemoteStream(&streamInfo, interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
n, err = rtpReadStream.Read(in)
return n, a, err
}))
srtcpSession, err := t.getSRTCPSession()
if err != nil {
return nil, nil, nil, nil, err
}
rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc))
if err != nil {
return nil, nil, nil, nil, err
}
rtcpInterceptor := t.api.interceptor.BindRTCPReader(interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
n, err = rtcpReadStream.Read(in)
return n, a, err
}))
return rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, nil
}

View File

@@ -200,7 +200,6 @@ var (
errRTPReceiverDTLSTransportNil = errors.New("DTLSTransport must not be nil")
errRTPReceiverReceiveAlreadyCalled = errors.New("Receive has already been called")
errRTPReceiverWithSSRCTrackStreamNotFound = errors.New("unable to find stream for Track with SSRC")
errRTPReceiverForSSRCTrackStreamNotFound = errors.New("no trackStreams found for SSRC")
errRTPReceiverForRIDTrackStreamNotFound = errors.New("no trackStreams found for RID")
errRTPSenderTrackNil = errors.New("Track must not be nil")

View File

@@ -117,7 +117,7 @@ func (i *interceptorToTrackLocalWriter) Write(b []byte) (int, error) {
return i.WriteRTP(&packet.Header, packet.Payload)
}
func createStreamInfo(id string, ssrc SSRC, payloadType PayloadType, codec RTPCodecCapability, webrtcHeaderExtensions []RTPHeaderExtensionParameter) interceptor.StreamInfo {
func createStreamInfo(id string, ssrc SSRC, payloadType PayloadType, codec RTPCodecCapability, webrtcHeaderExtensions []RTPHeaderExtensionParameter) *interceptor.StreamInfo {
headerExtensions := make([]interceptor.RTPHeaderExtension, 0, len(webrtcHeaderExtensions))
for _, h := range webrtcHeaderExtensions {
headerExtensions = append(headerExtensions, interceptor.RTPHeaderExtension{ID: h.ID, URI: h.URI})
@@ -128,7 +128,7 @@ func createStreamInfo(id string, ssrc SSRC, payloadType PayloadType, codec RTPCo
feedbacks = append(feedbacks, interceptor.RTCPFeedback{Type: f.Type, Parameter: f.Parameter})
}
return interceptor.StreamInfo{
return &interceptor.StreamInfo{
ID: id,
Attributes: interceptor.Attributes{},
SSRC: uint32(ssrc),

View File

@@ -1397,42 +1397,57 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
return errPeerConnSimulcastStreamIDRTPExtensionRequired
}
repairStreamIDExtensionID, _, _ := pc.api.mediaEngine.getHeaderExtensionID(RTPHeaderExtensionCapability{sdesRepairRTPStreamIDURI})
b := make([]byte, pc.api.settingEngine.getReceiveMTU())
var mid, rid string
i, err := rtpStream.Read(b)
if err != nil {
return err
}
var mid, rid, rsid string
payloadType, err := handleUnknownRTPPacket(b[:i], uint8(midExtensionID), uint8(streamIDExtensionID), uint8(repairStreamIDExtensionID), &mid, &rid, &rsid)
if err != nil {
return err
}
params, err := pc.api.mediaEngine.getRTPParametersByPayloadType(payloadType)
if err != nil {
return err
}
streamInfo := createStreamInfo("", ssrc, params.Codecs[0].PayloadType, params.Codecs[0].RTPCodecCapability, params.HeaderExtensions)
readStream, interceptor, rtcpReadStream, rtcpInterceptor, err := pc.dtlsTransport.streamsForSSRC(ssrc, *streamInfo)
if err != nil {
return err
}
for readCount := 0; readCount <= simulcastProbeCount; readCount++ {
i, err := rtpStream.Read(b)
if err != nil {
return err
}
if mid == "" || (rid == "" && rsid == "") {
i, _, err := interceptor.Read(b, nil)
if err != nil {
return err
}
maybeMid, maybeRid, payloadType, err := handleUnknownRTPPacket(b[:i], uint8(midExtensionID), uint8(streamIDExtensionID))
if err != nil {
return err
}
if _, err = handleUnknownRTPPacket(b[:i], uint8(midExtensionID), uint8(streamIDExtensionID), uint8(repairStreamIDExtensionID), &mid, &rid, &rsid); err != nil {
return err
}
if maybeMid != "" {
mid = maybeMid
}
if maybeRid != "" {
rid = maybeRid
}
if mid == "" || rid == "" {
continue
}
params, err := pc.api.mediaEngine.getRTPParametersByPayloadType(payloadType)
if err != nil {
return err
}
for _, t := range pc.GetTransceivers() {
receiver := t.Receiver()
if t.Mid() != mid || receiver == nil {
continue
}
track, err := receiver.receiveForRid(rid, params, ssrc)
if rsid != "" {
return receiver.receiveForRsid(rsid, streamInfo, readStream, interceptor, rtcpReadStream, rtcpInterceptor)
}
track, err := receiver.receiveForRid(rid, params, streamInfo, readStream, interceptor, rtcpReadStream, rtcpInterceptor)
if err != nil {
return err
}
@@ -1441,6 +1456,13 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
}
}
if readStream != nil {
_ = readStream.Close()
}
if rtcpReadStream != nil {
_ = rtcpReadStream.Close()
}
pc.api.interceptor.UnbindRemoteStream(streamInfo)
return errPeerConnSimulcastIncomingSSRCFailed
}

View File

@@ -1175,7 +1175,14 @@ func TestPeerConnection_Simulcast(t *testing.T) {
PayloadType: 96,
}
assert.NoError(t, header.SetExtension(1, []byte("0")))
assert.NoError(t, header.SetExtension(2, []byte(rid)))
// Send RSID for first 10 packets
if sequenceNumber >= 10 {
assert.NoError(t, header.SetExtension(2, []byte(rid)))
} else {
assert.NoError(t, header.SetExtension(3, []byte(rid)))
header.SSRC += 10
}
_, err := vp8Writer.bindings[0].writeStream.WriteRTP(header, []byte{0x00})
assert.NoError(t, err)

View File

@@ -19,13 +19,19 @@ import (
type trackStreams struct {
track *TrackRemote
streamInfo interceptor.StreamInfo
streamInfo, repairStreamInfo *interceptor.StreamInfo
rtpReadStream *srtp.ReadStreamSRTP
rtpInterceptor interceptor.RTPReader
rtcpReadStream *srtp.ReadStreamSRTCP
rtcpInterceptor interceptor.RTCPReader
repairReadStream *srtp.ReadStreamSRTP
repairInterceptor interceptor.RTPReader
repairRtcpReadStream *srtp.ReadStreamSRTCP
repairRtcpInterceptor interceptor.RTCPReader
}
// RTPReceiver allows an application to inspect the receipt of a TrackRemote
@@ -146,7 +152,7 @@ func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error {
if parameters.Encodings[i].SSRC != 0 {
t.streamInfo = createStreamInfo("", parameters.Encodings[i].SSRC, 0, codec, globalParams.HeaderExtensions)
var err error
if t.rtpReadStream, t.rtpInterceptor, t.rtcpReadStream, t.rtcpInterceptor, err = r.streamsForSSRC(parameters.Encodings[i].SSRC, t.streamInfo); err != nil {
if t.rtpReadStream, t.rtpInterceptor, t.rtcpReadStream, t.rtcpInterceptor, err = r.transport.streamsForSSRC(parameters.Encodings[i].SSRC, *t.streamInfo); err != nil {
return err
}
}
@@ -245,8 +251,23 @@ func (r *RTPReceiver) Stop() error {
errs = append(errs, r.tracks[i].rtpReadStream.Close())
}
if r.tracks[i].repairReadStream != nil {
errs = append(errs, r.tracks[i].repairReadStream.Close())
}
if r.tracks[i].repairRtcpReadStream != nil {
errs = append(errs, r.tracks[i].repairRtcpReadStream.Close())
}
if r.tracks[i].streamInfo != nil {
r.api.interceptor.UnbindRemoteStream(r.tracks[i].streamInfo)
}
if r.tracks[i].repairStreamInfo != nil {
r.api.interceptor.UnbindRemoteStream(r.tracks[i].repairStreamInfo)
}
err = util.FlattenErrs(errs)
r.api.interceptor.UnbindRemoteStream(&r.tracks[i].streamInfo)
}
default:
}
@@ -276,7 +297,7 @@ func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, a intercept
// receiveForRid is the sibling of Receive expect for RIDs instead of SSRCs
// It populates all the internal state for the given RID
func (r *RTPReceiver) receiveForRid(rid string, params RTPParameters, ssrc SSRC) (*TrackRemote, error) {
func (r *RTPReceiver) receiveForRid(rid string, params RTPParameters, streamInfo *interceptor.StreamInfo, rtpReadStream *srtp.ReadStreamSRTP, rtpInterceptor interceptor.RTPReader, rtcpReadStream *srtp.ReadStreamSRTCP, rtcpInterceptor interceptor.RTCPReader) (*TrackRemote, error) {
r.mu.Lock()
defer r.mu.Unlock()
@@ -286,54 +307,53 @@ func (r *RTPReceiver) receiveForRid(rid string, params RTPParameters, ssrc SSRC)
r.tracks[i].track.kind = r.kind
r.tracks[i].track.codec = params.Codecs[0]
r.tracks[i].track.params = params
r.tracks[i].track.ssrc = ssrc
r.tracks[i].streamInfo = createStreamInfo("", ssrc, params.Codecs[0].PayloadType, params.Codecs[0].RTPCodecCapability, params.HeaderExtensions)
r.tracks[i].track.ssrc = SSRC(streamInfo.SSRC)
r.tracks[i].track.mu.Unlock()
var err error
if r.tracks[i].rtpReadStream, r.tracks[i].rtpInterceptor, r.tracks[i].rtcpReadStream, r.tracks[i].rtcpInterceptor, err = r.streamsForSSRC(ssrc, r.tracks[i].streamInfo); err != nil {
return nil, err
}
r.tracks[i].streamInfo = streamInfo
r.tracks[i].rtpReadStream = rtpReadStream
r.tracks[i].rtpInterceptor = rtpInterceptor
r.tracks[i].rtcpReadStream = rtcpReadStream
r.tracks[i].rtcpInterceptor = rtcpInterceptor
return r.tracks[i].track, nil
}
}
return nil, fmt.Errorf("%w: %d", errRTPReceiverForSSRCTrackStreamNotFound, ssrc)
return nil, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid)
}
func (r *RTPReceiver) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) {
srtpSession, err := r.transport.getSRTPSession()
if err != nil {
return nil, nil, nil, nil, err
// receiveForRsid starts a routine that processes the repair stream for a RID
// These packets aren't exposed to the user yet, but we need to process them for
// TWCC
func (r *RTPReceiver) receiveForRsid(rsid string, streamInfo *interceptor.StreamInfo, rtpReadStream *srtp.ReadStreamSRTP, rtpInterceptor interceptor.RTPReader, rtcpReadStream *srtp.ReadStreamSRTCP, rtcpInterceptor interceptor.RTCPReader) error {
r.mu.Lock()
defer r.mu.Unlock()
for i := range r.tracks {
if r.tracks[i].track.RID() == rsid {
var err error
r.tracks[i].repairStreamInfo = streamInfo
r.tracks[i].repairReadStream = rtpReadStream
r.tracks[i].repairInterceptor = rtpInterceptor
r.tracks[i].repairRtcpReadStream = rtcpReadStream
r.tracks[i].repairRtcpInterceptor = rtcpInterceptor
go func() {
b := make([]byte, r.api.settingEngine.getReceiveMTU())
for {
if _, _, readErr := r.tracks[i].repairInterceptor.Read(b, nil); readErr != nil {
return
}
}
}()
return err
}
}
rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc))
if err != nil {
return nil, nil, nil, nil, err
}
rtpInterceptor := r.api.interceptor.BindRemoteStream(&streamInfo, interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
n, err = rtpReadStream.Read(in)
return n, a, err
}))
srtcpSession, err := r.transport.getSRTCPSession()
if err != nil {
return nil, nil, nil, nil, err
}
rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc))
if err != nil {
return nil, nil, nil, nil, err
}
rtcpInterceptor := r.api.interceptor.BindRTCPReader(interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
n, err = rtcpReadStream.Read(in)
return n, a, err
}))
return rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, nil
return fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rsid)
}
// SetReadDeadline sets the max amount of time the RTCP stream will block before returning. 0 is forever.

View File

@@ -211,7 +211,7 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error {
}
r.context.params.Codecs = []RTPCodecParameters{codec}
r.streamInfo = createStreamInfo(r.id, parameters.Encodings[0].SSRC, codec.PayloadType, codec.RTPCodecCapability, parameters.HeaderExtensions)
r.streamInfo = *createStreamInfo(r.id, parameters.Encodings[0].SSRC, codec.PayloadType, codec.RTPCodecCapability, parameters.HeaderExtensions)
rtpInterceptor := r.api.interceptor.BindLocalStream(&r.streamInfo, interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
return r.srtpStream.WriteRTP(header, payload)
}))

View File

@@ -247,7 +247,7 @@ func satisfyTypeAndDirection(remoteKind RTPCodecType, remoteDirection RTPTransce
// handleUnknownRTPPacket consumes a single RTP Packet and returns information that is helpful
// for demuxing and handling an unknown SSRC (usually for Simulcast)
func handleUnknownRTPPacket(buf []byte, midExtensionID, streamIDExtensionID uint8) (mid, rid string, payloadType PayloadType, err error) {
func handleUnknownRTPPacket(buf []byte, midExtensionID, streamIDExtensionID, repairStreamIDExtensionID uint8, mid, rid, rsid *string) (payloadType PayloadType, err error) {
rp := &rtp.Packet{}
if err = rp.Unmarshal(buf); err != nil {
return
@@ -259,11 +259,15 @@ func handleUnknownRTPPacket(buf []byte, midExtensionID, streamIDExtensionID uint
payloadType = PayloadType(rp.PayloadType)
if payload := rp.GetExtension(midExtensionID); payload != nil {
mid = string(payload)
*mid = string(payload)
}
if payload := rp.GetExtension(streamIDExtensionID); payload != nil {
rid = string(payload)
*rid = string(payload)
}
if payload := rp.GetExtension(repairStreamIDExtensionID); payload != nil {
*rsid = string(payload)
}
return