diff --git a/constants.go b/constants.go index d1966b10..825601dd 100644 --- a/constants.go +++ b/constants.go @@ -31,6 +31,8 @@ const ( incomingUnhandledRTPSsrc = "Incoming unhandled RTP ssrc(%d), OnTrack will not be fired. %v" generatedCertificateOrigin = "WebRTC" + + sdesRepairRTPStreamIDURI = "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id" ) func defaultSrtpProtectionProfiles() []dtls.SRTPProtectionProfile { diff --git a/dtlstransport.go b/dtlstransport.go index 9b69b2d7..1a0505f4 100644 --- a/dtlstransport.go +++ b/dtlstransport.go @@ -17,6 +17,7 @@ import ( "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/pkg/crypto/fingerprint" + "github.com/pion/interceptor" "github.com/pion/logging" "github.com/pion/rtcp" "github.com/pion/srtp/v2" @@ -459,3 +460,37 @@ func (t *DTLSTransport) storeSimulcastStream(s *srtp.ReadStreamSRTP) { t.simulcastStreams = append(t.simulcastStreams, s) } + +func (t *DTLSTransport) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) { + srtpSession, err := t.getSRTPSession() + if err != nil { + return nil, nil, nil, nil, err + } + + rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc)) + if err != nil { + return nil, nil, nil, nil, err + } + + rtpInterceptor := t.api.interceptor.BindRemoteStream(&streamInfo, interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { + n, err = rtpReadStream.Read(in) + return n, a, err + })) + + srtcpSession, err := t.getSRTCPSession() + if err != nil { + return nil, nil, nil, nil, err + } + + rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc)) + if err != nil { + return nil, nil, nil, nil, err + } + + rtcpInterceptor := t.api.interceptor.BindRTCPReader(interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { + n, err = rtcpReadStream.Read(in) + return n, a, err + })) + + return rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, nil +} diff --git a/errors.go b/errors.go index 1a1a491f..339edfa7 100644 --- a/errors.go +++ b/errors.go @@ -200,7 +200,6 @@ var ( errRTPReceiverDTLSTransportNil = errors.New("DTLSTransport must not be nil") errRTPReceiverReceiveAlreadyCalled = errors.New("Receive has already been called") errRTPReceiverWithSSRCTrackStreamNotFound = errors.New("unable to find stream for Track with SSRC") - errRTPReceiverForSSRCTrackStreamNotFound = errors.New("no trackStreams found for SSRC") errRTPReceiverForRIDTrackStreamNotFound = errors.New("no trackStreams found for RID") errRTPSenderTrackNil = errors.New("Track must not be nil") diff --git a/interceptor.go b/interceptor.go index a7f8b432..864afcae 100644 --- a/interceptor.go +++ b/interceptor.go @@ -117,7 +117,7 @@ func (i *interceptorToTrackLocalWriter) Write(b []byte) (int, error) { return i.WriteRTP(&packet.Header, packet.Payload) } -func createStreamInfo(id string, ssrc SSRC, payloadType PayloadType, codec RTPCodecCapability, webrtcHeaderExtensions []RTPHeaderExtensionParameter) interceptor.StreamInfo { +func createStreamInfo(id string, ssrc SSRC, payloadType PayloadType, codec RTPCodecCapability, webrtcHeaderExtensions []RTPHeaderExtensionParameter) *interceptor.StreamInfo { headerExtensions := make([]interceptor.RTPHeaderExtension, 0, len(webrtcHeaderExtensions)) for _, h := range webrtcHeaderExtensions { headerExtensions = append(headerExtensions, interceptor.RTPHeaderExtension{ID: h.ID, URI: h.URI}) @@ -128,7 +128,7 @@ func createStreamInfo(id string, ssrc SSRC, payloadType PayloadType, codec RTPCo feedbacks = append(feedbacks, interceptor.RTCPFeedback{Type: f.Type, Parameter: f.Parameter}) } - return interceptor.StreamInfo{ + return &interceptor.StreamInfo{ ID: id, Attributes: interceptor.Attributes{}, SSRC: uint32(ssrc), diff --git a/peerconnection.go b/peerconnection.go index aaa3143b..67f0c2a0 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -1397,42 +1397,57 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err return errPeerConnSimulcastStreamIDRTPExtensionRequired } + repairStreamIDExtensionID, _, _ := pc.api.mediaEngine.getHeaderExtensionID(RTPHeaderExtensionCapability{sdesRepairRTPStreamIDURI}) + b := make([]byte, pc.api.settingEngine.getReceiveMTU()) - var mid, rid string + + i, err := rtpStream.Read(b) + if err != nil { + return err + } + + var mid, rid, rsid string + payloadType, err := handleUnknownRTPPacket(b[:i], uint8(midExtensionID), uint8(streamIDExtensionID), uint8(repairStreamIDExtensionID), &mid, &rid, &rsid) + if err != nil { + return err + } + + params, err := pc.api.mediaEngine.getRTPParametersByPayloadType(payloadType) + if err != nil { + return err + } + + streamInfo := createStreamInfo("", ssrc, params.Codecs[0].PayloadType, params.Codecs[0].RTPCodecCapability, params.HeaderExtensions) + readStream, interceptor, rtcpReadStream, rtcpInterceptor, err := pc.dtlsTransport.streamsForSSRC(ssrc, *streamInfo) + if err != nil { + return err + } + for readCount := 0; readCount <= simulcastProbeCount; readCount++ { - i, err := rtpStream.Read(b) - if err != nil { - return err - } + if mid == "" || (rid == "" && rsid == "") { + i, _, err := interceptor.Read(b, nil) + if err != nil { + return err + } - maybeMid, maybeRid, payloadType, err := handleUnknownRTPPacket(b[:i], uint8(midExtensionID), uint8(streamIDExtensionID)) - if err != nil { - return err - } + if _, err = handleUnknownRTPPacket(b[:i], uint8(midExtensionID), uint8(streamIDExtensionID), uint8(repairStreamIDExtensionID), &mid, &rid, &rsid); err != nil { + return err + } - if maybeMid != "" { - mid = maybeMid - } - if maybeRid != "" { - rid = maybeRid - } - - if mid == "" || rid == "" { continue } - params, err := pc.api.mediaEngine.getRTPParametersByPayloadType(payloadType) - if err != nil { - return err - } - for _, t := range pc.GetTransceivers() { receiver := t.Receiver() if t.Mid() != mid || receiver == nil { continue } - track, err := receiver.receiveForRid(rid, params, ssrc) + if rsid != "" { + return receiver.receiveForRsid(rsid, streamInfo, readStream, interceptor, rtcpReadStream, rtcpInterceptor) + } + + track, err := receiver.receiveForRid(rid, params, streamInfo, readStream, interceptor, rtcpReadStream, rtcpInterceptor) if err != nil { return err } @@ -1441,6 +1456,13 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err } } + if readStream != nil { + _ = readStream.Close() + } + if rtcpReadStream != nil { + _ = rtcpReadStream.Close() + } + pc.api.interceptor.UnbindRemoteStream(streamInfo) return errPeerConnSimulcastIncomingSSRCFailed } diff --git a/peerconnection_media_test.go b/peerconnection_media_test.go index 729378ae..c847d642 100644 --- a/peerconnection_media_test.go +++ b/peerconnection_media_test.go @@ -1175,7 +1175,14 @@ func TestPeerConnection_Simulcast(t *testing.T) { PayloadType: 96, } assert.NoError(t, header.SetExtension(1, []byte("0"))) - assert.NoError(t, header.SetExtension(2, []byte(rid))) + + // Send RSID for first 10 packets + if sequenceNumber >= 10 { + assert.NoError(t, header.SetExtension(2, []byte(rid))) + } else { + assert.NoError(t, header.SetExtension(3, []byte(rid))) + header.SSRC += 10 + } _, err := vp8Writer.bindings[0].writeStream.WriteRTP(header, []byte{0x00}) assert.NoError(t, err) diff --git a/rtpreceiver.go b/rtpreceiver.go index ff1a429a..cae82cbf 100644 --- a/rtpreceiver.go +++ b/rtpreceiver.go @@ -19,13 +19,19 @@ import ( type trackStreams struct { track *TrackRemote - streamInfo interceptor.StreamInfo + streamInfo, repairStreamInfo *interceptor.StreamInfo rtpReadStream *srtp.ReadStreamSRTP rtpInterceptor interceptor.RTPReader rtcpReadStream *srtp.ReadStreamSRTCP rtcpInterceptor interceptor.RTCPReader + + repairReadStream *srtp.ReadStreamSRTP + repairInterceptor interceptor.RTPReader + + repairRtcpReadStream *srtp.ReadStreamSRTCP + repairRtcpInterceptor interceptor.RTCPReader } // RTPReceiver allows an application to inspect the receipt of a TrackRemote @@ -146,7 +152,7 @@ func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error { if parameters.Encodings[i].SSRC != 0 { t.streamInfo = createStreamInfo("", parameters.Encodings[i].SSRC, 0, codec, globalParams.HeaderExtensions) var err error - if t.rtpReadStream, t.rtpInterceptor, t.rtcpReadStream, t.rtcpInterceptor, err = r.streamsForSSRC(parameters.Encodings[i].SSRC, t.streamInfo); err != nil { + if t.rtpReadStream, t.rtpInterceptor, t.rtcpReadStream, t.rtcpInterceptor, err = r.transport.streamsForSSRC(parameters.Encodings[i].SSRC, *t.streamInfo); err != nil { return err } } @@ -245,8 +251,23 @@ func (r *RTPReceiver) Stop() error { errs = append(errs, r.tracks[i].rtpReadStream.Close()) } + if r.tracks[i].repairReadStream != nil { + errs = append(errs, r.tracks[i].repairReadStream.Close()) + } + + if r.tracks[i].repairRtcpReadStream != nil { + errs = append(errs, r.tracks[i].repairRtcpReadStream.Close()) + } + + if r.tracks[i].streamInfo != nil { + r.api.interceptor.UnbindRemoteStream(r.tracks[i].streamInfo) + } + + if r.tracks[i].repairStreamInfo != nil { + r.api.interceptor.UnbindRemoteStream(r.tracks[i].repairStreamInfo) + } + err = util.FlattenErrs(errs) - r.api.interceptor.UnbindRemoteStream(&r.tracks[i].streamInfo) } default: } @@ -276,7 +297,7 @@ func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, a intercept // receiveForRid is the sibling of Receive expect for RIDs instead of SSRCs // It populates all the internal state for the given RID -func (r *RTPReceiver) receiveForRid(rid string, params RTPParameters, ssrc SSRC) (*TrackRemote, error) { +func (r *RTPReceiver) receiveForRid(rid string, params RTPParameters, streamInfo *interceptor.StreamInfo, rtpReadStream *srtp.ReadStreamSRTP, rtpInterceptor interceptor.RTPReader, rtcpReadStream *srtp.ReadStreamSRTCP, rtcpInterceptor interceptor.RTCPReader) (*TrackRemote, error) { r.mu.Lock() defer r.mu.Unlock() @@ -286,54 +307,53 @@ func (r *RTPReceiver) receiveForRid(rid string, params RTPParameters, ssrc SSRC) r.tracks[i].track.kind = r.kind r.tracks[i].track.codec = params.Codecs[0] r.tracks[i].track.params = params - r.tracks[i].track.ssrc = ssrc - r.tracks[i].streamInfo = createStreamInfo("", ssrc, params.Codecs[0].PayloadType, params.Codecs[0].RTPCodecCapability, params.HeaderExtensions) + r.tracks[i].track.ssrc = SSRC(streamInfo.SSRC) r.tracks[i].track.mu.Unlock() - var err error - if r.tracks[i].rtpReadStream, r.tracks[i].rtpInterceptor, r.tracks[i].rtcpReadStream, r.tracks[i].rtcpInterceptor, err = r.streamsForSSRC(ssrc, r.tracks[i].streamInfo); err != nil { - return nil, err - } + r.tracks[i].streamInfo = streamInfo + r.tracks[i].rtpReadStream = rtpReadStream + r.tracks[i].rtpInterceptor = rtpInterceptor + r.tracks[i].rtcpReadStream = rtcpReadStream + r.tracks[i].rtcpInterceptor = rtcpInterceptor return r.tracks[i].track, nil } } - return nil, fmt.Errorf("%w: %d", errRTPReceiverForSSRCTrackStreamNotFound, ssrc) + return nil, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid) } -func (r *RTPReceiver) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) { - srtpSession, err := r.transport.getSRTPSession() - if err != nil { - return nil, nil, nil, nil, err +// receiveForRsid starts a routine that processes the repair stream for a RID +// These packets aren't exposed to the user yet, but we need to process them for +// TWCC +func (r *RTPReceiver) receiveForRsid(rsid string, streamInfo *interceptor.StreamInfo, rtpReadStream *srtp.ReadStreamSRTP, rtpInterceptor interceptor.RTPReader, rtcpReadStream *srtp.ReadStreamSRTCP, rtcpInterceptor interceptor.RTCPReader) error { + r.mu.Lock() + defer r.mu.Unlock() + + for i := range r.tracks { + if r.tracks[i].track.RID() == rsid { + var err error + + r.tracks[i].repairStreamInfo = streamInfo + r.tracks[i].repairReadStream = rtpReadStream + r.tracks[i].repairInterceptor = rtpInterceptor + r.tracks[i].repairRtcpReadStream = rtcpReadStream + r.tracks[i].repairRtcpInterceptor = rtcpInterceptor + + go func() { + b := make([]byte, r.api.settingEngine.getReceiveMTU()) + for { + if _, _, readErr := r.tracks[i].repairInterceptor.Read(b, nil); readErr != nil { + return + } + } + }() + + return err + } } - rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc)) - if err != nil { - return nil, nil, nil, nil, err - } - - rtpInterceptor := r.api.interceptor.BindRemoteStream(&streamInfo, interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { - n, err = rtpReadStream.Read(in) - return n, a, err - })) - - srtcpSession, err := r.transport.getSRTCPSession() - if err != nil { - return nil, nil, nil, nil, err - } - - rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc)) - if err != nil { - return nil, nil, nil, nil, err - } - - rtcpInterceptor := r.api.interceptor.BindRTCPReader(interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { - n, err = rtcpReadStream.Read(in) - return n, a, err - })) - - return rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, nil + return fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rsid) } // SetReadDeadline sets the max amount of time the RTCP stream will block before returning. 0 is forever. diff --git a/rtpsender.go b/rtpsender.go index dadcc913..95c97a0d 100644 --- a/rtpsender.go +++ b/rtpsender.go @@ -211,7 +211,7 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error { } r.context.params.Codecs = []RTPCodecParameters{codec} - r.streamInfo = createStreamInfo(r.id, parameters.Encodings[0].SSRC, codec.PayloadType, codec.RTPCodecCapability, parameters.HeaderExtensions) + r.streamInfo = *createStreamInfo(r.id, parameters.Encodings[0].SSRC, codec.PayloadType, codec.RTPCodecCapability, parameters.HeaderExtensions) rtpInterceptor := r.api.interceptor.BindLocalStream(&r.streamInfo, interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { return r.srtpStream.WriteRTP(header, payload) })) diff --git a/rtptransceiver.go b/rtptransceiver.go index d596f43e..a42de68d 100644 --- a/rtptransceiver.go +++ b/rtptransceiver.go @@ -247,7 +247,7 @@ func satisfyTypeAndDirection(remoteKind RTPCodecType, remoteDirection RTPTransce // handleUnknownRTPPacket consumes a single RTP Packet and returns information that is helpful // for demuxing and handling an unknown SSRC (usually for Simulcast) -func handleUnknownRTPPacket(buf []byte, midExtensionID, streamIDExtensionID uint8) (mid, rid string, payloadType PayloadType, err error) { +func handleUnknownRTPPacket(buf []byte, midExtensionID, streamIDExtensionID, repairStreamIDExtensionID uint8, mid, rid, rsid *string) (payloadType PayloadType, err error) { rp := &rtp.Packet{} if err = rp.Unmarshal(buf); err != nil { return @@ -259,11 +259,15 @@ func handleUnknownRTPPacket(buf []byte, midExtensionID, streamIDExtensionID uint payloadType = PayloadType(rp.PayloadType) if payload := rp.GetExtension(midExtensionID); payload != nil { - mid = string(payload) + *mid = string(payload) } if payload := rp.GetExtension(streamIDExtensionID); payload != nil { - rid = string(payload) + *rid = string(payload) + } + + if payload := rp.GetExtension(repairStreamIDExtensionID); payload != nil { + *rsid = string(payload) } return