support calling Pause() in parallel with WriteFrame(); call TEARDOWN after publishing and calling Close(); fix #13

This commit is contained in:
aler9
2020-11-15 20:11:32 +01:00
parent 862cd0ea62
commit aba0f1598c
5 changed files with 157 additions and 107 deletions

View File

@@ -14,7 +14,7 @@ import (
"net" "net"
"strconv" "strconv"
"strings" "strings"
"sync/atomic" "sync"
"time" "time"
"github.com/aler9/gortsplib/pkg/auth" "github.com/aler9/gortsplib/pkg/auth"
@@ -34,7 +34,7 @@ const (
clientUDPFrameReadBufferSize = 2048 clientUDPFrameReadBufferSize = 2048
) )
type connClientState int32 type connClientState int
const ( const (
connClientStateInitial connClientState = iota connClientStateInitial connClientState = iota
@@ -42,7 +42,6 @@ const (
connClientStatePlay connClientStatePlay
connClientStatePreRecord connClientStatePreRecord
connClientStateRecord connClientStateRecord
connClientStateUDPError
) )
func (s connClientState) String() string { func (s connClientState) String() string {
@@ -57,18 +56,9 @@ func (s connClientState) String() string {
return "preRecord" return "preRecord"
case connClientStateRecord: case connClientStateRecord:
return "record" return "record"
case connClientStateUDPError:
return "udpError"
} }
return "uknown" return "uknown"
} }
func (s *connClientState) load() connClientState {
return connClientState(atomic.LoadInt32((*int32)(s)))
}
func (s *connClientState) store(v connClientState) {
atomic.StoreInt32((*int32)(s), int32(v))
}
// ConnClient is a client-side RTSP connection. // ConnClient is a client-side RTSP connection.
type ConnClient struct { type ConnClient struct {
@@ -79,7 +69,7 @@ type ConnClient struct {
session string session string
cseq int cseq int
auth *auth.Client auth *auth.Client
state *connClientState state connClientState
streamUrl *base.URL streamUrl *base.URL
streamProtocol *StreamProtocol streamProtocol *StreamProtocol
tracks Tracks tracks Tracks
@@ -88,25 +78,24 @@ type ConnClient struct {
udpRtpListeners map[int]*connClientUDPListener udpRtpListeners map[int]*connClientUDPListener
udpRtcpListeners map[int]*connClientUDPListener udpRtcpListeners map[int]*connClientUDPListener
tcpFrameBuffer *multibuffer.MultiBuffer tcpFrameBuffer *multibuffer.MultiBuffer
writeFrameFunc func(trackId int, streamType StreamType, content []byte) error
getParameterSupported bool getParameterSupported bool
backgroundError error backgroundError error
backgroundTerminate chan struct{} backgroundTerminate chan struct{}
backgroundDone chan struct{} backgroundDone chan struct{}
readFrame chan base.InterleavedFrame readFrame chan base.InterleavedFrame
writeFrameMutex sync.RWMutex
writeFrameOpen bool
} }
// Close closes all the ConnClient resources. // Close closes all the ConnClient resources.
func (c *ConnClient) Close() error { func (c *ConnClient) Close() error {
s := c.state.load() s := c.state
if s == connClientStatePlay || s == connClientStateRecord { if s == connClientStatePlay || s == connClientStateRecord {
close(c.backgroundTerminate) close(c.backgroundTerminate)
<-c.backgroundDone <-c.backgroundDone
}
if s == connClientStatePlay {
c.Do(&base.Request{ c.Do(&base.Request{
Method: base.TEARDOWN, Method: base.TEARDOWN,
URL: c.streamUrl, URL: c.streamUrl,
@@ -126,18 +115,17 @@ func (c *ConnClient) Close() error {
return err return err
} }
func (c *ConnClient) checkState(allowed map[connClientState]struct{}) (connClientState, error) { func (c *ConnClient) checkState(allowed map[connClientState]struct{}) error {
s := c.state.load() if _, ok := allowed[c.state]; ok {
if _, ok := allowed[s]; ok { return nil
return s, nil
} }
var allowedList []connClientState var allowedList []connClientState
for s := range allowed { for a := range allowed {
allowedList = append(allowedList, s) allowedList = append(allowedList, a)
} }
return 0, fmt.Errorf("client must be in state %v, while is in state %v", return fmt.Errorf("client must be in state %v, while is in state %v",
allowedList, s) allowedList, c.state)
} }
// NetConn returns the underlying net.Conn. // NetConn returns the underlying net.Conn.
@@ -238,7 +226,7 @@ func (c *ConnClient) Do(req *base.Request) (*base.Response, error) {
// Since this method is not implemented by every RTSP server, the function // Since this method is not implemented by every RTSP server, the function
// does not fail if the returned code is StatusNotFound. // does not fail if the returned code is StatusNotFound.
func (c *ConnClient) Options(u *base.URL) (*base.Response, error) { func (c *ConnClient) Options(u *base.URL) (*base.Response, error) {
_, err := c.checkState(map[connClientState]struct{}{ err := c.checkState(map[connClientState]struct{}{
connClientStateInitial: {}, connClientStateInitial: {},
connClientStatePrePlay: {}, connClientStatePrePlay: {},
connClientStatePreRecord: {}, connClientStatePreRecord: {},
@@ -278,7 +266,7 @@ func (c *ConnClient) Options(u *base.URL) (*base.Response, error) {
// Describe writes a DESCRIBE request and reads a Response. // Describe writes a DESCRIBE request and reads a Response.
func (c *ConnClient) Describe(u *base.URL) (Tracks, *base.Response, error) { func (c *ConnClient) Describe(u *base.URL) (Tracks, *base.Response, error) {
_, err := c.checkState(map[connClientState]struct{}{ err := c.checkState(map[connClientState]struct{}{
connClientStateInitial: {}, connClientStateInitial: {},
connClientStatePrePlay: {}, connClientStatePrePlay: {},
connClientStatePreRecord: {}, connClientStatePreRecord: {},
@@ -376,7 +364,7 @@ func (c *ConnClient) urlForTrack(baseUrl *base.URL, mode headers.TransportMode,
// if rtpPort and rtcpPort are zero, they are chosen automatically. // if rtpPort and rtcpPort are zero, they are chosen automatically.
func (c *ConnClient) Setup(u *base.URL, mode headers.TransportMode, proto base.StreamProtocol, func (c *ConnClient) Setup(u *base.URL, mode headers.TransportMode, proto base.StreamProtocol,
track *Track, rtpPort int, rtcpPort int) (*base.Response, error) { track *Track, rtpPort int, rtcpPort int) (*base.Response, error) {
s, err := c.checkState(map[connClientState]struct{}{ err := c.checkState(map[connClientState]struct{}{
connClientStateInitial: {}, connClientStateInitial: {},
connClientStatePrePlay: {}, connClientStatePrePlay: {},
connClientStatePreRecord: {}, connClientStatePreRecord: {},
@@ -385,12 +373,12 @@ func (c *ConnClient) Setup(u *base.URL, mode headers.TransportMode, proto base.S
return nil, err return nil, err
} }
if mode == headers.TransportModeRecord && s != connClientStatePreRecord { if mode == headers.TransportModeRecord && c.state != connClientStatePreRecord {
return nil, fmt.Errorf("cannot read and publish at the same time") return nil, fmt.Errorf("cannot read and publish at the same time")
} }
if mode == headers.TransportModePlay && s != connClientStatePrePlay && if mode == headers.TransportModePlay && c.state != connClientStatePrePlay &&
s != connClientStateInitial { c.state != connClientStateInitial {
return nil, fmt.Errorf("cannot read and publish at the same time") return nil, fmt.Errorf("cannot read and publish at the same time")
} }
@@ -551,9 +539,9 @@ func (c *ConnClient) Setup(u *base.URL, mode headers.TransportMode, proto base.S
} }
if mode == headers.TransportModePlay { if mode == headers.TransportModePlay {
*c.state = connClientStatePrePlay c.state = connClientStatePrePlay
} else { } else {
*c.state = connClientStatePreRecord c.state = connClientStatePreRecord
} }
return res, nil return res, nil
@@ -562,7 +550,7 @@ func (c *ConnClient) Setup(u *base.URL, mode headers.TransportMode, proto base.S
// Pause writes a PAUSE request and reads a Response. // Pause writes a PAUSE request and reads a Response.
// This can be called only after Play() or Record(). // This can be called only after Play() or Record().
func (c *ConnClient) Pause() (*base.Response, error) { func (c *ConnClient) Pause() (*base.Response, error) {
s, err := c.checkState(map[connClientState]struct{}{ err := c.checkState(map[connClientState]struct{}{
connClientStatePlay: {}, connClientStatePlay: {},
connClientStateRecord: {}, connClientStateRecord: {},
}) })
@@ -585,11 +573,11 @@ func (c *ConnClient) Pause() (*base.Response, error) {
return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage)
} }
switch s { switch c.state {
case connClientStatePlay: case connClientStatePlay:
c.state.store(connClientStatePrePlay) c.state = connClientStatePrePlay
case connClientStateRecord: case connClientStateRecord:
c.state.store(connClientStatePreRecord) c.state = connClientStatePreRecord
} }
return res, nil return res, nil

View File

@@ -9,7 +9,7 @@ import (
// Announce writes an ANNOUNCE request and reads a Response. // Announce writes an ANNOUNCE request and reads a Response.
func (c *ConnClient) Announce(u *base.URL, tracks Tracks) (*base.Response, error) { func (c *ConnClient) Announce(u *base.URL, tracks Tracks) (*base.Response, error) {
_, err := c.checkState(map[connClientState]struct{}{ err := c.checkState(map[connClientState]struct{}{
connClientStateInitial: {}, connClientStateInitial: {},
}) })
if err != nil { if err != nil {
@@ -33,7 +33,7 @@ func (c *ConnClient) Announce(u *base.URL, tracks Tracks) (*base.Response, error
} }
c.streamUrl = u c.streamUrl = u
*c.state = connClientStatePreRecord c.state = connClientStatePreRecord
return res, nil return res, nil
} }
@@ -41,7 +41,7 @@ func (c *ConnClient) Announce(u *base.URL, tracks Tracks) (*base.Response, error
// Record writes a RECORD request and reads a Response. // Record writes a RECORD request and reads a Response.
// This can be called only after Announce() and Setup(). // This can be called only after Announce() and Setup().
func (c *ConnClient) Record() (*base.Response, error) { func (c *ConnClient) Record() (*base.Response, error) {
_, err := c.checkState(map[connClientState]struct{}{ err := c.checkState(map[connClientState]struct{}{
connClientStatePreRecord: {}, connClientStatePreRecord: {},
}) })
if err != nil { if err != nil {
@@ -60,14 +60,9 @@ func (c *ConnClient) Record() (*base.Response, error) {
return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage)
} }
if *c.streamProtocol == StreamProtocolUDP { c.state = connClientStateRecord
c.writeFrameFunc = c.writeFrameUDP
} else {
c.writeFrameFunc = c.writeFrameTCP
}
c.state.store(connClientStateRecord)
c.writeFrameOpen = true
c.backgroundTerminate = make(chan struct{}) c.backgroundTerminate = make(chan struct{})
c.backgroundDone = make(chan struct{}) c.backgroundDone = make(chan struct{})
@@ -83,15 +78,22 @@ func (c *ConnClient) Record() (*base.Response, error) {
func (c *ConnClient) backgroundRecordUDP() { func (c *ConnClient) backgroundRecordUDP() {
defer close(c.backgroundDone) defer close(c.backgroundDone)
c.nconn.SetReadDeadline(time.Time{}) // disable deadline defer func() {
c.writeFrameMutex.Lock()
defer c.writeFrameMutex.Unlock()
c.writeFrameOpen = false
}()
readDone := make(chan error) // disable deadline
c.nconn.SetReadDeadline(time.Time{})
readerDone := make(chan error)
go func() { go func() {
for { for {
var res base.Response var res base.Response
err := res.Read(c.br) err := res.Read(c.br)
if err != nil { if err != nil {
readDone <- err readerDone <- err
return return
} }
} }
@@ -100,42 +102,43 @@ func (c *ConnClient) backgroundRecordUDP() {
select { select {
case <-c.backgroundTerminate: case <-c.backgroundTerminate:
c.nconn.SetReadDeadline(time.Now()) c.nconn.SetReadDeadline(time.Now())
<-readDone <-readerDone
c.backgroundError = fmt.Errorf("terminated") c.backgroundError = fmt.Errorf("terminated")
c.state.store(connClientStateUDPError)
return return
case err := <-readDone: case err := <-readerDone:
c.backgroundError = err c.backgroundError = err
c.state.store(connClientStateUDPError)
return return
} }
} }
func (c *ConnClient) backgroundRecordTCP() { func (c *ConnClient) backgroundRecordTCP() {
defer close(c.backgroundDone) defer close(c.backgroundDone)
defer func() {
c.writeFrameMutex.Lock()
defer c.writeFrameMutex.Unlock()
c.writeFrameOpen = false
}()
<-c.backgroundTerminate
} }
func (c *ConnClient) writeFrameUDP(trackId int, streamType StreamType, content []byte) error { // WriteFrame writes a frame.
switch c.state.load() { // This can be used only after Record().
case connClientStateUDPError: func (c *ConnClient) WriteFrame(trackId int, streamType StreamType, content []byte) error {
c.writeFrameMutex.RLock()
defer c.writeFrameMutex.RUnlock()
if !c.writeFrameOpen {
return c.backgroundError return c.backgroundError
case connClientStateRecord:
default:
return fmt.Errorf("not recording")
} }
if streamType == StreamTypeRtp { if *c.streamProtocol == StreamProtocolUDP {
return c.udpRtpListeners[trackId].write(content) if streamType == StreamTypeRtp {
} return c.udpRtpListeners[trackId].write(content)
return c.udpRtcpListeners[trackId].write(content) }
} return c.udpRtcpListeners[trackId].write(content)
func (c *ConnClient) writeFrameTCP(trackId int, streamType StreamType, content []byte) error {
if c.state.load() != connClientStateRecord {
return fmt.Errorf("not recording")
} }
c.nconn.SetWriteDeadline(time.Now().Add(c.d.WriteTimeout)) c.nconn.SetWriteDeadline(time.Now().Add(c.d.WriteTimeout))
@@ -146,9 +149,3 @@ func (c *ConnClient) writeFrameTCP(trackId int, streamType StreamType, content [
} }
return frame.Write(c.bw) return frame.Write(c.bw)
} }
// WriteFrame writes a frame.
// This can be used only after Record().
func (c *ConnClient) WriteFrame(trackId int, streamType StreamType, content []byte) error {
return c.writeFrameFunc(trackId, streamType, content)
}

View File

@@ -11,7 +11,7 @@ import (
// Play writes a PLAY request and reads a Response. // Play writes a PLAY request and reads a Response.
// This can be called only after Setup(). // This can be called only after Setup().
func (c *ConnClient) Play() (*base.Response, error) { func (c *ConnClient) Play() (*base.Response, error) {
_, err := c.checkState(map[connClientState]struct{}{ err := c.checkState(map[connClientState]struct{}{
connClientStatePrePlay: {}, connClientStatePrePlay: {},
}) })
if err != nil { if err != nil {
@@ -30,7 +30,7 @@ func (c *ConnClient) Play() (*base.Response, error) {
return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage)
} }
c.state.store(connClientStatePlay) c.state = connClientStatePlay
c.readFrame = make(chan base.InterleavedFrame) c.readFrame = make(chan base.InterleavedFrame)
c.backgroundTerminate = make(chan struct{}) c.backgroundTerminate = make(chan struct{})
@@ -80,13 +80,13 @@ func (c *ConnClient) backgroundPlayUDP() {
// disable deadline // disable deadline
c.nconn.SetReadDeadline(time.Time{}) c.nconn.SetReadDeadline(time.Time{})
readDone := make(chan error) readerDone := make(chan error)
go func() { go func() {
for { for {
var res base.Response var res base.Response
err := res.Read(c.br) err := res.Read(c.br)
if err != nil { if err != nil {
readDone <- err readerDone <- err
return return
} }
} }
@@ -105,7 +105,7 @@ func (c *ConnClient) backgroundPlayUDP() {
select { select {
case <-c.backgroundTerminate: case <-c.backgroundTerminate:
c.nconn.SetReadDeadline(time.Now()) c.nconn.SetReadDeadline(time.Now())
<-readDone <-readerDone
c.backgroundError = fmt.Errorf("terminated") c.backgroundError = fmt.Errorf("terminated")
return return
@@ -130,7 +130,7 @@ func (c *ConnClient) backgroundPlayUDP() {
}) })
if err != nil { if err != nil {
c.nconn.SetReadDeadline(time.Now()) c.nconn.SetReadDeadline(time.Now())
<-readDone <-readerDone
c.backgroundError = err c.backgroundError = err
return return
} }
@@ -143,13 +143,13 @@ func (c *ConnClient) backgroundPlayUDP() {
if now.Sub(last) >= c.d.ReadTimeout { if now.Sub(last) >= c.d.ReadTimeout {
c.nconn.SetReadDeadline(time.Now()) c.nconn.SetReadDeadline(time.Now())
<-readDone <-readerDone
c.backgroundError = fmt.Errorf("no packets received recently (maybe there's a firewall/NAT in between)") c.backgroundError = fmt.Errorf("no packets received recently (maybe there's a firewall/NAT in between)")
return return
} }
} }
case err := <-readDone: case err := <-readerDone:
c.backgroundError = err c.backgroundError = err
return return
} }
@@ -168,7 +168,7 @@ func (c *ConnClient) backgroundPlayTCP() {
close(ch) close(ch)
}() }()
readDone := make(chan error) readerDone := make(chan error)
go func() { go func() {
for { for {
c.nconn.SetReadDeadline(time.Now().Add(c.d.ReadTimeout)) c.nconn.SetReadDeadline(time.Now().Add(c.d.ReadTimeout))
@@ -177,7 +177,7 @@ func (c *ConnClient) backgroundPlayTCP() {
} }
err := frame.Read(c.br) err := frame.Read(c.br)
if err != nil { if err != nil {
readDone <- err readerDone <- err
return return
} }
@@ -194,7 +194,7 @@ func (c *ConnClient) backgroundPlayTCP() {
select { select {
case <-c.backgroundTerminate: case <-c.backgroundTerminate:
c.nconn.SetReadDeadline(time.Now()) c.nconn.SetReadDeadline(time.Now())
<-readDone <-readerDone
c.backgroundError = fmt.Errorf("terminated") c.backgroundError = fmt.Errorf("terminated")
return return
@@ -210,7 +210,7 @@ func (c *ConnClient) backgroundPlayTCP() {
frame.Write(c.bw) frame.Write(c.bw)
} }
case err := <-readDone: case err := <-readerDone:
c.backgroundError = err c.backgroundError = err
return return
} }

View File

@@ -88,14 +88,10 @@ func (d Dialer) Dial(host string) (*ConnClient, error) {
} }
return &ConnClient{ return &ConnClient{
d: d, d: d,
nconn: nconn, nconn: nconn,
br: bufio.NewReaderSize(nconn, clientReadBufferSize), br: bufio.NewReaderSize(nconn, clientReadBufferSize),
bw: bufio.NewWriterSize(nconn, clientWriteBufferSize), bw: bufio.NewWriterSize(nconn, clientWriteBufferSize),
state: func() *connClientState {
v := connClientState(0)
return &v
}(),
rtcpReceivers: make(map[int]*rtcpreceiver.RtcpReceiver), rtcpReceivers: make(map[int]*rtcpreceiver.RtcpReceiver),
udpLastFrameTimes: make(map[int]*int64), udpLastFrameTimes: make(map[int]*int64),
udpRtpListeners: make(map[int]*connClientUDPListener), udpRtpListeners: make(map[int]*connClientUDPListener),

View File

@@ -144,9 +144,9 @@ func TestDialReadParallel(t *testing.T) {
conn, err := dialer.DialRead("rtsp://localhost:8554/teststream") conn, err := dialer.DialRead("rtsp://localhost:8554/teststream")
require.NoError(t, err) require.NoError(t, err)
readDone := make(chan struct{}) readerDone := make(chan struct{})
go func() { go func() {
defer close(readDone) defer close(readerDone)
for { for {
_, _, _, err := conn.ReadFrame() _, _, _, err := conn.ReadFrame()
@@ -159,7 +159,7 @@ func TestDialReadParallel(t *testing.T) {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
conn.Close() conn.Close()
<-readDone <-readerDone
}) })
} }
} }
@@ -287,9 +287,9 @@ func TestDialReadPauseParallel(t *testing.T) {
conn, err := dialer.DialRead("rtsp://localhost:8554/teststream") conn, err := dialer.DialRead("rtsp://localhost:8554/teststream")
require.NoError(t, err) require.NoError(t, err)
readDone := make(chan struct{}) readerDone := make(chan struct{})
go func() { go func() {
defer close(readDone) defer close(readerDone)
for { for {
_, _, _, err := conn.ReadFrame() _, _, _, err := conn.ReadFrame()
@@ -301,8 +301,9 @@ func TestDialReadPauseParallel(t *testing.T) {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
conn.Pause() _, err = conn.Pause()
<-readDone require.NoError(t, err)
<-readerDone
conn.Close() conn.Close()
}) })
@@ -415,8 +416,8 @@ func TestDialPublishParallel(t *testing.T) {
track, err := NewTrackH264(0, sps, pps) track, err := NewTrackH264(0, sps, pps)
require.NoError(t, err) require.NoError(t, err)
writeDone := make(chan struct{}) writerDone := make(chan struct{})
defer func() { <-writeDone }() defer func() { <-writerDone }()
var conn *ConnClient var conn *ConnClient
defer func() { conn.Close() }() defer func() { conn.Close() }()
@@ -429,7 +430,7 @@ func TestDialPublishParallel(t *testing.T) {
}() }()
go func() { go func() {
defer close(writeDone) defer close(writerDone)
port := "8554" port := "8554"
if ca.server == "ffmpeg" { if ca.server == "ffmpeg" {
@@ -542,3 +543,71 @@ func TestDialPublishPause(t *testing.T) {
}) })
} }
} }
func TestDialPublishPauseParallel(t *testing.T) {
for _, proto := range []string{
"udp",
"tcp",
} {
t.Run(proto, func(t *testing.T) {
cnt1, err := newContainer("rtsp-simple-server", "server", []string{"{}"})
require.NoError(t, err)
defer cnt1.close()
time.Sleep(1 * time.Second)
pc, err := net.ListenPacket("udp4", "127.0.0.1:0")
require.NoError(t, err)
defer pc.Close()
cnt2, err := newContainer("gstreamer", "source", []string{
"filesrc location=emptyvideo.ts ! tsdemux ! video/x-h264" +
" ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=" + strconv.FormatInt(int64(pc.LocalAddr().(*net.UDPAddr).Port), 10),
})
require.NoError(t, err)
defer cnt2.close()
decoder := rtph264.NewDecoderFromPacketConn(pc)
sps, pps, err := decoder.ReadSPSPPS()
require.NoError(t, err)
track, err := NewTrackH264(0, sps, pps)
require.NoError(t, err)
dialer := func() Dialer {
if proto == "udp" {
return Dialer{}
}
return Dialer{StreamProtocol: StreamProtocolTCP}
}()
conn, err := dialer.DialPublish("rtsp://localhost:8554/teststream",
Tracks{track})
require.NoError(t, err)
writerDone := make(chan struct{})
go func() {
defer close(writerDone)
buf := make([]byte, 2048)
for {
n, _, err := pc.ReadFrom(buf)
require.NoError(t, err)
err = conn.WriteFrame(track.Id, StreamTypeRtp, buf[:n])
if err != nil {
break
}
}
}()
time.Sleep(1 * time.Second)
_, err = conn.Pause()
require.NoError(t, err)
<-writerDone
conn.Close()
})
}
}