diff --git a/interceptor.go b/interceptor.go index 5d7dea4c..41041cb9 100644 --- a/interceptor.go +++ b/interceptor.go @@ -7,6 +7,7 @@ package webrtc import ( + "sync" "sync/atomic" "github.com/pion/interceptor" @@ -14,6 +15,7 @@ import ( "github.com/pion/interceptor/pkg/nack" "github.com/pion/interceptor/pkg/report" "github.com/pion/interceptor/pkg/rfc8888" + "github.com/pion/interceptor/pkg/stats" "github.com/pion/interceptor/pkg/twcc" "github.com/pion/rtp" "github.com/pion/sdp/v3" @@ -35,9 +37,41 @@ func RegisterDefaultInterceptors(mediaEngine *MediaEngine, interceptorRegistry * return err } + if err := ConfigureStatsInterceptor(interceptorRegistry); err != nil { + return err + } + return ConfigureTWCCSender(mediaEngine, interceptorRegistry) } +// ConfigureStatsInterceptor will setup everything necessary for generating RTP stream statistics. +func ConfigureStatsInterceptor(interceptorRegistry *interceptor.Registry) error { + statsInterceptor, err := stats.NewInterceptor() + if err != nil { + return err + } + statsInterceptor.OnNewPeerConnection(func(id string, stats stats.Getter) { + statsGetter.Store(id, stats) + }) + interceptorRegistry.Add(statsInterceptor) + + return nil +} + +// lookupStats returns the stats getter for a given peerconnection.statsId. +func lookupStats(id string) (stats.Getter, bool) { + if value, exists := statsGetter.Load(id); exists { + if getter, ok := value.(stats.Getter); ok { + return getter, true + } + } + + return nil, false +} + +// key: string (peerconnection.statsId), value: stats.Getter +var statsGetter sync.Map // nolint:gochecknoglobals + // ConfigureRTCPReports will setup everything necessary for generating Sender and Receiver Reports. func ConfigureRTCPReports(interceptorRegistry *interceptor.Registry) error { reciver, err := report.NewReceiverInterceptor() diff --git a/interceptor_test.go b/interceptor_test.go index 335ee0f3..8453b106 100644 --- a/interceptor_test.go +++ b/interceptor_test.go @@ -11,6 +11,7 @@ import ( "context" "fmt" "io" + "reflect" "sync/atomic" "testing" "time" @@ -353,6 +354,36 @@ func Test_Interceptor_ZeroSSRC(t *testing.T) { closePairNow(t, offerer, answerer) } +// TestStatsInterceptorIsAddedByDefault tests that the stats interceptor +// is automatically added when creating a PeerConnection with the default API +// and that its Getter is properly captured. +func TestStatsInterceptorIsAddedByDefault(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + pc, err := NewPeerConnection(Configuration{}) + assert.NoError(t, err) + defer func() { + assert.NoError(t, pc.Close()) + }() + + assert.NotNil(t, pc.statsGetter, "statsGetter should be non-nil with NewPeerConnection") + + // Also assert that the getter stored during interceptor Build matches + // the one attached to this PeerConnection. + getter, ok := lookupStats(pc.statsID) + assert.True(t, ok, "lookupStats should return a getter for this statsID") + assert.NotNil(t, getter) + assert.Equal(t, + reflect.ValueOf(getter).Pointer(), + reflect.ValueOf(pc.statsGetter).Pointer(), + "getter returned by lookup should match pc.statsGetter", + ) +} + // TestInterceptorNack is an end-to-end test for the NACK sender. // It tests that: // - we get a NACK if we negotiated generic NACks; diff --git a/peerconnection.go b/peerconnection.go index 086229b9..0d761c06 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -22,6 +22,7 @@ import ( "github.com/pion/ice/v4" "github.com/pion/interceptor" + "github.com/pion/interceptor/pkg/stats" "github.com/pion/logging" "github.com/pion/rtcp" "github.com/pion/sdp/v3" @@ -92,6 +93,7 @@ type PeerConnection struct { log logging.LeveledLogger interceptorRTCPWriter interceptor.RTCPWriter + statsGetter stats.Getter } // NewPeerConnection creates a PeerConnection with the default codecs and interceptors. @@ -143,11 +145,15 @@ func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection, pc.iceConnectionState.Store(ICEConnectionStateNew) pc.connectionState.Store(PeerConnectionStateNew) - i, err := api.interceptorRegistry.Build("") + i, err := api.interceptorRegistry.Build(pc.statsID) if err != nil { return nil, err } + if getter, ok := lookupStats(pc.statsID); ok { + pc.statsGetter = getter + } + pc.api = &API{ settingEngine: api.settingEngine, interceptor: i, @@ -2631,6 +2637,11 @@ func (pc *PeerConnection) GetStats() StatsReport { } pc.mu.Unlock() + receivers := pc.GetReceivers() + for _, receiver := range receivers { + receiver.collectStats(statsCollector, pc.statsGetter) + } + pc.api.mediaEngine.collectStats(statsCollector) return statsCollector.Ready() diff --git a/rtpreceiver.go b/rtpreceiver.go index dacef04f..9e369616 100644 --- a/rtpreceiver.go +++ b/rtpreceiver.go @@ -10,10 +10,12 @@ import ( "encoding/binary" "fmt" "io" + "math" "sync" "time" "github.com/pion/interceptor" + "github.com/pion/interceptor/pkg/stats" "github.com/pion/logging" "github.com/pion/rtcp" "github.com/pion/srtp/v3" @@ -401,6 +403,71 @@ func (r *RTPReceiver) Stop() error { //nolint:cyclop return err } +func (r *RTPReceiver) collectStats(collector *statsReportCollector, statsGetter stats.Getter) { + r.mu.Lock() + defer r.mu.Unlock() + + // Emit inbound-rtp stats for each track + mid := "" + if r.tr != nil { + mid = r.tr.Mid() + } + now := statsTimestampNow() + for trackIndex := range r.tracks { + remoteTrack := r.tracks[trackIndex].track + if remoteTrack == nil { + continue + } + + collector.Collecting() + + inboundID := fmt.Sprintf("inbound-rtp-%d", uint32(remoteTrack.SSRC())) + codecID := "" + if remoteTrack.codec.statsID != "" { + codecID = remoteTrack.codec.statsID + } + + inboundStats := InboundRTPStreamStats{ + Mid: mid, + Timestamp: now, + Type: StatsTypeInboundRTP, + ID: inboundID, + SSRC: remoteTrack.SSRC(), + Kind: r.kind.String(), + TransportID: "iceTransport", + CodecID: codecID, + } + + stats := statsGetter.Get(uint32(remoteTrack.SSRC())) + if stats != nil { //nolint:nestif // nested to keep mapping local + // Wrap-around casting by design, with warnings if overflow/underflow is detected. + pr := stats.InboundRTPStreamStats.PacketsReceived + if pr > math.MaxUint32 { + r.log.Warnf("Inbound PacketsReceived exceeds uint32 and will wrap: %d", pr) + } + inboundStats.PacketsReceived = uint32(pr) //nolint:gosec + + pl := stats.InboundRTPStreamStats.PacketsLost + if pl > math.MaxInt32 || pl < math.MinInt32 { + r.log.Warnf("Inbound PacketsLost exceeds int32 range and will wrap: %d", pl) + } + inboundStats.PacketsLost = int32(pl) //nolint:gosec + + inboundStats.Jitter = stats.InboundRTPStreamStats.Jitter + inboundStats.BytesReceived = stats.InboundRTPStreamStats.BytesReceived + inboundStats.HeaderBytesReceived = stats.InboundRTPStreamStats.HeaderBytesReceived + timestamp := stats.InboundRTPStreamStats.LastPacketReceivedTimestamp + inboundStats.LastPacketReceivedTimestamp = StatsTimestamp( + timestamp.UnixNano() / int64(time.Millisecond)) + inboundStats.FIRCount = stats.InboundRTPStreamStats.FIRCount + inboundStats.PLICount = stats.InboundRTPStreamStats.PLICount + inboundStats.NACKCount = stats.InboundRTPStreamStats.NACKCount + } + + collector.Collect(inboundID, inboundStats) + } +} + func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams { for i := range r.tracks { if r.tracks[i].track == t { diff --git a/rtpreceiver_test.go b/rtpreceiver_test.go index 0c09ca92..f9f1db5d 100644 --- a/rtpreceiver_test.go +++ b/rtpreceiver_test.go @@ -8,13 +8,17 @@ package webrtc import ( "context" + "math" "testing" "time" "github.com/pion/interceptor" + "github.com/pion/interceptor/pkg/stats" + "github.com/pion/logging" "github.com/pion/transport/v3/test" "github.com/pion/webrtc/v4/pkg/media" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // Assert that SetReadDeadline works as expected @@ -64,3 +68,70 @@ func Test_RTPReceiver_SetReadDeadline(t *testing.T) { assert.NoError(t, wan.Stop()) closePairNow(t, sender, receiver) } + +// TestRTPReceiver_CollectStats_Mapping validates that collectStats maps +// interceptor/pkg/stats values into InboundRTPStreamStats. +func TestRTPReceiver_CollectStats_Mapping(t *testing.T) { + ssrc := SSRC(1234) + now := time.Now() + pr := uint64(math.MaxUint32) + 42 + pl := int64(math.MaxInt32) + 7 + jitter := 0.123 + bytes := uint64(98765) + hdrBytes := uint64(4321) + fir := uint32(3) + pli := uint32(5) + nack := uint32(7) + + fg := &fakeGetter{s: stats.Stats{ + InboundRTPStreamStats: stats.InboundRTPStreamStats{ + ReceivedRTPStreamStats: stats.ReceivedRTPStreamStats{ + PacketsReceived: pr, + PacketsLost: pl, + Jitter: jitter, + }, + LastPacketReceivedTimestamp: now, + HeaderBytesReceived: hdrBytes, + BytesReceived: bytes, + FIRCount: fir, + PLICount: pli, + NACKCount: nack, + }, + }} + + // Minimal RTPReceiver with one track + r := &RTPReceiver{ + kind: RTPCodecTypeVideo, + log: logging.NewDefaultLoggerFactory().NewLogger("RTPReceiverTest"), + } + tr := newTrackRemote(RTPCodecTypeVideo, ssrc, 0, "", r) + r.tracks = []trackStreams{{track: tr}} + + collector := newStatsReportCollector() + r.collectStats(collector, fg) + report := collector.Ready() + + // Fetch the generated inbound-rtp stat by ID + statID := "inbound-rtp-1234" + got, ok := report[statID] + require.True(t, ok, "missing inbound stat") + + inbound, ok := got.(InboundRTPStreamStats) + require.True(t, ok) + + // Wrap-around semantics for casts + assert.Equal(t, uint32(pr), inbound.PacketsReceived) //nolint:gosec + assert.Equal(t, int32(pl), inbound.PacketsLost) //nolint:gosec + assert.Equal(t, jitter, inbound.Jitter) + assert.Equal(t, bytes, inbound.BytesReceived) + assert.Equal(t, hdrBytes, inbound.HeaderBytesReceived) + assert.Equal(t, fir, inbound.FIRCount) + assert.Equal(t, pli, inbound.PLICount) + assert.Equal(t, nack, inbound.NACKCount) + // Timestamp should be set (millisecond precision) + assert.Greater(t, float64(inbound.LastPacketReceivedTimestamp), 0.0) +} + +type fakeGetter struct{ s stats.Stats } + +func (f *fakeGetter) Get(uint32) *stats.Stats { return &f.s } diff --git a/stats_go_test.go b/stats_go_test.go index fd4d221b..e03a3191 100644 --- a/stats_go_test.go +++ b/stats_go_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/pion/ice/v4" + "github.com/pion/webrtc/v4/pkg/media" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1277,6 +1278,28 @@ func findCandidatePairStats(t *testing.T, report StatsReport) []ICECandidatePair return result } +func findInboundRTPStats(report StatsReport) []InboundRTPStreamStats { + result := []InboundRTPStreamStats{} + for _, s := range report { + if stats, ok := s.(InboundRTPStreamStats); ok { + result = append(result, stats) + } + } + + return result +} + +func findInboundRTPStatsBySSRC(report StatsReport, ssrc SSRC) []InboundRTPStreamStats { + result := []InboundRTPStreamStats{} + for _, s := range report { + if stats, ok := s.(InboundRTPStreamStats); ok && stats.SSRC == ssrc { + result = append(result, stats) + } + } + + return result +} + func signalPairForStats(pcOffer *PeerConnection, pcAnswer *PeerConnection) error { offerChan := make(chan SessionDescription) pcOffer.OnICECandidate(func(candidate *ICECandidate) { @@ -1360,7 +1383,7 @@ func TestStatsConvertState(t *testing.T) { } } -func TestPeerConnection_GetStats(t *testing.T) { +func TestPeerConnection_GetStats(t *testing.T) { //nolint:cyclop // involves multiple branches and waits offerPC, answerPC, err := newPair() assert.NoError(t, err) @@ -1438,6 +1461,68 @@ func TestPeerConnection_GetStats(t *testing.T) { assert.NotEmpty(t, findLocalCandidateStats(reportPCAnswer)) assert.NotEmpty(t, findRemoteCandidateStats(reportPCAnswer)) assert.NotEmpty(t, findCandidatePairStats(t, reportPCAnswer)) + + inboundAnswer := findInboundRTPStats(reportPCAnswer) + assert.NotEmpty(t, inboundAnswer) + + // Send a sample frame to generate RTP packets + sample := media.Sample{ + Data: []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05}, + Duration: time.Second / 30, // 30 FPS + Timestamp: time.Now(), + } + assert.NoError(t, track1.WriteSample(sample)) + + // Poll for packets to arrive rather than using a fixed wait time. + // This ensures the test is deterministic and fails fast with clear context + // if packets don't arrive within the timeout period. + assert.Eventually(t, func() bool { + reportPCAnswer = answerPC.GetStats() + receivers := answerPC.GetReceivers() + for _, r := range receivers { + for _, tr := range r.Tracks() { + if tr.SSRC() == 0 { + continue + } + matches := findInboundRTPStatsBySSRC(reportPCAnswer, tr.SSRC()) + if len(matches) > 0 && matches[0].PacketsReceived > 0 { + return true + } + } + } + + return false + }, time.Second, 10*time.Millisecond, "Expected packets to be received") + + // Get fresh stats after sending the sample + reportPCAnswer = answerPC.GetStats() + + receivers := answerPC.GetReceivers() + for _, r := range receivers { + for _, tr := range r.Tracks() { + if tr.SSRC() == 0 { + continue + } + matches := findInboundRTPStatsBySSRC(reportPCAnswer, tr.SSRC()) + require.NotEmpty(t, matches) + + for _, inboundStats := range matches { + assert.Equal(t, StatsTypeInboundRTP, inboundStats.Type) + assert.Equal(t, tr.SSRC(), inboundStats.SSRC) + assert.NotEmpty(t, inboundStats.Kind) + assert.NotEmpty(t, inboundStats.TransportID) + assert.Greater(t, inboundStats.PacketsReceived, uint32(0)) + assert.GreaterOrEqual(t, inboundStats.PacketsLost, int32(0)) + assert.Greater(t, inboundStats.BytesReceived, uint64(0)) + assert.GreaterOrEqual(t, inboundStats.Jitter, 0.0) + assert.GreaterOrEqual(t, inboundStats.HeaderBytesReceived, uint64(0)) + assert.GreaterOrEqual(t, inboundStats.LastPacketReceivedTimestamp, StatsTimestamp(0)) + assert.GreaterOrEqual(t, inboundStats.FIRCount, uint32(0)) + assert.GreaterOrEqual(t, inboundStats.PLICount, uint32(0)) + assert.GreaterOrEqual(t, inboundStats.NACKCount, uint32(0)) + } + } + } assert.NoError(t, err) for i := range offerPC.api.mediaEngine.videoCodecs { codecStat := getCodecStats(t, reportPCOffer, &(offerPC.api.mediaEngine.videoCodecs[i]))