diff --git a/track_local_static.go b/track_local_static.go index f590d811..7709b563 100644 --- a/track_local_static.go +++ b/track_local_static.go @@ -226,10 +226,12 @@ func (s *TrackLocalStaticRTP) Write(b []byte) (n int, err error) { // TrackLocalStaticSample is a TrackLocal that has a pre-set codec and accepts Samples. // If you wish to send a RTP Packet use TrackLocalStaticRTP. type TrackLocalStaticSample struct { + mu sync.Mutex packetizer rtp.Packetizer sequencer rtp.Sequencer rtpTrack *TrackLocalStaticRTP clockRate float64 + remainder float64 } // NewTrackLocalStaticSample returns a TrackLocalStaticSample. @@ -329,22 +331,36 @@ func (s *TrackLocalStaticSample) WriteSample(sample media.Sample) error { s.rtpTrack.mu.RLock() packetizer := s.packetizer clockRate := s.clockRate + sequencer := s.sequencer s.rtpTrack.mu.RUnlock() - if packetizer == nil { return nil } + s.mu.Lock() + remainder := s.remainder + // skip packets by the number of previously dropped packets for i := uint16(0); i < sample.PrevDroppedPackets; i++ { - s.sequencer.NextSequenceNumber() + sequencer.NextSequenceNumber() } - samples := uint32(sample.Duration.Seconds() * clockRate) + tickF := sample.Duration.Seconds() * clockRate + if sample.PrevDroppedPackets > 0 { - packetizer.SkipSamples(samples * uint32(sample.PrevDroppedPackets)) + dropTotal := tickF*float64(sample.PrevDroppedPackets) + remainder + dropTicks := uint32(dropTotal) + remainder = dropTotal - float64(dropTicks) + packetizer.SkipSamples(dropTicks) } - packets := packetizer.Packetize(sample.Data, samples) + + curTotal := tickF + remainder + curTicks := uint32(curTotal) + remainder = curTotal - float64(curTicks) + + s.remainder = remainder + packets := packetizer.Packetize(sample.Data, curTicks) + s.mu.Unlock() writeErrs := []error{} for _, p := range packets { diff --git a/track_local_static_test.go b/track_local_static_test.go index 48f4b524..2494630d 100644 --- a/track_local_static_test.go +++ b/track_local_static_test.go @@ -859,3 +859,130 @@ func TestBaseTrackLocalContext_HeaderExtensions_NilWhenUnset(t *testing.T) { var ctx baseTrackLocalContext assert.Nil(t, ctx.HeaderExtensions()) } + +func TestTrackLocalStaticSample_WriteSample_NoTimestampDrift(t *testing.T) { + const clockRate = uint32(90000) + frameDuration := time.Second / 60 + totalDuration := time.Hour + numFrames := int(totalDuration / frameDuration) + + track, err := NewTrackLocalStaticSample( + RTPCodecCapability{MimeType: MimeTypeVP8, ClockRate: clockRate}, + "video", "pion", + ) + assert.NoError(t, err) + + pack := &countingPacketizer{} + + track.rtpTrack.mu.Lock() + track.packetizer = pack + track.clockRate = float64(clockRate) + track.sequencer = rtp.NewRandomSequencer() + track.rtpTrack.mu.Unlock() + + for i := 0; i < numFrames; i++ { + err := track.WriteSample(media.Sample{ + Data: []byte{0x00}, + Duration: frameDuration, + }) + assert.NoError(t, err) + } + + expected := (uint64(numFrames) * uint64(frameDuration.Nanoseconds()) * uint64(clockRate)) / 1e9 //nolint:gosec + got := pack.totalSamples + + var drift uint64 + if got > expected { + drift = got - expected + } else { + drift = expected - got + } + + t.Logf("frames=%d frameDuration=%s expectedTicks=%d gotTicks=%d driftTicks=%d driftSeconds=%.6f", + numFrames, frameDuration, expected, got, drift, float64(drift)/float64(clockRate), + ) + + assert.LessOrEqual(t, drift, uint64(1), "timestamp drift should be negligible") +} + +func TestTrackLocalStaticSample_WriteSample_DroppedPackets_NoDrift(t *testing.T) { + const clockRate = uint32(90000) + frameDuration := time.Second / 60 + totalDuration := time.Hour + numFrames := int(totalDuration / frameDuration) + tickF := frameDuration.Seconds() * float64(clockRate) + + track, err := NewTrackLocalStaticSample( + RTPCodecCapability{MimeType: MimeTypeVP8, ClockRate: clockRate}, + "video", "pion", + ) + assert.NoError(t, err) + + pack := &countingPacketizer{} + + track.rtpTrack.mu.Lock() + track.packetizer = pack + track.clockRate = float64(clockRate) + track.sequencer = rtp.NewRandomSequencer() + track.rtpTrack.mu.Unlock() + + var expectedTotal uint64 + var remainder float64 + + for i := 0; i < numFrames; i++ { + var drops uint16 + if (i+1)%300 == 0 { + drops = uint16((i/300)%3 + 1) //nolint:gosec + } + + if drops > 0 { + dropTotal := tickF*float64(drops) + remainder + dropTicks := uint32(dropTotal) + remainder = dropTotal - float64(dropTicks) + expectedTotal += uint64(dropTicks) + } + + curTotal := tickF + remainder + curTicks := uint32(curTotal) + remainder = curTotal - float64(curTicks) + expectedTotal += uint64(curTicks) + + err := track.WriteSample(media.Sample{ + Data: []byte{0x00}, + Duration: frameDuration, + PrevDroppedPackets: drops, + }) + assert.NoError(t, err) + } + + got := pack.totalSamples + + var drift uint64 + if got > expectedTotal { + drift = got - expectedTotal + } else { + drift = expectedTotal - got + } + + t.Logf("frames=%d frameDuration=%s expectedTicks=%d gotTicks=%d driftTicks=%d driftSeconds=%.6f", + numFrames, frameDuration, expectedTotal, got, drift, float64(drift)/float64(clockRate), + ) + + assert.LessOrEqual(t, drift, uint64(1), "timestamp drift with drops should be negligible") +} + +type countingPacketizer struct { + totalSamples uint64 +} + +func (p *countingPacketizer) Packetize(payload []byte, samples uint32) []*rtp.Packet { + p.totalSamples += uint64(samples) + + return nil +} + +func (p *countingPacketizer) GeneratePadding(samples uint32) []*rtp.Packet { return nil } +func (p *countingPacketizer) EnableAbsSendTime(value int) {} +func (p *countingPacketizer) SkipSamples(skippedSamples uint32) { + p.totalSamples += uint64(skippedSamples) +}