diff --git a/clientudpl.go b/clientudpl.go index aca9bf35..491b34b9 100644 --- a/clientudpl.go +++ b/clientudpl.go @@ -182,11 +182,12 @@ func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) { return } + u.c.tracks[u.trackID].rtcpReceiver.ProcessPacketRTP(now, pkt, true) + // remove padding pkt.Header.Padding = false pkt.PaddingSize = 0 - u.c.tracks[u.trackID].rtcpReceiver.ProcessPacketRTP(now, pkt) u.c.OnPacketRTP(u.trackID, pkt) } diff --git a/examples/client-read-h264-save-to-disk/mpegtsencoder.go b/examples/client-read-h264-save-to-disk/mpegtsencoder.go index 40c44043..193167c9 100644 --- a/examples/client-read-h264-save-to-disk/mpegtsencoder.go +++ b/examples/client-read-h264-save-to-disk/mpegtsencoder.go @@ -62,17 +62,6 @@ func (e *mpegtsEncoder) encode(nalus [][]byte, pts time.Duration) error { e.startPTS = pts } - // check if there's an IDR - idrPresent := func() bool { - for _, nalu := range nalus { - typ := h264.NALUType(nalu[0] & 0x1F) - if typ == h264.NALUTypeIDR { - return true - } - } - return false - }() - // prepend an AUD. This is required by some players filteredNALUs := [][]byte{ {byte(h264.NALUTypeAccessUnitDelimiter), 240}, @@ -131,7 +120,7 @@ func (e *mpegtsEncoder) encode(nalus [][]byte, pts time.Duration) error { _, err = e.mux.WriteData(&astits.MuxerData{ PID: 256, AdaptationField: &astits.PacketAdaptationField{ - RandomAccessIndicator: idrPresent, + RandomAccessIndicator: h264.IDRPresent(filteredNALUs), }, PES: &astits.PESData{ Header: &astits.PESHeader{ diff --git a/pkg/h264/idrpresent.go b/pkg/h264/idrpresent.go new file mode 100644 index 00000000..9c329d49 --- /dev/null +++ b/pkg/h264/idrpresent.go @@ -0,0 +1,12 @@ +package h264 + +// IDRPresent check if there's an IDR inside provided NALUs. +func IDRPresent(nalus [][]byte) bool { + for _, nalu := range nalus { + typ := NALUType(nalu[0] & 0x1F) + if typ == NALUTypeIDR { + return true + } + } + return false +} diff --git a/pkg/rtcpreceiver/rtcpreceiver.go b/pkg/rtcpreceiver/rtcpreceiver.go index 96af15a5..9c0b6e74 100644 --- a/pkg/rtcpreceiver/rtcpreceiver.go +++ b/pkg/rtcpreceiver/rtcpreceiver.go @@ -131,7 +131,7 @@ func (rr *RTCPReceiver) report(ts time.Time) rtcp.Packet { } // ProcessPacketRTP extracts the needed data from RTP packets. -func (rr *RTCPReceiver) ProcessPacketRTP(ts time.Time, pkt *rtp.Packet) { +func (rr *RTCPReceiver) ProcessPacketRTP(ts time.Time, pkt *rtp.Packet, ptsEqualsDTS bool) { rr.mutex.Lock() defer rr.mutex.Unlock() @@ -140,8 +140,11 @@ func (rr *RTCPReceiver) ProcessPacketRTP(ts time.Time, pkt *rtp.Packet) { rr.firstRTPReceived = true rr.totalSinceReport = 1 rr.lastSequenceNumber = pkt.Header.SequenceNumber - rr.lastRTPTimeRTP = pkt.Header.Timestamp - rr.lastRTPTimeTime = ts + + if ptsEqualsDTS { + rr.lastRTPTimeRTP = pkt.Header.Timestamp + rr.lastRTPTimeTime = ts + } // subsequent packets } else { @@ -168,19 +171,25 @@ func (rr *RTCPReceiver) ProcessPacketRTP(ts time.Time, pkt *rtp.Packet) { } } - // compute jitter - // https://tools.ietf.org/html/rfc3550#page-39 - D := ts.Sub(rr.lastRTPTimeTime).Seconds()*rr.clockRate - - (float64(pkt.Header.Timestamp) - float64(rr.lastRTPTimeRTP)) - if D < 0 { - D = -D - } - rr.jitter += (D - rr.jitter) / 16 - rr.totalSinceReport += uint32(uint16(diff)) rr.lastSequenceNumber = pkt.Header.SequenceNumber - rr.lastRTPTimeRTP = pkt.Header.Timestamp - rr.lastRTPTimeTime = ts + + if ptsEqualsDTS { + var zero time.Time + if rr.lastRTPTimeTime != zero { + // update jitter + // https://tools.ietf.org/html/rfc3550#page-39 + D := ts.Sub(rr.lastRTPTimeTime).Seconds()*rr.clockRate - + (float64(pkt.Header.Timestamp) - float64(rr.lastRTPTimeRTP)) + if D < 0 { + D = -D + } + rr.jitter += (D - rr.jitter) / 16 + } + + rr.lastRTPTimeRTP = pkt.Header.Timestamp + rr.lastRTPTimeTime = ts + } } // ignore invalid packets (diff = 0) or reordered packets (diff < 0) } diff --git a/pkg/rtcpreceiver/rtcpreceiver_test.go b/pkg/rtcpreceiver/rtcpreceiver_test.go index 1c114a3e..0cac69f2 100644 --- a/pkg/rtcpreceiver/rtcpreceiver_test.go +++ b/pkg/rtcpreceiver/rtcpreceiver_test.go @@ -55,7 +55,7 @@ func TestRTCPReceiverBase(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, &rtpPkt) + rr.ProcessPacketRTP(ts, &rtpPkt, true) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -69,7 +69,7 @@ func TestRTCPReceiverBase(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) - rr.ProcessPacketRTP(ts, &rtpPkt) + rr.ProcessPacketRTP(ts, &rtpPkt, true) <-done } @@ -119,7 +119,7 @@ func TestRTCPReceiverOverflow(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, &rtpPkt) + rr.ProcessPacketRTP(ts, &rtpPkt, true) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -133,7 +133,7 @@ func TestRTCPReceiverOverflow(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, &rtpPkt) + rr.ProcessPacketRTP(ts, &rtpPkt, true) <-done } @@ -188,7 +188,7 @@ func TestRTCPReceiverPacketLost(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, &rtpPkt) + rr.ProcessPacketRTP(ts, &rtpPkt, true) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -202,7 +202,7 @@ func TestRTCPReceiverPacketLost(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, &rtpPkt) + rr.ProcessPacketRTP(ts, &rtpPkt, true) <-done } @@ -257,7 +257,7 @@ func TestRTCPReceiverOverflowPacketLost(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, &rtpPkt) + rr.ProcessPacketRTP(ts, &rtpPkt, true) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -271,7 +271,7 @@ func TestRTCPReceiverOverflowPacketLost(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, &rtpPkt) + rr.ProcessPacketRTP(ts, &rtpPkt, true) <-done } @@ -321,7 +321,7 @@ func TestRTCPReceiverReorderedPackets(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, &rtpPkt) + rr.ProcessPacketRTP(ts, &rtpPkt, true) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -335,7 +335,7 @@ func TestRTCPReceiverReorderedPackets(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, &rtpPkt) + rr.ProcessPacketRTP(ts, &rtpPkt, true) <-done } @@ -353,7 +353,7 @@ func TestRTCPReceiverJitter(t *testing.T) { Reports: []rtcp.ReceptionReport{ { SSRC: 0xba9da416, - LastSequenceNumber: 947, + LastSequenceNumber: 948, LastSenderReport: 0x887a17ce, Delay: 2 * 65536, Jitter: 45000 / 16, @@ -386,7 +386,7 @@ func TestRTCPReceiverJitter(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacketRTP(ts, &rtpPkt) + rr.ProcessPacketRTP(ts, &rtpPkt, true) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -400,7 +400,21 @@ func TestRTCPReceiverJitter(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) - rr.ProcessPacketRTP(ts, &rtpPkt) + rr.ProcessPacketRTP(ts, &rtpPkt, true) + + rtpPkt = rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: 948, + Timestamp: 0xafb45733, + SSRC: 0xba9da416, + }, + Payload: []byte("\x00\x00"), + } + ts = time.Date(2008, 0o5, 20, 22, 15, 22, 0, time.UTC) + rr.ProcessPacketRTP(ts, &rtpPkt, false) <-done } diff --git a/serverconn.go b/serverconn.go index 8922b97f..d1bded4b 100644 --- a/serverconn.go +++ b/serverconn.go @@ -257,31 +257,15 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { return } - // remove padding - pkt.Header.Padding = false - pkt.PaddingSize = 0 - - if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { - h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ - Session: sc.session, - TrackID: trackID, - Packet: pkt, - }) - } + sc.session.onPacketRTP(time.Time{}, trackID, pkt) } else { packets, err := rtcp.Unmarshal(payload) if err != nil { return } - if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok { - for _, pkt := range packets { - h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ - Session: sc.session, - TrackID: trackID, - Packet: pkt, - }) - } + for _, pkt := range packets { + sc.session.onPacketRTCP(trackID, pkt) } } } diff --git a/serverhandler.go b/serverhandler.go index 2372d60e..a4aadbac 100644 --- a/serverhandler.go +++ b/serverhandler.go @@ -1,6 +1,8 @@ package gortsplib import ( + "time" + "github.com/pion/rtcp" "github.com/pion/rtp" @@ -184,9 +186,12 @@ type ServerHandlerOnSetParameter interface { // ServerHandlerOnPacketRTPCtx is the context of a RTP packet. type ServerHandlerOnPacketRTPCtx struct { - Session *ServerSession - TrackID int - Packet *rtp.Packet + Session *ServerSession + TrackID int + Packet *rtp.Packet + PTSEqualsDTS bool + H264NALUs [][]byte + H264PTS time.Duration } // ServerHandlerOnPacketRTP can be implemented by a ServerHandler. diff --git a/serversession.go b/serversession.go index 867d1a07..544ffb17 100644 --- a/serversession.go +++ b/serversession.go @@ -15,10 +15,12 @@ import ( "github.com/pion/rtp" "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/h264" "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/ringbuffer" "github.com/aler9/gortsplib/pkg/rtcpreceiver" + "github.com/aler9/gortsplib/pkg/rtph264" ) func stringsReverseIndex(s, substr string) int { @@ -33,7 +35,7 @@ func stringsReverseIndex(s, substr string) int { func setupGetTrackIDPathQuery( url *base.URL, thMode *headers.TransportMode, - announcedTracks []ServerSessionAnnouncedTrack, + announcedTracks []*ServerSessionAnnouncedTrack, setuppedPath *string, setuppedQuery *string, setuppedBaseURL *base.URL, @@ -152,6 +154,7 @@ type ServerSessionSetuppedTrack struct { type ServerSessionAnnouncedTrack struct { track Track rtcpReceiver *rtcpreceiver.RTCPReceiver + h264Decoder *rtph264.Decoder } // ServerSession is a server-side RTSP session. @@ -173,8 +176,8 @@ type ServerSession struct { setuppedQuery *string lastRequestTime time.Time tcpConn *ServerConn - announcedTracks []ServerSessionAnnouncedTrack // publish - udpLastFrameTime *int64 // publish + announcedTracks []*ServerSessionAnnouncedTrack // publish + udpLastFrameTime *int64 // publish udpCheckStreamTimer *time.Timer writerRunning bool writeBuffer *ringbuffer.RingBuffer @@ -237,7 +240,7 @@ func (ss *ServerSession) SetuppedTransport() *Transport { } // AnnouncedTracks returns the announced tracks. -func (ss *ServerSession) AnnouncedTracks() []ServerSessionAnnouncedTrack { +func (ss *ServerSession) AnnouncedTracks() []*ServerSessionAnnouncedTrack { return ss.announcedTracks } @@ -282,9 +285,9 @@ func (ss *ServerSession) run() { ss.s.udpRTPListener.removeClient(ss) ss.s.udpRTCPListener.removeClient(ss) - for trackID := range ss.setuppedTracks { - ss.announcedTracks[trackID].rtcpReceiver.Close() - ss.announcedTracks[trackID].rtcpReceiver = nil + for _, at := range ss.announcedTracks { + at.rtcpReceiver.Close() + at.rtcpReceiver = nil } } } @@ -296,7 +299,6 @@ func (ss *ServerSession) run() { if ss.writerRunning { ss.writeBuffer.Close() <-ss.writerDone - ss.writerRunning = false } for sc := range ss.conns { @@ -550,9 +552,9 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.setuppedQuery = &query ss.setuppedBaseURL = req.URL - ss.announcedTracks = make([]ServerSessionAnnouncedTrack, len(tracks)) + ss.announcedTracks = make([]*ServerSessionAnnouncedTrack, len(tracks)) for trackID, track := range tracks { - ss.announcedTracks[trackID] = ServerSessionAnnouncedTrack{ + ss.announcedTracks[trackID] = &ServerSessionAnnouncedTrack{ track: track, } } @@ -871,7 +873,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base default: // TCP ss.tcpConn = sc - ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP err = errSwitchReadFunc @@ -976,6 +977,13 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.state = ServerSessionStateRecord + for _, at := range ss.announcedTracks { + if _, ok := at.track.(*TrackH264); ok { + at.h264Decoder = &rtph264.Decoder{} + at.h264Decoder.Init() + } + } + switch *ss.setuppedTransport { case TransportUDP: ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod) @@ -984,25 +992,24 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.writerDone = make(chan struct{}) go ss.runWriter() - for trackID, track := range ss.setuppedTracks { + for trackID, at := range ss.announcedTracks { // open the firewall by sending packets to the counterpart ss.WritePacketRTP(trackID, &rtp.Packet{Header: rtp.Header{Version: 2}}) ss.WritePacketRTCP(trackID, &rtcp.ReceiverReport{}) ctrackID := trackID - ss.announcedTracks[trackID].rtcpReceiver = rtcpreceiver.New(ss.s.udpReceiverReportPeriod, - nil, ss.announcedTracks[trackID].track.ClockRate(), func(pkt rtcp.Packet) { + at.rtcpReceiver = rtcpreceiver.New(ss.s.udpReceiverReportPeriod, + nil, at.track.ClockRate(), func(pkt rtcp.Packet) { ss.WritePacketRTCP(ctrackID, pkt) }) - ss.s.udpRTPListener.addClient(ss.author.ip(), track.udpRTPPort, ss, trackID, true) - ss.s.udpRTCPListener.addClient(ss.author.ip(), track.udpRTCPPort, ss, trackID, true) + ss.s.udpRTPListener.addClient(ss.author.ip(), ss.setuppedTracks[trackID].udpRTPPort, ss, trackID, true) + ss.s.udpRTCPListener.addClient(ss.author.ip(), ss.setuppedTracks[trackID].udpRTCPPort, ss, trackID, true) } default: // TCP ss.tcpConn = sc - ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP err = errSwitchReadFunc @@ -1072,13 +1079,10 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base default: // TCP ss.tcpConn.readFunc = ss.tcpConn.readFuncStandard err = errSwitchReadFunc - ss.tcpConn = nil } case ServerSessionStateRecord: - ss.state = ServerSessionStatePreRecord - switch *ss.setuppedTransport { case TransportUDP: ss.udpCheckStreamTimer = emptyTimer() @@ -1086,17 +1090,22 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.s.udpRTPListener.removeClient(ss) ss.s.udpRTCPListener.removeClient(ss) - for trackID := range ss.setuppedTracks { - ss.announcedTracks[trackID].rtcpReceiver.Close() - ss.announcedTracks[trackID].rtcpReceiver = nil + for _, at := range ss.announcedTracks { + at.rtcpReceiver.Close() + at.rtcpReceiver = nil } default: // TCP ss.tcpConn.readFunc = ss.tcpConn.readFuncStandard err = errSwitchReadFunc - ss.tcpConn = nil } + + for _, at := range ss.announcedTracks { + at.h264Decoder = nil + } + + ss.state = ServerSessionStatePreRecord } return res, err @@ -1203,6 +1212,76 @@ func (ss *ServerSession) runWriter() { } } +func (ss *ServerSession) onPacketRTP(now time.Time, trackID int, pkt *rtp.Packet) { + // remove padding + pkt.Header.Padding = false + pkt.PaddingSize = 0 + + at := ss.announcedTracks[trackID] + + if at.h264Decoder != nil { + nalus, pts, err := at.h264Decoder.DecodeUntilMarker(pkt) + if err == nil { + ptsEqualsDTS := h264.IDRPresent(nalus) + + rr := at.rtcpReceiver + if rr != nil { + rr.ProcessPacketRTP(now, pkt, ptsEqualsDTS) + } + + if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTP); ok { + h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ + Session: ss, + TrackID: trackID, + Packet: pkt, + PTSEqualsDTS: ptsEqualsDTS, + H264NALUs: append([][]byte(nil), nalus...), + H264PTS: pts, + }) + } + } else { + rr := at.rtcpReceiver + if rr != nil { + rr.ProcessPacketRTP(now, pkt, false) + } + + if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTP); ok { + h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ + Session: ss, + TrackID: trackID, + Packet: pkt, + PTSEqualsDTS: false, + }) + } + } + return + } + + rr := at.rtcpReceiver + if rr != nil { + rr.ProcessPacketRTP(now, pkt, true) + } + + if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTP); ok { + h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ + Session: ss, + TrackID: trackID, + Packet: pkt, + PTSEqualsDTS: true, + }) + } +} + +func (ss *ServerSession) onPacketRTCP(trackID int, pkt rtcp.Packet) { + if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTCP); ok { + h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ + Session: ss, + TrackID: trackID, + Packet: pkt, + }) + } +} + func (ss *ServerSession) writePacketRTP(trackID int, byts []byte) { if _, ok := ss.setuppedTracks[trackID]; !ok { return diff --git a/serverudpl.go b/serverudpl.go index 97e5caca..c3139128 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -204,21 +204,10 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { return } - // remove padding - pkt.Header.Padding = false - pkt.PaddingSize = 0 - now := time.Now() atomic.StoreInt64(clientData.ss.udpLastFrameTime, now.Unix()) - clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessPacketRTP(now, pkt) - if h, ok := u.s.Handler.(ServerHandlerOnPacketRTP); ok { - h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ - Session: clientData.ss, - TrackID: clientData.trackID, - Packet: pkt, - }) - } + clientData.ss.onPacketRTP(now, clientData.trackID, pkt) } func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) { @@ -236,14 +225,8 @@ func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) } } - if h, ok := u.s.Handler.(ServerHandlerOnPacketRTCP); ok { - for _, pkt := range packets { - h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ - Session: clientData.ss, - TrackID: clientData.trackID, - Packet: pkt, - }) - } + for _, pkt := range packets { + clientData.ss.onPacketRTCP(clientData.trackID, pkt) } }