change signature of rtp encoders and decoders

This commit is contained in:
aler9
2022-03-20 11:30:04 +01:00
parent ba99421e74
commit e7aca99c73
13 changed files with 175 additions and 110 deletions

View File

@@ -28,7 +28,8 @@ func main() {
" ! rtph264pay ! udpsink host=127.0.0.1 port=9000") " ! rtph264pay ! udpsink host=127.0.0.1 port=9000")
// get SPS and PPS // get SPS and PPS
decoder := rtph264.NewDecoder() decoder := &rtph264.Decoder{}
decoder.Init()
sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc}) sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc})
if err != nil { if err != nil {
panic(err) panic(err)

View File

@@ -29,7 +29,8 @@ func main() {
" ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000") " ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000")
// get SPS and PPS // get SPS and PPS
decoder := rtph264.NewDecoder() decoder := &rtph264.Decoder{}
decoder.Init()
sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc}) sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc})
if err != nil { if err != nil {
panic(err) panic(err)

View File

@@ -30,7 +30,8 @@ func main() {
" ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000") " ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000")
// get SPS and PPS // get SPS and PPS
decoder := rtph264.NewDecoder() decoder := &rtph264.Decoder{}
decoder.Init()
sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc}) sps, pps, err := decoder.ReadSPSPPS(rtph264.PacketConnReader{pc})
if err != nil { if err != nil {
panic(err) panic(err)

View File

@@ -52,7 +52,10 @@ func main() {
} }
// setup decoder // setup decoder
dec := rtpaac.NewDecoder(clockRate) dec := &rtpaac.Decoder{
SampleRate: clockRate,
}
dec.Init()
// called when a RTP packet arrives // called when a RTP packet arrives
c.OnPacketRTP = func(trackID int, pkt *rtp.Packet) { c.OnPacketRTP = func(trackID int, pkt *rtp.Packet) {

View File

@@ -76,7 +76,8 @@ func main() {
} }
// setup RTP->H264 decoder // setup RTP->H264 decoder
rtpDec := rtph264.NewDecoder() rtpDec := &rtph264.Decoder{}
rtpDec.Init()
// setup H264->raw frames decoder // setup H264->raw frames decoder
h264dec, err := newH264Decoder() h264dec, err := newH264Decoder()

View File

@@ -51,10 +51,11 @@ func main() {
panic("H264 track not found") panic("H264 track not found")
} }
// setup decoder // setup RTP->H264 decoder
dec := rtph264.NewDecoder() rtpDec := &rtph264.Decoder{}
rtpDec.Init()
// setup encoder // setup H264->MPEGTS encoder
enc, err := newMPEGTSEncoder(sps, pps) enc, err := newMPEGTSEncoder(sps, pps)
if err != nil { if err != nil {
panic(err) panic(err)
@@ -67,7 +68,7 @@ func main() {
} }
// decode H264 NALUs from the RTP packet // decode H264 NALUs from the RTP packet
nalus, pts, err := dec.DecodeUntilMarker(pkt) nalus, pts, err := rtpDec.DecodeUntilMarker(pkt)
if err != nil { if err != nil {
return return
} }

View File

@@ -53,7 +53,8 @@ func main() {
} }
// setup RTP->H264 decoder // setup RTP->H264 decoder
rtpDec := rtph264.NewDecoder() rtpDec := &rtph264.Decoder{}
rtpDec.Init()
// setup H264->raw frames decoder // setup H264->raw frames decoder
h264dec, err := newH264Decoder() h264dec, err := newH264Decoder()

View File

@@ -16,16 +16,17 @@ var ErrMorePacketsNeeded = errors.New("need more packets")
// Decoder is a RTP/AAC decoder. // Decoder is a RTP/AAC decoder.
type Decoder struct { type Decoder struct {
// sample rate of input packets.
SampleRate int
timeDecoder *rtptimedec.Decoder timeDecoder *rtptimedec.Decoder
isDecodingFragmented bool isDecodingFragmented bool
fragmentedBuf []byte fragmentedBuf []byte
} }
// NewDecoder allocates a Decoder. // Init initializes the decoder
func NewDecoder(clockRate int) *Decoder { func (d *Decoder) Init() {
return &Decoder{ d.timeDecoder = rtptimedec.New(d.SampleRate)
timeDecoder: rtptimedec.New(clockRate),
}
} }
// Decode decodes AUs from a RTP/AAC packet. // Decode decodes AUs from a RTP/AAC packet.

View File

@@ -21,45 +21,44 @@ func randUint32() uint32 {
// Encoder is a RTP/AAC encoder. // Encoder is a RTP/AAC encoder.
type Encoder struct { type Encoder struct {
payloadType uint8 // payload type of packets.
sampleRate float64 PayloadType uint8
// sample rate of packets.
SampleRate int
// SSRC of packets (optional).
SSRC *uint32
// initial sequence number of packets (optional).
InitialSequenceNumber *uint16
// initial timestamp of packets (optional).
InitialTimestamp *uint32
sequenceNumber uint16 sequenceNumber uint16
ssrc uint32
initialTs uint32
} }
// NewEncoder allocates an Encoder. // Init initializes the encoder.
func NewEncoder(payloadType uint8, func (e *Encoder) Init() {
sampleRate int, if e.SSRC == nil {
sequenceNumber *uint16, v := randUint32()
ssrc *uint32, e.SSRC = &v
initialTs *uint32) *Encoder {
return &Encoder{
payloadType: payloadType,
sampleRate: float64(sampleRate),
sequenceNumber: func() uint16 {
if sequenceNumber != nil {
return *sequenceNumber
}
return uint16(randUint32())
}(),
ssrc: func() uint32 {
if ssrc != nil {
return *ssrc
}
return randUint32()
}(),
initialTs: func() uint32 {
if initialTs != nil {
return *initialTs
}
return randUint32()
}(),
} }
if e.InitialSequenceNumber == nil {
v := uint16(randUint32())
e.InitialSequenceNumber = &v
}
if e.InitialTimestamp == nil {
v := randUint32()
e.InitialTimestamp = &v
}
e.sequenceNumber = *e.InitialSequenceNumber
} }
func (e *Encoder) encodeTimestamp(ts time.Duration) uint32 { func (e *Encoder) encodeTimestamp(ts time.Duration) uint32 {
return e.initialTs + uint32(ts.Seconds()*e.sampleRate) return *e.InitialTimestamp + uint32(ts.Seconds()*float64(e.SampleRate))
} }
// Encode encodes AUs into RTP/AAC packets. // Encode encodes AUs into RTP/AAC packets.
@@ -82,7 +81,7 @@ func (e *Encoder) Encode(aus [][]byte, firstPTS time.Duration) ([]*rtp.Packet, e
return nil, err return nil, err
} }
rets = append(rets, pkts...) rets = append(rets, pkts...)
pts += time.Duration(len(batch)) * 1000 * time.Second / time.Duration(e.sampleRate) pts += time.Duration(len(batch)) * 1000 * time.Second / time.Duration(e.SampleRate)
} }
// initialize new batch // initialize new batch
@@ -139,10 +138,10 @@ func (e *Encoder) writeFragmented(au []byte, pts time.Duration) ([]*rtp.Packet,
ret[i] = &rtp.Packet{ ret[i] = &rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: rtpVersion, Version: rtpVersion,
PayloadType: e.payloadType, PayloadType: e.PayloadType,
SequenceNumber: e.sequenceNumber, SequenceNumber: e.sequenceNumber,
Timestamp: encPTS, Timestamp: encPTS,
SSRC: e.ssrc, SSRC: *e.SSRC,
Marker: (i == (packetCount - 1)), Marker: (i == (packetCount - 1)),
}, },
Payload: data, Payload: data,
@@ -192,10 +191,10 @@ func (e *Encoder) writeAggregated(aus [][]byte, firstPTS time.Duration) ([]*rtp.
pkt := &rtp.Packet{ pkt := &rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: rtpVersion, Version: rtpVersion,
PayloadType: e.payloadType, PayloadType: e.PayloadType,
SequenceNumber: e.sequenceNumber, SequenceNumber: e.sequenceNumber,
Timestamp: e.encodeTimestamp(firstPTS), Timestamp: e.encodeTimestamp(firstPTS),
SSRC: e.ssrc, SSRC: *e.SSRC,
Marker: true, Marker: true,
}, },
Payload: payload, Payload: payload,

View File

@@ -279,7 +279,10 @@ var cases = []struct {
func TestDecode(t *testing.T) { func TestDecode(t *testing.T) {
for _, ca := range cases { for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
d := NewDecoder(48000) d := &Decoder{
SampleRate: 48000,
}
d.Init()
// send an initial packet downstream // send an initial packet downstream
// in order to compute the right timestamp, // in order to compute the right timestamp,
@@ -562,7 +565,11 @@ func TestDecodeErrors(t *testing.T) {
}, },
} { } {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
d := NewDecoder(48000) d := &Decoder{
SampleRate: 48000,
}
d.Init()
var lastErr error var lastErr error
for _, pkt := range ca.pkts { for _, pkt := range ca.pkts {
_, _, lastErr = d.Decode(pkt) _, _, lastErr = d.Decode(pkt)
@@ -575,10 +582,23 @@ func TestDecodeErrors(t *testing.T) {
func TestEncode(t *testing.T) { func TestEncode(t *testing.T) {
for _, ca := range cases { for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
sequenceNumber := uint16(0x44ed) e := &Encoder{
ssrc := uint32(0x9dbb7812) PayloadType: 96,
initialTs := uint32(0x88776655) SampleRate: 48000,
e := NewEncoder(96, 48000, &sequenceNumber, &ssrc, &initialTs) SSRC: func() *uint32 {
v := uint32(0x9dbb7812)
return &v
}(),
InitialSequenceNumber: func() *uint16 {
v := uint16(0x44ed)
return &v
}(),
InitialTimestamp: func() *uint32 {
v := uint32(0x88776655)
return &v
}(),
}
e.Init()
pkts, err := e.Encode(ca.aus, ca.pts) pkts, err := e.Encode(ca.aus, ca.pts)
require.NoError(t, err) require.NoError(t, err)
@@ -588,5 +608,12 @@ func TestEncode(t *testing.T) {
} }
func TestEncodeRandomInitialState(t *testing.T) { func TestEncodeRandomInitialState(t *testing.T) {
NewEncoder(96, 48000, nil, nil, nil) e := &Encoder{
PayloadType: 96,
SampleRate: 48000,
}
e.Init()
require.NotEqual(t, nil, e.SSRC)
require.NotEqual(t, nil, e.InitialSequenceNumber)
require.NotEqual(t, nil, e.InitialTimestamp)
} }

View File

@@ -46,11 +46,9 @@ type Decoder struct {
naluBuffer [][]byte naluBuffer [][]byte
} }
// NewDecoder allocates a Decoder. // Init initializes the decoder
func NewDecoder() *Decoder { func (d *Decoder) Init() {
return &Decoder{ d.timeDecoder = rtptimedec.New(90000)
timeDecoder: rtptimedec.New(90000),
}
} }
// Decode decodes NALUs from a RTP/H264 packet. // Decode decodes NALUs from a RTP/H264 packet.

View File

@@ -22,42 +22,41 @@ func randUint32() uint32 {
// Encoder is a RTP/H264 encoder. // Encoder is a RTP/H264 encoder.
type Encoder struct { type Encoder struct {
payloadType uint8 // payload type of packets.
PayloadType uint8
// SSRC of packets (optional).
SSRC *uint32
// initial sequence number of packets (optional).
InitialSequenceNumber *uint16
// initial timestamp of packets (optional).
InitialTimestamp *uint32
sequenceNumber uint16 sequenceNumber uint16
ssrc uint32
initialTs uint32
} }
// NewEncoder allocates an Encoder. // Init initializes the encoder.
func NewEncoder(payloadType uint8, func (e *Encoder) Init() {
sequenceNumber *uint16, if e.SSRC == nil {
ssrc *uint32, v := randUint32()
initialTs *uint32) *Encoder { e.SSRC = &v
return &Encoder{
payloadType: payloadType,
sequenceNumber: func() uint16 {
if sequenceNumber != nil {
return *sequenceNumber
}
return uint16(randUint32())
}(),
ssrc: func() uint32 {
if ssrc != nil {
return *ssrc
}
return randUint32()
}(),
initialTs: func() uint32 {
if initialTs != nil {
return *initialTs
}
return randUint32()
}(),
} }
if e.InitialSequenceNumber == nil {
v := uint16(randUint32())
e.InitialSequenceNumber = &v
}
if e.InitialTimestamp == nil {
v := randUint32()
e.InitialTimestamp = &v
}
e.sequenceNumber = *e.InitialSequenceNumber
} }
func (e *Encoder) encodeTimestamp(ts time.Duration) uint32 { func (e *Encoder) encodeTimestamp(ts time.Duration) uint32 {
return e.initialTs + uint32(ts.Seconds()*rtpClockRate) return *e.InitialTimestamp + uint32(ts.Seconds()*rtpClockRate)
} }
// Encode encodes NALUs into RTP/H264 packets. // Encode encodes NALUs into RTP/H264 packets.
@@ -114,10 +113,10 @@ func (e *Encoder) writeSingle(nalu []byte, pts time.Duration, marker bool) ([]*r
pkt := &rtp.Packet{ pkt := &rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: rtpVersion, Version: rtpVersion,
PayloadType: e.payloadType, PayloadType: e.PayloadType,
SequenceNumber: e.sequenceNumber, SequenceNumber: e.sequenceNumber,
Timestamp: e.encodeTimestamp(pts), Timestamp: e.encodeTimestamp(pts),
SSRC: e.ssrc, SSRC: *e.SSRC,
Marker: marker, Marker: marker,
}, },
Payload: nalu, Payload: nalu,
@@ -168,10 +167,10 @@ func (e *Encoder) writeFragmented(nalu []byte, pts time.Duration, marker bool) (
ret[i] = &rtp.Packet{ ret[i] = &rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: rtpVersion, Version: rtpVersion,
PayloadType: e.payloadType, PayloadType: e.PayloadType,
SequenceNumber: e.sequenceNumber, SequenceNumber: e.sequenceNumber,
Timestamp: encPTS, Timestamp: encPTS,
SSRC: e.ssrc, SSRC: *e.SSRC,
Marker: (i == (packetCount-1) && marker), Marker: (i == (packetCount-1) && marker),
}, },
Payload: data, Payload: data,
@@ -220,10 +219,10 @@ func (e *Encoder) writeAggregated(nalus [][]byte, pts time.Duration, marker bool
pkt := &rtp.Packet{ pkt := &rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: rtpVersion, Version: rtpVersion,
PayloadType: e.payloadType, PayloadType: e.PayloadType,
SequenceNumber: e.sequenceNumber, SequenceNumber: e.sequenceNumber,
Timestamp: e.encodeTimestamp(pts), Timestamp: e.encodeTimestamp(pts),
SSRC: e.ssrc, SSRC: *e.SSRC,
Marker: marker, Marker: marker,
}, },
Payload: payload, Payload: payload,

View File

@@ -307,7 +307,8 @@ var cases = []struct {
func TestDecode(t *testing.T) { func TestDecode(t *testing.T) {
for _, ca := range cases { for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
d := NewDecoder() d := &Decoder{}
d.Init()
// send an initial packet downstream // send an initial packet downstream
// in order to compute the right timestamp, // in order to compute the right timestamp,
@@ -350,7 +351,8 @@ func TestDecode(t *testing.T) {
} }
func TestDecodePartOfFragmentedBeforeSingle(t *testing.T) { func TestDecodePartOfFragmentedBeforeSingle(t *testing.T) {
d := NewDecoder() d := &Decoder{}
d.Init()
pkt := rtp.Packet{ pkt := rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
@@ -389,6 +391,9 @@ func TestDecodePartOfFragmentedBeforeSingle(t *testing.T) {
} }
func TestDecodeSTAPAWithPadding(t *testing.T) { func TestDecodeSTAPAWithPadding(t *testing.T) {
d := &Decoder{}
d.Init()
pkt := rtp.Packet{ pkt := rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: 2, Version: 2,
@@ -405,7 +410,8 @@ func TestDecodeSTAPAWithPadding(t *testing.T) {
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}, },
} }
nalus, _, err := NewDecoder().Decode(&pkt)
nalus, _, err := d.Decode(&pkt)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, [][]byte{ require.Equal(t, [][]byte{
{0xaa, 0xbb}, {0xaa, 0xbb},
@@ -654,7 +660,9 @@ func TestDecodeErrors(t *testing.T) {
}, },
} { } {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
d := NewDecoder() d := &Decoder{}
d.Init()
var lastErr error var lastErr error
for _, pkt := range ca.pkts { for _, pkt := range ca.pkts {
_, _, lastErr = d.Decode(pkt) _, _, lastErr = d.Decode(pkt)
@@ -667,10 +675,22 @@ func TestDecodeErrors(t *testing.T) {
func TestEncode(t *testing.T) { func TestEncode(t *testing.T) {
for _, ca := range cases { for _, ca := range cases {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
sequenceNumber := uint16(0x44ed) e := &Encoder{
ssrc := uint32(0x9dbb7812) PayloadType: 96,
initialTs := uint32(0x88776655) SSRC: func() *uint32 {
e := NewEncoder(96, &sequenceNumber, &ssrc, &initialTs) v := uint32(0x9dbb7812)
return &v
}(),
InitialSequenceNumber: func() *uint16 {
v := uint16(0x44ed)
return &v
}(),
InitialTimestamp: func() *uint32 {
v := uint32(0x88776655)
return &v
}(),
}
e.Init()
pkts, err := e.Encode(ca.nalus, ca.pts) pkts, err := e.Encode(ca.nalus, ca.pts)
require.NoError(t, err) require.NoError(t, err)
@@ -680,7 +700,13 @@ func TestEncode(t *testing.T) {
} }
func TestEncodeRandomInitialState(t *testing.T) { func TestEncodeRandomInitialState(t *testing.T) {
NewEncoder(96, nil, nil, nil) e := &Encoder{
PayloadType: 96,
}
e.Init()
require.NotEqual(t, nil, e.SSRC)
require.NotEqual(t, nil, e.InitialSequenceNumber)
require.NotEqual(t, nil, e.InitialTimestamp)
} }
type dummyReader struct { type dummyReader struct {
@@ -724,7 +750,10 @@ func TestReadSPSPPS(t *testing.T) {
}, },
} { } {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
sps, pps, err := NewDecoder().ReadSPSPPS(&dummyReader{byts: ca.byts}) d := &Decoder{}
d.Init()
sps, pps, err := d.ReadSPSPPS(&dummyReader{byts: ca.byts})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, ca.sps, sps) require.Equal(t, ca.sps, sps)
require.Equal(t, ca.pps, pps) require.Equal(t, ca.pps, pps)
@@ -759,7 +788,10 @@ func TestReadSPSPPSErrors(t *testing.T) {
}, },
} { } {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
_, _, err := NewDecoder().ReadSPSPPS(&dummyReader{byts: ca.byts}) d := &Decoder{}
d.Init()
_, _, err := d.ReadSPSPPS(&dummyReader{byts: ca.byts})
require.EqualError(t, err, ca.err) require.EqualError(t, err, ca.err)
}) })
} }