mirror of
https://github.com/aler9/gortsplib
synced 2025-10-08 16:40:09 +08:00
rtph264: fix decode error (#150)
Intact NALUs received after corrupted NALUs were wrongly discarded.
This commit is contained in:
@@ -26,9 +26,8 @@ var ErrNonStartingPacketAndNoPrevious = errors.New(
|
||||
type Decoder struct {
|
||||
timeDecoder *rtptimedec.Decoder
|
||||
firstPacketReceived bool
|
||||
fragmentedMode bool
|
||||
fragmentedParts [][]byte
|
||||
fragmentedSize int
|
||||
fragments [][]byte
|
||||
firstNALUParsed bool
|
||||
annexBMode bool
|
||||
|
||||
@@ -36,159 +35,126 @@ type Decoder struct {
|
||||
naluBuffer [][]byte
|
||||
}
|
||||
|
||||
// Init initializes the decoder
|
||||
// Init initializes the decoder.
|
||||
func (d *Decoder) Init() {
|
||||
d.timeDecoder = rtptimedec.New(rtpClockRate)
|
||||
}
|
||||
|
||||
// Decode decodes NALUs from a RTP/H264 packet.
|
||||
func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, time.Duration, error) {
|
||||
if !d.fragmentedMode {
|
||||
if len(pkt.Payload) < 1 {
|
||||
return nil, 0, fmt.Errorf("payload is too short")
|
||||
if len(pkt.Payload) < 1 {
|
||||
d.fragments = d.fragments[:0] // discard pending fragmented packets
|
||||
return nil, 0, fmt.Errorf("payload is too short")
|
||||
}
|
||||
|
||||
typ := naluType(pkt.Payload[0] & 0x1F)
|
||||
var nalus [][]byte
|
||||
|
||||
switch typ {
|
||||
case naluTypeFUA:
|
||||
if len(pkt.Payload) < 2 {
|
||||
return nil, 0, fmt.Errorf("invalid FU-A packet (invalid size)")
|
||||
}
|
||||
|
||||
typ := naluType(pkt.Payload[0] & 0x1F)
|
||||
start := pkt.Payload[1] >> 7
|
||||
end := (pkt.Payload[1] >> 6) & 0x01
|
||||
|
||||
switch typ {
|
||||
case naluTypeSTAPA:
|
||||
var nalus [][]byte
|
||||
payload := pkt.Payload[1:]
|
||||
if start == 1 {
|
||||
d.fragments = d.fragments[:0] // discard pending fragmented packets
|
||||
|
||||
for len(payload) > 0 {
|
||||
if len(payload) < 2 {
|
||||
return nil, 0, fmt.Errorf("invalid STAP-A packet (invalid size)")
|
||||
}
|
||||
|
||||
size := uint16(payload[0])<<8 | uint16(payload[1])
|
||||
payload = payload[2:]
|
||||
|
||||
// avoid final padding
|
||||
if size == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if int(size) > len(payload) {
|
||||
return nil, 0, fmt.Errorf("invalid STAP-A packet (invalid size)")
|
||||
}
|
||||
|
||||
nalus = append(nalus, payload[:size])
|
||||
payload = payload[size:]
|
||||
}
|
||||
|
||||
if len(nalus) == 0 {
|
||||
return nil, 0, fmt.Errorf("STAP-A packet doesn't contain any NALU")
|
||||
}
|
||||
|
||||
d.firstPacketReceived = true
|
||||
|
||||
var err error
|
||||
nalus, err = d.finalize(nalus)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return nalus, d.timeDecoder.Decode(pkt.Timestamp), nil
|
||||
|
||||
case naluTypeFUA: // first packet of a fragmented NALU
|
||||
if len(pkt.Payload) < 2 {
|
||||
return nil, 0, fmt.Errorf("invalid FU-A packet (invalid size)")
|
||||
}
|
||||
|
||||
start := pkt.Payload[1] >> 7
|
||||
if start != 1 {
|
||||
if !d.firstPacketReceived {
|
||||
return nil, 0, ErrNonStartingPacketAndNoPrevious
|
||||
}
|
||||
return nil, 0, fmt.Errorf("invalid FU-A packet (non-starting)")
|
||||
}
|
||||
|
||||
end := (pkt.Payload[1] >> 6) & 0x01
|
||||
if end != 0 {
|
||||
return nil, 0, fmt.Errorf("invalid FU-A packet (can't contain both a start and end bit)")
|
||||
}
|
||||
|
||||
nri := (pkt.Payload[0] >> 5) & 0x03
|
||||
typ := pkt.Payload[1] & 0x1F
|
||||
d.fragmentedSize = len(pkt.Payload) - 1
|
||||
d.fragmentedParts = append(d.fragmentedParts, []byte{(nri << 5) | typ})
|
||||
d.fragmentedParts = append(d.fragmentedParts, pkt.Payload[2:])
|
||||
d.fragmentedMode = true
|
||||
|
||||
d.fragmentedSize = len(pkt.Payload[1:])
|
||||
d.fragments = append(d.fragments, []byte{(nri << 5) | typ}, pkt.Payload[2:])
|
||||
d.firstPacketReceived = true
|
||||
return nil, 0, ErrMorePacketsNeeded
|
||||
|
||||
case naluTypeSTAPB, naluTypeMTAP16,
|
||||
naluTypeMTAP24, naluTypeFUB:
|
||||
return nil, 0, fmt.Errorf("packet type not supported (%v)", typ)
|
||||
return nil, 0, ErrMorePacketsNeeded
|
||||
}
|
||||
|
||||
nalus := [][]byte{pkt.Payload}
|
||||
if len(d.fragments) == 0 {
|
||||
if !d.firstPacketReceived {
|
||||
return nil, 0, ErrNonStartingPacketAndNoPrevious
|
||||
}
|
||||
|
||||
return nil, 0, fmt.Errorf("invalid FU-A packet (non-starting)")
|
||||
}
|
||||
|
||||
d.fragmentedSize += len(pkt.Payload[2:])
|
||||
if d.fragmentedSize > h264.MaxNALUSize {
|
||||
d.fragments = d.fragments[:0]
|
||||
return nil, 0, fmt.Errorf("NALU size (%d) is too big (maximum is %d)", d.fragmentedSize, h264.MaxNALUSize)
|
||||
}
|
||||
|
||||
d.fragments = append(d.fragments, pkt.Payload[2:])
|
||||
|
||||
if end != 1 {
|
||||
return nil, 0, ErrMorePacketsNeeded
|
||||
}
|
||||
|
||||
nalu := make([]byte, d.fragmentedSize)
|
||||
pos := 0
|
||||
|
||||
for _, frag := range d.fragments {
|
||||
pos += copy(nalu[pos:], frag)
|
||||
}
|
||||
|
||||
d.fragments = d.fragments[:0]
|
||||
nalus = [][]byte{nalu}
|
||||
|
||||
case naluTypeSTAPA:
|
||||
d.fragments = d.fragments[:0] // discard pending fragmented packets
|
||||
|
||||
payload := pkt.Payload[1:]
|
||||
|
||||
for len(payload) > 0 {
|
||||
if len(payload) < 2 {
|
||||
return nil, 0, fmt.Errorf("invalid STAP-A packet (invalid size)")
|
||||
}
|
||||
|
||||
size := uint16(payload[0])<<8 | uint16(payload[1])
|
||||
payload = payload[2:]
|
||||
|
||||
// avoid final padding
|
||||
if size == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if int(size) > len(payload) {
|
||||
return nil, 0, fmt.Errorf("invalid STAP-A packet (invalid size)")
|
||||
}
|
||||
|
||||
nalus = append(nalus, payload[:size])
|
||||
payload = payload[size:]
|
||||
}
|
||||
|
||||
if nalus == nil {
|
||||
return nil, 0, fmt.Errorf("STAP-A packet doesn't contain any NALU")
|
||||
}
|
||||
|
||||
d.firstPacketReceived = true
|
||||
|
||||
var err error
|
||||
nalus, err = d.finalize(nalus)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
case naluTypeSTAPB, naluTypeMTAP16,
|
||||
naluTypeMTAP24, naluTypeFUB:
|
||||
d.fragments = d.fragments[:0] // discard pending fragmented packets
|
||||
d.firstPacketReceived = true
|
||||
return nil, 0, fmt.Errorf("packet type not supported (%v)", typ)
|
||||
|
||||
return nalus, d.timeDecoder.Decode(pkt.Timestamp), nil
|
||||
default:
|
||||
d.fragments = d.fragments[:0] // discard pending fragmented packets
|
||||
d.firstPacketReceived = true
|
||||
nalus = [][]byte{pkt.Payload}
|
||||
}
|
||||
|
||||
// we are decoding a fragmented NALU
|
||||
|
||||
if len(pkt.Payload) < 2 {
|
||||
d.fragmentedParts = d.fragmentedParts[:0]
|
||||
d.fragmentedMode = false
|
||||
return nil, 0, fmt.Errorf("invalid FU-A packet (invalid size)")
|
||||
}
|
||||
|
||||
typ := naluType(pkt.Payload[0] & 0x1F)
|
||||
if typ != naluTypeFUA {
|
||||
d.fragmentedParts = d.fragmentedParts[:0]
|
||||
d.fragmentedMode = false
|
||||
return nil, 0, fmt.Errorf("expected FU-A packet, got %s packet", typ)
|
||||
}
|
||||
|
||||
start := pkt.Payload[1] >> 7
|
||||
if start == 1 {
|
||||
d.fragmentedParts = d.fragmentedParts[:0]
|
||||
d.fragmentedMode = false
|
||||
return nil, 0, fmt.Errorf("invalid FU-A packet (decoded two starting packets in a row)")
|
||||
}
|
||||
|
||||
d.fragmentedSize += len(pkt.Payload[2:])
|
||||
if d.fragmentedSize > h264.MaxNALUSize {
|
||||
d.fragmentedParts = d.fragmentedParts[:0]
|
||||
d.fragmentedMode = false
|
||||
return nil, 0, fmt.Errorf("NALU size (%d) is too big (maximum is %d)", d.fragmentedSize, h264.MaxNALUSize)
|
||||
}
|
||||
|
||||
d.fragmentedParts = append(d.fragmentedParts, pkt.Payload[2:])
|
||||
|
||||
end := (pkt.Payload[1] >> 6) & 0x01
|
||||
if end != 1 {
|
||||
return nil, 0, ErrMorePacketsNeeded
|
||||
}
|
||||
|
||||
ret := make([]byte, d.fragmentedSize)
|
||||
n := 0
|
||||
for _, p := range d.fragmentedParts {
|
||||
n += copy(ret[n:], p)
|
||||
}
|
||||
nalus := [][]byte{ret}
|
||||
|
||||
d.fragmentedParts = d.fragmentedParts[:0]
|
||||
d.fragmentedMode = false
|
||||
|
||||
var err error
|
||||
nalus, err = d.finalize(nalus)
|
||||
nalus, err := d.removeAnnexB(nalus)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return nalus, d.timeDecoder.Decode(pkt.Timestamp), nil
|
||||
return nalus, d.timeDecoder.Decode(pkt.Timestamp), err
|
||||
}
|
||||
|
||||
// DecodeUntilMarker decodes NALUs from a RTP/H264 packet and puts them in a buffer.
|
||||
@@ -217,7 +183,7 @@ func (d *Decoder) DecodeUntilMarker(pkt *rtp.Packet) ([][]byte, time.Duration, e
|
||||
return ret, pts, nil
|
||||
}
|
||||
|
||||
func (d *Decoder) finalize(nalus [][]byte) ([][]byte, error) {
|
||||
func (d *Decoder) removeAnnexB(nalus [][]byte) ([][]byte, error) {
|
||||
// some cameras / servers wrap NALUs into Annex-B
|
||||
if !d.firstNALUParsed {
|
||||
d.firstNALUParsed = true
|
||||
|
@@ -349,44 +349,42 @@ func TestDecode(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodePartOfFragmentedBeforeSingle(t *testing.T) {
|
||||
func TestDecodeCorruptedFragment(t *testing.T) {
|
||||
d := &Decoder{}
|
||||
d.Init()
|
||||
|
||||
pkt := rtp.Packet{
|
||||
_, _, err := d.Decode(&rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
Marker: true,
|
||||
PayloadType: 96,
|
||||
SequenceNumber: 17647,
|
||||
Timestamp: 2289531307,
|
||||
SSRC: 0x9dbb7812,
|
||||
},
|
||||
Payload: mergeBytes(
|
||||
[]byte{0x1c, 0x45},
|
||||
[]byte{0x04, 0x05, 0x06, 0x07},
|
||||
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 147),
|
||||
),
|
||||
}
|
||||
_, _, err := d.Decode(&pkt)
|
||||
require.Equal(t, ErrNonStartingPacketAndNoPrevious, err)
|
||||
|
||||
pkt = rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
Marker: true,
|
||||
Marker: false,
|
||||
PayloadType: 96,
|
||||
SequenceNumber: 17645,
|
||||
Timestamp: 2289528607,
|
||||
Timestamp: 2289527317,
|
||||
SSRC: 0x9dbb7812,
|
||||
},
|
||||
Payload: mergeBytes(
|
||||
[]byte{0x05},
|
||||
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 8),
|
||||
[]byte{
|
||||
0x1c, 0x85,
|
||||
},
|
||||
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182),
|
||||
[]byte{0x00, 0x01},
|
||||
),
|
||||
}
|
||||
_, _, err = d.Decode(&pkt)
|
||||
})
|
||||
require.Equal(t, ErrMorePacketsNeeded, err)
|
||||
|
||||
nalus, _, err := d.Decode(&rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
Marker: false,
|
||||
PayloadType: 96,
|
||||
SequenceNumber: 17646,
|
||||
Timestamp: 2289527317,
|
||||
SSRC: 0x9dbb7812,
|
||||
},
|
||||
Payload: []byte{0x01, 0x00},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [][]byte{{0x01, 0x00}}, nalus)
|
||||
}
|
||||
|
||||
func TestDecodeSTAPAWithPadding(t *testing.T) {
|
||||
@@ -538,7 +536,24 @@ func TestDecodeErrors(t *testing.T) {
|
||||
"invalid FU-A packet (invalid size)",
|
||||
},
|
||||
{
|
||||
"FU-A without start bit",
|
||||
"FU-A with start and end bit",
|
||||
[]*rtp.Packet{
|
||||
{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
Marker: true,
|
||||
PayloadType: 96,
|
||||
SequenceNumber: 17646,
|
||||
Timestamp: 2289527317,
|
||||
SSRC: 0x9dbb7812,
|
||||
},
|
||||
Payload: []byte{0x1c, 0b11000000},
|
||||
},
|
||||
},
|
||||
"invalid FU-A packet (can't contain both a start and end bit)",
|
||||
},
|
||||
{
|
||||
"FU-A non-starting",
|
||||
[]*rtp.Packet{
|
||||
{
|
||||
Header: rtp.Header{
|
||||
@@ -563,112 +578,11 @@ func TestDecodeErrors(t *testing.T) {
|
||||
Timestamp: 2289527317,
|
||||
SSRC: 0x9dbb7812,
|
||||
},
|
||||
Payload: []byte{0x1c, 0x00},
|
||||
Payload: []byte{0x1c, 0b01000000},
|
||||
},
|
||||
},
|
||||
"invalid FU-A packet (non-starting)",
|
||||
},
|
||||
{
|
||||
"FU-A with 2nd packet empty",
|
||||
[]*rtp.Packet{
|
||||
{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
Marker: false,
|
||||
PayloadType: 96,
|
||||
SequenceNumber: 17645,
|
||||
Timestamp: 2289527317,
|
||||
SSRC: 0x9dbb7812,
|
||||
},
|
||||
Payload: mergeBytes(
|
||||
[]byte{0x1c, 0x85},
|
||||
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182),
|
||||
[]byte{0x00, 0x01},
|
||||
),
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
Marker: false,
|
||||
PayloadType: 96,
|
||||
SequenceNumber: 17646,
|
||||
Timestamp: 2289527317,
|
||||
SSRC: 0x9dbb7812,
|
||||
},
|
||||
},
|
||||
},
|
||||
"invalid FU-A packet (invalid size)",
|
||||
},
|
||||
{
|
||||
"FU-A with 2nd packet invalid",
|
||||
[]*rtp.Packet{
|
||||
{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
Marker: false,
|
||||
PayloadType: 96,
|
||||
SequenceNumber: 17645,
|
||||
Timestamp: 2289527317,
|
||||
SSRC: 0x9dbb7812,
|
||||
},
|
||||
Payload: mergeBytes(
|
||||
[]byte{
|
||||
0x1c, 0x85,
|
||||
},
|
||||
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182),
|
||||
[]byte{0x00, 0x01},
|
||||
),
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
Marker: false,
|
||||
PayloadType: 96,
|
||||
SequenceNumber: 17646,
|
||||
Timestamp: 2289527317,
|
||||
SSRC: 0x9dbb7812,
|
||||
},
|
||||
Payload: []byte{0x01, 0x00},
|
||||
},
|
||||
},
|
||||
"expected FU-A packet, got NonIDR packet",
|
||||
},
|
||||
{
|
||||
"FU-A with two starting packets",
|
||||
[]*rtp.Packet{
|
||||
{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
Marker: false,
|
||||
PayloadType: 96,
|
||||
SequenceNumber: 17645,
|
||||
Timestamp: 2289527317,
|
||||
SSRC: 0x9dbb7812,
|
||||
},
|
||||
Payload: mergeBytes(
|
||||
[]byte{0x1c, 0x85},
|
||||
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182),
|
||||
[]byte{0x00, 0x01},
|
||||
),
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
Marker: false,
|
||||
PayloadType: 96,
|
||||
SequenceNumber: 17646,
|
||||
Timestamp: 2289527317,
|
||||
SSRC: 0x9dbb7812,
|
||||
},
|
||||
Payload: mergeBytes(
|
||||
[]byte{0x1c, 0x85},
|
||||
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182),
|
||||
[]byte{0x00, 0x01},
|
||||
),
|
||||
},
|
||||
},
|
||||
"invalid FU-A packet (decoded two starting packets in a row)",
|
||||
},
|
||||
{
|
||||
"MTAP",
|
||||
[]*rtp.Packet{
|
||||
@@ -699,40 +613,3 @@ func TestDecodeErrors(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncode(t *testing.T) {
|
||||
for _, ca := range cases {
|
||||
t.Run(ca.name, func(t *testing.T) {
|
||||
e := &Encoder{
|
||||
PayloadType: 96,
|
||||
SSRC: func() *uint32 {
|
||||
v := uint32(0x9dbb7812)
|
||||
return &v
|
||||
}(),
|
||||
InitialSequenceNumber: func() *uint16 {
|
||||
v := uint16(0x44ed)
|
||||
return &v
|
||||
}(),
|
||||
InitialTimestamp: func() *uint32 {
|
||||
v := uint32(0x88776655)
|
||||
return &v
|
||||
}(),
|
||||
}
|
||||
e.Init()
|
||||
|
||||
pkts, err := e.Encode(ca.nalus, ca.pts)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ca.pkts, pkts)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeRandomInitialState(t *testing.T) {
|
||||
e := &Encoder{
|
||||
PayloadType: 96,
|
||||
}
|
||||
e.Init()
|
||||
require.NotEqual(t, nil, e.SSRC)
|
||||
require.NotEqual(t, nil, e.InitialSequenceNumber)
|
||||
require.NotEqual(t, nil, e.InitialTimestamp)
|
||||
}
|
@@ -7,6 +7,10 @@ import (
|
||||
"github.com/pion/rtp"
|
||||
)
|
||||
|
||||
const (
|
||||
rtpVersion = 2
|
||||
)
|
||||
|
||||
func randUint32() uint32 {
|
||||
var b [4]byte
|
||||
rand.Read(b[:])
|
||||
|
44
pkg/rtph264/encoder_test.go
Normal file
44
pkg/rtph264/encoder_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package rtph264
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncode(t *testing.T) {
|
||||
for _, ca := range cases {
|
||||
t.Run(ca.name, func(t *testing.T) {
|
||||
e := &Encoder{
|
||||
PayloadType: 96,
|
||||
SSRC: func() *uint32 {
|
||||
v := uint32(0x9dbb7812)
|
||||
return &v
|
||||
}(),
|
||||
InitialSequenceNumber: func() *uint16 {
|
||||
v := uint16(0x44ed)
|
||||
return &v
|
||||
}(),
|
||||
InitialTimestamp: func() *uint32 {
|
||||
v := uint32(0x88776655)
|
||||
return &v
|
||||
}(),
|
||||
}
|
||||
e.Init()
|
||||
|
||||
pkts, err := e.Encode(ca.nalus, ca.pts)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ca.pkts, pkts)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeRandomInitialState(t *testing.T) {
|
||||
e := &Encoder{
|
||||
PayloadType: 96,
|
||||
}
|
||||
e.Init()
|
||||
require.NotEqual(t, nil, e.SSRC)
|
||||
require.NotEqual(t, nil, e.InitialSequenceNumber)
|
||||
require.NotEqual(t, nil, e.InitialTimestamp)
|
||||
}
|
@@ -2,6 +2,5 @@
|
||||
package rtph264
|
||||
|
||||
const (
|
||||
rtpVersion = 0x02
|
||||
rtpClockRate = 90000 // h264 always uses 90khz
|
||||
)
|
||||
|
Reference in New Issue
Block a user