From aba0f1598c9a3db46ee27dc0472689b5949c4a24 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sun, 15 Nov 2020 20:11:32 +0100 Subject: [PATCH] support calling Pause() in parallel with WriteFrame(); call TEARDOWN after publishing and calling Close(); fix #13 --- connclient.go | 62 +++++++++++++----------------- connclientpublish.go | 77 ++++++++++++++++++-------------------- connclientread.go | 24 ++++++------ dialer.go | 12 ++---- dialer_test.go | 89 +++++++++++++++++++++++++++++++++++++++----- 5 files changed, 157 insertions(+), 107 deletions(-) diff --git a/connclient.go b/connclient.go index c91aafb7..11bc15c4 100644 --- a/connclient.go +++ b/connclient.go @@ -14,7 +14,7 @@ import ( "net" "strconv" "strings" - "sync/atomic" + "sync" "time" "github.com/aler9/gortsplib/pkg/auth" @@ -34,7 +34,7 @@ const ( clientUDPFrameReadBufferSize = 2048 ) -type connClientState int32 +type connClientState int const ( connClientStateInitial connClientState = iota @@ -42,7 +42,6 @@ const ( connClientStatePlay connClientStatePreRecord connClientStateRecord - connClientStateUDPError ) func (s connClientState) String() string { @@ -57,18 +56,9 @@ func (s connClientState) String() string { return "preRecord" case connClientStateRecord: return "record" - case connClientStateUDPError: - return "udpError" } return "uknown" } -func (s *connClientState) load() connClientState { - return connClientState(atomic.LoadInt32((*int32)(s))) -} - -func (s *connClientState) store(v connClientState) { - atomic.StoreInt32((*int32)(s), int32(v)) -} // ConnClient is a client-side RTSP connection. type ConnClient struct { @@ -79,7 +69,7 @@ type ConnClient struct { session string cseq int auth *auth.Client - state *connClientState + state connClientState streamUrl *base.URL streamProtocol *StreamProtocol tracks Tracks @@ -88,25 +78,24 @@ type ConnClient struct { udpRtpListeners map[int]*connClientUDPListener udpRtcpListeners map[int]*connClientUDPListener tcpFrameBuffer *multibuffer.MultiBuffer - writeFrameFunc func(trackId int, streamType StreamType, content []byte) error getParameterSupported bool backgroundError error backgroundTerminate chan struct{} backgroundDone chan struct{} readFrame chan base.InterleavedFrame + writeFrameMutex sync.RWMutex + writeFrameOpen bool } // Close closes all the ConnClient resources. func (c *ConnClient) Close() error { - s := c.state.load() + s := c.state if s == connClientStatePlay || s == connClientStateRecord { close(c.backgroundTerminate) <-c.backgroundDone - } - if s == connClientStatePlay { c.Do(&base.Request{ Method: base.TEARDOWN, URL: c.streamUrl, @@ -126,18 +115,17 @@ func (c *ConnClient) Close() error { return err } -func (c *ConnClient) checkState(allowed map[connClientState]struct{}) (connClientState, error) { - s := c.state.load() - if _, ok := allowed[s]; ok { - return s, nil +func (c *ConnClient) checkState(allowed map[connClientState]struct{}) error { + if _, ok := allowed[c.state]; ok { + return nil } var allowedList []connClientState - for s := range allowed { - allowedList = append(allowedList, s) + for a := range allowed { + allowedList = append(allowedList, a) } - return 0, fmt.Errorf("client must be in state %v, while is in state %v", - allowedList, s) + return fmt.Errorf("client must be in state %v, while is in state %v", + allowedList, c.state) } // NetConn returns the underlying net.Conn. @@ -238,7 +226,7 @@ func (c *ConnClient) Do(req *base.Request) (*base.Response, error) { // Since this method is not implemented by every RTSP server, the function // does not fail if the returned code is StatusNotFound. func (c *ConnClient) Options(u *base.URL) (*base.Response, error) { - _, err := c.checkState(map[connClientState]struct{}{ + err := c.checkState(map[connClientState]struct{}{ connClientStateInitial: {}, connClientStatePrePlay: {}, connClientStatePreRecord: {}, @@ -278,7 +266,7 @@ func (c *ConnClient) Options(u *base.URL) (*base.Response, error) { // Describe writes a DESCRIBE request and reads a Response. func (c *ConnClient) Describe(u *base.URL) (Tracks, *base.Response, error) { - _, err := c.checkState(map[connClientState]struct{}{ + err := c.checkState(map[connClientState]struct{}{ connClientStateInitial: {}, connClientStatePrePlay: {}, connClientStatePreRecord: {}, @@ -376,7 +364,7 @@ func (c *ConnClient) urlForTrack(baseUrl *base.URL, mode headers.TransportMode, // if rtpPort and rtcpPort are zero, they are chosen automatically. func (c *ConnClient) Setup(u *base.URL, mode headers.TransportMode, proto base.StreamProtocol, track *Track, rtpPort int, rtcpPort int) (*base.Response, error) { - s, err := c.checkState(map[connClientState]struct{}{ + err := c.checkState(map[connClientState]struct{}{ connClientStateInitial: {}, connClientStatePrePlay: {}, connClientStatePreRecord: {}, @@ -385,12 +373,12 @@ func (c *ConnClient) Setup(u *base.URL, mode headers.TransportMode, proto base.S return nil, err } - if mode == headers.TransportModeRecord && s != connClientStatePreRecord { + if mode == headers.TransportModeRecord && c.state != connClientStatePreRecord { return nil, fmt.Errorf("cannot read and publish at the same time") } - if mode == headers.TransportModePlay && s != connClientStatePrePlay && - s != connClientStateInitial { + if mode == headers.TransportModePlay && c.state != connClientStatePrePlay && + c.state != connClientStateInitial { return nil, fmt.Errorf("cannot read and publish at the same time") } @@ -551,9 +539,9 @@ func (c *ConnClient) Setup(u *base.URL, mode headers.TransportMode, proto base.S } if mode == headers.TransportModePlay { - *c.state = connClientStatePrePlay + c.state = connClientStatePrePlay } else { - *c.state = connClientStatePreRecord + c.state = connClientStatePreRecord } return res, nil @@ -562,7 +550,7 @@ func (c *ConnClient) Setup(u *base.URL, mode headers.TransportMode, proto base.S // Pause writes a PAUSE request and reads a Response. // This can be called only after Play() or Record(). func (c *ConnClient) Pause() (*base.Response, error) { - s, err := c.checkState(map[connClientState]struct{}{ + err := c.checkState(map[connClientState]struct{}{ connClientStatePlay: {}, connClientStateRecord: {}, }) @@ -585,11 +573,11 @@ func (c *ConnClient) Pause() (*base.Response, error) { return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) } - switch s { + switch c.state { case connClientStatePlay: - c.state.store(connClientStatePrePlay) + c.state = connClientStatePrePlay case connClientStateRecord: - c.state.store(connClientStatePreRecord) + c.state = connClientStatePreRecord } return res, nil diff --git a/connclientpublish.go b/connclientpublish.go index b629fd79..00f9c962 100644 --- a/connclientpublish.go +++ b/connclientpublish.go @@ -9,7 +9,7 @@ import ( // Announce writes an ANNOUNCE request and reads a Response. func (c *ConnClient) Announce(u *base.URL, tracks Tracks) (*base.Response, error) { - _, err := c.checkState(map[connClientState]struct{}{ + err := c.checkState(map[connClientState]struct{}{ connClientStateInitial: {}, }) if err != nil { @@ -33,7 +33,7 @@ func (c *ConnClient) Announce(u *base.URL, tracks Tracks) (*base.Response, error } c.streamUrl = u - *c.state = connClientStatePreRecord + c.state = connClientStatePreRecord return res, nil } @@ -41,7 +41,7 @@ func (c *ConnClient) Announce(u *base.URL, tracks Tracks) (*base.Response, error // Record writes a RECORD request and reads a Response. // This can be called only after Announce() and Setup(). func (c *ConnClient) Record() (*base.Response, error) { - _, err := c.checkState(map[connClientState]struct{}{ + err := c.checkState(map[connClientState]struct{}{ connClientStatePreRecord: {}, }) if err != nil { @@ -60,14 +60,9 @@ func (c *ConnClient) Record() (*base.Response, error) { return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) } - if *c.streamProtocol == StreamProtocolUDP { - c.writeFrameFunc = c.writeFrameUDP - } else { - c.writeFrameFunc = c.writeFrameTCP - } - - c.state.store(connClientStateRecord) + c.state = connClientStateRecord + c.writeFrameOpen = true c.backgroundTerminate = make(chan struct{}) c.backgroundDone = make(chan struct{}) @@ -83,15 +78,22 @@ func (c *ConnClient) Record() (*base.Response, error) { func (c *ConnClient) backgroundRecordUDP() { defer close(c.backgroundDone) - c.nconn.SetReadDeadline(time.Time{}) // disable deadline + defer func() { + c.writeFrameMutex.Lock() + defer c.writeFrameMutex.Unlock() + c.writeFrameOpen = false + }() - readDone := make(chan error) + // disable deadline + c.nconn.SetReadDeadline(time.Time{}) + + readerDone := make(chan error) go func() { for { var res base.Response err := res.Read(c.br) if err != nil { - readDone <- err + readerDone <- err return } } @@ -100,42 +102,43 @@ func (c *ConnClient) backgroundRecordUDP() { select { case <-c.backgroundTerminate: c.nconn.SetReadDeadline(time.Now()) - <-readDone + <-readerDone c.backgroundError = fmt.Errorf("terminated") - c.state.store(connClientStateUDPError) return - case err := <-readDone: + case err := <-readerDone: c.backgroundError = err - c.state.store(connClientStateUDPError) return } } func (c *ConnClient) backgroundRecordTCP() { defer close(c.backgroundDone) + + defer func() { + c.writeFrameMutex.Lock() + defer c.writeFrameMutex.Unlock() + c.writeFrameOpen = false + }() + + <-c.backgroundTerminate } -func (c *ConnClient) writeFrameUDP(trackId int, streamType StreamType, content []byte) error { - switch c.state.load() { - case connClientStateUDPError: +// WriteFrame writes a frame. +// This can be used only after Record(). +func (c *ConnClient) WriteFrame(trackId int, streamType StreamType, content []byte) error { + c.writeFrameMutex.RLock() + defer c.writeFrameMutex.RUnlock() + + if !c.writeFrameOpen { return c.backgroundError - - case connClientStateRecord: - - default: - return fmt.Errorf("not recording") } - if streamType == StreamTypeRtp { - return c.udpRtpListeners[trackId].write(content) - } - return c.udpRtcpListeners[trackId].write(content) -} - -func (c *ConnClient) writeFrameTCP(trackId int, streamType StreamType, content []byte) error { - if c.state.load() != connClientStateRecord { - return fmt.Errorf("not recording") + if *c.streamProtocol == StreamProtocolUDP { + if streamType == StreamTypeRtp { + return c.udpRtpListeners[trackId].write(content) + } + return c.udpRtcpListeners[trackId].write(content) } c.nconn.SetWriteDeadline(time.Now().Add(c.d.WriteTimeout)) @@ -146,9 +149,3 @@ func (c *ConnClient) writeFrameTCP(trackId int, streamType StreamType, content [ } return frame.Write(c.bw) } - -// WriteFrame writes a frame. -// This can be used only after Record(). -func (c *ConnClient) WriteFrame(trackId int, streamType StreamType, content []byte) error { - return c.writeFrameFunc(trackId, streamType, content) -} diff --git a/connclientread.go b/connclientread.go index b4c77b33..263f4497 100644 --- a/connclientread.go +++ b/connclientread.go @@ -11,7 +11,7 @@ import ( // Play writes a PLAY request and reads a Response. // This can be called only after Setup(). func (c *ConnClient) Play() (*base.Response, error) { - _, err := c.checkState(map[connClientState]struct{}{ + err := c.checkState(map[connClientState]struct{}{ connClientStatePrePlay: {}, }) if err != nil { @@ -30,7 +30,7 @@ func (c *ConnClient) Play() (*base.Response, error) { return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) } - c.state.store(connClientStatePlay) + c.state = connClientStatePlay c.readFrame = make(chan base.InterleavedFrame) c.backgroundTerminate = make(chan struct{}) @@ -80,13 +80,13 @@ func (c *ConnClient) backgroundPlayUDP() { // disable deadline c.nconn.SetReadDeadline(time.Time{}) - readDone := make(chan error) + readerDone := make(chan error) go func() { for { var res base.Response err := res.Read(c.br) if err != nil { - readDone <- err + readerDone <- err return } } @@ -105,7 +105,7 @@ func (c *ConnClient) backgroundPlayUDP() { select { case <-c.backgroundTerminate: c.nconn.SetReadDeadline(time.Now()) - <-readDone + <-readerDone c.backgroundError = fmt.Errorf("terminated") return @@ -130,7 +130,7 @@ func (c *ConnClient) backgroundPlayUDP() { }) if err != nil { c.nconn.SetReadDeadline(time.Now()) - <-readDone + <-readerDone c.backgroundError = err return } @@ -143,13 +143,13 @@ func (c *ConnClient) backgroundPlayUDP() { if now.Sub(last) >= c.d.ReadTimeout { c.nconn.SetReadDeadline(time.Now()) - <-readDone + <-readerDone c.backgroundError = fmt.Errorf("no packets received recently (maybe there's a firewall/NAT in between)") return } } - case err := <-readDone: + case err := <-readerDone: c.backgroundError = err return } @@ -168,7 +168,7 @@ func (c *ConnClient) backgroundPlayTCP() { close(ch) }() - readDone := make(chan error) + readerDone := make(chan error) go func() { for { c.nconn.SetReadDeadline(time.Now().Add(c.d.ReadTimeout)) @@ -177,7 +177,7 @@ func (c *ConnClient) backgroundPlayTCP() { } err := frame.Read(c.br) if err != nil { - readDone <- err + readerDone <- err return } @@ -194,7 +194,7 @@ func (c *ConnClient) backgroundPlayTCP() { select { case <-c.backgroundTerminate: c.nconn.SetReadDeadline(time.Now()) - <-readDone + <-readerDone c.backgroundError = fmt.Errorf("terminated") return @@ -210,7 +210,7 @@ func (c *ConnClient) backgroundPlayTCP() { frame.Write(c.bw) } - case err := <-readDone: + case err := <-readerDone: c.backgroundError = err return } diff --git a/dialer.go b/dialer.go index 281e005e..5f16e2a3 100644 --- a/dialer.go +++ b/dialer.go @@ -88,14 +88,10 @@ func (d Dialer) Dial(host string) (*ConnClient, error) { } return &ConnClient{ - d: d, - nconn: nconn, - br: bufio.NewReaderSize(nconn, clientReadBufferSize), - bw: bufio.NewWriterSize(nconn, clientWriteBufferSize), - state: func() *connClientState { - v := connClientState(0) - return &v - }(), + d: d, + nconn: nconn, + br: bufio.NewReaderSize(nconn, clientReadBufferSize), + bw: bufio.NewWriterSize(nconn, clientWriteBufferSize), rtcpReceivers: make(map[int]*rtcpreceiver.RtcpReceiver), udpLastFrameTimes: make(map[int]*int64), udpRtpListeners: make(map[int]*connClientUDPListener), diff --git a/dialer_test.go b/dialer_test.go index d80fe806..2eaaa3f7 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -144,9 +144,9 @@ func TestDialReadParallel(t *testing.T) { conn, err := dialer.DialRead("rtsp://localhost:8554/teststream") require.NoError(t, err) - readDone := make(chan struct{}) + readerDone := make(chan struct{}) go func() { - defer close(readDone) + defer close(readerDone) for { _, _, _, err := conn.ReadFrame() @@ -159,7 +159,7 @@ func TestDialReadParallel(t *testing.T) { time.Sleep(1 * time.Second) conn.Close() - <-readDone + <-readerDone }) } } @@ -287,9 +287,9 @@ func TestDialReadPauseParallel(t *testing.T) { conn, err := dialer.DialRead("rtsp://localhost:8554/teststream") require.NoError(t, err) - readDone := make(chan struct{}) + readerDone := make(chan struct{}) go func() { - defer close(readDone) + defer close(readerDone) for { _, _, _, err := conn.ReadFrame() @@ -301,8 +301,9 @@ func TestDialReadPauseParallel(t *testing.T) { time.Sleep(1 * time.Second) - conn.Pause() - <-readDone + _, err = conn.Pause() + require.NoError(t, err) + <-readerDone conn.Close() }) @@ -415,8 +416,8 @@ func TestDialPublishParallel(t *testing.T) { track, err := NewTrackH264(0, sps, pps) require.NoError(t, err) - writeDone := make(chan struct{}) - defer func() { <-writeDone }() + writerDone := make(chan struct{}) + defer func() { <-writerDone }() var conn *ConnClient defer func() { conn.Close() }() @@ -429,7 +430,7 @@ func TestDialPublishParallel(t *testing.T) { }() go func() { - defer close(writeDone) + defer close(writerDone) port := "8554" if ca.server == "ffmpeg" { @@ -542,3 +543,71 @@ func TestDialPublishPause(t *testing.T) { }) } } + +func TestDialPublishPauseParallel(t *testing.T) { + for _, proto := range []string{ + "udp", + "tcp", + } { + t.Run(proto, func(t *testing.T) { + cnt1, err := newContainer("rtsp-simple-server", "server", []string{"{}"}) + require.NoError(t, err) + defer cnt1.close() + + time.Sleep(1 * time.Second) + + pc, err := net.ListenPacket("udp4", "127.0.0.1:0") + require.NoError(t, err) + defer pc.Close() + + cnt2, err := newContainer("gstreamer", "source", []string{ + "filesrc location=emptyvideo.ts ! tsdemux ! video/x-h264" + + " ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=" + strconv.FormatInt(int64(pc.LocalAddr().(*net.UDPAddr).Port), 10), + }) + require.NoError(t, err) + defer cnt2.close() + + decoder := rtph264.NewDecoderFromPacketConn(pc) + sps, pps, err := decoder.ReadSPSPPS() + require.NoError(t, err) + + track, err := NewTrackH264(0, sps, pps) + require.NoError(t, err) + + dialer := func() Dialer { + if proto == "udp" { + return Dialer{} + } + return Dialer{StreamProtocol: StreamProtocolTCP} + }() + + conn, err := dialer.DialPublish("rtsp://localhost:8554/teststream", + Tracks{track}) + require.NoError(t, err) + + writerDone := make(chan struct{}) + go func() { + defer close(writerDone) + + buf := make([]byte, 2048) + for { + n, _, err := pc.ReadFrom(buf) + require.NoError(t, err) + + err = conn.WriteFrame(track.Id, StreamTypeRtp, buf[:n]) + if err != nil { + break + } + } + }() + + time.Sleep(1 * time.Second) + + _, err = conn.Pause() + require.NoError(t, err) + <-writerDone + + conn.Close() + }) + } +}