From 0b336c547cbf90094920e04aa466496f2bf27f13 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sun, 7 Mar 2021 20:54:11 +0100 Subject: [PATCH] rtph264: support decoding STAP-A frames --- examples/client-read-h264/main.go | 6 ++- pkg/rtph264/decoder.go | 84 +++++++++++++++++++++++++------ pkg/rtph264/rtph264.go | 65 ++++++++++++++++++++++++ pkg/rtph264/rtph264_test.go | 55 +++++++++++++++++++- 4 files changed, 192 insertions(+), 18 deletions(-) diff --git a/examples/client-read-h264/main.go b/examples/client-read-h264/main.go index e617f238..7323abc1 100644 --- a/examples/client-read-h264/main.go +++ b/examples/client-read-h264/main.go @@ -40,12 +40,14 @@ func main() { // get H264 NALUs of that track err = <-conn.ReadFrames(func(trackID int, typ gortsplib.StreamType, buf []byte) { if trackID == h264Track { - nt, err := dec.Decode(buf) + nts, err := dec.Decode(buf) if err != nil { return } - fmt.Printf("received H264 NALU of size %d\n", len(nt.NALU)) + for _, nt := range nts { + fmt.Printf("received H264 NALU of size %d\n", len(nt.NALU)) + } } }) panic(err) diff --git a/pkg/rtph264/decoder.go b/pkg/rtph264/decoder.go index fec48b1b..30f579dc 100644 --- a/pkg/rtph264/decoder.go +++ b/pkg/rtph264/decoder.go @@ -1,6 +1,7 @@ package rtph264 import ( + "encoding/binary" "errors" "fmt" "io" @@ -33,10 +34,15 @@ const ( // Decoder is a RTP/H264 decoder. type Decoder struct { - initialTs uint32 - initialTsSet bool + initialTs uint32 + initialTsSet bool + + // for Decode() and FU-A state decoderState fragmentedBuf []byte + + // for Read() + nalusQueue []*NALUAndTimestamp } // NewDecoder allocates a Decoder. @@ -44,10 +50,12 @@ func NewDecoder() *Decoder { return &Decoder{} } -// Decode decodes a NALU from RTP/H264 packets. -// Since a NALU can require multiple RTP/H264 packets, it returns -// one packet, or no packets with ErrMorePacketsNeeded. -func (d *Decoder) Decode(byts []byte) (*NALUAndTimestamp, error) { +// Decode decodes NALUs from RTP/H264 packets. +// It can return: +// * no NALUs and ErrMorePacketsNeeded +// * one NALU (in case of FU-A) +// * multiple NALUs (in case of STAP-A) +func (d *Decoder) Decode(byts []byte) ([]*NALUAndTimestamp, error) { switch d.state { case decoderStateInitial: pkt := rtp.Packet{} @@ -72,10 +80,44 @@ func (d *Decoder) Decode(byts []byte) (*NALUAndTimestamp, error) { NALUTypeReserved18, NALUTypeSliceLayerWithoutPartitioning, NALUTypeSliceExtension, NALUTypeSliceExtensionDepth, NALUTypeReserved22, NALUTypeReserved23: - return &NALUAndTimestamp{ + return []*NALUAndTimestamp{{ NALU: pkt.Payload, Timestamp: time.Duration(pkt.Timestamp-d.initialTs) * time.Second / rtpClockRate, - }, nil + }}, nil + + case NALUTypeStapA: + var ret []*NALUAndTimestamp + pkt.Payload = pkt.Payload[1:] + + for len(pkt.Payload) > 0 { + if len(pkt.Payload) < 2 { + return nil, fmt.Errorf("Invalid STAP-A packet") + } + + size := binary.BigEndian.Uint16(pkt.Payload) + pkt.Payload = pkt.Payload[2:] + + // avoid final padding + if size == 0 { + break + } + + if int(size) > len(pkt.Payload) { + return nil, fmt.Errorf("Invalid STAP-A packet") + } + + ret = append(ret, &NALUAndTimestamp{ + NALU: pkt.Payload[:size], + Timestamp: time.Duration(pkt.Timestamp-d.initialTs) * time.Second / rtpClockRate, + }) + pkt.Payload = pkt.Payload[size:] + } + + if len(ret) == 0 { + return nil, fmt.Errorf("STAP-A packet doesn't contain any NALU") + } + + return ret, nil case NALUTypeFuA: // first packet of a fragmented NALU nri := (pkt.Payload[0] >> 5) & 0x03 @@ -89,11 +131,11 @@ func (d *Decoder) Decode(byts []byte) (*NALUAndTimestamp, error) { d.state = decoderStateReadingFragmented return nil, ErrMorePacketsNeeded - case NALUTypeStapA, NALUTypeStapB, NALUTypeMtap16, NALUTypeMtap24, NALUTypeFuB: - return nil, fmt.Errorf("NALU type not supported (%d)", typ) + case NALUTypeStapB, NALUTypeMtap16, NALUTypeMtap24, NALUTypeFuB: + return nil, fmt.Errorf("NALU type not supported (%v)", typ) } - return nil, fmt.Errorf("invalid NALU type (%d)", typ) + return nil, fmt.Errorf("invalid NALU type (%v)", typ) default: // decoderStateReadingFragmented pkt := rtp.Packet{} @@ -102,6 +144,10 @@ func (d *Decoder) Decode(byts []byte) (*NALUAndTimestamp, error) { return nil, err } + if len(pkt.Payload) < 2 { + return nil, fmt.Errorf("Invalid FU-A packet") + } + typ := NALUType(pkt.Payload[0] & 0x1F) if typ != NALUTypeFuA { return nil, fmt.Errorf("non-starting NALU is not FU-A") @@ -115,15 +161,21 @@ func (d *Decoder) Decode(byts []byte) (*NALUAndTimestamp, error) { } d.state = decoderStateInitial - return &NALUAndTimestamp{ + return []*NALUAndTimestamp{{ NALU: d.fragmentedBuf, Timestamp: time.Duration(pkt.Timestamp-d.initialTs) * time.Second / rtpClockRate, - }, nil + }}, nil } } // Read reads RTP/H264 packets from a reader until a NALU is decoded. func (d *Decoder) Read(r io.Reader) (*NALUAndTimestamp, error) { + if len(d.nalusQueue) > 0 { + nalu := d.nalusQueue[0] + d.nalusQueue = d.nalusQueue[1:] + return nalu, nil + } + buf := make([]byte, 2048) for { n, err := r.Read(buf) @@ -131,13 +183,17 @@ func (d *Decoder) Read(r io.Reader) (*NALUAndTimestamp, error) { return nil, err } - nalu, err := d.Decode(buf[:n]) + nalus, err := d.Decode(buf[:n]) if err != nil { if err == ErrMorePacketsNeeded { continue } return nil, err } + + nalu := nalus[0] + d.nalusQueue = nalus[1:] + return nalu, nil } } diff --git a/pkg/rtph264/rtph264.go b/pkg/rtph264/rtph264.go index 52227618..a221ae00 100644 --- a/pkg/rtph264/rtph264.go +++ b/pkg/rtph264/rtph264.go @@ -46,3 +46,68 @@ const ( NALUTypeFuA NALUType = 28 NALUTypeFuB NALUType = 29 ) + +// String implements fmt.Stringer. +func (nt NALUType) String() string { + switch nt { + case NALUTypeNonIDR: + return "NonIDR" + case NALUTypeDataPartitionA: + return "DataPartitionA" + case NALUTypeDataPartitionB: + return "DataPartitionB" + case NALUTypeDataPartitionC: + return "DataPartitionC" + case NALUTypeIDR: + return "IDR" + case NALUTypeSei: + return "Sei" + case NALUTypeSPS: + return "SPS" + case NALUTypePPS: + return "PPS" + case NALUTypeAccessUnitDelimiter: + return "AccessUnitDelimiter" + case NALUTypeEndOfSequence: + return "EndOfSequence" + case NALUTypeEndOfStream: + return "EndOfStream" + case NALUTypeFillerData: + return "FillerData" + case NALUTypeSPSExtension: + return "SPSExtension" + case NALUTypePrefix: + return "Prefix" + case NALUTypeSubsetSPS: + return "SubsetSPS" + case NALUTypeReserved16: + return "Reserved16" + case NALUTypeReserved17: + return "Reserved17" + case NALUTypeReserved18: + return "Reserved18" + case NALUTypeSliceLayerWithoutPartitioning: + return "SliceLayerWithoutPartitioning" + case NALUTypeSliceExtension: + return "SliceExtension" + case NALUTypeSliceExtensionDepth: + return "SliceExtensionDepth" + case NALUTypeReserved22: + return "Reserved22" + case NALUTypeReserved23: + return "Reserved23" + case NALUTypeStapA: + return "StapA" + case NALUTypeStapB: + return "StapB" + case NALUTypeMtap16: + return "Mtap16" + case NALUTypeMtap24: + return "Mtap24" + case NALUTypeFuA: + return "FuA" + case NALUTypeFuB: + return "FuB" + } + return "unknown" +} diff --git a/pkg/rtph264/rtph264_test.go b/pkg/rtph264/rtph264_test.go index da40997d..62e8f0fd 100644 --- a/pkg/rtph264/rtph264_test.go +++ b/pkg/rtph264/rtph264_test.go @@ -109,9 +109,8 @@ func TestDecode(t *testing.T) { return 0, io.EOF } - n := copy(p, ca.enc[i]) i++ - return n, nil + return copy(p, ca.enc[i-1]), nil }) d := NewDecoder() @@ -133,3 +132,55 @@ func TestDecode(t *testing.T) { }) } } + +func TestDecodeStapA(t *testing.T) { + sent := false + r := readerFunc(func(p []byte) (int, error) { + if sent { + return 0, io.EOF + } + + sent = true + pkt := []byte{ + 0x80, 0xe0, 0x0e, 0x6a, 0x48, 0xf1, 0x7d, 0xb9, + 0x23, 0xe6, 0x5d, 0x50, 0x18, 0x00, 0x02, 0x09, + 0xf0, 0x00, 0x44, 0x41, 0x9a, 0x24, 0x6c, 0x41, + 0x4f, 0xfe, 0xd6, 0x8c, 0xb0, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, + 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, + 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x6d, 0x40, + } + return copy(p, pkt), nil + }) + + d := NewDecoder() + + nt, err := d.Read(r) + require.NoError(t, err) + require.Equal(t, &NALUAndTimestamp{ + NALU: []byte{0x09, 0xF0}, + }, nt) + + nt, err = d.Read(r) + require.NoError(t, err) + require.Equal(t, &NALUAndTimestamp{ + NALU: []byte{ + 0x41, 0x9a, 0x24, 0x6c, 0x41, 0x4f, 0xfe, 0xd6, + 0x8c, 0xb0, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, + 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, + 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x6d, 0x40, + }, + }, nt) + + _, err = d.Read(r) + require.Equal(t, io.EOF, err) +}