From 71882c1d5ccb262eed55fef06a675aadfd1b97d5 Mon Sep 17 00:00:00 2001 From: Quentin Renard Date: Sat, 22 Apr 2023 14:32:01 +0200 Subject: [PATCH] Rate enforcer pts reference is now global --- libav/pts_reference.go | 48 ---------------- libav/rate_enforcer.go | 125 ++++++++++++++++++----------------------- 2 files changed, 54 insertions(+), 119 deletions(-) diff --git a/libav/pts_reference.go b/libav/pts_reference.go index b860762..822b3c2 100644 --- a/libav/pts_reference.go +++ b/libav/pts_reference.go @@ -5,7 +5,6 @@ import ( "time" "github.com/asticode/go-astiav" - "github.com/asticode/go-astiencoder" ) type PTSReference struct { @@ -63,50 +62,3 @@ func (r *PTSReference) Update(pts int64, t time.Time, timeBase astiav.Rational) defer r.m.Unlock() return r.updateUnsafe(pts, t, timeBase) } - -type ptsReferences struct { - m *sync.Mutex - p map[astiencoder.Node]*PTSReference -} - -func newPTSReferences(p map[astiencoder.Node]*PTSReference) *ptsReferences { - // Create pts references - r := &ptsReferences{ - m: &sync.Mutex{}, - p: make(map[astiencoder.Node]*PTSReference), - } - - // Copy pts references - for k, v := range p { - r.p[k] = v - } - return r -} - -func (rs *ptsReferences) get(n astiencoder.Node) *PTSReference { - rs.m.Lock() - defer rs.m.Unlock() - return rs.p[n] -} - -func (rs *ptsReferences) set(n astiencoder.Node, r *PTSReference) { - rs.m.Lock() - defer rs.m.Unlock() - rs.p[n] = r -} - -func (rs *ptsReferences) lockAll() { - rs.m.Lock() - defer rs.m.Unlock() - for _, r := range rs.p { - r.lock() - } -} - -func (rs *ptsReferences) unlockAll() { - rs.m.Lock() - defer rs.m.Unlock() - for _, r := range rs.p { - r.unlock() - } -} diff --git a/libav/rate_enforcer.go b/libav/rate_enforcer.go index 2f513e1..a104330 100644 --- a/libav/rate_enforcer.go +++ b/libav/rate_enforcer.go @@ -17,25 +17,26 @@ var countRateEnforcer uint64 // RateEnforcer represents an object capable of enforcing rate based on PTS type RateEnforcer struct { *astiencoder.BaseNode - c *astikit.Chan - currentNode astiencoder.Node - d *frameDispatcher - delay time.Duration - descriptor Descriptor - desiredNode astiencoder.Node - eh *astiencoder.EventHandler - f RateEnforcerFiller - frames map[astiencoder.Node][]*astiav.Frame - m *sync.Mutex - outputCtx Context - p *framePool - period time.Duration - ptsReferences *ptsReferences - restamper FrameRestamper - statFramesDelay *astikit.AtomicDuration - statFramesFilled uint64 - statFramesProcessed uint64 - statFramesReceived uint64 + c *astikit.Chan + currentNode astiencoder.Node + d *frameDispatcher + delay time.Duration + descriptor Descriptor + desiredNode astiencoder.Node + eh *astiencoder.EventHandler + f RateEnforcerFiller + frames map[astiencoder.Node][]*astiav.Frame + m *sync.Mutex + outputCtx Context + p *framePool + period time.Duration + ptsReference *PTSReference + restamper FrameRestamper + statFramesDelay *astikit.AtomicDuration + statFramesFilled uint64 + statFramesProcessed uint64 + statFramesReceived uint64 + updatePTSReferenceOnFrame bool } // RateEnforcerOptions represents rate enforcer options @@ -44,9 +45,10 @@ type RateEnforcerOptions struct { Filler RateEnforcerFiller Node astiencoder.NodeOptions // Both FrameRate and TimeBase are mandatory - OutputCtx Context - PTSReferences map[astiencoder.Node]*PTSReference - Restamper FrameRestamper + OutputCtx Context + PTSReference *PTSReference + Restamper FrameRestamper + UpdatePTSReferenceOnFrame bool } // NewRateEnforcer creates a new rate enforcer @@ -57,18 +59,19 @@ func NewRateEnforcer(o RateEnforcerOptions, eh *astiencoder.EventHandler, c *ast // Create rate enforcer r = &RateEnforcer{ - c: astikit.NewChan(astikit.ChanOptions{ProcessAll: true}), - delay: o.Delay, - descriptor: o.OutputCtx.Descriptor(), - frames: make(map[astiencoder.Node][]*astiav.Frame), - eh: eh, - f: o.Filler, - m: &sync.Mutex{}, - outputCtx: o.OutputCtx, - period: time.Duration(float64(1e9) / o.OutputCtx.FrameRate.ToDouble()), - ptsReferences: newPTSReferences(o.PTSReferences), - restamper: o.Restamper, - statFramesDelay: astikit.NewAtomicDuration(0), + c: astikit.NewChan(astikit.ChanOptions{ProcessAll: true}), + delay: o.Delay, + descriptor: o.OutputCtx.Descriptor(), + frames: make(map[astiencoder.Node][]*astiav.Frame), + eh: eh, + f: o.Filler, + m: &sync.Mutex{}, + outputCtx: o.OutputCtx, + period: time.Duration(float64(1e9) / o.OutputCtx.FrameRate.ToDouble()), + ptsReference: o.PTSReference, + restamper: o.Restamper, + statFramesDelay: astikit.NewAtomicDuration(0), + updatePTSReferenceOnFrame: o.UpdatePTSReferenceOnFrame, } // Create base node @@ -85,6 +88,11 @@ func NewRateEnforcer(o RateEnforcerOptions, eh *astiencoder.EventHandler, c *ast r.f = newPreviousRateEnforcerFiller(r, r.eh, r.p) } + // Create pts reference + if r.ptsReference == nil { + r.ptsReference = NewPTSReference() + } + // Add stat options r.addStatOptions() return @@ -275,31 +283,18 @@ func (r *RateEnforcer) HandleFrame(p FrameHandlerPayload) { r.frames[p.Node] = append(r.frames[p.Node], f) } - // Get pts reference - ptsReference := r.ptsReferences.get(p.Node) - // Lock pts reference - if ptsReference != nil { - ptsReference.lock() - defer ptsReference.unlock() - } + r.ptsReference.lock() + defer r.ptsReference.unlock() // Increment frames delay before updating pts reference - if ptsReference != nil && !ptsReference.isZeroUnsafe() && (r.currentNode == p.Node || (r.currentNode == nil && r.desiredNode == p.Node)) { - r.statFramesDelay.Add(t.Sub(ptsReference.timeFromPTSUnsafe(f.Pts(), r.outputCtx.TimeBase))) - } - - // Make sure pts reference exists - if ptsReference == nil { - ptsReference = NewPTSReference() - ptsReference.lock() - defer ptsReference.unlock() - r.ptsReferences.set(p.Node, ptsReference) + if r.ptsReference != nil && !r.ptsReference.isZeroUnsafe() && (r.currentNode == p.Node || (r.currentNode == nil && r.desiredNode == p.Node)) { + r.statFramesDelay.Add(t.Sub(r.ptsReference.timeFromPTSUnsafe(f.Pts(), r.outputCtx.TimeBase))) } // Update pts reference - if ptsReference.isZeroUnsafe() || ptsReference.timeFromPTSUnsafe(f.Pts(), r.outputCtx.TimeBase).After(t) { - ptsReference.updateUnsafe(f.Pts(), t, r.outputCtx.TimeBase) + if r.ptsReference.isZeroUnsafe() || (r.updatePTSReferenceOnFrame && r.ptsReference.timeFromPTSUnsafe(f.Pts(), r.outputCtx.TimeBase).After(t)) { + r.ptsReference.updateUnsafe(f.Pts(), t, r.outputCtx.TimeBase) } }) }) @@ -373,9 +368,9 @@ func (r *RateEnforcer) tickFunc(ctx context.Context, nextAt *time.Time) (stop bo } func (r *RateEnforcer) frame(from time.Time) (f *astiav.Frame, n astiencoder.Node, filled bool) { - // Lock pts references - r.ptsReferences.lockAll() - defer r.ptsReferences.unlockAll() + // Lock pts reference + r.ptsReference.lock() + defer r.ptsReference.unlock() // Get to to := from.Add(r.period) @@ -408,15 +403,9 @@ func (r *RateEnforcer) frame(from time.Time) (f *astiav.Frame, n astiencoder.Nod } func (r *RateEnforcer) frameForNode(n astiencoder.Node, from, to time.Time) (f *astiav.Frame) { - // Get pts reference - ptsReference := r.ptsReferences.get(n) - if ptsReference == nil { - return - } - // Get pts boundaries - ptsMin := ptsReference.ptsFromTimeUnsafe(from, r.outputCtx.TimeBase) - ptsMax := ptsReference.ptsFromTimeUnsafe(to, r.outputCtx.TimeBase) + ptsMin := r.ptsReference.ptsFromTimeUnsafe(from, r.outputCtx.TimeBase) + ptsMax := r.ptsReference.ptsFromTimeUnsafe(to, r.outputCtx.TimeBase) // Loop through frames for idx := range r.frames[n] { @@ -432,14 +421,8 @@ func (r *RateEnforcer) frameForNode(n astiencoder.Node, from, to time.Time) (f * func (r *RateEnforcer) cleanup(to time.Time) { // Loop through nodes for n := range r.frames { - // Get pts reference - ptsReference := r.ptsReferences.get(n) - if ptsReference == nil { - continue - } - // Get max pts - ptsMax := ptsReference.ptsFromTimeUnsafe(to, r.outputCtx.TimeBase) + ptsMax := r.ptsReference.ptsFromTimeUnsafe(to, r.outputCtx.TimeBase) // Loop through frames for idx := 0; idx < len(r.frames[n]); idx++ {