server: fix panic due to regression in #887 (#892)

This happened when writing a TCP packet to a conn after a read error.
This commit is contained in:
Alessandro Ros
2025-09-16 10:36:06 +02:00
committed by GitHub
parent 3f446ed08d
commit 8c7e4320bc
7 changed files with 69 additions and 42 deletions

View File

@@ -882,6 +882,7 @@ func (c *Client) runInner() error {
} }
case err := <-c.chReadError: case err := <-c.chReadError:
c.reader.close()
c.reader = nil c.reader = nil
return err return err
@@ -918,6 +919,7 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) {
} }
case err := <-c.chReadError: case err := <-c.chReadError:
c.reader.close()
c.reader = nil c.reader = nil
return nil, err return nil, err
@@ -975,7 +977,7 @@ func (c *Client) doClose() {
if c.reader != nil { if c.reader != nil {
c.nconn.Close() c.nconn.Close()
c.reader.wait() c.reader.close()
c.reader = nil c.reader = nil
c.nconn = nil c.nconn = nil
c.conn = nil c.conn = nil

View File

@@ -12,9 +12,16 @@ type clientReader struct {
mutex sync.Mutex mutex sync.Mutex
allowInterleavedFrames bool allowInterleavedFrames bool
terminate chan struct{}
done chan struct{}
} }
func (r *clientReader) start() { func (r *clientReader) start() {
r.terminate = make(chan struct{})
r.done = make(chan struct{})
go r.run() go r.run()
} }
@@ -24,19 +31,20 @@ func (r *clientReader) setAllowInterleavedFrames(v bool) {
r.allowInterleavedFrames = v r.allowInterleavedFrames = v
} }
func (r *clientReader) wait() { func (r *clientReader) close() {
for { close(r.terminate)
select { <-r.done
case <-r.c.chResponse:
case <-r.c.chRequest:
case <-r.c.chReadError:
return
}
}
} }
func (r *clientReader) run() { func (r *clientReader) run() {
r.c.chReadError <- r.runInner() defer close(r.done)
err := r.runInner()
select {
case r.c.chReadError <- err:
case <-r.terminate:
}
} }
func (r *clientReader) runInner() error { func (r *clientReader) runInner() error {
@@ -48,10 +56,16 @@ func (r *clientReader) runInner() error {
switch what := what.(type) { switch what := what.(type) {
case *base.Response: case *base.Response:
r.c.chResponse <- what select {
case r.c.chResponse <- what:
case <-r.terminate:
}
case *base.Request: case *base.Request:
r.c.chRequest <- what select {
case r.c.chRequest <- what:
case <-r.terminate:
}
case *base.InterleavedFrame: case *base.InterleavedFrame:
r.mutex.Lock() r.mutex.Lock()

View File

@@ -401,6 +401,9 @@ func (s *Server) runInner() error {
sc.Close() sc.Close()
case req := <-s.chHandleHTTPChannel: case req := <-s.chHandleHTTPChannel:
if _, ok := s.conns[req.sc]; !ok {
continue
}
if !req.write { if !req.write {
req.sc.httpReadTunnelID = req.tunnelID req.sc.httpReadTunnelID = req.tunnelID
s.httpReadChannels[req.sc] = req.res s.httpReadChannels[req.sc] = req.res
@@ -551,8 +554,6 @@ func (s *Server) handleHTTPChannel(req sessionHandleHTTPChannelReq) error {
select { select {
case s.chHandleHTTPChannel <- req: case s.chHandleHTTPChannel <- req:
case <-req.sc.ctx.Done():
return fmt.Errorf("terminated")
case <-s.ctx.Done(): case <-s.ctx.Done():
return fmt.Errorf("terminated") return fmt.Errorf("terminated")
} }

View File

@@ -16,6 +16,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/auth" "github.com/bluenviron/gortsplib/v4/pkg/auth"
"github.com/bluenviron/gortsplib/v4/pkg/base" "github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/bytecounter" "github.com/bluenviron/gortsplib/v4/pkg/bytecounter"
"github.com/bluenviron/gortsplib/v4/pkg/conn"
"github.com/bluenviron/gortsplib/v4/pkg/description" "github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/headers" "github.com/bluenviron/gortsplib/v4/pkg/headers"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors" "github.com/bluenviron/gortsplib/v4/pkg/liberrors"
@@ -205,8 +206,8 @@ type ServerConn struct {
userData interface{} userData interface{}
remoteAddr *net.TCPAddr remoteAddr *net.TCPAddr
bc *bytecounter.ByteCounter bc *bytecounter.ByteCounter
conn *conn.Conn
session *ServerSession session *ServerSession
reader *serverConnReader
authNonce string authNonce string
httpReadBuf *bufio.Reader httpReadBuf *bufio.Reader
httpReadTunnelID string httpReadTunnelID string
@@ -362,10 +363,10 @@ func (sc *ServerConn) run() {
}) })
} }
sc.reader = &serverConnReader{ reader := &serverConnReader{
sc: sc, sc: sc,
} }
sc.reader.initialize() reader.initialize()
err := sc.runInner() err := sc.runInner()
@@ -375,9 +376,7 @@ func (sc *ServerConn) run() {
sc.nconn.Close() sc.nconn.Close()
} }
if sc.reader != nil { reader.wait()
sc.reader.wait()
}
if sc.session != nil { if sc.session != nil {
sc.session.removeConn(sc) sc.session.removeConn(sc)
@@ -400,7 +399,6 @@ func (sc *ServerConn) runInner() error {
req.res <- sc.handleRequestOuter(req.req) req.res <- sc.handleRequestOuter(req.req)
case err := <-sc.chReadError: case err := <-sc.chReadError:
sc.reader = nil
return err return err
case ss := <-sc.chRemoveSession: case ss := <-sc.chRemoveSession:
@@ -629,7 +627,7 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
} }
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout)) sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout))
err2 := sc.reader.conn.WriteResponse(res) err2 := sc.conn.WriteResponse(res)
if err == nil && err2 != nil { if err == nil && err2 != nil {
err = err2 err = err2
} }

View File

@@ -37,27 +37,29 @@ func isSwitchReadFuncError(err error) bool {
type serverConnReader struct { type serverConnReader struct {
sc *ServerConn sc *ServerConn
conn *conn.Conn
done chan struct{}
} }
func (cr *serverConnReader) initialize() { func (cr *serverConnReader) initialize() {
cr.done = make(chan struct{})
go cr.run() go cr.run()
} }
func (cr *serverConnReader) wait() { func (cr *serverConnReader) wait() {
for { <-cr.done
select {
case <-cr.sc.chReadError:
return
case req := <-cr.sc.chRequest:
req.res <- fmt.Errorf("terminated")
}
}
} }
func (cr *serverConnReader) run() { func (cr *serverConnReader) run() {
cr.sc.chReadError <- cr.runInner() defer close(cr.done)
err := cr.runInner()
select {
case cr.sc.chReadError <- err:
case <-cr.sc.ctx.Done():
}
} }
func (cr *serverConnReader) runInner() error { func (cr *serverConnReader) runInner() error {
@@ -71,7 +73,7 @@ func (cr *serverConnReader) runInner() error {
} }
} }
cr.conn = conn.NewConn(bufio.NewReader(rw), rw) cr.sc.conn = conn.NewConn(bufio.NewReader(rw), rw)
readFunc := cr.readFuncStandard readFunc := cr.readFuncStandard
@@ -171,7 +173,7 @@ func (cr *serverConnReader) readFuncStandard() error {
cr.sc.nconn.SetReadDeadline(time.Time{}) cr.sc.nconn.SetReadDeadline(time.Time{})
for { for {
what, err := cr.conn.Read() what, err := cr.sc.conn.Read()
if err != nil { if err != nil {
return err return err
} }
@@ -180,7 +182,12 @@ func (cr *serverConnReader) readFuncStandard() error {
case *base.Request: case *base.Request:
cres := make(chan error) cres := make(chan error)
req := readReq{req: what, res: cres} req := readReq{req: what, res: cres}
cr.sc.chRequest <- req
select {
case cr.sc.chRequest <- req:
case <-cr.sc.ctx.Done():
return fmt.Errorf("terminated")
}
err = <-cres err = <-cres
if err != nil { if err != nil {
@@ -207,7 +214,7 @@ func (cr *serverConnReader) readFuncTCP() error {
cr.sc.nconn.SetReadDeadline(time.Now().Add(cr.sc.s.ReadTimeout)) cr.sc.nconn.SetReadDeadline(time.Now().Add(cr.sc.s.ReadTimeout))
} }
what, err := cr.conn.Read() what, err := cr.sc.conn.Read()
if err != nil { if err != nil {
return err return err
} }
@@ -216,7 +223,12 @@ func (cr *serverConnReader) readFuncTCP() error {
case *base.Request: case *base.Request:
cres := make(chan error) cres := make(chan error)
req := readReq{req: what, res: cres} req := readReq{req: what, res: cres}
cr.sc.chRequest <- req
select {
case cr.sc.chRequest <- req:
case <-cr.sc.ctx.Done():
return fmt.Errorf("terminated")
}
err = <-cres err = <-cres
if err != nil { if err != nil {

View File

@@ -179,7 +179,7 @@ func (sf *serverSessionFormat) writePacketRTPInQueueTCP(payload []byte) error {
sf.sm.ss.tcpFrame.Channel = sf.sm.tcpChannel sf.sm.ss.tcpFrame.Channel = sf.sm.tcpChannel
sf.sm.ss.tcpFrame.Payload = payload sf.sm.ss.tcpFrame.Payload = payload
sf.sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sf.sm.ss.s.WriteTimeout)) sf.sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sf.sm.ss.s.WriteTimeout))
err := sf.sm.ss.tcpConn.reader.conn.WriteInterleavedFrame(sf.sm.ss.tcpFrame, sf.sm.ss.tcpBuffer) err := sf.sm.ss.tcpConn.conn.WriteInterleavedFrame(sf.sm.ss.tcpFrame, sf.sm.ss.tcpBuffer)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -483,7 +483,7 @@ func (sm *serverSessionMedia) writePacketRTCPInQueueTCP(payload []byte) error {
sm.ss.tcpFrame.Channel = sm.tcpChannel + 1 sm.ss.tcpFrame.Channel = sm.tcpChannel + 1
sm.ss.tcpFrame.Payload = payload sm.ss.tcpFrame.Payload = payload
sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout)) sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout))
err := sm.ss.tcpConn.reader.conn.WriteInterleavedFrame(sm.ss.tcpFrame, sm.ss.tcpBuffer) err := sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.ss.tcpFrame, sm.ss.tcpBuffer)
if err != nil { if err != nil {
return err return err
} }