From d04381d7877660c743a44197611fc38b394179ea Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Fri, 12 Nov 2021 15:04:57 +0100 Subject: [PATCH] client: allow calling Close() during a request --- client.go | 81 ++++++++++++++++++++++++++++++-------------------- client_test.go | 47 +++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 33 deletions(-) diff --git a/client.go b/client.go index 4ad1fadf..045c4d89 100644 --- a/client.go +++ b/client.go @@ -224,9 +224,15 @@ type Client struct { checkStreamInitial bool tcpLastFrameTime int64 keepaliveTimer *time.Timer - readerRunning bool finalErr error + // connCloser channels + connCloserTerminate chan struct{} + connCloserDone chan struct{} + + // reader channels + readerErr chan error + // in options chan optionsReq describe chan describeReq @@ -237,8 +243,7 @@ type Client struct { pause chan pauseReq // out - readerErr chan error - done chan struct{} + done chan struct{} } // Dial connects to a server. @@ -558,7 +563,7 @@ func (c *Client) run() { c.keepaliveTimer = time.NewTimer(clientUDPKeepalivePeriod) case err := <-c.readerErr: - c.readerRunning = false + c.readerErr = nil return err case <-c.ctx.Done(): @@ -569,12 +574,12 @@ func (c *Client) run() { c.ctxCancel() - c.doClose() + c.doClose(true) } -func (c *Client) doClose() { +func (c *Client) doClose(isClosing bool) { if c.state == clientStatePlay || c.state == clientStateRecord { - c.playRecordClose() + c.playRecordClose(isClosing) c.do(&base.Request{ Method: base.Teardown, @@ -590,13 +595,14 @@ func (c *Client) doClose() { } if c.nconn != nil { + c.connCloserStop() c.nconn.Close() c.nconn = nil } } func (c *Client) reset() { - c.doClose() + c.doClose(false) c.state = clientStateInitial c.session = "" @@ -654,6 +660,9 @@ func (c *Client) trySwitchingProtocol() error { } func (c *Client) playRecordStart() { + // stop connCloser + c.connCloserStop() + // allow writing c.writeMutex.Lock() c.writeFrameAllowed = true @@ -695,7 +704,6 @@ func (c *Client) playRecordStart() { c.nconn.SetReadDeadline(time.Time{}) // start reader - c.readerRunning = true c.readerErr = make(chan error) go func() { c.readerErr <- c.runReader() @@ -775,9 +783,9 @@ func (c *Client) runReader() error { } } -func (c *Client) playRecordClose() { +func (c *Client) playRecordClose(isClosing bool) { // stop reader - if c.readerRunning { + if c.readerErr != nil { c.nconn.SetReadDeadline(time.Now()) <-c.readerErr } @@ -799,6 +807,11 @@ func (c *Client) playRecordClose() { c.writeMutex.Lock() c.writeFrameAllowed = false c.writeMutex.Unlock() + + // start connCloser + if !isClosing { + c.connCloserStart() + } } func (c *Client) connOpen() error { @@ -832,9 +845,32 @@ func (c *Client) connOpen() error { c.nconn = nconn c.br = bufio.NewReaderSize(conn, clientReadBufferSize) c.bw = bufio.NewWriterSize(conn, clientWriteBufferSize) + c.connCloserStart() return nil } +func (c *Client) connCloserStart() { + c.connCloserTerminate = make(chan struct{}) + c.connCloserDone = make(chan struct{}) + go func() { + defer close(c.connCloserDone) + select { + case <-c.ctx.Done(): + c.nconn.Close() + + case <-c.connCloserTerminate: + } + }() +} + +func (c *Client) connCloserStop() { + if c.connCloserDone != nil { + close(c.connCloserTerminate) + <-c.connCloserDone + c.connCloserDone = nil + } +} + func (c *Client) do(req *base.Request, skipResponse bool) (*base.Response, error) { if c.nconn == nil { err := c.connOpen() @@ -867,27 +903,6 @@ func (c *Client) do(req *base.Request, skipResponse bool) (*base.Response, error var res base.Response err := func() error { - // the only two do() with skipResponses are - // - TEARDOWN -> ctx is already canceled, so this can't be used - // - keepalives -> if ctx is canceled during a keepalive, - // it's better not to stop the request, but wait until teardown - if !skipResponse { - ctxHandlerDone := make(chan struct{}) - defer func() { <-ctxHandlerDone }() - - ctxHandlerTerminate := make(chan struct{}) - defer close(ctxHandlerTerminate) - - go func() { - defer close(ctxHandlerDone) - select { - case <-c.ctx.Done(): - c.nconn.Close() - case <-ctxHandlerTerminate: - } - }() - } - c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) err := req.Write(c.bw) if err != nil { @@ -1620,7 +1635,7 @@ func (c *Client) doPause() (*base.Response, error) { return nil, err } - c.playRecordClose() + c.playRecordClose(false) res, err := c.do(&base.Request{ Method: base.Pause, diff --git a/client_test.go b/client_test.go index 921790ee..88996042 100644 --- a/client_test.go +++ b/client_test.go @@ -245,3 +245,50 @@ func TestClientDescribeCharset(t *testing.T) { _, _, _, err = c.Describe(u) require.NoError(t, err) } + +func TestClientCloseDuringRequest(t *testing.T) { + l, err := net.Listen("tcp", "localhost:8554") + require.NoError(t, err) + defer l.Close() + + requestReceived := make(chan struct{}) + releaseConn := make(chan struct{}) + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) + + conn, err := l.Accept() + require.NoError(t, err) + defer conn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + req, err := readRequest(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.Options, req.Method) + + close(requestReceived) + <-releaseConn + }() + + u, err := base.ParseURL("rtsp://localhost:8554/teststream") + require.NoError(t, err) + + c := Client{} + + err = c.Dial(u.Scheme, u.Host) + require.NoError(t, err) + + optionsDone := make(chan struct{}) + go func() { + defer close(optionsDone) + _, err := c.Options(u) + require.Error(t, err) + }() + + <-requestReceived + c.Close() + <-optionsDone + close(releaseConn) +}