diff --git a/pkg/rtph264/decoder.go b/pkg/rtph264/decoder.go index af90a394..af690457 100644 --- a/pkg/rtph264/decoder.go +++ b/pkg/rtph264/decoder.go @@ -14,6 +14,10 @@ import ( // ErrMorePacketsNeeded is returned when more packets are needed. var ErrMorePacketsNeeded = errors.New("need more packets") +// ErrNonStartingPacketAndNoPrevious is returned when we decoded a non-starting +// fragmented packet and we didn't received anything before. +var ErrNonStartingPacketAndNoPrevious = errors.New("decoded a non-starting fragmented packet without any previous starting packets") + // PacketConnReader creates a io.Reader around a net.PacketConn. type PacketConnReader struct { net.PacketConn @@ -38,8 +42,9 @@ type Decoder struct { initialTsSet bool // for Decode() - state decoderState - fragmentedBuf []byte + startingPacketReceived bool + state decoderState + fragmentedBuf []byte } // NewDecoder allocates a Decoder. @@ -87,7 +92,7 @@ func (d *Decoder) DecodeRTP(pkt *rtp.Packet) ([][]byte, time.Duration, error) { for len(pkt.Payload) > 0 { if len(pkt.Payload) < 2 { - return nil, 0, fmt.Errorf("Invalid STAP-A packet") + return nil, 0, fmt.Errorf("invalid STAP-A packet (invalid size)") } size := binary.BigEndian.Uint16(pkt.Payload) @@ -99,7 +104,7 @@ func (d *Decoder) DecodeRTP(pkt *rtp.Packet) ([][]byte, time.Duration, error) { } if int(size) > len(pkt.Payload) { - return nil, 0, fmt.Errorf("Invalid STAP-A packet") + return nil, 0, fmt.Errorf("invalid STAP-A packet (invalid size)") } nalus = append(nalus, pkt.Payload[:size]) @@ -110,16 +115,20 @@ func (d *Decoder) DecodeRTP(pkt *rtp.Packet) ([][]byte, time.Duration, error) { return nil, 0, fmt.Errorf("STAP-A packet doesn't contain any NALU") } + d.startingPacketReceived = true return nalus, d.decodeTimestamp(pkt.Timestamp), nil case naluTypeFUA: // first packet of a fragmented NALU if len(pkt.Payload) < 2 { - return nil, 0, fmt.Errorf("Invalid FU-A packet") + return nil, 0, fmt.Errorf("invalid FU-A packet (invalid size)") } start := pkt.Payload[1] >> 7 if start != 1 { - return nil, 0, fmt.Errorf("first NALU does not contain the start bit") + if !d.startingPacketReceived { + return nil, 0, ErrNonStartingPacketAndNoPrevious + } + return nil, 0, fmt.Errorf("invalid FU-A packet (non-starting)") } nri := (pkt.Payload[0] >> 5) & 0x03 @@ -127,25 +136,27 @@ func (d *Decoder) DecodeRTP(pkt *rtp.Packet) ([][]byte, time.Duration, error) { d.fragmentedBuf = append([]byte{(nri << 5) | typ}, pkt.Payload[2:]...) d.state = decoderStateReadingFragmented + d.startingPacketReceived = true return nil, 0, ErrMorePacketsNeeded case naluTypeSTAPB, naluTypeMTAP16, naluTypeMTAP24, naluTypeFUB: - return nil, 0, fmt.Errorf("NALU type not supported (%v)", typ) + return nil, 0, fmt.Errorf("packet type not supported (%v)", typ) } + d.startingPacketReceived = true return [][]byte{pkt.Payload}, d.decodeTimestamp(pkt.Timestamp), nil default: // decoderStateReadingFragmented if len(pkt.Payload) < 2 { d.state = decoderStateInitial - return nil, 0, fmt.Errorf("Invalid non-starting FU-A packet") + return nil, 0, fmt.Errorf("invalid FU-A packet (invalid size)") } typ := naluType(pkt.Payload[0] & 0x1F) if typ != naluTypeFUA { d.state = decoderStateInitial - return nil, 0, fmt.Errorf("Packet is not FU-A") + return nil, 0, fmt.Errorf("expected FU-A packet, got another type") } end := (pkt.Payload[1] >> 6) & 0x01 @@ -157,6 +168,7 @@ func (d *Decoder) DecodeRTP(pkt *rtp.Packet) ([][]byte, time.Duration, error) { } d.state = decoderStateInitial + d.startingPacketReceived = true return [][]byte{d.fragmentedBuf}, d.decodeTimestamp(pkt.Timestamp), nil } } diff --git a/pkg/rtph264/rtph264_test.go b/pkg/rtph264/rtph264_test.go index cb0b5683..fe54f8da 100644 --- a/pkg/rtph264/rtph264_test.go +++ b/pkg/rtph264/rtph264_test.go @@ -265,6 +265,29 @@ func TestDecode(t *testing.T) { } } +func TestDecodePartOfFragmentedBeforeSingle(t *testing.T) { + d := NewDecoder() + + _, _, err := d.Decode(mergeBytes( + []byte{ + 0x80, 0xe0, 0x44, 0xef, 0x88, 0x77, 0x79, 0xab, + 0x9d, 0xbb, 0x78, 0x12, 0x1c, 0x45, + }, + []byte{0x04, 0x05, 0x06, 0x07}, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 147), + )) + require.Equal(t, ErrNonStartingPacketAndNoPrevious, err) + + _, _, err = d.Decode(mergeBytes( + []byte{ + 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6f, 0x1f, + 0x9d, 0xbb, 0x78, 0x12, 0x05, + }, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 8), + )) + require.NoError(t, err) +} + func TestDecodeSTAPAWithPadding(t *testing.T) { d := NewDecoder() nalus, _, err := d.Decode([]byte{ @@ -318,7 +341,7 @@ func TestDecodeErrors(t *testing.T) { 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, 0x9d, 0xbb, 0x78, 0x12, byte(naluTypeSTAPA), 0x01, }}, - "Invalid STAP-A packet", + "invalid STAP-A packet (invalid size)", }, { "STAP-A with invalid size", @@ -326,7 +349,7 @@ func TestDecodeErrors(t *testing.T) { 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, 0x9d, 0xbb, 0x78, 0x12, byte(naluTypeSTAPA), 0x00, 0x15, }}, - "Invalid STAP-A packet", + "invalid STAP-A packet (invalid size)", }, { "FU-A without payload", @@ -334,15 +357,24 @@ func TestDecodeErrors(t *testing.T) { 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, 0x9d, 0xbb, 0x78, 0x12, byte(naluTypeFUA), }}, - "Invalid FU-A packet", + "invalid FU-A packet (invalid size)", }, { "FU-A without start bit", - [][]byte{{ - 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, - 0x9d, 0xbb, 0x78, 0x12, byte(naluTypeFUA), 0x00, - }}, - "first NALU does not contain the start bit", + [][]byte{ + mergeBytes( + []byte{ + 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6f, 0x1f, + 0x9d, 0xbb, 0x78, 0x12, 0x05, + }, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 8), + ), + { + 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, + 0x9d, 0xbb, 0x78, 0x12, byte(naluTypeFUA), 0x00, + }, + }, + "invalid FU-A packet (non-starting)", }, { "FU-A with 2nd packet empty", @@ -362,7 +394,7 @@ func TestDecodeErrors(t *testing.T) { }, ), }, - "Invalid non-starting FU-A packet", + "invalid FU-A packet (invalid size)", }, { "FU-A with 2nd packet invalid", @@ -382,7 +414,7 @@ func TestDecodeErrors(t *testing.T) { }, ), }, - "Packet is not FU-A", + "expected FU-A packet, got another type", }, { "MTAP", @@ -390,7 +422,7 @@ func TestDecodeErrors(t *testing.T) { 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, 0x9d, 0xbb, 0x78, 0x12, byte(naluTypeMTAP16), }}, - "NALU type not supported (MTAP16)", + "packet type not supported (MTAP16)", }, } { t.Run(ca.name, func(t *testing.T) {