Don't drop packets when probing Simulcast

Before any packets that we read during the probe would get lost

Co-authored-by: cptpcrd <31829097+cptpcrd@users.noreply.github.com>
This commit is contained in:
Sean DuBois
2025-11-28 12:16:16 -05:00
committed by Sean DuBois
parent c457479a9b
commit 71b8a13dc9
8 changed files with 172 additions and 62 deletions

View File

@@ -12,7 +12,6 @@ import (
"crypto/rand"
"errors"
"fmt"
"io"
"slices"
"strconv"
"strings"
@@ -1688,7 +1687,7 @@ func (pc *PeerConnection) handleNonMediaBandwidthProbe() {
}
}
func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) error { //nolint:gocyclo,gocognit,cyclop
func (pc *PeerConnection) handleIncomingSSRC(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC) error { //nolint:gocyclo,gocognit,cyclop,lll
remoteDescription := pc.RemoteDescription()
if remoteDescription == nil {
return errPeerConnRemoteDescriptionNil
@@ -1725,7 +1724,7 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
// We read the RTP packet to determine the payload type
b := make([]byte, pc.api.settingEngine.getReceiveMTU())
i, err := rtpStream.Read(b)
i, err := rtpStream.Peek(b)
if err != nil {
return err
}
@@ -1802,6 +1801,8 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
return err
}
peekedPackets := []*peekedPacket{}
// if the first packet didn't contain simuilcast IDs, then probe more packets
var paddingOnly bool
for readCount := 0; readCount <= simulcastProbeCount; readCount++ {
@@ -1811,11 +1812,16 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
readCount--
}
i, _, err := interceptor.Read(b, nil)
i, attributes, err := interceptor.Read(b, nil)
if err != nil {
return err
}
peekedPackets = append(peekedPackets, &peekedPacket{
payload: slices.Clone(b[:i]),
attributes: attributes,
})
if paddingOnly, err = handleUnknownRTPPacket(
b[:i], uint8(midExtensionID), //nolint:gosec // G115
uint8(streamIDExtensionID), //nolint:gosec // G115
@@ -1851,6 +1857,7 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
interceptor,
rtcpReadStream,
rtcpInterceptor,
peekedPackets,
)
if err != nil {
return err
@@ -1930,7 +1937,7 @@ func (pc *PeerConnection) undeclaredRTPMediaProcessor() { //nolint:cyclop
continue
}
go func(rtpStream io.Reader, ssrc SSRC) {
go func(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC) {
if err := pc.handleIncomingSSRC(rtpStream, ssrc); err != nil {
pc.log.Errorf(incomingUnhandledRTPSsrc, ssrc, err)
}

View File

@@ -10,6 +10,7 @@ import (
"bufio"
"bytes"
"context"
"crypto/rand"
"errors"
"fmt"
"io"
@@ -2062,14 +2063,13 @@ func TestPeerConnection_Simulcast_RTX(t *testing.T) { //nolint:cyclop
assert.NotZero(t, ridID)
assert.NotZero(t, rsid)
err = signalPairWithModification(pcOffer, pcAnswer, func(sdp string) string {
assert.NoError(t, signalPairWithModification(pcOffer, pcAnswer, func(sdp string) string {
// Original chrome sdp contains no ssrc info https://pastebin.com/raw/JTjX6zg6
re := regexp.MustCompile("(?m)[\r\n]+^.*a=ssrc.*$")
res := re.ReplaceAllString(sdp, "")
return res
})
assert.NoError(t, err)
}))
// padding only packets should not affect simulcast probe
var sequenceNumber uint16
@@ -2493,3 +2493,104 @@ func Test_PeerConnection_RTX_E2E(t *testing.T) { //nolint:cyclop
closePairNow(t, pcOffer, pcAnswer)
assert.NoError(t, wan.Stop())
}
// Assert that we don't drop any packets during the probe.
func TestPeerConnection_Simulcast_Probe_PacketLoss(t *testing.T) { //nolint:cyclop
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()
report := test.CheckRoutines(t)
defer report()
const rtpPktCount = 10
pcOffer, pcAnswer, wan := createVNetPair(t, nil)
rids := []string{"a", "b", "c"}
vp8WriterA, err := NewTrackLocalStaticRTP(
RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[0]),
)
assert.NoError(t, err)
vp8WriterB, err := NewTrackLocalStaticRTP(
RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[1]),
)
assert.NoError(t, err)
vp8WriterC, err := NewTrackLocalStaticRTP(
RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[2]),
)
assert.NoError(t, err)
sender, err := pcOffer.AddTrack(vp8WriterA)
assert.NoError(t, err)
assert.NotNil(t, sender)
assert.NoError(t, sender.AddEncoding(vp8WriterB))
assert.NoError(t, sender.AddEncoding(vp8WriterC))
expectedBuffer := make([]byte, outboundMTU*rtpPktCount)
_, err = rand.Read(expectedBuffer)
assert.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
pcAnswer.OnTrack(func(trackRemote *TrackRemote, _ *RTPReceiver) {
actualBuffer := []byte{}
for i := 0; i < rtpPktCount; i++ {
pkt, _, err := trackRemote.ReadRTP()
assert.NoError(t, err)
actualBuffer = append(actualBuffer, pkt.Payload...)
}
assert.Equal(t, actualBuffer, expectedBuffer)
cancel()
})
var midID, ridID uint8
for _, extension := range sender.GetParameters().HeaderExtensions {
switch extension.URI {
case sdp.SDESMidURI:
midID = uint8(extension.ID) //nolint:gosec // G115
case sdp.SDESRTPStreamIDURI:
ridID = uint8(extension.ID) //nolint:gosec // G115
}
}
assert.NotZero(t, midID)
assert.NotZero(t, ridID)
assert.NoError(t, signalPairWithModification(pcOffer, pcAnswer, func(sdp string) string {
// Original chrome sdp contains no ssrc info https://pastebin.com/raw/JTjX6zg6
re := regexp.MustCompile("(?m)[\r\n]+^.*a=ssrc.*$")
res := re.ReplaceAllString(sdp, "")
return res
}))
peerConnectionConnected := untilConnectionState(PeerConnectionStateConnected, pcOffer, pcAnswer)
peerConnectionConnected.Wait()
for sequenceNumber := uint16(0); sequenceNumber < rtpPktCount; sequenceNumber++ {
pkt := &rtp.Packet{
Header: rtp.Header{
Version: 2,
PayloadType: 96,
SequenceNumber: sequenceNumber,
},
}
// Make sure that packets for Stream received before MID/RID don't get dropped
if sequenceNumber > 3 {
assert.NoError(t, pkt.SetExtension(midID, []byte("0")))
assert.NoError(t, pkt.SetExtension(ridID, []byte(vp8WriterA.RID())))
}
offset := int(sequenceNumber) * outboundMTU
pkt.Payload = expectedBuffer[offset : offset+outboundMTU]
assert.NoError(t, vp8WriterA.WriteRTP(pkt))
}
<-ctx.Done()
assert.NoError(t, wan.Stop())
closePairNow(t, pcOffer, pcAnswer)
}

View File

@@ -1105,12 +1105,6 @@ func TestPeerConnection_Renegotiation_Simulcast(t *testing.T) {
for _, track := range trackMap {
_, _, err := track.ReadRTP()
// Ignore first Read, this was our peeked data
if err == nil {
_, _, err = track.ReadRTP()
}
assert.Equal(t, err, io.EOF)
}
}

View File

@@ -12,6 +12,7 @@ import (
"io"
"math"
"sync"
"sync/atomic"
"time"
"github.com/pion/interceptor"
@@ -64,8 +65,9 @@ type RTPReceiver struct {
tracks []trackStreams
closed, received chan any
mu sync.RWMutex
closed atomic.Bool
closedChan, received chan any
mu sync.RWMutex
tr *RTPTransceiver
@@ -84,12 +86,12 @@ func (api *API) NewRTPReceiver(kind RTPCodecType, transport *DTLSTransport) (*RT
}
rtpReceiver := &RTPReceiver{
kind: kind,
transport: transport,
api: api,
closed: make(chan any),
received: make(chan any),
tracks: []trackStreams{},
kind: kind,
transport: transport,
api: api,
closedChan: make(chan any),
received: make(chan any),
tracks: []trackStreams{},
rtxPool: sync.Pool{New: func() any {
return make([]byte, api.settingEngine.getReceiveMTU())
}},
@@ -290,7 +292,7 @@ func (r *RTPReceiver) Read(b []byte) (n int, a interceptor.Attributes, err error
}
return r.tracks[0].rtcpInterceptor.Read(b, a)
case <-r.closed:
case <-r.closedChan:
return 0, nil, io.ErrClosedPipe
}
}
@@ -315,7 +317,7 @@ func (r *RTPReceiver) ReadSimulcast(b []byte, rid string) (n int, a interceptor.
return rtcpInterceptor.Read(b, a)
case <-r.closed:
case <-r.closedChan:
return 0, nil, io.ErrClosedPipe
}
}
@@ -359,6 +361,10 @@ func (r *RTPReceiver) haveReceived() bool {
}
}
func (r *RTPReceiver) haveClosed() bool {
return r.closed.Load()
}
// Stop irreversibly stops the RTPReceiver.
func (r *RTPReceiver) Stop() error { //nolint:cyclop
r.mu.Lock()
@@ -366,7 +372,7 @@ func (r *RTPReceiver) Stop() error { //nolint:cyclop
var err error
select {
case <-r.closed:
case <-r.closedChan:
return err
default:
}
@@ -405,7 +411,8 @@ func (r *RTPReceiver) Stop() error { //nolint:cyclop
default:
}
close(r.closed)
close(r.closedChan)
r.closed.Store(true)
return err
}
@@ -519,7 +526,7 @@ func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams {
func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, a interceptor.Attributes, err error) {
select {
case <-r.received:
case <-r.closed:
case <-r.closedChan:
return 0, nil, io.EOF
}
@@ -540,6 +547,7 @@ func (r *RTPReceiver) receiveForRid(
rtpInterceptor interceptor.RTPReader,
rtcpReadStream *srtp.ReadStreamSRTCP,
rtcpInterceptor interceptor.RTCPReader,
peekedPackets []*peekedPacket,
) (*TrackRemote, error) {
r.mu.Lock()
defer r.mu.Unlock()
@@ -551,6 +559,7 @@ func (r *RTPReceiver) receiveForRid(
r.tracks[i].track.codec = params.Codecs[0]
r.tracks[i].track.params = params
r.tracks[i].track.ssrc = SSRC(streamInfo.SSRC)
r.tracks[i].track.peekedPackets = peekedPackets
r.tracks[i].track.mu.Unlock()
r.tracks[i].streamInfo = streamInfo
@@ -651,7 +660,7 @@ func (r *RTPReceiver) receiveForRtx(
copy(b[headerLength:i-2], b[headerLength+2:i])
select {
case <-r.closed:
case <-r.closedChan:
r.rtxPool.Put(b) // nolint:staticcheck
return

View File

@@ -215,7 +215,7 @@ func (p *defaultAudioPlayoutStatsProvider) AddTrack(track *TrackRemote) error {
}
select {
case <-receiver.closed:
case <-receiver.closedChan:
p.removeTrackInternal(track)
case <-ctx.Done():
return

View File

@@ -2349,7 +2349,7 @@ func TestDefaultAudioPlayoutStatsProvider_AccumulateSnapshot(t *testing.T) {
}
func TestDefaultAudioPlayoutStatsProvider_AddRemoveTrack(t *testing.T) {
receiver := &RTPReceiver{closed: make(chan any)}
receiver := &RTPReceiver{closedChan: make(chan any)}
track := newTrackRemote(RTPCodecTypeAudio, 1234, 0, "", receiver)
samplesPerBatch := 960
@@ -2371,7 +2371,7 @@ func TestDefaultAudioPlayoutStatsProvider_AddRemoveTrack(t *testing.T) {
}
func TestDefaultAudioPlayoutStatsProvider_MultipleProviders(t *testing.T) {
receiver := &RTPReceiver{closed: make(chan any)}
receiver := &RTPReceiver{closedChan: make(chan any)}
track := newTrackRemote(RTPCodecTypeAudio, 5555, 0, "", receiver)
samplesPerBatch := 960

View File

@@ -827,8 +827,7 @@ func Test_TrackRemote_ReadRTP_UnmarshalError(t *testing.T) {
tr := newTrackRemote(RTPCodecTypeVideo, 0, 0, "", recv)
tr.mu.Lock()
tr.peeked = []byte{0x80, 96}
tr.peekedAttributes = nil
tr.peekedPackets = []*peekedPacket{{payload: []byte{0x80, 96}}}
tr.mu.Unlock()
pkt, attrs, err := tr.ReadRTP()

View File

@@ -8,6 +8,7 @@ package webrtc
import (
"fmt"
"io"
"sync"
"time"
@@ -15,6 +16,11 @@ import (
"github.com/pion/rtp"
)
type peekedPacket struct {
payload []byte
attributes interceptor.Attributes
}
// TrackRemote represents a single inbound source of media.
type TrackRemote struct {
mu sync.RWMutex
@@ -30,9 +36,9 @@ type TrackRemote struct {
params RTPParameters
rid string
receiver *RTPReceiver
peeked []byte
peekedAttributes interceptor.Attributes
receiver *RTPReceiver
peekedPackets []*peekedPacket
audioPlayoutStatsProviders []AudioPlayoutStatsProvider
}
@@ -116,25 +122,22 @@ func (t *TrackRemote) Codec() RTPCodecParameters {
func (t *TrackRemote) Read(b []byte) (n int, attributes interceptor.Attributes, err error) {
t.mu.RLock()
receiver := t.receiver
peeked := t.peeked != nil
var peekedPkt *peekedPacket
if len(t.peekedPackets) != 0 {
peekedPkt = t.peekedPackets[0]
t.peekedPackets = t.peekedPackets[1:]
}
t.mu.RUnlock()
if peeked {
t.mu.Lock()
data := t.peeked
attributes = t.peekedAttributes
if receiver.haveClosed() {
return 0, nil, io.EOF
}
t.peeked = nil
t.peekedAttributes = nil
t.mu.Unlock()
// someone else may have stolen our packet when we
// released the lock. Deal with it.
if data != nil {
n = copy(b, data)
err = t.checkAndUpdateTrack(b)
if peekedPkt != nil {
n = copy(b, peekedPkt.payload)
err = t.checkAndUpdateTrack(b)
return n, attributes, err
}
return n, peekedPkt.attributes, err
}
// If there's a separate RTX track and an RTX packet is available, return that
@@ -142,18 +145,16 @@ func (t *TrackRemote) Read(b []byte) (n int, attributes interceptor.Attributes,
n = copy(b, rtxPacketReceived.pkt)
attributes = rtxPacketReceived.attributes
rtxPacketReceived.release()
err = nil
} else {
// If there's no separate RTX track (or there's a separate RTX track but no RTX packet waiting), wait for and return
// a packet from the main track
n, attributes, err = receiver.readRTP(b, t)
if err != nil {
return n, attributes, err
}
err = t.checkAndUpdateTrack(b)
return n, attributes, nil
}
n, attributes, err = receiver.readRTP(b, t)
if err != nil {
return n, attributes, err
}
err = t.checkAndUpdateTrack(b)
return n, attributes, err
}
@@ -212,8 +213,7 @@ func (t *TrackRemote) peek(b []byte) (n int, a interceptor.Attributes, err error
// that case.
data := make([]byte, n)
n = copy(data, b[:n])
t.peeked = data
t.peekedAttributes = a
t.peekedPackets = append(t.peekedPackets, &peekedPacket{payload: data, attributes: a})
t.mu.Unlock()
return