diff --git a/client_publish_test.go b/client_publish_test.go index ab4b0eb4..b19368e3 100644 --- a/client_publish_test.go +++ b/client_publish_test.go @@ -181,12 +181,16 @@ func TestClientPublishSerial(t *testing.T) { require.NoError(t, err) recvDone := make(chan struct{}) - done := conn.ReadFrames(func(trackID int, streamType StreamType, payload []byte) { - require.Equal(t, 0, trackID) - require.Equal(t, StreamTypeRTCP, streamType) - require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, payload) - close(recvDone) - }) + done := make(chan struct{}) + go func() { + defer close(done) + conn.ReadFrames(func(trackID int, streamType StreamType, payload []byte) { + require.Equal(t, 0, trackID) + require.Equal(t, StreamTypeRTCP, streamType) + require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, payload) + close(recvDone) + }) + }() err = conn.WriteFrame(track.ID, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04}) diff --git a/client_read_test.go b/client_read_test.go index 48b4b192..4bfb9dfc 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -359,20 +359,24 @@ func TestClientRead(t *testing.T) { conn, err := c.DialRead(scheme + "://localhost:8554/teststream") require.NoError(t, err) - done := conn.ReadFrames(func(id int, streamType StreamType, payload []byte) { - require.Equal(t, 0, id) - require.Equal(t, StreamTypeRTP, streamType) - require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload) + done := make(chan struct{}) + go func() { + defer close(done) + conn.ReadFrames(func(id int, streamType StreamType, payload []byte) { + require.Equal(t, 0, id) + require.Equal(t, StreamTypeRTP, streamType) + require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload) - err = conn.WriteFrame(0, StreamTypeRTCP, []byte{0x05, 0x06, 0x07, 0x08}) - require.NoError(t, err) - }) + err = conn.WriteFrame(0, StreamTypeRTCP, []byte{0x05, 0x06, 0x07, 0x08}) + require.NoError(t, err) + }) + }() <-frameRecv conn.Close() <-done - <-conn.ReadFrames(func(id int, typ StreamType, payload []byte) { + conn.ReadFrames(func(id int, typ StreamType, payload []byte) { }) }) } @@ -590,9 +594,13 @@ func TestClientReadAnyPort(t *testing.T) { require.NoError(t, err) frameRecv := make(chan struct{}) - done := conn.ReadFrames(func(id int, typ StreamType, payload []byte) { - close(frameRecv) - }) + done := make(chan struct{}) + go func() { + defer close(done) + conn.ReadFrames(func(id int, typ StreamType, payload []byte) { + close(frameRecv) + }) + }() <-frameRecv conn.Close() @@ -704,9 +712,13 @@ func TestClientReadAutomaticProtocol(t *testing.T) { require.NoError(t, err) frameRecv := make(chan struct{}) - done := conn.ReadFrames(func(id int, typ StreamType, payload []byte) { - close(frameRecv) - }) + done := make(chan struct{}) + go func() { + defer close(done) + conn.ReadFrames(func(id int, typ StreamType, payload []byte) { + close(frameRecv) + }) + }() <-frameRecv conn.Close() @@ -872,9 +884,13 @@ func TestClientReadAutomaticProtocol(t *testing.T) { require.NoError(t, err) frameRecv := make(chan struct{}) - done := conn.ReadFrames(func(id int, typ StreamType, payload []byte) { - close(frameRecv) - }) + done := make(chan struct{}) + go func() { + defer close(done) + conn.ReadFrames(func(id int, typ StreamType, payload []byte) { + close(frameRecv) + }) + }() <-frameRecv conn.Close() @@ -1013,9 +1029,13 @@ func TestClientReadRedirect(t *testing.T) { require.NoError(t, err) frameRecv := make(chan struct{}) - done := conn.ReadFrames(func(id int, typ StreamType, payload []byte) { - close(frameRecv) - }) + done := make(chan struct{}) + go func() { + defer close(done) + conn.ReadFrames(func(id int, typ StreamType, payload []byte) { + close(frameRecv) + }) + }() <-frameRecv conn.Close() @@ -1214,18 +1234,22 @@ func TestClientReadPause(t *testing.T) { firstFrame := int32(0) frameRecv := make(chan struct{}) - done := conn.ReadFrames(func(id int, typ StreamType, payload []byte) { - if atomic.SwapInt32(&firstFrame, 1) == 0 { - close(frameRecv) - } - }) + done := make(chan struct{}) + go func() { + defer close(done) + conn.ReadFrames(func(id int, typ StreamType, payload []byte) { + if atomic.SwapInt32(&firstFrame, 1) == 0 { + close(frameRecv) + } + }) + }() <-frameRecv _, err = conn.Pause() require.NoError(t, err) <-done - <-conn.ReadFrames(func(id int, typ StreamType, payload []byte) { + conn.ReadFrames(func(id int, typ StreamType, payload []byte) { }) _, err = conn.Play() @@ -1233,11 +1257,15 @@ func TestClientReadPause(t *testing.T) { firstFrame = int32(0) frameRecv = make(chan struct{}) - done = conn.ReadFrames(func(id int, typ StreamType, payload []byte) { - if atomic.SwapInt32(&firstFrame, 1) == 0 { - close(frameRecv) - } - }) + done = make(chan struct{}) + go func() { + defer close(done) + conn.ReadFrames(func(id int, typ StreamType, payload []byte) { + if atomic.SwapInt32(&firstFrame, 1) == 0 { + close(frameRecv) + } + }) + }() <-frameRecv conn.Close() @@ -1398,12 +1426,16 @@ func TestClientReadRTCPReport(t *testing.T) { recv := 0 recvDone := make(chan struct{}) - done := conn.ReadFrames(func(id int, typ StreamType, payload []byte) { - recv++ - if recv >= 3 { - close(recvDone) - } - }) + done := make(chan struct{}) + go func() { + defer close(done) + conn.ReadFrames(func(id int, typ StreamType, payload []byte) { + recv++ + if recv >= 3 { + close(recvDone) + } + }) + }() time.Sleep(1300 * time.Millisecond) @@ -1559,7 +1591,7 @@ func TestClientReadErrorTimeout(t *testing.T) { require.NoError(t, err) defer conn.Close() - err = <-conn.ReadFrames(func(trackID int, streamType StreamType, payload []byte) { + err = conn.ReadFrames(func(trackID int, streamType StreamType, payload []byte) { }) switch proto { @@ -1688,13 +1720,17 @@ func TestClientReadIgnoreTCPInvalidTrack(t *testing.T) { conn, err := c.DialRead("rtsp://localhost:8554/teststream") require.NoError(t, err) - defer conn.Close() recv := make(chan struct{}) - conn.ReadFrames(func(trackID int, streamType StreamType, payload []byte) { - close(recv) - }) - require.NoError(t, err) + done := make(chan struct{}) + go func() { + defer close(done) + conn.ReadFrames(func(trackID int, streamType StreamType, payload []byte) { + close(recv) + }) + }() <-recv + conn.Close() + <-done } diff --git a/clientconn.go b/clientconn.go index bb115613..2a91aff2 100644 --- a/clientconn.go +++ b/clientconn.go @@ -1164,7 +1164,7 @@ func (cc *ClientConn) doSetup( } proto := func() StreamProtocol { - // protocol set by previous Setup() or ReadFrames() + // protocol set by previous Setup() or switchProtocolIfTimeout() if cc.streamProtocol != nil { return *cc.streamProtocol } @@ -1566,8 +1566,7 @@ func (cc *ClientConn) Pause() (*base.Response, error) { } // ReadFrames starts reading frames. -// it returns a channel that is written when the reading stops. -func (cc *ClientConn) ReadFrames(onFrame func(int, StreamType, []byte)) chan error { +func (cc *ClientConn) ReadFrames(onFrame func(int, StreamType, []byte)) error { cc.readCBMutex.Lock() cc.readCB = onFrame cc.readCBMutex.Unlock() @@ -1578,12 +1577,8 @@ func (cc *ClientConn) ReadFrames(onFrame func(int, StreamType, []byte)) chan err cc.readCBSet = nil } - ch := make(chan error, 1) - go func() { - <-cc.backgroundDone - ch <- cc.backgroundErr - }() - return ch + <-cc.backgroundDone + return cc.backgroundErr } // WriteFrame writes a frame. diff --git a/examples/client-read-h264/main.go b/examples/client-read-h264/main.go index 8a3180ca..c31dd1c5 100644 --- a/examples/client-read-h264/main.go +++ b/examples/client-read-h264/main.go @@ -38,7 +38,7 @@ func main() { dec := rtph264.NewDecoder() // read RTP frames - err = <-conn.ReadFrames(func(trackID int, typ gortsplib.StreamType, buf []byte) { + err = conn.ReadFrames(func(trackID int, typ gortsplib.StreamType, buf []byte) { if trackID == h264Track { // convert RTP frames into H264 NALUs nalus, _, err := dec.Decode(buf) diff --git a/examples/client-read-options/main.go b/examples/client-read-options/main.go index 77a9110c..cd8854f0 100644 --- a/examples/client-read-options/main.go +++ b/examples/client-read-options/main.go @@ -30,7 +30,7 @@ func main() { defer conn.Close() // read RTP frames - err = <-conn.ReadFrames(func(trackID int, typ gortsplib.StreamType, buf []byte) { + err = conn.ReadFrames(func(trackID int, typ gortsplib.StreamType, buf []byte) { fmt.Printf("frame from track %d, type %v, size %d\n", trackID, typ, len(buf)) }) panic(err) diff --git a/examples/client-read-partial/main.go b/examples/client-read-partial/main.go index 618eae96..856ee0f9 100644 --- a/examples/client-read-partial/main.go +++ b/examples/client-read-partial/main.go @@ -52,7 +52,7 @@ func main() { } // read RTP frames - err = <-conn.ReadFrames(func(trackID int, typ gortsplib.StreamType, buf []byte) { + err = conn.ReadFrames(func(trackID int, typ gortsplib.StreamType, buf []byte) { fmt.Printf("frame from track %d, type %v, size %d\n", trackID, typ, len(buf)) }) panic(err) diff --git a/examples/client-read-pause/main.go b/examples/client-read-pause/main.go index 23f84ee2..e5bb96e5 100644 --- a/examples/client-read-pause/main.go +++ b/examples/client-read-pause/main.go @@ -23,9 +23,13 @@ func main() { for { // read RTP frames - done := conn.ReadFrames(func(trackID int, typ gortsplib.StreamType, buf []byte) { - fmt.Printf("frame from track %d, type %v, size %d\n", trackID, typ, len(buf)) - }) + done := make(chan struct{}) + go func() { + defer close(done) + conn.ReadFrames(func(trackID int, typ gortsplib.StreamType, buf []byte) { + fmt.Printf("frame from track %d, type %v, size %d\n", trackID, typ, len(buf)) + }) + }() // wait time.Sleep(5 * time.Second) diff --git a/examples/client-read/main.go b/examples/client-read/main.go index f904e49c..f3dbc907 100644 --- a/examples/client-read/main.go +++ b/examples/client-read/main.go @@ -18,7 +18,7 @@ func main() { defer conn.Close() // read RTP frames - err = <-conn.ReadFrames(func(trackID int, typ gortsplib.StreamType, buf []byte) { + err = conn.ReadFrames(func(trackID int, typ gortsplib.StreamType, buf []byte) { fmt.Printf("frame from track %d, type %v, size %d\n", trackID, typ, len(buf)) }) panic(err)