diff --git a/connclient.go b/connclient.go index 11bc15c4..e1ca31bc 100644 --- a/connclient.go +++ b/connclient.go @@ -80,12 +80,12 @@ type ConnClient struct { tcpFrameBuffer *multibuffer.MultiBuffer getParameterSupported bool backgroundError error + writeFrameMutex sync.RWMutex + writeFrameOpen bool + readCB func(int, StreamType, []byte, error) backgroundTerminate chan struct{} backgroundDone chan struct{} - readFrame chan base.InterleavedFrame - writeFrameMutex sync.RWMutex - writeFrameOpen bool } // Close closes all the ConnClient resources. diff --git a/connclientpublish.go b/connclientpublish.go index 00f9c962..ce491087 100644 --- a/connclientpublish.go +++ b/connclientpublish.go @@ -61,7 +61,6 @@ func (c *ConnClient) Record() (*base.Response, error) { } c.state = connClientStateRecord - c.writeFrameOpen = true c.backgroundTerminate = make(chan struct{}) c.backgroundDone = make(chan struct{}) diff --git a/connclientread.go b/connclientread.go index 0089d4a3..a83e385c 100644 --- a/connclientread.go +++ b/connclientread.go @@ -30,42 +30,19 @@ func (c *ConnClient) Play() (*base.Response, error) { return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) } - c.state = connClientStatePlay - - c.readFrame = make(chan base.InterleavedFrame) - c.backgroundTerminate = make(chan struct{}) - c.backgroundDone = make(chan struct{}) - - if *c.streamProtocol == StreamProtocolUDP { - // open the firewall by sending packets to the counterpart - for trackId := range c.udpRtpListeners { - c.udpRtpListeners[trackId].write( - []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) - - c.udpRtcpListeners[trackId].write( - []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) - } - - go c.backgroundPlayUDP() - } else { - go c.backgroundPlayTCP() - } - return res, nil } func (c *ConnClient) backgroundPlayUDP() { defer close(c.backgroundDone) - readFrame := c.readFrame - defer func() { for trackId := range c.udpRtpListeners { c.udpRtpListeners[trackId].stop() c.udpRtcpListeners[trackId].stop() } - close(readFrame) + c.readCB(0, 0, nil, c.backgroundError) }() for trackId := range c.udpRtpListeners { @@ -100,10 +77,6 @@ func (c *ConnClient) backgroundPlayUDP() { for { select { case <-c.backgroundTerminate: - go func() { - for range readFrame { - } - }() c.nconn.SetReadDeadline(time.Now()) <-readerDone c.backgroundError = fmt.Errorf("terminated") @@ -129,10 +102,6 @@ func (c *ConnClient) backgroundPlayUDP() { SkipResponse: true, }) if err != nil { - go func() { - for range readFrame { - } - }() c.nconn.SetReadDeadline(time.Now()) <-readerDone c.backgroundError = err @@ -146,10 +115,6 @@ func (c *ConnClient) backgroundPlayUDP() { last := time.Unix(atomic.LoadInt64(lastUnix), 0) if now.Sub(last) >= c.d.ReadTimeout { - go func() { - for range readFrame { - } - }() c.nconn.SetReadDeadline(time.Now()) <-readerDone c.backgroundError = fmt.Errorf("no packets received recently (maybe there's a firewall/NAT in between)") @@ -158,10 +123,6 @@ func (c *ConnClient) backgroundPlayUDP() { } case err := <-readerDone: - go func() { - for range readFrame { - } - }() c.backgroundError = err return } @@ -171,11 +132,7 @@ func (c *ConnClient) backgroundPlayUDP() { func (c *ConnClient) backgroundPlayTCP() { defer close(c.backgroundDone) - readFrame := c.readFrame - - defer func() { - close(readFrame) - }() + defer c.readCB(0, 0, nil, c.backgroundError) readerDone := make(chan error) go func() { @@ -192,7 +149,7 @@ func (c *ConnClient) backgroundPlayTCP() { c.rtcpReceivers[frame.TrackId].OnFrame(frame.StreamType, frame.Content) - readFrame <- frame + c.readCB(frame.TrackId, frame.StreamType, frame.Content, nil) } }() @@ -202,10 +159,6 @@ func (c *ConnClient) backgroundPlayTCP() { for { select { case <-c.backgroundTerminate: - go func() { - for range readFrame { - } - }() c.nconn.SetReadDeadline(time.Now()) <-readerDone c.backgroundError = fmt.Errorf("terminated") @@ -224,23 +177,31 @@ func (c *ConnClient) backgroundPlayTCP() { } case err := <-readerDone: - go func() { - for range readFrame { - } - }() c.backgroundError = err return } } } -// ReadFrame reads a frame. -// This can be used only after Play(). -func (c *ConnClient) ReadFrame() (int, StreamType, []byte, error) { - f, ok := <-c.readFrame - if !ok { - return 0, 0, nil, c.backgroundError - } +// OnFrame sets a callback that is called when a frame is received. +func (c *ConnClient) OnFrame(cb func(int, StreamType, []byte, error)) { + c.state = connClientStatePlay + c.readCB = cb + c.backgroundTerminate = make(chan struct{}) + c.backgroundDone = make(chan struct{}) - return f.TrackId, f.StreamType, f.Content, nil + if *c.streamProtocol == StreamProtocolUDP { + // open the firewall by sending packets to the counterpart + for trackId := range c.udpRtpListeners { + c.udpRtpListeners[trackId].write( + []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + + c.udpRtcpListeners[trackId].write( + []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) + } + + go c.backgroundPlayUDP() + } else { + go c.backgroundPlayTCP() + } } diff --git a/connclientudpl.go b/connclientudpl.go index 324bcd5a..925c8255 100644 --- a/connclientudpl.go +++ b/connclientudpl.go @@ -6,7 +6,6 @@ import ( "sync/atomic" "time" - "github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/multibuffer" ) @@ -33,7 +32,7 @@ func newConnClientUDPListener(c *ConnClient, port int) (*connClientUDPListener, return &connClientUDPListener{ c: c, pc: pc, - udpFrameBuffer: multibuffer.New(c.d.ReadBufferCount+1, 2048), + udpFrameBuffer: multibuffer.New(c.d.ReadBufferCount, 2048), }, nil } @@ -76,11 +75,7 @@ func (l *connClientUDPListener) run() { l.c.rtcpReceivers[l.trackId].OnFrame(l.streamType, buf[:n]) - l.c.readFrame <- base.InterleavedFrame{ - TrackId: l.trackId, - StreamType: l.streamType, - Content: buf[:n], - } + l.c.readCB(l.trackId, l.streamType, buf[:n], nil) } } diff --git a/dialer.go b/dialer.go index 5f16e2a3..5b67b55a 100644 --- a/dialer.go +++ b/dialer.go @@ -96,7 +96,7 @@ func (d Dialer) Dial(host string) (*ConnClient, error) { udpLastFrameTimes: make(map[int]*int64), udpRtpListeners: make(map[int]*connClientUDPListener), udpRtcpListeners: make(map[int]*connClientUDPListener), - tcpFrameBuffer: multibuffer.New(d.ReadBufferCount+1, clientTCPFrameReadBufferSize), + tcpFrameBuffer: multibuffer.New(d.ReadBufferCount, clientTCPFrameReadBufferSize), backgroundError: fmt.Errorf("not running"), }, nil } diff --git a/dialer_test.go b/dialer_test.go index 2eaaa3f7..e3b253cf 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -5,6 +5,7 @@ import ( "os" "os/exec" "strconv" + "sync/atomic" "testing" "time" @@ -58,56 +59,6 @@ func (c *container) wait() int { return int(code) } -func TestDialRead(t *testing.T) { - for _, proto := range []string{ - "udp", - "tcp", - } { - t.Run(proto, func(t *testing.T) { - cnt1, err := newContainer("rtsp-simple-server", "server", []string{"{}"}) - require.NoError(t, err) - defer cnt1.close() - - time.Sleep(1 * time.Second) - - cnt2, err := newContainer("ffmpeg", "publish", []string{ - "-re", - "-stream_loop", "-1", - "-i", "/emptyvideo.ts", - "-c", "copy", - "-f", "rtsp", - "-rtsp_transport", "udp", - "rtsp://localhost:8554/teststream", - }) - require.NoError(t, err) - defer cnt2.close() - - time.Sleep(1 * time.Second) - - dialer := func() Dialer { - if proto == "udp" { - return Dialer{} - } - return Dialer{StreamProtocol: StreamProtocolTCP} - }() - - conn, err := dialer.DialRead("rtsp://localhost:8554/teststream") - require.NoError(t, err) - - id, typ, _, err := conn.ReadFrame() - require.NoError(t, err) - - require.Equal(t, 0, id) - require.Equal(t, StreamTypeRtp, typ) - - conn.Close() - - _, _, _, err = conn.ReadFrame() - require.Error(t, err) - }) - } -} - func TestDialReadParallel(t *testing.T) { for _, proto := range []string{ "udp", @@ -144,27 +95,35 @@ func TestDialReadParallel(t *testing.T) { conn, err := dialer.DialRead("rtsp://localhost:8554/teststream") require.NoError(t, err) + var firstFrame int32 + frameRecv := make(chan struct{}) readerDone := make(chan struct{}) - go func() { - defer close(readerDone) - - for { - _, _, _, err := conn.ReadFrame() - if err != nil { - break - } + conn.OnFrame(func(id int, typ StreamType, content []byte, err error) { + if err != nil { + close(readerDone) + return } - }() - time.Sleep(1 * time.Second) + if atomic.SwapInt32(&firstFrame, 1) == 0 { + close(frameRecv) + } + }) + <-frameRecv conn.Close() <-readerDone + + readerDone = make(chan struct{}) + conn.OnFrame(func(id int, typ StreamType, content []byte, err error) { + require.Error(t, err) + close(readerDone) + }) + <-readerDone }) } } -func TestDialReadRedirect(t *testing.T) { +func TestDialReadRedirectParallel(t *testing.T) { cnt1, err := newContainer("rtsp-simple-server", "server", []string{ "paths:\n" + " path1:\n" + @@ -193,62 +152,24 @@ func TestDialReadRedirect(t *testing.T) { conn, err := DialRead("rtsp://localhost:8554/path1") require.NoError(t, err) - defer conn.Close() - _, _, _, err = conn.ReadFrame() - require.NoError(t, err) -} + var firstFrame int32 + frameRecv := make(chan struct{}) + readerDone := make(chan struct{}) + conn.OnFrame(func(id int, typ StreamType, content []byte, err error) { + if err != nil { + close(readerDone) + return + } -func TestDialReadPause(t *testing.T) { - for _, proto := range []string{ - "udp", - "tcp", - } { - t.Run(proto, func(t *testing.T) { - cnt1, err := newContainer("rtsp-simple-server", "server", []string{"{}"}) - require.NoError(t, err) - defer cnt1.close() + if atomic.SwapInt32(&firstFrame, 1) == 0 { + close(frameRecv) + } + }) - time.Sleep(1 * time.Second) - - cnt2, err := newContainer("ffmpeg", "publish", []string{ - "-re", - "-stream_loop", "-1", - "-i", "/emptyvideo.ts", - "-c", "copy", - "-f", "rtsp", - "-rtsp_transport", "udp", - "rtsp://localhost:8554/teststream", - }) - require.NoError(t, err) - defer cnt2.close() - - time.Sleep(1 * time.Second) - - dialer := func() Dialer { - if proto == "udp" { - return Dialer{} - } - return Dialer{StreamProtocol: StreamProtocolTCP} - }() - - conn, err := dialer.DialRead("rtsp://localhost:8554/teststream") - require.NoError(t, err) - defer conn.Close() - - _, _, _, err = conn.ReadFrame() - require.NoError(t, err) - - _, err = conn.Pause() - require.NoError(t, err) - - _, err = conn.Play() - require.NoError(t, err) - - _, _, _, err = conn.ReadFrame() - require.NoError(t, err) - }) - } + <-frameRecv + conn.Close() + <-readerDone } func TestDialReadPauseParallel(t *testing.T) { @@ -287,30 +208,50 @@ func TestDialReadPauseParallel(t *testing.T) { conn, err := dialer.DialRead("rtsp://localhost:8554/teststream") require.NoError(t, err) + firstFrame := int32(0) + frameRecv := make(chan struct{}) readerDone := make(chan struct{}) - go func() { - defer close(readerDone) - - for { - _, _, _, err := conn.ReadFrame() - if err != nil { - break - } + conn.OnFrame(func(id int, typ StreamType, content []byte, err error) { + if err != nil { + close(readerDone) + return } - }() - time.Sleep(1 * time.Second) + if atomic.SwapInt32(&firstFrame, 1) == 0 { + close(frameRecv) + } + }) + <-frameRecv _, err = conn.Pause() require.NoError(t, err) <-readerDone + _, err = conn.Play() + require.NoError(t, err) + + firstFrame = int32(0) + frameRecv = make(chan struct{}) + readerDone = make(chan struct{}) + conn.OnFrame(func(id int, typ StreamType, content []byte, err error) { + if err != nil { + close(readerDone) + return + } + + if atomic.SwapInt32(&firstFrame, 1) == 0 { + close(frameRecv) + } + }) + + <-frameRecv conn.Close() + <-readerDone }) } } -func TestDialPublish(t *testing.T) { +func TestDialPublishSerial(t *testing.T) { for _, proto := range []string{ "udp", "tcp", @@ -476,7 +417,7 @@ func TestDialPublishParallel(t *testing.T) { } } -func TestDialPublishPause(t *testing.T) { +func TestDialPublishPauseSerial(t *testing.T) { for _, proto := range []string{ "udp", "tcp", diff --git a/examples/client-read-tcp.go b/examples/client-read-tcp.go index 3ee042a8..192dfd0f 100644 --- a/examples/client-read-tcp.go +++ b/examples/client-read-tcp.go @@ -22,15 +22,18 @@ func main() { } defer conn.Close() + readerDone := make(chan struct{}) + defer func() { <-readerDone }() + // read frames - for { - id, typ, buf, err := conn.ReadFrame() + conn.OnFrame(func(id int, typ gortsplib.StreamType, buf []byte, err error) { if err != nil { - fmt.Printf("connection is closed (%s)\n", err) - break + fmt.Printf("ERR: %v\n", err) + close(readerDone) + return } fmt.Printf("frame from track %d, type %v: %v\n", id, typ, buf) - } + }) } diff --git a/examples/client-read-udp.go b/examples/client-read-udp.go index 3f2fa0ff..944e523c 100644 --- a/examples/client-read-udp.go +++ b/examples/client-read-udp.go @@ -19,15 +19,18 @@ func main() { } defer conn.Close() + readerDone := make(chan struct{}) + defer func() { <-readerDone }() + // read frames - for { - id, typ, buf, err := conn.ReadFrame() + conn.OnFrame(func(id int, typ gortsplib.StreamType, buf []byte, err error) { if err != nil { - fmt.Printf("connection is closed (%s)\n", err) - break + fmt.Printf("ERR: %v\n", err) + close(readerDone) + return } fmt.Printf("frame from track %d, type %v: %v\n", id, typ, buf) - } + }) }