diff --git a/pkg/rtph264/decoder.go b/pkg/rtph264/decoder.go index 1e162b3d..c41f2e78 100644 --- a/pkg/rtph264/decoder.go +++ b/pkg/rtph264/decoder.go @@ -26,9 +26,8 @@ var ErrNonStartingPacketAndNoPrevious = errors.New( type Decoder struct { timeDecoder *rtptimedec.Decoder firstPacketReceived bool - fragmentedMode bool - fragmentedParts [][]byte fragmentedSize int + fragments [][]byte firstNALUParsed bool annexBMode bool @@ -36,159 +35,126 @@ type Decoder struct { naluBuffer [][]byte } -// Init initializes the decoder +// Init initializes the decoder. func (d *Decoder) Init() { d.timeDecoder = rtptimedec.New(rtpClockRate) } // Decode decodes NALUs from a RTP/H264 packet. func (d *Decoder) Decode(pkt *rtp.Packet) ([][]byte, time.Duration, error) { - if !d.fragmentedMode { - if len(pkt.Payload) < 1 { - return nil, 0, fmt.Errorf("payload is too short") + if len(pkt.Payload) < 1 { + d.fragments = d.fragments[:0] // discard pending fragmented packets + return nil, 0, fmt.Errorf("payload is too short") + } + + typ := naluType(pkt.Payload[0] & 0x1F) + var nalus [][]byte + + switch typ { + case naluTypeFUA: + if len(pkt.Payload) < 2 { + return nil, 0, fmt.Errorf("invalid FU-A packet (invalid size)") } - typ := naluType(pkt.Payload[0] & 0x1F) + start := pkt.Payload[1] >> 7 + end := (pkt.Payload[1] >> 6) & 0x01 - switch typ { - case naluTypeSTAPA: - var nalus [][]byte - payload := pkt.Payload[1:] + if start == 1 { + d.fragments = d.fragments[:0] // discard pending fragmented packets - for len(payload) > 0 { - if len(payload) < 2 { - return nil, 0, fmt.Errorf("invalid STAP-A packet (invalid size)") - } - - size := uint16(payload[0])<<8 | uint16(payload[1]) - payload = payload[2:] - - // avoid final padding - if size == 0 { - break - } - - if int(size) > len(payload) { - return nil, 0, fmt.Errorf("invalid STAP-A packet (invalid size)") - } - - nalus = append(nalus, payload[:size]) - payload = payload[size:] - } - - if len(nalus) == 0 { - return nil, 0, fmt.Errorf("STAP-A packet doesn't contain any NALU") - } - - d.firstPacketReceived = true - - var err error - nalus, err = d.finalize(nalus) - if err != nil { - return nil, 0, err - } - - return nalus, d.timeDecoder.Decode(pkt.Timestamp), nil - - case naluTypeFUA: // first packet of a fragmented NALU - if len(pkt.Payload) < 2 { - return nil, 0, fmt.Errorf("invalid FU-A packet (invalid size)") - } - - start := pkt.Payload[1] >> 7 - if start != 1 { - if !d.firstPacketReceived { - return nil, 0, ErrNonStartingPacketAndNoPrevious - } - return nil, 0, fmt.Errorf("invalid FU-A packet (non-starting)") - } - - end := (pkt.Payload[1] >> 6) & 0x01 if end != 0 { return nil, 0, fmt.Errorf("invalid FU-A packet (can't contain both a start and end bit)") } nri := (pkt.Payload[0] >> 5) & 0x03 typ := pkt.Payload[1] & 0x1F - d.fragmentedSize = len(pkt.Payload) - 1 - d.fragmentedParts = append(d.fragmentedParts, []byte{(nri << 5) | typ}) - d.fragmentedParts = append(d.fragmentedParts, pkt.Payload[2:]) - d.fragmentedMode = true - + d.fragmentedSize = len(pkt.Payload[1:]) + d.fragments = append(d.fragments, []byte{(nri << 5) | typ}, pkt.Payload[2:]) d.firstPacketReceived = true - return nil, 0, ErrMorePacketsNeeded - case naluTypeSTAPB, naluTypeMTAP16, - naluTypeMTAP24, naluTypeFUB: - return nil, 0, fmt.Errorf("packet type not supported (%v)", typ) + return nil, 0, ErrMorePacketsNeeded } - nalus := [][]byte{pkt.Payload} + if len(d.fragments) == 0 { + if !d.firstPacketReceived { + return nil, 0, ErrNonStartingPacketAndNoPrevious + } + + return nil, 0, fmt.Errorf("invalid FU-A packet (non-starting)") + } + + d.fragmentedSize += len(pkt.Payload[2:]) + if d.fragmentedSize > h264.MaxNALUSize { + d.fragments = d.fragments[:0] + return nil, 0, fmt.Errorf("NALU size (%d) is too big (maximum is %d)", d.fragmentedSize, h264.MaxNALUSize) + } + + d.fragments = append(d.fragments, pkt.Payload[2:]) + + if end != 1 { + return nil, 0, ErrMorePacketsNeeded + } + + nalu := make([]byte, d.fragmentedSize) + pos := 0 + + for _, frag := range d.fragments { + pos += copy(nalu[pos:], frag) + } + + d.fragments = d.fragments[:0] + nalus = [][]byte{nalu} + + case naluTypeSTAPA: + d.fragments = d.fragments[:0] // discard pending fragmented packets + + payload := pkt.Payload[1:] + + for len(payload) > 0 { + if len(payload) < 2 { + return nil, 0, fmt.Errorf("invalid STAP-A packet (invalid size)") + } + + size := uint16(payload[0])<<8 | uint16(payload[1]) + payload = payload[2:] + + // avoid final padding + if size == 0 { + break + } + + if int(size) > len(payload) { + return nil, 0, fmt.Errorf("invalid STAP-A packet (invalid size)") + } + + nalus = append(nalus, payload[:size]) + payload = payload[size:] + } + + if nalus == nil { + return nil, 0, fmt.Errorf("STAP-A packet doesn't contain any NALU") + } d.firstPacketReceived = true - var err error - nalus, err = d.finalize(nalus) - if err != nil { - return nil, 0, err - } + case naluTypeSTAPB, naluTypeMTAP16, + naluTypeMTAP24, naluTypeFUB: + d.fragments = d.fragments[:0] // discard pending fragmented packets + d.firstPacketReceived = true + return nil, 0, fmt.Errorf("packet type not supported (%v)", typ) - return nalus, d.timeDecoder.Decode(pkt.Timestamp), nil + default: + d.fragments = d.fragments[:0] // discard pending fragmented packets + d.firstPacketReceived = true + nalus = [][]byte{pkt.Payload} } - // we are decoding a fragmented NALU - - if len(pkt.Payload) < 2 { - d.fragmentedParts = d.fragmentedParts[:0] - d.fragmentedMode = false - return nil, 0, fmt.Errorf("invalid FU-A packet (invalid size)") - } - - typ := naluType(pkt.Payload[0] & 0x1F) - if typ != naluTypeFUA { - d.fragmentedParts = d.fragmentedParts[:0] - d.fragmentedMode = false - return nil, 0, fmt.Errorf("expected FU-A packet, got %s packet", typ) - } - - start := pkt.Payload[1] >> 7 - if start == 1 { - d.fragmentedParts = d.fragmentedParts[:0] - d.fragmentedMode = false - return nil, 0, fmt.Errorf("invalid FU-A packet (decoded two starting packets in a row)") - } - - d.fragmentedSize += len(pkt.Payload[2:]) - if d.fragmentedSize > h264.MaxNALUSize { - d.fragmentedParts = d.fragmentedParts[:0] - d.fragmentedMode = false - return nil, 0, fmt.Errorf("NALU size (%d) is too big (maximum is %d)", d.fragmentedSize, h264.MaxNALUSize) - } - - d.fragmentedParts = append(d.fragmentedParts, pkt.Payload[2:]) - - end := (pkt.Payload[1] >> 6) & 0x01 - if end != 1 { - return nil, 0, ErrMorePacketsNeeded - } - - ret := make([]byte, d.fragmentedSize) - n := 0 - for _, p := range d.fragmentedParts { - n += copy(ret[n:], p) - } - nalus := [][]byte{ret} - - d.fragmentedParts = d.fragmentedParts[:0] - d.fragmentedMode = false - - var err error - nalus, err = d.finalize(nalus) + nalus, err := d.removeAnnexB(nalus) if err != nil { return nil, 0, err } - return nalus, d.timeDecoder.Decode(pkt.Timestamp), nil + return nalus, d.timeDecoder.Decode(pkt.Timestamp), err } // DecodeUntilMarker decodes NALUs from a RTP/H264 packet and puts them in a buffer. @@ -217,7 +183,7 @@ func (d *Decoder) DecodeUntilMarker(pkt *rtp.Packet) ([][]byte, time.Duration, e return ret, pts, nil } -func (d *Decoder) finalize(nalus [][]byte) ([][]byte, error) { +func (d *Decoder) removeAnnexB(nalus [][]byte) ([][]byte, error) { // some cameras / servers wrap NALUs into Annex-B if !d.firstNALUParsed { d.firstNALUParsed = true diff --git a/pkg/rtph264/rtph264_test.go b/pkg/rtph264/decoder_test.go similarity index 78% rename from pkg/rtph264/rtph264_test.go rename to pkg/rtph264/decoder_test.go index 633626df..3edaad8b 100644 --- a/pkg/rtph264/rtph264_test.go +++ b/pkg/rtph264/decoder_test.go @@ -349,44 +349,42 @@ func TestDecode(t *testing.T) { } } -func TestDecodePartOfFragmentedBeforeSingle(t *testing.T) { +func TestDecodeCorruptedFragment(t *testing.T) { d := &Decoder{} d.Init() - pkt := rtp.Packet{ + _, _, err := d.Decode(&rtp.Packet{ Header: rtp.Header{ Version: 2, - Marker: true, - PayloadType: 96, - SequenceNumber: 17647, - Timestamp: 2289531307, - SSRC: 0x9dbb7812, - }, - Payload: mergeBytes( - []byte{0x1c, 0x45}, - []byte{0x04, 0x05, 0x06, 0x07}, - bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 147), - ), - } - _, _, err := d.Decode(&pkt) - require.Equal(t, ErrNonStartingPacketAndNoPrevious, err) - - pkt = rtp.Packet{ - Header: rtp.Header{ - Version: 2, - Marker: true, + Marker: false, PayloadType: 96, SequenceNumber: 17645, - Timestamp: 2289528607, + Timestamp: 2289527317, SSRC: 0x9dbb7812, }, Payload: mergeBytes( - []byte{0x05}, - bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 8), + []byte{ + 0x1c, 0x85, + }, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182), + []byte{0x00, 0x01}, ), - } - _, _, err = d.Decode(&pkt) + }) + require.Equal(t, ErrMorePacketsNeeded, err) + + nalus, _, err := d.Decode(&rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: false, + PayloadType: 96, + SequenceNumber: 17646, + Timestamp: 2289527317, + SSRC: 0x9dbb7812, + }, + Payload: []byte{0x01, 0x00}, + }) require.NoError(t, err) + require.Equal(t, [][]byte{{0x01, 0x00}}, nalus) } func TestDecodeSTAPAWithPadding(t *testing.T) { @@ -538,7 +536,24 @@ func TestDecodeErrors(t *testing.T) { "invalid FU-A packet (invalid size)", }, { - "FU-A without start bit", + "FU-A with start and end bit", + []*rtp.Packet{ + { + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: 17646, + Timestamp: 2289527317, + SSRC: 0x9dbb7812, + }, + Payload: []byte{0x1c, 0b11000000}, + }, + }, + "invalid FU-A packet (can't contain both a start and end bit)", + }, + { + "FU-A non-starting", []*rtp.Packet{ { Header: rtp.Header{ @@ -563,112 +578,11 @@ func TestDecodeErrors(t *testing.T) { Timestamp: 2289527317, SSRC: 0x9dbb7812, }, - Payload: []byte{0x1c, 0x00}, + Payload: []byte{0x1c, 0b01000000}, }, }, "invalid FU-A packet (non-starting)", }, - { - "FU-A with 2nd packet empty", - []*rtp.Packet{ - { - Header: rtp.Header{ - Version: 2, - Marker: false, - PayloadType: 96, - SequenceNumber: 17645, - Timestamp: 2289527317, - SSRC: 0x9dbb7812, - }, - Payload: mergeBytes( - []byte{0x1c, 0x85}, - bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182), - []byte{0x00, 0x01}, - ), - }, - { - Header: rtp.Header{ - Version: 2, - Marker: false, - PayloadType: 96, - SequenceNumber: 17646, - Timestamp: 2289527317, - SSRC: 0x9dbb7812, - }, - }, - }, - "invalid FU-A packet (invalid size)", - }, - { - "FU-A with 2nd packet invalid", - []*rtp.Packet{ - { - Header: rtp.Header{ - Version: 2, - Marker: false, - PayloadType: 96, - SequenceNumber: 17645, - Timestamp: 2289527317, - SSRC: 0x9dbb7812, - }, - Payload: mergeBytes( - []byte{ - 0x1c, 0x85, - }, - bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182), - []byte{0x00, 0x01}, - ), - }, - { - Header: rtp.Header{ - Version: 2, - Marker: false, - PayloadType: 96, - SequenceNumber: 17646, - Timestamp: 2289527317, - SSRC: 0x9dbb7812, - }, - Payload: []byte{0x01, 0x00}, - }, - }, - "expected FU-A packet, got NonIDR packet", - }, - { - "FU-A with two starting packets", - []*rtp.Packet{ - { - Header: rtp.Header{ - Version: 2, - Marker: false, - PayloadType: 96, - SequenceNumber: 17645, - Timestamp: 2289527317, - SSRC: 0x9dbb7812, - }, - Payload: mergeBytes( - []byte{0x1c, 0x85}, - bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182), - []byte{0x00, 0x01}, - ), - }, - { - Header: rtp.Header{ - Version: 2, - Marker: false, - PayloadType: 96, - SequenceNumber: 17646, - Timestamp: 2289527317, - SSRC: 0x9dbb7812, - }, - Payload: mergeBytes( - []byte{0x1c, 0x85}, - bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 182), - []byte{0x00, 0x01}, - ), - }, - }, - "invalid FU-A packet (decoded two starting packets in a row)", - }, { "MTAP", []*rtp.Packet{ @@ -699,40 +613,3 @@ func TestDecodeErrors(t *testing.T) { }) } } - -func TestEncode(t *testing.T) { - for _, ca := range cases { - t.Run(ca.name, func(t *testing.T) { - e := &Encoder{ - PayloadType: 96, - SSRC: func() *uint32 { - v := uint32(0x9dbb7812) - return &v - }(), - InitialSequenceNumber: func() *uint16 { - v := uint16(0x44ed) - return &v - }(), - InitialTimestamp: func() *uint32 { - v := uint32(0x88776655) - return &v - }(), - } - e.Init() - - pkts, err := e.Encode(ca.nalus, ca.pts) - require.NoError(t, err) - require.Equal(t, ca.pkts, pkts) - }) - } -} - -func TestEncodeRandomInitialState(t *testing.T) { - e := &Encoder{ - PayloadType: 96, - } - e.Init() - require.NotEqual(t, nil, e.SSRC) - require.NotEqual(t, nil, e.InitialSequenceNumber) - require.NotEqual(t, nil, e.InitialTimestamp) -} diff --git a/pkg/rtph264/encoder.go b/pkg/rtph264/encoder.go index f1b16900..bad30f45 100644 --- a/pkg/rtph264/encoder.go +++ b/pkg/rtph264/encoder.go @@ -7,6 +7,10 @@ import ( "github.com/pion/rtp" ) +const ( + rtpVersion = 2 +) + func randUint32() uint32 { var b [4]byte rand.Read(b[:]) diff --git a/pkg/rtph264/encoder_test.go b/pkg/rtph264/encoder_test.go new file mode 100644 index 00000000..70f1388f --- /dev/null +++ b/pkg/rtph264/encoder_test.go @@ -0,0 +1,44 @@ +package rtph264 + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestEncode(t *testing.T) { + for _, ca := range cases { + t.Run(ca.name, func(t *testing.T) { + e := &Encoder{ + PayloadType: 96, + SSRC: func() *uint32 { + v := uint32(0x9dbb7812) + return &v + }(), + InitialSequenceNumber: func() *uint16 { + v := uint16(0x44ed) + return &v + }(), + InitialTimestamp: func() *uint32 { + v := uint32(0x88776655) + return &v + }(), + } + e.Init() + + pkts, err := e.Encode(ca.nalus, ca.pts) + require.NoError(t, err) + require.Equal(t, ca.pkts, pkts) + }) + } +} + +func TestEncodeRandomInitialState(t *testing.T) { + e := &Encoder{ + PayloadType: 96, + } + e.Init() + require.NotEqual(t, nil, e.SSRC) + require.NotEqual(t, nil, e.InitialSequenceNumber) + require.NotEqual(t, nil, e.InitialTimestamp) +} diff --git a/pkg/rtph264/rtph264.go b/pkg/rtph264/rtph264.go index cbf81d60..5d38a93d 100644 --- a/pkg/rtph264/rtph264.go +++ b/pkg/rtph264/rtph264.go @@ -2,6 +2,5 @@ package rtph264 const ( - rtpVersion = 0x02 rtpClockRate = 90000 // h264 always uses 90khz )