improve performance of H264/H265 SPS parsers

This commit is contained in:
aler9
2022-12-24 12:25:08 +01:00
parent db251994aa
commit 9acff114b7
4 changed files with 92 additions and 210 deletions

View File

@@ -5,19 +5,33 @@ import (
"fmt"
)
// HasSpace checks whether buffer has space for N bits.
func HasSpace(buf []byte, pos int, n int) error {
if n > ((len(buf) * 8) - pos) {
return fmt.Errorf("not enough bits")
}
return nil
}
// ReadBits reads N bits.
func ReadBits(buf []byte, pos *int, n int) (uint64, error) {
if n > ((len(buf) * 8) - *pos) {
return 0, fmt.Errorf("not enough bits")
err := HasSpace(buf, *pos, n)
if err != nil {
return 0, err
}
return ReadBitsUnsafe(buf, pos, n), nil
}
// ReadBitsUnsafe reads N bits.
func ReadBitsUnsafe(buf []byte, pos *int, n int) uint64 {
v := uint64(0)
res := 8 - (*pos & 0x07)
if n < res {
v := uint64((buf[*pos>>0x03] >> (res - n)) & (1<<n - 1))
*pos += n
return v, nil
return v
}
v = (v << res) | uint64(buf[*pos>>0x03]&(1<<res-1))
@@ -35,7 +49,7 @@ func ReadBits(buf []byte, pos *int, n int) (uint64, error) {
*pos += n
}
return v, nil
return v
}
// ReadGolombUnsigned reads an unsigned golomb-encoded value.
@@ -93,29 +107,17 @@ func ReadGolombSigned(buf []byte, pos *int) (int32, error) {
// ReadFlag reads a boolean flag.
func ReadFlag(buf []byte, pos *int) (bool, error) {
if (len(buf)*8 - *pos) == 0 {
return false, fmt.Errorf("not enough bits")
err := HasSpace(buf, *pos, 1)
if err != nil {
return false, err
}
return ReadFlagUnsafe(buf, pos), nil
}
// ReadFlagUnsafe reads a boolean flag.
func ReadFlagUnsafe(buf []byte, pos *int) bool {
b := (buf[*pos>>0x03] >> (7 - (*pos & 0x07))) & 0x01
*pos++
return b == 1, nil
}
// ReadUint8 reads a uint8.
func ReadUint8(buf []byte, pos *int) (uint8, error) {
v, err := ReadBits(buf, pos, 8)
return uint8(v), err
}
// ReadUint16 reads a uint16.
func ReadUint16(buf []byte, pos *int) (uint16, error) {
v, err := ReadBits(buf, pos, 16)
return uint16(v), err
}
// ReadUint32 reads a uint32.
func ReadUint32(buf []byte, pos *int) (uint32, error) {
v, err := ReadBits(buf, pos, 32)
return uint32(v), err
return b == 1
}

View File

@@ -86,14 +86,3 @@ func TestReadFlagError(t *testing.T) {
_, err := ReadFlag(buf, &pos)
require.EqualError(t, err, "not enough bits")
}
func TestUint(t *testing.T) {
buf := []byte{0x45, 0x46, 0x47, 0x48, 0x49, 0x50, 0x51}
pos := 0
u8, _ := ReadUint8(buf, &pos)
require.Equal(t, uint8(0x45), u8)
u16, _ := ReadUint16(buf, &pos)
require.Equal(t, uint16(0x4647), u16)
u32, _ := ReadUint32(buf, &pos)
require.Equal(t, uint32(0x48495051), u32)
}

View File

@@ -56,17 +56,13 @@ func (h *SPS_HRD) unmarshal(buf []byte, pos *int) error {
return err
}
tmp, err := bits.ReadBits(buf, pos, 4)
err = bits.HasSpace(buf, *pos, 8)
if err != nil {
return err
}
h.BitRateScale = uint8(tmp)
tmp, err = bits.ReadBits(buf, pos, 4)
if err != nil {
return err
}
h.CpbSizeScale = uint8(tmp)
h.BitRateScale = uint8(bits.ReadBitsUnsafe(buf, pos, 4))
h.CpbSizeScale = uint8(bits.ReadBitsUnsafe(buf, pos, 4))
for i := uint32(0); i <= h.CpbCntMinus1; i++ {
v, err := bits.ReadGolombUnsigned(buf, pos)
@@ -88,29 +84,15 @@ func (h *SPS_HRD) unmarshal(buf []byte, pos *int) error {
h.CbrFlag = append(h.CbrFlag, vb)
}
tmp, err = bits.ReadBits(buf, pos, 5)
err = bits.HasSpace(buf, *pos, 5+5+5+5)
if err != nil {
return err
}
h.InitialCpbRemovalDelayLengthMinus1 = uint8(tmp)
tmp, err = bits.ReadBits(buf, pos, 5)
if err != nil {
return err
}
h.CpbRemovalDelayLengthMinus1 = uint8(tmp)
tmp, err = bits.ReadBits(buf, pos, 5)
if err != nil {
return err
}
h.DpbOutputDelayLengthMinus1 = uint8(tmp)
tmp, err = bits.ReadBits(buf, pos, 5)
if err != nil {
return err
}
h.TimeOffsetLength = uint8(tmp)
h.InitialCpbRemovalDelayLengthMinus1 = uint8(bits.ReadBitsUnsafe(buf, pos, 5))
h.CpbRemovalDelayLengthMinus1 = uint8(bits.ReadBitsUnsafe(buf, pos, 5))
h.DpbOutputDelayLengthMinus1 = uint8(bits.ReadBitsUnsafe(buf, pos, 5))
h.TimeOffsetLength = uint8(bits.ReadBitsUnsafe(buf, pos, 5))
return nil
}
@@ -123,21 +105,14 @@ type SPS_TimingInfo struct { //nolint:revive
}
func (t *SPS_TimingInfo) unmarshal(buf []byte, pos *int) error {
var err error
t.NumUnitsInTick, err = bits.ReadUint32(buf, pos)
err := bits.HasSpace(buf, *pos, 32+32+1)
if err != nil {
return err
}
t.TimeScale, err = bits.ReadUint32(buf, pos)
if err != nil {
return err
}
t.FixedFrameRateFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
t.NumUnitsInTick = uint32(bits.ReadBitsUnsafe(buf, pos, 32))
t.TimeScale = uint32(bits.ReadBitsUnsafe(buf, pos, 32))
t.FixedFrameRateFlag = bits.ReadFlagUnsafe(buf, pos)
return nil
}
@@ -243,21 +218,20 @@ func (v *SPS_VUI) unmarshal(buf []byte, pos *int) error {
}
if v.AspectRatioInfoPresentFlag {
v.AspectRatioIdc, err = bits.ReadUint8(buf, pos)
tmp, err := bits.ReadBits(buf, pos, 8)
if err != nil {
return err
}
v.AspectRatioIdc = uint8(tmp)
if v.AspectRatioIdc == 255 { // Extended_SAR
v.SarWidth, err = bits.ReadUint16(buf, pos)
err := bits.HasSpace(buf, *pos, 32)
if err != nil {
return err
}
v.SarHeight, err = bits.ReadUint16(buf, pos)
if err != nil {
return err
}
v.SarWidth = uint16(bits.ReadBitsUnsafe(buf, pos, 16))
v.SarHeight = uint16(bits.ReadBitsUnsafe(buf, pos, 16))
}
}
@@ -279,37 +253,24 @@ func (v *SPS_VUI) unmarshal(buf []byte, pos *int) error {
}
if v.VideoSignalTypePresentFlag {
tmp, err := bits.ReadBits(buf, pos, 3)
if err != nil {
return err
}
v.VideoFormat = uint8(tmp)
v.VideoFullRangeFlag, err = bits.ReadFlag(buf, pos)
err := bits.HasSpace(buf, *pos, 5)
if err != nil {
return err
}
v.ColourDescriptionPresentFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
v.VideoFormat = uint8(bits.ReadBitsUnsafe(buf, pos, 3))
v.VideoFullRangeFlag = bits.ReadFlagUnsafe(buf, pos)
v.ColourDescriptionPresentFlag = bits.ReadFlagUnsafe(buf, pos)
if v.ColourDescriptionPresentFlag {
v.ColourPrimaries, err = bits.ReadUint8(buf, pos)
err := bits.HasSpace(buf, *pos, 24)
if err != nil {
return err
}
v.TransferCharacteristics, err = bits.ReadUint8(buf, pos)
if err != nil {
return err
}
v.MatrixCoefficients, err = bits.ReadUint8(buf, pos)
if err != nil {
return err
}
v.ColourPrimaries = uint8(bits.ReadBitsUnsafe(buf, pos, 8))
v.TransferCharacteristics = uint8(bits.ReadBitsUnsafe(buf, pos, 8))
v.MatrixCoefficients = uint8(bits.ReadBitsUnsafe(buf, pos, 8))
}
}
@@ -490,7 +451,7 @@ func (s *SPS) Unmarshal(buf []byte) error {
buf = EmulationPreventionRemove(buf)
if len(buf) < 4 {
return fmt.Errorf("buffer too short")
return fmt.Errorf("not enough bits")
}
forbidden := buf[0] >> 7

View File

@@ -46,90 +46,31 @@ type SPS_ProfileLevelTier struct { //nolint:revive
}
func (p *SPS_ProfileLevelTier) unmarshal(buf []byte, pos *int, maxNumSubLayersMinus1 uint8) error {
tmp, err := bits.ReadBits(buf, pos, 2)
err := bits.HasSpace(buf, *pos, 8+32+12+34+8)
if err != nil {
return err
}
p.GeneralProfileSpace = uint8(tmp)
tmp, err = bits.ReadBits(buf, pos, 1)
if err != nil {
return err
}
p.GeneralTierFlag = uint8(tmp)
tmp, err = bits.ReadBits(buf, pos, 5)
if err != nil {
return err
}
p.GeneralProfileIdc = uint8(tmp)
p.GeneralProfileSpace = uint8(bits.ReadBitsUnsafe(buf, pos, 2))
p.GeneralTierFlag = uint8(bits.ReadBitsUnsafe(buf, pos, 1))
p.GeneralProfileIdc = uint8(bits.ReadBitsUnsafe(buf, pos, 5))
for j := 0; j < 32; j++ {
p.GeneralProfileCompatibilityFlag[j], err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.GeneralProfileCompatibilityFlag[j] = bits.ReadFlagUnsafe(buf, pos)
}
p.ProgressiveSourceFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.InterlacedSourceFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.NonPackedConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.FrameOnlyConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.Max12bitConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.Max10bitConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.Max8bitConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.Max422ChromeConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.Max420ChromaConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.IntraConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.OnePictureOnlyConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.LowerBitRateConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.ProgressiveSourceFlag = bits.ReadFlagUnsafe(buf, pos)
p.InterlacedSourceFlag = bits.ReadFlagUnsafe(buf, pos)
p.NonPackedConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
p.FrameOnlyConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
p.Max12bitConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
p.Max10bitConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
p.Max8bitConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
p.Max422ChromeConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
p.Max420ChromaConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
p.IntraConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
p.OnePictureOnlyConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
p.LowerBitRateConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
if p.GeneralProfileIdc == 5 ||
p.GeneralProfileIdc == 9 ||
@@ -139,43 +80,36 @@ func (p *SPS_ProfileLevelTier) unmarshal(buf []byte, pos *int, maxNumSubLayersMi
p.GeneralProfileCompatibilityFlag[9] ||
p.GeneralProfileCompatibilityFlag[10] ||
p.GeneralProfileCompatibilityFlag[11] {
p.Max14BitConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.Max14BitConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
*pos += 33
} else {
*pos += 34
}
tmp, err = bits.ReadBits(buf, pos, 8)
if err != nil {
return err
}
p.LevelIdc = uint8(tmp)
p.LevelIdc = uint8(bits.ReadBitsUnsafe(buf, pos, 8))
if maxNumSubLayersMinus1 > 0 {
p.SubLayerProfilePresentFlag = make([]bool, maxNumSubLayersMinus1)
p.SubLayerLevelPresentFlag = make([]bool, maxNumSubLayersMinus1)
err := bits.HasSpace(buf, *pos, int(2*maxNumSubLayersMinus1))
if err != nil {
return err
}
}
for j := uint8(0); j < maxNumSubLayersMinus1; j++ {
p.SubLayerProfilePresentFlag[j], err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.SubLayerLevelPresentFlag[j], err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.SubLayerProfilePresentFlag[j] = bits.ReadFlagUnsafe(buf, pos)
p.SubLayerLevelPresentFlag[j] = bits.ReadFlagUnsafe(buf, pos)
}
if maxNumSubLayersMinus1 > 0 {
for i := maxNumSubLayersMinus1; i < 8; i++ {
*pos += 2
err := bits.HasSpace(buf, *pos, int(8-maxNumSubLayersMinus1)*2)
if err != nil {
return err
}
*pos += int(8-maxNumSubLayersMinus1) * 2
}
for i := uint8(0); i < maxNumSubLayersMinus1; i++ {
@@ -245,6 +179,10 @@ type SPS struct {
// Unmarshal decodes a SPS from bytes.
func (s *SPS) Unmarshal(buf []byte) error {
if len(buf) < 2 {
return fmt.Errorf("not enough bits")
}
typ := NALUType((buf[0] >> 1) & 0b111111)
if typ != NALUTypeSPS {
@@ -255,22 +193,14 @@ func (s *SPS) Unmarshal(buf []byte) error {
buf = h264.EmulationPreventionRemove(buf)
pos := 0
tmp, err := bits.ReadBits(buf, &pos, 4)
err := bits.HasSpace(buf, pos, 8)
if err != nil {
return err
}
s.VPSID = uint8(tmp)
tmp, err = bits.ReadBits(buf, &pos, 3)
if err != nil {
return err
}
s.MaxNumSubLayersMinus1 = uint8(tmp)
s.TemporalIDNestingFlag, err = bits.ReadFlag(buf, &pos)
if err != nil {
return err
}
s.VPSID = uint8(bits.ReadBitsUnsafe(buf, &pos, 4))
s.MaxNumSubLayersMinus1 = uint8(bits.ReadBitsUnsafe(buf, &pos, 3))
s.TemporalIDNestingFlag = bits.ReadFlagUnsafe(buf, &pos)
err = s.ProfileLevelTier.unmarshal(buf, &pos, s.MaxNumSubLayersMinus1)
if err != nil {