move RTP packet handling into separate package

This commit is contained in:
aler9
2022-06-11 13:47:26 +02:00
parent 1fd66bdaed
commit 98b6515c33
7 changed files with 201 additions and 180 deletions

View File

@@ -24,13 +24,12 @@ import (
"github.com/aler9/gortsplib/pkg/auth" "github.com/aler9/gortsplib/pkg/auth"
"github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/h264"
"github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/headers"
"github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/liberrors"
"github.com/aler9/gortsplib/pkg/ringbuffer" "github.com/aler9/gortsplib/pkg/ringbuffer"
"github.com/aler9/gortsplib/pkg/rtcpreceiver" "github.com/aler9/gortsplib/pkg/rtcpreceiver"
"github.com/aler9/gortsplib/pkg/rtcpsender" "github.com/aler9/gortsplib/pkg/rtcpsender"
"github.com/aler9/gortsplib/pkg/rtph264" "github.com/aler9/gortsplib/pkg/rtpproc"
"github.com/aler9/gortsplib/pkg/sdp" "github.com/aler9/gortsplib/pkg/sdp"
"github.com/aler9/gortsplib/pkg/url" "github.com/aler9/gortsplib/pkg/url"
) )
@@ -91,8 +90,7 @@ type clientTrack struct {
tcpChannel int tcpChannel int
rtcpReceiver *rtcpreceiver.RTCPReceiver rtcpReceiver *rtcpreceiver.RTCPReceiver
rtcpSender *rtcpsender.RTCPSender rtcpSender *rtcpsender.RTCPSender
h264Decoder *rtph264.Decoder proc *rtpproc.Processor
h264Encoder *rtph264.Encoder
} }
func (s clientState) String() string { func (s clientState) String() string {
@@ -710,10 +708,8 @@ func (c *Client) playRecordStart() {
if c.state == clientStatePlay { if c.state == clientStatePlay {
for _, ct := range c.tracks { for _, ct := range c.tracks {
if _, ok := ct.track.(*TrackH264); ok { _, isH264 := ct.track.(*TrackH264)
ct.h264Decoder = &rtph264.Decoder{} ct.proc = rtpproc.NewProcessor(isH264, *c.effectiveTransport == TransportTCP)
ct.h264Decoder.Init()
}
} }
c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) c.keepaliveTimer = time.NewTimer(c.keepalivePeriod)
@@ -812,57 +808,19 @@ func (c *Client) runReader() {
return err return err
} }
ctx := ClientOnPacketRTPCtx{ out, err := c.tracks[trackID].proc.Process(pkt)
TrackID: trackID, if err != nil {
Packet: pkt, return err
} }
ct := c.tracks[trackID]
c.processPacketRTP(ct, &ctx)
if ct.h264Decoder != nil { for _, entry := range out {
if ct.h264Encoder == nil && len(payload) > maxPacketSize { c.OnPacketRTP(&ClientOnPacketRTPCtx{
v1 := pkt.SSRC TrackID: trackID,
v2 := pkt.SequenceNumber Packet: entry.Packet,
v3 := pkt.Timestamp PTSEqualsDTS: entry.PTSEqualsDTS,
ct.h264Encoder = &rtph264.Encoder{ H264NALUs: entry.H264NALUs,
PayloadType: pkt.PayloadType, H264PTS: entry.H264PTS,
SSRC: &v1, })
InitialSequenceNumber: &v2,
InitialTimestamp: &v3,
}
ct.h264Encoder.Init()
}
if ct.h264Encoder != nil {
if ctx.H264NALUs != nil {
packets, err := ct.h264Encoder.Encode(ctx.H264NALUs, ctx.H264PTS)
if err != nil {
return err
}
for i, pkt := range packets {
if i != len(packets)-1 {
c.OnPacketRTP(&ClientOnPacketRTPCtx{
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: false,
})
} else {
ctx.Packet = pkt
c.OnPacketRTP(&ctx)
}
}
}
} else {
c.OnPacketRTP(&ctx)
}
} else {
if len(payload) > maxPacketSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
len(payload), maxPacketSize)
}
c.OnPacketRTP(&ctx)
} }
} else { } else {
if len(payload) > maxPacketSize { if len(payload) > maxPacketSize {
@@ -975,8 +933,7 @@ func (c *Client) playRecordStop(isClosing bool) {
} }
for _, ct := range c.tracks { for _, ct := range c.tracks {
ct.h264Decoder = nil ct.proc = nil
ct.h264Encoder = nil
} }
// stop timers // stop timers
@@ -1928,24 +1885,6 @@ func (c *Client) runWriter() {
} }
} }
func (c *Client) processPacketRTP(ct *clientTrack, ctx *ClientOnPacketRTPCtx) {
// remove padding
ctx.Packet.Header.Padding = false
ctx.Packet.PaddingSize = 0
// decode
if ct.h264Decoder != nil {
nalus, pts, err := ct.h264Decoder.DecodeUntilMarker(ctx.Packet)
if err == nil {
ctx.PTSEqualsDTS = h264.IDRPresent(nalus)
ctx.H264NALUs = nalus
ctx.H264PTS = pts
}
} else {
ctx.PTSEqualsDTS = true
}
}
// WritePacketRTP writes a RTP packet. // WritePacketRTP writes a RTP packet.
func (c *Client) WritePacketRTP(trackID int, pkt *rtp.Packet, ptsEqualsDTS bool) error { func (c *Client) WritePacketRTP(trackID int, pkt *rtp.Packet, ptsEqualsDTS bool) error {
c.writeMutex.RLock() c.writeMutex.RLock()

View File

@@ -220,7 +220,7 @@ func TestClientRead(t *testing.T) {
require.Equal(t, base.Describe, req.Method) require.Equal(t, base.Describe, req.Method)
require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value"), req.URL) require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value"), req.URL)
track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) track, err := NewTrackGeneric("application", []string{"97"}, "97 private/90000", "")
require.NoError(t, err) require.NoError(t, err)
tracks := Tracks{track} tracks := Tracks{track}

View File

@@ -178,14 +178,23 @@ func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) {
return return
} }
ctx := ClientOnPacketRTPCtx{
TrackID: u.trackID,
Packet: pkt,
}
ct := u.c.tracks[u.trackID] ct := u.c.tracks[u.trackID]
u.c.processPacketRTP(ct, &ctx)
ct.rtcpReceiver.ProcessPacketRTP(time.Now(), pkt, ctx.PTSEqualsDTS) out, err := ct.proc.Process(pkt)
u.c.OnPacketRTP(&ctx) if err != nil {
return
}
out0 := out[0]
ct.rtcpReceiver.ProcessPacketRTP(time.Now(), pkt, out0.PTSEqualsDTS)
u.c.OnPacketRTP(&ClientOnPacketRTPCtx{
TrackID: u.trackID,
Packet: out0.Packet,
PTSEqualsDTS: out0.PTSEqualsDTS,
H264NALUs: out0.H264NALUs,
H264PTS: out0.H264PTS,
})
} }
func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) { func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) {

134
pkg/rtpproc/processor.go Normal file
View File

@@ -0,0 +1,134 @@
package rtpproc
import (
"fmt"
"time"
"github.com/pion/rtp"
"github.com/aler9/gortsplib/pkg/h264"
"github.com/aler9/gortsplib/pkg/rtph264"
)
const (
// 1500 (UDP MTU) - 20 (IP header) - 8 (UDP header)
maxPacketSize = 1472
)
// ProcessorOutput is the output of Process().
type ProcessorOutput struct {
Packet *rtp.Packet
PTSEqualsDTS bool
H264NALUs [][]byte
H264PTS time.Duration
}
// Processor is used to process incoming RTP packets, in order to:
// - remove padding
// - decode packets encoded with supported codecs
// - re-encode packets if they are bigger than maximum allowed.
type Processor struct {
isH264 bool
isTCP bool
h264Decoder *rtph264.Decoder
h264Encoder *rtph264.Encoder
}
// NewProcessor allocates a Processor.
func NewProcessor(isH264 bool, isTCP bool) *Processor {
p := &Processor{
isH264: isH264,
isTCP: isTCP,
}
if isH264 {
p.h264Decoder = &rtph264.Decoder{}
p.h264Decoder.Init()
}
return p
}
func (p *Processor) processH264(pkt *rtp.Packet) ([]*ProcessorOutput, error) {
// decode
nalus, pts, err := p.h264Decoder.DecodeUntilMarker(pkt)
if err != nil {
if err == rtph264.ErrNonStartingPacketAndNoPrevious ||
err == rtph264.ErrMorePacketsNeeded {
return []*ProcessorOutput{{
Packet: pkt,
PTSEqualsDTS: false,
}}, nil
}
return nil, err
}
ptsEqualsDTS := h264.IDRPresent(nalus)
// re-encode if packets use non-standard sizes
if p.isTCP && p.h264Encoder == nil && pkt.MarshalSize() > maxPacketSize {
v1 := pkt.SSRC
v2 := pkt.SequenceNumber
v3 := pkt.Timestamp
p.h264Encoder = &rtph264.Encoder{
PayloadType: pkt.PayloadType,
SSRC: &v1,
InitialSequenceNumber: &v2,
InitialTimestamp: &v3,
}
p.h264Encoder.Init()
}
if p.h264Encoder != nil {
packets, err := p.h264Encoder.Encode(nalus, pts)
if err != nil {
return nil, err
}
output := make([]*ProcessorOutput, len(packets))
for i, pkt := range packets {
if i != len(packets)-1 {
output[i] = &ProcessorOutput{
Packet: pkt,
PTSEqualsDTS: false,
}
} else {
output[i] = &ProcessorOutput{
Packet: pkt,
PTSEqualsDTS: ptsEqualsDTS,
}
}
}
return output, nil
}
return []*ProcessorOutput{{
Packet: pkt,
PTSEqualsDTS: ptsEqualsDTS,
H264NALUs: nalus,
H264PTS: pts,
}}, nil
}
// Process processes a RTP packet.
func (p *Processor) Process(pkt *rtp.Packet) ([]*ProcessorOutput, error) {
// remove padding
pkt.Header.Padding = false
pkt.PaddingSize = 0
if p.h264Decoder != nil {
return p.processH264(pkt)
}
if p.isTCP && pkt.MarshalSize() > maxPacketSize {
return nil, fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
pkt.MarshalSize(), maxPacketSize)
}
return []*ProcessorOutput{{
Packet: pkt,
PTSEqualsDTS: true,
}}, nil
}

View File

@@ -15,7 +15,6 @@ import (
"github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/liberrors"
"github.com/aler9/gortsplib/pkg/rtph264"
"github.com/aler9/gortsplib/pkg/url" "github.com/aler9/gortsplib/pkg/url"
) )
@@ -258,66 +257,21 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
return err return err
} }
ctx := ServerHandlerOnPacketRTPCtx{ out, err := sc.session.announcedTracks[trackID].proc.Process(pkt)
Session: sc.session, if err != nil {
TrackID: trackID, return err
Packet: pkt,
} }
at := sc.session.announcedTracks[trackID]
sc.session.processPacketRTP(at, &ctx)
if at.h264Decoder != nil { if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
if at.h264Encoder == nil && len(payload) > maxPacketSize { for _, entry := range out {
v1 := pkt.SSRC h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
v2 := pkt.SequenceNumber Session: sc.session,
v3 := pkt.Timestamp TrackID: trackID,
at.h264Encoder = &rtph264.Encoder{ Packet: entry.Packet,
PayloadType: pkt.PayloadType, PTSEqualsDTS: entry.PTSEqualsDTS,
SSRC: &v1, H264NALUs: entry.H264NALUs,
InitialSequenceNumber: &v2, H264PTS: entry.H264PTS,
InitialTimestamp: &v3, })
}
at.h264Encoder.Init()
}
if at.h264Encoder != nil {
if ctx.H264NALUs != nil {
packets, err := at.h264Encoder.Encode(ctx.H264NALUs, ctx.H264PTS)
if err != nil {
return err
}
for i, pkt := range packets {
if i != len(packets)-1 {
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
Session: sc.session,
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: false,
})
}
} else {
ctx.Packet = pkt
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ctx)
}
}
}
}
} else {
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ctx)
}
}
} else {
if len(payload) > maxPacketSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
len(payload), maxPacketSize)
}
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ctx)
} }
} }
} else { } else {

View File

@@ -14,12 +14,11 @@ import (
"github.com/pion/rtp" "github.com/pion/rtp"
"github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/h264"
"github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/headers"
"github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/liberrors"
"github.com/aler9/gortsplib/pkg/ringbuffer" "github.com/aler9/gortsplib/pkg/ringbuffer"
"github.com/aler9/gortsplib/pkg/rtcpreceiver" "github.com/aler9/gortsplib/pkg/rtcpreceiver"
"github.com/aler9/gortsplib/pkg/rtph264" "github.com/aler9/gortsplib/pkg/rtpproc"
"github.com/aler9/gortsplib/pkg/url" "github.com/aler9/gortsplib/pkg/url"
) )
@@ -154,8 +153,7 @@ type ServerSessionSetuppedTrack struct {
type ServerSessionAnnouncedTrack struct { type ServerSessionAnnouncedTrack struct {
track Track track Track
rtcpReceiver *rtcpreceiver.RTCPReceiver rtcpReceiver *rtcpreceiver.RTCPReceiver
h264Decoder *rtph264.Decoder proc *rtpproc.Processor
h264Encoder *rtph264.Encoder
} }
// ServerSession is a server-side RTSP session. // ServerSession is a server-side RTSP session.
@@ -980,10 +978,8 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.state = ServerSessionStateRecord ss.state = ServerSessionStateRecord
for _, at := range ss.announcedTracks { for _, at := range ss.announcedTracks {
if _, ok := at.track.(*TrackH264); ok { _, isH264 := at.track.(*TrackH264)
at.h264Decoder = &rtph264.Decoder{} at.proc = rtpproc.NewProcessor(isH264, *ss.setuppedTransport == TransportTCP)
at.h264Decoder.Init()
}
} }
switch *ss.setuppedTransport { switch *ss.setuppedTransport {
@@ -1104,8 +1100,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
} }
for _, at := range ss.announcedTracks { for _, at := range ss.announcedTracks {
at.h264Decoder = nil at.proc = nil
at.h264Encoder = nil
} }
ss.state = ServerSessionStatePreRecord ss.state = ServerSessionStatePreRecord
@@ -1215,24 +1210,6 @@ func (ss *ServerSession) runWriter() {
} }
} }
func (ss *ServerSession) processPacketRTP(at *ServerSessionAnnouncedTrack, ctx *ServerHandlerOnPacketRTPCtx) {
// remove padding
ctx.Packet.Header.Padding = false
ctx.Packet.PaddingSize = 0
// decode
if at.h264Decoder != nil {
nalus, pts, err := at.h264Decoder.DecodeUntilMarker(ctx.Packet)
if err == nil {
ctx.PTSEqualsDTS = h264.IDRPresent(nalus)
ctx.H264NALUs = nalus
ctx.H264PTS = pts
}
} else {
ctx.PTSEqualsDTS = true
}
}
func (ss *ServerSession) onPacketRTCP(trackID int, pkt rtcp.Packet) { func (ss *ServerSession) onPacketRTCP(trackID int, pkt rtcp.Packet) {
if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTCP); ok { if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTCP); ok {
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{

View File

@@ -203,17 +203,25 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) {
now := time.Now() now := time.Now()
atomic.StoreInt64(clientData.ss.udpLastFrameTime, now.Unix()) atomic.StoreInt64(clientData.ss.udpLastFrameTime, now.Unix())
ctx := ServerHandlerOnPacketRTPCtx{
Session: clientData.ss,
TrackID: clientData.trackID,
Packet: pkt,
}
at := clientData.ss.announcedTracks[clientData.trackID] at := clientData.ss.announcedTracks[clientData.trackID]
clientData.ss.processPacketRTP(at, &ctx)
at.rtcpReceiver.ProcessPacketRTP(now, ctx.Packet, ctx.PTSEqualsDTS) out, err := at.proc.Process(pkt)
if err != nil {
return
}
out0 := out[0]
at.rtcpReceiver.ProcessPacketRTP(now, pkt, out0.PTSEqualsDTS)
if h, ok := clientData.ss.s.Handler.(ServerHandlerOnPacketRTP); ok { if h, ok := clientData.ss.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ctx) h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
Session: clientData.ss,
TrackID: clientData.trackID,
Packet: out0.Packet,
PTSEqualsDTS: out0.PTSEqualsDTS,
H264NALUs: out0.H264NALUs,
H264PTS: out0.H264PTS,
})
} }
} }