fix various race conditions when writing packets to closed clients or server sessions (#684)

This commit is contained in:
Alessandro Ros
2025-01-19 12:07:59 +01:00
committed by GitHub
parent b2cfa93d13
commit ca6286321d
12 changed files with 438 additions and 219 deletions

View File

@@ -14,30 +14,29 @@ type asyncProcessor struct {
buffer *ringbuffer.RingBuffer
stopError error
stopped chan struct{}
chStopped chan struct{}
}
func (w *asyncProcessor) initialize() {
w.buffer, _ = ringbuffer.New(uint64(w.bufferSize))
}
func (w *asyncProcessor) start() {
w.running = true
w.stopped = make(chan struct{})
go w.run()
}
func (w *asyncProcessor) stop() {
func (w *asyncProcessor) close() {
if w.running {
w.buffer.Close()
<-w.stopped
w.running = false
<-w.chStopped
}
}
func (w *asyncProcessor) start() {
w.running = true
w.chStopped = make(chan struct{})
go w.run()
}
func (w *asyncProcessor) run() {
w.stopError = w.runInner()
close(w.stopped)
close(w.chStopped)
}
func (w *asyncProcessor) runInner() error {

View File

@@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestAsyncProcessorStopAfterError(t *testing.T) {
func TestAsyncProcessorCloseAfterError(t *testing.T) {
p := &asyncProcessor{bufferSize: 8}
p.initialize()
@@ -17,8 +17,8 @@ func TestAsyncProcessorStopAfterError(t *testing.T) {
p.start()
<-p.stopped
<-p.chStopped
require.EqualError(t, p.stopError, "ok")
p.stop()
p.close()
}

View File

@@ -13,6 +13,7 @@ import (
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@@ -340,6 +341,7 @@ type Client struct {
keepaliveTimer *time.Timer
closeError error
writer *asyncProcessor
writerMutex sync.RWMutex
reader *clientReader
timeDecoder *rtptime.GlobalDecoder2
mustClose bool
@@ -560,8 +562,8 @@ func (c *Client) runInner() error {
}()
chWriterError := func() chan struct{} {
if c.writer != nil && c.writer.running {
return c.writer.stopped
if c.writer != nil {
return c.writer.chStopped
}
return nil
}()
@@ -721,7 +723,7 @@ func (c *Client) handleServerRequest(req *base.Request) error {
func (c *Client) doClose() {
if c.state == clientStatePlay || c.state == clientStateRecord {
c.writer.stop()
c.destroyWriter()
c.stopTransportRoutines()
}
@@ -848,22 +850,6 @@ func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *base.UR
}
func (c *Client) startTransportRoutines() {
// allocate writer here because it's needed by RTCP receiver / sender
if c.state == clientStateRecord || c.backChannelSetupped {
c.writer = &asyncProcessor{
bufferSize: c.WriteQueueSize,
}
c.writer.initialize()
} else {
// when reading, buffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
c.writer = &asyncProcessor{
bufferSize: 8,
}
c.writer.initialize()
}
c.timeDecoder = rtptime.NewGlobalDecoder2()
for _, cm := range c.setuppedMedias {
@@ -913,6 +899,39 @@ func (c *Client) stopTransportRoutines() {
c.timeDecoder = nil
}
func (c *Client) createWriter() {
c.writerMutex.Lock()
c.writer = &asyncProcessor{
bufferSize: func() int {
if c.state == clientStateRecord || c.backChannelSetupped {
return c.WriteQueueSize
}
// when reading, buffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
return 8
}(),
}
c.writer.initialize()
c.writerMutex.Unlock()
}
func (c *Client) startWriter() {
c.writer.start()
}
func (c *Client) destroyWriter() {
c.writer.close()
c.writerMutex.Lock()
c.writer = nil
c.writerMutex.Unlock()
}
func (c *Client) connOpen() error {
if c.nconn != nil {
return nil
@@ -1389,7 +1408,7 @@ func (c *Client) doSetup(
return nil, liberrors.ErrClientUDPPortsNotConsecutive{}
}
err = cm.allocateUDPListeners(
err = cm.createUDPListeners(
false,
nil,
net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)),
@@ -1544,7 +1563,7 @@ func (c *Client) doSetup(
readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
}
err = cm.allocateUDPListeners(
err = cm.createUDPListeners(
true,
readIP,
net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[0]), 10)),
@@ -1680,6 +1699,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
c.state = clientStatePlay
c.startTransportRoutines()
c.createWriter()
// Range is mandatory in Parrot Streaming Server
if ra == nil {
@@ -1704,12 +1724,14 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
Header: header,
}, false)
if err != nil {
c.destroyWriter()
c.stopTransportRoutines()
c.state = clientStatePrePlay
return nil, err
}
if res.StatusCode != base.StatusOK {
c.destroyWriter()
c.stopTransportRoutines()
c.state = clientStatePrePlay
return nil, liberrors.ErrClientBadStatusCode{
@@ -1731,7 +1753,8 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
}
}
c.writer.start()
c.startWriter()
c.lastRange = ra
return res, nil
@@ -1761,18 +1784,21 @@ func (c *Client) doRecord() (*base.Response, error) {
c.state = clientStateRecord
c.startTransportRoutines()
c.createWriter()
res, err := c.do(&base.Request{
Method: base.Record,
URL: c.baseURL,
}, false)
if err != nil {
c.destroyWriter()
c.stopTransportRoutines()
c.state = clientStatePreRecord
return nil, err
}
if res.StatusCode != base.StatusOK {
c.destroyWriter()
c.stopTransportRoutines()
c.state = clientStatePreRecord
return nil, liberrors.ErrClientBadStatusCode{
@@ -1780,7 +1806,7 @@ func (c *Client) doRecord() (*base.Response, error) {
}
}
c.writer.start()
c.startWriter()
return nil, nil
}
@@ -1808,19 +1834,21 @@ func (c *Client) doPause() (*base.Response, error) {
return nil, err
}
c.writer.stop()
c.destroyWriter()
res, err := c.do(&base.Request{
Method: base.Pause,
URL: c.baseURL,
}, false)
if err != nil {
c.writer.start()
c.createWriter()
c.startWriter()
return nil, err
}
if res.StatusCode != base.StatusOK {
c.writer.start()
c.createWriter()
c.startWriter()
return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage,
}
@@ -1918,6 +1946,13 @@ func (c *Client) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.Packet,
default:
}
c.writerMutex.RLock()
defer c.writerMutex.RUnlock()
if c.writer == nil {
return nil
}
cm := c.setuppedMedias[medi]
cf := cm.formats[pkt.PayloadType]
@@ -1946,6 +1981,13 @@ func (c *Client) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error
default:
}
c.writerMutex.RLock()
defer c.writerMutex.RUnlock()
if c.writer == nil {
return nil
}
cm := c.setuppedMedias[medi]
ok := c.writer.push(func() error {

View File

@@ -59,7 +59,7 @@ func (cm *clientMedia) close() {
}
}
func (cm *clientMedia) allocateUDPListeners(
func (cm *clientMedia) createUDPListeners(
multicastEnable bool,
multicastSourceIP net.IP,
rtpAddress string,
@@ -94,7 +94,7 @@ func (cm *clientMedia) allocateUDPListeners(
}
var err error
cm.udpRTPListener, cm.udpRTCPListener, err = allocateUDPListenerPair(cm.c)
cm.udpRTPListener, cm.udpRTCPListener, err = createUDPListenerPair(cm.c)
return err
}

View File

@@ -1813,7 +1813,7 @@ func TestClientPlayRedirect(t *testing.T) {
}
}
func TestClientPlayPause(t *testing.T) {
func TestClientPlayPausePlay(t *testing.T) {
writeFrames := func(inTH *headers.Transport, conn *conn.Conn) (chan struct{}, chan struct{}) {
writerTerminate := make(chan struct{})
writerDone := make(chan struct{})

View File

@@ -482,7 +482,7 @@ func TestClientRecordSocketError(t *testing.T) {
}
}
func TestClientRecordPauseSerial(t *testing.T) {
func TestClientRecordPauseRecordSerial(t *testing.T) {
for _, transport := range []string{
"udp",
"tcp",
@@ -618,6 +618,9 @@ func TestClientRecordPauseSerial(t *testing.T) {
_, err = c.Pause()
require.NoError(t, err)
err = c.WritePacketRTP(medi, &testRTPPacket)
require.NoError(t, err)
_, err = c.Record()
require.NoError(t, err)
@@ -627,6 +630,187 @@ func TestClientRecordPauseSerial(t *testing.T) {
}
}
func TestClientRecordPauseRecordParallel(t *testing.T) {
for _, transport := range []string{
"udp",
"tcp",
} {
t.Run(transport, func(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
nconn, err2 := l.Accept()
require.NoError(t, err2)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err2 := conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Options, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Announce),
string(base.Setup),
string(base.Record),
string(base.Pause),
}, ", ")},
},
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Announce, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Setup, req.Method)
var inTH headers.Transport
err2 = inTH.Unmarshal(req.Header["Transport"])
require.NoError(t, err2)
th := headers.Transport{
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
}
if transport == "udp" {
th.Protocol = headers.TransportProtocolUDP
th.ServerPorts = &[2]int{34556, 34557}
th.ClientPorts = inTH.ClientPorts
} else {
th.Protocol = headers.TransportProtocolTCP
th.InterleavedIDs = inTH.InterleavedIDs
}
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Record, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
if transport == "tcp" {
_, err2 = conn.ReadInterleavedFrame()
require.NoError(t, err2)
}
req, err2 = readRequestIgnoreFrames(conn)
require.NoError(t, err2)
require.Equal(t, base.Pause, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Record, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
if transport == "tcp" {
_, err2 = conn.ReadInterleavedFrame()
require.NoError(t, err2)
}
req, err2 = readRequestIgnoreFrames(conn)
require.NoError(t, err2)
require.Equal(t, base.Teardown, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
}()
c := Client{
Transport: func() *Transport {
if transport == "udp" {
v := TransportUDP
return &v
}
v := TransportTCP
return &v
}(),
}
medi := testH264Media
medias := []*description.Media{medi}
err = record(&c, "rtsp://localhost:8554/teststream", medias, nil)
require.NoError(t, err)
defer c.Close()
writerTerminate := make(chan struct{})
writerDone := make(chan struct{})
defer func() {
close(writerTerminate)
<-writerDone
}()
go func() {
defer close(writerDone)
ti := time.NewTicker(50 * time.Millisecond)
defer ti.Stop()
for {
select {
case <-ti.C:
err2 := c.WritePacketRTP(medi, &testRTPPacket)
require.NoError(t, err2)
case <-writerTerminate:
return
}
}
}()
time.Sleep(500 * time.Millisecond)
_, err = c.Pause()
require.NoError(t, err)
time.Sleep(500 * time.Millisecond)
_, err = c.Record()
require.NoError(t, err)
time.Sleep(500 * time.Millisecond)
})
}
}
func TestClientRecordAutomaticProtocol(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err)

View File

@@ -24,7 +24,7 @@ func randInRange(maxVal int) (int, error) {
return int(n.Int64()), nil
}
func allocateUDPListenerPair(c *Client) (*clientUDPListener, *clientUDPListener, error) {
func createUDPListenerPair(c *Client) (*clientUDPListener, *clientUDPListener, error) {
// choose two consecutive ports in range 65535-10000
// RTP port must be even and RTCP port odd
for {

View File

@@ -103,7 +103,7 @@ func (cr *serverConnReader) readFuncTCP() error {
// reset deadline
cr.sc.nconn.SetReadDeadline(time.Time{})
cr.sc.session.startWriter()
cr.sc.session.asyncStartWriter()
for {
if cr.sc.session.state == ServerSessionStateRecord {

View File

@@ -22,7 +22,7 @@ func (h *serverMulticastWriter) initialize() error {
return err
}
rtpl, rtcpl, err := allocateUDPListenerMulticastPair(
rtpl, rtcpl, err := createUDPListenerMulticastPair(
h.s.ListenPacket,
h.s.WriteTimeout,
h.s.MulticastRTPPort,
@@ -60,7 +60,7 @@ func (h *serverMulticastWriter) initialize() error {
func (h *serverMulticastWriter) close() {
h.rtpl.close()
h.rtcpl.close()
h.writer.stop()
h.writer.close()
}
func (h *serverMulticastWriter) ip() net.IP {

View File

@@ -1528,62 +1528,7 @@ func TestServerPlayTCPResponseBeforeFrames(t *testing.T) {
require.NoError(t, err)
}
func TestServerPlayPlayPlay(t *testing.T) {
var stream *ServerStream
s := &Server{
Handler: &testServerHandler{
onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onPlay: func(_ *ServerHandlerOnPlayCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
UDPRTPAddress: "127.0.0.1:8000",
UDPRTCPAddress: "127.0.0.1:8001",
RTSPAddress: "localhost:8554",
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
stream = NewServerStream(s, &description.Session{Medias: []*description.Media{testH264Media}})
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer nconn.Close()
conn := conn.NewConn(nconn)
desc := doDescribe(t, conn)
inTH := &headers.Transport{
Protocol: headers.TransportProtocolUDP,
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
Mode: transportModePtr(headers.TransportModePlay),
ClientPorts: &[2]int{30450, 30451},
}
res, _ := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "")
session := readSession(t, res)
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
}
func TestServerPlayPlayPausePlay(t *testing.T) {
func TestServerPlayPause(t *testing.T) {
var stream *ServerStream
writerStarted := false
writerDone := make(chan struct{})
@@ -1666,91 +1611,105 @@ func TestServerPlayPlayPausePlay(t *testing.T) {
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
doPause(t, conn, "rtsp://localhost:8554/teststream", session)
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
}
func TestServerPlayPlayPausePause(t *testing.T) {
var stream *ServerStream
writerDone := make(chan struct{})
writerTerminate := make(chan struct{})
func TestServerPlayPlayPausePausePlay(t *testing.T) {
for _, ca := range []string{"stream", "direct"} {
t.Run(ca, func(t *testing.T) {
var stream *ServerStream
writerStarted := false
writerDone := make(chan struct{})
writerTerminate := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onConnClose: func(_ *ServerHandlerOnConnCloseCtx) {
close(writerTerminate)
<-writerDone
},
onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onPlay: func(_ *ServerHandlerOnPlayCtx) (*base.Response, error) {
go func() {
defer close(writerDone)
s := &Server{
Handler: &testServerHandler{
onConnClose: func(_ *ServerHandlerOnConnCloseCtx) {
close(writerTerminate)
<-writerDone
},
onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
if !writerStarted {
writerStarted = true
go func() {
defer close(writerDone)
ti := time.NewTicker(50 * time.Millisecond)
defer ti.Stop()
ti := time.NewTicker(50 * time.Millisecond)
defer ti.Stop()
for {
select {
case <-ti.C:
err := stream.WritePacketRTP(stream.Description().Medias[0], &testRTPPacket)
require.NoError(t, err)
case <-writerTerminate:
return
for {
select {
case <-ti.C:
if ca == "stream" {
err := stream.WritePacketRTP(stream.Description().Medias[0], &testRTPPacket)
require.NoError(t, err)
} else {
err := ctx.Session.WritePacketRTP(stream.Description().Medias[0], &testRTPPacket)
require.NoError(t, err)
}
case <-writerTerminate:
return
}
}
}()
}
}
}()
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
onPause: func(_ *ServerHandlerOnPauseCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
RTSPAddress: "localhost:8554",
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
onPause: func(_ *ServerHandlerOnPauseCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
RTSPAddress: "localhost:8554",
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
stream = NewServerStream(s, &description.Session{Medias: []*description.Media{testH264Media}})
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer nconn.Close()
conn := conn.NewConn(nconn)
desc := doDescribe(t, conn)
inTH := &headers.Transport{
Protocol: headers.TransportProtocolTCP,
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
Mode: transportModePtr(headers.TransportModePlay),
InterleavedIDs: &[2]int{0, 1},
}
res, _ := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "")
session := readSession(t, res)
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
doPause(t, conn, "rtsp://localhost:8554/teststream", session)
doPause(t, conn, "rtsp://localhost:8554/teststream", session)
time.Sleep(500 * time.Millisecond)
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
})
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
stream = NewServerStream(s, &description.Session{Medias: []*description.Media{testH264Media}})
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer nconn.Close()
conn := conn.NewConn(nconn)
desc := doDescribe(t, conn)
inTH := &headers.Transport{
Protocol: headers.TransportProtocolTCP,
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
Mode: transportModePtr(headers.TransportModePlay),
InterleavedIDs: &[2]int{0, 1},
}
res, _ := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "")
session := readSession(t, res)
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
doPause(t, conn, "rtsp://localhost:8554/teststream", session)
doPause(t, conn, "rtsp://localhost:8554/teststream", session)
}
func TestServerPlayTimeout(t *testing.T) {

View File

@@ -7,6 +7,7 @@ import (
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@@ -253,14 +254,15 @@ type ServerSession struct {
udpLastPacketTime *int64 // publish
udpCheckStreamTimer *time.Timer
writer *asyncProcessor
writerMutex sync.RWMutex
timeDecoder *rtptime.GlobalDecoder2
tcpFrame *base.InterleavedFrame
tcpBuffer []byte
// in
chHandleRequest chan sessionRequestReq
chRemoveConn chan *ServerConn
chStartWriter chan struct{}
chHandleRequest chan sessionRequestReq
chRemoveConn chan *ServerConn
chAsyncStartWriter chan struct{}
}
func (ss *ServerSession) initialize() {
@@ -278,7 +280,7 @@ func (ss *ServerSession) initialize() {
ss.chHandleRequest = make(chan sessionRequestReq)
ss.chRemoveConn = make(chan *ServerConn)
ss.chStartWriter = make(chan struct{})
ss.chAsyncStartWriter = make(chan struct{})
ss.s.wg.Add(1)
go ss.run()
@@ -575,6 +577,37 @@ func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) err
return liberrors.ErrServerInvalidState{AllowedList: allowedList, State: ss.state}
}
func (ss *ServerSession) createWriter() {
ss.writerMutex.Lock()
ss.writer = &asyncProcessor{
bufferSize: func() int {
if ss.state == ServerSessionStatePrePlay {
return ss.s.WriteQueueSize
}
// when recording, writeBuffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
return 8
}(),
}
ss.writer.initialize()
ss.writerMutex.Unlock()
}
func (ss *ServerSession) startWriter() {
ss.writer.start()
}
func (ss *ServerSession) destroyWriter() {
ss.writerMutex.Lock()
ss.writer = nil
ss.writerMutex.Unlock()
}
func (ss *ServerSession) run() {
defer ss.s.wg.Done()
@@ -611,7 +644,7 @@ func (ss *ServerSession) run() {
}
if ss.writer != nil {
ss.writer.stop()
ss.destroyWriter()
}
ss.s.closeSession(ss)
@@ -627,8 +660,8 @@ func (ss *ServerSession) run() {
func (ss *ServerSession) runInner() error {
for {
chWriterError := func() chan struct{} {
if ss.writer != nil && ss.writer.running {
return ss.writer.stopped
if ss.writer != nil {
return ss.writer.chStopped
}
return nil
}()
@@ -703,11 +736,11 @@ func (ss *ServerSession) runInner() error {
return liberrors.ErrServerSessionNotInUse{}
}
case <-ss.chStartWriter:
case <-ss.chAsyncStartWriter:
if (ss.state == ServerSessionStateRecord ||
ss.state == ServerSessionStatePlay) &&
*ss.setuppedTransport == TransportTCP {
ss.writer.start()
ss.startWriter()
}
case <-ss.udpCheckStreamTimer.C:
@@ -1118,15 +1151,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}, liberrors.ErrServerPathHasChanged{Prev: ss.setuppedPath, Cur: path}
}
// allocate writeBuffer before calling OnPlay().
// in this way it's possible to call ServerSession.WritePacket*()
// inside the callback.
if ss.state != ServerSessionStatePlay &&
*ss.setuppedTransport != TransportUDPMulticast {
ss.writer = &asyncProcessor{
bufferSize: ss.s.WriteQueueSize,
}
ss.writer.initialize()
ss.createWriter()
}
res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{
@@ -1138,8 +1165,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
})
if res.StatusCode != base.StatusOK {
if ss.state != ServerSessionStatePlay {
ss.writer = nil
if ss.state != ServerSessionStatePlay &&
*ss.setuppedTransport != TransportUDPMulticast {
ss.destroyWriter()
}
return res, err
}
@@ -1167,7 +1195,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
ss.writer.start()
ss.startWriter()
case TransportUDPMulticast:
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
@@ -1175,7 +1203,8 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
default: // TCP
ss.tcpConn = sc
err = switchReadFuncError{true}
// writer.start() is called by ServerConn after the response has been sent
// startWriter() is called by ServerConn, through chAsyncStartWriter,
// after the response has been sent
}
ss.setuppedStream.readerSetActive(ss)
@@ -1218,16 +1247,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}, liberrors.ErrServerPathHasChanged{Prev: ss.setuppedPath, Cur: path}
}
// allocate writeBuffer before calling OnRecord().
// in this way it's possible to call ServerSession.WritePacket*()
// inside the callback.
// when recording, writeBuffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
ss.writer = &asyncProcessor{
bufferSize: 8,
}
ss.writer.initialize()
ss.createWriter()
res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{
Session: ss,
@@ -1238,7 +1258,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
})
if res.StatusCode != base.StatusOK {
ss.writer = nil
ss.destroyWriter()
return res, err
}
@@ -1261,12 +1281,13 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
ss.writer.start()
ss.startWriter()
default: // TCP
ss.tcpConn = sc
err = switchReadFuncError{true}
// runWriter() is called by conn after sending the response
// startWriter() is called by ServerConn, through chAsyncStartWriter,
// after the response has been sent
}
return res, err
@@ -1297,6 +1318,8 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}
if ss.state == ServerSessionStatePlay || ss.state == ServerSessionStateRecord {
ss.destroyWriter()
if ss.setuppedStream != nil {
ss.setuppedStream.readerSetInactive(ss)
}
@@ -1305,8 +1328,6 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
sm.stop()
}
ss.writer.stop()
ss.timeDecoder = nil
switch ss.state {
@@ -1446,6 +1467,13 @@ func (ss *ServerSession) writePacketRTP(medi *description.Media, payloadType uin
sm := ss.setuppedMedias[medi]
sf := sm.formats[payloadType]
ss.writerMutex.RLock()
defer ss.writerMutex.RUnlock()
if ss.writer == nil {
return nil
}
ok := ss.writer.push(func() error {
return sf.writePacketRTPInQueue(byts)
})
@@ -1471,6 +1499,13 @@ func (ss *ServerSession) WritePacketRTP(medi *description.Media, pkt *rtp.Packet
func (ss *ServerSession) writePacketRTCP(medi *description.Media, byts []byte) error {
sm := ss.setuppedMedias[medi]
ss.writerMutex.RLock()
defer ss.writerMutex.RUnlock()
if ss.writer == nil {
return nil
}
ok := ss.writer.push(func() error {
return sm.writePacketRTCPInQueue(byts)
})
@@ -1543,9 +1578,9 @@ func (ss *ServerSession) removeConn(sc *ServerConn) {
}
}
func (ss *ServerSession) startWriter() {
func (ss *ServerSession) asyncStartWriter() {
select {
case ss.chStartWriter <- struct{}{}:
case ss.chAsyncStartWriter <- struct{}{}:
case <-ss.ctx.Done():
}
}

View File

@@ -25,7 +25,7 @@ func (p *clientAddr) fill(ip net.IP, port int) {
}
}
func allocateUDPListenerMulticastPair(
func createUDPListenerMulticastPair(
listenPacket func(network, address string) (net.PacketConn, error),
writeTimeout time.Duration,
multicastRTPPort int,