From e13b4289ec331c31e9f184d607a5e5717603df56 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Fri, 12 Nov 2021 15:48:50 +0100 Subject: [PATCH] client: simplify API, add StartReadingWait --- client.go | 44 +++--- client_publish_test.go | 16 +-- client_read_test.go | 136 ++++-------------- client_test.go | 8 +- examples/client-publish-aac/main.go | 2 +- examples/client-publish-h264/main.go | 2 +- examples/client-publish-options/main.go | 2 +- examples/client-publish-opus/main.go | 2 +- examples/client-publish-pause/main.go | 8 +- examples/client-query/main.go | 2 +- .../client-read-h264-save-to-disk/main.go | 10 +- examples/client-read-h264/main.go | 10 +- examples/client-read-options/main.go | 10 +- examples/client-read-partial/main.go | 7 +- examples/client-read-pause/main.go | 12 +- examples/client-read/main.go | 10 +- 16 files changed, 82 insertions(+), 199 deletions(-) diff --git a/client.go b/client.go index 045c4d89..1593332a 100644 --- a/client.go +++ b/client.go @@ -246,8 +246,8 @@ type Client struct { done chan struct{} } -// Dial connects to a server. -func (c *Client) Dial(scheme string, host string) error { +// Start initializes the connection to a server. +func (c *Client) Start(scheme string, host string) error { // callbacks if c.OnPacketRTP == nil { c.OnPacketRTP = func(c *Client, trackID int, payload []byte) { @@ -317,14 +317,14 @@ func (c *Client) Dial(scheme string, host string) error { return nil } -// DialRead connects to the address and starts reading all tracks. -func (c *Client) DialRead(address string) error { +// StartReading connects to the address and starts reading all tracks. +func (c *Client) StartReading(address string) error { u, err := base.ParseURL(address) if err != nil { return err } - err = c.Dial(u.Scheme, u.Host) + err = c.Start(u.Scheme, u.Host) if err != nil { return err } @@ -358,14 +358,25 @@ func (c *Client) DialRead(address string) error { return nil } -// DialPublish connects to the address and starts publishing the tracks. -func (c *Client) DialPublish(address string, tracks Tracks) error { +// StartReadingAndWait connects to the address, starts reading all tracks and waits +// until a read error. +func (c *Client) StartReadingAndWait(address string) error { + err := c.StartReading(address) + if err != nil { + return err + } + + return c.Wait() +} + +// StartPublishing connects to the address and starts publishing the tracks. +func (c *Client) StartPublishing(address string, tracks Tracks) error { u, err := base.ParseURL(address) if err != nil { return err } - err = c.Dial(u.Scheme, u.Host) + err = c.Start(u.Scheme, u.Host) if err != nil { return err } @@ -399,11 +410,18 @@ func (c *Client) DialPublish(address string, tracks Tracks) error { return nil } -// Close closes all the client resources and waits for them to exit. +// Close closes all client resources and waits for them to close. func (c *Client) Close() error { c.ctxCancel() <-c.done - return nil + return c.finalErr +} + +// Wait waits until all client resources are closed. +// This can happen when a read error occurs or when Close() is called. +func (c *Client) Wait() error { + <-c.done + return c.finalErr } // Tracks returns all the tracks that the client is reading or publishing. @@ -1685,12 +1703,6 @@ func (c *Client) Seek(ra *headers.Range) (*base.Response, error) { return c.Play(ra) } -// ReadFrames starts reading frames. -func (c *Client) ReadFrames() error { - <-c.done - return c.finalErr -} - // WritePacketRTP writes a RTP packet. func (c *Client) WritePacketRTP(trackID int, payload []byte) error { select { diff --git a/client_publish_test.go b/client_publish_test.go index 6772b801..0bf6cdf7 100644 --- a/client_publish_test.go +++ b/client_publish_test.go @@ -179,14 +179,14 @@ func TestClientPublishSerial(t *testing.T) { track, err := NewTrackH264(96, &TrackConfigH264{[]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}}) require.NoError(t, err) - err = c.DialPublish("rtsp://localhost:8554/teststream", + err = c.StartPublishing("rtsp://localhost:8554/teststream", Tracks{track}) require.NoError(t, err) done := make(chan struct{}) go func() { defer close(done) - c.ReadFrames() + c.Wait() }() err = c.WritePacketRTP(0, @@ -317,7 +317,7 @@ func TestClientPublishParallel(t *testing.T) { writerDone := make(chan struct{}) defer func() { <-writerDone }() - err = c.DialPublish("rtsp://localhost:8554/teststream", + err = c.StartPublishing("rtsp://localhost:8554/teststream", Tracks{track}) require.NoError(t, err) defer c.Close() @@ -471,7 +471,7 @@ func TestClientPublishPauseSerial(t *testing.T) { track, err := NewTrackH264(96, &TrackConfigH264{[]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}}) require.NoError(t, err) - err = c.DialPublish("rtsp://localhost:8554/teststream", + err = c.StartPublishing("rtsp://localhost:8554/teststream", Tracks{track}) require.NoError(t, err) defer c.Close() @@ -608,7 +608,7 @@ func TestClientPublishPauseParallel(t *testing.T) { track, err := NewTrackH264(96, &TrackConfigH264{[]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}}) require.NoError(t, err) - err = c.DialPublish("rtsp://localhost:8554/teststream", + err = c.StartPublishing("rtsp://localhost:8554/teststream", Tracks{track}) require.NoError(t, err) @@ -748,7 +748,7 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { c := Client{} - err = c.DialPublish("rtsp://localhost:8554/teststream", + err = c.StartPublishing("rtsp://localhost:8554/teststream", Tracks{track}) require.NoError(t, err) defer c.Close() @@ -889,7 +889,7 @@ func TestClientPublishRTCPReport(t *testing.T) { track, err := NewTrackH264(96, &TrackConfigH264{[]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}}) require.NoError(t, err) - err = c.DialPublish("rtsp://localhost:8554/teststream", + err = c.StartPublishing("rtsp://localhost:8554/teststream", Tracks{track}) require.NoError(t, err) defer c.Close() @@ -1027,7 +1027,7 @@ func TestClientPublishIgnoreTCPRTPPackets(t *testing.T) { track, err := NewTrackH264(96, &TrackConfigH264{[]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}}) require.NoError(t, err) - err = c.DialPublish("rtsp://localhost:8554/teststream", + err = c.StartPublishing("rtsp://localhost:8554/teststream", Tracks{track}) require.NoError(t, err) defer c.Close() diff --git a/client_read_test.go b/client_read_test.go index fc9a3b06..5a3d786e 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -132,7 +132,7 @@ func TestClientReadTracks(t *testing.T) { c := Client{} - err = c.DialRead("rtsp://localhost:8554/teststream") + err = c.StartReading("rtsp://localhost:8554/teststream") require.NoError(t, err) defer c.Close() @@ -429,19 +429,11 @@ func TestClientRead(t *testing.T) { }, } - err = c.DialRead(scheme + "://" + listenIP + ":8554/test/stream?param=value") + err = c.StartReading(scheme + "://" + listenIP + ":8554/test/stream?param=value") require.NoError(t, err) - - done := make(chan struct{}) - - go func() { - defer close(done) - c.ReadFrames() - }() + defer c.Close() <-frameRecv - c.Close() - <-done }) } } @@ -558,18 +550,11 @@ func TestClientReadNonStandardFrameSize(t *testing.T) { }, } - err = c.DialRead("rtsp://localhost:8554/teststream") + err = c.StartReading("rtsp://localhost:8554/teststream") require.NoError(t, err) - - done := make(chan struct{}) - go func() { - defer close(done) - c.ReadFrames() - }() + defer c.Close() <-frameRecv - c.Close() - <-done } func TestClientReadPartial(t *testing.T) { @@ -682,8 +667,9 @@ func TestClientReadPartial(t *testing.T) { u, err := base.ParseURL("rtsp://" + listenIP + ":8554/teststream") require.NoError(t, err) - err = c.Dial(u.Scheme, u.Host) + err = c.Start(u.Scheme, u.Host) require.NoError(t, err) + defer c.Close() tracks, baseURL, _, err := c.Describe(u) require.NoError(t, err) @@ -694,15 +680,7 @@ func TestClientReadPartial(t *testing.T) { _, err = c.Play(nil) require.NoError(t, err) - done := make(chan struct{}) - go func() { - defer close(done) - c.ReadFrames() - }() - <-frameRecv - c.Close() - <-done } func TestClientReadNoContentBase(t *testing.T) { @@ -805,7 +783,7 @@ func TestClientReadNoContentBase(t *testing.T) { c := Client{} - err = c.DialRead("rtsp://localhost:8554/teststream") + err = c.StartReading("rtsp://localhost:8554/teststream") require.NoError(t, err) c.Close() } @@ -929,18 +907,11 @@ func TestClientReadAnyPort(t *testing.T) { }, } - err = c.DialRead("rtsp://localhost:8554/teststream") + err = c.StartReading("rtsp://localhost:8554/teststream") require.NoError(t, err) - - done := make(chan struct{}) - go func() { - defer close(done) - c.ReadFrames() - }() + defer c.Close() <-frameRecv - c.Close() - <-done }) } } @@ -1053,18 +1024,11 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }, } - err = c.DialRead("rtsp://localhost:8554/teststream") + err = c.StartReading("rtsp://localhost:8554/teststream") require.NoError(t, err) - - done := make(chan struct{}) - go func() { - defer close(done) - c.ReadFrames() - }() + defer c.Close() <-frameRecv - c.Close() - <-done }) t.Run("switch after timeout", func(t *testing.T) { @@ -1262,18 +1226,11 @@ func TestClientReadAutomaticProtocol(t *testing.T) { }, } - err = c.DialRead("rtsp://myuser:mypass@localhost:8554/teststream") + err = c.StartReading("rtsp://myuser:mypass@localhost:8554/teststream") require.NoError(t, err) - - done := make(chan struct{}) - go func() { - defer close(done) - c.ReadFrames() - }() + defer c.Close() <-frameRecv - c.Close() - <-done }) } @@ -1394,18 +1351,11 @@ func TestClientReadDifferentInterleavedIDs(t *testing.T) { }, } - err = c.DialRead("rtsp://localhost:8554/teststream") + err = c.StartReading("rtsp://localhost:8554/teststream") require.NoError(t, err) - - done := make(chan struct{}) - go func() { - defer close(done) - c.ReadFrames() - }() + defer c.Close() <-frameRecv - c.Close() - <-done } func TestClientReadRedirect(t *testing.T) { @@ -1545,18 +1495,11 @@ func TestClientReadRedirect(t *testing.T) { }, } - err = c.DialRead("rtsp://localhost:8554/path1") + err = c.StartReading("rtsp://localhost:8554/path1") require.NoError(t, err) - - done := make(chan struct{}) - go func() { - defer close(done) - c.ReadFrames() - }() + defer c.Close() <-frameRecv - c.Close() - <-done } func TestClientReadPause(t *testing.T) { @@ -1754,16 +1697,12 @@ func TestClientReadPause(t *testing.T) { }, } - err = c.DialRead("rtsp://localhost:8554/teststream") + err = c.StartReading("rtsp://localhost:8554/teststream") require.NoError(t, err) - - done := make(chan struct{}) - go func() { - defer close(done) - c.ReadFrames() - }() + defer c.Close() <-frameRecv + _, err = c.Pause() require.NoError(t, err) @@ -1774,8 +1713,6 @@ func TestClientReadPause(t *testing.T) { require.NoError(t, err) <-frameRecv - c.Close() - <-done }) } } @@ -1941,20 +1878,11 @@ func TestClientReadRTCPReport(t *testing.T) { receiverReportPeriod: 1 * time.Second, } - err = c.DialRead("rtsp://localhost:8554/teststream") + err = c.StartReading("rtsp://localhost:8554/teststream") require.NoError(t, err) - - done := make(chan struct{}) - go func() { - defer close(done) - c.ReadFrames() - }() - - time.Sleep(1300 * time.Millisecond) + defer c.Close() <-recvDone - c.Close() - <-done } func TestClientReadErrorTimeout(t *testing.T) { @@ -2101,11 +2029,10 @@ func TestClientReadErrorTimeout(t *testing.T) { ReadTimeout: 1 * time.Second, } - err = c.DialRead("rtsp://localhost:8554/teststream") + err = c.StartReading("rtsp://localhost:8554/teststream") require.NoError(t, err) - defer c.Close() - err = c.ReadFrames() + err = c.Wait() switch transport { case "udp", "auto": @@ -2236,18 +2163,11 @@ func TestClientReadIgnoreTCPInvalidTrack(t *testing.T) { }, } - err = c.DialRead("rtsp://localhost:8554/teststream") + err = c.StartReading("rtsp://localhost:8554/teststream") require.NoError(t, err) - - done := make(chan struct{}) - go func() { - defer close(done) - c.ReadFrames() - }() + defer c.Close() <-recv - c.Close() - <-done } func TestClientReadSeek(t *testing.T) { @@ -2389,7 +2309,7 @@ func TestClientReadSeek(t *testing.T) { u, err := base.ParseURL("rtsp://localhost:8554/teststream") require.NoError(t, err) - err = c.Dial(u.Scheme, u.Host) + err = c.Start(u.Scheme, u.Host) require.NoError(t, err) defer c.Close() diff --git a/client_test.go b/client_test.go index 88996042..320828f0 100644 --- a/client_test.go +++ b/client_test.go @@ -90,7 +90,7 @@ func TestClientSession(t *testing.T) { c := Client{} - err = c.Dial(u.Scheme, u.Host) + err = c.Start(u.Scheme, u.Host) require.NoError(t, err) defer c.Close() @@ -171,7 +171,7 @@ func TestClientAuth(t *testing.T) { c := Client{} - err = c.Dial(u.Scheme, u.Host) + err = c.Start(u.Scheme, u.Host) require.NoError(t, err) defer c.Close() @@ -235,7 +235,7 @@ func TestClientDescribeCharset(t *testing.T) { c := Client{} - err = c.Dial(u.Scheme, u.Host) + err = c.Start(u.Scheme, u.Host) require.NoError(t, err) defer c.Close() @@ -277,7 +277,7 @@ func TestClientCloseDuringRequest(t *testing.T) { c := Client{} - err = c.Dial(u.Scheme, u.Host) + err = c.Start(u.Scheme, u.Host) require.NoError(t, err) optionsDone := make(chan struct{}) diff --git a/examples/client-publish-aac/main.go b/examples/client-publish-aac/main.go index 39f57a61..d16ef0ae 100644 --- a/examples/client-publish-aac/main.go +++ b/examples/client-publish-aac/main.go @@ -42,7 +42,7 @@ func main() { c := gortsplib.Client{} // connect to the server and start publishing the track - err = c.DialPublish("rtsp://localhost:8554/mystream", + err = c.StartPublishing("rtsp://localhost:8554/mystream", gortsplib.Tracks{track}) if err != nil { panic(err) diff --git a/examples/client-publish-h264/main.go b/examples/client-publish-h264/main.go index c8261ea8..77c4cbd6 100644 --- a/examples/client-publish-h264/main.go +++ b/examples/client-publish-h264/main.go @@ -43,7 +43,7 @@ func main() { c := gortsplib.Client{} // connect to the server and start publishing the track - err = c.DialPublish("rtsp://localhost:8554/mystream", + err = c.StartPublishing("rtsp://localhost:8554/mystream", gortsplib.Tracks{track}) if err != nil { panic(err) diff --git a/examples/client-publish-options/main.go b/examples/client-publish-options/main.go index be6d566f..774b9613 100644 --- a/examples/client-publish-options/main.go +++ b/examples/client-publish-options/main.go @@ -52,7 +52,7 @@ func main() { } // connect to the server and start publishing the track - err = c.DialPublish("rtsp://localhost:8554/mystream", + err = c.StartPublishing("rtsp://localhost:8554/mystream", gortsplib.Tracks{track}) if err != nil { panic(err) diff --git a/examples/client-publish-opus/main.go b/examples/client-publish-opus/main.go index 020476fd..f213212a 100644 --- a/examples/client-publish-opus/main.go +++ b/examples/client-publish-opus/main.go @@ -42,7 +42,7 @@ func main() { c := gortsplib.Client{} // connect to the server and start publishing the track - err = c.DialPublish("rtsp://localhost:8554/mystream", + err = c.StartPublishing("rtsp://localhost:8554/mystream", gortsplib.Tracks{track}) if err != nil { panic(err) diff --git a/examples/client-publish-pause/main.go b/examples/client-publish-pause/main.go index 0719fc22..1d04a29d 100644 --- a/examples/client-publish-pause/main.go +++ b/examples/client-publish-pause/main.go @@ -45,7 +45,7 @@ func main() { c := gortsplib.Client{} // connect to the server and start publishing the track - err = c.DialPublish("rtsp://localhost:8554/mystream", + err = c.StartPublishing("rtsp://localhost:8554/mystream", gortsplib.Tracks{track}) if err != nil { panic(err) @@ -53,10 +53,7 @@ func main() { defer c.Close() for { - writerDone := make(chan struct{}) go func() { - defer close(writerDone) - buf := make([]byte, 2048) for { // read packets from the source @@ -82,9 +79,6 @@ func main() { panic(err) } - // join writer - <-writerDone - // wait time.Sleep(5 * time.Second) diff --git a/examples/client-query/main.go b/examples/client-query/main.go index 2cad2946..169c3ea0 100644 --- a/examples/client-query/main.go +++ b/examples/client-query/main.go @@ -19,7 +19,7 @@ func main() { c := gortsplib.Client{} - err = c.Dial(u.Scheme, u.Host) + err = c.Start(u.Scheme, u.Host) if err != nil { panic(err) } diff --git a/examples/client-read-h264-save-to-disk/main.go b/examples/client-read-h264-save-to-disk/main.go index 6b737a61..c5f0a0a1 100644 --- a/examples/client-read-h264-save-to-disk/main.go +++ b/examples/client-read-h264-save-to-disk/main.go @@ -163,13 +163,5 @@ func main() { } // connect to the server and start reading all tracks - err = c.DialRead(inputStream) - if err != nil { - panic(err) - } - defer c.Close() - - // read packets - err = c.ReadFrames() - panic(err) + panic(c.StartReadingAndWait(inputStream)) } diff --git a/examples/client-read-h264/main.go b/examples/client-read-h264/main.go index 9143bc65..7ed4fdc7 100644 --- a/examples/client-read-h264/main.go +++ b/examples/client-read-h264/main.go @@ -64,13 +64,5 @@ func main() { } // connect to the server and start reading all tracks - err := c.DialRead("rtsp://localhost:8554/mystream") - if err != nil { - panic(err) - } - defer c.Close() - - // read packets - err = c.ReadFrames() - panic(err) + panic(c.StartReadingAndWait("rtsp://localhost:8554/mystream")) } diff --git a/examples/client-read-options/main.go b/examples/client-read-options/main.go index b71627aa..4134e6b3 100644 --- a/examples/client-read-options/main.go +++ b/examples/client-read-options/main.go @@ -31,13 +31,5 @@ func main() { } // connect to the server and start reading all tracks - err := c.DialRead("rtsp://localhost:8554/mystream") - if err != nil { - panic(err) - } - defer c.Close() - - // read packets - err = c.ReadFrames() - panic(err) + panic(c.StartReadingAndWait("rtsp://localhost:8554/mystream")) } diff --git a/examples/client-read-partial/main.go b/examples/client-read-partial/main.go index 7ce65338..d583db5d 100644 --- a/examples/client-read-partial/main.go +++ b/examples/client-read-partial/main.go @@ -30,7 +30,7 @@ func main() { }, } - err = c.Dial(u.Scheme, u.Host) + err = c.Start(u.Scheme, u.Host) if err != nil { panic(err) } @@ -62,7 +62,6 @@ func main() { panic(err) } - // read packets - err = c.ReadFrames() - panic(err) + // wait until a fatal error + panic(c.Wait()) } diff --git a/examples/client-read-pause/main.go b/examples/client-read-pause/main.go index 4895e3d9..361de88c 100644 --- a/examples/client-read-pause/main.go +++ b/examples/client-read-pause/main.go @@ -26,20 +26,13 @@ func main() { } // connect to the server and start reading all tracks - err := c.DialRead("rtsp://localhost:8554/mystream") + err := c.StartReading("rtsp://localhost:8554/mystream") if err != nil { panic(err) } defer c.Close() for { - // read packets - done := make(chan struct{}) - go func() { - defer close(done) - c.ReadFrames() - }() - // wait time.Sleep(5 * time.Second) @@ -49,9 +42,6 @@ func main() { panic(err) } - // join reader - <-done - // wait time.Sleep(5 * time.Second) diff --git a/examples/client-read/main.go b/examples/client-read/main.go index 4b43a6a8..a398c2ac 100644 --- a/examples/client-read/main.go +++ b/examples/client-read/main.go @@ -22,13 +22,5 @@ func main() { } // connect to the server and start reading all tracks - err := c.DialRead("rtsp://localhost:8554/mystream") - if err != nil { - panic(err) - } - defer c.Close() - - // read packets - err = c.ReadFrames() - panic(err) + panic(c.StartReadingAndWait("rtsp://localhost:8554/mystream")) }