rtp*: remove Read(), return nalus and pts separately

This commit is contained in:
aler9
2021-04-05 17:15:56 +02:00
parent db676cab85
commit dcbd9d8211
8 changed files with 197 additions and 312 deletions

View File

@@ -41,14 +41,14 @@ func main() {
err = <-conn.ReadFrames(func(trackID int, typ gortsplib.StreamType, buf []byte) { err = <-conn.ReadFrames(func(trackID int, typ gortsplib.StreamType, buf []byte) {
if trackID == h264Track { if trackID == h264Track {
// convert RTP frames into H264 NALUs // convert RTP frames into H264 NALUs
nts, err := dec.Decode(buf) nalus, _, err := dec.Decode(buf)
if err != nil { if err != nil {
return return
} }
// print NALUs // print NALUs
for _, nt := range nts { for _, nalu := range nalus {
fmt.Printf("received H264 NALU of size %d\n", len(nt.NALU)) fmt.Printf("received H264 NALU of size %d\n", len(nalu))
} }
} }
}) })

View File

@@ -4,7 +4,6 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io"
"time" "time"
"github.com/pion/rtp" "github.com/pion/rtp"
@@ -29,9 +28,6 @@ type Decoder struct {
// for Decode() // for Decode()
state decoderState state decoderState
fragmentedBuf []byte fragmentedBuf []byte
// for Read()
readQueue []*AUAndTimestamp
} }
// NewDecoder allocates a Decoder. // NewDecoder allocates a Decoder.
@@ -176,33 +172,3 @@ func (d *Decoder) Decode(byts []byte) ([]*AUAndTimestamp, error) {
}}, nil }}, nil
} }
} }
// Read reads RTP/AAC packets from a reader until an AU is decoded.
func (d *Decoder) Read(r io.Reader) (*AUAndTimestamp, error) {
if len(d.readQueue) > 0 {
au := d.readQueue[0]
d.readQueue = d.readQueue[1:]
return au, nil
}
buf := make([]byte, 2048)
for {
n, err := r.Read(buf)
if err != nil {
return nil, err
}
aus, err := d.Decode(buf[:n])
if err != nil {
if err == ErrMorePacketsNeeded {
continue
}
return nil, err
}
au := aus[0]
d.readQueue = aus[1:]
return au, nil
}
}

View File

@@ -2,7 +2,6 @@ package rtpaac
import ( import (
"bytes" "bytes"
"io"
"testing" "testing"
"time" "time"
@@ -25,12 +24,6 @@ func mergeBytes(vals ...[]byte) []byte {
return res return res
} }
type readerFunc func(p []byte) (int, error)
func (f readerFunc) Read(p []byte) (int, error) {
return f(p)
}
var cases = []struct { var cases = []struct {
name string name string
dec []*AUAndTimestamp dec []*AUAndTimestamp
@@ -207,16 +200,6 @@ func TestEncode(t *testing.T) {
func TestDecode(t *testing.T) { func TestDecode(t *testing.T) {
for _, ca := range cases { for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
i := 0
r := readerFunc(func(p []byte) (int, error) {
if i == len(ca.enc) {
return 0, io.EOF
}
i++
return copy(p, ca.enc[i-1]), nil
})
d := NewDecoder(48000) d := NewDecoder(48000)
// send an initial packet downstream // send an initial packet downstream
@@ -228,14 +211,19 @@ func TestDecode(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
for _, dec0 := range ca.dec { var ats []*AUAndTimestamp
dec, err := d.Read(r)
require.NoError(t, err) for _, pkt := range ca.enc {
require.Equal(t, dec0, dec) addATs, err := d.Decode(pkt)
if err == ErrMorePacketsNeeded {
continue
} }
_, err = d.Read(r) require.NoError(t, err)
require.Equal(t, io.EOF, err) ats = append(ats, addATs...)
}
require.Equal(t, ca.dec, ats)
}) })
} }
} }

View File

@@ -40,9 +40,6 @@ type Decoder struct {
// for Decode() // for Decode()
state decoderState state decoderState
fragmentedBuf []byte fragmentedBuf []byte
// for Read()
readQueue []*NALUAndTimestamp
} }
// NewDecoder allocates a Decoder. // NewDecoder allocates a Decoder.
@@ -59,13 +56,13 @@ func (d *Decoder) decodeTimestamp(ts uint32) time.Duration {
// * no NALUs and ErrMorePacketsNeeded // * no NALUs and ErrMorePacketsNeeded
// * one NALU (in case of FU-A) // * one NALU (in case of FU-A)
// * multiple NALUs (in case of STAP-A) // * multiple NALUs (in case of STAP-A)
func (d *Decoder) Decode(byts []byte) ([]*NALUAndTimestamp, error) { func (d *Decoder) Decode(byts []byte) ([][]byte, time.Duration, error) {
switch d.state { switch d.state {
case decoderStateInitial: case decoderStateInitial:
pkt := rtp.Packet{} pkt := rtp.Packet{}
err := pkt.Unmarshal(byts) err := pkt.Unmarshal(byts)
if err != nil { if err != nil {
return nil, err return nil, 0, err
} }
if !d.initialTsSet { if !d.initialTsSet {
@@ -74,19 +71,19 @@ func (d *Decoder) Decode(byts []byte) ([]*NALUAndTimestamp, error) {
} }
if len(pkt.Payload) < 1 { if len(pkt.Payload) < 1 {
return nil, fmt.Errorf("payload is too short") return nil, 0, fmt.Errorf("payload is too short")
} }
typ := NALUType(pkt.Payload[0] & 0x1F) typ := NALUType(pkt.Payload[0] & 0x1F)
switch typ { switch typ {
case NALUTypeStapA: case NALUTypeSTAPA:
var ret []*NALUAndTimestamp var nalus [][]byte
pkt.Payload = pkt.Payload[1:] pkt.Payload = pkt.Payload[1:]
for len(pkt.Payload) > 0 { for len(pkt.Payload) > 0 {
if len(pkt.Payload) < 2 { if len(pkt.Payload) < 2 {
return nil, fmt.Errorf("Invalid STAP-A packet") return nil, 0, fmt.Errorf("Invalid STAP-A packet")
} }
size := binary.BigEndian.Uint16(pkt.Payload) size := binary.BigEndian.Uint16(pkt.Payload)
@@ -98,30 +95,27 @@ func (d *Decoder) Decode(byts []byte) ([]*NALUAndTimestamp, error) {
} }
if int(size) > len(pkt.Payload) { if int(size) > len(pkt.Payload) {
return nil, fmt.Errorf("Invalid STAP-A packet") return nil, 0, fmt.Errorf("Invalid STAP-A packet")
} }
ret = append(ret, &NALUAndTimestamp{ nalus = append(nalus, pkt.Payload[:size])
NALU: pkt.Payload[:size],
Timestamp: d.decodeTimestamp(pkt.Timestamp),
})
pkt.Payload = pkt.Payload[size:] pkt.Payload = pkt.Payload[size:]
} }
if len(ret) == 0 { if len(nalus) == 0 {
return nil, fmt.Errorf("STAP-A packet doesn't contain any NALU") return nil, 0, fmt.Errorf("STAP-A packet doesn't contain any NALU")
} }
return ret, nil return nalus, d.decodeTimestamp(pkt.Timestamp), nil
case NALUTypeFuA: // first packet of a fragmented NALU case NALUTypeFUA: // first packet of a fragmented NALU
if len(pkt.Payload) < 2 { if len(pkt.Payload) < 2 {
return nil, fmt.Errorf("Invalid FU-A packet") return nil, 0, fmt.Errorf("Invalid FU-A packet")
} }
start := pkt.Payload[1] >> 7 start := pkt.Payload[1] >> 7
if start != 1 { if start != 1 {
return nil, fmt.Errorf("first NALU does not contain the start bit") return nil, 0, fmt.Errorf("first NALU does not contain the start bit")
} }
nri := (pkt.Payload[0] >> 5) & 0x03 nri := (pkt.Payload[0] >> 5) & 0x03
@@ -129,35 +123,32 @@ func (d *Decoder) Decode(byts []byte) ([]*NALUAndTimestamp, error) {
d.fragmentedBuf = append([]byte{(nri << 5) | typ}, pkt.Payload[2:]...) d.fragmentedBuf = append([]byte{(nri << 5) | typ}, pkt.Payload[2:]...)
d.state = decoderStateReadingFragmented d.state = decoderStateReadingFragmented
return nil, ErrMorePacketsNeeded return nil, 0, ErrMorePacketsNeeded
case NALUTypeStapB, NALUTypeMtap16, case NALUTypeSTAPB, NALUTypeMTAP16,
NALUTypeMtap24, NALUTypeFuB: NALUTypeMTAP24, NALUTypeFUB:
return nil, fmt.Errorf("NALU type not yet supported (%v)", typ) return nil, 0, fmt.Errorf("NALU type not supported (%v)", typ)
} }
return []*NALUAndTimestamp{{ return [][]byte{pkt.Payload}, d.decodeTimestamp(pkt.Timestamp), nil
NALU: pkt.Payload,
Timestamp: d.decodeTimestamp(pkt.Timestamp),
}}, nil
default: // decoderStateReadingFragmented default: // decoderStateReadingFragmented
pkt := rtp.Packet{} pkt := rtp.Packet{}
err := pkt.Unmarshal(byts) err := pkt.Unmarshal(byts)
if err != nil { if err != nil {
d.state = decoderStateInitial d.state = decoderStateInitial
return nil, err return nil, 0, err
} }
if len(pkt.Payload) < 2 { if len(pkt.Payload) < 2 {
d.state = decoderStateInitial d.state = decoderStateInitial
return nil, fmt.Errorf("Invalid FU-A packet") return nil, 0, fmt.Errorf("Invalid FU-A packet")
} }
typ := NALUType(pkt.Payload[0] & 0x1F) typ := NALUType(pkt.Payload[0] & 0x1F)
if typ != NALUTypeFuA { if typ != NALUTypeFUA {
d.state = decoderStateInitial d.state = decoderStateInitial
return nil, fmt.Errorf("non-starting NALU is not FU-A") return nil, 0, fmt.Errorf("non-starting NALU is not FU-A")
} }
end := (pkt.Payload[1] >> 6) & 0x01 end := (pkt.Payload[1] >> 6) & 0x01
@@ -165,44 +156,11 @@ func (d *Decoder) Decode(byts []byte) ([]*NALUAndTimestamp, error) {
d.fragmentedBuf = append(d.fragmentedBuf, pkt.Payload[2:]...) d.fragmentedBuf = append(d.fragmentedBuf, pkt.Payload[2:]...)
if end != 1 { if end != 1 {
return nil, ErrMorePacketsNeeded return nil, 0, ErrMorePacketsNeeded
} }
d.state = decoderStateInitial d.state = decoderStateInitial
return []*NALUAndTimestamp{{ return [][]byte{d.fragmentedBuf}, d.decodeTimestamp(pkt.Timestamp), nil
NALU: d.fragmentedBuf,
Timestamp: d.decodeTimestamp(pkt.Timestamp),
}}, nil
}
}
// Read reads RTP/H264 packets from a reader until a NALU is decoded.
func (d *Decoder) Read(r io.Reader) (*NALUAndTimestamp, error) {
if len(d.readQueue) > 0 {
nalu := d.readQueue[0]
d.readQueue = d.readQueue[1:]
return nalu, nil
}
buf := make([]byte, 2048)
for {
n, err := r.Read(buf)
if err != nil {
return nil, err
}
nalus, err := d.Decode(buf[:n])
if err != nil {
if err == ErrMorePacketsNeeded {
continue
}
return nil, err
}
nalu := nalus[0]
d.readQueue = nalus[1:]
return nalu, nil
} }
} }
@@ -212,24 +170,35 @@ func (d *Decoder) ReadSPSPPS(r io.Reader) ([]byte, []byte, error) {
var sps []byte var sps []byte
var pps []byte var pps []byte
buf := make([]byte, 2048)
for { for {
nt, err := d.Read(r) n, err := r.Read(buf)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
switch NALUType(nt.NALU[0] & 0x1F) { nalus, _, err := d.Decode(buf[:n])
if err != nil {
if err == ErrMorePacketsNeeded {
continue
}
return nil, nil, err
}
for _, nalu := range nalus {
switch NALUType(nalu[0] & 0x1F) {
case NALUTypeSPS: case NALUTypeSPS:
sps = append([]byte(nil), nt.NALU...) sps = append([]byte(nil), nalu...)
if sps != nil && pps != nil { if sps != nil && pps != nil {
return sps, pps, nil return sps, pps, nil
} }
case NALUTypePPS: case NALUTypePPS:
pps = append([]byte(nil), nt.NALU...) pps = append([]byte(nil), nalu...)
if sps != nil && pps != nil { if sps != nil && pps != nil {
return sps, pps, nil return sps, pps, nil
} }
} }
} }
} }
}

View File

@@ -2,7 +2,6 @@ package rtph264
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"math/rand" "math/rand"
"time" "time"
@@ -60,24 +59,20 @@ func (e *Encoder) encodeTimestamp(ts time.Duration) uint32 {
// * a single packets // * a single packets
// * multiple fragmented packets (FU-A) // * multiple fragmented packets (FU-A)
// * an aggregated packet (STAP-A) // * an aggregated packet (STAP-A)
func (e *Encoder) Encode(nts []*NALUAndTimestamp) ([][]byte, error) { func (e *Encoder) Encode(nalus [][]byte, pts time.Duration) ([][]byte, error) {
var rets [][]byte var rets [][]byte
var batch []*NALUAndTimestamp var batch [][]byte
// split NALUs into batches // split NALUs into batches
for _, nt := range nts { for _, nalu := range nalus {
if len(batch) > 0 && batch[0].Timestamp != nt.Timestamp { if e.lenAggregated(batch, nalu) <= rtpPayloadMaxSize {
return nil, fmt.Errorf("encoding NALUs with different timestamps is not supported")
}
if e.lenAggregated(batch, nt) <= rtpPayloadMaxSize {
// add to existing batch // add to existing batch
batch = append(batch, nt) batch = append(batch, nalu)
} else { } else {
// write batch // write batch
if batch != nil { if batch != nil {
pkts, err := e.writeBatch(batch, false) pkts, err := e.writeBatch(batch, pts, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -85,13 +80,13 @@ func (e *Encoder) Encode(nts []*NALUAndTimestamp) ([][]byte, error) {
} }
// initialize new batch // initialize new batch
batch = []*NALUAndTimestamp{nt} batch = [][]byte{nalu}
} }
} }
// write final batch // write final batch
// marker is used to indicate when all NALUs with same PTS have been sent // marker is used to indicate when all NALUs with same PTS have been sent
pkts, err := e.writeBatch(batch, true) pkts, err := e.writeBatch(batch, pts, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -100,31 +95,31 @@ func (e *Encoder) Encode(nts []*NALUAndTimestamp) ([][]byte, error) {
return rets, nil return rets, nil
} }
func (e *Encoder) writeBatch(nts []*NALUAndTimestamp, marker bool) ([][]byte, error) { func (e *Encoder) writeBatch(nalus [][]byte, pts time.Duration, marker bool) ([][]byte, error) {
if len(nts) == 1 { if len(nalus) == 1 {
// the NALU fits into a single RTP packet // the NALU fits into a single RTP packet
if len(nts[0].NALU) < rtpPayloadMaxSize { if len(nalus[0]) < rtpPayloadMaxSize {
return e.writeSingle(nts[0], marker) return e.writeSingle(nalus[0], pts, marker)
} }
// split the NALU into multiple fragmentation packet // split the NALU into multiple fragmentation packet
return e.writeFragmented(nts[0], marker) return e.writeFragmented(nalus[0], pts, marker)
} }
return e.writeAggregated(nts, marker) return e.writeAggregated(nalus, pts, marker)
} }
func (e *Encoder) writeSingle(nt *NALUAndTimestamp, marker bool) ([][]byte, error) { func (e *Encoder) writeSingle(nalu []byte, pts time.Duration, marker bool) ([][]byte, error) {
rpkt := rtp.Packet{ rpkt := rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: rtpVersion, Version: rtpVersion,
PayloadType: e.payloadType, PayloadType: e.payloadType,
SequenceNumber: e.sequenceNumber, SequenceNumber: e.sequenceNumber,
Timestamp: e.encodeTimestamp(nt.Timestamp), Timestamp: e.encodeTimestamp(pts),
SSRC: e.ssrc, SSRC: e.ssrc,
Marker: marker, Marker: marker,
}, },
Payload: nt.NALU, Payload: nalu,
} }
e.sequenceNumber++ e.sequenceNumber++
@@ -136,9 +131,7 @@ func (e *Encoder) writeSingle(nt *NALUAndTimestamp, marker bool) ([][]byte, erro
return [][]byte{frame}, nil return [][]byte{frame}, nil
} }
func (e *Encoder) writeFragmented(nt *NALUAndTimestamp, marker bool) ([][]byte, error) { func (e *Encoder) writeFragmented(nalu []byte, pts time.Duration, marker bool) ([][]byte, error) {
nalu := nt.NALU
// use only FU-A, not FU-B, since we always use non-interleaved mode // use only FU-A, not FU-B, since we always use non-interleaved mode
// (packetization-mode=1) // (packetization-mode=1)
packetCount := (len(nalu) - 1) / (rtpPayloadMaxSize - 2) packetCount := (len(nalu) - 1) / (rtpPayloadMaxSize - 2)
@@ -148,14 +141,14 @@ func (e *Encoder) writeFragmented(nt *NALUAndTimestamp, marker bool) ([][]byte,
} }
ret := make([][]byte, packetCount) ret := make([][]byte, packetCount)
ts := e.encodeTimestamp(nt.Timestamp) encPTS := e.encodeTimestamp(pts)
nri := (nalu[0] >> 5) & 0x03 nri := (nalu[0] >> 5) & 0x03
typ := nalu[0] & 0x1F typ := nalu[0] & 0x1F
nalu = nalu[1:] // remove header nalu = nalu[1:] // remove header
for i := range ret { for i := range ret {
indicator := (nri << 5) | uint8(NALUTypeFuA) indicator := (nri << 5) | uint8(NALUTypeFUA)
start := uint8(0) start := uint8(0)
if i == 0 { if i == 0 {
@@ -180,7 +173,7 @@ func (e *Encoder) writeFragmented(nt *NALUAndTimestamp, marker bool) ([][]byte,
Version: rtpVersion, Version: rtpVersion,
PayloadType: e.payloadType, PayloadType: e.payloadType,
SequenceNumber: e.sequenceNumber, SequenceNumber: e.sequenceNumber,
Timestamp: ts, Timestamp: encPTS,
SSRC: e.ssrc, SSRC: e.ssrc,
Marker: (i == (packetCount-1) && marker), Marker: (i == (packetCount-1) && marker),
}, },
@@ -199,37 +192,37 @@ func (e *Encoder) writeFragmented(nt *NALUAndTimestamp, marker bool) ([][]byte,
return ret, nil return ret, nil
} }
func (e *Encoder) lenAggregated(nts []*NALUAndTimestamp, additionalEl *NALUAndTimestamp) int { func (e *Encoder) lenAggregated(nalus [][]byte, addNALU []byte) int {
ret := 1 // header ret := 1 // header
for _, bnt := range nts { for _, nalu := range nalus {
ret += 2 // size ret += 2 // size
ret += len(bnt.NALU) // nalu ret += len(nalu) // nalu
} }
if additionalEl != nil { if addNALU != nil {
ret += 2 // size ret += 2 // size
ret += len(additionalEl.NALU) // nalu ret += len(addNALU) // nalu
} }
return ret return ret
} }
func (e *Encoder) writeAggregated(nts []*NALUAndTimestamp, marker bool) ([][]byte, error) { func (e *Encoder) writeAggregated(nalus [][]byte, pts time.Duration, marker bool) ([][]byte, error) {
payload := make([]byte, e.lenAggregated(nts, nil)) payload := make([]byte, e.lenAggregated(nalus, nil))
// header // header
payload[0] = uint8(NALUTypeStapA) payload[0] = uint8(NALUTypeSTAPA)
pos := 1 pos := 1
for _, nt := range nts { for _, nalu := range nalus {
// size // size
naluLen := len(nt.NALU) naluLen := len(nalu)
binary.BigEndian.PutUint16(payload[pos:], uint16(naluLen)) binary.BigEndian.PutUint16(payload[pos:], uint16(naluLen))
pos += 2 pos += 2
// nalu // nalu
copy(payload[pos:], nt.NALU) copy(payload[pos:], nalu)
pos += naluLen pos += naluLen
} }
@@ -238,7 +231,7 @@ func (e *Encoder) writeAggregated(nts []*NALUAndTimestamp, marker bool) ([][]byt
Version: rtpVersion, Version: rtpVersion,
PayloadType: e.payloadType, PayloadType: e.payloadType,
SequenceNumber: e.sequenceNumber, SequenceNumber: e.sequenceNumber,
Timestamp: e.encodeTimestamp(nts[0].Timestamp), Timestamp: e.encodeTimestamp(pts),
SSRC: e.ssrc, SSRC: e.ssrc,
Marker: marker, Marker: marker,
}, },

View File

@@ -28,12 +28,12 @@ const (
NALUTypeSliceExtensionDepth NALUType = 21 NALUTypeSliceExtensionDepth NALUType = 21
NALUTypeReserved22 NALUType = 22 NALUTypeReserved22 NALUType = 22
NALUTypeReserved23 NALUType = 23 NALUTypeReserved23 NALUType = 23
NALUTypeStapA NALUType = 24 NALUTypeSTAPA NALUType = 24
NALUTypeStapB NALUType = 25 NALUTypeSTAPB NALUType = 25
NALUTypeMtap16 NALUType = 26 NALUTypeMTAP16 NALUType = 26
NALUTypeMtap24 NALUType = 27 NALUTypeMTAP24 NALUType = 27
NALUTypeFuA NALUType = 28 NALUTypeFUA NALUType = 28
NALUTypeFuB NALUType = 29 NALUTypeFUB NALUType = 29
) )
// String implements fmt.Stringer. // String implements fmt.Stringer.
@@ -85,18 +85,18 @@ func (nt NALUType) String() string {
return "Reserved22" return "Reserved22"
case NALUTypeReserved23: case NALUTypeReserved23:
return "Reserved23" return "Reserved23"
case NALUTypeStapA: case NALUTypeSTAPA:
return "StapA" return "STAPA"
case NALUTypeStapB: case NALUTypeSTAPB:
return "StapB" return "STAPB"
case NALUTypeMtap16: case NALUTypeMTAP16:
return "Mtap16" return "MTAP16"
case NALUTypeMtap24: case NALUTypeMTAP24:
return "Mtap24" return "MTAP24"
case NALUTypeFuA: case NALUTypeFUA:
return "FuA" return "FUA"
case NALUTypeFuB: case NALUTypeFUB:
return "FuB" return "FUB"
} }
return "unknown" return "unknown"
} }

View File

@@ -1,12 +1,2 @@
// Package rtph264 contains a RTP/H264 decoder and encoder. // Package rtph264 contains a RTP/H264 decoder and encoder.
package rtph264 package rtph264
import (
"time"
)
// NALUAndTimestamp is a Network Abstraction Layer Unit and its timestamp.
type NALUAndTimestamp struct {
Timestamp time.Duration
NALU []byte
}

View File

@@ -2,7 +2,6 @@ package rtph264
import ( import (
"bytes" "bytes"
"io"
"testing" "testing"
"time" "time"
@@ -25,28 +24,21 @@ func mergeBytes(vals ...[]byte) []byte {
return res return res
} }
type readerFunc func(p []byte) (int, error)
func (f readerFunc) Read(p []byte) (int, error) {
return f(p)
}
var cases = []struct { var cases = []struct {
name string name string
dec []*NALUAndTimestamp nalus [][]byte
pts time.Duration
enc [][]byte enc [][]byte
}{ }{
{ {
"single", "single",
[]*NALUAndTimestamp{ [][]byte{
{ mergeBytes(
Timestamp: 25 * time.Millisecond,
NALU: mergeBytes(
[]byte{0x05}, []byte{0x05},
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 8), bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 8),
), ),
}, },
}, 25 * time.Millisecond,
[][]byte{ [][]byte{
mergeBytes( mergeBytes(
[]byte{ []byte{
@@ -59,15 +51,13 @@ var cases = []struct {
}, },
{ {
"negative timestamp", "negative timestamp",
[]*NALUAndTimestamp{ [][]byte{
{ mergeBytes(
Timestamp: -20 * time.Millisecond,
NALU: mergeBytes(
[]byte{0x05}, []byte{0x05},
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 8), bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 8),
), ),
}, },
}, -20 * time.Millisecond,
[][]byte{ [][]byte{
mergeBytes( mergeBytes(
[]byte{ []byte{
@@ -80,15 +70,13 @@ var cases = []struct {
}, },
{ {
"fragmented", "fragmented",
[]*NALUAndTimestamp{ [][]byte{
{ mergeBytes(
Timestamp: 55 * time.Millisecond,
NALU: mergeBytes(
[]byte{0x05}, []byte{0x05},
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 256), bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 256),
), ),
}, },
}, 55 * time.Millisecond,
[][]byte{ [][]byte{
mergeBytes( mergeBytes(
[]byte{ []byte{
@@ -110,12 +98,9 @@ var cases = []struct {
}, },
{ {
"aggregated", "aggregated",
[]*NALUAndTimestamp{ [][]byte{
{0x09, 0xF0},
{ {
NALU: []byte{0x09, 0xF0},
},
{
NALU: []byte{
0x41, 0x9a, 0x24, 0x6c, 0x41, 0x4f, 0xfe, 0xd6, 0x41, 0x9a, 0x24, 0x6c, 0x41, 0x4f, 0xfe, 0xd6,
0x8c, 0xb0, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x8c, 0xb0, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03,
0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00,
@@ -127,7 +112,7 @@ var cases = []struct {
0x00, 0x00, 0x6d, 0x40, 0x00, 0x00, 0x6d, 0x40,
}, },
}, },
}, 0,
[][]byte{ [][]byte{
{ {
0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x66, 0x55, 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x66, 0x55,
@@ -146,12 +131,9 @@ var cases = []struct {
}, },
{ {
"aggregated followed by single", "aggregated followed by single",
[]*NALUAndTimestamp{ [][]byte{
{0x09, 0xF0},
{ {
NALU: []byte{0x09, 0xF0},
},
{
NALU: []byte{
0x41, 0x9a, 0x24, 0x6c, 0x41, 0x4f, 0xfe, 0xd6, 0x41, 0x9a, 0x24, 0x6c, 0x41, 0x4f, 0xfe, 0xd6,
0x8c, 0xb0, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x8c, 0xb0, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03,
0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00,
@@ -162,14 +144,12 @@ var cases = []struct {
0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x03,
0x00, 0x00, 0x6d, 0x40, 0x00, 0x00, 0x6d, 0x40,
}, },
}, mergeBytes(
{
NALU: mergeBytes(
[]byte{0x08}, []byte{0x08},
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 175), bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 175),
), ),
}, },
}, 0,
[][]byte{ [][]byte{
{ {
0x80, 0x60, 0x44, 0xed, 0x88, 0x77, 0x66, 0x55, 0x80, 0x60, 0x44, 0xed, 0x88, 0x77, 0x66, 0x55,
@@ -195,20 +175,15 @@ var cases = []struct {
}, },
{ {
"fragmented followed by aggregated", "fragmented followed by aggregated",
[]*NALUAndTimestamp{ [][]byte{
{ mergeBytes(
NALU: mergeBytes(
[]byte{0x05}, []byte{0x05},
bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 256), bytes.Repeat([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 256),
), ),
{0x09, 0xF0},
{0x09, 0xF0},
}, },
{ 0,
NALU: []byte{0x09, 0xF0},
},
{
NALU: []byte{0x09, 0xF0},
},
},
[][]byte{ [][]byte{
mergeBytes( mergeBytes(
[]byte{ []byte{
@@ -242,7 +217,7 @@ func TestEncode(t *testing.T) {
ssrc := uint32(0x9dbb7812) ssrc := uint32(0x9dbb7812)
initialTs := uint32(0x88776655) initialTs := uint32(0x88776655)
e := NewEncoder(96, &sequenceNumber, &ssrc, &initialTs) e := NewEncoder(96, &sequenceNumber, &ssrc, &initialTs)
enc, err := e.Encode(ca.dec) enc, err := e.Encode(ca.nalus, ca.pts)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, ca.enc, enc) require.Equal(t, ca.enc, enc)
}) })
@@ -252,35 +227,31 @@ func TestEncode(t *testing.T) {
func TestDecode(t *testing.T) { func TestDecode(t *testing.T) {
for _, ca := range cases { for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
i := 0
r := readerFunc(func(p []byte) (int, error) {
if i == len(ca.enc) {
return 0, io.EOF
}
i++
return copy(p, ca.enc[i-1]), nil
})
d := NewDecoder() d := NewDecoder()
// send an initial packet downstream // send an initial packet downstream
// in order to compute the timestamp, // in order to compute the timestamp,
// which is relative to the initial packet // which is relative to the initial packet
_, err := d.Decode([]byte{ _, _, err := d.Decode([]byte{
0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x66, 0x55, 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x66, 0x55,
0x9d, 0xbb, 0x78, 0x12, 0x06, 0x00, 0x9d, 0xbb, 0x78, 0x12, 0x06, 0x00,
}) })
require.NoError(t, err) require.NoError(t, err)
for _, dec0 := range ca.dec { var nalus [][]byte
dec, err := d.Read(r)
require.NoError(t, err) for _, pkt := range ca.enc {
require.Equal(t, dec0, dec) addNALUs, pts, err := d.Decode(pkt)
if err == ErrMorePacketsNeeded {
continue
} }
_, err = d.Read(r) require.NoError(t, err)
require.Equal(t, io.EOF, err) require.Equal(t, ca.pts, pts)
nalus = append(nalus, addNALUs...)
}
require.Equal(t, ca.nalus, nalus)
}) })
} }
} }
@@ -303,7 +274,7 @@ func TestDecodeErrors(t *testing.T) {
"STAP-A without NALUs", "STAP-A without NALUs",
[]byte{ []byte{
0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15,
0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeStapA), 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeSTAPA),
}, },
"STAP-A packet doesn't contain any NALU", "STAP-A packet doesn't contain any NALU",
}, },
@@ -311,7 +282,7 @@ func TestDecodeErrors(t *testing.T) {
"STAP-A without size", "STAP-A without size",
[]byte{ []byte{
0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15,
0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeStapA), 0x01, 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeSTAPA), 0x01,
}, },
"Invalid STAP-A packet", "Invalid STAP-A packet",
}, },
@@ -319,7 +290,7 @@ func TestDecodeErrors(t *testing.T) {
"STAP-A with invalid size", "STAP-A with invalid size",
[]byte{ []byte{
0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15,
0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeStapA), 0x00, 0x15, 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeSTAPA), 0x00, 0x15,
}, },
"Invalid STAP-A packet", "Invalid STAP-A packet",
}, },
@@ -327,7 +298,7 @@ func TestDecodeErrors(t *testing.T) {
"FU-A without payload", "FU-A without payload",
[]byte{ []byte{
0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15,
0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeFuA), 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeFUA),
}, },
"Invalid FU-A packet", "Invalid FU-A packet",
}, },
@@ -335,14 +306,22 @@ func TestDecodeErrors(t *testing.T) {
"FU-A without start bit", "FU-A without start bit",
[]byte{ []byte{
0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15, 0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15,
0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeFuA), 0x00, 0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeFUA), 0x00,
}, },
"first NALU does not contain the start bit", "first NALU does not contain the start bit",
}, },
{
"MTAP",
[]byte{
0x80, 0xe0, 0x44, 0xed, 0x88, 0x77, 0x6a, 0x15,
0x9d, 0xbb, 0x78, 0x12, byte(NALUTypeMTAP16),
},
"NALU type not supported (MTAP16)",
},
} { } {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
d := NewDecoder() d := NewDecoder()
_, err := d.Decode(ca.byts) _, _, err := d.Decode(ca.byts)
require.NotEqual(t, ErrMorePacketsNeeded, err) require.NotEqual(t, ErrMorePacketsNeeded, err)
require.Equal(t, ca.err, err.Error()) require.Equal(t, ca.err, err.Error())
}) })