diff --git a/conn-client-udpl.go b/conn-client-udpl.go index 38b49403..6dc9b47f 100644 --- a/conn-client-udpl.go +++ b/conn-client-udpl.go @@ -3,6 +3,8 @@ package gortsplib import ( "net" "strconv" + "sync/atomic" + "time" ) // connClientUDPListener is a UDP listener created by SetupUDP() to receive UDP frames. @@ -46,7 +48,10 @@ func (l *connClientUDPListener) read(buf []byte) (int, error) { continue } + atomic.StoreInt64(l.c.udpLastFrameTimes[l.trackId], time.Now().Unix()) + l.c.rtcpReceivers[l.trackId].OnFrame(l.streamType, buf[:n]) + return n, nil } } diff --git a/conn-client.go b/conn-client.go index 33dd8aa6..248d17d7 100644 --- a/conn-client.go +++ b/conn-client.go @@ -14,6 +14,7 @@ import ( "net/url" "strconv" "strings" + "sync/atomic" "time" ) @@ -50,18 +51,19 @@ type ConnClientConf struct { // ConnClient is a client-side RTSP connection. type ConnClient struct { - conf ConnClientConf - nconn net.Conn - br *bufio.Reader - bw *bufio.Writer - session string - cseq int - auth *authClient - streamUrl *url.URL - streamProtocol *StreamProtocol - rtcpReceivers map[int]*RtcpReceiver - rtpListeners map[int]*connClientUDPListener - rtcpListeners map[int]*connClientUDPListener + conf ConnClientConf + nconn net.Conn + br *bufio.Reader + bw *bufio.Writer + session string + cseq int + auth *authClient + streamUrl *url.URL + streamProtocol *StreamProtocol + rtcpReceivers map[int]*RtcpReceiver + udpLastFrameTimes map[int]*int64 + udpRtpListeners map[int]*connClientUDPListener + udpRtcpListeners map[int]*connClientUDPListener receiverReportTerminate chan struct{} receiverReportDone chan struct{} @@ -88,13 +90,14 @@ func NewConnClient(conf ConnClientConf) (*ConnClient, error) { } return &ConnClient{ - conf: conf, - nconn: nconn, - br: bufio.NewReaderSize(nconn, clientReadBufferSize), - bw: bufio.NewWriterSize(nconn, clientWriteBufferSize), - rtcpReceivers: make(map[int]*RtcpReceiver), - rtpListeners: make(map[int]*connClientUDPListener), - rtcpListeners: make(map[int]*connClientUDPListener), + conf: conf, + nconn: nconn, + br: bufio.NewReaderSize(nconn, clientReadBufferSize), + bw: bufio.NewWriterSize(nconn, clientWriteBufferSize), + rtcpReceivers: make(map[int]*RtcpReceiver), + udpLastFrameTimes: make(map[int]*int64), + udpRtpListeners: make(map[int]*connClientUDPListener), + udpRtcpListeners: make(map[int]*connClientUDPListener), }, nil } @@ -115,15 +118,11 @@ func (c *ConnClient) Close() error { <-c.receiverReportDone } - for _, rr := range c.rtcpReceivers { - rr.Close() - } - - for _, l := range c.rtpListeners { + for _, l := range c.udpRtpListeners { l.close() } - for _, l := range c.rtcpListeners { + for _, l := range c.udpRtcpListeners { l.close() } @@ -439,13 +438,16 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int, c.streamProtocol = &streamProtocol c.rtcpReceivers[track.Id] = NewRtcpReceiver() + v := time.Now().Unix() + c.udpLastFrameTimes[track.Id] = &v + rtpListener.publisherIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP rtpListener.publisherPort = (*th.ServerPorts)[0] - c.rtpListeners[track.Id] = rtpListener + c.udpRtpListeners[track.Id] = rtpListener rtcpListener.publisherIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP rtcpListener.publisherPort = (*th.ServerPorts)[1] - c.rtcpListeners[track.Id] = rtcpListener + c.udpRtcpListeners[track.Id] = rtcpListener return rtpListener.read, rtcpListener.read, res, nil } @@ -560,21 +562,21 @@ func (c *ConnClient) Play(u *url.URL) (*Response, error) { // open the firewall by sending packets to every channel if *c.streamProtocol == StreamProtocolUDP { - for trackId := range c.rtpListeners { - c.rtpListeners[trackId].pc.WriteTo( + for trackId := range c.udpRtpListeners { + c.udpRtpListeners[trackId].pc.WriteTo( []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, &net.UDPAddr{ IP: c.nconn.RemoteAddr().(*net.TCPAddr).IP, Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone, - Port: c.rtpListeners[trackId].publisherPort, + Port: c.udpRtpListeners[trackId].publisherPort, }) - c.rtcpListeners[trackId].pc.WriteTo( + c.udpRtcpListeners[trackId].pc.WriteTo( []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, &net.UDPAddr{ IP: c.nconn.RemoteAddr().(*net.TCPAddr).IP, Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone, - Port: c.rtcpListeners[trackId].publisherPort, + Port: c.udpRtcpListeners[trackId].publisherPort, }) } } @@ -597,10 +599,10 @@ func (c *ConnClient) Play(u *url.URL) (*Response, error) { frame := c.rtcpReceivers[trackId].Report() if *c.streamProtocol == StreamProtocolUDP { - c.rtcpListeners[trackId].pc.WriteTo(frame, &net.UDPAddr{ + c.udpRtcpListeners[trackId].pc.WriteTo(frame, &net.UDPAddr{ IP: c.nconn.RemoteAddr().(*net.TCPAddr).IP, Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone, - Port: c.rtcpListeners[trackId].publisherPort, + Port: c.udpRtcpListeners[trackId].publisherPort, }) } else { @@ -664,11 +666,15 @@ func (c *ConnClient) LoopUDP(u *url.URL) error { } case <-checkStreamTicker.C: - for trackId := range c.rtcpReceivers { - if time.Since(c.rtcpReceivers[trackId].LastFrameTime()) >= c.conf.ReadTimeout { + now := time.Now() + + for _, lastUnix := range c.udpLastFrameTimes { + last := time.Unix(atomic.LoadInt64(lastUnix), 0) + + if now.Sub(last) >= c.conf.ReadTimeout { c.nconn.Close() <-readDone - return fmt.Errorf("no packets received recently (maybe there's a firewall/NAT)") + return fmt.Errorf("no packets received recently (maybe there's a firewall/NAT in between)") } } } diff --git a/rtcp-receiver.go b/rtcp-receiver.go index 7e946e58..f2333fb9 100644 --- a/rtcp-receiver.go +++ b/rtcp-receiver.go @@ -2,7 +2,7 @@ package gortsplib import ( "math/rand" - "time" + "sync" "github.com/pion/rtcp" ) @@ -16,10 +16,6 @@ type frameRtcpReq struct { ntpTimeMiddle uint32 } -type lastFrameTimeReq struct { - res chan time.Time -} - type reportReq struct { res chan []byte } @@ -27,121 +23,68 @@ type reportReq struct { // RtcpReceiver is an object that helps building RTCP receiver reports, by parsing // incoming frames. type RtcpReceiver struct { - frameRtp chan frameRtpReq - frameRtcp chan frameRtcpReq - lastFrameTime chan lastFrameTimeReq - report chan reportReq - terminate chan struct{} - done chan struct{} + mutex sync.Mutex + publisherSSRC uint32 + receiverSSRC uint32 + sequenceNumberCycles uint16 + lastSequenceNumber uint16 + lastSenderReport uint32 } // NewRtcpReceiver allocates a RtcpReceiver. func NewRtcpReceiver() *RtcpReceiver { - rr := &RtcpReceiver{ - frameRtp: make(chan frameRtpReq), - frameRtcp: make(chan frameRtcpReq), - lastFrameTime: make(chan lastFrameTimeReq), - report: make(chan reportReq), - terminate: make(chan struct{}), - done: make(chan struct{}), + return &RtcpReceiver{ + receiverSSRC: rand.Uint32(), } - - go rr.run() - - return rr } -func (rr *RtcpReceiver) run() { - lastFrameTime := time.Now() - publisherSSRC := uint32(0) - receiverSSRC := rand.Uint32() - sequenceNumberCycles := uint16(0) - lastSequenceNumber := uint16(0) - lastSenderReport := uint32(0) - -outer: - for { - select { - case req := <-rr.frameRtp: - if req.sequenceNumber < lastSequenceNumber { - sequenceNumberCycles += 1 - } - lastSequenceNumber = req.sequenceNumber - lastFrameTime = time.Now() - - case req := <-rr.frameRtcp: - publisherSSRC = req.ssrc - lastSenderReport = req.ntpTimeMiddle - - case req := <-rr.lastFrameTime: - req.res <- lastFrameTime - - case req := <-rr.report: - rr := &rtcp.ReceiverReport{ - SSRC: receiverSSRC, - Reports: []rtcp.ReceptionReport{ - { - SSRC: publisherSSRC, - LastSequenceNumber: uint32(sequenceNumberCycles)<<8 | uint32(lastSequenceNumber), - LastSenderReport: lastSenderReport, - }, - }, - } - frame, _ := rr.Marshal() - req.res <- frame - - case <-rr.terminate: - break outer - } - } - - close(rr.frameRtp) - close(rr.frameRtcp) - close(rr.lastFrameTime) - close(rr.report) - close(rr.done) -} - -// Close closes a RtcpReceiver. -func (rr *RtcpReceiver) Close() { - close(rr.terminate) - <-rr.done -} - -// OnFrame process a RTP or RTCP frame and extract the data needed by RTCP receiver reports. +// OnFrame processes a RTP or RTCP frame and extract the data needed by RTCP receiver reports. func (rr *RtcpReceiver) OnFrame(streamType StreamType, buf []byte) { + rr.mutex.Lock() + defer rr.mutex.Unlock() + if streamType == StreamTypeRtp { - // extract sequence number of first frame if len(buf) >= 3 { + // extract the sequence number of the first frame sequenceNumber := uint16(uint16(buf[2])<<8 | uint16(buf[1])) - rr.frameRtp <- frameRtpReq{sequenceNumber} + + if sequenceNumber < rr.lastSequenceNumber { + rr.sequenceNumberCycles += 1 + } + rr.lastSequenceNumber = sequenceNumber } } else { + // we can afford to unmarshal all RTCP frames + // since they are sent with a frequency much lower than the one of the RTP frames frames, err := rtcp.Unmarshal(buf) if err == nil { for _, frame := range frames { if senderReport, ok := (frame).(*rtcp.SenderReport); ok { - rr.frameRtcp <- frameRtcpReq{ - senderReport.SSRC, - uint32(senderReport.NTPTime >> 16), - } + rr.publisherSSRC = senderReport.SSRC + rr.lastSenderReport = uint32(senderReport.NTPTime >> 16) } } } } } -// LastFrameTime returns the time the last frame was received. -func (rr *RtcpReceiver) LastFrameTime() time.Time { - res := make(chan time.Time) - rr.lastFrameTime <- lastFrameTimeReq{res} - return <-res -} - // Report generates a RTCP receiver report. func (rr *RtcpReceiver) Report() []byte { - res := make(chan []byte) - rr.report <- reportReq{res} - return <-res + rr.mutex.Lock() + defer rr.mutex.Unlock() + + report := &rtcp.ReceiverReport{ + SSRC: rr.receiverSSRC, + Reports: []rtcp.ReceptionReport{ + { + SSRC: rr.publisherSSRC, + LastSequenceNumber: uint32(rr.sequenceNumberCycles)<<8 | uint32(rr.lastSequenceNumber), + LastSenderReport: rr.lastSenderReport, + }, + }, + } + + byts, _ := report.Marshal() + return byts }