diff --git a/pkg/rtpaac/decoder.go b/pkg/rtpaac/decoder.go index e2e455ad..78581d36 100644 --- a/pkg/rtpaac/decoder.go +++ b/pkg/rtpaac/decoder.go @@ -57,6 +57,19 @@ func (d *Decoder) Decode(byts []byte) ([][]byte, time.Duration, error) { // DecodeRTP decodes AUs from a rtp.Packet. func (d *Decoder) DecodeRTP(pkt *rtp.Packet) ([][]byte, time.Duration, error) { + if len(pkt.Payload) < 2 { + d.state = decoderStateInitial + return nil, 0, fmt.Errorf("payload is too short") + } + + // AU-headers-length + headersLen := binary.BigEndian.Uint16(pkt.Payload) + if (headersLen % 16) != 0 { + d.state = decoderStateInitial + return nil, 0, fmt.Errorf("invalid AU-headers-length (%d)", headersLen) + } + pkt.Payload = pkt.Payload[2:] + switch d.state { case decoderStateInitial: if !d.initialTsSet { @@ -65,17 +78,6 @@ func (d *Decoder) DecodeRTP(pkt *rtp.Packet) ([][]byte, time.Duration, error) { } if pkt.Header.Marker { - if len(pkt.Payload) < 2 { - return nil, 0, fmt.Errorf("payload is too short") - } - - // AU-headers-length - headersLen := binary.BigEndian.Uint16(pkt.Payload) - if (headersLen % 16) != 0 { - return nil, 0, fmt.Errorf("invalid AU-headers-length (%d)", headersLen) - } - pkt.Payload = pkt.Payload[2:] - // AU-headers // AAC headers are 16 bits, where // * 13 bits are data size @@ -112,49 +114,47 @@ func (d *Decoder) DecodeRTP(pkt *rtp.Packet) ([][]byte, time.Duration, error) { return aus, d.decodeTimestamp(pkt.Timestamp), nil } - // AU-headers-length - headersLen := binary.BigEndian.Uint16(pkt.Payload) if headersLen != 16 { - return nil, 0, fmt.Errorf("invalid AU-headers-length (%d)", headersLen) + return nil, 0, fmt.Errorf("a fragmented packet can only contain one AU") } // AU-header - header := binary.BigEndian.Uint16(pkt.Payload[2:]) + header := binary.BigEndian.Uint16(pkt.Payload) dataLen := header >> 3 auIndex := header & 0x03 if auIndex != 0 { return nil, 0, fmt.Errorf("AU-index field is not zero") } + pkt.Payload = pkt.Payload[2:] if len(pkt.Payload) < int(dataLen) { return nil, 0, fmt.Errorf("payload is too short") } - d.fragmentedBuf = append(d.fragmentedBuf, pkt.Payload[4:]...) + d.fragmentedBuf = append(d.fragmentedBuf, pkt.Payload...) d.state = decoderStateReadingFragmented return nil, 0, ErrMorePacketsNeeded default: // decoderStateReadingFragmented - // AU-headers-length - headersLen := binary.BigEndian.Uint16(pkt.Payload) if headersLen != 16 { - return nil, 0, fmt.Errorf("invalid AU-headers-length (%d)", headersLen) + return nil, 0, fmt.Errorf("a fragmented packet can only contain one AU") } // AU-header - header := binary.BigEndian.Uint16(pkt.Payload[2:]) + header := binary.BigEndian.Uint16(pkt.Payload) dataLen := header >> 3 auIndex := header & 0x03 if auIndex != 0 { return nil, 0, fmt.Errorf("AU-index field is not zero") } + pkt.Payload = pkt.Payload[2:] if len(pkt.Payload) < int(dataLen) { return nil, 0, fmt.Errorf("payload is too short") } - d.fragmentedBuf = append(d.fragmentedBuf, pkt.Payload[4:]...) + d.fragmentedBuf = append(d.fragmentedBuf, pkt.Payload...) if !pkt.Header.Marker { return nil, 0, ErrMorePacketsNeeded diff --git a/pkg/rtpaac/rtpaac_test.go b/pkg/rtpaac/rtpaac_test.go index 3d74766d..a38d34ba 100644 --- a/pkg/rtpaac/rtpaac_test.go +++ b/pkg/rtpaac/rtpaac_test.go @@ -152,7 +152,7 @@ var cases = []struct { { "fragmented", [][]byte{ - bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 256), + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 512), }, 0, [][]byte{ @@ -165,10 +165,17 @@ var cases = []struct { ), mergeBytes( []byte{ - 0x80, 0xe0, 0x44, 0xee, 0x88, 0x77, 0x66, 0x55, - 0x9d, 0xbb, 0x78, 0x12, 0x00, 0x10, 0x02, 0x50, + 0x80, 0x60, 0x44, 0xee, 0x88, 0x77, 0x66, 0x55, + 0x9d, 0xbb, 0x78, 0x12, 0x00, 0x10, 0x05, 0xb0, }, - bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 74), + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182), + ), + mergeBytes( + []byte{ + 0x80, 0xe0, 0x44, 0xef, 0x88, 0x77, 0x66, 0x55, + 0x9d, 0xbb, 0x78, 0x12, 0x00, 0x10, 0x04, 0xa0, + }, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 148), ), }, }, @@ -257,46 +264,153 @@ func TestDecode(t *testing.T) { func TestDecodeErrors(t *testing.T) { for _, ca := range []struct { name string - byts []byte + pkts [][]byte err string }{ { "missing payload", - []byte{ - 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, - 0x9d, 0xbb, 0x78, 0x12, + [][]byte{ + { + 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, + 0x9d, 0xbb, 0x78, 0x12, + }, }, "payload is too short", }, { "missing au header", - []byte{ - 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, - 0x9d, 0xbb, 0x78, 0x12, 0x00, 0x10, + [][]byte{ + { + 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, + 0x9d, 0xbb, 0x78, 0x12, 0x00, 0x10, + }, }, "payload is too short", }, { "missing au", - []byte{ - 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, - 0x9d, 0xbb, 0x78, 0x12, 0x00, 0x10, 0x0a, 0xd8, + [][]byte{ + { + 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, + 0x9d, 0xbb, 0x78, 0x12, 0x00, 0x10, 0x0a, 0xd8, + }, }, "payload is too short", }, + { + "invalid au headers length", + [][]byte{ + { + 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, + 0x9d, 0xbb, 0x78, 0x12, 0x00, 0x09, + }, + }, + "invalid AU-headers-length (9)", + }, { "au index not zero", - []byte{ - 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, - 0x9d, 0xbb, 0x78, 0x12, 0x00, 0x10, 0x0a, 0xd8 | 0x01, + [][]byte{ + { + 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, + 0x9d, 0xbb, 0x78, 0x12, 0x00, 0x10, 0x0a, 0xd8 | 0x01, + }, }, "AU-index field is not zero", }, + { + "fragmented with multiple AUs", + [][]byte{ + { + 0x80, 0x60, 0x0e, 0xa2, 0x0e, 0x01, 0x9b, 0xb7, + 0x35, 0x6e, 0xcb, 0x3b, 0x00, 0x20, + }, + }, + "a fragmented packet can only contain one AU", + }, + { + "fragmented with AU index not zero", + [][]byte{ + { + 0x80, 0x60, 0x0e, 0xa2, 0x0e, 0x01, 0x9b, 0xb7, + 0x35, 0x6e, 0xcb, 0x3b, 0x00, 0x10, 0x0a, 0xd8 | 0x01, + }, + }, + "AU-index field is not zero", + }, + { + "fragmented with missing au", + [][]byte{ + { + 0x80, 0x60, 0x0e, 0xa2, 0x0e, 0x01, 0x9b, 0xb7, + 0x35, 0x6e, 0xcb, 0x3b, 0x00, 0x10, 0x0a, 0xd8, + }, + }, + "payload is too short", + }, + { + "fragmented with multiple AUs in 2nd packet", + [][]byte{ + mergeBytes( + []byte{ + 0x80, 0x60, 0x44, 0xed, 0x88, 0x77, 0x66, 0x55, + 0x9d, 0xbb, 0x78, 0x12, 0x0, 0x10, 0x5, 0xb0, + }, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182), + ), + mergeBytes( + []byte{ + 0x80, 0xe0, 0x44, 0xee, 0x88, 0x77, 0x66, 0x55, + 0x9d, 0xbb, 0x78, 0x12, 0x00, 0x20, + }, + ), + }, + "a fragmented packet can only contain one AU", + }, + { + "fragmented with au index not zero in 2nd packet", + [][]byte{ + mergeBytes( + []byte{ + 0x80, 0x60, 0x44, 0xed, 0x88, 0x77, 0x66, 0x55, + 0x9d, 0xbb, 0x78, 0x12, 0x0, 0x10, 0x5, 0xb0, + }, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182), + ), + mergeBytes( + []byte{ + 0x80, 0xe0, 0x44, 0xee, 0x88, 0x77, 0x66, 0x55, + 0x9d, 0xbb, 0x78, 0x12, 0x00, 0x10, 0x0a, 0xd8 | 0x01, + }, + ), + }, + "AU-index field is not zero", + }, + { + "fragmented without payload in 2nd packet", + [][]byte{ + mergeBytes( + []byte{ + 0x80, 0x60, 0x44, 0xed, 0x88, 0x77, 0x66, 0x55, + 0x9d, 0xbb, 0x78, 0x12, 0x0, 0x10, 0x5, 0xb0, + }, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182), + ), + mergeBytes( + []byte{ + 0x80, 0xe0, 0x44, 0xee, 0x88, 0x77, 0x66, 0x55, + 0x9d, 0xbb, 0x78, 0x12, 0x00, 0x10, 0x0a, 0xd8, + }, + ), + }, + "payload is too short", + }, } { t.Run(ca.name, func(t *testing.T) { d := NewDecoder(48000) - _, _, err := d.Decode(ca.byts) - require.NotEqual(t, ErrMorePacketsNeeded, err) + var err error + for _, pkt := range ca.pkts { + _, _, err = d.Decode(pkt) + } require.Equal(t, ca.err, err.Error()) }) } diff --git a/pkg/rtph264/rtph264_test.go b/pkg/rtph264/rtph264_test.go index c8487720..4cf2bbd3 100644 --- a/pkg/rtph264/rtph264_test.go +++ b/pkg/rtph264/rtph264_test.go @@ -390,7 +390,6 @@ func TestDecodeErrors(t *testing.T) { for _, pkt := range ca.pkts { _, _, err = d.Decode(pkt) } - require.NotEqual(t, ErrMorePacketsNeeded, err) require.Equal(t, ca.err, err.Error()) }) }