client: allow calling Close() during a request

This commit is contained in:
aler9
2021-11-12 15:04:57 +01:00
committed by Alessandro Ros
parent 07b1fe6a05
commit d04381d787
2 changed files with 95 additions and 33 deletions

View File

@@ -224,9 +224,15 @@ type Client struct {
checkStreamInitial bool checkStreamInitial bool
tcpLastFrameTime int64 tcpLastFrameTime int64
keepaliveTimer *time.Timer keepaliveTimer *time.Timer
readerRunning bool
finalErr error finalErr error
// connCloser channels
connCloserTerminate chan struct{}
connCloserDone chan struct{}
// reader channels
readerErr chan error
// in // in
options chan optionsReq options chan optionsReq
describe chan describeReq describe chan describeReq
@@ -237,8 +243,7 @@ type Client struct {
pause chan pauseReq pause chan pauseReq
// out // out
readerErr chan error done chan struct{}
done chan struct{}
} }
// Dial connects to a server. // Dial connects to a server.
@@ -558,7 +563,7 @@ func (c *Client) run() {
c.keepaliveTimer = time.NewTimer(clientUDPKeepalivePeriod) c.keepaliveTimer = time.NewTimer(clientUDPKeepalivePeriod)
case err := <-c.readerErr: case err := <-c.readerErr:
c.readerRunning = false c.readerErr = nil
return err return err
case <-c.ctx.Done(): case <-c.ctx.Done():
@@ -569,12 +574,12 @@ func (c *Client) run() {
c.ctxCancel() c.ctxCancel()
c.doClose() c.doClose(true)
} }
func (c *Client) doClose() { func (c *Client) doClose(isClosing bool) {
if c.state == clientStatePlay || c.state == clientStateRecord { if c.state == clientStatePlay || c.state == clientStateRecord {
c.playRecordClose() c.playRecordClose(isClosing)
c.do(&base.Request{ c.do(&base.Request{
Method: base.Teardown, Method: base.Teardown,
@@ -590,13 +595,14 @@ func (c *Client) doClose() {
} }
if c.nconn != nil { if c.nconn != nil {
c.connCloserStop()
c.nconn.Close() c.nconn.Close()
c.nconn = nil c.nconn = nil
} }
} }
func (c *Client) reset() { func (c *Client) reset() {
c.doClose() c.doClose(false)
c.state = clientStateInitial c.state = clientStateInitial
c.session = "" c.session = ""
@@ -654,6 +660,9 @@ func (c *Client) trySwitchingProtocol() error {
} }
func (c *Client) playRecordStart() { func (c *Client) playRecordStart() {
// stop connCloser
c.connCloserStop()
// allow writing // allow writing
c.writeMutex.Lock() c.writeMutex.Lock()
c.writeFrameAllowed = true c.writeFrameAllowed = true
@@ -695,7 +704,6 @@ func (c *Client) playRecordStart() {
c.nconn.SetReadDeadline(time.Time{}) c.nconn.SetReadDeadline(time.Time{})
// start reader // start reader
c.readerRunning = true
c.readerErr = make(chan error) c.readerErr = make(chan error)
go func() { go func() {
c.readerErr <- c.runReader() c.readerErr <- c.runReader()
@@ -775,9 +783,9 @@ func (c *Client) runReader() error {
} }
} }
func (c *Client) playRecordClose() { func (c *Client) playRecordClose(isClosing bool) {
// stop reader // stop reader
if c.readerRunning { if c.readerErr != nil {
c.nconn.SetReadDeadline(time.Now()) c.nconn.SetReadDeadline(time.Now())
<-c.readerErr <-c.readerErr
} }
@@ -799,6 +807,11 @@ func (c *Client) playRecordClose() {
c.writeMutex.Lock() c.writeMutex.Lock()
c.writeFrameAllowed = false c.writeFrameAllowed = false
c.writeMutex.Unlock() c.writeMutex.Unlock()
// start connCloser
if !isClosing {
c.connCloserStart()
}
} }
func (c *Client) connOpen() error { func (c *Client) connOpen() error {
@@ -832,9 +845,32 @@ func (c *Client) connOpen() error {
c.nconn = nconn c.nconn = nconn
c.br = bufio.NewReaderSize(conn, clientReadBufferSize) c.br = bufio.NewReaderSize(conn, clientReadBufferSize)
c.bw = bufio.NewWriterSize(conn, clientWriteBufferSize) c.bw = bufio.NewWriterSize(conn, clientWriteBufferSize)
c.connCloserStart()
return nil return nil
} }
func (c *Client) connCloserStart() {
c.connCloserTerminate = make(chan struct{})
c.connCloserDone = make(chan struct{})
go func() {
defer close(c.connCloserDone)
select {
case <-c.ctx.Done():
c.nconn.Close()
case <-c.connCloserTerminate:
}
}()
}
func (c *Client) connCloserStop() {
if c.connCloserDone != nil {
close(c.connCloserTerminate)
<-c.connCloserDone
c.connCloserDone = nil
}
}
func (c *Client) do(req *base.Request, skipResponse bool) (*base.Response, error) { func (c *Client) do(req *base.Request, skipResponse bool) (*base.Response, error) {
if c.nconn == nil { if c.nconn == nil {
err := c.connOpen() err := c.connOpen()
@@ -867,27 +903,6 @@ func (c *Client) do(req *base.Request, skipResponse bool) (*base.Response, error
var res base.Response var res base.Response
err := func() error { err := func() error {
// the only two do() with skipResponses are
// - TEARDOWN -> ctx is already canceled, so this can't be used
// - keepalives -> if ctx is canceled during a keepalive,
// it's better not to stop the request, but wait until teardown
if !skipResponse {
ctxHandlerDone := make(chan struct{})
defer func() { <-ctxHandlerDone }()
ctxHandlerTerminate := make(chan struct{})
defer close(ctxHandlerTerminate)
go func() {
defer close(ctxHandlerDone)
select {
case <-c.ctx.Done():
c.nconn.Close()
case <-ctxHandlerTerminate:
}
}()
}
c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
err := req.Write(c.bw) err := req.Write(c.bw)
if err != nil { if err != nil {
@@ -1620,7 +1635,7 @@ func (c *Client) doPause() (*base.Response, error) {
return nil, err return nil, err
} }
c.playRecordClose() c.playRecordClose(false)
res, err := c.do(&base.Request{ res, err := c.do(&base.Request{
Method: base.Pause, Method: base.Pause,

View File

@@ -245,3 +245,50 @@ func TestClientDescribeCharset(t *testing.T) {
_, _, _, err = c.Describe(u) _, _, _, err = c.Describe(u)
require.NoError(t, err) require.NoError(t, err)
} }
func TestClientCloseDuringRequest(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
requestReceived := make(chan struct{})
releaseConn := make(chan struct{})
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
conn, err := l.Accept()
require.NoError(t, err)
defer conn.Close()
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
req, err := readRequest(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)
close(requestReceived)
<-releaseConn
}()
u, err := base.ParseURL("rtsp://localhost:8554/teststream")
require.NoError(t, err)
c := Client{}
err = c.Dial(u.Scheme, u.Host)
require.NoError(t, err)
optionsDone := make(chan struct{})
go func() {
defer close(optionsDone)
_, err := c.Options(u)
require.Error(t, err)
}()
<-requestReceived
c.Close()
<-optionsDone
close(releaseConn)
}