diff --git a/pkg/bits/read.go b/pkg/bits/read.go index 97688c1d..6d925b42 100644 --- a/pkg/bits/read.go +++ b/pkg/bits/read.go @@ -5,19 +5,33 @@ import ( "fmt" ) +// HasSpace checks whether buffer has space for N bits. +func HasSpace(buf []byte, pos int, n int) error { + if n > ((len(buf) * 8) - pos) { + return fmt.Errorf("not enough bits") + } + return nil +} + // ReadBits reads N bits. func ReadBits(buf []byte, pos *int, n int) (uint64, error) { - if n > ((len(buf) * 8) - *pos) { - return 0, fmt.Errorf("not enough bits") + err := HasSpace(buf, *pos, n) + if err != nil { + return 0, err } + return ReadBitsUnsafe(buf, pos, n), nil +} + +// ReadBitsUnsafe reads N bits. +func ReadBitsUnsafe(buf []byte, pos *int, n int) uint64 { v := uint64(0) res := 8 - (*pos & 0x07) if n < res { v := uint64((buf[*pos>>0x03] >> (res - n)) & (1<>0x03]&(1<>0x03] >> (7 - (*pos & 0x07))) & 0x01 *pos++ - return b == 1, nil -} - -// ReadUint8 reads a uint8. -func ReadUint8(buf []byte, pos *int) (uint8, error) { - v, err := ReadBits(buf, pos, 8) - return uint8(v), err -} - -// ReadUint16 reads a uint16. -func ReadUint16(buf []byte, pos *int) (uint16, error) { - v, err := ReadBits(buf, pos, 16) - return uint16(v), err -} - -// ReadUint32 reads a uint32. -func ReadUint32(buf []byte, pos *int) (uint32, error) { - v, err := ReadBits(buf, pos, 32) - return uint32(v), err + return b == 1 } diff --git a/pkg/bits/read_test.go b/pkg/bits/read_test.go index d10e3faa..6a72c8ae 100644 --- a/pkg/bits/read_test.go +++ b/pkg/bits/read_test.go @@ -86,14 +86,3 @@ func TestReadFlagError(t *testing.T) { _, err := ReadFlag(buf, &pos) require.EqualError(t, err, "not enough bits") } - -func TestUint(t *testing.T) { - buf := []byte{0x45, 0x46, 0x47, 0x48, 0x49, 0x50, 0x51} - pos := 0 - u8, _ := ReadUint8(buf, &pos) - require.Equal(t, uint8(0x45), u8) - u16, _ := ReadUint16(buf, &pos) - require.Equal(t, uint16(0x4647), u16) - u32, _ := ReadUint32(buf, &pos) - require.Equal(t, uint32(0x48495051), u32) -} diff --git a/pkg/h264/sps.go b/pkg/h264/sps.go index 56eb48a0..8c2b14f4 100644 --- a/pkg/h264/sps.go +++ b/pkg/h264/sps.go @@ -56,17 +56,13 @@ func (h *SPS_HRD) unmarshal(buf []byte, pos *int) error { return err } - tmp, err := bits.ReadBits(buf, pos, 4) + err = bits.HasSpace(buf, *pos, 8) if err != nil { return err } - h.BitRateScale = uint8(tmp) - tmp, err = bits.ReadBits(buf, pos, 4) - if err != nil { - return err - } - h.CpbSizeScale = uint8(tmp) + h.BitRateScale = uint8(bits.ReadBitsUnsafe(buf, pos, 4)) + h.CpbSizeScale = uint8(bits.ReadBitsUnsafe(buf, pos, 4)) for i := uint32(0); i <= h.CpbCntMinus1; i++ { v, err := bits.ReadGolombUnsigned(buf, pos) @@ -88,29 +84,15 @@ func (h *SPS_HRD) unmarshal(buf []byte, pos *int) error { h.CbrFlag = append(h.CbrFlag, vb) } - tmp, err = bits.ReadBits(buf, pos, 5) + err = bits.HasSpace(buf, *pos, 5+5+5+5) if err != nil { return err } - h.InitialCpbRemovalDelayLengthMinus1 = uint8(tmp) - tmp, err = bits.ReadBits(buf, pos, 5) - if err != nil { - return err - } - h.CpbRemovalDelayLengthMinus1 = uint8(tmp) - - tmp, err = bits.ReadBits(buf, pos, 5) - if err != nil { - return err - } - h.DpbOutputDelayLengthMinus1 = uint8(tmp) - - tmp, err = bits.ReadBits(buf, pos, 5) - if err != nil { - return err - } - h.TimeOffsetLength = uint8(tmp) + h.InitialCpbRemovalDelayLengthMinus1 = uint8(bits.ReadBitsUnsafe(buf, pos, 5)) + h.CpbRemovalDelayLengthMinus1 = uint8(bits.ReadBitsUnsafe(buf, pos, 5)) + h.DpbOutputDelayLengthMinus1 = uint8(bits.ReadBitsUnsafe(buf, pos, 5)) + h.TimeOffsetLength = uint8(bits.ReadBitsUnsafe(buf, pos, 5)) return nil } @@ -123,21 +105,14 @@ type SPS_TimingInfo struct { //nolint:revive } func (t *SPS_TimingInfo) unmarshal(buf []byte, pos *int) error { - var err error - t.NumUnitsInTick, err = bits.ReadUint32(buf, pos) + err := bits.HasSpace(buf, *pos, 32+32+1) if err != nil { return err } - t.TimeScale, err = bits.ReadUint32(buf, pos) - if err != nil { - return err - } - - t.FixedFrameRateFlag, err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } + t.NumUnitsInTick = uint32(bits.ReadBitsUnsafe(buf, pos, 32)) + t.TimeScale = uint32(bits.ReadBitsUnsafe(buf, pos, 32)) + t.FixedFrameRateFlag = bits.ReadFlagUnsafe(buf, pos) return nil } @@ -243,21 +218,20 @@ func (v *SPS_VUI) unmarshal(buf []byte, pos *int) error { } if v.AspectRatioInfoPresentFlag { - v.AspectRatioIdc, err = bits.ReadUint8(buf, pos) + tmp, err := bits.ReadBits(buf, pos, 8) if err != nil { return err } + v.AspectRatioIdc = uint8(tmp) if v.AspectRatioIdc == 255 { // Extended_SAR - v.SarWidth, err = bits.ReadUint16(buf, pos) + err := bits.HasSpace(buf, *pos, 32) if err != nil { return err } - v.SarHeight, err = bits.ReadUint16(buf, pos) - if err != nil { - return err - } + v.SarWidth = uint16(bits.ReadBitsUnsafe(buf, pos, 16)) + v.SarHeight = uint16(bits.ReadBitsUnsafe(buf, pos, 16)) } } @@ -279,37 +253,24 @@ func (v *SPS_VUI) unmarshal(buf []byte, pos *int) error { } if v.VideoSignalTypePresentFlag { - tmp, err := bits.ReadBits(buf, pos, 3) - if err != nil { - return err - } - v.VideoFormat = uint8(tmp) - - v.VideoFullRangeFlag, err = bits.ReadFlag(buf, pos) + err := bits.HasSpace(buf, *pos, 5) if err != nil { return err } - v.ColourDescriptionPresentFlag, err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } + v.VideoFormat = uint8(bits.ReadBitsUnsafe(buf, pos, 3)) + v.VideoFullRangeFlag = bits.ReadFlagUnsafe(buf, pos) + v.ColourDescriptionPresentFlag = bits.ReadFlagUnsafe(buf, pos) if v.ColourDescriptionPresentFlag { - v.ColourPrimaries, err = bits.ReadUint8(buf, pos) + err := bits.HasSpace(buf, *pos, 24) if err != nil { return err } - v.TransferCharacteristics, err = bits.ReadUint8(buf, pos) - if err != nil { - return err - } - - v.MatrixCoefficients, err = bits.ReadUint8(buf, pos) - if err != nil { - return err - } + v.ColourPrimaries = uint8(bits.ReadBitsUnsafe(buf, pos, 8)) + v.TransferCharacteristics = uint8(bits.ReadBitsUnsafe(buf, pos, 8)) + v.MatrixCoefficients = uint8(bits.ReadBitsUnsafe(buf, pos, 8)) } } @@ -490,7 +451,7 @@ func (s *SPS) Unmarshal(buf []byte) error { buf = EmulationPreventionRemove(buf) if len(buf) < 4 { - return fmt.Errorf("buffer too short") + return fmt.Errorf("not enough bits") } forbidden := buf[0] >> 7 diff --git a/pkg/h265/sps.go b/pkg/h265/sps.go index 0cb09c75..c356ac5e 100644 --- a/pkg/h265/sps.go +++ b/pkg/h265/sps.go @@ -46,90 +46,31 @@ type SPS_ProfileLevelTier struct { //nolint:revive } func (p *SPS_ProfileLevelTier) unmarshal(buf []byte, pos *int, maxNumSubLayersMinus1 uint8) error { - tmp, err := bits.ReadBits(buf, pos, 2) + err := bits.HasSpace(buf, *pos, 8+32+12+34+8) if err != nil { return err } - p.GeneralProfileSpace = uint8(tmp) - tmp, err = bits.ReadBits(buf, pos, 1) - if err != nil { - return err - } - p.GeneralTierFlag = uint8(tmp) - - tmp, err = bits.ReadBits(buf, pos, 5) - if err != nil { - return err - } - p.GeneralProfileIdc = uint8(tmp) + p.GeneralProfileSpace = uint8(bits.ReadBitsUnsafe(buf, pos, 2)) + p.GeneralTierFlag = uint8(bits.ReadBitsUnsafe(buf, pos, 1)) + p.GeneralProfileIdc = uint8(bits.ReadBitsUnsafe(buf, pos, 5)) for j := 0; j < 32; j++ { - p.GeneralProfileCompatibilityFlag[j], err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } + p.GeneralProfileCompatibilityFlag[j] = bits.ReadFlagUnsafe(buf, pos) } - p.ProgressiveSourceFlag, err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } - - p.InterlacedSourceFlag, err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } - - p.NonPackedConstraintFlag, err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } - - p.FrameOnlyConstraintFlag, err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } - - p.Max12bitConstraintFlag, err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } - - p.Max10bitConstraintFlag, err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } - - p.Max8bitConstraintFlag, err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } - - p.Max422ChromeConstraintFlag, err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } - - p.Max420ChromaConstraintFlag, err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } - - p.IntraConstraintFlag, err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } - - p.OnePictureOnlyConstraintFlag, err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } - - p.LowerBitRateConstraintFlag, err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } + p.ProgressiveSourceFlag = bits.ReadFlagUnsafe(buf, pos) + p.InterlacedSourceFlag = bits.ReadFlagUnsafe(buf, pos) + p.NonPackedConstraintFlag = bits.ReadFlagUnsafe(buf, pos) + p.FrameOnlyConstraintFlag = bits.ReadFlagUnsafe(buf, pos) + p.Max12bitConstraintFlag = bits.ReadFlagUnsafe(buf, pos) + p.Max10bitConstraintFlag = bits.ReadFlagUnsafe(buf, pos) + p.Max8bitConstraintFlag = bits.ReadFlagUnsafe(buf, pos) + p.Max422ChromeConstraintFlag = bits.ReadFlagUnsafe(buf, pos) + p.Max420ChromaConstraintFlag = bits.ReadFlagUnsafe(buf, pos) + p.IntraConstraintFlag = bits.ReadFlagUnsafe(buf, pos) + p.OnePictureOnlyConstraintFlag = bits.ReadFlagUnsafe(buf, pos) + p.LowerBitRateConstraintFlag = bits.ReadFlagUnsafe(buf, pos) if p.GeneralProfileIdc == 5 || p.GeneralProfileIdc == 9 || @@ -139,43 +80,36 @@ func (p *SPS_ProfileLevelTier) unmarshal(buf []byte, pos *int, maxNumSubLayersMi p.GeneralProfileCompatibilityFlag[9] || p.GeneralProfileCompatibilityFlag[10] || p.GeneralProfileCompatibilityFlag[11] { - p.Max14BitConstraintFlag, err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } - + p.Max14BitConstraintFlag = bits.ReadFlagUnsafe(buf, pos) *pos += 33 } else { *pos += 34 } - tmp, err = bits.ReadBits(buf, pos, 8) - if err != nil { - return err - } - p.LevelIdc = uint8(tmp) + p.LevelIdc = uint8(bits.ReadBitsUnsafe(buf, pos, 8)) if maxNumSubLayersMinus1 > 0 { p.SubLayerProfilePresentFlag = make([]bool, maxNumSubLayersMinus1) p.SubLayerLevelPresentFlag = make([]bool, maxNumSubLayersMinus1) + + err := bits.HasSpace(buf, *pos, int(2*maxNumSubLayersMinus1)) + if err != nil { + return err + } } for j := uint8(0); j < maxNumSubLayersMinus1; j++ { - p.SubLayerProfilePresentFlag[j], err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } - - p.SubLayerLevelPresentFlag[j], err = bits.ReadFlag(buf, pos) - if err != nil { - return err - } + p.SubLayerProfilePresentFlag[j] = bits.ReadFlagUnsafe(buf, pos) + p.SubLayerLevelPresentFlag[j] = bits.ReadFlagUnsafe(buf, pos) } if maxNumSubLayersMinus1 > 0 { - for i := maxNumSubLayersMinus1; i < 8; i++ { - *pos += 2 + err := bits.HasSpace(buf, *pos, int(8-maxNumSubLayersMinus1)*2) + if err != nil { + return err } + + *pos += int(8-maxNumSubLayersMinus1) * 2 } for i := uint8(0); i < maxNumSubLayersMinus1; i++ { @@ -245,6 +179,10 @@ type SPS struct { // Unmarshal decodes a SPS from bytes. func (s *SPS) Unmarshal(buf []byte) error { + if len(buf) < 2 { + return fmt.Errorf("not enough bits") + } + typ := NALUType((buf[0] >> 1) & 0b111111) if typ != NALUTypeSPS { @@ -255,22 +193,14 @@ func (s *SPS) Unmarshal(buf []byte) error { buf = h264.EmulationPreventionRemove(buf) pos := 0 - tmp, err := bits.ReadBits(buf, &pos, 4) + err := bits.HasSpace(buf, pos, 8) if err != nil { return err } - s.VPSID = uint8(tmp) - tmp, err = bits.ReadBits(buf, &pos, 3) - if err != nil { - return err - } - s.MaxNumSubLayersMinus1 = uint8(tmp) - - s.TemporalIDNestingFlag, err = bits.ReadFlag(buf, &pos) - if err != nil { - return err - } + s.VPSID = uint8(bits.ReadBitsUnsafe(buf, &pos, 4)) + s.MaxNumSubLayersMinus1 = uint8(bits.ReadBitsUnsafe(buf, &pos, 3)) + s.TemporalIDNestingFlag = bits.ReadFlagUnsafe(buf, &pos) err = s.ProfileLevelTier.unmarshal(buf, &pos, s.MaxNumSubLayersMinus1) if err != nil {