fix various race conditions when writing packets to closed clients or server sessions (#684)

This commit is contained in:
Alessandro Ros
2025-01-19 12:07:59 +01:00
committed by GitHub
parent b2cfa93d13
commit ca6286321d
12 changed files with 438 additions and 219 deletions

View File

@@ -7,6 +7,7 @@ import (
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@@ -253,14 +254,15 @@ type ServerSession struct {
udpLastPacketTime *int64 // publish
udpCheckStreamTimer *time.Timer
writer *asyncProcessor
writerMutex sync.RWMutex
timeDecoder *rtptime.GlobalDecoder2
tcpFrame *base.InterleavedFrame
tcpBuffer []byte
// in
chHandleRequest chan sessionRequestReq
chRemoveConn chan *ServerConn
chStartWriter chan struct{}
chHandleRequest chan sessionRequestReq
chRemoveConn chan *ServerConn
chAsyncStartWriter chan struct{}
}
func (ss *ServerSession) initialize() {
@@ -278,7 +280,7 @@ func (ss *ServerSession) initialize() {
ss.chHandleRequest = make(chan sessionRequestReq)
ss.chRemoveConn = make(chan *ServerConn)
ss.chStartWriter = make(chan struct{})
ss.chAsyncStartWriter = make(chan struct{})
ss.s.wg.Add(1)
go ss.run()
@@ -575,6 +577,37 @@ func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) err
return liberrors.ErrServerInvalidState{AllowedList: allowedList, State: ss.state}
}
func (ss *ServerSession) createWriter() {
ss.writerMutex.Lock()
ss.writer = &asyncProcessor{
bufferSize: func() int {
if ss.state == ServerSessionStatePrePlay {
return ss.s.WriteQueueSize
}
// when recording, writeBuffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
return 8
}(),
}
ss.writer.initialize()
ss.writerMutex.Unlock()
}
func (ss *ServerSession) startWriter() {
ss.writer.start()
}
func (ss *ServerSession) destroyWriter() {
ss.writerMutex.Lock()
ss.writer = nil
ss.writerMutex.Unlock()
}
func (ss *ServerSession) run() {
defer ss.s.wg.Done()
@@ -611,7 +644,7 @@ func (ss *ServerSession) run() {
}
if ss.writer != nil {
ss.writer.stop()
ss.destroyWriter()
}
ss.s.closeSession(ss)
@@ -627,8 +660,8 @@ func (ss *ServerSession) run() {
func (ss *ServerSession) runInner() error {
for {
chWriterError := func() chan struct{} {
if ss.writer != nil && ss.writer.running {
return ss.writer.stopped
if ss.writer != nil {
return ss.writer.chStopped
}
return nil
}()
@@ -703,11 +736,11 @@ func (ss *ServerSession) runInner() error {
return liberrors.ErrServerSessionNotInUse{}
}
case <-ss.chStartWriter:
case <-ss.chAsyncStartWriter:
if (ss.state == ServerSessionStateRecord ||
ss.state == ServerSessionStatePlay) &&
*ss.setuppedTransport == TransportTCP {
ss.writer.start()
ss.startWriter()
}
case <-ss.udpCheckStreamTimer.C:
@@ -1118,15 +1151,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}, liberrors.ErrServerPathHasChanged{Prev: ss.setuppedPath, Cur: path}
}
// allocate writeBuffer before calling OnPlay().
// in this way it's possible to call ServerSession.WritePacket*()
// inside the callback.
if ss.state != ServerSessionStatePlay &&
*ss.setuppedTransport != TransportUDPMulticast {
ss.writer = &asyncProcessor{
bufferSize: ss.s.WriteQueueSize,
}
ss.writer.initialize()
ss.createWriter()
}
res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{
@@ -1138,8 +1165,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
})
if res.StatusCode != base.StatusOK {
if ss.state != ServerSessionStatePlay {
ss.writer = nil
if ss.state != ServerSessionStatePlay &&
*ss.setuppedTransport != TransportUDPMulticast {
ss.destroyWriter()
}
return res, err
}
@@ -1167,7 +1195,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
ss.writer.start()
ss.startWriter()
case TransportUDPMulticast:
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
@@ -1175,7 +1203,8 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
default: // TCP
ss.tcpConn = sc
err = switchReadFuncError{true}
// writer.start() is called by ServerConn after the response has been sent
// startWriter() is called by ServerConn, through chAsyncStartWriter,
// after the response has been sent
}
ss.setuppedStream.readerSetActive(ss)
@@ -1218,16 +1247,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}, liberrors.ErrServerPathHasChanged{Prev: ss.setuppedPath, Cur: path}
}
// allocate writeBuffer before calling OnRecord().
// in this way it's possible to call ServerSession.WritePacket*()
// inside the callback.
// when recording, writeBuffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
ss.writer = &asyncProcessor{
bufferSize: 8,
}
ss.writer.initialize()
ss.createWriter()
res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{
Session: ss,
@@ -1238,7 +1258,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
})
if res.StatusCode != base.StatusOK {
ss.writer = nil
ss.destroyWriter()
return res, err
}
@@ -1261,12 +1281,13 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
ss.writer.start()
ss.startWriter()
default: // TCP
ss.tcpConn = sc
err = switchReadFuncError{true}
// runWriter() is called by conn after sending the response
// startWriter() is called by ServerConn, through chAsyncStartWriter,
// after the response has been sent
}
return res, err
@@ -1297,6 +1318,8 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}
if ss.state == ServerSessionStatePlay || ss.state == ServerSessionStateRecord {
ss.destroyWriter()
if ss.setuppedStream != nil {
ss.setuppedStream.readerSetInactive(ss)
}
@@ -1305,8 +1328,6 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
sm.stop()
}
ss.writer.stop()
ss.timeDecoder = nil
switch ss.state {
@@ -1446,6 +1467,13 @@ func (ss *ServerSession) writePacketRTP(medi *description.Media, payloadType uin
sm := ss.setuppedMedias[medi]
sf := sm.formats[payloadType]
ss.writerMutex.RLock()
defer ss.writerMutex.RUnlock()
if ss.writer == nil {
return nil
}
ok := ss.writer.push(func() error {
return sf.writePacketRTPInQueue(byts)
})
@@ -1471,6 +1499,13 @@ func (ss *ServerSession) WritePacketRTP(medi *description.Media, pkt *rtp.Packet
func (ss *ServerSession) writePacketRTCP(medi *description.Media, byts []byte) error {
sm := ss.setuppedMedias[medi]
ss.writerMutex.RLock()
defer ss.writerMutex.RUnlock()
if ss.writer == nil {
return nil
}
ok := ss.writer.push(func() error {
return sm.writePacketRTCPInQueue(byts)
})
@@ -1543,9 +1578,9 @@ func (ss *ServerSession) removeConn(sc *ServerConn) {
}
}
func (ss *ServerSession) startWriter() {
func (ss *ServerSession) asyncStartWriter() {
select {
case ss.chStartWriter <- struct{}{}:
case ss.chAsyncStartWriter <- struct{}{}:
case <-ss.ctx.Done():
}
}