Fix drift in WriteSample

This commit is contained in:
Joe Turki
2025-12-16 02:01:33 +02:00
parent 1519afa466
commit ef58d38984
2 changed files with 148 additions and 5 deletions

View File

@@ -226,10 +226,12 @@ func (s *TrackLocalStaticRTP) Write(b []byte) (n int, err error) {
// TrackLocalStaticSample is a TrackLocal that has a pre-set codec and accepts Samples.
// If you wish to send a RTP Packet use TrackLocalStaticRTP.
type TrackLocalStaticSample struct {
mu sync.Mutex
packetizer rtp.Packetizer
sequencer rtp.Sequencer
rtpTrack *TrackLocalStaticRTP
clockRate float64
remainder float64
}
// NewTrackLocalStaticSample returns a TrackLocalStaticSample.
@@ -329,22 +331,36 @@ func (s *TrackLocalStaticSample) WriteSample(sample media.Sample) error {
s.rtpTrack.mu.RLock()
packetizer := s.packetizer
clockRate := s.clockRate
sequencer := s.sequencer
s.rtpTrack.mu.RUnlock()
if packetizer == nil {
return nil
}
s.mu.Lock()
remainder := s.remainder
// skip packets by the number of previously dropped packets
for i := uint16(0); i < sample.PrevDroppedPackets; i++ {
s.sequencer.NextSequenceNumber()
sequencer.NextSequenceNumber()
}
samples := uint32(sample.Duration.Seconds() * clockRate)
tickF := sample.Duration.Seconds() * clockRate
if sample.PrevDroppedPackets > 0 {
packetizer.SkipSamples(samples * uint32(sample.PrevDroppedPackets))
dropTotal := tickF*float64(sample.PrevDroppedPackets) + remainder
dropTicks := uint32(dropTotal)
remainder = dropTotal - float64(dropTicks)
packetizer.SkipSamples(dropTicks)
}
packets := packetizer.Packetize(sample.Data, samples)
curTotal := tickF + remainder
curTicks := uint32(curTotal)
remainder = curTotal - float64(curTicks)
s.remainder = remainder
packets := packetizer.Packetize(sample.Data, curTicks)
s.mu.Unlock()
writeErrs := []error{}
for _, p := range packets {

View File

@@ -859,3 +859,130 @@ func TestBaseTrackLocalContext_HeaderExtensions_NilWhenUnset(t *testing.T) {
var ctx baseTrackLocalContext
assert.Nil(t, ctx.HeaderExtensions())
}
func TestTrackLocalStaticSample_WriteSample_NoTimestampDrift(t *testing.T) {
const clockRate = uint32(90000)
frameDuration := time.Second / 60
totalDuration := time.Hour
numFrames := int(totalDuration / frameDuration)
track, err := NewTrackLocalStaticSample(
RTPCodecCapability{MimeType: MimeTypeVP8, ClockRate: clockRate},
"video", "pion",
)
assert.NoError(t, err)
pack := &countingPacketizer{}
track.rtpTrack.mu.Lock()
track.packetizer = pack
track.clockRate = float64(clockRate)
track.sequencer = rtp.NewRandomSequencer()
track.rtpTrack.mu.Unlock()
for i := 0; i < numFrames; i++ {
err := track.WriteSample(media.Sample{
Data: []byte{0x00},
Duration: frameDuration,
})
assert.NoError(t, err)
}
expected := (uint64(numFrames) * uint64(frameDuration.Nanoseconds()) * uint64(clockRate)) / 1e9 //nolint:gosec
got := pack.totalSamples
var drift uint64
if got > expected {
drift = got - expected
} else {
drift = expected - got
}
t.Logf("frames=%d frameDuration=%s expectedTicks=%d gotTicks=%d driftTicks=%d driftSeconds=%.6f",
numFrames, frameDuration, expected, got, drift, float64(drift)/float64(clockRate),
)
assert.LessOrEqual(t, drift, uint64(1), "timestamp drift should be negligible")
}
func TestTrackLocalStaticSample_WriteSample_DroppedPackets_NoDrift(t *testing.T) {
const clockRate = uint32(90000)
frameDuration := time.Second / 60
totalDuration := time.Hour
numFrames := int(totalDuration / frameDuration)
tickF := frameDuration.Seconds() * float64(clockRate)
track, err := NewTrackLocalStaticSample(
RTPCodecCapability{MimeType: MimeTypeVP8, ClockRate: clockRate},
"video", "pion",
)
assert.NoError(t, err)
pack := &countingPacketizer{}
track.rtpTrack.mu.Lock()
track.packetizer = pack
track.clockRate = float64(clockRate)
track.sequencer = rtp.NewRandomSequencer()
track.rtpTrack.mu.Unlock()
var expectedTotal uint64
var remainder float64
for i := 0; i < numFrames; i++ {
var drops uint16
if (i+1)%300 == 0 {
drops = uint16((i/300)%3 + 1) //nolint:gosec
}
if drops > 0 {
dropTotal := tickF*float64(drops) + remainder
dropTicks := uint32(dropTotal)
remainder = dropTotal - float64(dropTicks)
expectedTotal += uint64(dropTicks)
}
curTotal := tickF + remainder
curTicks := uint32(curTotal)
remainder = curTotal - float64(curTicks)
expectedTotal += uint64(curTicks)
err := track.WriteSample(media.Sample{
Data: []byte{0x00},
Duration: frameDuration,
PrevDroppedPackets: drops,
})
assert.NoError(t, err)
}
got := pack.totalSamples
var drift uint64
if got > expectedTotal {
drift = got - expectedTotal
} else {
drift = expectedTotal - got
}
t.Logf("frames=%d frameDuration=%s expectedTicks=%d gotTicks=%d driftTicks=%d driftSeconds=%.6f",
numFrames, frameDuration, expectedTotal, got, drift, float64(drift)/float64(clockRate),
)
assert.LessOrEqual(t, drift, uint64(1), "timestamp drift with drops should be negligible")
}
type countingPacketizer struct {
totalSamples uint64
}
func (p *countingPacketizer) Packetize(payload []byte, samples uint32) []*rtp.Packet {
p.totalSamples += uint64(samples)
return nil
}
func (p *countingPacketizer) GeneratePadding(samples uint32) []*rtp.Packet { return nil }
func (p *countingPacketizer) EnableAbsSendTime(value int) {}
func (p *countingPacketizer) SkipSamples(skippedSamples uint32) {
p.totalSamples += uint64(skippedSamples)
}