Add inbound-rtp stats

Remove comments

Add collectStats test

Fix linter issues

Remove comment

Fix tests

Address comments

Fix comment

Fix function comment
This commit is contained in:
Shreyas Jaganmohan
2025-09-13 12:32:39 -04:00
parent 370412f694
commit 4c1261ff83
6 changed files with 301 additions and 2 deletions

View File

@@ -7,6 +7,7 @@
package webrtc
import (
"sync"
"sync/atomic"
"github.com/pion/interceptor"
@@ -14,6 +15,7 @@ import (
"github.com/pion/interceptor/pkg/nack"
"github.com/pion/interceptor/pkg/report"
"github.com/pion/interceptor/pkg/rfc8888"
"github.com/pion/interceptor/pkg/stats"
"github.com/pion/interceptor/pkg/twcc"
"github.com/pion/rtp"
"github.com/pion/sdp/v3"
@@ -35,9 +37,41 @@ func RegisterDefaultInterceptors(mediaEngine *MediaEngine, interceptorRegistry *
return err
}
if err := ConfigureStatsInterceptor(interceptorRegistry); err != nil {
return err
}
return ConfigureTWCCSender(mediaEngine, interceptorRegistry)
}
// ConfigureStatsInterceptor will setup everything necessary for generating RTP stream statistics.
func ConfigureStatsInterceptor(interceptorRegistry *interceptor.Registry) error {
statsInterceptor, err := stats.NewInterceptor()
if err != nil {
return err
}
statsInterceptor.OnNewPeerConnection(func(id string, stats stats.Getter) {
statsGetter.Store(id, stats)
})
interceptorRegistry.Add(statsInterceptor)
return nil
}
// lookupStats returns the stats getter for a given peerconnection.statsId.
func lookupStats(id string) (stats.Getter, bool) {
if value, exists := statsGetter.Load(id); exists {
if getter, ok := value.(stats.Getter); ok {
return getter, true
}
}
return nil, false
}
// key: string (peerconnection.statsId), value: stats.Getter
var statsGetter sync.Map // nolint:gochecknoglobals
// ConfigureRTCPReports will setup everything necessary for generating Sender and Receiver Reports.
func ConfigureRTCPReports(interceptorRegistry *interceptor.Registry) error {
reciver, err := report.NewReceiverInterceptor()

View File

@@ -11,6 +11,7 @@ import (
"context"
"fmt"
"io"
"reflect"
"sync/atomic"
"testing"
"time"
@@ -353,6 +354,36 @@ func Test_Interceptor_ZeroSSRC(t *testing.T) {
closePairNow(t, offerer, answerer)
}
// TestStatsInterceptorIsAddedByDefault tests that the stats interceptor
// is automatically added when creating a PeerConnection with the default API
// and that its Getter is properly captured.
func TestStatsInterceptorIsAddedByDefault(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()
report := test.CheckRoutines(t)
defer report()
pc, err := NewPeerConnection(Configuration{})
assert.NoError(t, err)
defer func() {
assert.NoError(t, pc.Close())
}()
assert.NotNil(t, pc.statsGetter, "statsGetter should be non-nil with NewPeerConnection")
// Also assert that the getter stored during interceptor Build matches
// the one attached to this PeerConnection.
getter, ok := lookupStats(pc.statsID)
assert.True(t, ok, "lookupStats should return a getter for this statsID")
assert.NotNil(t, getter)
assert.Equal(t,
reflect.ValueOf(getter).Pointer(),
reflect.ValueOf(pc.statsGetter).Pointer(),
"getter returned by lookup should match pc.statsGetter",
)
}
// TestInterceptorNack is an end-to-end test for the NACK sender.
// It tests that:
// - we get a NACK if we negotiated generic NACks;

View File

@@ -22,6 +22,7 @@ import (
"github.com/pion/ice/v4"
"github.com/pion/interceptor"
"github.com/pion/interceptor/pkg/stats"
"github.com/pion/logging"
"github.com/pion/rtcp"
"github.com/pion/sdp/v3"
@@ -92,6 +93,7 @@ type PeerConnection struct {
log logging.LeveledLogger
interceptorRTCPWriter interceptor.RTCPWriter
statsGetter stats.Getter
}
// NewPeerConnection creates a PeerConnection with the default codecs and interceptors.
@@ -143,11 +145,15 @@ func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection,
pc.iceConnectionState.Store(ICEConnectionStateNew)
pc.connectionState.Store(PeerConnectionStateNew)
i, err := api.interceptorRegistry.Build("")
i, err := api.interceptorRegistry.Build(pc.statsID)
if err != nil {
return nil, err
}
if getter, ok := lookupStats(pc.statsID); ok {
pc.statsGetter = getter
}
pc.api = &API{
settingEngine: api.settingEngine,
interceptor: i,
@@ -2631,6 +2637,11 @@ func (pc *PeerConnection) GetStats() StatsReport {
}
pc.mu.Unlock()
receivers := pc.GetReceivers()
for _, receiver := range receivers {
receiver.collectStats(statsCollector, pc.statsGetter)
}
pc.api.mediaEngine.collectStats(statsCollector)
return statsCollector.Ready()

View File

@@ -10,10 +10,12 @@ import (
"encoding/binary"
"fmt"
"io"
"math"
"sync"
"time"
"github.com/pion/interceptor"
"github.com/pion/interceptor/pkg/stats"
"github.com/pion/logging"
"github.com/pion/rtcp"
"github.com/pion/srtp/v3"
@@ -401,6 +403,71 @@ func (r *RTPReceiver) Stop() error { //nolint:cyclop
return err
}
func (r *RTPReceiver) collectStats(collector *statsReportCollector, statsGetter stats.Getter) {
r.mu.Lock()
defer r.mu.Unlock()
// Emit inbound-rtp stats for each track
mid := ""
if r.tr != nil {
mid = r.tr.Mid()
}
now := statsTimestampNow()
for trackIndex := range r.tracks {
remoteTrack := r.tracks[trackIndex].track
if remoteTrack == nil {
continue
}
collector.Collecting()
inboundID := fmt.Sprintf("inbound-rtp-%d", uint32(remoteTrack.SSRC()))
codecID := ""
if remoteTrack.codec.statsID != "" {
codecID = remoteTrack.codec.statsID
}
inboundStats := InboundRTPStreamStats{
Mid: mid,
Timestamp: now,
Type: StatsTypeInboundRTP,
ID: inboundID,
SSRC: remoteTrack.SSRC(),
Kind: r.kind.String(),
TransportID: "iceTransport",
CodecID: codecID,
}
stats := statsGetter.Get(uint32(remoteTrack.SSRC()))
if stats != nil { //nolint:nestif // nested to keep mapping local
// Wrap-around casting by design, with warnings if overflow/underflow is detected.
pr := stats.InboundRTPStreamStats.PacketsReceived
if pr > math.MaxUint32 {
r.log.Warnf("Inbound PacketsReceived exceeds uint32 and will wrap: %d", pr)
}
inboundStats.PacketsReceived = uint32(pr) //nolint:gosec
pl := stats.InboundRTPStreamStats.PacketsLost
if pl > math.MaxInt32 || pl < math.MinInt32 {
r.log.Warnf("Inbound PacketsLost exceeds int32 range and will wrap: %d", pl)
}
inboundStats.PacketsLost = int32(pl) //nolint:gosec
inboundStats.Jitter = stats.InboundRTPStreamStats.Jitter
inboundStats.BytesReceived = stats.InboundRTPStreamStats.BytesReceived
inboundStats.HeaderBytesReceived = stats.InboundRTPStreamStats.HeaderBytesReceived
timestamp := stats.InboundRTPStreamStats.LastPacketReceivedTimestamp
inboundStats.LastPacketReceivedTimestamp = StatsTimestamp(
timestamp.UnixNano() / int64(time.Millisecond))
inboundStats.FIRCount = stats.InboundRTPStreamStats.FIRCount
inboundStats.PLICount = stats.InboundRTPStreamStats.PLICount
inboundStats.NACKCount = stats.InboundRTPStreamStats.NACKCount
}
collector.Collect(inboundID, inboundStats)
}
}
func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams {
for i := range r.tracks {
if r.tracks[i].track == t {

View File

@@ -8,13 +8,17 @@ package webrtc
import (
"context"
"math"
"testing"
"time"
"github.com/pion/interceptor"
"github.com/pion/interceptor/pkg/stats"
"github.com/pion/logging"
"github.com/pion/transport/v3/test"
"github.com/pion/webrtc/v4/pkg/media"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Assert that SetReadDeadline works as expected
@@ -64,3 +68,70 @@ func Test_RTPReceiver_SetReadDeadline(t *testing.T) {
assert.NoError(t, wan.Stop())
closePairNow(t, sender, receiver)
}
// TestRTPReceiver_CollectStats_Mapping validates that collectStats maps
// interceptor/pkg/stats values into InboundRTPStreamStats.
func TestRTPReceiver_CollectStats_Mapping(t *testing.T) {
ssrc := SSRC(1234)
now := time.Now()
pr := uint64(math.MaxUint32) + 42
pl := int64(math.MaxInt32) + 7
jitter := 0.123
bytes := uint64(98765)
hdrBytes := uint64(4321)
fir := uint32(3)
pli := uint32(5)
nack := uint32(7)
fg := &fakeGetter{s: stats.Stats{
InboundRTPStreamStats: stats.InboundRTPStreamStats{
ReceivedRTPStreamStats: stats.ReceivedRTPStreamStats{
PacketsReceived: pr,
PacketsLost: pl,
Jitter: jitter,
},
LastPacketReceivedTimestamp: now,
HeaderBytesReceived: hdrBytes,
BytesReceived: bytes,
FIRCount: fir,
PLICount: pli,
NACKCount: nack,
},
}}
// Minimal RTPReceiver with one track
r := &RTPReceiver{
kind: RTPCodecTypeVideo,
log: logging.NewDefaultLoggerFactory().NewLogger("RTPReceiverTest"),
}
tr := newTrackRemote(RTPCodecTypeVideo, ssrc, 0, "", r)
r.tracks = []trackStreams{{track: tr}}
collector := newStatsReportCollector()
r.collectStats(collector, fg)
report := collector.Ready()
// Fetch the generated inbound-rtp stat by ID
statID := "inbound-rtp-1234"
got, ok := report[statID]
require.True(t, ok, "missing inbound stat")
inbound, ok := got.(InboundRTPStreamStats)
require.True(t, ok)
// Wrap-around semantics for casts
assert.Equal(t, uint32(pr), inbound.PacketsReceived) //nolint:gosec
assert.Equal(t, int32(pl), inbound.PacketsLost) //nolint:gosec
assert.Equal(t, jitter, inbound.Jitter)
assert.Equal(t, bytes, inbound.BytesReceived)
assert.Equal(t, hdrBytes, inbound.HeaderBytesReceived)
assert.Equal(t, fir, inbound.FIRCount)
assert.Equal(t, pli, inbound.PLICount)
assert.Equal(t, nack, inbound.NACKCount)
// Timestamp should be set (millisecond precision)
assert.Greater(t, float64(inbound.LastPacketReceivedTimestamp), 0.0)
}
type fakeGetter struct{ s stats.Stats }
func (f *fakeGetter) Get(uint32) *stats.Stats { return &f.s }

View File

@@ -15,6 +15,7 @@ import (
"time"
"github.com/pion/ice/v4"
"github.com/pion/webrtc/v4/pkg/media"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -1277,6 +1278,28 @@ func findCandidatePairStats(t *testing.T, report StatsReport) []ICECandidatePair
return result
}
func findInboundRTPStats(report StatsReport) []InboundRTPStreamStats {
result := []InboundRTPStreamStats{}
for _, s := range report {
if stats, ok := s.(InboundRTPStreamStats); ok {
result = append(result, stats)
}
}
return result
}
func findInboundRTPStatsBySSRC(report StatsReport, ssrc SSRC) []InboundRTPStreamStats {
result := []InboundRTPStreamStats{}
for _, s := range report {
if stats, ok := s.(InboundRTPStreamStats); ok && stats.SSRC == ssrc {
result = append(result, stats)
}
}
return result
}
func signalPairForStats(pcOffer *PeerConnection, pcAnswer *PeerConnection) error {
offerChan := make(chan SessionDescription)
pcOffer.OnICECandidate(func(candidate *ICECandidate) {
@@ -1360,7 +1383,7 @@ func TestStatsConvertState(t *testing.T) {
}
}
func TestPeerConnection_GetStats(t *testing.T) {
func TestPeerConnection_GetStats(t *testing.T) { //nolint:cyclop // involves multiple branches and waits
offerPC, answerPC, err := newPair()
assert.NoError(t, err)
@@ -1438,6 +1461,68 @@ func TestPeerConnection_GetStats(t *testing.T) {
assert.NotEmpty(t, findLocalCandidateStats(reportPCAnswer))
assert.NotEmpty(t, findRemoteCandidateStats(reportPCAnswer))
assert.NotEmpty(t, findCandidatePairStats(t, reportPCAnswer))
inboundAnswer := findInboundRTPStats(reportPCAnswer)
assert.NotEmpty(t, inboundAnswer)
// Send a sample frame to generate RTP packets
sample := media.Sample{
Data: []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05},
Duration: time.Second / 30, // 30 FPS
Timestamp: time.Now(),
}
assert.NoError(t, track1.WriteSample(sample))
// Poll for packets to arrive rather than using a fixed wait time.
// This ensures the test is deterministic and fails fast with clear context
// if packets don't arrive within the timeout period.
assert.Eventually(t, func() bool {
reportPCAnswer = answerPC.GetStats()
receivers := answerPC.GetReceivers()
for _, r := range receivers {
for _, tr := range r.Tracks() {
if tr.SSRC() == 0 {
continue
}
matches := findInboundRTPStatsBySSRC(reportPCAnswer, tr.SSRC())
if len(matches) > 0 && matches[0].PacketsReceived > 0 {
return true
}
}
}
return false
}, time.Second, 10*time.Millisecond, "Expected packets to be received")
// Get fresh stats after sending the sample
reportPCAnswer = answerPC.GetStats()
receivers := answerPC.GetReceivers()
for _, r := range receivers {
for _, tr := range r.Tracks() {
if tr.SSRC() == 0 {
continue
}
matches := findInboundRTPStatsBySSRC(reportPCAnswer, tr.SSRC())
require.NotEmpty(t, matches)
for _, inboundStats := range matches {
assert.Equal(t, StatsTypeInboundRTP, inboundStats.Type)
assert.Equal(t, tr.SSRC(), inboundStats.SSRC)
assert.NotEmpty(t, inboundStats.Kind)
assert.NotEmpty(t, inboundStats.TransportID)
assert.Greater(t, inboundStats.PacketsReceived, uint32(0))
assert.GreaterOrEqual(t, inboundStats.PacketsLost, int32(0))
assert.Greater(t, inboundStats.BytesReceived, uint64(0))
assert.GreaterOrEqual(t, inboundStats.Jitter, 0.0)
assert.GreaterOrEqual(t, inboundStats.HeaderBytesReceived, uint64(0))
assert.GreaterOrEqual(t, inboundStats.LastPacketReceivedTimestamp, StatsTimestamp(0))
assert.GreaterOrEqual(t, inboundStats.FIRCount, uint32(0))
assert.GreaterOrEqual(t, inboundStats.PLICount, uint32(0))
assert.GreaterOrEqual(t, inboundStats.NACKCount, uint32(0))
}
}
}
assert.NoError(t, err)
for i := range offerPC.api.mediaEngine.videoCodecs {
codecStat := getCodecStats(t, reportPCOffer, &(offerPC.api.mediaEngine.videoCodecs[i]))