From d11fe222c3be5bb9f15d9d5169d9836dfe3ac5d0 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Fri, 24 Jun 2022 16:41:01 +0200 Subject: [PATCH] add more efficient bit reader / writer --- go.mod | 1 - go.sum | 4 - pkg/aac/mpeg4audioconfig.go | 50 +++--- pkg/aac/mpeg4audioconfig_test.go | 50 ------ pkg/bits/read.go | 120 ++++++++++++++ pkg/bits/read_test.go | 70 ++++++++ pkg/bits/write.go | 26 +++ pkg/bits/write_test.go | 18 +++ pkg/h264/dtsextractor.go | 25 +-- pkg/h264/sps.go | 263 +++++++++++-------------------- pkg/h264/sps_test.go | 15 ++ pkg/rtpaac/decoder.go | 14 +- pkg/rtpaac/encoder.go | 21 ++- pkg/rtpaac/rtpaac_test.go | 2 +- 14 files changed, 394 insertions(+), 285 deletions(-) create mode 100644 pkg/bits/read.go create mode 100644 pkg/bits/read_test.go create mode 100644 pkg/bits/write.go create mode 100644 pkg/bits/write_test.go diff --git a/go.mod b/go.mod index fabf3950..de758a61 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.16 require ( github.com/asticode/go-astits v1.10.0 - github.com/icza/bitio v1.1.0 github.com/pion/rtcp v1.2.9 github.com/pion/rtp v1.7.13 github.com/pion/sdp/v3 v3.0.5 diff --git a/go.sum b/go.sum index 16264173..a22ae27a 100644 --- a/go.sum +++ b/go.sum @@ -4,10 +4,6 @@ github.com/asticode/go-astits v1.10.0 h1:ixKsRl84nWtjgHWcWKTDkUHNQ4kxbf9nKmjuSCn github.com/asticode/go-astits v1.10.0/go.mod h1:DkOWmBNQpnr9mv24KfZjq4JawCFX1FCqjLVGvO0DygQ= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/icza/bitio v1.1.0 h1:ysX4vtldjdi3Ygai5m1cWy4oLkhWTAi+SyO6HC8L9T0= -github.com/icza/bitio v1.1.0/go.mod h1:0jGnlLAx8MKMr9VGnn/4YrvZiprkvBelsVIbA9Jjr9A= -github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6 h1:8UsGZ2rr2ksmEru6lToqnXgA8Mz1DP11X4zSJ159C3k= -github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6/go.mod h1:xQig96I1VNBDIWGCdTt54nHt6EeI639SmHycLYL7FkA= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/rtcp v1.2.9 h1:1ujStwg++IOLIEoOiIQ2s+qBuJ1VN81KW+9pMPsif+U= diff --git a/pkg/aac/mpeg4audioconfig.go b/pkg/aac/mpeg4audioconfig.go index 7cae5741..682fa4f2 100644 --- a/pkg/aac/mpeg4audioconfig.go +++ b/pkg/aac/mpeg4audioconfig.go @@ -1,10 +1,9 @@ package aac import ( - "bytes" "fmt" - "github.com/icza/bitio" + "github.com/aler9/gortsplib/pkg/bits" ) // MPEG4AudioConfig is a MPEG-4 Audio configuration. @@ -20,12 +19,12 @@ type MPEG4AudioConfig struct { } // Unmarshal decodes an MPEG4AudioConfig. -func (c *MPEG4AudioConfig) Unmarshal(byts []byte) error { +func (c *MPEG4AudioConfig) Unmarshal(buf []byte) error { // ref: ISO 14496-3 - r := bitio.NewReader(bytes.NewBuffer(byts)) + pos := 0 - tmp, err := r.ReadBits(5) + tmp, err := bits.ReadBits(buf, &pos, 5) if err != nil { return err } @@ -37,7 +36,7 @@ func (c *MPEG4AudioConfig) Unmarshal(byts []byte) error { return fmt.Errorf("unsupported type: %d", c.Type) } - sampleRateIndex, err := r.ReadBits(4) + sampleRateIndex, err := bits.ReadBits(buf, &pos, 4) if err != nil { return err } @@ -47,7 +46,7 @@ func (c *MPEG4AudioConfig) Unmarshal(byts []byte) error { c.SampleRate = sampleRates[sampleRateIndex] case sampleRateIndex == 15: - tmp, err := r.ReadBits(24) + tmp, err := bits.ReadBits(buf, &pos, 24) if err != nil { return err } @@ -57,7 +56,7 @@ func (c *MPEG4AudioConfig) Unmarshal(byts []byte) error { return fmt.Errorf("invalid sample rate index (%d)", sampleRateIndex) } - channelConfig, err := r.ReadBits(4) + channelConfig, err := bits.ReadBits(buf, &pos, 4) if err != nil { return err } @@ -76,31 +75,28 @@ func (c *MPEG4AudioConfig) Unmarshal(byts []byte) error { return fmt.Errorf("invalid channel configuration (%d)", channelConfig) } - tmp, err = r.ReadBits(1) + c.FrameLengthFlag, err = bits.ReadFlag(buf, &pos) if err != nil { return err } - c.FrameLengthFlag = (tmp == 1) - tmp, err = r.ReadBits(1) + c.DependsOnCoreCoder, err = bits.ReadFlag(buf, &pos) if err != nil { return err } - c.DependsOnCoreCoder = (tmp == 1) if c.DependsOnCoreCoder { - tmp, err := r.ReadBits(14) + tmp, err := bits.ReadBits(buf, &pos, 14) if err != nil { return err } c.CoreCoderDelay = uint16(tmp) } - tmp, err = r.ReadBits(1) + extensionFlag, err := bits.ReadFlag(buf, &pos) if err != nil { return err } - extensionFlag := (tmp == 1) if extensionFlag { return fmt.Errorf("unsupported") @@ -134,16 +130,16 @@ func (c MPEG4AudioConfig) marshalSize() int { // Marshal encodes an MPEG4AudioConfig. func (c MPEG4AudioConfig) Marshal() ([]byte, error) { buf := make([]byte, c.marshalSize()) - w := bitio.NewWriter(bytes.NewBuffer(buf[:0])) + pos := 0 - w.WriteBits(uint64(c.Type), 5) + bits.WriteBits(buf, &pos, uint64(c.Type), 5) sampleRateIndex, ok := reverseSampleRates[c.SampleRate] if !ok { - w.WriteBits(uint64(15), 4) - w.WriteBits(uint64(c.SampleRate), 24) + bits.WriteBits(buf, &pos, uint64(15), 4) + bits.WriteBits(buf, &pos, uint64(c.SampleRate), 24) } else { - w.WriteBits(uint64(sampleRateIndex), 4) + bits.WriteBits(buf, &pos, uint64(sampleRateIndex), 4) } var channelConfig int @@ -158,25 +154,23 @@ func (c MPEG4AudioConfig) Marshal() ([]byte, error) { return nil, fmt.Errorf("invalid channel count (%d)", c.ChannelCount) } - w.WriteBits(uint64(channelConfig), 4) + bits.WriteBits(buf, &pos, uint64(channelConfig), 4) if c.FrameLengthFlag { - w.WriteBits(1, 1) + bits.WriteBits(buf, &pos, 1, 1) } else { - w.WriteBits(0, 1) + bits.WriteBits(buf, &pos, 0, 1) } if c.DependsOnCoreCoder { - w.WriteBits(1, 1) + bits.WriteBits(buf, &pos, 1, 1) } else { - w.WriteBits(0, 1) + bits.WriteBits(buf, &pos, 0, 1) } if c.DependsOnCoreCoder { - w.WriteBits(uint64(c.CoreCoderDelay), 14) + bits.WriteBits(buf, &pos, uint64(c.CoreCoderDelay), 14) } - w.Close() - return buf, nil } diff --git a/pkg/aac/mpeg4audioconfig_test.go b/pkg/aac/mpeg4audioconfig_test.go index 45d22be7..259aa9d4 100644 --- a/pkg/aac/mpeg4audioconfig_test.go +++ b/pkg/aac/mpeg4audioconfig_test.go @@ -80,56 +80,6 @@ func TestConfigUnmarshal(t *testing.T) { } } -func TestConfigUnmarshalErrors(t *testing.T) { - for _, ca := range []struct { - name string - byts []byte - err string - }{ - { - "empty", - []byte{}, - "EOF", - }, - { - "unsupported type", - []byte{18 << 3}, - "unsupported type: 18", - }, - { - "sample rate missing", - []byte{0x12}, - "EOF", - }, - { - "sample rate invalid", - []byte{0x17, 0}, - "invalid sample rate index (14)", - }, - { - "explicit sample rate missing", - []byte{0x17, 0x80, 0x67}, - "EOF", - }, - { - "channel configuration invalid", - []byte{0x11, 0xF0}, - "invalid channel configuration (14)", - }, - { - "channel configuration zero", - []byte{0x11, 0x80}, - "not yet supported", - }, - } { - t.Run(ca.name, func(t *testing.T) { - var dec MPEG4AudioConfig - err := dec.Unmarshal(ca.byts) - require.EqualError(t, err, ca.err) - }) - } -} - func TestConfigMarshal(t *testing.T) { for _, ca := range configCases { t.Run(ca.name, func(t *testing.T) { diff --git a/pkg/bits/read.go b/pkg/bits/read.go new file mode 100644 index 00000000..97dde87d --- /dev/null +++ b/pkg/bits/read.go @@ -0,0 +1,120 @@ +// Package bits contains functions to read/write bits from/to buffers. +package bits + +import ( + "fmt" +) + +// ReadBits reads N bits. +func ReadBits(buf []byte, pos *int, n int) (uint64, error) { + if n > ((len(buf) * 8) - *pos) { + return 0, fmt.Errorf("not enough bits") + } + + v := uint64(0) + + res := 8 - (*pos & 0x07) + if n < res { + v := uint64((buf[*pos>>0x03] >> (res - n)) & (1<>0x03]&(1<= 8 { + v = (v << 8) | uint64(buf[*pos>>0x03]) + *pos += 8 + n -= 8 + } + + if n > 0 { + v = (v << n) | uint64(buf[*pos>>0x03]>>(8-n)) + *pos += n + } + + return v, nil +} + +// ReadGolombUnsigned reads an unsigned golomb-encoded value. +func ReadGolombUnsigned(buf []byte, pos *int) (uint32, error) { + leadingZeroBits := uint32(0) + + for { + if (len(buf)*8 - *pos) == 0 { + return 0, fmt.Errorf("not enough bits") + } + + b := (buf[*pos>>0x03] >> (7 - (*pos & 0x07))) & 0x01 + *pos++ + if b != 0 { + break + } + + leadingZeroBits++ + if leadingZeroBits > 32 { + return 0, fmt.Errorf("invalid value") + } + } + + if (len(buf)*8 - *pos) < int(leadingZeroBits) { + return 0, fmt.Errorf("not enough bits") + } + + codeNum := uint32(0) + + for n := leadingZeroBits; n > 0; n-- { + b := (buf[*pos>>0x03] >> (7 - (*pos & 0x07))) & 0x01 + *pos++ + codeNum |= uint32(b) << (n - 1) + } + + codeNum = (1 << leadingZeroBits) - 1 + codeNum + + return codeNum, nil +} + +// ReadGolombSigned reads a signed golomb-encoded value. +func ReadGolombSigned(buf []byte, pos *int) (int32, error) { + v, err := ReadGolombUnsigned(buf, pos) + if err != nil { + return 0, err + } + + vi := int32(v) + if (vi & 0x01) != 0 { + return (vi + 1) / 2, nil + } + return -vi / 2, nil +} + +// ReadFlag reads a boolean flag. +func ReadFlag(buf []byte, pos *int) (bool, error) { + if (len(buf)*8 - *pos) == 0 { + return false, fmt.Errorf("not enough bits") + } + + b := (buf[*pos>>0x03] >> (7 - (*pos & 0x07))) & 0x01 + *pos++ + return b == 1, nil +} + +// ReadUint8 reads a uint8. +func ReadUint8(buf []byte, pos *int) (uint8, error) { + v, err := ReadBits(buf, pos, 8) + return uint8(v), err +} + +// ReadUint16 reads a uint16. +func ReadUint16(buf []byte, pos *int) (uint16, error) { + v, err := ReadBits(buf, pos, 16) + return uint16(v), err +} + +// ReadUint32 reads a uint32. +func ReadUint32(buf []byte, pos *int) (uint32, error) { + v, err := ReadBits(buf, pos, 32) + return uint32(v), err +} diff --git a/pkg/bits/read_test.go b/pkg/bits/read_test.go new file mode 100644 index 00000000..e0be2477 --- /dev/null +++ b/pkg/bits/read_test.go @@ -0,0 +1,70 @@ +package bits + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestReadBits(t *testing.T) { + buf := []byte{0xA8, 0xC7, 0xD6, 0xAA, 0xBB, 0x10} + pos := 0 + v, _ := ReadBits(buf, &pos, 6) + require.Equal(t, uint64(0x2a), v) + v, _ = ReadBits(buf, &pos, 6) + require.Equal(t, uint64(0x0c), v) + v, _ = ReadBits(buf, &pos, 6) + require.Equal(t, uint64(0x1f), v) + v, _ = ReadBits(buf, &pos, 8) + require.Equal(t, uint64(0x5a), v) + v, _ = ReadBits(buf, &pos, 20) + require.Equal(t, uint64(0xaaec4), v) +} + +func TestReadBitsError(t *testing.T) { + buf := []byte{0xA8} + pos := 0 + _, err := ReadBits(buf, &pos, 6) + require.NoError(t, err) + _, err = ReadBits(buf, &pos, 6) + require.EqualError(t, err, "not enough bits") +} + +func TestReadGolombUnsigned(t *testing.T) { + buf := []byte{0x38} + pos := 0 + v, _ := ReadGolombUnsigned(buf, &pos) + require.Equal(t, uint32(6), v) +} + +func TestReadGolombSigned(t *testing.T) { + buf := []byte{0x38} + pos := 0 + v, _ := ReadGolombSigned(buf, &pos) + require.Equal(t, int32(-3), v) +} + +func TestReadFlag(t *testing.T) { + buf := []byte{0xFF} + pos := 0 + v, _ := ReadFlag(buf, &pos) + require.Equal(t, true, v) +} + +func TestReadFlagError(t *testing.T) { + buf := []byte{} + pos := 0 + _, err := ReadFlag(buf, &pos) + require.EqualError(t, err, "not enough bits") +} + +func TestUint(t *testing.T) { + buf := []byte{0x45, 0x46, 0x47, 0x48, 0x49, 0x50, 0x51} + pos := 0 + u8, _ := ReadUint8(buf, &pos) + require.Equal(t, uint8(0x45), u8) + u16, _ := ReadUint16(buf, &pos) + require.Equal(t, uint16(0x4647), u16) + u32, _ := ReadUint32(buf, &pos) + require.Equal(t, uint32(0x48495051), u32) +} diff --git a/pkg/bits/write.go b/pkg/bits/write.go new file mode 100644 index 00000000..4c71275e --- /dev/null +++ b/pkg/bits/write.go @@ -0,0 +1,26 @@ +package bits + +// WriteBits writes N bits. +func WriteBits(buf []byte, pos *int, bits uint64, n int) { + res := 8 - (*pos & 0x07) + if n < res { + buf[*pos>>0x03] |= byte(bits << (res - n)) + *pos += n + return + } + + buf[*pos>>3] |= byte(bits >> (n - res)) + *pos += res + n -= res + + for n >= 8 { + buf[*pos>>3] = byte(bits >> (n - 8)) + *pos += 8 + n -= 8 + } + + if n > 0 { + buf[*pos>>3] = byte((bits & (1< 0; n-- { - b, err := br.ReadBits(1) - if err != nil { - return 0, err - } - - codeNum |= uint32(b) << (n - 1) - } - - codeNum = (1 << leadingZeroBits) - 1 + codeNum - - return codeNum, nil -} - -func readGolombSigned(br *bitio.Reader) (int32, error) { - v, err := readGolombUnsigned(br) - if err != nil { - return 0, err - } - vi := int32(v) - - if (vi & 0x01) != 0 { - return (vi + 1) / 2, nil - } - - return -vi / 2, nil -} - -func readFlag(br *bitio.Reader) (bool, error) { - tmp, err := br.ReadBits(1) - if err != nil { - return false, err - } - return (tmp == 1), nil -} - -func readUint8(br *bitio.Reader) (uint8, error) { - tmp, err := br.ReadBits(8) - if err != nil { - return 0, err - } - return uint8(tmp), nil -} - -func readUint16(br *bitio.Reader) (uint16, error) { - tmp, err := br.ReadBits(16) - if err != nil { - return 0, err - } - return uint16(tmp), nil -} - -func readUint32(br *bitio.Reader) (uint32, error) { - tmp, err := br.ReadBits(32) - if err != nil { - return 0, err - } - return uint32(tmp), nil -} - -func readScalingList(br *bitio.Reader, size int) ([]int32, bool, error) { +func readScalingList(buf []byte, pos *int, size int) ([]int32, bool, error) { lastScale := int32(8) nextScale := int32(8) scalingList := make([]int32, size) @@ -93,7 +14,7 @@ func readScalingList(br *bitio.Reader, size int) ([]int32, bool, error) { for j := 0; j < size; j++ { if nextScale != 0 { - deltaScale, err := readGolombSigned(br) + deltaScale, err := bits.ReadGolombSigned(buf, pos) if err != nil { return nil, false, err } @@ -128,64 +49,64 @@ type SPS_HRD struct { //nolint:revive TimeOffsetLength uint8 } -func (h *SPS_HRD) unmarshal(br *bitio.Reader) error { +func (h *SPS_HRD) unmarshal(buf []byte, pos *int) error { var err error - h.CpbCntMinus1, err = readGolombUnsigned(br) + h.CpbCntMinus1, err = bits.ReadGolombUnsigned(buf, pos) if err != nil { return err } - tmp, err := br.ReadBits(4) + tmp, err := bits.ReadBits(buf, pos, 4) if err != nil { return err } h.BitRateScale = uint8(tmp) - tmp, err = br.ReadBits(4) + tmp, err = bits.ReadBits(buf, pos, 4) if err != nil { return err } h.CpbSizeScale = uint8(tmp) for i := uint32(0); i <= h.CpbCntMinus1; i++ { - v, err := readGolombUnsigned(br) + v, err := bits.ReadGolombUnsigned(buf, pos) if err != nil { return err } h.BitRateValueMinus1 = append(h.BitRateValueMinus1, v) - v, err = readGolombUnsigned(br) + v, err = bits.ReadGolombUnsigned(buf, pos) if err != nil { return err } h.CpbSizeValueMinus1 = append(h.CpbSizeValueMinus1, v) - vb, err := readFlag(br) + vb, err := bits.ReadFlag(buf, pos) if err != nil { return err } h.CbrFlag = append(h.CbrFlag, vb) } - tmp, err = br.ReadBits(5) + tmp, err = bits.ReadBits(buf, pos, 5) if err != nil { return err } h.InitialCpbRemovalDelayLengthMinus1 = uint8(tmp) - tmp, err = br.ReadBits(5) + tmp, err = bits.ReadBits(buf, pos, 5) if err != nil { return err } h.CpbRemovalDelayLengthMinus1 = uint8(tmp) - tmp, err = br.ReadBits(5) + tmp, err = bits.ReadBits(buf, pos, 5) if err != nil { return err } h.DpbOutputDelayLengthMinus1 = uint8(tmp) - tmp, err = br.ReadBits(5) + tmp, err = bits.ReadBits(buf, pos, 5) if err != nil { return err } @@ -201,19 +122,19 @@ type SPS_TimingInfo struct { //nolint:revive FixedFrameRateFlag bool } -func (t *SPS_TimingInfo) unmarshal(br *bitio.Reader) error { +func (t *SPS_TimingInfo) unmarshal(buf []byte, pos *int) error { var err error - t.NumUnitsInTick, err = readUint32(br) + t.NumUnitsInTick, err = bits.ReadUint32(buf, pos) if err != nil { return err } - t.TimeScale, err = readUint32(br) + t.TimeScale, err = bits.ReadUint32(buf, pos) if err != nil { return err } - t.FixedFrameRateFlag, err = readFlag(br) + t.FixedFrameRateFlag, err = bits.ReadFlag(buf, pos) if err != nil { return err } @@ -232,39 +153,39 @@ type SPS_BitstreamRestriction struct { //nolint:revive MaxDecFrameBuffering uint32 } -func (r *SPS_BitstreamRestriction) unmarshal(br *bitio.Reader) error { +func (r *SPS_BitstreamRestriction) unmarshal(buf []byte, pos *int) error { var err error - r.MotionVectorsOverPicBoundariesFlag, err = readFlag(br) + r.MotionVectorsOverPicBoundariesFlag, err = bits.ReadFlag(buf, pos) if err != nil { return err } - r.MaxBytesPerPicDenom, err = readGolombUnsigned(br) + r.MaxBytesPerPicDenom, err = bits.ReadGolombUnsigned(buf, pos) if err != nil { return err } - r.MaxBitsPerMbDenom, err = readGolombUnsigned(br) + r.MaxBitsPerMbDenom, err = bits.ReadGolombUnsigned(buf, pos) if err != nil { return err } - r.Log2MaxMvLengthHorizontal, err = readGolombUnsigned(br) + r.Log2MaxMvLengthHorizontal, err = bits.ReadGolombUnsigned(buf, pos) if err != nil { return err } - r.Log2MaxMvLengthVertical, err = readGolombUnsigned(br) + r.Log2MaxMvLengthVertical, err = bits.ReadGolombUnsigned(buf, pos) if err != nil { return err } - r.MaxNumReorderFrames, err = readGolombUnsigned(br) + r.MaxNumReorderFrames, err = bits.ReadGolombUnsigned(buf, pos) if err != nil { return err } - r.MaxDecFrameBuffering, err = readGolombUnsigned(br) + r.MaxDecFrameBuffering, err = bits.ReadGolombUnsigned(buf, pos) if err != nil { return err } @@ -314,160 +235,160 @@ type SPS_VUI struct { //nolint:revive BitstreamRestriction *SPS_BitstreamRestriction } -func (v *SPS_VUI) unmarshal(br *bitio.Reader) error { +func (v *SPS_VUI) unmarshal(buf []byte, pos *int) error { var err error - v.AspectRatioInfoPresentFlag, err = readFlag(br) + v.AspectRatioInfoPresentFlag, err = bits.ReadFlag(buf, pos) if err != nil { return err } if v.AspectRatioInfoPresentFlag { - v.AspectRatioIdc, err = readUint8(br) + v.AspectRatioIdc, err = bits.ReadUint8(buf, pos) if err != nil { return err } if v.AspectRatioIdc == 255 { // Extended_SAR - v.SarWidth, err = readUint16(br) + v.SarWidth, err = bits.ReadUint16(buf, pos) if err != nil { return err } - v.SarHeight, err = readUint16(br) + v.SarHeight, err = bits.ReadUint16(buf, pos) if err != nil { return err } } } - v.OverscanInfoPresentFlag, err = readFlag(br) + v.OverscanInfoPresentFlag, err = bits.ReadFlag(buf, pos) if err != nil { return err } if v.OverscanInfoPresentFlag { - v.OverscanAppropriateFlag, err = readFlag(br) + v.OverscanAppropriateFlag, err = bits.ReadFlag(buf, pos) if err != nil { return err } } - v.VideoSignalTypePresentFlag, err = readFlag(br) + v.VideoSignalTypePresentFlag, err = bits.ReadFlag(buf, pos) if err != nil { return err } if v.VideoSignalTypePresentFlag { - tmp, err := br.ReadBits(3) + tmp, err := bits.ReadBits(buf, pos, 3) if err != nil { return err } v.VideoFormat = uint8(tmp) - v.VideoFullRangeFlag, err = readFlag(br) + v.VideoFullRangeFlag, err = bits.ReadFlag(buf, pos) if err != nil { return err } - v.ColourDescriptionPresentFlag, err = readFlag(br) + v.ColourDescriptionPresentFlag, err = bits.ReadFlag(buf, pos) if err != nil { return err } if v.ColourDescriptionPresentFlag { - v.ColourPrimaries, err = readUint8(br) + v.ColourPrimaries, err = bits.ReadUint8(buf, pos) if err != nil { return err } - v.TransferCharacteristics, err = readUint8(br) + v.TransferCharacteristics, err = bits.ReadUint8(buf, pos) if err != nil { return err } - v.MatrixCoefficients, err = readUint8(br) + v.MatrixCoefficients, err = bits.ReadUint8(buf, pos) if err != nil { return err } } } - v.ChromaLocInfoPresentFlag, err = readFlag(br) + v.ChromaLocInfoPresentFlag, err = bits.ReadFlag(buf, pos) if err != nil { return err } if v.ChromaLocInfoPresentFlag { - v.ChromaSampleLocTypeTopField, err = readGolombUnsigned(br) + v.ChromaSampleLocTypeTopField, err = bits.ReadGolombUnsigned(buf, pos) if err != nil { return err } - v.ChromaSampleLocTypeBottomField, err = readGolombUnsigned(br) + v.ChromaSampleLocTypeBottomField, err = bits.ReadGolombUnsigned(buf, pos) if err != nil { return err } } - timingInfoPresentFlag, err := readFlag(br) + timingInfoPresentFlag, err := bits.ReadFlag(buf, pos) if err != nil { return err } if timingInfoPresentFlag { v.TimingInfo = &SPS_TimingInfo{} - err := v.TimingInfo.unmarshal(br) + err := v.TimingInfo.unmarshal(buf, pos) if err != nil { return err } } - nalHrdParametersPresentFlag, err := readFlag(br) + nalHrdParametersPresentFlag, err := bits.ReadFlag(buf, pos) if err != nil { return err } if nalHrdParametersPresentFlag { v.NalHRD = &SPS_HRD{} - err := v.NalHRD.unmarshal(br) + err := v.NalHRD.unmarshal(buf, pos) if err != nil { return err } } - vclHrdParametersPresentFlag, err := readFlag(br) + vclHrdParametersPresentFlag, err := bits.ReadFlag(buf, pos) if err != nil { return err } if vclHrdParametersPresentFlag { v.VclHRD = &SPS_HRD{} - err := v.VclHRD.unmarshal(br) + err := v.VclHRD.unmarshal(buf, pos) if err != nil { return err } } if nalHrdParametersPresentFlag || vclHrdParametersPresentFlag { - v.LowDelayHrdFlag, err = readFlag(br) + v.LowDelayHrdFlag, err = bits.ReadFlag(buf, pos) if err != nil { return err } } - v.PicStructPresentFlag, err = readFlag(br) + v.PicStructPresentFlag, err = bits.ReadFlag(buf, pos) if err != nil { return err } - bitstreamRestrictionFlag, err := readFlag(br) + bitstreamRestrictionFlag, err := bits.ReadFlag(buf, pos) if err != nil { return err } if bitstreamRestrictionFlag { v.BitstreamRestriction = &SPS_BitstreamRestriction{} - err := v.BitstreamRestriction.unmarshal(br) + err := v.BitstreamRestriction.unmarshal(buf, pos) if err != nil { return err } @@ -484,24 +405,24 @@ type SPS_FrameCropping struct { //nolint:revive BottomOffset uint32 } -func (c *SPS_FrameCropping) unmarshal(br *bitio.Reader) error { +func (c *SPS_FrameCropping) unmarshal(buf []byte, pos *int) error { var err error - c.LeftOffset, err = readGolombUnsigned(br) + c.LeftOffset, err = bits.ReadGolombUnsigned(buf, pos) if err != nil { return err } - c.RightOffset, err = readGolombUnsigned(br) + c.RightOffset, err = bits.ReadGolombUnsigned(buf, pos) if err != nil { return err } - c.TopOffset, err = readGolombUnsigned(br) + c.TopOffset, err = bits.ReadGolombUnsigned(buf, pos) if err != nil { return err } - c.BottomOffset, err = readGolombUnsigned(br) + c.BottomOffset, err = bits.ReadGolombUnsigned(buf, pos) if err != nil { return err } @@ -599,24 +520,24 @@ func (s *SPS) Unmarshal(buf []byte) error { s.ConstraintSet5Flag = (buf[2] >> 2 & 0x01) == 1 s.LevelIdc = buf[3] - r := bytes.NewReader(buf[4:]) - br := bitio.NewReader(r) + buf = buf[4:] + pos := 0 var err error - s.ID, err = readGolombUnsigned(br) + s.ID, err = bits.ReadGolombUnsigned(buf, &pos) if err != nil { return err } switch s.ProfileIdc { case 100, 110, 122, 244, 44, 83, 86, 118, 128, 138, 139, 134, 135: - s.ChromeFormatIdc, err = readGolombUnsigned(br) + s.ChromeFormatIdc, err = bits.ReadGolombUnsigned(buf, &pos) if err != nil { return err } if s.ChromeFormatIdc == 3 { - s.SeparateColourPlaneFlag, err = readFlag(br) + s.SeparateColourPlaneFlag, err = bits.ReadFlag(buf, &pos) if err != nil { return err } @@ -624,22 +545,22 @@ func (s *SPS) Unmarshal(buf []byte) error { s.SeparateColourPlaneFlag = false } - s.BitDepthLumaMinus8, err = readGolombUnsigned(br) + s.BitDepthLumaMinus8, err = bits.ReadGolombUnsigned(buf, &pos) if err != nil { return err } - s.BitDepthChromaMinus8, err = readGolombUnsigned(br) + s.BitDepthChromaMinus8, err = bits.ReadGolombUnsigned(buf, &pos) if err != nil { return err } - s.QpprimeYZeroTransformBypassFlag, err = readFlag(br) + s.QpprimeYZeroTransformBypassFlag, err = bits.ReadFlag(buf, &pos) if err != nil { return err } - seqScalingMatrixPresentFlag, err := readFlag(br) + seqScalingMatrixPresentFlag, err := bits.ReadFlag(buf, &pos) if err != nil { return err } @@ -653,25 +574,27 @@ func (s *SPS) Unmarshal(buf []byte) error { } for i := 0; i < lim; i++ { - seqScalingListPresentFlag, err := readFlag(br) + seqScalingListPresentFlag, err := bits.ReadFlag(buf, &pos) if err != nil { return err } if seqScalingListPresentFlag { if i < 6 { - scalingList, useDefaultScalingMatrixFlag, err := readScalingList(br, 16) + scalingList, useDefaultScalingMatrixFlag, err := readScalingList(buf, &pos, 16) if err != nil { return err } + s.ScalingList4x4 = append(s.ScalingList4x4, scalingList) s.UseDefaultScalingMatrix4x4Flag = append(s.UseDefaultScalingMatrix4x4Flag, useDefaultScalingMatrixFlag) } else { - scalingList, useDefaultScalingMatrixFlag, err := readScalingList(br, 64) + scalingList, useDefaultScalingMatrixFlag, err := readScalingList(buf, &pos, 64) if err != nil { return err } + s.ScalingList8x8 = append(s.ScalingList8x8, scalingList) s.UseDefaultScalingMatrix8x8Flag = append(s.UseDefaultScalingMatrix8x8Flag, useDefaultScalingMatrixFlag) @@ -688,19 +611,19 @@ func (s *SPS) Unmarshal(buf []byte) error { s.QpprimeYZeroTransformBypassFlag = false } - s.Log2MaxFrameNumMinus4, err = readGolombUnsigned(br) + s.Log2MaxFrameNumMinus4, err = bits.ReadGolombUnsigned(buf, &pos) if err != nil { return err } - s.PicOrderCntType, err = readGolombUnsigned(br) + s.PicOrderCntType, err = bits.ReadGolombUnsigned(buf, &pos) if err != nil { return err } switch s.PicOrderCntType { case 0: - s.Log2MaxPicOrderCntLsbMinus4, err = readGolombUnsigned(br) + s.Log2MaxPicOrderCntLsbMinus4, err = bits.ReadGolombUnsigned(buf, &pos) if err != nil { return err } @@ -708,34 +631,34 @@ func (s *SPS) Unmarshal(buf []byte) error { case 1: s.Log2MaxPicOrderCntLsbMinus4 = 0 - s.DeltaPicOrderAlwaysZeroFlag, err = readFlag(br) + s.DeltaPicOrderAlwaysZeroFlag, err = bits.ReadFlag(buf, &pos) if err != nil { return err } - s.OffsetForNonRefPic, err = readGolombSigned(br) + s.OffsetForNonRefPic, err = bits.ReadGolombSigned(buf, &pos) if err != nil { return err } - s.OffsetForTopToBottomField, err = readGolombSigned(br) + s.OffsetForTopToBottomField, err = bits.ReadGolombSigned(buf, &pos) if err != nil { return err } - numRefFramesInPicOrderCntCycle, err := readGolombUnsigned(br) + numRefFramesInPicOrderCntCycle, err := bits.ReadGolombUnsigned(buf, &pos) if err != nil { return err } - s.OffsetForRefFrames = nil + s.OffsetForRefFrames = make([]int32, numRefFramesInPicOrderCntCycle) for i := uint32(0); i < numRefFramesInPicOrderCntCycle; i++ { - v, err := readGolombSigned(br) + v, err := bits.ReadGolombSigned(buf, &pos) if err != nil { return err } - s.OffsetForRefFrames = append(s.OffsetForRefFrames, v) + s.OffsetForRefFrames[i] = v } default: @@ -746,33 +669,33 @@ func (s *SPS) Unmarshal(buf []byte) error { s.OffsetForRefFrames = nil } - s.MaxNumRefFrames, err = readGolombUnsigned(br) + s.MaxNumRefFrames, err = bits.ReadGolombUnsigned(buf, &pos) if err != nil { return err } - s.GapsInFrameNumValueAllowedFlag, err = readFlag(br) + s.GapsInFrameNumValueAllowedFlag, err = bits.ReadFlag(buf, &pos) if err != nil { return err } - s.PicWidthInMbsMinus1, err = readGolombUnsigned(br) + s.PicWidthInMbsMinus1, err = bits.ReadGolombUnsigned(buf, &pos) if err != nil { return err } - s.PicHeightInMbsMinus1, err = readGolombUnsigned(br) + s.PicHeightInMbsMinus1, err = bits.ReadGolombUnsigned(buf, &pos) if err != nil { return err } - s.FrameMbsOnlyFlag, err = readFlag(br) + s.FrameMbsOnlyFlag, err = bits.ReadFlag(buf, &pos) if err != nil { return err } if !s.FrameMbsOnlyFlag { - s.MbAdaptiveFrameFieldFlag, err = readFlag(br) + s.MbAdaptiveFrameFieldFlag, err = bits.ReadFlag(buf, &pos) if err != nil { return err } @@ -780,19 +703,19 @@ func (s *SPS) Unmarshal(buf []byte) error { s.MbAdaptiveFrameFieldFlag = false } - s.Direct8x8InferenceFlag, err = readFlag(br) + s.Direct8x8InferenceFlag, err = bits.ReadFlag(buf, &pos) if err != nil { return err } - frameCroppingFlag, err := readFlag(br) + frameCroppingFlag, err := bits.ReadFlag(buf, &pos) if err != nil { return err } if frameCroppingFlag { s.FrameCropping = &SPS_FrameCropping{} - err := s.FrameCropping.unmarshal(br) + err := s.FrameCropping.unmarshal(buf, &pos) if err != nil { return err } @@ -800,14 +723,14 @@ func (s *SPS) Unmarshal(buf []byte) error { s.FrameCropping = nil } - vuiParameterPresentFlag, err := readFlag(br) + vuiParameterPresentFlag, err := bits.ReadFlag(buf, &pos) if err != nil { return err } if vuiParameterPresentFlag { s.VUI = &SPS_VUI{} - err := s.VUI.unmarshal(br) + err := s.VUI.unmarshal(buf, &pos) if err != nil { return err } diff --git a/pkg/h264/sps_test.go b/pkg/h264/sps_test.go index 510b775d..5bde4561 100644 --- a/pkg/h264/sps_test.go +++ b/pkg/h264/sps_test.go @@ -435,3 +435,18 @@ func TestSPSUnmarshal(t *testing.T) { }) } } + +func BenchmarkSPSUnmarshal(b *testing.B) { + for i := 0; i < b.N; i++ { + var sps SPS + sps.Unmarshal([]byte{ + 103, 77, 0, 41, 154, 100, 3, 192, + 17, 63, 46, 2, 220, 4, 4, 5, + 0, 0, 3, 3, 232, 0, 0, 195, + 80, 232, 96, 0, 186, 180, 0, 2, + 234, 196, 187, 203, 141, 12, 0, 23, + 86, 128, 0, 93, 88, 151, 121, 112, + 160, + }) + } +} diff --git a/pkg/rtpaac/decoder.go b/pkg/rtpaac/decoder.go index c16acc36..20a36d54 100644 --- a/pkg/rtpaac/decoder.go +++ b/pkg/rtpaac/decoder.go @@ -1,16 +1,15 @@ package rtpaac import ( - "bytes" "encoding/binary" "errors" "fmt" "time" - "github.com/icza/bitio" "github.com/pion/rtp" "github.com/aler9/gortsplib/pkg/aac" + "github.com/aler9/gortsplib/pkg/bits" "github.com/aler9/gortsplib/pkg/rtptimedec" ) @@ -177,8 +176,7 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, time.Duration, error) { return [][]byte{ret}, d.timeDecoder.Decode(pkt.Timestamp), nil } -func (d *Decoder) readAUHeaders(payload []byte, headersLen int) ([]uint64, error) { - br := bitio.NewReader(bytes.NewReader(payload)) +func (d *Decoder) readAUHeaders(buf []byte, headersLen int) ([]uint64, error) { firstRead := false count := 0 @@ -195,9 +193,11 @@ func (d *Decoder) readAUHeaders(payload []byte, headersLen int) ([]uint64, error dataLens := make([]uint64, count) + pos := 0 i := 0 + for headersLen > 0 { - dataLen, err := br.ReadBits(uint8(d.SizeLength)) + dataLen, err := bits.ReadBits(buf, &pos, d.SizeLength) if err != nil { return nil, err } @@ -206,7 +206,7 @@ func (d *Decoder) readAUHeaders(payload []byte, headersLen int) ([]uint64, error if !firstRead { firstRead = true if d.IndexLength > 0 { - auIndex, err := br.ReadBits(uint8(d.IndexLength)) + auIndex, err := bits.ReadBits(buf, &pos, d.IndexLength) if err != nil { return nil, err } @@ -217,7 +217,7 @@ func (d *Decoder) readAUHeaders(payload []byte, headersLen int) ([]uint64, error } } } else if d.IndexDeltaLength > 0 { - auIndexDelta, err := br.ReadBits(uint8(d.IndexDeltaLength)) + auIndexDelta, err := bits.ReadBits(buf, &pos, d.IndexDeltaLength) if err != nil { return nil, err } diff --git a/pkg/rtpaac/encoder.go b/pkg/rtpaac/encoder.go index 869d029d..af46461a 100644 --- a/pkg/rtpaac/encoder.go +++ b/pkg/rtpaac/encoder.go @@ -1,15 +1,14 @@ package rtpaac import ( - "bytes" "crypto/rand" "encoding/binary" "time" - "github.com/icza/bitio" "github.com/pion/rtp" "github.com/aler9/gortsplib/pkg/aac" + "github.com/aler9/gortsplib/pkg/bits" ) func randUint32() uint32 { @@ -161,10 +160,9 @@ func (e *Encoder) writeFragmented(au []byte, pts time.Duration) ([]*rtp.Packet, binary.BigEndian.PutUint16(byts, uint16(auHeadersLen)) // AU-headers - bw := bitio.NewWriter(bytes.NewBuffer(byts[2:2])) - bw.WriteBits(uint64(le), uint8(e.SizeLength)) - bw.WriteBits(0, uint8(e.IndexLength)) - bw.Close() + pos := 0 + bits.WriteBits(byts[2:], &pos, uint64(le), e.SizeLength) + bits.WriteBits(byts[2:], &pos, 0, e.IndexLength) // AU copy(byts[2+auHeadersLenBytes:], au[:le]) @@ -228,20 +226,19 @@ func (e *Encoder) writeAggregated(aus [][]byte, firstPTS time.Duration) ([]*rtp. // AU-headers written := 0 - bw := bitio.NewWriter(bytes.NewBuffer(payload[2:2])) + pos := 0 for i, au := range aus { - bw.WriteBits(uint64(len(au)), uint8(e.SizeLength)) + bits.WriteBits(payload[2:], &pos, uint64(len(au)), e.SizeLength) written += e.SizeLength if i == 0 { - bw.WriteBits(0, uint8(e.IndexLength)) + bits.WriteBits(payload[2:], &pos, 0, e.IndexLength) written += e.IndexLength } else { - bw.WriteBits(0, uint8(e.IndexDeltaLength)) + bits.WriteBits(payload[2:], &pos, 0, e.IndexDeltaLength) written += e.IndexDeltaLength } } - bw.Close() - pos := 2 + (written / 8) + pos = 2 + (written / 8) if (written % 8) != 0 { pos++ } diff --git a/pkg/rtpaac/rtpaac_test.go b/pkg/rtpaac/rtpaac_test.go index a463bf06..4b8f6189 100644 --- a/pkg/rtpaac/rtpaac_test.go +++ b/pkg/rtpaac/rtpaac_test.go @@ -616,7 +616,7 @@ func TestDecodeErrors(t *testing.T) { Payload: []byte{0x00, 0x10}, }, }, - "EOF", + "not enough bits", }, { "missing au",