From dcbd9d8211503043555ffde231d69ba8a10f5bd3 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Mon, 5 Apr 2021 17:15:56 +0200 Subject: [PATCH] rtp*: remove Read(), return nalus and pts separately --- examples/client-read-h264/main.go | 6 +- pkg/rtpaac/decoder.go | 34 ------ pkg/rtpaac/rtpaac_test.go | 32 ++--- pkg/rtph264/decoder.go | 123 +++++++------------ pkg/rtph264/encoder.go | 77 ++++++------ pkg/rtph264/nalutype.go | 36 +++--- pkg/rtph264/rtph264.go | 10 -- pkg/rtph264/rtph264_test.go | 191 +++++++++++++----------------- 8 files changed, 197 insertions(+), 312 deletions(-) diff --git a/examples/client-read-h264/main.go b/examples/client-read-h264/main.go index e10eac4a..8a3180ca 100644 --- a/examples/client-read-h264/main.go +++ b/examples/client-read-h264/main.go @@ -41,14 +41,14 @@ func main() { err = <-conn.ReadFrames(func(trackID int, typ gortsplib.StreamType, buf []byte) { if trackID == h264Track { // convert RTP frames into H264 NALUs - nts, err := dec.Decode(buf) + nalus, _, err := dec.Decode(buf) if err != nil { return } // print NALUs - for _, nt := range nts { - fmt.Printf("received H264 NALU of size %d\n", len(nt.NALU)) + for _, nalu := range nalus { + fmt.Printf("received H264 NALU of size %d\n", len(nalu)) } } }) diff --git a/pkg/rtpaac/decoder.go b/pkg/rtpaac/decoder.go index b199ec87..5f8ea6a4 100644 --- a/pkg/rtpaac/decoder.go +++ b/pkg/rtpaac/decoder.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "errors" "fmt" - "io" "time" "github.com/pion/rtp" @@ -29,9 +28,6 @@ type Decoder struct { // for Decode() state decoderState fragmentedBuf []byte - - // for Read() - readQueue []*AUAndTimestamp } // NewDecoder allocates a Decoder. @@ -176,33 +172,3 @@ func (d *Decoder) Decode(byts []byte) ([]*AUAndTimestamp, error) { }}, nil } } - -// Read reads RTP/AAC packets from a reader until an AU is decoded. -func (d *Decoder) Read(r io.Reader) (*AUAndTimestamp, error) { - if len(d.readQueue) > 0 { - au := d.readQueue[0] - d.readQueue = d.readQueue[1:] - return au, nil - } - - buf := make([]byte, 2048) - for { - n, err := r.Read(buf) - if err != nil { - return nil, err - } - - aus, err := d.Decode(buf[:n]) - if err != nil { - if err == ErrMorePacketsNeeded { - continue - } - return nil, err - } - - au := aus[0] - d.readQueue = aus[1:] - - return au, nil - } -} diff --git a/pkg/rtpaac/rtpaac_test.go b/pkg/rtpaac/rtpaac_test.go index 309d22b8..7ef67b3c 100644 --- a/pkg/rtpaac/rtpaac_test.go +++ b/pkg/rtpaac/rtpaac_test.go @@ -2,7 +2,6 @@ package rtpaac import ( "bytes" - "io" "testing" "time" @@ -25,12 +24,6 @@ func mergeBytes(vals ...[]byte) []byte { return res } -type readerFunc func(p []byte) (int, error) - -func (f readerFunc) Read(p []byte) (int, error) { - return f(p) -} - var cases = []struct { name string dec []*AUAndTimestamp @@ -207,16 +200,6 @@ func TestEncode(t *testing.T) { func TestDecode(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { - i := 0 - r := readerFunc(func(p []byte) (int, error) { - if i == len(ca.enc) { - return 0, io.EOF - } - - i++ - return copy(p, ca.enc[i-1]), nil - }) - d := NewDecoder(48000) // send an initial packet downstream @@ -228,14 +211,19 @@ func TestDecode(t *testing.T) { }) require.NoError(t, err) - for _, dec0 := range ca.dec { - dec, err := d.Read(r) + var ats []*AUAndTimestamp + + for _, pkt := range ca.enc { + addATs, err := d.Decode(pkt) + if err == ErrMorePacketsNeeded { + continue + } + require.NoError(t, err) - require.Equal(t, dec0, dec) + ats = append(ats, addATs...) } - _, err = d.Read(r) - require.Equal(t, io.EOF, err) + require.Equal(t, ca.dec, ats) }) } } diff --git a/pkg/rtph264/decoder.go b/pkg/rtph264/decoder.go index 663cb6eb..ec205b89 100644 --- a/pkg/rtph264/decoder.go +++ b/pkg/rtph264/decoder.go @@ -40,9 +40,6 @@ type Decoder struct { // for Decode() state decoderState fragmentedBuf []byte - - // for Read() - readQueue []*NALUAndTimestamp } // NewDecoder allocates a Decoder. @@ -59,13 +56,13 @@ func (d *Decoder) decodeTimestamp(ts uint32) time.Duration { // * no NALUs and ErrMorePacketsNeeded // * one NALU (in case of FU-A) // * multiple NALUs (in case of STAP-A) -func (d *Decoder) Decode(byts []byte) ([]*NALUAndTimestamp, error) { +func (d *Decoder) Decode(byts []byte) ([][]byte, time.Duration, error) { switch d.state { case decoderStateInitial: pkt := rtp.Packet{} err := pkt.Unmarshal(byts) if err != nil { - return nil, err + return nil, 0, err } if !d.initialTsSet { @@ -74,19 +71,19 @@ func (d *Decoder) Decode(byts []byte) ([]*NALUAndTimestamp, error) { } if len(pkt.Payload) < 1 { - return nil, fmt.Errorf("payload is too short") + return nil, 0, fmt.Errorf("payload is too short") } typ := NALUType(pkt.Payload[0] & 0x1F) switch typ { - case NALUTypeStapA: - var ret []*NALUAndTimestamp + case NALUTypeSTAPA: + var nalus [][]byte pkt.Payload = pkt.Payload[1:] for len(pkt.Payload) > 0 { if len(pkt.Payload) < 2 { - return nil, fmt.Errorf("Invalid STAP-A packet") + return nil, 0, fmt.Errorf("Invalid STAP-A packet") } size := binary.BigEndian.Uint16(pkt.Payload) @@ -98,30 +95,27 @@ func (d *Decoder) Decode(byts []byte) ([]*NALUAndTimestamp, error) { } if int(size) > len(pkt.Payload) { - return nil, fmt.Errorf("Invalid STAP-A packet") + return nil, 0, fmt.Errorf("Invalid STAP-A packet") } - ret = append(ret, &NALUAndTimestamp{ - NALU: pkt.Payload[:size], - Timestamp: d.decodeTimestamp(pkt.Timestamp), - }) + nalus = append(nalus, pkt.Payload[:size]) pkt.Payload = pkt.Payload[size:] } - if len(ret) == 0 { - return nil, fmt.Errorf("STAP-A packet doesn't contain any NALU") + if len(nalus) == 0 { + return nil, 0, fmt.Errorf("STAP-A packet doesn't contain any NALU") } - return ret, nil + return nalus, d.decodeTimestamp(pkt.Timestamp), nil - case NALUTypeFuA: // first packet of a fragmented NALU + case NALUTypeFUA: // first packet of a fragmented NALU if len(pkt.Payload) < 2 { - return nil, fmt.Errorf("Invalid FU-A packet") + return nil, 0, fmt.Errorf("Invalid FU-A packet") } start := pkt.Payload[1] >> 7 if start != 1 { - return nil, fmt.Errorf("first NALU does not contain the start bit") + return nil, 0, fmt.Errorf("first NALU does not contain the start bit") } nri := (pkt.Payload[0] >> 5) & 0x03 @@ -129,35 +123,32 @@ func (d *Decoder) Decode(byts []byte) ([]*NALUAndTimestamp, error) { d.fragmentedBuf = append([]byte{(nri << 5) | typ}, pkt.Payload[2:]...) d.state = decoderStateReadingFragmented - return nil, ErrMorePacketsNeeded + return nil, 0, ErrMorePacketsNeeded - case NALUTypeStapB, NALUTypeMtap16, - NALUTypeMtap24, NALUTypeFuB: - return nil, fmt.Errorf("NALU type not yet supported (%v)", typ) + case NALUTypeSTAPB, NALUTypeMTAP16, + NALUTypeMTAP24, NALUTypeFUB: + return nil, 0, fmt.Errorf("NALU type not supported (%v)", typ) } - return []*NALUAndTimestamp{{ - NALU: pkt.Payload, - Timestamp: d.decodeTimestamp(pkt.Timestamp), - }}, nil + return [][]byte{pkt.Payload}, d.decodeTimestamp(pkt.Timestamp), nil default: // decoderStateReadingFragmented pkt := rtp.Packet{} err := pkt.Unmarshal(byts) if err != nil { d.state = decoderStateInitial - return nil, err + return nil, 0, err } if len(pkt.Payload) < 2 { d.state = decoderStateInitial - return nil, fmt.Errorf("Invalid FU-A packet") + return nil, 0, fmt.Errorf("Invalid FU-A packet") } typ := NALUType(pkt.Payload[0] & 0x1F) - if typ != NALUTypeFuA { + if typ != NALUTypeFUA { d.state = decoderStateInitial - return nil, fmt.Errorf("non-starting NALU is not FU-A") + return nil, 0, fmt.Errorf("non-starting NALU is not FU-A") } end := (pkt.Payload[1] >> 6) & 0x01 @@ -165,44 +156,11 @@ func (d *Decoder) Decode(byts []byte) ([]*NALUAndTimestamp, error) { d.fragmentedBuf = append(d.fragmentedBuf, pkt.Payload[2:]...) if end != 1 { - return nil, ErrMorePacketsNeeded + return nil, 0, ErrMorePacketsNeeded } d.state = decoderStateInitial - return []*NALUAndTimestamp{{ - NALU: d.fragmentedBuf, - Timestamp: d.decodeTimestamp(pkt.Timestamp), - }}, nil - } -} - -// Read reads RTP/H264 packets from a reader until a NALU is decoded. -func (d *Decoder) Read(r io.Reader) (*NALUAndTimestamp, error) { - if len(d.readQueue) > 0 { - nalu := d.readQueue[0] - d.readQueue = d.readQueue[1:] - return nalu, nil - } - - buf := make([]byte, 2048) - for { - n, err := r.Read(buf) - if err != nil { - return nil, err - } - - nalus, err := d.Decode(buf[:n]) - if err != nil { - if err == ErrMorePacketsNeeded { - continue - } - return nil, err - } - - nalu := nalus[0] - d.readQueue = nalus[1:] - - return nalu, nil + return [][]byte{d.fragmentedBuf}, d.decodeTimestamp(pkt.Timestamp), nil } } @@ -212,23 +170,34 @@ func (d *Decoder) ReadSPSPPS(r io.Reader) ([]byte, []byte, error) { var sps []byte var pps []byte + buf := make([]byte, 2048) for { - nt, err := d.Read(r) + n, err := r.Read(buf) if err != nil { return nil, nil, err } - switch NALUType(nt.NALU[0] & 0x1F) { - case NALUTypeSPS: - sps = append([]byte(nil), nt.NALU...) - if sps != nil && pps != nil { - return sps, pps, nil + nalus, _, err := d.Decode(buf[:n]) + if err != nil { + if err == ErrMorePacketsNeeded { + continue } + return nil, nil, err + } - case NALUTypePPS: - pps = append([]byte(nil), nt.NALU...) - if sps != nil && pps != nil { - return sps, pps, nil + for _, nalu := range nalus { + switch NALUType(nalu[0] & 0x1F) { + case NALUTypeSPS: + sps = append([]byte(nil), nalu...) + if sps != nil && pps != nil { + return sps, pps, nil + } + + case NALUTypePPS: + pps = append([]byte(nil), nalu...) + if sps != nil && pps != nil { + return sps, pps, nil + } } } } diff --git a/pkg/rtph264/encoder.go b/pkg/rtph264/encoder.go index 01ac3aff..fcda87c9 100644 --- a/pkg/rtph264/encoder.go +++ b/pkg/rtph264/encoder.go @@ -2,7 +2,6 @@ package rtph264 import ( "encoding/binary" - "fmt" "math/rand" "time" @@ -60,24 +59,20 @@ func (e *Encoder) encodeTimestamp(ts time.Duration) uint32 { // * a single packets // * multiple fragmented packets (FU-A) // * an aggregated packet (STAP-A) -func (e *Encoder) Encode(nts []*NALUAndTimestamp) ([][]byte, error) { +func (e *Encoder) Encode(nalus [][]byte, pts time.Duration) ([][]byte, error) { var rets [][]byte - var batch []*NALUAndTimestamp + var batch [][]byte // split NALUs into batches - for _, nt := range nts { - if len(batch) > 0 && batch[0].Timestamp != nt.Timestamp { - return nil, fmt.Errorf("encoding NALUs with different timestamps is not supported") - } - - if e.lenAggregated(batch, nt) <= rtpPayloadMaxSize { + for _, nalu := range nalus { + if e.lenAggregated(batch, nalu) <= rtpPayloadMaxSize { // add to existing batch - batch = append(batch, nt) + batch = append(batch, nalu) } else { // write batch if batch != nil { - pkts, err := e.writeBatch(batch, false) + pkts, err := e.writeBatch(batch, pts, false) if err != nil { return nil, err } @@ -85,13 +80,13 @@ func (e *Encoder) Encode(nts []*NALUAndTimestamp) ([][]byte, error) { } // initialize new batch - batch = []*NALUAndTimestamp{nt} + batch = [][]byte{nalu} } } // write final batch // marker is used to indicate when all NALUs with same PTS have been sent - pkts, err := e.writeBatch(batch, true) + pkts, err := e.writeBatch(batch, pts, true) if err != nil { return nil, err } @@ -100,31 +95,31 @@ func (e *Encoder) Encode(nts []*NALUAndTimestamp) ([][]byte, error) { return rets, nil } -func (e *Encoder) writeBatch(nts []*NALUAndTimestamp, marker bool) ([][]byte, error) { - if len(nts) == 1 { +func (e *Encoder) writeBatch(nalus [][]byte, pts time.Duration, marker bool) ([][]byte, error) { + if len(nalus) == 1 { // the NALU fits into a single RTP packet - if len(nts[0].NALU) < rtpPayloadMaxSize { - return e.writeSingle(nts[0], marker) + if len(nalus[0]) < rtpPayloadMaxSize { + return e.writeSingle(nalus[0], pts, marker) } // split the NALU into multiple fragmentation packet - return e.writeFragmented(nts[0], marker) + return e.writeFragmented(nalus[0], pts, marker) } - return e.writeAggregated(nts, marker) + return e.writeAggregated(nalus, pts, marker) } -func (e *Encoder) writeSingle(nt *NALUAndTimestamp, marker bool) ([][]byte, error) { +func (e *Encoder) writeSingle(nalu []byte, pts time.Duration, marker bool) ([][]byte, error) { rpkt := rtp.Packet{ Header: rtp.Header{ Version: rtpVersion, PayloadType: e.payloadType, SequenceNumber: e.sequenceNumber, - Timestamp: e.encodeTimestamp(nt.Timestamp), + Timestamp: e.encodeTimestamp(pts), SSRC: e.ssrc, Marker: marker, }, - Payload: nt.NALU, + Payload: nalu, } e.sequenceNumber++ @@ -136,9 +131,7 @@ func (e *Encoder) writeSingle(nt *NALUAndTimestamp, marker bool) ([][]byte, erro return [][]byte{frame}, nil } -func (e *Encoder) writeFragmented(nt *NALUAndTimestamp, marker bool) ([][]byte, error) { - nalu := nt.NALU - +func (e *Encoder) writeFragmented(nalu []byte, pts time.Duration, marker bool) ([][]byte, error) { // use only FU-A, not FU-B, since we always use non-interleaved mode // (packetization-mode=1) packetCount := (len(nalu) - 1) / (rtpPayloadMaxSize - 2) @@ -148,14 +141,14 @@ func (e *Encoder) writeFragmented(nt *NALUAndTimestamp, marker bool) ([][]byte, } ret := make([][]byte, packetCount) - ts := e.encodeTimestamp(nt.Timestamp) + encPTS := e.encodeTimestamp(pts) nri := (nalu[0] >> 5) & 0x03 typ := nalu[0] & 0x1F nalu = nalu[1:] // remove header for i := range ret { - indicator := (nri << 5) | uint8(NALUTypeFuA) + indicator := (nri << 5) | uint8(NALUTypeFUA) start := uint8(0) if i == 0 { @@ -180,7 +173,7 @@ func (e *Encoder) writeFragmented(nt *NALUAndTimestamp, marker bool) ([][]byte, Version: rtpVersion, PayloadType: e.payloadType, SequenceNumber: e.sequenceNumber, - Timestamp: ts, + Timestamp: encPTS, SSRC: e.ssrc, Marker: (i == (packetCount-1) && marker), }, @@ -199,37 +192,37 @@ func (e *Encoder) writeFragmented(nt *NALUAndTimestamp, marker bool) ([][]byte, return ret, nil } -func (e *Encoder) lenAggregated(nts []*NALUAndTimestamp, additionalEl *NALUAndTimestamp) int { +func (e *Encoder) lenAggregated(nalus [][]byte, addNALU []byte) int { ret := 1 // header - for _, bnt := range nts { - ret += 2 // size - ret += len(bnt.NALU) // nalu + for _, nalu := range nalus { + ret += 2 // size + ret += len(nalu) // nalu } - if additionalEl != nil { - ret += 2 // size - ret += len(additionalEl.NALU) // nalu + if addNALU != nil { + ret += 2 // size + ret += len(addNALU) // nalu } return ret } -func (e *Encoder) writeAggregated(nts []*NALUAndTimestamp, marker bool) ([][]byte, error) { - payload := make([]byte, e.lenAggregated(nts, nil)) +func (e *Encoder) writeAggregated(nalus [][]byte, pts time.Duration, marker bool) ([][]byte, error) { + payload := make([]byte, e.lenAggregated(nalus, nil)) // header - payload[0] = uint8(NALUTypeStapA) + payload[0] = uint8(NALUTypeSTAPA) pos := 1 - for _, nt := range nts { + for _, nalu := range nalus { // size - naluLen := len(nt.NALU) + naluLen := len(nalu) binary.BigEndian.PutUint16(payload[pos:], uint16(naluLen)) pos += 2 // nalu - copy(payload[pos:], nt.NALU) + copy(payload[pos:], nalu) pos += naluLen } @@ -238,7 +231,7 @@ func (e *Encoder) writeAggregated(nts []*NALUAndTimestamp, marker bool) ([][]byt Version: rtpVersion, PayloadType: e.payloadType, SequenceNumber: e.sequenceNumber, - Timestamp: e.encodeTimestamp(nts[0].Timestamp), + Timestamp: e.encodeTimestamp(pts), SSRC: e.ssrc, Marker: marker, }, diff --git a/pkg/rtph264/nalutype.go b/pkg/rtph264/nalutype.go index 0a1ee69c..00d7f03d 100644 --- a/pkg/rtph264/nalutype.go +++ b/pkg/rtph264/nalutype.go @@ -28,12 +28,12 @@ const ( NALUTypeSliceExtensionDepth NALUType = 21 NALUTypeReserved22 NALUType = 22 NALUTypeReserved23 NALUType = 23 - NALUTypeStapA NALUType = 24 - NALUTypeStapB NALUType = 25 - NALUTypeMtap16 NALUType = 26 - NALUTypeMtap24 NALUType = 27 - NALUTypeFuA NALUType = 28 - NALUTypeFuB NALUType = 29 + NALUTypeSTAPA NALUType = 24 + NALUTypeSTAPB NALUType = 25 + NALUTypeMTAP16 NALUType = 26 + NALUTypeMTAP24 NALUType = 27 + NALUTypeFUA NALUType = 28 + NALUTypeFUB NALUType = 29 ) // String implements fmt.Stringer. @@ -85,18 +85,18 @@ func (nt NALUType) String() string { return "Reserved22" case NALUTypeReserved23: return "Reserved23" - case NALUTypeStapA: - return "StapA" - case NALUTypeStapB: - return "StapB" - case NALUTypeMtap16: - return "Mtap16" - case NALUTypeMtap24: - return "Mtap24" - case NALUTypeFuA: - return "FuA" - case NALUTypeFuB: - return "FuB" + case NALUTypeSTAPA: + return "STAPA" + case NALUTypeSTAPB: + return "STAPB" + case NALUTypeMTAP16: + return "MTAP16" + case NALUTypeMTAP24: + return "MTAP24" + case NALUTypeFUA: + return "FUA" + case NALUTypeFUB: + return "FUB" } return "unknown" } diff --git a/pkg/rtph264/rtph264.go b/pkg/rtph264/rtph264.go index d787a3bd..ab0c3f6f 100644 --- a/pkg/rtph264/rtph264.go +++ b/pkg/rtph264/rtph264.go @@ -1,12 +1,2 @@ // Package rtph264 contains a RTP/H264 decoder and encoder. package rtph264 - -import ( - "time" -) - -// NALUAndTimestamp is a Network Abstraction Layer Unit and its timestamp. -type NALUAndTimestamp struct { - Timestamp time.Duration - NALU []byte -} diff --git a/pkg/rtph264/rtph264_test.go b/pkg/rtph264/rtph264_test.go index bc96113e..513a98f4 100644 --- a/pkg/rtph264/rtph264_test.go +++ b/pkg/rtph264/rtph264_test.go @@ -2,7 +2,6 @@ package rtph264 import ( "bytes" - "io" "testing" "time" @@ -25,28 +24,21 @@ func mergeBytes(vals ...[]byte) []byte { return res } -type readerFunc func(p []byte) (int, error) - -func (f readerFunc) Read(p []byte) (int, error) { - return f(p) -} - var cases = []struct { - name string - dec []*NALUAndTimestamp - enc [][]byte + name string + nalus [][]byte + pts time.Duration + enc [][]byte }{ { "single", - []*NALUAndTimestamp{ - { - Timestamp: 25 * time.Millisecond, - NALU: mergeBytes( - []byte{0x05}, - bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 8), - ), - }, + [][]byte{ + mergeBytes( + []byte{0x05}, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 8), + ), }, + 25 * time.Millisecond, [][]byte{ mergeBytes( []byte{ @@ -59,15 +51,13 @@ var cases = []struct { }, { "negative timestamp", - []*NALUAndTimestamp{ - { - Timestamp: -20 * time.Millisecond, - NALU: mergeBytes( - []byte{0x05}, - bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 8), - ), - }, + [][]byte{ + mergeBytes( + []byte{0x05}, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 8), + ), }, + -20 * time.Millisecond, [][]byte{ mergeBytes( []byte{ @@ -80,15 +70,13 @@ var cases = []struct { }, { "fragmented", - []*NALUAndTimestamp{ - { - Timestamp: 55 * time.Millisecond, - NALU: mergeBytes( - []byte{0x05}, - bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 256), - ), - }, + [][]byte{ + mergeBytes( + []byte{0x05}, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 256), + ), }, + 55 * time.Millisecond, [][]byte{ mergeBytes( []byte{ @@ -110,24 +98,21 @@ var cases = []struct { }, { "aggregated", - []*NALUAndTimestamp{ + [][]byte{ + {0x09, 0xF0}, { - NALU: []byte{0x09, 0xF0}, - }, - { - NALU: []byte{ - 0x41, 0x9a, 0x24, 0x6c, 0x41, 0x4f, 0xfe, 0xd6, - 0x8c, 0xb0, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, - 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, - 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, - 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, - 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, - 0x00, 0x00, 0x6d, 0x40, - }, + 0x41, 0x9a, 0x24, 0x6c, 0x41, 0x4f, 0xfe, 0xd6, + 0x8c, 0xb0, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, + 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, + 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x6d, 0x40, }, }, + 0, [][]byte{ { 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x66, 0x55, @@ -146,30 +131,25 @@ var cases = []struct { }, { "aggregated followed by single", - []*NALUAndTimestamp{ + [][]byte{ + {0x09, 0xF0}, { - NALU: []byte{0x09, 0xF0}, - }, - { - NALU: []byte{ - 0x41, 0x9a, 0x24, 0x6c, 0x41, 0x4f, 0xfe, 0xd6, - 0x8c, 0xb0, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, - 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, - 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, - 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, - 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, - 0x00, 0x00, 0x6d, 0x40, - }, - }, - { - NALU: mergeBytes( - []byte{0x08}, - bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 175), - ), + 0x41, 0x9a, 0x24, 0x6c, 0x41, 0x4f, 0xfe, 0xd6, + 0x8c, 0xb0, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, + 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, + 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x6d, 0x40, }, + mergeBytes( + []byte{0x08}, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 175), + ), }, + 0, [][]byte{ { 0x80, 0x60, 0x44, 0xed, 0x88, 0x77, 0x66, 0x55, @@ -195,20 +175,15 @@ var cases = []struct { }, { "fragmented followed by aggregated", - []*NALUAndTimestamp{ - { - NALU: mergeBytes( - []byte{0x05}, - bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 256), - ), - }, - { - NALU: []byte{0x09, 0xF0}, - }, - { - NALU: []byte{0x09, 0xF0}, - }, + [][]byte{ + mergeBytes( + []byte{0x05}, + bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 256), + ), + {0x09, 0xF0}, + {0x09, 0xF0}, }, + 0, [][]byte{ mergeBytes( []byte{ @@ -242,7 +217,7 @@ func TestEncode(t *testing.T) { ssrc := uint32(0x9dbb7812) initialTs := uint32(0x88776655) e := NewEncoder(96, &sequenceNumber, &ssrc, &initialTs) - enc, err := e.Encode(ca.dec) + enc, err := e.Encode(ca.nalus, ca.pts) require.NoError(t, err) require.Equal(t, ca.enc, enc) }) @@ -252,35 +227,31 @@ func TestEncode(t *testing.T) { func TestDecode(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { - i := 0 - r := readerFunc(func(p []byte) (int, error) { - if i == len(ca.enc) { - return 0, io.EOF - } - - i++ - return copy(p, ca.enc[i-1]), nil - }) - d := NewDecoder() // send an initial packet downstream // in order to compute the timestamp, // which is relative to the initial packet - _, err := d.Decode([]byte{ + _, _, err := d.Decode([]byte{ 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x66, 0x55, 0x9d, 0xbb, 0x78, 0x12, 0x06, 0x00, }) require.NoError(t, err) - for _, dec0 := range ca.dec { - dec, err := d.Read(r) + var nalus [][]byte + + for _, pkt := range ca.enc { + addNALUs, pts, err := d.Decode(pkt) + if err == ErrMorePacketsNeeded { + continue + } + require.NoError(t, err) - require.Equal(t, dec0, dec) + require.Equal(t, ca.pts, pts) + nalus = append(nalus, addNALUs...) } - _, err = d.Read(r) - require.Equal(t, io.EOF, err) + require.Equal(t, ca.nalus, nalus) }) } } @@ -303,7 +274,7 @@ func TestDecodeErrors(t *testing.T) { "STAP-A without NALUs", []byte{ 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, - 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeStapA), + 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeSTAPA), }, "STAP-A packet doesn't contain any NALU", }, @@ -311,7 +282,7 @@ func TestDecodeErrors(t *testing.T) { "STAP-A without size", []byte{ 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, - 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeStapA), 0x01, + 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeSTAPA), 0x01, }, "Invalid STAP-A packet", }, @@ -319,7 +290,7 @@ func TestDecodeErrors(t *testing.T) { "STAP-A with invalid size", []byte{ 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, - 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeStapA), 0x00, 0x15, + 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeSTAPA), 0x00, 0x15, }, "Invalid STAP-A packet", }, @@ -327,7 +298,7 @@ func TestDecodeErrors(t *testing.T) { "FU-A without payload", []byte{ 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, - 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeFuA), + 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeFUA), }, "Invalid FU-A packet", }, @@ -335,14 +306,22 @@ func TestDecodeErrors(t *testing.T) { "FU-A without start bit", []byte{ 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, - 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeFuA), 0x00, + 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeFUA), 0x00, }, "first NALU does not contain the start bit", }, + { + "MTAP", + []byte{ + 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, + 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeMTAP16), + }, + "NALU type not supported (MTAP16)", + }, } { t.Run(ca.name, func(t *testing.T) { d := NewDecoder() - _, err := d.Decode(ca.byts) + _, _, err := d.Decode(ca.byts) require.NotEqual(t, ErrMorePacketsNeeded, err) require.Equal(t, ca.err, err.Error()) })