cleanup code

This commit is contained in:
aler9
2022-03-02 22:17:14 +01:00
parent e642b964b0
commit dd0904407f
3 changed files with 306 additions and 309 deletions

View File

@@ -430,7 +430,14 @@ func (c *Client) Tracks() Tracks {
func (c *Client) run() { func (c *Client) run() {
defer close(c.done) defer close(c.done)
c.closeError = func() error { c.closeError = c.runInner()
c.ctxCancel()
c.doClose()
}
func (c *Client) runInner() error {
for { for {
select { select {
case req := <-c.options: case req := <-c.options:
@@ -570,11 +577,6 @@ func (c *Client) run() {
return liberrors.ErrClientTerminated{} return liberrors.ErrClientTerminated{}
} }
} }
}()
c.ctxCancel()
c.doClose()
} }
func (c *Client) doClose() { func (c *Client) doClose() {
@@ -957,22 +959,18 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba
c.OnRequest(req) c.OnRequest(req)
} }
var res base.Response
err := func() error {
var buf bytes.Buffer var buf bytes.Buffer
req.Write(&buf) req.Write(&buf)
c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
_, err := c.conn.Write(buf.Bytes()) _, err := c.conn.Write(buf.Bytes())
if err != nil { if err != nil {
return err return nil, err
} }
if skipResponse { var res base.Response
return nil
}
if !skipResponse {
c.conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) c.conn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
if allowFrames { if allowFrames {
@@ -983,20 +981,14 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba
buf := make([]byte, c.ReadBufferSize) buf := make([]byte, c.ReadBufferSize)
err = res.ReadIgnoreFrames(c.br, buf) err = res.ReadIgnoreFrames(c.br, buf)
if err != nil { if err != nil {
return err return nil, err
} }
} else { } else {
err = res.Read(c.br) err = res.Read(c.br)
if err != nil {
return err
}
}
return nil
}()
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
if c.OnResponse != nil { if c.OnResponse != nil {
c.OnResponse(&res) c.OnResponse(&res)
@@ -1029,6 +1021,7 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba
return c.do(req, skipResponse, allowFrames) return c.do(req, skipResponse, allowFrames)
} }
}
return &res, nil return &res, nil
} }

View File

@@ -115,25 +115,7 @@ func (sc *ServerConn) run() {
readDone := make(chan struct{}) readDone := make(chan struct{})
go sc.runReader(readRequest, readErr, readDone) go sc.runReader(readRequest, readErr, readDone)
err := func() error { err := sc.runInner(readRequest, readErr)
for {
select {
case req := <-readRequest:
req.res <- sc.handleRequestOuter(req.req)
case err := <-readErr:
return err
case ss := <-sc.sessionRemove:
if sc.session == ss {
sc.session = nil
}
case <-sc.ctx.Done():
return liberrors.ErrServerTerminated{}
}
}
}()
sc.ctxCancel() sc.ctxCancel()
@@ -160,6 +142,26 @@ func (sc *ServerConn) run() {
} }
} }
func (sc *ServerConn) runInner(readRequest chan readReq, readErr chan error) error {
for {
select {
case req := <-readRequest:
req.res <- sc.handleRequestOuter(req.req)
case err := <-readErr:
return err
case ss := <-sc.sessionRemove:
if sc.session == ss {
sc.session = nil
}
case <-sc.ctx.Done():
return liberrors.ErrServerTerminated{}
}
}
}
var errSwitchReadFunc = errors.New("switch read function") var errSwitchReadFunc = errors.New("switch read function")
func (sc *ServerConn) runReader(readRequest chan readReq, readErr chan error, readDone chan struct{}) { func (sc *ServerConn) runReader(readRequest chan readReq, readErr chan error, readDone chan struct{}) {

View File

@@ -267,7 +267,63 @@ func (ss *ServerSession) run() {
}) })
} }
err := func() error { err := ss.runInner()
ss.ctxCancel()
switch ss.state {
case ServerSessionStatePlay:
ss.setuppedStream.readerSetInactive(ss)
if *ss.setuppedTransport == TransportUDP {
ss.s.udpRTCPListener.removeClient(ss)
}
case ServerSessionStateRecord:
if *ss.setuppedTransport == TransportUDP {
ss.s.udpRTPListener.removeClient(ss)
ss.s.udpRTCPListener.removeClient(ss)
}
}
if ss.setuppedStream != nil {
ss.setuppedStream.readerRemove(ss)
}
if ss.writerRunning {
ss.writeBuffer.Close()
<-ss.writerDone
ss.writerRunning = false
}
for sc := range ss.conns {
if sc == ss.tcpConn {
sc.Close()
// make sure that OnFrame() is never called after OnSessionClose()
<-sc.done
}
select {
case sc.sessionRemove <- ss:
case <-sc.ctx.Done():
}
}
select {
case ss.s.sessionClose <- ss:
case <-ss.s.ctx.Done():
}
if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok {
h.OnSessionClose(&ServerHandlerOnSessionCloseCtx{
Session: ss,
Error: err,
})
}
}
func (ss *ServerSession) runInner() error {
for { for {
select { select {
case req := <-ss.request: case req := <-ss.request:
@@ -369,60 +425,6 @@ func (ss *ServerSession) run() {
return liberrors.ErrServerTerminated{} return liberrors.ErrServerTerminated{}
} }
} }
}()
ss.ctxCancel()
switch ss.state {
case ServerSessionStatePlay:
ss.setuppedStream.readerSetInactive(ss)
if *ss.setuppedTransport == TransportUDP {
ss.s.udpRTCPListener.removeClient(ss)
}
case ServerSessionStateRecord:
if *ss.setuppedTransport == TransportUDP {
ss.s.udpRTPListener.removeClient(ss)
ss.s.udpRTCPListener.removeClient(ss)
}
}
if ss.setuppedStream != nil {
ss.setuppedStream.readerRemove(ss)
}
if ss.writerRunning {
ss.writeBuffer.Close()
<-ss.writerDone
ss.writerRunning = false
}
for sc := range ss.conns {
if sc == ss.tcpConn {
sc.Close()
// make sure that OnFrame() is never called after OnSessionClose()
<-sc.done
}
select {
case sc.sessionRemove <- ss:
case <-sc.ctx.Done():
}
}
select {
case ss.s.sessionClose <- ss:
case <-ss.s.ctx.Done():
}
if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok {
h.OnSessionClose(&ServerHandlerOnSessionCloseCtx{
Session: ss,
Error: err,
})
}
} }
func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base.Response, error) { func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base.Response, error) {