h264: simplify DTS extractor usage

This commit is contained in:
aler9
2022-06-02 14:05:47 +02:00
parent 622fe12c4b
commit 277e89f3ac
3 changed files with 38 additions and 45 deletions

View File

@@ -2,7 +2,6 @@ package main
import ( import (
"bufio" "bufio"
"bytes"
"context" "context"
"log" "log"
"os" "os"
@@ -27,7 +26,6 @@ type mpegtsEncoder struct {
dtsExtractor *h264.DTSExtractor dtsExtractor *h264.DTSExtractor
firstPacketWritten bool firstPacketWritten bool
startPTS time.Duration startPTS time.Duration
spsp *h264.SPS
firstIDRReceived bool firstIDRReceived bool
} }
@@ -46,15 +44,6 @@ func newMPEGTSEncoder(sps []byte, pps []byte) (*mpegtsEncoder, error) {
}) })
mux.SetPCRPID(256) mux.SetPCRPID(256)
var spsp *h264.SPS
if sps != nil {
spsp = &h264.SPS{}
err := spsp.Unmarshal(sps)
if err != nil {
return nil, err
}
}
return &mpegtsEncoder{ return &mpegtsEncoder{
sps: sps, sps: sps,
pps: pps, pps: pps,
@@ -62,7 +51,6 @@ func newMPEGTSEncoder(sps []byte, pps []byte) (*mpegtsEncoder, error) {
b: b, b: b,
mux: mux, mux: mux,
dtsExtractor: h264.NewDTSExtractor(), dtsExtractor: h264.NewDTSExtractor(),
spsp: spsp,
}, nil }, nil
} }
@@ -84,20 +72,10 @@ func (e *mpegtsEncoder) encode(nalus [][]byte, pts time.Duration) error {
{byte(h264.NALUTypeAccessUnitDelimiter), 240}, {byte(h264.NALUTypeAccessUnitDelimiter), 240},
} }
idrPresent := h264.IDRPresent(nalus)
for _, nalu := range nalus { for _, nalu := range nalus {
typ := h264.NALUType(nalu[0] & 0x1F) typ := h264.NALUType(nalu[0] & 0x1F)
switch typ { switch typ {
case h264.NALUTypeSPS: case h264.NALUTypeSPS:
if e.sps == nil || !bytes.Equal(e.sps, nalu) {
var spsp h264.SPS
err := spsp.Unmarshal(nalu)
if err != nil {
return err
}
e.spsp = &spsp
}
e.sps = append([]byte(nil), nalu...) e.sps = append([]byte(nil), nalu...)
continue continue
@@ -122,6 +100,7 @@ func (e *mpegtsEncoder) encode(nalus [][]byte, pts time.Duration) error {
return nil return nil
} }
idrPresent := h264.IDRPresent(nalus)
if !e.firstIDRReceived && !idrPresent { if !e.firstIDRReceived && !idrPresent {
return nil return nil
} }
@@ -135,7 +114,7 @@ func (e *mpegtsEncoder) encode(nalus [][]byte, pts time.Duration) error {
pts -= e.startPTS pts -= e.startPTS
dts, err := e.dtsExtractor.Extract(filteredNALUs, idrPresent, pts, e.spsp) dts, err := e.dtsExtractor.Extract(filteredNALUs, pts)
if err != nil { if err != nil {
return err return err
} }
@@ -159,7 +138,7 @@ func (e *mpegtsEncoder) encode(nalus [][]byte, pts time.Duration) error {
_, err = e.mux.WriteData(&astits.MuxerData{ _, err = e.mux.WriteData(&astits.MuxerData{
PID: 256, PID: 256,
AdaptationField: &astits.PacketAdaptationField{ AdaptationField: &astits.PacketAdaptationField{
RandomAccessIndicator: h264.IDRPresent(filteredNALUs), RandomAccessIndicator: idrPresent,
}, },
PES: &astits.PESData{ PES: &astits.PESData{
Header: &astits.PESHeader{ Header: &astits.PESHeader{

View File

@@ -96,6 +96,8 @@ func getPOCDiff(poc1 uint32, poc2 uint32, sps *SPS) int32 {
// DTSExtractor is a utility that allows to extract NALU DTS from PTS. // DTSExtractor is a utility that allows to extract NALU DTS from PTS.
type DTSExtractor struct { type DTSExtractor struct {
sps []byte
spsp *SPS
prevPTS time.Duration prevPTS time.Duration
prevDTS time.Duration prevDTS time.Duration
prevPOCDiff int32 prevPOCDiff int32
@@ -109,25 +111,50 @@ func NewDTSExtractor() *DTSExtractor {
func (d *DTSExtractor) extractInner( func (d *DTSExtractor) extractInner(
nalus [][]byte, nalus [][]byte,
idrPresent bool,
pts time.Duration, pts time.Duration,
sps *SPS,
) (time.Duration, int32, error) { ) (time.Duration, int32, error) {
if idrPresent || sps.PicOrderCntType == 2 { idrPresent := false
for _, nalu := range nalus {
typ := NALUType(nalu[0] & 0x1F)
switch typ {
// parse SPS
case NALUTypeSPS:
if d.sps == nil || !bytes.Equal(d.sps, nalu) {
var spsp SPS
err := spsp.Unmarshal(nalu)
if err != nil {
return 0, 0, err
}
d.sps = append([]byte(nil), nalu...)
d.spsp = &spsp
}
// set IDR present flag
case NALUTypeIDR:
idrPresent = true
}
}
if d.spsp == nil {
return 0, 0, fmt.Errorf("SPS not received yet")
}
if idrPresent || d.spsp.PicOrderCntType == 2 {
d.expectedPOC = 0 d.expectedPOC = 0
return pts, 0, nil return pts, 0, nil
} }
// compute expectedPOC immediately in order to store it even in case of errors // compute expectedPOC immediately in order to store it even in case of errors
d.expectedPOC += 2 d.expectedPOC += 2
d.expectedPOC &= ((1 << (sps.Log2MaxPicOrderCntLsbMinus4 + 4)) - 1) d.expectedPOC &= ((1 << (d.spsp.Log2MaxPicOrderCntLsbMinus4 + 4)) - 1)
poc, err := getNALUSPOC(nalus, sps) poc, err := getNALUSPOC(nalus, d.spsp)
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
} }
pocDiff := getPOCDiff(poc, d.expectedPOC, sps) pocDiff := getPOCDiff(poc, d.expectedPOC, d.spsp)
if pocDiff == 0 { if pocDiff == 0 {
return pts, pocDiff, nil return pts, pocDiff, nil
@@ -149,11 +176,9 @@ func (d *DTSExtractor) extractInner(
// Extract extracts the DTS of a NALU group. // Extract extracts the DTS of a NALU group.
func (d *DTSExtractor) Extract( func (d *DTSExtractor) Extract(
nalus [][]byte, nalus [][]byte,
idrPresent bool,
pts time.Duration, pts time.Duration,
sps *SPS,
) (time.Duration, error) { ) (time.Duration, error) {
dts, pocDiff, err := d.extractInner(nalus, idrPresent, pts, sps) dts, pocDiff, err := d.extractInner(nalus, pts)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@@ -135,20 +135,9 @@ func TestDTSExtractor(t *testing.T) {
} }
ex := NewDTSExtractor() ex := NewDTSExtractor()
sps := &SPS{}
for _, sample := range sequence { for _, sample := range sequence {
idrPresent := IDRPresent(sample.nalus) dts, err := ex.Extract(sample.nalus, sample.pts)
for _, nalu := range sample.nalus {
if NALUType(nalu[0]&0x1F) == NALUTypeSPS {
err := sps.Unmarshal(nalu)
require.NoError(t, err)
break
}
}
dts, err := ex.Extract(sample.nalus, idrPresent, sample.pts, sps)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, sample.dts, dts) require.Equal(t, sample.dts, dts)
} }