rtph264: fix decode error (#150)

Intact NALUs received after corrupted NALUs were wrongly discarded.
This commit is contained in:
Alessandro Ros
2022-11-14 14:30:30 +01:00
committed by GitHub
parent 12c8845fef
commit 764ac1ce35
5 changed files with 183 additions and 293 deletions

View File

@@ -26,9 +26,8 @@ var ErrNonStartingPacketAndNoPrevious = errors.New(
type Decoder struct { type Decoder struct {
timeDecoder *rtptimedec.Decoder timeDecoder *rtptimedec.Decoder
firstPacketReceived bool firstPacketReceived bool
fragmentedMode bool
fragmentedParts [][]byte
fragmentedSize int fragmentedSize int
fragments [][]byte
firstNALUParsed bool firstNALUParsed bool
annexBMode bool annexBMode bool
@@ -36,23 +35,79 @@ type Decoder struct {
naluBuffer [][]byte naluBuffer [][]byte
} }
// Init initializes the decoder // Init initializes the decoder.
func (d *Decoder) Init() { func (d *Decoder) Init() {
d.timeDecoder = rtptimedec.New(rtpClockRate) d.timeDecoder = rtptimedec.New(rtpClockRate)
} }
// Decode decodes NALUs from a RTP/H264 packet. // Decode decodes NALUs from a RTP/H264 packet.
func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, time.Duration, error) { func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, time.Duration, error) {
if !d.fragmentedMode {
if len(pkt.Payload) < 1 { if len(pkt.Payload) < 1 {
d.fragments = d.fragments[:0] // discard pending fragmented packets
return nil, 0, fmt.Errorf("payload is too short") return nil, 0, fmt.Errorf("payload is too short")
} }
typ := naluType(pkt.Payload[0] & 0x1F) typ := naluType(pkt.Payload[0] & 0x1F)
var nalus [][]byte
switch typ { switch typ {
case naluTypeFUA:
if len(pkt.Payload) < 2 {
return nil, 0, fmt.Errorf("invalid FU-A packet (invalid size)")
}
start := pkt.Payload[1] >> 7
end := (pkt.Payload[1] >> 6) & 0x01
if start == 1 {
d.fragments = d.fragments[:0] // discard pending fragmented packets
if end != 0 {
return nil, 0, fmt.Errorf("invalid FU-A packet (can't contain both a start and end bit)")
}
nri := (pkt.Payload[0] >> 5) & 0x03
typ := pkt.Payload[1] & 0x1F
d.fragmentedSize = len(pkt.Payload[1:])
d.fragments = append(d.fragments, []byte{(nri << 5) | typ}, pkt.Payload[2:])
d.firstPacketReceived = true
return nil, 0, ErrMorePacketsNeeded
}
if len(d.fragments) == 0 {
if !d.firstPacketReceived {
return nil, 0, ErrNonStartingPacketAndNoPrevious
}
return nil, 0, fmt.Errorf("invalid FU-A packet (non-starting)")
}
d.fragmentedSize += len(pkt.Payload[2:])
if d.fragmentedSize > h264.MaxNALUSize {
d.fragments = d.fragments[:0]
return nil, 0, fmt.Errorf("NALU size (%d) is too big (maximum is %d)", d.fragmentedSize, h264.MaxNALUSize)
}
d.fragments = append(d.fragments, pkt.Payload[2:])
if end != 1 {
return nil, 0, ErrMorePacketsNeeded
}
nalu := make([]byte, d.fragmentedSize)
pos := 0
for _, frag := range d.fragments {
pos += copy(nalu[pos:], frag)
}
d.fragments = d.fragments[:0]
nalus = [][]byte{nalu}
case naluTypeSTAPA: case naluTypeSTAPA:
var nalus [][]byte d.fragments = d.fragments[:0] // discard pending fragmented packets
payload := pkt.Payload[1:] payload := pkt.Payload[1:]
for len(payload) > 0 { for len(payload) > 0 {
@@ -76,119 +131,30 @@ func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, time.Duration, error) {
payload = payload[size:] payload = payload[size:]
} }
if len(nalus) == 0 { if nalus == nil {
return nil, 0, fmt.Errorf("STAP-A packet doesn't contain any NALU") return nil, 0, fmt.Errorf("STAP-A packet doesn't contain any NALU")
} }
d.firstPacketReceived = true d.firstPacketReceived = true
var err error
nalus, err = d.finalize(nalus)
if err != nil {
return nil, 0, err
}
return nalus, d.timeDecoder.Decode(pkt.Timestamp), nil
case naluTypeFUA: // first packet of a fragmented NALU
if len(pkt.Payload) < 2 {
return nil, 0, fmt.Errorf("invalid FU-A packet (invalid size)")
}
start := pkt.Payload[1] >> 7
if start != 1 {
if !d.firstPacketReceived {
return nil, 0, ErrNonStartingPacketAndNoPrevious
}
return nil, 0, fmt.Errorf("invalid FU-A packet (non-starting)")
}
end := (pkt.Payload[1] >> 6) & 0x01
if end != 0 {
return nil, 0, fmt.Errorf("invalid FU-A packet (can't contain both a start and end bit)")
}
nri := (pkt.Payload[0] >> 5) & 0x03
typ := pkt.Payload[1] & 0x1F
d.fragmentedSize = len(pkt.Payload) - 1
d.fragmentedParts = append(d.fragmentedParts, []byte{(nri << 5) | typ})
d.fragmentedParts = append(d.fragmentedParts, pkt.Payload[2:])
d.fragmentedMode = true
d.firstPacketReceived = true
return nil, 0, ErrMorePacketsNeeded
case naluTypeSTAPB, naluTypeMTAP16, case naluTypeSTAPB, naluTypeMTAP16,
naluTypeMTAP24, naluTypeFUB: naluTypeMTAP24, naluTypeFUB:
return nil, 0, fmt.Errorf("packet type not supported (%v)", typ) d.fragments = d.fragments[:0] // discard pending fragmented packets
}
nalus := [][]byte{pkt.Payload}
d.firstPacketReceived = true d.firstPacketReceived = true
return nil, 0, fmt.Errorf("packet type not supported (%v)", typ)
var err error default:
nalus, err = d.finalize(nalus) d.fragments = d.fragments[:0] // discard pending fragmented packets
d.firstPacketReceived = true
nalus = [][]byte{pkt.Payload}
}
nalus, err := d.removeAnnexB(nalus)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
return nalus, d.timeDecoder.Decode(pkt.Timestamp), nil return nalus, d.timeDecoder.Decode(pkt.Timestamp), err
}
// we are decoding a fragmented NALU
if len(pkt.Payload) < 2 {
d.fragmentedParts = d.fragmentedParts[:0]
d.fragmentedMode = false
return nil, 0, fmt.Errorf("invalid FU-A packet (invalid size)")
}
typ := naluType(pkt.Payload[0] & 0x1F)
if typ != naluTypeFUA {
d.fragmentedParts = d.fragmentedParts[:0]
d.fragmentedMode = false
return nil, 0, fmt.Errorf("expected FU-A packet, got %s packet", typ)
}
start := pkt.Payload[1] >> 7
if start == 1 {
d.fragmentedParts = d.fragmentedParts[:0]
d.fragmentedMode = false
return nil, 0, fmt.Errorf("invalid FU-A packet (decoded two starting packets in a row)")
}
d.fragmentedSize += len(pkt.Payload[2:])
if d.fragmentedSize > h264.MaxNALUSize {
d.fragmentedParts = d.fragmentedParts[:0]
d.fragmentedMode = false
return nil, 0, fmt.Errorf("NALU size (%d) is too big (maximum is %d)", d.fragmentedSize, h264.MaxNALUSize)
}
d.fragmentedParts = append(d.fragmentedParts, pkt.Payload[2:])
end := (pkt.Payload[1] >> 6) & 0x01
if end != 1 {
return nil, 0, ErrMorePacketsNeeded
}
ret := make([]byte, d.fragmentedSize)
n := 0
for _, p := range d.fragmentedParts {
n += copy(ret[n:], p)
}
nalus := [][]byte{ret}
d.fragmentedParts = d.fragmentedParts[:0]
d.fragmentedMode = false
var err error
nalus, err = d.finalize(nalus)
if err != nil {
return nil, 0, err
}
return nalus, d.timeDecoder.Decode(pkt.Timestamp), nil
} }
// DecodeUntilMarker decodes NALUs from a RTP/H264 packet and puts them in a buffer. // DecodeUntilMarker decodes NALUs from a RTP/H264 packet and puts them in a buffer.
@@ -217,7 +183,7 @@ func (d *Decoder) DecodeUntilMarker(pkt *rtp.Packet) ([][]byte, time.Duration, e
return ret, pts, nil return ret, pts, nil
} }
func (d *Decoder) finalize(nalus [][]byte) ([][]byte, error) { func (d *Decoder) removeAnnexB(nalus [][]byte) ([][]byte, error) {
// some cameras / servers wrap NALUs into Annex-B // some cameras / servers wrap NALUs into Annex-B
if !d.firstNALUParsed { if !d.firstNALUParsed {
d.firstNALUParsed = true d.firstNALUParsed = true

View File

@@ -349,44 +349,42 @@ func TestDecode(t *testing.T) {
} }
} }
func TestDecodePartOfFragmentedBeforeSingle(t *testing.T) { func TestDecodeCorruptedFragment(t *testing.T) {
d := &Decoder{} d := &Decoder{}
d.Init() d.Init()
pkt := rtp.Packet{ _, _, err := d.Decode(&rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: 2, Version: 2,
Marker: true, Marker: false,
PayloadType: 96,
SequenceNumber: 17647,
Timestamp: 2289531307,
SSRC: 0x9dbb7812,
},
Payload: mergeBytes(
[]byte{0x1c, 0x45},
[]byte{0x04, 0x05, 0x06, 0x07},
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 147),
),
}
_, _, err := d.Decode(&pkt)
require.Equal(t, ErrNonStartingPacketAndNoPrevious, err)
pkt = rtp.Packet{
Header: rtp.Header{
Version: 2,
Marker: true,
PayloadType: 96, PayloadType: 96,
SequenceNumber: 17645, SequenceNumber: 17645,
Timestamp: 2289528607, Timestamp: 2289527317,
SSRC: 0x9dbb7812, SSRC: 0x9dbb7812,
}, },
Payload: mergeBytes( Payload: mergeBytes(
[]byte{0x05}, []byte{
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 8), 0x1c, 0x85,
},
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182),
[]byte{0x00, 0x01},
), ),
} })
_, _, err = d.Decode(&pkt) require.Equal(t, ErrMorePacketsNeeded, err)
nalus, _, err := d.Decode(&rtp.Packet{
Header: rtp.Header{
Version: 2,
Marker: false,
PayloadType: 96,
SequenceNumber: 17646,
Timestamp: 2289527317,
SSRC: 0x9dbb7812,
},
Payload: []byte{0x01, 0x00},
})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, [][]byte{{0x01, 0x00}}, nalus)
} }
func TestDecodeSTAPAWithPadding(t *testing.T) { func TestDecodeSTAPAWithPadding(t *testing.T) {
@@ -538,7 +536,24 @@ func TestDecodeErrors(t *testing.T) {
"invalid FU-A packet (invalid size)", "invalid FU-A packet (invalid size)",
}, },
{ {
"FU-A without start bit", "FU-A with start and end bit",
[]*rtp.Packet{
{
Header: rtp.Header{
Version: 2,
Marker: true,
PayloadType: 96,
SequenceNumber: 17646,
Timestamp: 2289527317,
SSRC: 0x9dbb7812,
},
Payload: []byte{0x1c, 0b11000000},
},
},
"invalid FU-A packet (can't contain both a start and end bit)",
},
{
"FU-A non-starting",
[]*rtp.Packet{ []*rtp.Packet{
{ {
Header: rtp.Header{ Header: rtp.Header{
@@ -563,112 +578,11 @@ func TestDecodeErrors(t *testing.T) {
Timestamp: 2289527317, Timestamp: 2289527317,
SSRC: 0x9dbb7812, SSRC: 0x9dbb7812,
}, },
Payload: []byte{0x1c, 0x00}, Payload: []byte{0x1c, 0b01000000},
}, },
}, },
"invalid FU-A packet (non-starting)", "invalid FU-A packet (non-starting)",
}, },
{
"FU-A with 2nd packet empty",
[]*rtp.Packet{
{
Header: rtp.Header{
Version: 2,
Marker: false,
PayloadType: 96,
SequenceNumber: 17645,
Timestamp: 2289527317,
SSRC: 0x9dbb7812,
},
Payload: mergeBytes(
[]byte{0x1c, 0x85},
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182),
[]byte{0x00, 0x01},
),
},
{
Header: rtp.Header{
Version: 2,
Marker: false,
PayloadType: 96,
SequenceNumber: 17646,
Timestamp: 2289527317,
SSRC: 0x9dbb7812,
},
},
},
"invalid FU-A packet (invalid size)",
},
{
"FU-A with 2nd packet invalid",
[]*rtp.Packet{
{
Header: rtp.Header{
Version: 2,
Marker: false,
PayloadType: 96,
SequenceNumber: 17645,
Timestamp: 2289527317,
SSRC: 0x9dbb7812,
},
Payload: mergeBytes(
[]byte{
0x1c, 0x85,
},
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182),
[]byte{0x00, 0x01},
),
},
{
Header: rtp.Header{
Version: 2,
Marker: false,
PayloadType: 96,
SequenceNumber: 17646,
Timestamp: 2289527317,
SSRC: 0x9dbb7812,
},
Payload: []byte{0x01, 0x00},
},
},
"expected FU-A packet, got NonIDR packet",
},
{
"FU-A with two starting packets",
[]*rtp.Packet{
{
Header: rtp.Header{
Version: 2,
Marker: false,
PayloadType: 96,
SequenceNumber: 17645,
Timestamp: 2289527317,
SSRC: 0x9dbb7812,
},
Payload: mergeBytes(
[]byte{0x1c, 0x85},
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182),
[]byte{0x00, 0x01},
),
},
{
Header: rtp.Header{
Version: 2,
Marker: false,
PayloadType: 96,
SequenceNumber: 17646,
Timestamp: 2289527317,
SSRC: 0x9dbb7812,
},
Payload: mergeBytes(
[]byte{0x1c, 0x85},
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182),
[]byte{0x00, 0x01},
),
},
},
"invalid FU-A packet (decoded two starting packets in a row)",
},
{ {
"MTAP", "MTAP",
[]*rtp.Packet{ []*rtp.Packet{
@@ -699,40 +613,3 @@ func TestDecodeErrors(t *testing.T) {
}) })
} }
} }
func TestEncode(t *testing.T) {
for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) {
e := &Encoder{
PayloadType: 96,
SSRC: func() *uint32 {
v := uint32(0x9dbb7812)
return &v
}(),
InitialSequenceNumber: func() *uint16 {
v := uint16(0x44ed)
return &v
}(),
InitialTimestamp: func() *uint32 {
v := uint32(0x88776655)
return &v
}(),
}
e.Init()
pkts, err := e.Encode(ca.nalus, ca.pts)
require.NoError(t, err)
require.Equal(t, ca.pkts, pkts)
})
}
}
func TestEncodeRandomInitialState(t *testing.T) {
e := &Encoder{
PayloadType: 96,
}
e.Init()
require.NotEqual(t, nil, e.SSRC)
require.NotEqual(t, nil, e.InitialSequenceNumber)
require.NotEqual(t, nil, e.InitialTimestamp)
}

View File

@@ -7,6 +7,10 @@ import (
"github.com/pion/rtp" "github.com/pion/rtp"
) )
const (
rtpVersion = 2
)
func randUint32() uint32 { func randUint32() uint32 {
var b [4]byte var b [4]byte
rand.Read(b[:]) rand.Read(b[:])

View File

@@ -0,0 +1,44 @@
package rtph264
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestEncode(t *testing.T) {
for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) {
e := &Encoder{
PayloadType: 96,
SSRC: func() *uint32 {
v := uint32(0x9dbb7812)
return &v
}(),
InitialSequenceNumber: func() *uint16 {
v := uint16(0x44ed)
return &v
}(),
InitialTimestamp: func() *uint32 {
v := uint32(0x88776655)
return &v
}(),
}
e.Init()
pkts, err := e.Encode(ca.nalus, ca.pts)
require.NoError(t, err)
require.Equal(t, ca.pkts, pkts)
})
}
}
func TestEncodeRandomInitialState(t *testing.T) {
e := &Encoder{
PayloadType: 96,
}
e.Init()
require.NotEqual(t, nil, e.SSRC)
require.NotEqual(t, nil, e.InitialSequenceNumber)
require.NotEqual(t, nil, e.InitialTimestamp)
}

View File

@@ -2,6 +2,5 @@
package rtph264 package rtph264
const ( const (
rtpVersion = 0x02
rtpClockRate = 90000 // h264 always uses 90khz rtpClockRate = 90000 // h264 always uses 90khz
) )