diff --git a/connclient.go b/connclient.go index be35e86d..c91aafb7 100644 --- a/connclient.go +++ b/connclient.go @@ -87,17 +87,14 @@ type ConnClient struct { udpLastFrameTimes map[int]*int64 udpRtpListeners map[int]*connClientUDPListener udpRtcpListeners map[int]*connClientUDPListener - response *base.Response - frame *base.InterleavedFrame tcpFrameBuffer *multibuffer.MultiBuffer - readFrameFunc func() (int, StreamType, []byte, error) writeFrameFunc func(trackId int, streamType StreamType, content []byte) error getParameterSupported bool - backgroundUDPError error + backgroundError error backgroundTerminate chan struct{} backgroundDone chan struct{} - udpFrame chan base.InterleavedFrame + readFrame chan base.InterleavedFrame } // Close closes all the ConnClient resources. @@ -117,15 +114,6 @@ func (c *ConnClient) Close() error { }) } - if s == connClientStatePlay { - if *c.streamProtocol == StreamProtocolUDP { - go func() { - for range c.udpFrame { - } - }() - } - } - for _, l := range c.udpRtpListeners { l.close() } @@ -134,12 +122,6 @@ func (c *ConnClient) Close() error { l.close() } - if s == connClientStatePlay { - if *c.streamProtocol == StreamProtocolUDP { - close(c.udpFrame) - } - } - err := c.nconn.Close() return err } @@ -169,10 +151,12 @@ func (c *ConnClient) Tracks() Tracks { } func (c *ConnClient) readFrameTCPOrResponse() (interface{}, error) { - c.frame.Content = c.tcpFrameBuffer.Next() - c.nconn.SetReadDeadline(time.Now().Add(c.d.ReadTimeout)) - return base.ReadInterleavedFrameOrResponse(c.frame, c.response, c.br) + f := base.InterleavedFrame{ + Content: c.tcpFrameBuffer.Next(), + } + r := base.Response{} + return base.ReadInterleavedFrameOrResponse(&f, &r, c.br) } // Do writes a Request and reads a Response. @@ -589,23 +573,6 @@ func (c *ConnClient) Pause() (*base.Response, error) { close(c.backgroundTerminate) <-c.backgroundDone - if s == connClientStatePlay { - if *c.streamProtocol == StreamProtocolUDP { - ch := c.udpFrame - go func() { - for range ch { - } - }() - - for trackId := range c.udpRtpListeners { - c.udpRtpListeners[trackId].stop() - c.udpRtcpListeners[trackId].stop() - } - - close(ch) - } - } - res, err := c.Do(&base.Request{ Method: base.PAUSE, URL: c.streamUrl, diff --git a/connclientpublish.go b/connclientpublish.go index 5078fd05..b629fd79 100644 --- a/connclientpublish.go +++ b/connclientpublish.go @@ -101,12 +101,12 @@ func (c *ConnClient) backgroundRecordUDP() { case <-c.backgroundTerminate: c.nconn.SetReadDeadline(time.Now()) <-readDone - c.backgroundUDPError = fmt.Errorf("terminated") + c.backgroundError = fmt.Errorf("terminated") c.state.store(connClientStateUDPError) return case err := <-readDone: - c.backgroundUDPError = err + c.backgroundError = err c.state.store(connClientStateUDPError) return } @@ -119,7 +119,7 @@ func (c *ConnClient) backgroundRecordTCP() { func (c *ConnClient) writeFrameUDP(trackId int, streamType StreamType, content []byte) error { switch c.state.load() { case connClientStateUDPError: - return c.backgroundUDPError + return c.backgroundError case connClientStateRecord: diff --git a/connclientread.go b/connclientread.go index db6237db..b4c77b33 100644 --- a/connclientread.go +++ b/connclientread.go @@ -30,33 +30,19 @@ func (c *ConnClient) Play() (*base.Response, error) { return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) } - if *c.streamProtocol == StreamProtocolUDP { - c.readFrameFunc = c.readFrameUDP - c.writeFrameFunc = c.writeFrameUDP - } else { - c.readFrameFunc = c.readFrameTCP - c.writeFrameFunc = c.writeFrameTCP - } - c.state.store(connClientStatePlay) + c.readFrame = make(chan base.InterleavedFrame) c.backgroundTerminate = make(chan struct{}) c.backgroundDone = make(chan struct{}) if *c.streamProtocol == StreamProtocolUDP { - c.udpFrame = make(chan base.InterleavedFrame) - - for trackId := range c.udpRtpListeners { - c.udpRtpListeners[trackId].start() - c.udpRtcpListeners[trackId].start() - } - // open the firewall by sending packets to the counterpart for trackId := range c.udpRtpListeners { - c.WriteFrame(trackId, StreamTypeRtp, + c.udpRtpListeners[trackId].write( []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) - c.WriteFrame(trackId, StreamTypeRtcp, + c.udpRtcpListeners[trackId].write( []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) } @@ -71,7 +57,28 @@ func (c *ConnClient) Play() (*base.Response, error) { func (c *ConnClient) backgroundPlayUDP() { defer close(c.backgroundDone) - c.nconn.SetReadDeadline(time.Time{}) // disable deadline + defer func() { + ch := c.readFrame + go func() { + for range ch { + } + }() + + for trackId := range c.udpRtpListeners { + c.udpRtpListeners[trackId].stop() + c.udpRtcpListeners[trackId].stop() + } + + close(ch) + }() + + for trackId := range c.udpRtpListeners { + c.udpRtpListeners[trackId].start() + c.udpRtcpListeners[trackId].start() + } + + // disable deadline + c.nconn.SetReadDeadline(time.Time{}) readDone := make(chan error) go func() { @@ -99,14 +106,13 @@ func (c *ConnClient) backgroundPlayUDP() { case <-c.backgroundTerminate: c.nconn.SetReadDeadline(time.Now()) <-readDone - c.backgroundUDPError = fmt.Errorf("terminated") - c.state.store(connClientStateUDPError) + c.backgroundError = fmt.Errorf("terminated") return case <-reportTicker.C: for trackId := range c.rtcpReceivers { - frame := c.rtcpReceivers[trackId].Report() - c.WriteFrame(trackId, StreamTypeRtcp, frame) + report := c.rtcpReceivers[trackId].Report() + c.udpRtcpListeners[trackId].write(report) } case <-keepaliveTicker.C: @@ -125,8 +131,7 @@ func (c *ConnClient) backgroundPlayUDP() { if err != nil { c.nconn.SetReadDeadline(time.Now()) <-readDone - c.backgroundUDPError = err - c.state.store(connClientStateUDPError) + c.backgroundError = err return } @@ -139,11 +144,14 @@ func (c *ConnClient) backgroundPlayUDP() { if now.Sub(last) >= c.d.ReadTimeout { c.nconn.SetReadDeadline(time.Now()) <-readDone - c.backgroundUDPError = fmt.Errorf("no packets received recently (maybe there's a firewall/NAT in between)") - c.state.store(connClientStateUDPError) + c.backgroundError = fmt.Errorf("no packets received recently (maybe there's a firewall/NAT in between)") return } } + + case err := <-readDone: + c.backgroundError = err + return } } } @@ -151,51 +159,71 @@ func (c *ConnClient) backgroundPlayUDP() { func (c *ConnClient) backgroundPlayTCP() { defer close(c.backgroundDone) + defer func() { + ch := c.readFrame + go func() { + for range ch { + } + }() + close(ch) + }() + + readDone := make(chan error) + go func() { + for { + c.nconn.SetReadDeadline(time.Now().Add(c.d.ReadTimeout)) + frame := base.InterleavedFrame{ + Content: c.tcpFrameBuffer.Next(), + } + err := frame.Read(c.br) + if err != nil { + readDone <- err + return + } + + c.rtcpReceivers[frame.TrackId].OnFrame(frame.StreamType, frame.Content) + + c.readFrame <- frame + } + }() + reportTicker := time.NewTicker(clientReceiverReportPeriod) defer reportTicker.Stop() for { select { case <-c.backgroundTerminate: + c.nconn.SetReadDeadline(time.Now()) + <-readDone + c.backgroundError = fmt.Errorf("terminated") return case <-reportTicker.C: for trackId := range c.rtcpReceivers { - frame := c.rtcpReceivers[trackId].Report() - c.WriteFrame(trackId, StreamTypeRtcp, frame) + report := c.rtcpReceivers[trackId].Report() + c.nconn.SetWriteDeadline(time.Now().Add(c.d.WriteTimeout)) + frame := base.InterleavedFrame{ + TrackId: trackId, + StreamType: StreamTypeRtcp, + Content: report, + } + frame.Write(c.bw) } + + case err := <-readDone: + c.backgroundError = err + return } } } -func (c *ConnClient) readFrameUDP() (int, StreamType, []byte, error) { - if c.state.load() != connClientStatePlay { - return 0, 0, nil, fmt.Errorf("not playing") - } - - f := <-c.udpFrame - return f.TrackId, f.StreamType, f.Content, nil -} - -func (c *ConnClient) readFrameTCP() (int, StreamType, []byte, error) { - if c.state.load() != connClientStatePlay { - return 0, 0, nil, fmt.Errorf("not playing") - } - - c.nconn.SetReadDeadline(time.Now().Add(c.d.ReadTimeout)) - c.frame.Content = c.tcpFrameBuffer.Next() - err := c.frame.Read(c.br) - if err != nil { - return 0, 0, nil, err - } - - c.rtcpReceivers[c.frame.TrackId].OnFrame(c.frame.StreamType, c.frame.Content) - - return c.frame.TrackId, c.frame.StreamType, c.frame.Content, nil -} - // ReadFrame reads a frame. // This can be used only after Play(). func (c *ConnClient) ReadFrame() (int, StreamType, []byte, error) { - return c.readFrameFunc() + f, ok := <-c.readFrame + if !ok { + return 0, 0, nil, c.backgroundError + } + + return f.TrackId, f.StreamType, f.Content, nil } diff --git a/connclientudpl.go b/connclientudpl.go index 819cd7bf..324bcd5a 100644 --- a/connclientudpl.go +++ b/connclientudpl.go @@ -76,7 +76,7 @@ func (l *connClientUDPListener) run() { l.c.rtcpReceivers[l.trackId].OnFrame(l.streamType, buf[:n]) - l.c.udpFrame <- base.InterleavedFrame{ + l.c.readFrame <- base.InterleavedFrame{ TrackId: l.trackId, StreamType: l.streamType, Content: buf[:n], diff --git a/dialer.go b/dialer.go index 2e4bca6f..281e005e 100644 --- a/dialer.go +++ b/dialer.go @@ -96,14 +96,12 @@ func (d Dialer) Dial(host string) (*ConnClient, error) { v := connClientState(0) return &v }(), - rtcpReceivers: make(map[int]*rtcpreceiver.RtcpReceiver), - udpLastFrameTimes: make(map[int]*int64), - udpRtpListeners: make(map[int]*connClientUDPListener), - udpRtcpListeners: make(map[int]*connClientUDPListener), - response: &base.Response{}, - frame: &base.InterleavedFrame{}, - tcpFrameBuffer: multibuffer.New(d.ReadBufferCount, clientTCPFrameReadBufferSize), - backgroundUDPError: fmt.Errorf("not running"), + rtcpReceivers: make(map[int]*rtcpreceiver.RtcpReceiver), + udpLastFrameTimes: make(map[int]*int64), + udpRtpListeners: make(map[int]*connClientUDPListener), + udpRtcpListeners: make(map[int]*connClientUDPListener), + tcpFrameBuffer: multibuffer.New(d.ReadBufferCount+1, clientTCPFrameReadBufferSize), + backgroundError: fmt.Errorf("not running"), }, nil } diff --git a/dialer_test.go b/dialer_test.go index a9f7df46..d80fe806 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -251,6 +251,64 @@ func TestDialReadPause(t *testing.T) { } } +func TestDialReadPauseParallel(t *testing.T) { + for _, proto := range []string{ + "udp", + "tcp", + } { + t.Run(proto, func(t *testing.T) { + cnt1, err := newContainer("rtsp-simple-server", "server", []string{"{}"}) + require.NoError(t, err) + defer cnt1.close() + + time.Sleep(1 * time.Second) + + cnt2, err := newContainer("ffmpeg", "publish", []string{ + "-re", + "-stream_loop", "-1", + "-i", "/emptyvideo.ts", + "-c", "copy", + "-f", "rtsp", + "-rtsp_transport", "udp", + "rtsp://localhost:8554/teststream", + }) + require.NoError(t, err) + defer cnt2.close() + + time.Sleep(1 * time.Second) + + dialer := func() Dialer { + if proto == "udp" { + return Dialer{} + } + return Dialer{StreamProtocol: StreamProtocolTCP} + }() + + conn, err := dialer.DialRead("rtsp://localhost:8554/teststream") + require.NoError(t, err) + + readDone := make(chan struct{}) + go func() { + defer close(readDone) + + for { + _, _, _, err := conn.ReadFrame() + if err != nil { + break + } + } + }() + + time.Sleep(1 * time.Second) + + conn.Pause() + <-readDone + + conn.Close() + }) + } +} + func TestDialPublish(t *testing.T) { for _, proto := range []string{ "udp",