mirror of
https://github.com/pion/webrtc.git
synced 2025-12-24 11:51:03 +08:00
Don't drop packets when probing Simulcast
Before any packets that we read during the probe would get lost Co-authored-by: cptpcrd <31829097+cptpcrd@users.noreply.github.com>
This commit is contained in:
@@ -12,7 +12,6 @@ import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -1688,7 +1687,7 @@ func (pc *PeerConnection) handleNonMediaBandwidthProbe() {
|
||||
}
|
||||
}
|
||||
|
||||
func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) error { //nolint:gocyclo,gocognit,cyclop
|
||||
func (pc *PeerConnection) handleIncomingSSRC(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC) error { //nolint:gocyclo,gocognit,cyclop,lll
|
||||
remoteDescription := pc.RemoteDescription()
|
||||
if remoteDescription == nil {
|
||||
return errPeerConnRemoteDescriptionNil
|
||||
@@ -1725,7 +1724,7 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
|
||||
// We read the RTP packet to determine the payload type
|
||||
b := make([]byte, pc.api.settingEngine.getReceiveMTU())
|
||||
|
||||
i, err := rtpStream.Read(b)
|
||||
i, err := rtpStream.Peek(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1802,6 +1801,8 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
|
||||
return err
|
||||
}
|
||||
|
||||
peekedPackets := []*peekedPacket{}
|
||||
|
||||
// if the first packet didn't contain simuilcast IDs, then probe more packets
|
||||
var paddingOnly bool
|
||||
for readCount := 0; readCount <= simulcastProbeCount; readCount++ {
|
||||
@@ -1811,11 +1812,16 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
|
||||
readCount--
|
||||
}
|
||||
|
||||
i, _, err := interceptor.Read(b, nil)
|
||||
i, attributes, err := interceptor.Read(b, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peekedPackets = append(peekedPackets, &peekedPacket{
|
||||
payload: slices.Clone(b[:i]),
|
||||
attributes: attributes,
|
||||
})
|
||||
|
||||
if paddingOnly, err = handleUnknownRTPPacket(
|
||||
b[:i], uint8(midExtensionID), //nolint:gosec // G115
|
||||
uint8(streamIDExtensionID), //nolint:gosec // G115
|
||||
@@ -1851,6 +1857,7 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
|
||||
interceptor,
|
||||
rtcpReadStream,
|
||||
rtcpInterceptor,
|
||||
peekedPackets,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1930,7 +1937,7 @@ func (pc *PeerConnection) undeclaredRTPMediaProcessor() { //nolint:cyclop
|
||||
continue
|
||||
}
|
||||
|
||||
go func(rtpStream io.Reader, ssrc SSRC) {
|
||||
go func(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC) {
|
||||
if err := pc.handleIncomingSSRC(rtpStream, ssrc); err != nil {
|
||||
pc.log.Errorf(incomingUnhandledRTPSsrc, ssrc, err)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -2062,14 +2063,13 @@ func TestPeerConnection_Simulcast_RTX(t *testing.T) { //nolint:cyclop
|
||||
assert.NotZero(t, ridID)
|
||||
assert.NotZero(t, rsid)
|
||||
|
||||
err = signalPairWithModification(pcOffer, pcAnswer, func(sdp string) string {
|
||||
assert.NoError(t, signalPairWithModification(pcOffer, pcAnswer, func(sdp string) string {
|
||||
// Original chrome sdp contains no ssrc info https://pastebin.com/raw/JTjX6zg6
|
||||
re := regexp.MustCompile("(?m)[\r\n]+^.*a=ssrc.*$")
|
||||
res := re.ReplaceAllString(sdp, "")
|
||||
|
||||
return res
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}))
|
||||
|
||||
// padding only packets should not affect simulcast probe
|
||||
var sequenceNumber uint16
|
||||
@@ -2493,3 +2493,104 @@ func Test_PeerConnection_RTX_E2E(t *testing.T) { //nolint:cyclop
|
||||
closePairNow(t, pcOffer, pcAnswer)
|
||||
assert.NoError(t, wan.Stop())
|
||||
}
|
||||
|
||||
// Assert that we don't drop any packets during the probe.
|
||||
func TestPeerConnection_Simulcast_Probe_PacketLoss(t *testing.T) { //nolint:cyclop
|
||||
lim := test.TimeOut(time.Second * 30)
|
||||
defer lim.Stop()
|
||||
|
||||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
|
||||
const rtpPktCount = 10
|
||||
pcOffer, pcAnswer, wan := createVNetPair(t, nil)
|
||||
|
||||
rids := []string{"a", "b", "c"}
|
||||
vp8WriterA, err := NewTrackLocalStaticRTP(
|
||||
RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[0]),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
vp8WriterB, err := NewTrackLocalStaticRTP(
|
||||
RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[1]),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
vp8WriterC, err := NewTrackLocalStaticRTP(
|
||||
RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[2]),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
sender, err := pcOffer.AddTrack(vp8WriterA)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, sender)
|
||||
|
||||
assert.NoError(t, sender.AddEncoding(vp8WriterB))
|
||||
assert.NoError(t, sender.AddEncoding(vp8WriterC))
|
||||
|
||||
expectedBuffer := make([]byte, outboundMTU*rtpPktCount)
|
||||
_, err = rand.Read(expectedBuffer)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pcAnswer.OnTrack(func(trackRemote *TrackRemote, _ *RTPReceiver) {
|
||||
actualBuffer := []byte{}
|
||||
|
||||
for i := 0; i < rtpPktCount; i++ {
|
||||
pkt, _, err := trackRemote.ReadRTP()
|
||||
assert.NoError(t, err)
|
||||
|
||||
actualBuffer = append(actualBuffer, pkt.Payload...)
|
||||
}
|
||||
|
||||
assert.Equal(t, actualBuffer, expectedBuffer)
|
||||
cancel()
|
||||
})
|
||||
|
||||
var midID, ridID uint8
|
||||
for _, extension := range sender.GetParameters().HeaderExtensions {
|
||||
switch extension.URI {
|
||||
case sdp.SDESMidURI:
|
||||
midID = uint8(extension.ID) //nolint:gosec // G115
|
||||
case sdp.SDESRTPStreamIDURI:
|
||||
ridID = uint8(extension.ID) //nolint:gosec // G115
|
||||
}
|
||||
}
|
||||
assert.NotZero(t, midID)
|
||||
assert.NotZero(t, ridID)
|
||||
|
||||
assert.NoError(t, signalPairWithModification(pcOffer, pcAnswer, func(sdp string) string {
|
||||
// Original chrome sdp contains no ssrc info https://pastebin.com/raw/JTjX6zg6
|
||||
re := regexp.MustCompile("(?m)[\r\n]+^.*a=ssrc.*$")
|
||||
res := re.ReplaceAllString(sdp, "")
|
||||
|
||||
return res
|
||||
}))
|
||||
|
||||
peerConnectionConnected := untilConnectionState(PeerConnectionStateConnected, pcOffer, pcAnswer)
|
||||
peerConnectionConnected.Wait()
|
||||
|
||||
for sequenceNumber := uint16(0); sequenceNumber < rtpPktCount; sequenceNumber++ {
|
||||
pkt := &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
PayloadType: 96,
|
||||
SequenceNumber: sequenceNumber,
|
||||
},
|
||||
}
|
||||
|
||||
// Make sure that packets for Stream received before MID/RID don't get dropped
|
||||
if sequenceNumber > 3 {
|
||||
assert.NoError(t, pkt.SetExtension(midID, []byte("0")))
|
||||
assert.NoError(t, pkt.SetExtension(ridID, []byte(vp8WriterA.RID())))
|
||||
}
|
||||
|
||||
offset := int(sequenceNumber) * outboundMTU
|
||||
pkt.Payload = expectedBuffer[offset : offset+outboundMTU]
|
||||
assert.NoError(t, vp8WriterA.WriteRTP(pkt))
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
assert.NoError(t, wan.Stop())
|
||||
closePairNow(t, pcOffer, pcAnswer)
|
||||
}
|
||||
|
||||
@@ -1105,12 +1105,6 @@ func TestPeerConnection_Renegotiation_Simulcast(t *testing.T) {
|
||||
|
||||
for _, track := range trackMap {
|
||||
_, _, err := track.ReadRTP()
|
||||
|
||||
// Ignore first Read, this was our peeked data
|
||||
if err == nil {
|
||||
_, _, err = track.ReadRTP()
|
||||
}
|
||||
|
||||
assert.Equal(t, err, io.EOF)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"io"
|
||||
"math"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pion/interceptor"
|
||||
@@ -64,8 +65,9 @@ type RTPReceiver struct {
|
||||
|
||||
tracks []trackStreams
|
||||
|
||||
closed, received chan any
|
||||
mu sync.RWMutex
|
||||
closed atomic.Bool
|
||||
closedChan, received chan any
|
||||
mu sync.RWMutex
|
||||
|
||||
tr *RTPTransceiver
|
||||
|
||||
@@ -84,12 +86,12 @@ func (api *API) NewRTPReceiver(kind RTPCodecType, transport *DTLSTransport) (*RT
|
||||
}
|
||||
|
||||
rtpReceiver := &RTPReceiver{
|
||||
kind: kind,
|
||||
transport: transport,
|
||||
api: api,
|
||||
closed: make(chan any),
|
||||
received: make(chan any),
|
||||
tracks: []trackStreams{},
|
||||
kind: kind,
|
||||
transport: transport,
|
||||
api: api,
|
||||
closedChan: make(chan any),
|
||||
received: make(chan any),
|
||||
tracks: []trackStreams{},
|
||||
rtxPool: sync.Pool{New: func() any {
|
||||
return make([]byte, api.settingEngine.getReceiveMTU())
|
||||
}},
|
||||
@@ -290,7 +292,7 @@ func (r *RTPReceiver) Read(b []byte) (n int, a interceptor.Attributes, err error
|
||||
}
|
||||
|
||||
return r.tracks[0].rtcpInterceptor.Read(b, a)
|
||||
case <-r.closed:
|
||||
case <-r.closedChan:
|
||||
return 0, nil, io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
@@ -315,7 +317,7 @@ func (r *RTPReceiver) ReadSimulcast(b []byte, rid string) (n int, a interceptor.
|
||||
|
||||
return rtcpInterceptor.Read(b, a)
|
||||
|
||||
case <-r.closed:
|
||||
case <-r.closedChan:
|
||||
return 0, nil, io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
@@ -359,6 +361,10 @@ func (r *RTPReceiver) haveReceived() bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RTPReceiver) haveClosed() bool {
|
||||
return r.closed.Load()
|
||||
}
|
||||
|
||||
// Stop irreversibly stops the RTPReceiver.
|
||||
func (r *RTPReceiver) Stop() error { //nolint:cyclop
|
||||
r.mu.Lock()
|
||||
@@ -366,7 +372,7 @@ func (r *RTPReceiver) Stop() error { //nolint:cyclop
|
||||
var err error
|
||||
|
||||
select {
|
||||
case <-r.closed:
|
||||
case <-r.closedChan:
|
||||
return err
|
||||
default:
|
||||
}
|
||||
@@ -405,7 +411,8 @@ func (r *RTPReceiver) Stop() error { //nolint:cyclop
|
||||
default:
|
||||
}
|
||||
|
||||
close(r.closed)
|
||||
close(r.closedChan)
|
||||
r.closed.Store(true)
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -519,7 +526,7 @@ func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams {
|
||||
func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, a interceptor.Attributes, err error) {
|
||||
select {
|
||||
case <-r.received:
|
||||
case <-r.closed:
|
||||
case <-r.closedChan:
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
|
||||
@@ -540,6 +547,7 @@ func (r *RTPReceiver) receiveForRid(
|
||||
rtpInterceptor interceptor.RTPReader,
|
||||
rtcpReadStream *srtp.ReadStreamSRTCP,
|
||||
rtcpInterceptor interceptor.RTCPReader,
|
||||
peekedPackets []*peekedPacket,
|
||||
) (*TrackRemote, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
@@ -551,6 +559,7 @@ func (r *RTPReceiver) receiveForRid(
|
||||
r.tracks[i].track.codec = params.Codecs[0]
|
||||
r.tracks[i].track.params = params
|
||||
r.tracks[i].track.ssrc = SSRC(streamInfo.SSRC)
|
||||
r.tracks[i].track.peekedPackets = peekedPackets
|
||||
r.tracks[i].track.mu.Unlock()
|
||||
|
||||
r.tracks[i].streamInfo = streamInfo
|
||||
@@ -651,7 +660,7 @@ func (r *RTPReceiver) receiveForRtx(
|
||||
copy(b[headerLength:i-2], b[headerLength+2:i])
|
||||
|
||||
select {
|
||||
case <-r.closed:
|
||||
case <-r.closedChan:
|
||||
r.rtxPool.Put(b) // nolint:staticcheck
|
||||
|
||||
return
|
||||
|
||||
@@ -215,7 +215,7 @@ func (p *defaultAudioPlayoutStatsProvider) AddTrack(track *TrackRemote) error {
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receiver.closed:
|
||||
case <-receiver.closedChan:
|
||||
p.removeTrackInternal(track)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
@@ -2349,7 +2349,7 @@ func TestDefaultAudioPlayoutStatsProvider_AccumulateSnapshot(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAudioPlayoutStatsProvider_AddRemoveTrack(t *testing.T) {
|
||||
receiver := &RTPReceiver{closed: make(chan any)}
|
||||
receiver := &RTPReceiver{closedChan: make(chan any)}
|
||||
track := newTrackRemote(RTPCodecTypeAudio, 1234, 0, "", receiver)
|
||||
samplesPerBatch := 960
|
||||
|
||||
@@ -2371,7 +2371,7 @@ func TestDefaultAudioPlayoutStatsProvider_AddRemoveTrack(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAudioPlayoutStatsProvider_MultipleProviders(t *testing.T) {
|
||||
receiver := &RTPReceiver{closed: make(chan any)}
|
||||
receiver := &RTPReceiver{closedChan: make(chan any)}
|
||||
track := newTrackRemote(RTPCodecTypeAudio, 5555, 0, "", receiver)
|
||||
samplesPerBatch := 960
|
||||
|
||||
|
||||
@@ -827,8 +827,7 @@ func Test_TrackRemote_ReadRTP_UnmarshalError(t *testing.T) {
|
||||
tr := newTrackRemote(RTPCodecTypeVideo, 0, 0, "", recv)
|
||||
|
||||
tr.mu.Lock()
|
||||
tr.peeked = []byte{0x80, 96}
|
||||
tr.peekedAttributes = nil
|
||||
tr.peekedPackets = []*peekedPacket{{payload: []byte{0x80, 96}}}
|
||||
tr.mu.Unlock()
|
||||
|
||||
pkt, attrs, err := tr.ReadRTP()
|
||||
|
||||
@@ -8,6 +8,7 @@ package webrtc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -15,6 +16,11 @@ import (
|
||||
"github.com/pion/rtp"
|
||||
)
|
||||
|
||||
type peekedPacket struct {
|
||||
payload []byte
|
||||
attributes interceptor.Attributes
|
||||
}
|
||||
|
||||
// TrackRemote represents a single inbound source of media.
|
||||
type TrackRemote struct {
|
||||
mu sync.RWMutex
|
||||
@@ -30,9 +36,9 @@ type TrackRemote struct {
|
||||
params RTPParameters
|
||||
rid string
|
||||
|
||||
receiver *RTPReceiver
|
||||
peeked []byte
|
||||
peekedAttributes interceptor.Attributes
|
||||
receiver *RTPReceiver
|
||||
|
||||
peekedPackets []*peekedPacket
|
||||
|
||||
audioPlayoutStatsProviders []AudioPlayoutStatsProvider
|
||||
}
|
||||
@@ -116,25 +122,22 @@ func (t *TrackRemote) Codec() RTPCodecParameters {
|
||||
func (t *TrackRemote) Read(b []byte) (n int, attributes interceptor.Attributes, err error) {
|
||||
t.mu.RLock()
|
||||
receiver := t.receiver
|
||||
peeked := t.peeked != nil
|
||||
var peekedPkt *peekedPacket
|
||||
if len(t.peekedPackets) != 0 {
|
||||
peekedPkt = t.peekedPackets[0]
|
||||
t.peekedPackets = t.peekedPackets[1:]
|
||||
}
|
||||
t.mu.RUnlock()
|
||||
|
||||
if peeked {
|
||||
t.mu.Lock()
|
||||
data := t.peeked
|
||||
attributes = t.peekedAttributes
|
||||
if receiver.haveClosed() {
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
|
||||
t.peeked = nil
|
||||
t.peekedAttributes = nil
|
||||
t.mu.Unlock()
|
||||
// someone else may have stolen our packet when we
|
||||
// released the lock. Deal with it.
|
||||
if data != nil {
|
||||
n = copy(b, data)
|
||||
err = t.checkAndUpdateTrack(b)
|
||||
if peekedPkt != nil {
|
||||
n = copy(b, peekedPkt.payload)
|
||||
err = t.checkAndUpdateTrack(b)
|
||||
|
||||
return n, attributes, err
|
||||
}
|
||||
return n, peekedPkt.attributes, err
|
||||
}
|
||||
|
||||
// If there's a separate RTX track and an RTX packet is available, return that
|
||||
@@ -142,18 +145,16 @@ func (t *TrackRemote) Read(b []byte) (n int, attributes interceptor.Attributes,
|
||||
n = copy(b, rtxPacketReceived.pkt)
|
||||
attributes = rtxPacketReceived.attributes
|
||||
rtxPacketReceived.release()
|
||||
err = nil
|
||||
} else {
|
||||
// If there's no separate RTX track (or there's a separate RTX track but no RTX packet waiting), wait for and return
|
||||
// a packet from the main track
|
||||
n, attributes, err = receiver.readRTP(b, t)
|
||||
if err != nil {
|
||||
return n, attributes, err
|
||||
}
|
||||
|
||||
err = t.checkAndUpdateTrack(b)
|
||||
return n, attributes, nil
|
||||
}
|
||||
|
||||
n, attributes, err = receiver.readRTP(b, t)
|
||||
if err != nil {
|
||||
return n, attributes, err
|
||||
}
|
||||
err = t.checkAndUpdateTrack(b)
|
||||
|
||||
return n, attributes, err
|
||||
}
|
||||
|
||||
@@ -212,8 +213,7 @@ func (t *TrackRemote) peek(b []byte) (n int, a interceptor.Attributes, err error
|
||||
// that case.
|
||||
data := make([]byte, n)
|
||||
n = copy(data, b[:n])
|
||||
t.peeked = data
|
||||
t.peekedAttributes = a
|
||||
t.peekedPackets = append(t.peekedPackets, &peekedPacket{payload: data, attributes: a})
|
||||
t.mu.Unlock()
|
||||
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user