close connections in case of write errors (#613) (#655)

This commit is contained in:
Alessandro Ros
2024-12-14 13:45:11 +01:00
committed by GitHub
parent a2df9d83b3
commit 8f74559616
12 changed files with 427 additions and 350 deletions

127
client.go
View File

@@ -335,22 +335,19 @@ type Client struct {
keepalivePeriod time.Duration
keepaliveTimer *time.Timer
closeError error
writer asyncProcessor
writer *asyncProcessor
reader *clientReader
timeDecoder *rtptime.GlobalDecoder2
mustClose bool
// in
chOptions chan optionsReq
chDescribe chan describeReq
chAnnounce chan announceReq
chSetup chan setupReq
chPlay chan playReq
chRecord chan recordReq
chPause chan pauseReq
chReadError chan error
chReadResponse chan *base.Response
chReadRequest chan *base.Request
chOptions chan optionsReq
chDescribe chan describeReq
chAnnounce chan announceReq
chSetup chan setupReq
chPlay chan playReq
chRecord chan recordReq
chPause chan pauseReq
// out
done chan struct{}
@@ -462,9 +459,6 @@ func (c *Client) Start(scheme string, host string) error {
c.chPlay = make(chan playReq)
c.chRecord = make(chan recordReq)
c.chPause = make(chan pauseReq)
c.chReadError = make(chan error)
c.chReadResponse = make(chan *base.Response)
c.chReadRequest = make(chan *base.Request)
c.done = make(chan struct{})
go c.run()
@@ -530,6 +524,34 @@ func (c *Client) run() {
func (c *Client) runInner() error {
for {
chReaderResponse := func() chan *base.Response {
if c.reader != nil {
return c.reader.chResponse
}
return nil
}()
chReaderRequest := func() chan *base.Request {
if c.reader != nil {
return c.reader.chRequest
}
return nil
}()
chReaderError := func() chan error {
if c.reader != nil {
return c.reader.chError
}
return nil
}()
chWriterError := func() chan error {
if c.writer != nil {
return c.writer.chError
}
return nil
}()
select {
case req := <-c.chOptions:
res, err := c.doOptions(req.url)
@@ -601,15 +623,18 @@ func (c *Client) runInner() error {
}
c.keepaliveTimer = time.NewTimer(c.keepalivePeriod)
case err := <-c.chReadError:
case err := <-chWriterError:
return err
case err := <-chReaderError:
c.reader = nil
return err
case res := <-c.chReadResponse:
case res := <-chReaderResponse:
c.OnResponse(res)
// these are responses to keepalives, ignore them.
case req := <-c.chReadRequest:
case req := <-chReaderRequest:
err := c.handleServerRequest(req)
if err != nil {
return err
@@ -630,11 +655,11 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) {
case <-t.C:
return nil, liberrors.ErrClientRequestTimedOut{}
case err := <-c.chReadError:
case err := <-c.reader.chError:
c.reader = nil
return nil, err
case res := <-c.chReadResponse:
case res := <-c.reader.chResponse:
c.OnResponse(res)
// accept response if CSeq equals request CSeq, or if CSeq is not present
@@ -642,7 +667,7 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) {
return res, nil
}
case req := <-c.chReadRequest:
case req := <-c.reader.chRequest:
err := c.handleServerRequest(req)
if err != nil {
return nil, err
@@ -682,8 +707,8 @@ func (c *Client) handleServerRequest(req *base.Request) error {
func (c *Client) doClose() {
if c.state == clientStatePlay || c.state == clientStateRecord {
c.stopWriter()
c.stopReadRoutines()
c.writer.stop()
c.stopTransportRoutines()
}
if c.nconn != nil && c.baseURL != nil {
@@ -808,15 +833,21 @@ func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *base.UR
return c.doSetup(baseURL, medi, 0, 0)
}
func (c *Client) startReadRoutines() {
func (c *Client) startTransportRoutines() {
// allocate writer here because it's needed by RTCP receiver / sender
if c.state == clientStateRecord || c.backChannelSetupped {
c.writer.allocateBuffer(c.WriteQueueSize)
c.writer = &asyncProcessor{
bufferSize: c.WriteQueueSize,
}
c.writer.initialize()
} else {
// when reading, buffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
c.writer.allocateBuffer(8)
c.writer = &asyncProcessor{
bufferSize: 8,
}
c.writer.initialize()
}
c.timeDecoder = rtptime.NewGlobalDecoder2()
@@ -848,7 +879,7 @@ func (c *Client) startReadRoutines() {
}
}
func (c *Client) stopReadRoutines() {
func (c *Client) stopTransportRoutines() {
if c.reader != nil {
c.reader.setAllowInterleavedFrames(false)
}
@@ -861,14 +892,8 @@ func (c *Client) stopReadRoutines() {
}
c.timeDecoder = nil
}
func (c *Client) startWriter() {
c.writer.start()
}
func (c *Client) stopWriter() {
c.writer.stop()
c.writer = nil
}
func (c *Client) connOpen() error {
@@ -1637,7 +1662,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
}
c.state = clientStatePlay
c.startReadRoutines()
c.startTransportRoutines()
// Range is mandatory in Parrot Streaming Server
if ra == nil {
@@ -1662,13 +1687,13 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
Header: header,
}, false)
if err != nil {
c.stopReadRoutines()
c.stopTransportRoutines()
c.state = clientStatePrePlay
return nil, err
}
if res.StatusCode != base.StatusOK {
c.stopReadRoutines()
c.stopTransportRoutines()
c.state = clientStatePrePlay
return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage,
@@ -1689,7 +1714,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
}
}
c.startWriter()
c.writer.start()
c.lastRange = ra
return res, nil
@@ -1718,27 +1743,27 @@ func (c *Client) doRecord() (*base.Response, error) {
}
c.state = clientStateRecord
c.startReadRoutines()
c.startTransportRoutines()
res, err := c.do(&base.Request{
Method: base.Record,
URL: c.baseURL,
}, false)
if err != nil {
c.stopReadRoutines()
c.stopTransportRoutines()
c.state = clientStatePreRecord
return nil, err
}
if res.StatusCode != base.StatusOK {
c.stopReadRoutines()
c.stopTransportRoutines()
c.state = clientStatePreRecord
return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage,
}
}
c.startWriter()
c.writer.start()
return nil, nil
}
@@ -1766,25 +1791,25 @@ func (c *Client) doPause() (*base.Response, error) {
return nil, err
}
c.stopWriter()
c.writer.stop()
res, err := c.do(&base.Request{
Method: base.Pause,
URL: c.baseURL,
}, false)
if err != nil {
c.startWriter()
c.writer.start()
return nil, err
}
if res.StatusCode != base.StatusOK {
c.startWriter()
c.writer.start()
return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage,
}
}
c.stopReadRoutines()
c.stopTransportRoutines()
switch c.state {
case clientStatePlay:
@@ -1929,15 +1954,3 @@ func (c *Client) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time,
ct := cm.formats[pkt.PayloadType]
return ct.rtcpReceiver.PacketNTP(pkt.Timestamp)
}
func (c *Client) readResponse(res *base.Response) {
c.chReadResponse <- res
}
func (c *Client) readRequest(req *base.Request) {
c.chReadRequest <- req
}
func (c *Client) readError(err error) {
c.chReadError <- err
}