diff --git a/connclient.go b/connclient.go index b5363dd7..e4d928d3 100644 --- a/connclient.go +++ b/connclient.go @@ -79,13 +79,20 @@ type ConnClient struct { udpRtcpListeners map[int]*connClientUDPListener tcpFrameBuffer *multibuffer.MultiBuffer getParameterSupported bool - backgroundError error - writeFrameMutex sync.RWMutex - writeFrameOpen bool - readCB func(int, StreamType, []byte, error) + // read only + readCB func(int, StreamType, []byte) + + // publish only + publishError error + publishMutex sync.RWMutex + publishOpen bool + + // in backgroundTerminate chan struct{} - backgroundDone chan struct{} + + // out + backgroundDone chan struct{} } // Close closes all the ConnClient resources. diff --git a/connclientpublish.go b/connclientpublish.go index ce491087..2b53635f 100644 --- a/connclientpublish.go +++ b/connclientpublish.go @@ -61,7 +61,7 @@ func (c *ConnClient) Record() (*base.Response, error) { } c.state = connClientStateRecord - c.writeFrameOpen = true + c.publishOpen = true c.backgroundTerminate = make(chan struct{}) c.backgroundDone = make(chan struct{}) @@ -78,9 +78,9 @@ func (c *ConnClient) backgroundRecordUDP() { defer close(c.backgroundDone) defer func() { - c.writeFrameMutex.Lock() - defer c.writeFrameMutex.Unlock() - c.writeFrameOpen = false + c.publishMutex.Lock() + defer c.publishMutex.Unlock() + c.publishOpen = false }() // disable deadline @@ -102,11 +102,11 @@ func (c *ConnClient) backgroundRecordUDP() { case <-c.backgroundTerminate: c.nconn.SetReadDeadline(time.Now()) <-readerDone - c.backgroundError = fmt.Errorf("terminated") + c.publishError = fmt.Errorf("terminated") return case err := <-readerDone: - c.backgroundError = err + c.publishError = err return } } @@ -115,9 +115,9 @@ func (c *ConnClient) backgroundRecordTCP() { defer close(c.backgroundDone) defer func() { - c.writeFrameMutex.Lock() - defer c.writeFrameMutex.Unlock() - c.writeFrameOpen = false + c.publishMutex.Lock() + defer c.publishMutex.Unlock() + c.publishOpen = false }() <-c.backgroundTerminate @@ -126,11 +126,11 @@ func (c *ConnClient) backgroundRecordTCP() { // WriteFrame writes a frame. // This can be used only after Record(). func (c *ConnClient) WriteFrame(trackId int, streamType StreamType, content []byte) error { - c.writeFrameMutex.RLock() - defer c.writeFrameMutex.RUnlock() + c.publishMutex.RLock() + defer c.publishMutex.RUnlock() - if !c.writeFrameOpen { - return c.backgroundError + if !c.publishOpen { + return c.publishError } if *c.streamProtocol == StreamProtocolUDP { diff --git a/connclientread.go b/connclientread.go index 7a687dbd..992c9096 100644 --- a/connclientread.go +++ b/connclientread.go @@ -33,16 +33,18 @@ func (c *ConnClient) Play() (*base.Response, error) { return res, nil } -func (c *ConnClient) backgroundPlayUDP() { +func (c *ConnClient) backgroundPlayUDP(onFrameDone chan error) { defer close(c.backgroundDone) + var returnError error + defer func() { for trackId := range c.udpRtpListeners { c.udpRtpListeners[trackId].stop() c.udpRtcpListeners[trackId].stop() } - c.readCB(0, 0, nil, c.backgroundError) + onFrameDone <- returnError }() for trackId := range c.udpRtpListeners { @@ -79,7 +81,7 @@ func (c *ConnClient) backgroundPlayUDP() { case <-c.backgroundTerminate: c.nconn.SetReadDeadline(time.Now()) <-readerDone - c.backgroundError = fmt.Errorf("terminated") + returnError = fmt.Errorf("terminated") return case <-reportTicker.C: @@ -104,7 +106,7 @@ func (c *ConnClient) backgroundPlayUDP() { if err != nil { c.nconn.SetReadDeadline(time.Now()) <-readerDone - c.backgroundError = err + returnError = err return } @@ -117,22 +119,26 @@ func (c *ConnClient) backgroundPlayUDP() { if now.Sub(last) >= c.d.ReadTimeout { c.nconn.SetReadDeadline(time.Now()) <-readerDone - c.backgroundError = fmt.Errorf("no packets received recently (maybe there's a firewall/NAT in between)") + returnError = fmt.Errorf("no packets received recently (maybe there's a firewall/NAT in between)") return } } case err := <-readerDone: - c.backgroundError = err + returnError = err return } } } -func (c *ConnClient) backgroundPlayTCP() { +func (c *ConnClient) backgroundPlayTCP(onFrameDone chan error) { defer close(c.backgroundDone) - defer c.readCB(0, 0, nil, c.backgroundError) + var returnError error + + defer func() { + onFrameDone <- returnError + }() readerDone := make(chan error) go func() { @@ -148,7 +154,7 @@ func (c *ConnClient) backgroundPlayTCP() { c.rtcpReceivers[frame.TrackId].OnFrame(frame.StreamType, frame.Content) - c.readCB(frame.TrackId, frame.StreamType, frame.Content, nil) + c.readCB(frame.TrackId, frame.StreamType, frame.Content) } }() @@ -169,7 +175,7 @@ func (c *ConnClient) backgroundPlayTCP() { case <-c.backgroundTerminate: c.nconn.SetReadDeadline(time.Now()) <-readerDone - c.backgroundError = fmt.Errorf("terminated") + returnError = fmt.Errorf("terminated") return case <-reportTicker.C: @@ -185,14 +191,27 @@ func (c *ConnClient) backgroundPlayTCP() { } case err := <-readerDone: - c.backgroundError = err + returnError = err return } } } // OnFrame sets a callback that is called when a frame is received. -func (c *ConnClient) OnFrame(cb func(int, StreamType, []byte, error)) { +// it returns a channel that is called when the reading stops. +// routines. +func (c *ConnClient) OnFrame(cb func(int, StreamType, []byte)) chan error { + // channel is buffered, since listening to it is not mandatory + onFrameDone := make(chan error, 1) + + err := c.checkState(map[connClientState]struct{}{ + connClientStatePrePlay: {}, + }) + if err != nil { + onFrameDone <- err + return onFrameDone + } + c.state = connClientStatePlay c.readCB = cb c.backgroundTerminate = make(chan struct{}) @@ -208,8 +227,10 @@ func (c *ConnClient) OnFrame(cb func(int, StreamType, []byte, error)) { []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) } - go c.backgroundPlayUDP() + go c.backgroundPlayUDP(onFrameDone) } else { - go c.backgroundPlayTCP() + go c.backgroundPlayTCP(onFrameDone) } + + return onFrameDone } diff --git a/connclientudpl.go b/connclientudpl.go index ecbcfb32..1098e048 100644 --- a/connclientudpl.go +++ b/connclientudpl.go @@ -75,7 +75,7 @@ func (l *connClientUDPListener) run() { l.c.rtcpReceivers[l.trackId].OnFrame(l.streamType, buf[:n]) - l.c.readCB(l.trackId, l.streamType, buf[:n], nil) + l.c.readCB(l.trackId, l.streamType, buf[:n]) } } diff --git a/dialer.go b/dialer.go index 5b67b55a..766f99c5 100644 --- a/dialer.go +++ b/dialer.go @@ -97,7 +97,7 @@ func (d Dialer) Dial(host string) (*ConnClient, error) { udpRtpListeners: make(map[int]*connClientUDPListener), udpRtcpListeners: make(map[int]*connClientUDPListener), tcpFrameBuffer: multibuffer.New(d.ReadBufferCount, clientTCPFrameReadBufferSize), - backgroundError: fmt.Errorf("not running"), + publishError: fmt.Errorf("not running"), }, nil } diff --git a/dialer_test.go b/dialer_test.go index e3b253cf..de7c89f4 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -97,13 +97,7 @@ func TestDialReadParallel(t *testing.T) { var firstFrame int32 frameRecv := make(chan struct{}) - readerDone := make(chan struct{}) - conn.OnFrame(func(id int, typ StreamType, content []byte, err error) { - if err != nil { - close(readerDone) - return - } - + done := conn.OnFrame(func(id int, typ StreamType, content []byte) { if atomic.SwapInt32(&firstFrame, 1) == 0 { close(frameRecv) } @@ -111,14 +105,12 @@ func TestDialReadParallel(t *testing.T) { <-frameRecv conn.Close() - <-readerDone + <-done - readerDone = make(chan struct{}) - conn.OnFrame(func(id int, typ StreamType, content []byte, err error) { - require.Error(t, err) - close(readerDone) + done = conn.OnFrame(func(id int, typ StreamType, content []byte) { + t.Error("should not happen") }) - <-readerDone + <-done }) } } @@ -155,13 +147,7 @@ func TestDialReadRedirectParallel(t *testing.T) { var firstFrame int32 frameRecv := make(chan struct{}) - readerDone := make(chan struct{}) - conn.OnFrame(func(id int, typ StreamType, content []byte, err error) { - if err != nil { - close(readerDone) - return - } - + done := conn.OnFrame(func(id int, typ StreamType, content []byte) { if atomic.SwapInt32(&firstFrame, 1) == 0 { close(frameRecv) } @@ -169,7 +155,7 @@ func TestDialReadRedirectParallel(t *testing.T) { <-frameRecv conn.Close() - <-readerDone + <-done } func TestDialReadPauseParallel(t *testing.T) { @@ -210,13 +196,7 @@ func TestDialReadPauseParallel(t *testing.T) { firstFrame := int32(0) frameRecv := make(chan struct{}) - readerDone := make(chan struct{}) - conn.OnFrame(func(id int, typ StreamType, content []byte, err error) { - if err != nil { - close(readerDone) - return - } - + done := conn.OnFrame(func(id int, typ StreamType, content []byte) { if atomic.SwapInt32(&firstFrame, 1) == 0 { close(frameRecv) } @@ -225,20 +205,14 @@ func TestDialReadPauseParallel(t *testing.T) { <-frameRecv _, err = conn.Pause() require.NoError(t, err) - <-readerDone + <-done _, err = conn.Play() require.NoError(t, err) firstFrame = int32(0) frameRecv = make(chan struct{}) - readerDone = make(chan struct{}) - conn.OnFrame(func(id int, typ StreamType, content []byte, err error) { - if err != nil { - close(readerDone) - return - } - + done = conn.OnFrame(func(id int, typ StreamType, content []byte) { if atomic.SwapInt32(&firstFrame, 1) == 0 { close(frameRecv) } @@ -246,7 +220,7 @@ func TestDialReadPauseParallel(t *testing.T) { <-frameRecv conn.Close() - <-readerDone + <-done }) } } diff --git a/examples/client-read-pause.go b/examples/client-read-pause.go index 7611d59f..f8db4ecb 100644 --- a/examples/client-read-pause.go +++ b/examples/client-read-pause.go @@ -4,6 +4,7 @@ package main import ( "fmt" + "time" "github.com/aler9/gortsplib" ) @@ -24,13 +25,7 @@ func main() { for { // read frames from the server - readerDone := make(chan struct{}) - conn.OnFrame(func(id int, typ gortsplib.StreamType, buf []byte, err error) { - if err != nil { - close(readerDone) - return - } - + done := conn.OnFrame(func(id int, typ gortsplib.StreamType, buf []byte) { fmt.Printf("frame from track %d, type %v: %v\n", id, typ, buf) }) @@ -44,13 +39,13 @@ func main() { } // join reader - <-readerDone + <-done // wait time.Sleep(5 * time.Second) // play again - _, err := conn.Play() + _, err = conn.Play() if err != nil { panic(err) } diff --git a/examples/client-read-tcp.go b/examples/client-read-tcp.go index a3203ede..ef7289b2 100644 --- a/examples/client-read-tcp.go +++ b/examples/client-read-tcp.go @@ -23,18 +23,10 @@ func main() { } defer conn.Close() - readerDone := make(chan struct{}) - // read frames from the server - conn.OnFrame(func(id int, typ gortsplib.StreamType, buf []byte, err error) { - if err != nil { - fmt.Printf("ERR: %v\n", err) - close(readerDone) - return - } - + done := conn.OnFrame(func(id int, typ gortsplib.StreamType, buf []byte) { fmt.Printf("frame from track %d, type %v: %v\n", id, typ, buf) }) - <-readerDone + <-done } diff --git a/examples/client-read-udp.go b/examples/client-read-udp.go index 778f95e9..90c09658 100644 --- a/examples/client-read-udp.go +++ b/examples/client-read-udp.go @@ -20,18 +20,10 @@ func main() { } defer conn.Close() - readerDone := make(chan struct{}) - // read frames from the server - conn.OnFrame(func(id int, typ gortsplib.StreamType, buf []byte, err error) { - if err != nil { - fmt.Printf("ERR: %v\n", err) - close(readerDone) - return - } - + done := conn.OnFrame(func(id int, typ gortsplib.StreamType, buf []byte) { fmt.Printf("frame from track %d, type %v: %v\n", id, typ, buf) }) - <-readerDone + <-done }