mirror of
https://github.com/pion/webrtc.git
synced 2025-12-24 11:51:03 +08:00
Fix drift in WriteSample
This commit is contained in:
@@ -226,10 +226,12 @@ func (s *TrackLocalStaticRTP) Write(b []byte) (n int, err error) {
|
||||
// TrackLocalStaticSample is a TrackLocal that has a pre-set codec and accepts Samples.
|
||||
// If you wish to send a RTP Packet use TrackLocalStaticRTP.
|
||||
type TrackLocalStaticSample struct {
|
||||
mu sync.Mutex
|
||||
packetizer rtp.Packetizer
|
||||
sequencer rtp.Sequencer
|
||||
rtpTrack *TrackLocalStaticRTP
|
||||
clockRate float64
|
||||
remainder float64
|
||||
}
|
||||
|
||||
// NewTrackLocalStaticSample returns a TrackLocalStaticSample.
|
||||
@@ -329,22 +331,36 @@ func (s *TrackLocalStaticSample) WriteSample(sample media.Sample) error {
|
||||
s.rtpTrack.mu.RLock()
|
||||
packetizer := s.packetizer
|
||||
clockRate := s.clockRate
|
||||
sequencer := s.sequencer
|
||||
s.rtpTrack.mu.RUnlock()
|
||||
|
||||
if packetizer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
remainder := s.remainder
|
||||
|
||||
// skip packets by the number of previously dropped packets
|
||||
for i := uint16(0); i < sample.PrevDroppedPackets; i++ {
|
||||
s.sequencer.NextSequenceNumber()
|
||||
sequencer.NextSequenceNumber()
|
||||
}
|
||||
|
||||
samples := uint32(sample.Duration.Seconds() * clockRate)
|
||||
tickF := sample.Duration.Seconds() * clockRate
|
||||
|
||||
if sample.PrevDroppedPackets > 0 {
|
||||
packetizer.SkipSamples(samples * uint32(sample.PrevDroppedPackets))
|
||||
dropTotal := tickF*float64(sample.PrevDroppedPackets) + remainder
|
||||
dropTicks := uint32(dropTotal)
|
||||
remainder = dropTotal - float64(dropTicks)
|
||||
packetizer.SkipSamples(dropTicks)
|
||||
}
|
||||
packets := packetizer.Packetize(sample.Data, samples)
|
||||
|
||||
curTotal := tickF + remainder
|
||||
curTicks := uint32(curTotal)
|
||||
remainder = curTotal - float64(curTicks)
|
||||
|
||||
s.remainder = remainder
|
||||
packets := packetizer.Packetize(sample.Data, curTicks)
|
||||
s.mu.Unlock()
|
||||
|
||||
writeErrs := []error{}
|
||||
for _, p := range packets {
|
||||
|
||||
@@ -859,3 +859,130 @@ func TestBaseTrackLocalContext_HeaderExtensions_NilWhenUnset(t *testing.T) {
|
||||
var ctx baseTrackLocalContext
|
||||
assert.Nil(t, ctx.HeaderExtensions())
|
||||
}
|
||||
|
||||
func TestTrackLocalStaticSample_WriteSample_NoTimestampDrift(t *testing.T) {
|
||||
const clockRate = uint32(90000)
|
||||
frameDuration := time.Second / 60
|
||||
totalDuration := time.Hour
|
||||
numFrames := int(totalDuration / frameDuration)
|
||||
|
||||
track, err := NewTrackLocalStaticSample(
|
||||
RTPCodecCapability{MimeType: MimeTypeVP8, ClockRate: clockRate},
|
||||
"video", "pion",
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pack := &countingPacketizer{}
|
||||
|
||||
track.rtpTrack.mu.Lock()
|
||||
track.packetizer = pack
|
||||
track.clockRate = float64(clockRate)
|
||||
track.sequencer = rtp.NewRandomSequencer()
|
||||
track.rtpTrack.mu.Unlock()
|
||||
|
||||
for i := 0; i < numFrames; i++ {
|
||||
err := track.WriteSample(media.Sample{
|
||||
Data: []byte{0x00},
|
||||
Duration: frameDuration,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
expected := (uint64(numFrames) * uint64(frameDuration.Nanoseconds()) * uint64(clockRate)) / 1e9 //nolint:gosec
|
||||
got := pack.totalSamples
|
||||
|
||||
var drift uint64
|
||||
if got > expected {
|
||||
drift = got - expected
|
||||
} else {
|
||||
drift = expected - got
|
||||
}
|
||||
|
||||
t.Logf("frames=%d frameDuration=%s expectedTicks=%d gotTicks=%d driftTicks=%d driftSeconds=%.6f",
|
||||
numFrames, frameDuration, expected, got, drift, float64(drift)/float64(clockRate),
|
||||
)
|
||||
|
||||
assert.LessOrEqual(t, drift, uint64(1), "timestamp drift should be negligible")
|
||||
}
|
||||
|
||||
func TestTrackLocalStaticSample_WriteSample_DroppedPackets_NoDrift(t *testing.T) {
|
||||
const clockRate = uint32(90000)
|
||||
frameDuration := time.Second / 60
|
||||
totalDuration := time.Hour
|
||||
numFrames := int(totalDuration / frameDuration)
|
||||
tickF := frameDuration.Seconds() * float64(clockRate)
|
||||
|
||||
track, err := NewTrackLocalStaticSample(
|
||||
RTPCodecCapability{MimeType: MimeTypeVP8, ClockRate: clockRate},
|
||||
"video", "pion",
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pack := &countingPacketizer{}
|
||||
|
||||
track.rtpTrack.mu.Lock()
|
||||
track.packetizer = pack
|
||||
track.clockRate = float64(clockRate)
|
||||
track.sequencer = rtp.NewRandomSequencer()
|
||||
track.rtpTrack.mu.Unlock()
|
||||
|
||||
var expectedTotal uint64
|
||||
var remainder float64
|
||||
|
||||
for i := 0; i < numFrames; i++ {
|
||||
var drops uint16
|
||||
if (i+1)%300 == 0 {
|
||||
drops = uint16((i/300)%3 + 1) //nolint:gosec
|
||||
}
|
||||
|
||||
if drops > 0 {
|
||||
dropTotal := tickF*float64(drops) + remainder
|
||||
dropTicks := uint32(dropTotal)
|
||||
remainder = dropTotal - float64(dropTicks)
|
||||
expectedTotal += uint64(dropTicks)
|
||||
}
|
||||
|
||||
curTotal := tickF + remainder
|
||||
curTicks := uint32(curTotal)
|
||||
remainder = curTotal - float64(curTicks)
|
||||
expectedTotal += uint64(curTicks)
|
||||
|
||||
err := track.WriteSample(media.Sample{
|
||||
Data: []byte{0x00},
|
||||
Duration: frameDuration,
|
||||
PrevDroppedPackets: drops,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
got := pack.totalSamples
|
||||
|
||||
var drift uint64
|
||||
if got > expectedTotal {
|
||||
drift = got - expectedTotal
|
||||
} else {
|
||||
drift = expectedTotal - got
|
||||
}
|
||||
|
||||
t.Logf("frames=%d frameDuration=%s expectedTicks=%d gotTicks=%d driftTicks=%d driftSeconds=%.6f",
|
||||
numFrames, frameDuration, expectedTotal, got, drift, float64(drift)/float64(clockRate),
|
||||
)
|
||||
|
||||
assert.LessOrEqual(t, drift, uint64(1), "timestamp drift with drops should be negligible")
|
||||
}
|
||||
|
||||
type countingPacketizer struct {
|
||||
totalSamples uint64
|
||||
}
|
||||
|
||||
func (p *countingPacketizer) Packetize(payload []byte, samples uint32) []*rtp.Packet {
|
||||
p.totalSamples += uint64(samples)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *countingPacketizer) GeneratePadding(samples uint32) []*rtp.Packet { return nil }
|
||||
func (p *countingPacketizer) EnableAbsSendTime(value int) {}
|
||||
func (p *countingPacketizer) SkipSamples(skippedSamples uint32) {
|
||||
p.totalSamples += uint64(skippedSamples)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user