improve performance of H264/H265 SPS parsers

This commit is contained in:
aler9
2022-12-24 12:25:08 +01:00
parent db251994aa
commit 9acff114b7
4 changed files with 92 additions and 210 deletions

View File

@@ -5,19 +5,33 @@ import (
"fmt" "fmt"
) )
// HasSpace checks whether buffer has space for N bits.
func HasSpace(buf []byte, pos int, n int) error {
if n > ((len(buf) * 8) - pos) {
return fmt.Errorf("not enough bits")
}
return nil
}
// ReadBits reads N bits. // ReadBits reads N bits.
func ReadBits(buf []byte, pos *int, n int) (uint64, error) { func ReadBits(buf []byte, pos *int, n int) (uint64, error) {
if n > ((len(buf) * 8) - *pos) { err := HasSpace(buf, *pos, n)
return 0, fmt.Errorf("not enough bits") if err != nil {
return 0, err
} }
return ReadBitsUnsafe(buf, pos, n), nil
}
// ReadBitsUnsafe reads N bits.
func ReadBitsUnsafe(buf []byte, pos *int, n int) uint64 {
v := uint64(0) v := uint64(0)
res := 8 - (*pos & 0x07) res := 8 - (*pos & 0x07)
if n < res { if n < res {
v := uint64((buf[*pos>>0x03] >> (res - n)) & (1<<n - 1)) v := uint64((buf[*pos>>0x03] >> (res - n)) & (1<<n - 1))
*pos += n *pos += n
return v, nil return v
} }
v = (v << res) | uint64(buf[*pos>>0x03]&(1<<res-1)) v = (v << res) | uint64(buf[*pos>>0x03]&(1<<res-1))
@@ -35,7 +49,7 @@ func ReadBits(buf []byte, pos *int, n int) (uint64, error) {
*pos += n *pos += n
} }
return v, nil return v
} }
// ReadGolombUnsigned reads an unsigned golomb-encoded value. // ReadGolombUnsigned reads an unsigned golomb-encoded value.
@@ -93,29 +107,17 @@ func ReadGolombSigned(buf []byte, pos *int) (int32, error) {
// ReadFlag reads a boolean flag. // ReadFlag reads a boolean flag.
func ReadFlag(buf []byte, pos *int) (bool, error) { func ReadFlag(buf []byte, pos *int) (bool, error) {
if (len(buf)*8 - *pos) == 0 { err := HasSpace(buf, *pos, 1)
return false, fmt.Errorf("not enough bits") if err != nil {
return false, err
} }
return ReadFlagUnsafe(buf, pos), nil
}
// ReadFlagUnsafe reads a boolean flag.
func ReadFlagUnsafe(buf []byte, pos *int) bool {
b := (buf[*pos>>0x03] >> (7 - (*pos & 0x07))) & 0x01 b := (buf[*pos>>0x03] >> (7 - (*pos & 0x07))) & 0x01
*pos++ *pos++
return b == 1, nil return b == 1
}
// ReadUint8 reads a uint8.
func ReadUint8(buf []byte, pos *int) (uint8, error) {
v, err := ReadBits(buf, pos, 8)
return uint8(v), err
}
// ReadUint16 reads a uint16.
func ReadUint16(buf []byte, pos *int) (uint16, error) {
v, err := ReadBits(buf, pos, 16)
return uint16(v), err
}
// ReadUint32 reads a uint32.
func ReadUint32(buf []byte, pos *int) (uint32, error) {
v, err := ReadBits(buf, pos, 32)
return uint32(v), err
} }

View File

@@ -86,14 +86,3 @@ func TestReadFlagError(t *testing.T) {
_, err := ReadFlag(buf, &pos) _, err := ReadFlag(buf, &pos)
require.EqualError(t, err, "not enough bits") require.EqualError(t, err, "not enough bits")
} }
func TestUint(t *testing.T) {
buf := []byte{0x45, 0x46, 0x47, 0x48, 0x49, 0x50, 0x51}
pos := 0
u8, _ := ReadUint8(buf, &pos)
require.Equal(t, uint8(0x45), u8)
u16, _ := ReadUint16(buf, &pos)
require.Equal(t, uint16(0x4647), u16)
u32, _ := ReadUint32(buf, &pos)
require.Equal(t, uint32(0x48495051), u32)
}

View File

@@ -56,17 +56,13 @@ func (h *SPS_HRD) unmarshal(buf []byte, pos *int) error {
return err return err
} }
tmp, err := bits.ReadBits(buf, pos, 4) err = bits.HasSpace(buf, *pos, 8)
if err != nil { if err != nil {
return err return err
} }
h.BitRateScale = uint8(tmp)
tmp, err = bits.ReadBits(buf, pos, 4) h.BitRateScale = uint8(bits.ReadBitsUnsafe(buf, pos, 4))
if err != nil { h.CpbSizeScale = uint8(bits.ReadBitsUnsafe(buf, pos, 4))
return err
}
h.CpbSizeScale = uint8(tmp)
for i := uint32(0); i <= h.CpbCntMinus1; i++ { for i := uint32(0); i <= h.CpbCntMinus1; i++ {
v, err := bits.ReadGolombUnsigned(buf, pos) v, err := bits.ReadGolombUnsigned(buf, pos)
@@ -88,29 +84,15 @@ func (h *SPS_HRD) unmarshal(buf []byte, pos *int) error {
h.CbrFlag = append(h.CbrFlag, vb) h.CbrFlag = append(h.CbrFlag, vb)
} }
tmp, err = bits.ReadBits(buf, pos, 5) err = bits.HasSpace(buf, *pos, 5+5+5+5)
if err != nil { if err != nil {
return err return err
} }
h.InitialCpbRemovalDelayLengthMinus1 = uint8(tmp)
tmp, err = bits.ReadBits(buf, pos, 5) h.InitialCpbRemovalDelayLengthMinus1 = uint8(bits.ReadBitsUnsafe(buf, pos, 5))
if err != nil { h.CpbRemovalDelayLengthMinus1 = uint8(bits.ReadBitsUnsafe(buf, pos, 5))
return err h.DpbOutputDelayLengthMinus1 = uint8(bits.ReadBitsUnsafe(buf, pos, 5))
} h.TimeOffsetLength = uint8(bits.ReadBitsUnsafe(buf, pos, 5))
h.CpbRemovalDelayLengthMinus1 = uint8(tmp)
tmp, err = bits.ReadBits(buf, pos, 5)
if err != nil {
return err
}
h.DpbOutputDelayLengthMinus1 = uint8(tmp)
tmp, err = bits.ReadBits(buf, pos, 5)
if err != nil {
return err
}
h.TimeOffsetLength = uint8(tmp)
return nil return nil
} }
@@ -123,21 +105,14 @@ type SPS_TimingInfo struct { //nolint:revive
} }
func (t *SPS_TimingInfo) unmarshal(buf []byte, pos *int) error { func (t *SPS_TimingInfo) unmarshal(buf []byte, pos *int) error {
var err error err := bits.HasSpace(buf, *pos, 32+32+1)
t.NumUnitsInTick, err = bits.ReadUint32(buf, pos)
if err != nil { if err != nil {
return err return err
} }
t.TimeScale, err = bits.ReadUint32(buf, pos) t.NumUnitsInTick = uint32(bits.ReadBitsUnsafe(buf, pos, 32))
if err != nil { t.TimeScale = uint32(bits.ReadBitsUnsafe(buf, pos, 32))
return err t.FixedFrameRateFlag = bits.ReadFlagUnsafe(buf, pos)
}
t.FixedFrameRateFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
return nil return nil
} }
@@ -243,21 +218,20 @@ func (v *SPS_VUI) unmarshal(buf []byte, pos *int) error {
} }
if v.AspectRatioInfoPresentFlag { if v.AspectRatioInfoPresentFlag {
v.AspectRatioIdc, err = bits.ReadUint8(buf, pos) tmp, err := bits.ReadBits(buf, pos, 8)
if err != nil { if err != nil {
return err return err
} }
v.AspectRatioIdc = uint8(tmp)
if v.AspectRatioIdc == 255 { // Extended_SAR if v.AspectRatioIdc == 255 { // Extended_SAR
v.SarWidth, err = bits.ReadUint16(buf, pos) err := bits.HasSpace(buf, *pos, 32)
if err != nil { if err != nil {
return err return err
} }
v.SarHeight, err = bits.ReadUint16(buf, pos) v.SarWidth = uint16(bits.ReadBitsUnsafe(buf, pos, 16))
if err != nil { v.SarHeight = uint16(bits.ReadBitsUnsafe(buf, pos, 16))
return err
}
} }
} }
@@ -279,37 +253,24 @@ func (v *SPS_VUI) unmarshal(buf []byte, pos *int) error {
} }
if v.VideoSignalTypePresentFlag { if v.VideoSignalTypePresentFlag {
tmp, err := bits.ReadBits(buf, pos, 3) err := bits.HasSpace(buf, *pos, 5)
if err != nil {
return err
}
v.VideoFormat = uint8(tmp)
v.VideoFullRangeFlag, err = bits.ReadFlag(buf, pos)
if err != nil { if err != nil {
return err return err
} }
v.ColourDescriptionPresentFlag, err = bits.ReadFlag(buf, pos) v.VideoFormat = uint8(bits.ReadBitsUnsafe(buf, pos, 3))
if err != nil { v.VideoFullRangeFlag = bits.ReadFlagUnsafe(buf, pos)
return err v.ColourDescriptionPresentFlag = bits.ReadFlagUnsafe(buf, pos)
}
if v.ColourDescriptionPresentFlag { if v.ColourDescriptionPresentFlag {
v.ColourPrimaries, err = bits.ReadUint8(buf, pos) err := bits.HasSpace(buf, *pos, 24)
if err != nil { if err != nil {
return err return err
} }
v.TransferCharacteristics, err = bits.ReadUint8(buf, pos) v.ColourPrimaries = uint8(bits.ReadBitsUnsafe(buf, pos, 8))
if err != nil { v.TransferCharacteristics = uint8(bits.ReadBitsUnsafe(buf, pos, 8))
return err v.MatrixCoefficients = uint8(bits.ReadBitsUnsafe(buf, pos, 8))
}
v.MatrixCoefficients, err = bits.ReadUint8(buf, pos)
if err != nil {
return err
}
} }
} }
@@ -490,7 +451,7 @@ func (s *SPS) Unmarshal(buf []byte) error {
buf = EmulationPreventionRemove(buf) buf = EmulationPreventionRemove(buf)
if len(buf) < 4 { if len(buf) < 4 {
return fmt.Errorf("buffer too short") return fmt.Errorf("not enough bits")
} }
forbidden := buf[0] >> 7 forbidden := buf[0] >> 7

View File

@@ -46,90 +46,31 @@ type SPS_ProfileLevelTier struct { //nolint:revive
} }
func (p *SPS_ProfileLevelTier) unmarshal(buf []byte, pos *int, maxNumSubLayersMinus1 uint8) error { func (p *SPS_ProfileLevelTier) unmarshal(buf []byte, pos *int, maxNumSubLayersMinus1 uint8) error {
tmp, err := bits.ReadBits(buf, pos, 2) err := bits.HasSpace(buf, *pos, 8+32+12+34+8)
if err != nil { if err != nil {
return err return err
} }
p.GeneralProfileSpace = uint8(tmp)
tmp, err = bits.ReadBits(buf, pos, 1) p.GeneralProfileSpace = uint8(bits.ReadBitsUnsafe(buf, pos, 2))
if err != nil { p.GeneralTierFlag = uint8(bits.ReadBitsUnsafe(buf, pos, 1))
return err p.GeneralProfileIdc = uint8(bits.ReadBitsUnsafe(buf, pos, 5))
}
p.GeneralTierFlag = uint8(tmp)
tmp, err = bits.ReadBits(buf, pos, 5)
if err != nil {
return err
}
p.GeneralProfileIdc = uint8(tmp)
for j := 0; j < 32; j++ { for j := 0; j < 32; j++ {
p.GeneralProfileCompatibilityFlag[j], err = bits.ReadFlag(buf, pos) p.GeneralProfileCompatibilityFlag[j] = bits.ReadFlagUnsafe(buf, pos)
if err != nil {
return err
}
} }
p.ProgressiveSourceFlag, err = bits.ReadFlag(buf, pos) p.ProgressiveSourceFlag = bits.ReadFlagUnsafe(buf, pos)
if err != nil { p.InterlacedSourceFlag = bits.ReadFlagUnsafe(buf, pos)
return err p.NonPackedConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
} p.FrameOnlyConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
p.Max12bitConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
p.InterlacedSourceFlag, err = bits.ReadFlag(buf, pos) p.Max10bitConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
if err != nil { p.Max8bitConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
return err p.Max422ChromeConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
} p.Max420ChromaConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
p.IntraConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
p.NonPackedConstraintFlag, err = bits.ReadFlag(buf, pos) p.OnePictureOnlyConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
if err != nil { p.LowerBitRateConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
return err
}
p.FrameOnlyConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.Max12bitConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.Max10bitConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.Max8bitConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.Max422ChromeConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.Max420ChromaConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.IntraConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.OnePictureOnlyConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
p.LowerBitRateConstraintFlag, err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
if p.GeneralProfileIdc == 5 || if p.GeneralProfileIdc == 5 ||
p.GeneralProfileIdc == 9 || p.GeneralProfileIdc == 9 ||
@@ -139,43 +80,36 @@ func (p *SPS_ProfileLevelTier) unmarshal(buf []byte, pos *int, maxNumSubLayersMi
p.GeneralProfileCompatibilityFlag[9] || p.GeneralProfileCompatibilityFlag[9] ||
p.GeneralProfileCompatibilityFlag[10] || p.GeneralProfileCompatibilityFlag[10] ||
p.GeneralProfileCompatibilityFlag[11] { p.GeneralProfileCompatibilityFlag[11] {
p.Max14BitConstraintFlag, err = bits.ReadFlag(buf, pos) p.Max14BitConstraintFlag = bits.ReadFlagUnsafe(buf, pos)
if err != nil {
return err
}
*pos += 33 *pos += 33
} else { } else {
*pos += 34 *pos += 34
} }
tmp, err = bits.ReadBits(buf, pos, 8) p.LevelIdc = uint8(bits.ReadBitsUnsafe(buf, pos, 8))
if err != nil {
return err
}
p.LevelIdc = uint8(tmp)
if maxNumSubLayersMinus1 > 0 { if maxNumSubLayersMinus1 > 0 {
p.SubLayerProfilePresentFlag = make([]bool, maxNumSubLayersMinus1) p.SubLayerProfilePresentFlag = make([]bool, maxNumSubLayersMinus1)
p.SubLayerLevelPresentFlag = make([]bool, maxNumSubLayersMinus1) p.SubLayerLevelPresentFlag = make([]bool, maxNumSubLayersMinus1)
err := bits.HasSpace(buf, *pos, int(2*maxNumSubLayersMinus1))
if err != nil {
return err
}
} }
for j := uint8(0); j < maxNumSubLayersMinus1; j++ { for j := uint8(0); j < maxNumSubLayersMinus1; j++ {
p.SubLayerProfilePresentFlag[j], err = bits.ReadFlag(buf, pos) p.SubLayerProfilePresentFlag[j] = bits.ReadFlagUnsafe(buf, pos)
if err != nil { p.SubLayerLevelPresentFlag[j] = bits.ReadFlagUnsafe(buf, pos)
return err
}
p.SubLayerLevelPresentFlag[j], err = bits.ReadFlag(buf, pos)
if err != nil {
return err
}
} }
if maxNumSubLayersMinus1 > 0 { if maxNumSubLayersMinus1 > 0 {
for i := maxNumSubLayersMinus1; i < 8; i++ { err := bits.HasSpace(buf, *pos, int(8-maxNumSubLayersMinus1)*2)
*pos += 2 if err != nil {
return err
} }
*pos += int(8-maxNumSubLayersMinus1) * 2
} }
for i := uint8(0); i < maxNumSubLayersMinus1; i++ { for i := uint8(0); i < maxNumSubLayersMinus1; i++ {
@@ -245,6 +179,10 @@ type SPS struct {
// Unmarshal decodes a SPS from bytes. // Unmarshal decodes a SPS from bytes.
func (s *SPS) Unmarshal(buf []byte) error { func (s *SPS) Unmarshal(buf []byte) error {
if len(buf) < 2 {
return fmt.Errorf("not enough bits")
}
typ := NALUType((buf[0] >> 1) & 0b111111) typ := NALUType((buf[0] >> 1) & 0b111111)
if typ != NALUTypeSPS { if typ != NALUTypeSPS {
@@ -255,22 +193,14 @@ func (s *SPS) Unmarshal(buf []byte) error {
buf = h264.EmulationPreventionRemove(buf) buf = h264.EmulationPreventionRemove(buf)
pos := 0 pos := 0
tmp, err := bits.ReadBits(buf, &pos, 4) err := bits.HasSpace(buf, pos, 8)
if err != nil { if err != nil {
return err return err
} }
s.VPSID = uint8(tmp)
tmp, err = bits.ReadBits(buf, &pos, 3) s.VPSID = uint8(bits.ReadBitsUnsafe(buf, &pos, 4))
if err != nil { s.MaxNumSubLayersMinus1 = uint8(bits.ReadBitsUnsafe(buf, &pos, 3))
return err s.TemporalIDNestingFlag = bits.ReadFlagUnsafe(buf, &pos)
}
s.MaxNumSubLayersMinus1 = uint8(tmp)
s.TemporalIDNestingFlag, err = bits.ReadFlag(buf, &pos)
if err != nil {
return err
}
err = s.ProfileLevelTier.unmarshal(buf, &pos, s.MaxNumSubLayersMinus1) err = s.ProfileLevelTier.unmarshal(buf, &pos, s.MaxNumSubLayersMinus1)
if err != nil { if err != nil {