diff --git a/pkg/h264/sps.go b/pkg/h264/sps.go index 71c159cf..278dbf16 100644 --- a/pkg/h264/sps.go +++ b/pkg/h264/sps.go @@ -58,10 +58,33 @@ func readFlag(br *bitio.Reader) (bool, error) { if err != nil { return false, err } - return (tmp == 1), nil } +func readUint8(br *bitio.Reader) (uint8, error) { + tmp, err := br.ReadBits(8) + if err != nil { + return 0, err + } + return uint8(tmp), nil +} + +func readUint16(br *bitio.Reader) (uint16, error) { + tmp, err := br.ReadBits(16) + if err != nil { + return 0, err + } + return uint16(tmp), nil +} + +func readUint32(br *bitio.Reader) (uint32, error) { + tmp, err := br.ReadBits(32) + if err != nil { + return 0, err + } + return uint32(tmp), nil +} + func readScalingList(br *bitio.Reader, size int) ([]int32, bool, error) { lastScale := int32(8) nextScale := int32(8) @@ -179,17 +202,16 @@ type SPS_TimingInfo struct { //nolint:revive } func (t *SPS_TimingInfo) unmarshal(br *bitio.Reader) error { - tmp, err := br.ReadBits(32) + var err error + t.NumUnitsInTick, err = readUint32(br) if err != nil { return err } - t.NumUnitsInTick = uint32(tmp) - tmp, err = br.ReadBits(32) + t.TimeScale, err = readUint32(br) if err != nil { return err } - t.TimeScale = uint32(tmp) t.FixedFrameRateFlag, err = readFlag(br) if err != nil { @@ -256,24 +278,21 @@ func (v *SPS_VUI) unmarshal(br *bitio.Reader) error { } if v.AspectRatioInfoPresentFlag { - tmp, err := br.ReadBits(8) + v.AspectRatioIdc, err = readUint8(br) if err != nil { return err } - v.AspectRatioIdc = uint8(tmp) if v.AspectRatioIdc == 255 { // Extended_SAR - tmp, err := br.ReadBits(16) + v.SarWidth, err = readUint16(br) if err != nil { return err } - v.SarWidth = uint16(tmp) - tmp, err = br.ReadBits(16) + v.SarHeight, err = readUint16(br) if err != nil { return err } - v.SarHeight = uint16(tmp) } } @@ -312,23 +331,20 @@ func (v *SPS_VUI) unmarshal(br *bitio.Reader) error { } if v.ColourDescriptionPresentFlag { - tmp, err := br.ReadBits(8) + v.ColourPrimaries, err = readUint8(br) if err != nil { return err } - v.ColourPrimaries = uint8(tmp) - tmp, err = br.ReadBits(8) + v.TransferCharacteristics, err = readUint8(br) if err != nil { return err } - v.TransferCharacteristics = uint8(tmp) - tmp, err = br.ReadBits(8) + v.MatrixCoefficients, err = readUint8(br) if err != nil { return err } - v.MatrixCoefficients = uint8(tmp) } }