diff --git a/client.go b/client.go index ceb4cdc8..637bee37 100644 --- a/client.go +++ b/client.go @@ -24,13 +24,12 @@ import ( "github.com/aler9/gortsplib/pkg/auth" "github.com/aler9/gortsplib/pkg/base" - "github.com/aler9/gortsplib/pkg/h264" "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/ringbuffer" "github.com/aler9/gortsplib/pkg/rtcpreceiver" "github.com/aler9/gortsplib/pkg/rtcpsender" - "github.com/aler9/gortsplib/pkg/rtph264" + "github.com/aler9/gortsplib/pkg/rtpproc" "github.com/aler9/gortsplib/pkg/sdp" "github.com/aler9/gortsplib/pkg/url" ) @@ -91,8 +90,7 @@ type clientTrack struct { tcpChannel int rtcpReceiver *rtcpreceiver.RTCPReceiver rtcpSender *rtcpsender.RTCPSender - h264Decoder *rtph264.Decoder - h264Encoder *rtph264.Encoder + proc *rtpproc.Processor } func (s clientState) String() string { @@ -710,10 +708,8 @@ func (c *Client) playRecordStart() { if c.state == clientStatePlay { for _, ct := range c.tracks { - if _, ok := ct.track.(*TrackH264); ok { - ct.h264Decoder = &rtph264.Decoder{} - ct.h264Decoder.Init() - } + _, isH264 := ct.track.(*TrackH264) + ct.proc = rtpproc.NewProcessor(isH264, *c.effectiveTransport == TransportTCP) } c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) @@ -812,57 +808,19 @@ func (c *Client) runReader() { return err } - ctx := ClientOnPacketRTPCtx{ - TrackID: trackID, - Packet: pkt, + out, err := c.tracks[trackID].proc.Process(pkt) + if err != nil { + return err } - ct := c.tracks[trackID] - c.processPacketRTP(ct, &ctx) - if ct.h264Decoder != nil { - if ct.h264Encoder == nil && len(payload) > maxPacketSize { - v1 := pkt.SSRC - v2 := pkt.SequenceNumber - v3 := pkt.Timestamp - ct.h264Encoder = &rtph264.Encoder{ - PayloadType: pkt.PayloadType, - SSRC: &v1, - InitialSequenceNumber: &v2, - InitialTimestamp: &v3, - } - ct.h264Encoder.Init() - } - - if ct.h264Encoder != nil { - if ctx.H264NALUs != nil { - packets, err := ct.h264Encoder.Encode(ctx.H264NALUs, ctx.H264PTS) - if err != nil { - return err - } - - for i, pkt := range packets { - if i != len(packets)-1 { - c.OnPacketRTP(&ClientOnPacketRTPCtx{ - TrackID: trackID, - Packet: pkt, - PTSEqualsDTS: false, - }) - } else { - ctx.Packet = pkt - c.OnPacketRTP(&ctx) - } - } - } - } else { - c.OnPacketRTP(&ctx) - } - } else { - if len(payload) > maxPacketSize { - return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", - len(payload), maxPacketSize) - } - - c.OnPacketRTP(&ctx) + for _, entry := range out { + c.OnPacketRTP(&ClientOnPacketRTPCtx{ + TrackID: trackID, + Packet: entry.Packet, + PTSEqualsDTS: entry.PTSEqualsDTS, + H264NALUs: entry.H264NALUs, + H264PTS: entry.H264PTS, + }) } } else { if len(payload) > maxPacketSize { @@ -975,8 +933,7 @@ func (c *Client) playRecordStop(isClosing bool) { } for _, ct := range c.tracks { - ct.h264Decoder = nil - ct.h264Encoder = nil + ct.proc = nil } // stop timers @@ -1928,24 +1885,6 @@ func (c *Client) runWriter() { } } -func (c *Client) processPacketRTP(ct *clientTrack, ctx *ClientOnPacketRTPCtx) { - // remove padding - ctx.Packet.Header.Padding = false - ctx.Packet.PaddingSize = 0 - - // decode - if ct.h264Decoder != nil { - nalus, pts, err := ct.h264Decoder.DecodeUntilMarker(ctx.Packet) - if err == nil { - ctx.PTSEqualsDTS = h264.IDRPresent(nalus) - ctx.H264NALUs = nalus - ctx.H264PTS = pts - } - } else { - ctx.PTSEqualsDTS = true - } -} - // WritePacketRTP writes a RTP packet. func (c *Client) WritePacketRTP(trackID int, pkt *rtp.Packet, ptsEqualsDTS bool) error { c.writeMutex.RLock() diff --git a/client_read_test.go b/client_read_test.go index 9a6bb8e2..353be336 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -220,7 +220,7 @@ func TestClientRead(t *testing.T) { require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value"), req.URL) - track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) + track, err := NewTrackGeneric("application", []string{"97"}, "97 private/90000", "") require.NoError(t, err) tracks := Tracks{track} diff --git a/clientudpl.go b/clientudpl.go index 8ebd2a81..e6c1e613 100644 --- a/clientudpl.go +++ b/clientudpl.go @@ -178,14 +178,23 @@ func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) { return } - ctx := ClientOnPacketRTPCtx{ - TrackID: u.trackID, - Packet: pkt, - } ct := u.c.tracks[u.trackID] - u.c.processPacketRTP(ct, &ctx) - ct.rtcpReceiver.ProcessPacketRTP(time.Now(), pkt, ctx.PTSEqualsDTS) - u.c.OnPacketRTP(&ctx) + + out, err := ct.proc.Process(pkt) + if err != nil { + return + } + out0 := out[0] + + ct.rtcpReceiver.ProcessPacketRTP(time.Now(), pkt, out0.PTSEqualsDTS) + + u.c.OnPacketRTP(&ClientOnPacketRTPCtx{ + TrackID: u.trackID, + Packet: out0.Packet, + PTSEqualsDTS: out0.PTSEqualsDTS, + H264NALUs: out0.H264NALUs, + H264PTS: out0.H264PTS, + }) } func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) { diff --git a/pkg/rtpproc/processor.go b/pkg/rtpproc/processor.go new file mode 100644 index 00000000..464ed0f5 --- /dev/null +++ b/pkg/rtpproc/processor.go @@ -0,0 +1,134 @@ +package rtpproc + +import ( + "fmt" + "time" + + "github.com/pion/rtp" + + "github.com/aler9/gortsplib/pkg/h264" + "github.com/aler9/gortsplib/pkg/rtph264" +) + +const ( + // 1500 (UDP MTU) - 20 (IP header) - 8 (UDP header) + maxPacketSize = 1472 +) + +// ProcessorOutput is the output of Process(). +type ProcessorOutput struct { + Packet *rtp.Packet + PTSEqualsDTS bool + H264NALUs [][]byte + H264PTS time.Duration +} + +// Processor is used to process incoming RTP packets, in order to: +// - remove padding +// - decode packets encoded with supported codecs +// - re-encode packets if they are bigger than maximum allowed. +type Processor struct { + isH264 bool + isTCP bool + + h264Decoder *rtph264.Decoder + h264Encoder *rtph264.Encoder +} + +// NewProcessor allocates a Processor. +func NewProcessor(isH264 bool, isTCP bool) *Processor { + p := &Processor{ + isH264: isH264, + isTCP: isTCP, + } + + if isH264 { + p.h264Decoder = &rtph264.Decoder{} + p.h264Decoder.Init() + } + + return p +} + +func (p *Processor) processH264(pkt *rtp.Packet) ([]*ProcessorOutput, error) { + // decode + nalus, pts, err := p.h264Decoder.DecodeUntilMarker(pkt) + if err != nil { + if err == rtph264.ErrNonStartingPacketAndNoPrevious || + err == rtph264.ErrMorePacketsNeeded { + return []*ProcessorOutput{{ + Packet: pkt, + PTSEqualsDTS: false, + }}, nil + } + return nil, err + } + ptsEqualsDTS := h264.IDRPresent(nalus) + + // re-encode if packets use non-standard sizes + if p.isTCP && p.h264Encoder == nil && pkt.MarshalSize() > maxPacketSize { + v1 := pkt.SSRC + v2 := pkt.SequenceNumber + v3 := pkt.Timestamp + p.h264Encoder = &rtph264.Encoder{ + PayloadType: pkt.PayloadType, + SSRC: &v1, + InitialSequenceNumber: &v2, + InitialTimestamp: &v3, + } + p.h264Encoder.Init() + } + + if p.h264Encoder != nil { + packets, err := p.h264Encoder.Encode(nalus, pts) + if err != nil { + return nil, err + } + + output := make([]*ProcessorOutput, len(packets)) + + for i, pkt := range packets { + if i != len(packets)-1 { + output[i] = &ProcessorOutput{ + Packet: pkt, + PTSEqualsDTS: false, + } + } else { + output[i] = &ProcessorOutput{ + Packet: pkt, + PTSEqualsDTS: ptsEqualsDTS, + } + } + } + + return output, nil + } + + return []*ProcessorOutput{{ + Packet: pkt, + PTSEqualsDTS: ptsEqualsDTS, + H264NALUs: nalus, + H264PTS: pts, + }}, nil +} + +// Process processes a RTP packet. +func (p *Processor) Process(pkt *rtp.Packet) ([]*ProcessorOutput, error) { + // remove padding + pkt.Header.Padding = false + pkt.PaddingSize = 0 + + if p.h264Decoder != nil { + return p.processH264(pkt) + } + + if p.isTCP && pkt.MarshalSize() > maxPacketSize { + return nil, fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", + pkt.MarshalSize(), maxPacketSize) + } + + return []*ProcessorOutput{{ + Packet: pkt, + PTSEqualsDTS: true, + }}, nil +} diff --git a/serverconn.go b/serverconn.go index db8d159e..2b5cedd4 100644 --- a/serverconn.go +++ b/serverconn.go @@ -15,7 +15,6 @@ import ( "github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/liberrors" - "github.com/aler9/gortsplib/pkg/rtph264" "github.com/aler9/gortsplib/pkg/url" ) @@ -258,66 +257,21 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { return err } - ctx := ServerHandlerOnPacketRTPCtx{ - Session: sc.session, - TrackID: trackID, - Packet: pkt, + out, err := sc.session.announcedTracks[trackID].proc.Process(pkt) + if err != nil { + return err } - at := sc.session.announcedTracks[trackID] - sc.session.processPacketRTP(at, &ctx) - if at.h264Decoder != nil { - if at.h264Encoder == nil && len(payload) > maxPacketSize { - v1 := pkt.SSRC - v2 := pkt.SequenceNumber - v3 := pkt.Timestamp - at.h264Encoder = &rtph264.Encoder{ - PayloadType: pkt.PayloadType, - SSRC: &v1, - InitialSequenceNumber: &v2, - InitialTimestamp: &v3, - } - at.h264Encoder.Init() - } - - if at.h264Encoder != nil { - if ctx.H264NALUs != nil { - packets, err := at.h264Encoder.Encode(ctx.H264NALUs, ctx.H264PTS) - if err != nil { - return err - } - - for i, pkt := range packets { - if i != len(packets)-1 { - if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { - h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ - Session: sc.session, - TrackID: trackID, - Packet: pkt, - PTSEqualsDTS: false, - }) - } - } else { - ctx.Packet = pkt - if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { - h.OnPacketRTP(&ctx) - } - } - } - } - } else { - if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { - h.OnPacketRTP(&ctx) - } - } - } else { - if len(payload) > maxPacketSize { - return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", - len(payload), maxPacketSize) - } - - if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { - h.OnPacketRTP(&ctx) + if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { + for _, entry := range out { + h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ + Session: sc.session, + TrackID: trackID, + Packet: entry.Packet, + PTSEqualsDTS: entry.PTSEqualsDTS, + H264NALUs: entry.H264NALUs, + H264PTS: entry.H264PTS, + }) } } } else { diff --git a/serversession.go b/serversession.go index 9a21bb5d..15186a47 100644 --- a/serversession.go +++ b/serversession.go @@ -14,12 +14,11 @@ import ( "github.com/pion/rtp" "github.com/aler9/gortsplib/pkg/base" - "github.com/aler9/gortsplib/pkg/h264" "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/ringbuffer" "github.com/aler9/gortsplib/pkg/rtcpreceiver" - "github.com/aler9/gortsplib/pkg/rtph264" + "github.com/aler9/gortsplib/pkg/rtpproc" "github.com/aler9/gortsplib/pkg/url" ) @@ -154,8 +153,7 @@ type ServerSessionSetuppedTrack struct { type ServerSessionAnnouncedTrack struct { track Track rtcpReceiver *rtcpreceiver.RTCPReceiver - h264Decoder *rtph264.Decoder - h264Encoder *rtph264.Encoder + proc *rtpproc.Processor } // ServerSession is a server-side RTSP session. @@ -980,10 +978,8 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.state = ServerSessionStateRecord for _, at := range ss.announcedTracks { - if _, ok := at.track.(*TrackH264); ok { - at.h264Decoder = &rtph264.Decoder{} - at.h264Decoder.Init() - } + _, isH264 := at.track.(*TrackH264) + at.proc = rtpproc.NewProcessor(isH264, *ss.setuppedTransport == TransportTCP) } switch *ss.setuppedTransport { @@ -1104,8 +1100,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base } for _, at := range ss.announcedTracks { - at.h264Decoder = nil - at.h264Encoder = nil + at.proc = nil } ss.state = ServerSessionStatePreRecord @@ -1215,24 +1210,6 @@ func (ss *ServerSession) runWriter() { } } -func (ss *ServerSession) processPacketRTP(at *ServerSessionAnnouncedTrack, ctx *ServerHandlerOnPacketRTPCtx) { - // remove padding - ctx.Packet.Header.Padding = false - ctx.Packet.PaddingSize = 0 - - // decode - if at.h264Decoder != nil { - nalus, pts, err := at.h264Decoder.DecodeUntilMarker(ctx.Packet) - if err == nil { - ctx.PTSEqualsDTS = h264.IDRPresent(nalus) - ctx.H264NALUs = nalus - ctx.H264PTS = pts - } - } else { - ctx.PTSEqualsDTS = true - } -} - func (ss *ServerSession) onPacketRTCP(trackID int, pkt rtcp.Packet) { if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTCP); ok { h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ diff --git a/serverudpl.go b/serverudpl.go index d25b59a2..c6d64a99 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -203,17 +203,25 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { now := time.Now() atomic.StoreInt64(clientData.ss.udpLastFrameTime, now.Unix()) - ctx := ServerHandlerOnPacketRTPCtx{ - Session: clientData.ss, - TrackID: clientData.trackID, - Packet: pkt, - } at := clientData.ss.announcedTracks[clientData.trackID] - clientData.ss.processPacketRTP(at, &ctx) - at.rtcpReceiver.ProcessPacketRTP(now, ctx.Packet, ctx.PTSEqualsDTS) + out, err := at.proc.Process(pkt) + if err != nil { + return + } + out0 := out[0] + + at.rtcpReceiver.ProcessPacketRTP(now, pkt, out0.PTSEqualsDTS) + if h, ok := clientData.ss.s.Handler.(ServerHandlerOnPacketRTP); ok { - h.OnPacketRTP(&ctx) + h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ + Session: clientData.ss, + TrackID: clientData.trackID, + Packet: out0.Packet, + PTSEqualsDTS: out0.PTSEqualsDTS, + H264NALUs: out0.H264NALUs, + H264PTS: out0.H264PTS, + }) } }