From b2a849dbd84c692a9604cc45d10e02f94a7204d3 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sun, 13 Mar 2022 01:30:37 +0100 Subject: [PATCH] move RTCP sender / receiver writes into dedicate routine --- client.go | 172 +++++++++---------- client_publish_test.go | 5 - client_read_test.go | 12 +- pkg/rtcpreceiver/rtcpreceiver.go | 58 ++++++- pkg/rtcpreceiver/rtcpreceiver_test.go | 233 +++++++++++++++----------- pkg/rtcpsender/rtcpsender.go | 56 ++++++- pkg/rtcpsender/rtcpsender_test.go | 27 +-- serversession.go | 47 +++--- 8 files changed, 354 insertions(+), 256 deletions(-) diff --git a/client.go b/client.go index 4cc147b3..812b4ee0 100644 --- a/client.go +++ b/client.go @@ -213,12 +213,11 @@ type Client struct { lastDescribeURL *base.URL streamBaseURL *base.URL effectiveTransport *Transport - tracks map[int]clientTrack + tracks map[int]*clientTrack tracksByChannel map[int]int lastRange *headers.Range writeMutex sync.RWMutex // publish writeFrameAllowed bool // publish - udpReportTimer *time.Timer checkStreamTimer *time.Timer checkStreamInitial bool tcpLastFrameTime *int64 @@ -310,7 +309,6 @@ func (c *Client) Start(scheme string, host string) error { c.host = host c.ctx = ctx c.ctxCancel = ctxCancel - c.udpReportTimer = emptyTimer() c.checkStreamTimer = emptyTimer() c.keepaliveTimer = emptyTimer() c.options = make(chan optionsReq) @@ -468,29 +466,6 @@ func (c *Client) runInner() error { res, err := c.doPause() req.res <- clientRes{res: res, err: err} - case <-c.udpReportTimer.C: - if c.state == clientStatePlay { - now := time.Now() - for trackID, cct := range c.tracks { - rr := cct.rtcpReceiver.Report(now) - if rr != nil { - c.WritePacketRTCP(trackID, rr) - } - } - - c.udpReportTimer = time.NewTimer(c.udpReceiverReportPeriod) - } else { // Record - now := time.Now() - for trackID, cct := range c.tracks { - sr := cct.rtcpSender.Report(now) - if sr != nil { - c.WritePacketRTCP(trackID, sr) - } - } - - c.udpReportTimer = time.NewTimer(c.udpSenderReportPeriod) - } - case <-c.checkStreamTimer.C: if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast { @@ -581,13 +556,6 @@ func (c *Client) runInner() error { func (c *Client) doClose() { if c.state == clientStatePlay || c.state == clientStateRecord { - if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast { - for _, cct := range c.tracks { - cct.udpRTPListener.stop() - cct.udpRTCPListener.stop() - } - } - c.playRecordStop(true) c.do(&base.Request{ @@ -697,20 +665,41 @@ func (c *Client) playRecordStart() { c.writeFrameAllowed = true c.writeMutex.Unlock() - // start timers if c.state == clientStatePlay { c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) switch *c.effectiveTransport { case TransportUDP: - c.udpReportTimer = time.NewTimer(c.udpReceiverReportPeriod) + for trackID, cct := range c.tracks { + cct.rtcpReceiver = rtcpreceiver.New(c.udpReceiverReportPeriod, nil, + cct.track.ClockRate(), func(pkt rtcp.Packet) { + c.WritePacketRTCP(trackID, pkt) + }) + } + c.checkStreamTimer = time.NewTimer(c.InitialUDPReadTimeout) c.checkStreamInitial = true + for _, cct := range c.tracks { + cct.udpRTPListener.start(true) + cct.udpRTCPListener.start(true) + } + case TransportUDPMulticast: - c.udpReportTimer = time.NewTimer(c.udpReceiverReportPeriod) + for trackID, cct := range c.tracks { + cct.rtcpReceiver = rtcpreceiver.New(c.udpReceiverReportPeriod, nil, + cct.track.ClockRate(), func(pkt rtcp.Packet) { + c.WritePacketRTCP(trackID, pkt) + }) + } + c.checkStreamTimer = time.NewTimer(c.checkStreamPeriod) + for _, cct := range c.tracks { + cct.udpRTPListener.start(true) + cct.udpRTCPListener.start(true) + } + default: // TCP c.checkStreamTimer = time.NewTimer(c.checkStreamPeriod) v := time.Now().Unix() @@ -719,10 +708,30 @@ func (c *Client) playRecordStart() { } else { switch *c.effectiveTransport { case TransportUDP: - c.udpReportTimer = time.NewTimer(c.udpSenderReportPeriod) + for trackID, cct := range c.tracks { + cct.rtcpSender = rtcpsender.New(c.udpSenderReportPeriod, + cct.track.ClockRate(), func(pkt rtcp.Packet) { + c.WritePacketRTCP(trackID, pkt) + }) + } + + for _, cct := range c.tracks { + cct.udpRTPListener.start(true) + cct.udpRTCPListener.start(true) + } case TransportUDPMulticast: - c.udpReportTimer = time.NewTimer(c.udpSenderReportPeriod) + for trackID, cct := range c.tracks { + cct.rtcpSender = rtcpsender.New(c.udpSenderReportPeriod, + cct.track.ClockRate(), func(pkt rtcp.Packet) { + c.WritePacketRTCP(trackID, pkt) + }) + } + + for _, cct := range c.tracks { + cct.udpRTPListener.start(true) + cct.udpRTCPListener.start(true) + } } } @@ -838,16 +847,35 @@ func (c *Client) playRecordStop(isClosing bool) { <-c.readerErr } - // stop timers - c.udpReportTimer = emptyTimer() - c.checkStreamTimer = emptyTimer() - c.keepaliveTimer = emptyTimer() - // forbid writing c.writeMutex.Lock() c.writeFrameAllowed = false c.writeMutex.Unlock() + if *c.effectiveTransport == TransportUDP || + *c.effectiveTransport == TransportUDPMulticast { + for _, cct := range c.tracks { + cct.udpRTPListener.stop() + cct.udpRTCPListener.stop() + } + + if c.state == clientStatePlay { + for _, cct := range c.tracks { + cct.rtcpReceiver.Close() + cct.rtcpReceiver = nil + } + } else { + for _, cct := range c.tracks { + cct.rtcpSender.Close() + cct.rtcpSender = nil + } + } + } + + // stop timers + c.checkStreamTimer = emptyTimer() + c.keepaliveTimer = emptyTimer() + // stop writer c.writeBuffer.Close() <-c.writerDone @@ -1461,17 +1489,14 @@ func (c *Client) doSetup( } } - cct := clientTrack{ + cct := &clientTrack{ track: track, } - clockRate := track.ClockRate() if mode == headers.TransportModePlay { c.state = clientStatePrePlay - cct.rtcpReceiver = rtcpreceiver.New(nil, clockRate) } else { c.state = clientStatePreRecord - cct.rtcpSender = rtcpsender.New(clockRate) } c.streamBaseURL = baseURL @@ -1547,7 +1572,7 @@ func (c *Client) doSetup( } if c.tracks == nil { - c.tracks = make(map[int]clientTrack) + c.tracks = make(map[int]*clientTrack) } c.tracks[trackID] = cct @@ -1590,14 +1615,9 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp return nil, err } - // setup UDP communication before sending the request. + // open the firewall by sending packets to the counterpart. + // do this before sending the request. if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast { - for _, cct := range c.tracks { - cct.udpRTPListener.start(true) - cct.udpRTCPListener.start(true) - } - - // open the firewall by sending packets to the counterpart. for _, cct := range c.tracks { byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal() cct.udpRTPListener.write(byts) @@ -1624,24 +1644,10 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp }, }, false, *c.effectiveTransport == TransportTCP) if err != nil { - if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast { - for _, cct := range c.tracks { - cct.udpRTPListener.stop() - cct.udpRTCPListener.stop() - } - } - return nil, err } if res.StatusCode != base.StatusOK { - if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast { - for _, cct := range c.tracks { - cct.udpRTPListener.stop() - cct.udpRTCPListener.stop() - } - } - return nil, liberrors.ErrClientBadStatusCode{ Code: res.StatusCode, Message: res.StatusMessage, } @@ -1689,36 +1695,15 @@ func (c *Client) doRecord() (*base.Response, error) { return nil, err } - if *c.effectiveTransport == TransportUDP { - for _, cct := range c.tracks { - cct.udpRTPListener.start(false) - cct.udpRTCPListener.start(false) - } - } - res, err := c.do(&base.Request{ Method: base.Record, URL: c.streamBaseURL, }, false, false) if err != nil { - if *c.effectiveTransport == TransportUDP { - for _, cct := range c.tracks { - cct.udpRTPListener.stop() - cct.udpRTCPListener.stop() - } - } - return nil, err } if res.StatusCode != base.StatusOK { - if *c.effectiveTransport == TransportUDP { - for _, cct := range c.tracks { - cct.udpRTPListener.stop() - cct.udpRTCPListener.stop() - } - } - return nil, liberrors.ErrClientBadStatusCode{ Code: res.StatusCode, Message: res.StatusMessage, } @@ -1755,13 +1740,6 @@ func (c *Client) doPause() (*base.Response, error) { c.playRecordStop(false) - if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast { - for _, cct := range c.tracks { - cct.udpRTPListener.stop() - cct.udpRTCPListener.stop() - } - } - // change state regardless of the response switch c.state { case clientStatePlay: diff --git a/client_publish_test.go b/client_publish_test.go index 48ae10fd..cca6281e 100644 --- a/client_publish_test.go +++ b/client_publish_test.go @@ -15,7 +15,6 @@ import ( "github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/headers" - "github.com/aler9/gortsplib/pkg/rtcpreceiver" ) var testRTPPacket = rtp.Packet{ @@ -945,15 +944,12 @@ func TestClientPublishRTCPReport(t *testing.T) { _, err = conn.Write(bb.Bytes()) require.NoError(t, err) - rr := rtcpreceiver.New(nil, 90000) - buf := make([]byte, 2048) n, _, err := l1.ReadFrom(buf) require.NoError(t, err) var pkt rtp.Packet err = pkt.Unmarshal(buf[:n]) require.NoError(t, err) - rr.ProcessPacketRTP(time.Now(), &pkt) buf = make([]byte, 2048) n, _, err = l2.ReadFrom(buf) @@ -969,7 +965,6 @@ func TestClientPublishRTCPReport(t *testing.T) { PacketCount: 1, OctetCount: 4, }, sr) - rr.ProcessPacketRTCP(time.Now(), packets[0]) close(reportReceived) diff --git a/client_read_test.go b/client_read_test.go index 74808f6c..8eb527d1 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -19,7 +19,6 @@ import ( "github.com/aler9/gortsplib/pkg/auth" "github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/headers" - "github.com/aler9/gortsplib/pkg/rtcpsender" ) func TestClientReadTracks(t *testing.T) { @@ -1977,8 +1976,6 @@ func TestClientReadRTCPReport(t *testing.T) { _, _, err = l2.ReadFrom(buf) require.NoError(t, err) - rs := rtcpsender.New(90000) - pkt := rtp.Packet{ Header: rtp.Header{ Version: 2, @@ -1996,9 +1993,14 @@ func TestClientReadRTCPReport(t *testing.T) { Port: inTH.ClientPorts[0], }) require.NoError(t, err) - rs.ProcessPacketRTP(time.Now(), &pkt) - sr := rs.Report(time.Now()) + sr := &rtcp.SenderReport{ + SSRC: 753621, + NTPTime: 0, + RTPTime: 0, + PacketCount: 0, + OctetCount: 0, + } byts, _ = sr.Marshal() _, err = l2.WriteTo(byts, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), diff --git a/pkg/rtcpreceiver/rtcpreceiver.go b/pkg/rtcpreceiver/rtcpreceiver.go index 3e204aa4..285b5709 100644 --- a/pkg/rtcpreceiver/rtcpreceiver.go +++ b/pkg/rtcpreceiver/rtcpreceiver.go @@ -16,11 +16,15 @@ func randUint32() uint32 { return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) } +var now = time.Now + // RTCPReceiver is a utility to generate RTCP receiver reports. type RTCPReceiver struct { - receiverSSRC uint32 - clockRate float64 - mutex sync.Mutex + period time.Duration + receiverSSRC uint32 + clockRate float64 + writePacketRTCP func(rtcp.Packet) + mutex sync.Mutex // data from RTP packets firstRTPReceived bool @@ -37,24 +41,60 @@ type RTCPReceiver struct { senderSSRC uint32 lastSenderReport uint32 lastSenderReportTime time.Time + + terminate chan struct{} + done chan struct{} } // New allocates a RTCPReceiver. -func New(receiverSSRC *uint32, clockRate int) *RTCPReceiver { - return &RTCPReceiver{ +func New(period time.Duration, receiverSSRC *uint32, clockRate int, + writePacketRTCP func(rtcp.Packet)) *RTCPReceiver { + rr := &RTCPReceiver{ + period: period, receiverSSRC: func() uint32 { if receiverSSRC == nil { return randUint32() } return *receiverSSRC }(), - clockRate: float64(clockRate), + clockRate: float64(clockRate), + writePacketRTCP: writePacketRTCP, + terminate: make(chan struct{}), + done: make(chan struct{}), + } + + go rr.run() + + return rr +} + +// Close closes the RTCPReceiver. +func (rr *RTCPReceiver) Close() { + close(rr.terminate) + <-rr.done +} + +func (rr *RTCPReceiver) run() { + defer close(rr.done) + + t := time.NewTicker(rr.period) + defer t.Stop() + + for { + select { + case <-t.C: + report := rr.report(now()) + if report != nil { + rr.writePacketRTCP(report) + } + + case <-rr.terminate: + return + } } } -// Report generates a RTCP receiver report. -// It returns nil if no RTCP sender reports have been passed to ProcessPacketRTCP yet. -func (rr *RTCPReceiver) Report(ts time.Time) rtcp.Packet { +func (rr *RTCPReceiver) report(ts time.Time) rtcp.Packet { rr.mutex.Lock() defer rr.mutex.Unlock() diff --git a/pkg/rtcpreceiver/rtcpreceiver_test.go b/pkg/rtcpreceiver/rtcpreceiver_test.go index 0304feb6..17051328 100644 --- a/pkg/rtcpreceiver/rtcpreceiver_test.go +++ b/pkg/rtcpreceiver/rtcpreceiver_test.go @@ -10,10 +10,28 @@ import ( ) func TestRTCPReceiverBase(t *testing.T) { + now = func() time.Time { + return time.Date(2008, 0o5, 20, 22, 15, 22, 0, time.UTC) + } + done := make(chan struct{}) v := uint32(0x65f83afb) - rr := New(&v, 90000) - require.Equal(t, nil, rr.Report(time.Now())) + rr := New(500*time.Millisecond, &v, 90000, + func(pkt rtcp.Packet) { + require.Equal(t, &rtcp.ReceiverReport{ + SSRC: 0x65f83afb, + Reports: []rtcp.ReceptionReport{ + { + SSRC: 0xba9da416, + LastSequenceNumber: 947, + LastSenderReport: 0x887a17ce, + Delay: 2 * 65536, + }, + }, + }, pkt) + close(done) + }) + defer rr.Close() srPkt := rtcp.SenderReport{ SSRC: 0xba9da416, @@ -53,24 +71,31 @@ func TestRTCPReceiverBase(t *testing.T) { ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) rr.ProcessPacketRTP(ts, &rtpPkt) - expectedPkt := rtcp.ReceiverReport{ - SSRC: 0x65f83afb, - Reports: []rtcp.ReceptionReport{ - { - SSRC: 0xba9da416, - LastSequenceNumber: 947, - LastSenderReport: 0x887a17ce, - Delay: 2 * 65536, - }, - }, - } - ts = time.Date(2008, 0o5, 20, 22, 15, 22, 0, time.UTC) - require.Equal(t, &expectedPkt, rr.Report(ts)) + <-done } func TestRTCPReceiverOverflow(t *testing.T) { + done := make(chan struct{}) + now = func() time.Time { + return time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) + } v := uint32(0x65f83afb) - rr := New(&v, 90000) + + rr := New(500*time.Millisecond, &v, 90000, func(pkt rtcp.Packet) { + require.Equal(t, &rtcp.ReceiverReport{ + SSRC: 0x65f83afb, + Reports: []rtcp.ReceptionReport{ + { + SSRC: 0xba9da416, + LastSequenceNumber: 1<<16 | 0x0000, + LastSenderReport: 0x887a17ce, + Delay: 1 * 65536, + }, + }, + }, pkt) + close(done) + }) + defer rr.Close() srPkt := rtcp.SenderReport{ SSRC: 0xba9da416, @@ -110,24 +135,36 @@ func TestRTCPReceiverOverflow(t *testing.T) { ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) rr.ProcessPacketRTP(ts, &rtpPkt) - expectedPkt := rtcp.ReceiverReport{ - SSRC: 0x65f83afb, - Reports: []rtcp.ReceptionReport{ - { - SSRC: 0xba9da416, - LastSequenceNumber: 1<<16 | 0x0000, - LastSenderReport: 0x887a17ce, - Delay: 1 * 65536, - }, - }, - } - ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) - require.Equal(t, &expectedPkt, rr.Report(ts)) + <-done } func TestRTCPReceiverPacketLost(t *testing.T) { + done := make(chan struct{}) + now = func() time.Time { + return time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) + } v := uint32(0x65f83afb) - rr := New(&v, 90000) + + rr := New(500*time.Millisecond, &v, 90000, func(pkt rtcp.Packet) { + require.Equal(t, &rtcp.ReceiverReport{ + SSRC: 0x65f83afb, + Reports: []rtcp.ReceptionReport{ + { + SSRC: 0xba9da416, + LastSequenceNumber: 0x0122, + LastSenderReport: 0x887a17ce, + FractionLost: func() uint8 { + v := float64(1) / 3 + return uint8(v * 256) + }(), + TotalLost: 1, + Delay: 1 * 65536, + }, + }, + }, pkt) + close(done) + }) + defer rr.Close() srPkt := rtcp.SenderReport{ SSRC: 0xba9da416, @@ -167,29 +204,36 @@ func TestRTCPReceiverPacketLost(t *testing.T) { ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) rr.ProcessPacketRTP(ts, &rtpPkt) - expectedPkt := rtcp.ReceiverReport{ - SSRC: 0x65f83afb, - Reports: []rtcp.ReceptionReport{ - { - SSRC: 0xba9da416, - LastSequenceNumber: 0x0122, - LastSenderReport: 0x887a17ce, - FractionLost: func() uint8 { - v := float64(1) / 3 - return uint8(v * 256) - }(), - TotalLost: 1, - Delay: 1 * 65536, - }, - }, - } - ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) - require.Equal(t, &expectedPkt, rr.Report(ts)) + <-done } func TestRTCPReceiverOverflowPacketLost(t *testing.T) { + done := make(chan struct{}) + now = func() time.Time { + return time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) + } v := uint32(0x65f83afb) - rr := New(&v, 90000) + + rr := New(500*time.Millisecond, &v, 90000, func(pkt rtcp.Packet) { + require.Equal(t, &rtcp.ReceiverReport{ + SSRC: 0x65f83afb, + Reports: []rtcp.ReceptionReport{ + { + SSRC: 0xba9da416, + LastSequenceNumber: 1<<16 | 0x0002, + LastSenderReport: 0x887a17ce, + FractionLost: func() uint8 { + v := float64(2) / 4 + return uint8(v * 256) + }(), + TotalLost: 2, + Delay: 1 * 65536, + }, + }, + }, pkt) + close(done) + }) + defer rr.Close() srPkt := rtcp.SenderReport{ SSRC: 0xba9da416, @@ -229,29 +273,31 @@ func TestRTCPReceiverOverflowPacketLost(t *testing.T) { ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) rr.ProcessPacketRTP(ts, &rtpPkt) - expectedPkt := rtcp.ReceiverReport{ - SSRC: 0x65f83afb, - Reports: []rtcp.ReceptionReport{ - { - SSRC: 0xba9da416, - LastSequenceNumber: 1<<16 | 0x0002, - LastSenderReport: 0x887a17ce, - FractionLost: func() uint8 { - v := float64(2) / 4 - return uint8(v * 256) - }(), - TotalLost: 2, - Delay: 1 * 65536, - }, - }, - } - ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) - require.Equal(t, &expectedPkt, rr.Report(ts)) + <-done } func TestRTCPReceiverReorderedPackets(t *testing.T) { + done := make(chan struct{}) + now = func() time.Time { + return time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) + } v := uint32(0x65f83afb) - rr := New(&v, 90000) + + rr := New(500*time.Millisecond, &v, 90000, func(pkt rtcp.Packet) { + require.Equal(t, &rtcp.ReceiverReport{ + SSRC: 0x65f83afb, + Reports: []rtcp.ReceptionReport{ + { + SSRC: 0xba9da416, + LastSequenceNumber: 0x43a7, + LastSenderReport: 0x887a17ce, + Delay: 1 * 65536, + }, + }, + }, pkt) + close(done) + }) + defer rr.Close() srPkt := rtcp.SenderReport{ SSRC: 0xba9da416, @@ -291,24 +337,32 @@ func TestRTCPReceiverReorderedPackets(t *testing.T) { ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) rr.ProcessPacketRTP(ts, &rtpPkt) - expectedPkt := rtcp.ReceiverReport{ - SSRC: 0x65f83afb, - Reports: []rtcp.ReceptionReport{ - { - SSRC: 0xba9da416, - LastSequenceNumber: 0x43a7, - LastSenderReport: 0x887a17ce, - Delay: 1 * 65536, - }, - }, - } - ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) - require.Equal(t, &expectedPkt, rr.Report(ts)) + <-done } func TestRTCPReceiverJitter(t *testing.T) { + done := make(chan struct{}) + now = func() time.Time { + return time.Date(2008, 0o5, 20, 22, 15, 22, 0, time.UTC) + } v := uint32(0x65f83afb) - rr := New(&v, 90000) + + rr := New(500*time.Millisecond, &v, 90000, func(pkt rtcp.Packet) { + require.Equal(t, &rtcp.ReceiverReport{ + SSRC: 0x65f83afb, + Reports: []rtcp.ReceptionReport{ + { + SSRC: 0xba9da416, + LastSequenceNumber: 947, + LastSenderReport: 0x887a17ce, + Delay: 2 * 65536, + Jitter: 45000 / 16, + }, + }, + }, pkt) + close(done) + }) + defer rr.Close() srPkt := rtcp.SenderReport{ SSRC: 0xba9da416, @@ -348,18 +402,5 @@ func TestRTCPReceiverJitter(t *testing.T) { ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) rr.ProcessPacketRTP(ts, &rtpPkt) - expectedPkt := rtcp.ReceiverReport{ - SSRC: 0x65f83afb, - Reports: []rtcp.ReceptionReport{ - { - SSRC: 0xba9da416, - LastSequenceNumber: 947, - LastSenderReport: 0x887a17ce, - Delay: 2 * 65536, - Jitter: 45000 / 16, - }, - }, - } - ts = time.Date(2008, 0o5, 20, 22, 15, 22, 0, time.UTC) - require.Equal(t, &expectedPkt, rr.Report(ts)) + <-done } diff --git a/pkg/rtcpsender/rtcpsender.go b/pkg/rtcpsender/rtcpsender.go index 5cbcea41..89c774a0 100644 --- a/pkg/rtcpsender/rtcpsender.go +++ b/pkg/rtcpsender/rtcpsender.go @@ -9,10 +9,14 @@ import ( "github.com/pion/rtp/v2" ) +var now = time.Now + // RTCPSender is a utility to generate RTCP sender reports. type RTCPSender struct { - clockRate float64 - mutex sync.Mutex + period time.Duration + clockRate float64 + writePacketRTCP func(rtcp.Packet) + mutex sync.Mutex // data from RTP packets firstRTPReceived bool @@ -21,18 +25,54 @@ type RTCPSender struct { lastRTPTimeTime time.Time packetCount uint32 octetCount uint32 + + terminate chan struct{} + done chan struct{} } // New allocates a RTCPSender. -func New(clockRate int) *RTCPSender { - return &RTCPSender{ - clockRate: float64(clockRate), +func New(period time.Duration, clockRate int, + writePacketRTCP func(rtcp.Packet)) *RTCPSender { + rs := &RTCPSender{ + period: period, + clockRate: float64(clockRate), + writePacketRTCP: writePacketRTCP, + terminate: make(chan struct{}), + done: make(chan struct{}), + } + + go rs.run() + + return rs +} + +// Close closes the RTCPSender. +func (rs *RTCPSender) Close() { + close(rs.terminate) + <-rs.done +} + +func (rs *RTCPSender) run() { + defer close(rs.done) + + t := time.NewTicker(rs.period) + defer t.Stop() + + for { + select { + case <-t.C: + report := rs.report(now()) + if report != nil { + rs.writePacketRTCP(report) + } + + case <-rs.terminate: + return + } } } -// Report generates a RTCP sender report. -// It returns nil if no packets has been passed to ProcessPacketRTP yet. -func (rs *RTCPSender) Report(ts time.Time) rtcp.Packet { +func (rs *RTCPSender) report(ts time.Time) rtcp.Packet { rs.mutex.Lock() defer rs.mutex.Unlock() diff --git a/pkg/rtcpsender/rtcpsender_test.go b/pkg/rtcpsender/rtcpsender_test.go index 7c41c9bc..f7a9bb50 100644 --- a/pkg/rtcpsender/rtcpsender_test.go +++ b/pkg/rtcpsender/rtcpsender_test.go @@ -10,9 +10,22 @@ import ( ) func TestRTCPSender(t *testing.T) { - rs := New(90000) + now = func() time.Time { + return time.Date(2008, 5, 20, 22, 16, 20, 600000000, time.UTC) + } + done := make(chan struct{}) - require.Equal(t, nil, rs.Report(time.Now())) + rs := New(500*time.Millisecond, 90000, func(pkt rtcp.Packet) { + require.Equal(t, &rtcp.SenderReport{ + SSRC: 0xba9da416, + NTPTime: 0xcbddcc34999997ff, + RTPTime: 0x4d185ae8, + PacketCount: 2, + OctetCount: 4, + }, pkt) + close(done) + }) + defer rs.Close() rtpPkt := rtp.Packet{ Header: rtp.Header{ @@ -42,13 +55,5 @@ func TestRTCPSender(t *testing.T) { ts = time.Date(2008, 0o5, 20, 22, 15, 20, 500000000, time.UTC) rs.ProcessPacketRTP(ts, &rtpPkt) - expectedPkt := rtcp.SenderReport{ - SSRC: 0xba9da416, - NTPTime: 0xcbddcc34999997ff, - RTPTime: 0x4d185ae8, - PacketCount: 2, - OctetCount: 4, - } - ts = time.Date(2008, 0o5, 20, 22, 16, 20, 600000000, time.UTC) - require.Equal(t, &expectedPkt, rs.Report(ts)) + <-done } diff --git a/serversession.go b/serversession.go index 335cfeb9..2a1b4743 100644 --- a/serversession.go +++ b/serversession.go @@ -164,7 +164,7 @@ type ServerSession struct { ctxCancel func() conns map[*ServerConn]struct{} state ServerSessionState - setuppedTracks map[int]ServerSessionSetuppedTrack + setuppedTracks map[int]*ServerSessionSetuppedTrack tcpTracksByChannel map[int]int setuppedTransport *Transport setuppedBaseURL *base.URL // publish @@ -229,7 +229,7 @@ func (ss *ServerSession) State() ServerSessionState { } // SetuppedTracks returns the setupped tracks. -func (ss *ServerSession) SetuppedTracks() map[int]ServerSessionSetuppedTrack { +func (ss *ServerSession) SetuppedTracks() map[int]*ServerSessionSetuppedTrack { return ss.setuppedTracks } @@ -283,6 +283,11 @@ func (ss *ServerSession) run() { if *ss.setuppedTransport == TransportUDP { ss.s.udpRTPListener.removeClient(ss) ss.s.udpRTCPListener.removeClient(ss) + + for trackID := range ss.setuppedTracks { + ss.announcedTracks[trackID].rtcpReceiver.Close() + ss.announcedTracks[trackID].rtcpReceiver = nil + } } } @@ -419,18 +424,6 @@ func (ss *ServerSession) runInner() error { ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod) - case <-ss.udpReceiverReportTimer.C: - now := time.Now() - - for trackID, track := range ss.announcedTracks { - rr := track.rtcpReceiver.Report(now) - if rr != nil { - ss.WritePacketRTCP(trackID, rr) - } - } - - ss.udpReceiverReportTimer = time.NewTimer(ss.s.udpReceiverReportPeriod) - case <-ss.ctx.Done(): return liberrors.ErrServerTerminated{} } @@ -746,7 +739,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base res.Header = make(base.Header) } - sst := ServerSessionSetuppedTrack{} + sst := &ServerSessionSetuppedTrack{} switch transport { case TransportUDP: @@ -797,16 +790,11 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base } if ss.setuppedTracks == nil { - ss.setuppedTracks = make(map[int]ServerSessionSetuppedTrack) + ss.setuppedTracks = make(map[int]*ServerSessionSetuppedTrack) } ss.setuppedTracks[trackID] = sst - if ss.state == ServerSessionStatePreRecord && *ss.setuppedTransport != TransportTCP { - ss.announcedTracks[trackID].rtcpReceiver = rtcpreceiver.New(nil, - ss.announcedTracks[trackID].track.ClockRate()) - } - res.Header["Transport"] = th.Write() return res, err @@ -1000,19 +988,23 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base switch *ss.setuppedTransport { case TransportUDP: ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod) - ss.udpReceiverReportTimer = time.NewTimer(ss.s.udpReceiverReportPeriod) ss.writerRunning = true ss.writerDone = make(chan struct{}) go ss.runWriter() for trackID, track := range ss.setuppedTracks { - ss.s.udpRTPListener.addClient(ss.author.ip(), track.udpRTPPort, ss, trackID, true) - ss.s.udpRTCPListener.addClient(ss.author.ip(), track.udpRTCPPort, ss, trackID, true) - // open the firewall by sending packets to the counterpart ss.WritePacketRTP(trackID, &rtp.Packet{Header: rtp.Header{Version: 2}}) ss.WritePacketRTCP(trackID, &rtcp.ReceiverReport{}) + + ss.announcedTracks[trackID].rtcpReceiver = rtcpreceiver.New(ss.s.udpReceiverReportPeriod, + nil, ss.announcedTracks[trackID].track.ClockRate(), func(pkt rtcp.Packet) { + ss.WritePacketRTCP(trackID, pkt) + }) + + ss.s.udpRTPListener.addClient(ss.author.ip(), track.udpRTPPort, ss, trackID, true) + ss.s.udpRTCPListener.addClient(ss.author.ip(), track.udpRTCPPort, ss, trackID, true) } default: // TCP @@ -1099,6 +1091,11 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.s.udpRTPListener.removeClient(ss) ss.s.udpRTCPListener.removeClient(ss) + for trackID := range ss.setuppedTracks { + ss.announcedTracks[trackID].rtcpReceiver.Close() + ss.announcedTracks[trackID].rtcpReceiver = nil + } + case TransportUDPMulticast: default: // TCP