rtph264: add error in case of a fragmented NALU with two starting packets

This commit is contained in:
aler9
2021-07-24 17:14:39 +02:00
parent b60c4a65b2
commit 4ac9cda1fe
2 changed files with 61 additions and 38 deletions

View File

@@ -15,8 +15,10 @@ import (
var ErrMorePacketsNeeded = errors.New("need more packets") var ErrMorePacketsNeeded = errors.New("need more packets")
// ErrNonStartingPacketAndNoPrevious is returned when we decoded a non-starting // ErrNonStartingPacketAndNoPrevious is returned when we decoded a non-starting
// fragmented packet and we didn't received anything before. // packet of a fragmented NALU and we didn't received anything before.
var ErrNonStartingPacketAndNoPrevious = errors.New("decoded a non-starting fragmented packet without any previous starting packets") // It's normal to receive this when we are decoding a stream that has been already
// running for some time.
var ErrNonStartingPacketAndNoPrevious = errors.New("decoded a non-starting fragmented packet without any previous starting packet")
// PacketConnReader creates a io.Reader around a net.PacketConn. // PacketConnReader creates a io.Reader around a net.PacketConn.
type PacketConnReader struct { type PacketConnReader struct {
@@ -29,13 +31,6 @@ func (r PacketConnReader) Read(p []byte) (int, error) {
return n, err return n, err
} }
type decoderState int
const (
decoderStateInitial decoderState = iota
decoderStateReadingFragmented
)
// Decoder is a RTP/H264 decoder. // Decoder is a RTP/H264 decoder.
type Decoder struct { type Decoder struct {
initialTs uint32 initialTs uint32
@@ -43,7 +38,7 @@ type Decoder struct {
// for Decode() // for Decode()
startingPacketReceived bool startingPacketReceived bool
state decoderState isReadingFragmented bool
fragmentedBuf []byte fragmentedBuf []byte
} }
@@ -63,7 +58,7 @@ func (d *Decoder) Decode(byts []byte) ([][]byte, time.Duration, error) {
pkt := rtp.Packet{} pkt := rtp.Packet{}
err := pkt.Unmarshal(byts) err := pkt.Unmarshal(byts)
if err != nil { if err != nil {
d.state = decoderStateInitial d.isReadingFragmented = false
return nil, 0, err return nil, 0, err
} }
@@ -72,8 +67,7 @@ func (d *Decoder) Decode(byts []byte) ([][]byte, time.Duration, error) {
// DecodeRTP decodes NALUs from a rtp.Packet. // DecodeRTP decodes NALUs from a rtp.Packet.
func (d *Decoder) DecodeRTP(pkt *rtp.Packet) ([][]byte, time.Duration, error) { func (d *Decoder) DecodeRTP(pkt *rtp.Packet) ([][]byte, time.Duration, error) {
switch d.state { if !d.isReadingFragmented {
case decoderStateInitial:
if !d.initialTsSet { if !d.initialTsSet {
d.initialTsSet = true d.initialTsSet = true
d.initialTs = pkt.Timestamp d.initialTs = pkt.Timestamp
@@ -135,7 +129,7 @@ func (d *Decoder) DecodeRTP(pkt *rtp.Packet) ([][]byte, time.Duration, error) {
typ := pkt.Payload[1] & 0x1F typ := pkt.Payload[1] & 0x1F
d.fragmentedBuf = append([]byte{(nri << 5) | typ}, pkt.Payload[2:]...) d.fragmentedBuf = append([]byte{(nri << 5) | typ}, pkt.Payload[2:]...)
d.state = decoderStateReadingFragmented d.isReadingFragmented = true
d.startingPacketReceived = true d.startingPacketReceived = true
return nil, 0, ErrMorePacketsNeeded return nil, 0, ErrMorePacketsNeeded
@@ -146,31 +140,38 @@ func (d *Decoder) DecodeRTP(pkt *rtp.Packet) ([][]byte, time.Duration, error) {
d.startingPacketReceived = true d.startingPacketReceived = true
return [][]byte{pkt.Payload}, d.decodeTimestamp(pkt.Timestamp), nil return [][]byte{pkt.Payload}, d.decodeTimestamp(pkt.Timestamp), nil
}
// we are decoding a fragmented packet
default: // decoderStateReadingFragmented
if len(pkt.Payload) < 2 { if len(pkt.Payload) < 2 {
d.state = decoderStateInitial d.isReadingFragmented = false
return nil, 0, fmt.Errorf("invalid FU-A packet (invalid size)") return nil, 0, fmt.Errorf("invalid FU-A packet (invalid size)")
} }
typ := naluType(pkt.Payload[0] & 0x1F) typ := naluType(pkt.Payload[0] & 0x1F)
if typ != naluTypeFUA { if typ != naluTypeFUA {
d.state = decoderStateInitial d.isReadingFragmented = false
return nil, 0, fmt.Errorf("expected FU-A packet, got another type") return nil, 0, fmt.Errorf("expected FU-A packet, got another type")
} }
start := pkt.Payload[1] >> 7
end := (pkt.Payload[1] >> 6) & 0x01 end := (pkt.Payload[1] >> 6) & 0x01
if start == 1 {
d.isReadingFragmented = false
return nil, 0, fmt.Errorf("invalid FU-A packet (decoded two starting packets in a row)")
}
d.fragmentedBuf = append(d.fragmentedBuf, pkt.Payload[2:]...) d.fragmentedBuf = append(d.fragmentedBuf, pkt.Payload[2:]...)
if end != 1 { if end != 1 {
return nil, 0, ErrMorePacketsNeeded return nil, 0, ErrMorePacketsNeeded
} }
d.state = decoderStateInitial d.isReadingFragmented = false
d.startingPacketReceived = true d.startingPacketReceived = true
return [][]byte{d.fragmentedBuf}, d.decodeTimestamp(pkt.Timestamp), nil return [][]byte{d.fragmentedBuf}, d.decodeTimestamp(pkt.Timestamp), nil
}
} }
// ReadSPSPPS reads RTP/H264 packets from a reader until SPS and PPS are // ReadSPSPPS reads RTP/H264 packets from a reader until SPS and PPS are

View File

@@ -416,6 +416,28 @@ func TestDecodeErrors(t *testing.T) {
}, },
"expected FU-A packet, got another type", "expected FU-A packet, got another type",
}, },
{
"FU-A with two starting packets",
[][]byte{
mergeBytes(
[]byte{
0x80, 0x60, 0x44, 0xed, 0x88, 0x77, 0x79, 0xab,
0x9d, 0xbb, 0x78, 0x12, 0x1c, 0x85,
},
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182),
[]byte{0x00, 0x01},
),
mergeBytes(
[]byte{
0x80, 0x60, 0x44, 0xed, 0x88, 0x77, 0x79, 0xab,
0x9d, 0xbb, 0x78, 0x12, 0x1c, 0x85,
},
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182),
[]byte{0x00, 0x01},
),
},
"invalid FU-A packet (decoded two starting packets in a row)",
},
{ {
"MTAP", "MTAP",
[][]byte{{ [][]byte{{