diff --git a/peerconnection_media_test.go b/peerconnection_media_test.go index 2fc48bec..cde1c0f4 100644 --- a/peerconnection_media_test.go +++ b/peerconnection_media_test.go @@ -13,7 +13,6 @@ import ( "errors" "fmt" "io" - "math/rand" "regexp" "strings" "sync" @@ -2353,21 +2352,54 @@ func Test_PeerConnection_RTX_E2E(t *testing.T) { //nolint:cyclop pcOffer, pcAnswer, wan := createVNetPair(t, nil) - wan.AddChunkFilter(func(vnet.Chunk) bool { - return rand.Intn(5) != 4 //nolint: gosec - }) - track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "track-id", "stream-id") assert.NoError(t, err) rtpSender, err := pcOffer.AddTrack(track) assert.NoError(t, err) + // Signal pair first to negotiate codecs + assert.NoError(t, signalPair(pcOffer, pcAnswer)) + + // Get the negotiated payload type for the media codec + mediaPayloadType := uint8(rtpSender.GetParameters().Codecs[0].PayloadType) + + // Use deterministic packet dropping: drop every 5th packet (20% loss) + // This is more realistic and provides faster, more consistent test results + var packetCount atomic.Uint32 + wan.AddChunkFilter(func(c vnet.Chunk) bool { + // Only filter RTP packets (not RTCP, STUN, etc) + h := &rtp.Header{} + if _, err := h.Unmarshal(c.UserData()); err != nil { + return true // Not an RTP packet, let it through + } + + // Drop every 5th media packet to trigger NACK/RTX + if h.PayloadType == mediaPayloadType { + count := packetCount.Add(1) + if count%5 == 0 { + return false // Drop this packet + } + } + + return true + }) + + // Create context for coordinated cleanup + testCtx, testCancel := context.WithCancel(context.Background()) + defer testCancel() + + // RTCP reader with proper cleanup go func() { rtcpBuf := make([]byte, 1500) for { - if _, _, rtcpErr := rtpSender.Read(rtcpBuf); rtcpErr != nil { + select { + case <-testCtx.Done(): return + default: + if _, _, rtcpErr := rtpSender.Read(rtcpBuf); rtcpErr != nil { + return + } } } }() @@ -2376,45 +2408,88 @@ func Test_PeerConnection_RTX_E2E(t *testing.T) { //nolint:cyclop ssrc := rtpSender.GetParameters().Encodings[0].SSRC rtxRead, rtxReadCancel := context.WithCancel(context.Background()) + defer rtxReadCancel() // Ensure cleanup even if RTX is never detected + + // Track whether we've seen RTX + var rtxDetected atomic.Bool + + // OnTrack with proper cleanup pcAnswer.OnTrack(func(track *TrackRemote, _ *RTPReceiver) { for { - pkt, attributes, readRTPErr := track.ReadRTP() - if errors.Is(readRTPErr, io.EOF) { + select { + case <-testCtx.Done(): return - } else if pkt.PayloadType == 0 { - continue + default: } - assert.NotNil(t, pkt) - assert.Equal(t, pkt.SSRC, uint32(ssrc)) - assert.Equal(t, pkt.PayloadType, uint8(96)) + pkt, attributes, readRTPErr := track.ReadRTP() + if readRTPErr != nil { + return + } + // Validate packet - fail fast if unexpected + if !assert.NotNil(t, pkt) { + return + } + if !assert.Equal(t, uint32(ssrc), pkt.SSRC, "Unexpected SSRC") { + return + } + if !assert.Equal(t, mediaPayloadType, pkt.PayloadType, "Unexpected payload type") { + return + } + + // Check if this is an RTX retransmission rtxPayloadType := attributes.Get(AttributeRtxPayloadType) rtxSequenceNumber := attributes.Get(AttributeRtxSequenceNumber) rtxSSRC := attributes.Get(AttributeRtxSsrc) if rtxPayloadType != nil && rtxSequenceNumber != nil && rtxSSRC != nil { - assert.Equal(t, rtxPayloadType, uint8(97)) - assert.Equal(t, rtxSSRC, uint32(rtxSsrc)) + // Validate RTX attributes + if !assert.Equal(t, uint8(97), rtxPayloadType, "Unexpected RTX payload type") { + return + } + if !assert.Equal(t, uint32(rtxSsrc), rtxSSRC, "Unexpected RTX SSRC") { + return + } - rtxReadCancel() + // RTX detected successfully + if rtxDetected.CompareAndSwap(false, true) { + rtxReadCancel() + + return + } } } }) - assert.NoError(t, signalPair(pcOffer, pcAnswer)) + // Send packets until RTX is detected or timeout + // With 20% loss, we should see RTX within a few seconds + rtxTimeout := time.NewTimer(10 * time.Second) + defer rtxTimeout.Stop() func() { + ticker := time.NewTicker(20 * time.Millisecond) + defer ticker.Stop() + for { select { - case <-time.After(20 * time.Millisecond): + case <-ticker.C: writeErr := track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}) assert.NoError(t, writeErr) case <-rtxRead.Done(): + + return + case <-rtxTimeout.C: + assert.Fail(t, "RTX packet not detected within timeout - NACK/RTX mechanism may not be working") + return } } }() - assert.NoError(t, wan.Stop()) + // Verify RTX was actually detected + assert.True(t, rtxDetected.Load(), "RTX packet should have been detected") + + // Close peer connections before stopping the network closePairNow(t, pcOffer, pcAnswer) + assert.NoError(t, wan.Stop()) }