change signature of rtp encoders and decoders

This commit is contained in:
aler9
2022-03-20 11:30:04 +01:00
parent ba99421e74
commit e7aca99c73
13 changed files with 175 additions and 110 deletions

View File

@@ -28,7 +28,8 @@ func main() {
" ! rtph264pay ! udpsink host=127.0.0.1 port=9000")
// get SPS and PPS
decoder := rtph264.NewDecoder()
decoder := &rtph264.Decoder{}
decoder.Init()
sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc})
if err != nil {
panic(err)

View File

@@ -29,7 +29,8 @@ func main() {
" ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000")
// get SPS and PPS
decoder := rtph264.NewDecoder()
decoder := &rtph264.Decoder{}
decoder.Init()
sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc})
if err != nil {
panic(err)

View File

@@ -30,7 +30,8 @@ func main() {
" ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000")
// get SPS and PPS
decoder := rtph264.NewDecoder()
decoder := &rtph264.Decoder{}
decoder.Init()
sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc})
if err != nil {
panic(err)

View File

@@ -52,7 +52,10 @@ func main() {
}
// setup decoder
dec := rtpaac.NewDecoder(clockRate)
dec := &rtpaac.Decoder{
SampleRate: clockRate,
}
dec.Init()
// called when a RTP packet arrives
c.OnPacketRTP = func(trackID int, pkt *rtp.Packet) {

View File

@@ -76,7 +76,8 @@ func main() {
}
// setup RTP->H264 decoder
rtpDec := rtph264.NewDecoder()
rtpDec := &rtph264.Decoder{}
rtpDec.Init()
// setup H264->raw frames decoder
h264dec, err := newH264Decoder()

View File

@@ -51,10 +51,11 @@ func main() {
panic("H264 track not found")
}
// setup decoder
dec := rtph264.NewDecoder()
// setup RTP->H264 decoder
rtpDec := &rtph264.Decoder{}
rtpDec.Init()
// setup encoder
// setup H264->MPEGTS encoder
enc, err := newMPEGTSEncoder(sps, pps)
if err != nil {
panic(err)
@@ -67,7 +68,7 @@ func main() {
}
// decode H264 NALUs from the RTP packet
nalus, pts, err := dec.DecodeUntilMarker(pkt)
nalus, pts, err := rtpDec.DecodeUntilMarker(pkt)
if err != nil {
return
}

View File

@@ -53,7 +53,8 @@ func main() {
}
// setup RTP->H264 decoder
rtpDec := rtph264.NewDecoder()
rtpDec := &rtph264.Decoder{}
rtpDec.Init()
// setup H264->raw frames decoder
h264dec, err := newH264Decoder()

View File

@@ -16,16 +16,17 @@ var ErrMorePacketsNeeded = errors.New("need more packets")
// Decoder is a RTP/AAC decoder.
type Decoder struct {
// sample rate of input packets.
SampleRate int
timeDecoder *rtptimedec.Decoder
isDecodingFragmented bool
fragmentedBuf []byte
}
// NewDecoder allocates a Decoder.
func NewDecoder(clockRate int) *Decoder {
return &Decoder{
timeDecoder: rtptimedec.New(clockRate),
}
// Init initializes the decoder
func (d *Decoder) Init() {
d.timeDecoder = rtptimedec.New(d.SampleRate)
}
// Decode decodes AUs from a RTP/AAC packet.

View File

@@ -21,45 +21,44 @@ func randUint32() uint32 {
// Encoder is a RTP/AAC encoder.
type Encoder struct {
payloadType uint8
sampleRate float64
// payload type of packets.
PayloadType uint8
// sample rate of packets.
SampleRate int
// SSRC of packets (optional).
SSRC *uint32
// initial sequence number of packets (optional).
InitialSequenceNumber *uint16
// initial timestamp of packets (optional).
InitialTimestamp *uint32
sequenceNumber uint16
ssrc uint32
initialTs uint32
}
// NewEncoder allocates an Encoder.
func NewEncoder(payloadType uint8,
sampleRate int,
sequenceNumber *uint16,
ssrc *uint32,
initialTs *uint32) *Encoder {
return &Encoder{
payloadType: payloadType,
sampleRate: float64(sampleRate),
sequenceNumber: func() uint16 {
if sequenceNumber != nil {
return *sequenceNumber
// Init initializes the encoder.
func (e *Encoder) Init() {
if e.SSRC == nil {
v := randUint32()
e.SSRC = &v
}
return uint16(randUint32())
}(),
ssrc: func() uint32 {
if ssrc != nil {
return *ssrc
if e.InitialSequenceNumber == nil {
v := uint16(randUint32())
e.InitialSequenceNumber = &v
}
return randUint32()
}(),
initialTs: func() uint32 {
if initialTs != nil {
return *initialTs
}
return randUint32()
}(),
if e.InitialTimestamp == nil {
v := randUint32()
e.InitialTimestamp = &v
}
e.sequenceNumber = *e.InitialSequenceNumber
}
func (e *Encoder) encodeTimestamp(ts time.Duration) uint32 {
return e.initialTs + uint32(ts.Seconds()*e.sampleRate)
return *e.InitialTimestamp + uint32(ts.Seconds()*float64(e.SampleRate))
}
// Encode encodes AUs into RTP/AAC packets.
@@ -82,7 +81,7 @@ func (e *Encoder) Encode(aus [][]byte, firstPTS time.Duration) ([]*rtp.Packet, e
return nil, err
}
rets = append(rets, pkts...)
pts += time.Duration(len(batch)) * 1000 * time.Second / time.Duration(e.sampleRate)
pts += time.Duration(len(batch)) * 1000 * time.Second / time.Duration(e.SampleRate)
}
// initialize new batch
@@ -139,10 +138,10 @@ func (e *Encoder) writeFragmented(au []byte, pts time.Duration) ([]*rtp.Packet,
ret[i] = &rtp.Packet{
Header: rtp.Header{
Version: rtpVersion,
PayloadType: e.payloadType,
PayloadType: e.PayloadType,
SequenceNumber: e.sequenceNumber,
Timestamp: encPTS,
SSRC: e.ssrc,
SSRC: *e.SSRC,
Marker: (i == (packetCount - 1)),
},
Payload: data,
@@ -192,10 +191,10 @@ func (e *Encoder) writeAggregated(aus [][]byte, firstPTS time.Duration) ([]*rtp.
pkt := &rtp.Packet{
Header: rtp.Header{
Version: rtpVersion,
PayloadType: e.payloadType,
PayloadType: e.PayloadType,
SequenceNumber: e.sequenceNumber,
Timestamp: e.encodeTimestamp(firstPTS),
SSRC: e.ssrc,
SSRC: *e.SSRC,
Marker: true,
},
Payload: payload,

View File

@@ -279,7 +279,10 @@ var cases = []struct {
func TestDecode(t *testing.T) {
for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) {
d := NewDecoder(48000)
d := &Decoder{
SampleRate: 48000,
}
d.Init()
// send an initial packet downstream
// in order to compute the right timestamp,
@@ -562,7 +565,11 @@ func TestDecodeErrors(t *testing.T) {
},
} {
t.Run(ca.name, func(t *testing.T) {
d := NewDecoder(48000)
d := &Decoder{
SampleRate: 48000,
}
d.Init()
var lastErr error
for _, pkt := range ca.pkts {
_, _, lastErr = d.Decode(pkt)
@@ -575,10 +582,23 @@ func TestDecodeErrors(t *testing.T) {
func TestEncode(t *testing.T) {
for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) {
sequenceNumber := uint16(0x44ed)
ssrc := uint32(0x9dbb7812)
initialTs := uint32(0x88776655)
e := NewEncoder(96, 48000, &sequenceNumber, &ssrc, &initialTs)
e := &Encoder{
PayloadType: 96,
SampleRate: 48000,
SSRC: func() *uint32 {
v := uint32(0x9dbb7812)
return &v
}(),
InitialSequenceNumber: func() *uint16 {
v := uint16(0x44ed)
return &v
}(),
InitialTimestamp: func() *uint32 {
v := uint32(0x88776655)
return &v
}(),
}
e.Init()
pkts, err := e.Encode(ca.aus, ca.pts)
require.NoError(t, err)
@@ -588,5 +608,12 @@ func TestEncode(t *testing.T) {
}
func TestEncodeRandomInitialState(t *testing.T) {
NewEncoder(96, 48000, nil, nil, nil)
e := &Encoder{
PayloadType: 96,
SampleRate: 48000,
}
e.Init()
require.NotEqual(t, nil, e.SSRC)
require.NotEqual(t, nil, e.InitialSequenceNumber)
require.NotEqual(t, nil, e.InitialTimestamp)
}

View File

@@ -46,11 +46,9 @@ type Decoder struct {
naluBuffer [][]byte
}
// NewDecoder allocates a Decoder.
func NewDecoder() *Decoder {
return &Decoder{
timeDecoder: rtptimedec.New(90000),
}
// Init initializes the decoder
func (d *Decoder) Init() {
d.timeDecoder = rtptimedec.New(90000)
}
// Decode decodes NALUs from a RTP/H264 packet.

View File

@@ -22,42 +22,41 @@ func randUint32() uint32 {
// Encoder is a RTP/H264 encoder.
type Encoder struct {
payloadType uint8
// payload type of packets.
PayloadType uint8
// SSRC of packets (optional).
SSRC *uint32
// initial sequence number of packets (optional).
InitialSequenceNumber *uint16
// initial timestamp of packets (optional).
InitialTimestamp *uint32
sequenceNumber uint16
ssrc uint32
initialTs uint32
}
// NewEncoder allocates an Encoder.
func NewEncoder(payloadType uint8,
sequenceNumber *uint16,
ssrc *uint32,
initialTs *uint32) *Encoder {
return &Encoder{
payloadType: payloadType,
sequenceNumber: func() uint16 {
if sequenceNumber != nil {
return *sequenceNumber
// Init initializes the encoder.
func (e *Encoder) Init() {
if e.SSRC == nil {
v := randUint32()
e.SSRC = &v
}
return uint16(randUint32())
}(),
ssrc: func() uint32 {
if ssrc != nil {
return *ssrc
if e.InitialSequenceNumber == nil {
v := uint16(randUint32())
e.InitialSequenceNumber = &v
}
return randUint32()
}(),
initialTs: func() uint32 {
if initialTs != nil {
return *initialTs
}
return randUint32()
}(),
if e.InitialTimestamp == nil {
v := randUint32()
e.InitialTimestamp = &v
}
e.sequenceNumber = *e.InitialSequenceNumber
}
func (e *Encoder) encodeTimestamp(ts time.Duration) uint32 {
return e.initialTs + uint32(ts.Seconds()*rtpClockRate)
return *e.InitialTimestamp + uint32(ts.Seconds()*rtpClockRate)
}
// Encode encodes NALUs into RTP/H264 packets.
@@ -114,10 +113,10 @@ func (e *Encoder) writeSingle(nalu []byte, pts time.Duration, marker bool) ([]*r
pkt := &rtp.Packet{
Header: rtp.Header{
Version: rtpVersion,
PayloadType: e.payloadType,
PayloadType: e.PayloadType,
SequenceNumber: e.sequenceNumber,
Timestamp: e.encodeTimestamp(pts),
SSRC: e.ssrc,
SSRC: *e.SSRC,
Marker: marker,
},
Payload: nalu,
@@ -168,10 +167,10 @@ func (e *Encoder) writeFragmented(nalu []byte, pts time.Duration, marker bool) (
ret[i] = &rtp.Packet{
Header: rtp.Header{
Version: rtpVersion,
PayloadType: e.payloadType,
PayloadType: e.PayloadType,
SequenceNumber: e.sequenceNumber,
Timestamp: encPTS,
SSRC: e.ssrc,
SSRC: *e.SSRC,
Marker: (i == (packetCount-1) && marker),
},
Payload: data,
@@ -220,10 +219,10 @@ func (e *Encoder) writeAggregated(nalus [][]byte, pts time.Duration, marker bool
pkt := &rtp.Packet{
Header: rtp.Header{
Version: rtpVersion,
PayloadType: e.payloadType,
PayloadType: e.PayloadType,
SequenceNumber: e.sequenceNumber,
Timestamp: e.encodeTimestamp(pts),
SSRC: e.ssrc,
SSRC: *e.SSRC,
Marker: marker,
},
Payload: payload,

View File

@@ -307,7 +307,8 @@ var cases = []struct {
func TestDecode(t *testing.T) {
for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) {
d := NewDecoder()
d := &Decoder{}
d.Init()
// send an initial packet downstream
// in order to compute the right timestamp,
@@ -350,7 +351,8 @@ func TestDecode(t *testing.T) {
}
func TestDecodePartOfFragmentedBeforeSingle(t *testing.T) {
d := NewDecoder()
d := &Decoder{}
d.Init()
pkt := rtp.Packet{
Header: rtp.Header{
@@ -389,6 +391,9 @@ func TestDecodePartOfFragmentedBeforeSingle(t *testing.T) {
}
func TestDecodeSTAPAWithPadding(t *testing.T) {
d := &Decoder{}
d.Init()
pkt := rtp.Packet{
Header: rtp.Header{
Version: 2,
@@ -405,7 +410,8 @@ func TestDecodeSTAPAWithPadding(t *testing.T) {
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
}
nalus, _, err := NewDecoder().Decode(&pkt)
nalus, _, err := d.Decode(&pkt)
require.NoError(t, err)
require.Equal(t, [][]byte{
{0xaa, 0xbb},
@@ -654,7 +660,9 @@ func TestDecodeErrors(t *testing.T) {
},
} {
t.Run(ca.name, func(t *testing.T) {
d := NewDecoder()
d := &Decoder{}
d.Init()
var lastErr error
for _, pkt := range ca.pkts {
_, _, lastErr = d.Decode(pkt)
@@ -667,10 +675,22 @@ func TestDecodeErrors(t *testing.T) {
func TestEncode(t *testing.T) {
for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) {
sequenceNumber := uint16(0x44ed)
ssrc := uint32(0x9dbb7812)
initialTs := uint32(0x88776655)
e := NewEncoder(96, &sequenceNumber, &ssrc, &initialTs)
e := &Encoder{
PayloadType: 96,
SSRC: func() *uint32 {
v := uint32(0x9dbb7812)
return &v
}(),
InitialSequenceNumber: func() *uint16 {
v := uint16(0x44ed)
return &v
}(),
InitialTimestamp: func() *uint32 {
v := uint32(0x88776655)
return &v
}(),
}
e.Init()
pkts, err := e.Encode(ca.nalus, ca.pts)
require.NoError(t, err)
@@ -680,7 +700,13 @@ func TestEncode(t *testing.T) {
}
func TestEncodeRandomInitialState(t *testing.T) {
NewEncoder(96, nil, nil, nil)
e := &Encoder{
PayloadType: 96,
}
e.Init()
require.NotEqual(t, nil, e.SSRC)
require.NotEqual(t, nil, e.InitialSequenceNumber)
require.NotEqual(t, nil, e.InitialTimestamp)
}
type dummyReader struct {
@@ -724,7 +750,10 @@ func TestReadSPSPPS(t *testing.T) {
},
} {
t.Run(ca.name, func(t *testing.T) {
sps, pps, err := NewDecoder().ReadSPSPPS(&dummyReader{byts: ca.byts})
d := &Decoder{}
d.Init()
sps, pps, err := d.ReadSPSPPS(&dummyReader{byts: ca.byts})
require.NoError(t, err)
require.Equal(t, ca.sps, sps)
require.Equal(t, ca.pps, pps)
@@ -759,7 +788,10 @@ func TestReadSPSPPSErrors(t *testing.T) {
},
} {
t.Run(ca.name, func(t *testing.T) {
_, _, err := NewDecoder().ReadSPSPPS(&dummyReader{byts: ca.byts})
d := &Decoder{}
d.Init()
_, _, err := d.ReadSPSPPS(&dummyReader{byts: ca.byts})
require.EqualError(t, err, ca.err)
})
}