server: fix panic due to regression in #887 (#892)

This happened when writing a TCP packet to a conn after a read error.
This commit is contained in:
Alessandro Ros
2025-09-16 10:36:06 +02:00
committed by GitHub
parent 3f446ed08d
commit 8c7e4320bc
7 changed files with 69 additions and 42 deletions

View File

@@ -882,6 +882,7 @@ func (c *Client) runInner() error {
}
case err := <-c.chReadError:
c.reader.close()
c.reader = nil
return err
@@ -918,6 +919,7 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) {
}
case err := <-c.chReadError:
c.reader.close()
c.reader = nil
return nil, err
@@ -975,7 +977,7 @@ func (c *Client) doClose() {
if c.reader != nil {
c.nconn.Close()
c.reader.wait()
c.reader.close()
c.reader = nil
c.nconn = nil
c.conn = nil

View File

@@ -12,9 +12,16 @@ type clientReader struct {
mutex sync.Mutex
allowInterleavedFrames bool
terminate chan struct{}
done chan struct{}
}
func (r *clientReader) start() {
r.terminate = make(chan struct{})
r.done = make(chan struct{})
go r.run()
}
@@ -24,19 +31,20 @@ func (r *clientReader) setAllowInterleavedFrames(v bool) {
r.allowInterleavedFrames = v
}
func (r *clientReader) wait() {
for {
select {
case <-r.c.chResponse:
case <-r.c.chRequest:
case <-r.c.chReadError:
return
}
}
func (r *clientReader) close() {
close(r.terminate)
<-r.done
}
func (r *clientReader) run() {
r.c.chReadError <- r.runInner()
defer close(r.done)
err := r.runInner()
select {
case r.c.chReadError <- err:
case <-r.terminate:
}
}
func (r *clientReader) runInner() error {
@@ -48,10 +56,16 @@ func (r *clientReader) runInner() error {
switch what := what.(type) {
case *base.Response:
r.c.chResponse <- what
select {
case r.c.chResponse <- what:
case <-r.terminate:
}
case *base.Request:
r.c.chRequest <- what
select {
case r.c.chRequest <- what:
case <-r.terminate:
}
case *base.InterleavedFrame:
r.mutex.Lock()

View File

@@ -401,6 +401,9 @@ func (s *Server) runInner() error {
sc.Close()
case req := <-s.chHandleHTTPChannel:
if _, ok := s.conns[req.sc]; !ok {
continue
}
if !req.write {
req.sc.httpReadTunnelID = req.tunnelID
s.httpReadChannels[req.sc] = req.res
@@ -551,8 +554,6 @@ func (s *Server) handleHTTPChannel(req sessionHandleHTTPChannelReq) error {
select {
case s.chHandleHTTPChannel <- req:
case <-req.sc.ctx.Done():
return fmt.Errorf("terminated")
case <-s.ctx.Done():
return fmt.Errorf("terminated")
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/auth"
"github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/bytecounter"
"github.com/bluenviron/gortsplib/v4/pkg/conn"
"github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/headers"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
@@ -205,8 +206,8 @@ type ServerConn struct {
userData interface{}
remoteAddr *net.TCPAddr
bc *bytecounter.ByteCounter
conn *conn.Conn
session *ServerSession
reader *serverConnReader
authNonce string
httpReadBuf *bufio.Reader
httpReadTunnelID string
@@ -362,10 +363,10 @@ func (sc *ServerConn) run() {
})
}
sc.reader = &serverConnReader{
reader := &serverConnReader{
sc: sc,
}
sc.reader.initialize()
reader.initialize()
err := sc.runInner()
@@ -375,9 +376,7 @@ func (sc *ServerConn) run() {
sc.nconn.Close()
}
if sc.reader != nil {
sc.reader.wait()
}
reader.wait()
if sc.session != nil {
sc.session.removeConn(sc)
@@ -400,7 +399,6 @@ func (sc *ServerConn) runInner() error {
req.res <- sc.handleRequestOuter(req.req)
case err := <-sc.chReadError:
sc.reader = nil
return err
case ss := <-sc.chRemoveSession:
@@ -629,7 +627,7 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
}
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout))
err2 := sc.reader.conn.WriteResponse(res)
err2 := sc.conn.WriteResponse(res)
if err == nil && err2 != nil {
err = err2
}

View File

@@ -36,28 +36,30 @@ func isSwitchReadFuncError(err error) bool {
}
type serverConnReader struct {
sc *ServerConn
conn *conn.Conn
sc *ServerConn
done chan struct{}
}
func (cr *serverConnReader) initialize() {
cr.done = make(chan struct{})
go cr.run()
}
func (cr *serverConnReader) wait() {
for {
select {
case <-cr.sc.chReadError:
return
case req := <-cr.sc.chRequest:
req.res <- fmt.Errorf("terminated")
}
}
<-cr.done
}
func (cr *serverConnReader) run() {
cr.sc.chReadError <- cr.runInner()
defer close(cr.done)
err := cr.runInner()
select {
case cr.sc.chReadError <- err:
case <-cr.sc.ctx.Done():
}
}
func (cr *serverConnReader) runInner() error {
@@ -71,7 +73,7 @@ func (cr *serverConnReader) runInner() error {
}
}
cr.conn = conn.NewConn(bufio.NewReader(rw), rw)
cr.sc.conn = conn.NewConn(bufio.NewReader(rw), rw)
readFunc := cr.readFuncStandard
@@ -171,7 +173,7 @@ func (cr *serverConnReader) readFuncStandard() error {
cr.sc.nconn.SetReadDeadline(time.Time{})
for {
what, err := cr.conn.Read()
what, err := cr.sc.conn.Read()
if err != nil {
return err
}
@@ -180,7 +182,12 @@ func (cr *serverConnReader) readFuncStandard() error {
case *base.Request:
cres := make(chan error)
req := readReq{req: what, res: cres}
cr.sc.chRequest <- req
select {
case cr.sc.chRequest <- req:
case <-cr.sc.ctx.Done():
return fmt.Errorf("terminated")
}
err = <-cres
if err != nil {
@@ -207,7 +214,7 @@ func (cr *serverConnReader) readFuncTCP() error {
cr.sc.nconn.SetReadDeadline(time.Now().Add(cr.sc.s.ReadTimeout))
}
what, err := cr.conn.Read()
what, err := cr.sc.conn.Read()
if err != nil {
return err
}
@@ -216,7 +223,12 @@ func (cr *serverConnReader) readFuncTCP() error {
case *base.Request:
cres := make(chan error)
req := readReq{req: what, res: cres}
cr.sc.chRequest <- req
select {
case cr.sc.chRequest <- req:
case <-cr.sc.ctx.Done():
return fmt.Errorf("terminated")
}
err = <-cres
if err != nil {

View File

@@ -179,7 +179,7 @@ func (sf *serverSessionFormat) writePacketRTPInQueueTCP(payload []byte) error {
sf.sm.ss.tcpFrame.Channel = sf.sm.tcpChannel
sf.sm.ss.tcpFrame.Payload = payload
sf.sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sf.sm.ss.s.WriteTimeout))
err := sf.sm.ss.tcpConn.reader.conn.WriteInterleavedFrame(sf.sm.ss.tcpFrame, sf.sm.ss.tcpBuffer)
err := sf.sm.ss.tcpConn.conn.WriteInterleavedFrame(sf.sm.ss.tcpFrame, sf.sm.ss.tcpBuffer)
if err != nil {
return err
}

View File

@@ -483,7 +483,7 @@ func (sm *serverSessionMedia) writePacketRTCPInQueueTCP(payload []byte) error {
sm.ss.tcpFrame.Channel = sm.tcpChannel + 1
sm.ss.tcpFrame.Payload = payload
sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout))
err := sm.ss.tcpConn.reader.conn.WriteInterleavedFrame(sm.ss.tcpFrame, sm.ss.tcpBuffer)
err := sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.ss.tcpFrame, sm.ss.tcpBuffer)
if err != nil {
return err
}