diff --git a/pkg/format/rtpav1/decoder.go b/pkg/format/rtpav1/decoder.go index fb24dd2e..91757d17 100644 --- a/pkg/format/rtpav1/decoder.go +++ b/pkg/format/rtpav1/decoder.go @@ -28,6 +28,14 @@ func joinFragments(fragments [][]byte, size int) []byte { return ret } +func tuSize(tu [][]byte) int { + s := 0 + for _, obu := range tu { + s += len(obu) + } + return s +} + // Decoder is a RTP/AV1 decoder. // Specification: https://aomediacodec.github.io/av1-rtp-spec/ type Decoder struct { @@ -107,15 +115,7 @@ func (d *Decoder) decodeOBUs(pkt *rtp.Packet) ([][]byte, error) { if av1header.Y { elementCount := len(av1header.OBUElements) - d.fragmentsSize += len(av1header.OBUElements[elementCount-1]) - - if d.fragmentsSize > av1.MaxTemporalUnitSize { - errSize := d.fragmentsSize - d.resetFragments() - return nil, fmt.Errorf("temporal unit size (%d) is too big, maximum is %d", - errSize, av1.MaxTemporalUnitSize) - } - + d.fragmentsSize = len(av1header.OBUElements[elementCount-1]) d.fragments = append(d.fragments, av1header.OBUElements[elementCount-1]) av1header.OBUElements = av1header.OBUElements[:elementCount-1] d.fragmentNextSeqNum = pkt.SequenceNumber + 1 @@ -151,11 +151,7 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, error) { errCount, av1.MaxOBUsPerTemporalUnit) } - addSize := 0 - - for _, obu := range obus { - addSize += len(obu) - } + addSize := tuSize(obus) if (d.frameBufferSize + addSize) > av1.MaxTemporalUnitSize { errSize := d.frameBufferSize + addSize diff --git a/pkg/format/rtpav1/decoder_test.go b/pkg/format/rtpav1/decoder_test.go index a16bee8d..182518a8 100644 --- a/pkg/format/rtpav1/decoder_test.go +++ b/pkg/format/rtpav1/decoder_test.go @@ -1,6 +1,7 @@ package rtpav1 import ( + "bytes" "errors" "testing" @@ -33,6 +34,48 @@ func TestDecode(t *testing.T) { } } +func TestDecoderErrorTUSize(t *testing.T) { + d := &Decoder{} + err := d.Init() + require.NoError(t, err) + + size := 0 + i := uint16(0) + + for size < av1.MaxTemporalUnitSize { + var header byte + if i == 0 { + header = 0b01000000 + } else { + header = 0b11000000 + } + + fragmentLenLEB := av1.LEB128(1400) + buf := make([]byte, fragmentLenLEB.MarshalSize()) + fragmentLenLEB.MarshalTo(buf) + + payload := append([]byte{header}, buf...) + payload = append(payload, bytes.Repeat([]byte{1, 2, 3, 4}, 1400/4)...) + + _, err = d.Decode(&rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: false, + PayloadType: 96, + SequenceNumber: 17645 + i, + Timestamp: 2289527317, + SSRC: 0x9dbb7812, + }, + Payload: payload, + }) + + size += 1400 + i++ + } + + require.EqualError(t, err, "temporal unit size (3145800) is too big, maximum is 3145728") +} + func TestDecoderErrorOBUCount(t *testing.T) { d := &Decoder{} err := d.Init() diff --git a/pkg/format/rtph264/decoder.go b/pkg/format/rtph264/decoder.go index 8ccc781b..519c045d 100644 --- a/pkg/format/rtph264/decoder.go +++ b/pkg/format/rtph264/decoder.go @@ -38,6 +38,14 @@ func isAllZero(buf []byte) bool { return true } +func auSize(au [][]byte) int { + s := 0 + for _, nalu := range au { + s += len(nalu) + } + return s +} + // Decoder is a RTP/H264 decoder. // Specification: https://datatracker.ietf.org/doc/html/rfc6184 type Decoder struct { @@ -223,11 +231,7 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, error) { errCount, h264.MaxNALUsPerAccessUnit) } - addSize := 0 - - for _, nalu := range nalus { - addSize += len(nalu) - } + addSize := auSize(nalus) if (d.frameBufferSize + addSize) > h264.MaxAccessUnitSize { errSize := d.frameBufferSize + addSize diff --git a/pkg/format/rtph265/decoder.go b/pkg/format/rtph265/decoder.go index 4777a06c..837af18e 100644 --- a/pkg/format/rtph265/decoder.go +++ b/pkg/format/rtph265/decoder.go @@ -28,6 +28,14 @@ func joinFragments(fragments [][]byte, size int) []byte { return ret } +func auSize(au [][]byte) int { + s := 0 + for _, nalu := range au { + s += len(nalu) + } + return s +} + // Decoder is a RTP/H265 decoder. // Specification: https://datatracker.ietf.org/doc/html/rfc7798 type Decoder struct { @@ -182,11 +190,7 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, error) { errCount, h265.MaxNALUsPerAccessUnit) } - addSize := 0 - - for _, nalu := range nalus { - addSize += len(nalu) - } + addSize := auSize(nalus) if (d.frameBufferSize + addSize) > h265.MaxAccessUnitSize { errSize := d.frameBufferSize + addSize