mirror of
https://github.com/aler9/gortsplib
synced 2025-09-26 19:21:20 +08:00
This happened when writing a TCP packet to a conn after a read error.
This commit is contained in:
@@ -882,6 +882,7 @@ func (c *Client) runInner() error {
|
||||
}
|
||||
|
||||
case err := <-c.chReadError:
|
||||
c.reader.close()
|
||||
c.reader = nil
|
||||
return err
|
||||
|
||||
@@ -918,6 +919,7 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) {
|
||||
}
|
||||
|
||||
case err := <-c.chReadError:
|
||||
c.reader.close()
|
||||
c.reader = nil
|
||||
return nil, err
|
||||
|
||||
@@ -975,7 +977,7 @@ func (c *Client) doClose() {
|
||||
|
||||
if c.reader != nil {
|
||||
c.nconn.Close()
|
||||
c.reader.wait()
|
||||
c.reader.close()
|
||||
c.reader = nil
|
||||
c.nconn = nil
|
||||
c.conn = nil
|
||||
|
@@ -12,9 +12,16 @@ type clientReader struct {
|
||||
|
||||
mutex sync.Mutex
|
||||
allowInterleavedFrames bool
|
||||
|
||||
terminate chan struct{}
|
||||
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func (r *clientReader) start() {
|
||||
r.terminate = make(chan struct{})
|
||||
r.done = make(chan struct{})
|
||||
|
||||
go r.run()
|
||||
}
|
||||
|
||||
@@ -24,19 +31,20 @@ func (r *clientReader) setAllowInterleavedFrames(v bool) {
|
||||
r.allowInterleavedFrames = v
|
||||
}
|
||||
|
||||
func (r *clientReader) wait() {
|
||||
for {
|
||||
select {
|
||||
case <-r.c.chResponse:
|
||||
case <-r.c.chRequest:
|
||||
case <-r.c.chReadError:
|
||||
return
|
||||
}
|
||||
}
|
||||
func (r *clientReader) close() {
|
||||
close(r.terminate)
|
||||
<-r.done
|
||||
}
|
||||
|
||||
func (r *clientReader) run() {
|
||||
r.c.chReadError <- r.runInner()
|
||||
defer close(r.done)
|
||||
|
||||
err := r.runInner()
|
||||
|
||||
select {
|
||||
case r.c.chReadError <- err:
|
||||
case <-r.terminate:
|
||||
}
|
||||
}
|
||||
|
||||
func (r *clientReader) runInner() error {
|
||||
@@ -48,10 +56,16 @@ func (r *clientReader) runInner() error {
|
||||
|
||||
switch what := what.(type) {
|
||||
case *base.Response:
|
||||
r.c.chResponse <- what
|
||||
select {
|
||||
case r.c.chResponse <- what:
|
||||
case <-r.terminate:
|
||||
}
|
||||
|
||||
case *base.Request:
|
||||
r.c.chRequest <- what
|
||||
select {
|
||||
case r.c.chRequest <- what:
|
||||
case <-r.terminate:
|
||||
}
|
||||
|
||||
case *base.InterleavedFrame:
|
||||
r.mutex.Lock()
|
||||
|
@@ -401,6 +401,9 @@ func (s *Server) runInner() error {
|
||||
sc.Close()
|
||||
|
||||
case req := <-s.chHandleHTTPChannel:
|
||||
if _, ok := s.conns[req.sc]; !ok {
|
||||
continue
|
||||
}
|
||||
if !req.write {
|
||||
req.sc.httpReadTunnelID = req.tunnelID
|
||||
s.httpReadChannels[req.sc] = req.res
|
||||
@@ -551,8 +554,6 @@ func (s *Server) handleHTTPChannel(req sessionHandleHTTPChannelReq) error {
|
||||
|
||||
select {
|
||||
case s.chHandleHTTPChannel <- req:
|
||||
case <-req.sc.ctx.Done():
|
||||
return fmt.Errorf("terminated")
|
||||
case <-s.ctx.Done():
|
||||
return fmt.Errorf("terminated")
|
||||
}
|
||||
|
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/bluenviron/gortsplib/v4/pkg/auth"
|
||||
"github.com/bluenviron/gortsplib/v4/pkg/base"
|
||||
"github.com/bluenviron/gortsplib/v4/pkg/bytecounter"
|
||||
"github.com/bluenviron/gortsplib/v4/pkg/conn"
|
||||
"github.com/bluenviron/gortsplib/v4/pkg/description"
|
||||
"github.com/bluenviron/gortsplib/v4/pkg/headers"
|
||||
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
|
||||
@@ -205,8 +206,8 @@ type ServerConn struct {
|
||||
userData interface{}
|
||||
remoteAddr *net.TCPAddr
|
||||
bc *bytecounter.ByteCounter
|
||||
conn *conn.Conn
|
||||
session *ServerSession
|
||||
reader *serverConnReader
|
||||
authNonce string
|
||||
httpReadBuf *bufio.Reader
|
||||
httpReadTunnelID string
|
||||
@@ -362,10 +363,10 @@ func (sc *ServerConn) run() {
|
||||
})
|
||||
}
|
||||
|
||||
sc.reader = &serverConnReader{
|
||||
reader := &serverConnReader{
|
||||
sc: sc,
|
||||
}
|
||||
sc.reader.initialize()
|
||||
reader.initialize()
|
||||
|
||||
err := sc.runInner()
|
||||
|
||||
@@ -375,9 +376,7 @@ func (sc *ServerConn) run() {
|
||||
sc.nconn.Close()
|
||||
}
|
||||
|
||||
if sc.reader != nil {
|
||||
sc.reader.wait()
|
||||
}
|
||||
reader.wait()
|
||||
|
||||
if sc.session != nil {
|
||||
sc.session.removeConn(sc)
|
||||
@@ -400,7 +399,6 @@ func (sc *ServerConn) runInner() error {
|
||||
req.res <- sc.handleRequestOuter(req.req)
|
||||
|
||||
case err := <-sc.chReadError:
|
||||
sc.reader = nil
|
||||
return err
|
||||
|
||||
case ss := <-sc.chRemoveSession:
|
||||
@@ -629,7 +627,7 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
|
||||
}
|
||||
|
||||
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout))
|
||||
err2 := sc.reader.conn.WriteResponse(res)
|
||||
err2 := sc.conn.WriteResponse(res)
|
||||
if err == nil && err2 != nil {
|
||||
err = err2
|
||||
}
|
||||
|
@@ -36,28 +36,30 @@ func isSwitchReadFuncError(err error) bool {
|
||||
}
|
||||
|
||||
type serverConnReader struct {
|
||||
sc *ServerConn
|
||||
conn *conn.Conn
|
||||
sc *ServerConn
|
||||
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func (cr *serverConnReader) initialize() {
|
||||
cr.done = make(chan struct{})
|
||||
|
||||
go cr.run()
|
||||
}
|
||||
|
||||
func (cr *serverConnReader) wait() {
|
||||
for {
|
||||
select {
|
||||
case <-cr.sc.chReadError:
|
||||
return
|
||||
|
||||
case req := <-cr.sc.chRequest:
|
||||
req.res <- fmt.Errorf("terminated")
|
||||
}
|
||||
}
|
||||
<-cr.done
|
||||
}
|
||||
|
||||
func (cr *serverConnReader) run() {
|
||||
cr.sc.chReadError <- cr.runInner()
|
||||
defer close(cr.done)
|
||||
|
||||
err := cr.runInner()
|
||||
|
||||
select {
|
||||
case cr.sc.chReadError <- err:
|
||||
case <-cr.sc.ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func (cr *serverConnReader) runInner() error {
|
||||
@@ -71,7 +73,7 @@ func (cr *serverConnReader) runInner() error {
|
||||
}
|
||||
}
|
||||
|
||||
cr.conn = conn.NewConn(bufio.NewReader(rw), rw)
|
||||
cr.sc.conn = conn.NewConn(bufio.NewReader(rw), rw)
|
||||
|
||||
readFunc := cr.readFuncStandard
|
||||
|
||||
@@ -171,7 +173,7 @@ func (cr *serverConnReader) readFuncStandard() error {
|
||||
cr.sc.nconn.SetReadDeadline(time.Time{})
|
||||
|
||||
for {
|
||||
what, err := cr.conn.Read()
|
||||
what, err := cr.sc.conn.Read()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -180,7 +182,12 @@ func (cr *serverConnReader) readFuncStandard() error {
|
||||
case *base.Request:
|
||||
cres := make(chan error)
|
||||
req := readReq{req: what, res: cres}
|
||||
cr.sc.chRequest <- req
|
||||
|
||||
select {
|
||||
case cr.sc.chRequest <- req:
|
||||
case <-cr.sc.ctx.Done():
|
||||
return fmt.Errorf("terminated")
|
||||
}
|
||||
|
||||
err = <-cres
|
||||
if err != nil {
|
||||
@@ -207,7 +214,7 @@ func (cr *serverConnReader) readFuncTCP() error {
|
||||
cr.sc.nconn.SetReadDeadline(time.Now().Add(cr.sc.s.ReadTimeout))
|
||||
}
|
||||
|
||||
what, err := cr.conn.Read()
|
||||
what, err := cr.sc.conn.Read()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -216,7 +223,12 @@ func (cr *serverConnReader) readFuncTCP() error {
|
||||
case *base.Request:
|
||||
cres := make(chan error)
|
||||
req := readReq{req: what, res: cres}
|
||||
cr.sc.chRequest <- req
|
||||
|
||||
select {
|
||||
case cr.sc.chRequest <- req:
|
||||
case <-cr.sc.ctx.Done():
|
||||
return fmt.Errorf("terminated")
|
||||
}
|
||||
|
||||
err = <-cres
|
||||
if err != nil {
|
||||
|
@@ -179,7 +179,7 @@ func (sf *serverSessionFormat) writePacketRTPInQueueTCP(payload []byte) error {
|
||||
sf.sm.ss.tcpFrame.Channel = sf.sm.tcpChannel
|
||||
sf.sm.ss.tcpFrame.Payload = payload
|
||||
sf.sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sf.sm.ss.s.WriteTimeout))
|
||||
err := sf.sm.ss.tcpConn.reader.conn.WriteInterleavedFrame(sf.sm.ss.tcpFrame, sf.sm.ss.tcpBuffer)
|
||||
err := sf.sm.ss.tcpConn.conn.WriteInterleavedFrame(sf.sm.ss.tcpFrame, sf.sm.ss.tcpBuffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -483,7 +483,7 @@ func (sm *serverSessionMedia) writePacketRTCPInQueueTCP(payload []byte) error {
|
||||
sm.ss.tcpFrame.Channel = sm.tcpChannel + 1
|
||||
sm.ss.tcpFrame.Payload = payload
|
||||
sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout))
|
||||
err := sm.ss.tcpConn.reader.conn.WriteInterleavedFrame(sm.ss.tcpFrame, sm.ss.tcpBuffer)
|
||||
err := sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.ss.tcpFrame, sm.ss.tcpBuffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
Reference in New Issue
Block a user