From 277e89f3ac1889dd32028f1ab7d8cd5eca4404df Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Thu, 2 Jun 2022 14:05:47 +0200 Subject: [PATCH] h264: simplify DTS extractor usage --- .../mpegtsencoder.go | 27 ++---------- pkg/h264/dtsextractor.go | 43 +++++++++++++++---- pkg/h264/dtsextractor_test.go | 13 +----- 3 files changed, 38 insertions(+), 45 deletions(-) diff --git a/examples/client-read-h264-save-to-disk/mpegtsencoder.go b/examples/client-read-h264-save-to-disk/mpegtsencoder.go index 8d1581de..737d9dc2 100644 --- a/examples/client-read-h264-save-to-disk/mpegtsencoder.go +++ b/examples/client-read-h264-save-to-disk/mpegtsencoder.go @@ -2,7 +2,6 @@ package main import ( "bufio" - "bytes" "context" "log" "os" @@ -27,7 +26,6 @@ type mpegtsEncoder struct { dtsExtractor *h264.DTSExtractor firstPacketWritten bool startPTS time.Duration - spsp *h264.SPS firstIDRReceived bool } @@ -46,15 +44,6 @@ func newMPEGTSEncoder(sps []byte, pps []byte) (*mpegtsEncoder, error) { }) mux.SetPCRPID(256) - var spsp *h264.SPS - if sps != nil { - spsp = &h264.SPS{} - err := spsp.Unmarshal(sps) - if err != nil { - return nil, err - } - } - return &mpegtsEncoder{ sps: sps, pps: pps, @@ -62,7 +51,6 @@ func newMPEGTSEncoder(sps []byte, pps []byte) (*mpegtsEncoder, error) { b: b, mux: mux, dtsExtractor: h264.NewDTSExtractor(), - spsp: spsp, }, nil } @@ -84,20 +72,10 @@ func (e *mpegtsEncoder) encode(nalus [][]byte, pts time.Duration) error { {byte(h264.NALUTypeAccessUnitDelimiter), 240}, } - idrPresent := h264.IDRPresent(nalus) - for _, nalu := range nalus { typ := h264.NALUType(nalu[0] & 0x1F) switch typ { case h264.NALUTypeSPS: - if e.sps == nil || !bytes.Equal(e.sps, nalu) { - var spsp h264.SPS - err := spsp.Unmarshal(nalu) - if err != nil { - return err - } - e.spsp = &spsp - } e.sps = append([]byte(nil), nalu...) continue @@ -122,6 +100,7 @@ func (e *mpegtsEncoder) encode(nalus [][]byte, pts time.Duration) error { return nil } + idrPresent := h264.IDRPresent(nalus) if !e.firstIDRReceived && !idrPresent { return nil } @@ -135,7 +114,7 @@ func (e *mpegtsEncoder) encode(nalus [][]byte, pts time.Duration) error { pts -= e.startPTS - dts, err := e.dtsExtractor.Extract(filteredNALUs, idrPresent, pts, e.spsp) + dts, err := e.dtsExtractor.Extract(filteredNALUs, pts) if err != nil { return err } @@ -159,7 +138,7 @@ func (e *mpegtsEncoder) encode(nalus [][]byte, pts time.Duration) error { _, err = e.mux.WriteData(&astits.MuxerData{ PID: 256, AdaptationField: &astits.PacketAdaptationField{ - RandomAccessIndicator: h264.IDRPresent(filteredNALUs), + RandomAccessIndicator: idrPresent, }, PES: &astits.PESData{ Header: &astits.PESHeader{ diff --git a/pkg/h264/dtsextractor.go b/pkg/h264/dtsextractor.go index 6a016d0f..d21d83f2 100644 --- a/pkg/h264/dtsextractor.go +++ b/pkg/h264/dtsextractor.go @@ -96,6 +96,8 @@ func getPOCDiff(poc1 uint32, poc2 uint32, sps *SPS) int32 { // DTSExtractor is a utility that allows to extract NALU DTS from PTS. type DTSExtractor struct { + sps []byte + spsp *SPS prevPTS time.Duration prevDTS time.Duration prevPOCDiff int32 @@ -109,25 +111,50 @@ func NewDTSExtractor() *DTSExtractor { func (d *DTSExtractor) extractInner( nalus [][]byte, - idrPresent bool, pts time.Duration, - sps *SPS, ) (time.Duration, int32, error) { - if idrPresent || sps.PicOrderCntType == 2 { + idrPresent := false + + for _, nalu := range nalus { + typ := NALUType(nalu[0] & 0x1F) + switch typ { + // parse SPS + case NALUTypeSPS: + if d.sps == nil || !bytes.Equal(d.sps, nalu) { + var spsp SPS + err := spsp.Unmarshal(nalu) + if err != nil { + return 0, 0, err + } + d.sps = append([]byte(nil), nalu...) + d.spsp = &spsp + } + + // set IDR present flag + case NALUTypeIDR: + idrPresent = true + } + } + + if d.spsp == nil { + return 0, 0, fmt.Errorf("SPS not received yet") + } + + if idrPresent || d.spsp.PicOrderCntType == 2 { d.expectedPOC = 0 return pts, 0, nil } // compute expectedPOC immediately in order to store it even in case of errors d.expectedPOC += 2 - d.expectedPOC &= ((1 << (sps.Log2MaxPicOrderCntLsbMinus4 + 4)) - 1) + d.expectedPOC &= ((1 << (d.spsp.Log2MaxPicOrderCntLsbMinus4 + 4)) - 1) - poc, err := getNALUSPOC(nalus, sps) + poc, err := getNALUSPOC(nalus, d.spsp) if err != nil { return 0, 0, err } - pocDiff := getPOCDiff(poc, d.expectedPOC, sps) + pocDiff := getPOCDiff(poc, d.expectedPOC, d.spsp) if pocDiff == 0 { return pts, pocDiff, nil @@ -149,11 +176,9 @@ func (d *DTSExtractor) extractInner( // Extract extracts the DTS of a NALU group. func (d *DTSExtractor) Extract( nalus [][]byte, - idrPresent bool, pts time.Duration, - sps *SPS, ) (time.Duration, error) { - dts, pocDiff, err := d.extractInner(nalus, idrPresent, pts, sps) + dts, pocDiff, err := d.extractInner(nalus, pts) if err != nil { return 0, err } diff --git a/pkg/h264/dtsextractor_test.go b/pkg/h264/dtsextractor_test.go index d5c5add2..50e38fd4 100644 --- a/pkg/h264/dtsextractor_test.go +++ b/pkg/h264/dtsextractor_test.go @@ -135,20 +135,9 @@ func TestDTSExtractor(t *testing.T) { } ex := NewDTSExtractor() - sps := &SPS{} for _, sample := range sequence { - idrPresent := IDRPresent(sample.nalus) - - for _, nalu := range sample.nalus { - if NALUType(nalu[0]&0x1F) == NALUTypeSPS { - err := sps.Unmarshal(nalu) - require.NoError(t, err) - break - } - } - - dts, err := ex.Extract(sample.nalus, idrPresent, sample.pts, sps) + dts, err := ex.Extract(sample.nalus, sample.pts) require.NoError(t, err) require.Equal(t, sample.dts, dts) }