diff --git a/client.go b/client.go index e5116287..2513bafb 100644 --- a/client.go +++ b/client.go @@ -882,6 +882,7 @@ func (c *Client) runInner() error { } case err := <-c.chReadError: + c.reader.close() c.reader = nil return err @@ -918,6 +919,7 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) { } case err := <-c.chReadError: + c.reader.close() c.reader = nil return nil, err @@ -975,7 +977,7 @@ func (c *Client) doClose() { if c.reader != nil { c.nconn.Close() - c.reader.wait() + c.reader.close() c.reader = nil c.nconn = nil c.conn = nil diff --git a/client_reader.go b/client_reader.go index d68fa9a0..a275d304 100644 --- a/client_reader.go +++ b/client_reader.go @@ -12,9 +12,16 @@ type clientReader struct { mutex sync.Mutex allowInterleavedFrames bool + + terminate chan struct{} + + done chan struct{} } func (r *clientReader) start() { + r.terminate = make(chan struct{}) + r.done = make(chan struct{}) + go r.run() } @@ -24,19 +31,20 @@ func (r *clientReader) setAllowInterleavedFrames(v bool) { r.allowInterleavedFrames = v } -func (r *clientReader) wait() { - for { - select { - case <-r.c.chResponse: - case <-r.c.chRequest: - case <-r.c.chReadError: - return - } - } +func (r *clientReader) close() { + close(r.terminate) + <-r.done } func (r *clientReader) run() { - r.c.chReadError <- r.runInner() + defer close(r.done) + + err := r.runInner() + + select { + case r.c.chReadError <- err: + case <-r.terminate: + } } func (r *clientReader) runInner() error { @@ -48,10 +56,16 @@ func (r *clientReader) runInner() error { switch what := what.(type) { case *base.Response: - r.c.chResponse <- what + select { + case r.c.chResponse <- what: + case <-r.terminate: + } case *base.Request: - r.c.chRequest <- what + select { + case r.c.chRequest <- what: + case <-r.terminate: + } case *base.InterleavedFrame: r.mutex.Lock() diff --git a/server.go b/server.go index 54ac8d0a..f7966568 100644 --- a/server.go +++ b/server.go @@ -401,6 +401,9 @@ func (s *Server) runInner() error { sc.Close() case req := <-s.chHandleHTTPChannel: + if _, ok := s.conns[req.sc]; !ok { + continue + } if !req.write { req.sc.httpReadTunnelID = req.tunnelID s.httpReadChannels[req.sc] = req.res @@ -551,8 +554,6 @@ func (s *Server) handleHTTPChannel(req sessionHandleHTTPChannelReq) error { select { case s.chHandleHTTPChannel <- req: - case <-req.sc.ctx.Done(): - return fmt.Errorf("terminated") case <-s.ctx.Done(): return fmt.Errorf("terminated") } diff --git a/server_conn.go b/server_conn.go index da57efc1..22d737d8 100644 --- a/server_conn.go +++ b/server_conn.go @@ -16,6 +16,7 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/auth" "github.com/bluenviron/gortsplib/v4/pkg/base" "github.com/bluenviron/gortsplib/v4/pkg/bytecounter" + "github.com/bluenviron/gortsplib/v4/pkg/conn" "github.com/bluenviron/gortsplib/v4/pkg/description" "github.com/bluenviron/gortsplib/v4/pkg/headers" "github.com/bluenviron/gortsplib/v4/pkg/liberrors" @@ -205,8 +206,8 @@ type ServerConn struct { userData interface{} remoteAddr *net.TCPAddr bc *bytecounter.ByteCounter + conn *conn.Conn session *ServerSession - reader *serverConnReader authNonce string httpReadBuf *bufio.Reader httpReadTunnelID string @@ -362,10 +363,10 @@ func (sc *ServerConn) run() { }) } - sc.reader = &serverConnReader{ + reader := &serverConnReader{ sc: sc, } - sc.reader.initialize() + reader.initialize() err := sc.runInner() @@ -375,9 +376,7 @@ func (sc *ServerConn) run() { sc.nconn.Close() } - if sc.reader != nil { - sc.reader.wait() - } + reader.wait() if sc.session != nil { sc.session.removeConn(sc) @@ -400,7 +399,6 @@ func (sc *ServerConn) runInner() error { req.res <- sc.handleRequestOuter(req.req) case err := <-sc.chReadError: - sc.reader = nil return err case ss := <-sc.chRemoveSession: @@ -629,7 +627,7 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error { } sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout)) - err2 := sc.reader.conn.WriteResponse(res) + err2 := sc.conn.WriteResponse(res) if err == nil && err2 != nil { err = err2 } diff --git a/server_conn_reader.go b/server_conn_reader.go index 595d5d86..149b76b1 100644 --- a/server_conn_reader.go +++ b/server_conn_reader.go @@ -36,28 +36,30 @@ func isSwitchReadFuncError(err error) bool { } type serverConnReader struct { - sc *ServerConn - conn *conn.Conn + sc *ServerConn + + done chan struct{} } func (cr *serverConnReader) initialize() { + cr.done = make(chan struct{}) + go cr.run() } func (cr *serverConnReader) wait() { - for { - select { - case <-cr.sc.chReadError: - return - - case req := <-cr.sc.chRequest: - req.res <- fmt.Errorf("terminated") - } - } + <-cr.done } func (cr *serverConnReader) run() { - cr.sc.chReadError <- cr.runInner() + defer close(cr.done) + + err := cr.runInner() + + select { + case cr.sc.chReadError <- err: + case <-cr.sc.ctx.Done(): + } } func (cr *serverConnReader) runInner() error { @@ -71,7 +73,7 @@ func (cr *serverConnReader) runInner() error { } } - cr.conn = conn.NewConn(bufio.NewReader(rw), rw) + cr.sc.conn = conn.NewConn(bufio.NewReader(rw), rw) readFunc := cr.readFuncStandard @@ -171,7 +173,7 @@ func (cr *serverConnReader) readFuncStandard() error { cr.sc.nconn.SetReadDeadline(time.Time{}) for { - what, err := cr.conn.Read() + what, err := cr.sc.conn.Read() if err != nil { return err } @@ -180,7 +182,12 @@ func (cr *serverConnReader) readFuncStandard() error { case *base.Request: cres := make(chan error) req := readReq{req: what, res: cres} - cr.sc.chRequest <- req + + select { + case cr.sc.chRequest <- req: + case <-cr.sc.ctx.Done(): + return fmt.Errorf("terminated") + } err = <-cres if err != nil { @@ -207,7 +214,7 @@ func (cr *serverConnReader) readFuncTCP() error { cr.sc.nconn.SetReadDeadline(time.Now().Add(cr.sc.s.ReadTimeout)) } - what, err := cr.conn.Read() + what, err := cr.sc.conn.Read() if err != nil { return err } @@ -216,7 +223,12 @@ func (cr *serverConnReader) readFuncTCP() error { case *base.Request: cres := make(chan error) req := readReq{req: what, res: cres} - cr.sc.chRequest <- req + + select { + case cr.sc.chRequest <- req: + case <-cr.sc.ctx.Done(): + return fmt.Errorf("terminated") + } err = <-cres if err != nil { diff --git a/server_session_format.go b/server_session_format.go index dc2677c2..4d080afd 100644 --- a/server_session_format.go +++ b/server_session_format.go @@ -179,7 +179,7 @@ func (sf *serverSessionFormat) writePacketRTPInQueueTCP(payload []byte) error { sf.sm.ss.tcpFrame.Channel = sf.sm.tcpChannel sf.sm.ss.tcpFrame.Payload = payload sf.sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sf.sm.ss.s.WriteTimeout)) - err := sf.sm.ss.tcpConn.reader.conn.WriteInterleavedFrame(sf.sm.ss.tcpFrame, sf.sm.ss.tcpBuffer) + err := sf.sm.ss.tcpConn.conn.WriteInterleavedFrame(sf.sm.ss.tcpFrame, sf.sm.ss.tcpBuffer) if err != nil { return err } diff --git a/server_session_media.go b/server_session_media.go index 4d9939a5..247175ce 100644 --- a/server_session_media.go +++ b/server_session_media.go @@ -483,7 +483,7 @@ func (sm *serverSessionMedia) writePacketRTCPInQueueTCP(payload []byte) error { sm.ss.tcpFrame.Channel = sm.tcpChannel + 1 sm.ss.tcpFrame.Payload = payload sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout)) - err := sm.ss.tcpConn.reader.conn.WriteInterleavedFrame(sm.ss.tcpFrame, sm.ss.tcpBuffer) + err := sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.ss.tcpFrame, sm.ss.tcpBuffer) if err != nil { return err }