mirror of
https://github.com/aler9/gortsplib
synced 2025-09-27 03:25:52 +08:00
fix various race conditions when writing packets to closed clients or server sessions (#684)
This commit is contained in:
@@ -14,30 +14,29 @@ type asyncProcessor struct {
|
||||
buffer *ringbuffer.RingBuffer
|
||||
stopError error
|
||||
|
||||
stopped chan struct{}
|
||||
chStopped chan struct{}
|
||||
}
|
||||
|
||||
func (w *asyncProcessor) initialize() {
|
||||
w.buffer, _ = ringbuffer.New(uint64(w.bufferSize))
|
||||
}
|
||||
|
||||
func (w *asyncProcessor) start() {
|
||||
w.running = true
|
||||
w.stopped = make(chan struct{})
|
||||
go w.run()
|
||||
}
|
||||
|
||||
func (w *asyncProcessor) stop() {
|
||||
func (w *asyncProcessor) close() {
|
||||
if w.running {
|
||||
w.buffer.Close()
|
||||
<-w.stopped
|
||||
w.running = false
|
||||
<-w.chStopped
|
||||
}
|
||||
}
|
||||
|
||||
func (w *asyncProcessor) start() {
|
||||
w.running = true
|
||||
w.chStopped = make(chan struct{})
|
||||
go w.run()
|
||||
}
|
||||
|
||||
func (w *asyncProcessor) run() {
|
||||
w.stopError = w.runInner()
|
||||
close(w.stopped)
|
||||
close(w.chStopped)
|
||||
}
|
||||
|
||||
func (w *asyncProcessor) runInner() error {
|
||||
|
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAsyncProcessorStopAfterError(t *testing.T) {
|
||||
func TestAsyncProcessorCloseAfterError(t *testing.T) {
|
||||
p := &asyncProcessor{bufferSize: 8}
|
||||
p.initialize()
|
||||
|
||||
@@ -17,8 +17,8 @@ func TestAsyncProcessorStopAfterError(t *testing.T) {
|
||||
|
||||
p.start()
|
||||
|
||||
<-p.stopped
|
||||
<-p.chStopped
|
||||
require.EqualError(t, p.stopError, "ok")
|
||||
|
||||
p.stop()
|
||||
p.close()
|
||||
}
|
||||
|
94
client.go
94
client.go
@@ -13,6 +13,7 @@ import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -340,6 +341,7 @@ type Client struct {
|
||||
keepaliveTimer *time.Timer
|
||||
closeError error
|
||||
writer *asyncProcessor
|
||||
writerMutex sync.RWMutex
|
||||
reader *clientReader
|
||||
timeDecoder *rtptime.GlobalDecoder2
|
||||
mustClose bool
|
||||
@@ -560,8 +562,8 @@ func (c *Client) runInner() error {
|
||||
}()
|
||||
|
||||
chWriterError := func() chan struct{} {
|
||||
if c.writer != nil && c.writer.running {
|
||||
return c.writer.stopped
|
||||
if c.writer != nil {
|
||||
return c.writer.chStopped
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
@@ -721,7 +723,7 @@ func (c *Client) handleServerRequest(req *base.Request) error {
|
||||
|
||||
func (c *Client) doClose() {
|
||||
if c.state == clientStatePlay || c.state == clientStateRecord {
|
||||
c.writer.stop()
|
||||
c.destroyWriter()
|
||||
c.stopTransportRoutines()
|
||||
}
|
||||
|
||||
@@ -848,22 +850,6 @@ func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *base.UR
|
||||
}
|
||||
|
||||
func (c *Client) startTransportRoutines() {
|
||||
// allocate writer here because it's needed by RTCP receiver / sender
|
||||
if c.state == clientStateRecord || c.backChannelSetupped {
|
||||
c.writer = &asyncProcessor{
|
||||
bufferSize: c.WriteQueueSize,
|
||||
}
|
||||
c.writer.initialize()
|
||||
} else {
|
||||
// when reading, buffer is only used to send RTCP receiver reports,
|
||||
// that are much smaller than RTP packets and are sent at a fixed interval.
|
||||
// decrease RAM consumption by allocating less buffers.
|
||||
c.writer = &asyncProcessor{
|
||||
bufferSize: 8,
|
||||
}
|
||||
c.writer.initialize()
|
||||
}
|
||||
|
||||
c.timeDecoder = rtptime.NewGlobalDecoder2()
|
||||
|
||||
for _, cm := range c.setuppedMedias {
|
||||
@@ -913,6 +899,39 @@ func (c *Client) stopTransportRoutines() {
|
||||
c.timeDecoder = nil
|
||||
}
|
||||
|
||||
func (c *Client) createWriter() {
|
||||
c.writerMutex.Lock()
|
||||
|
||||
c.writer = &asyncProcessor{
|
||||
bufferSize: func() int {
|
||||
if c.state == clientStateRecord || c.backChannelSetupped {
|
||||
return c.WriteQueueSize
|
||||
}
|
||||
|
||||
// when reading, buffer is only used to send RTCP receiver reports,
|
||||
// that are much smaller than RTP packets and are sent at a fixed interval.
|
||||
// decrease RAM consumption by allocating less buffers.
|
||||
return 8
|
||||
}(),
|
||||
}
|
||||
|
||||
c.writer.initialize()
|
||||
|
||||
c.writerMutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) startWriter() {
|
||||
c.writer.start()
|
||||
}
|
||||
|
||||
func (c *Client) destroyWriter() {
|
||||
c.writer.close()
|
||||
|
||||
c.writerMutex.Lock()
|
||||
c.writer = nil
|
||||
c.writerMutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) connOpen() error {
|
||||
if c.nconn != nil {
|
||||
return nil
|
||||
@@ -1389,7 +1408,7 @@ func (c *Client) doSetup(
|
||||
return nil, liberrors.ErrClientUDPPortsNotConsecutive{}
|
||||
}
|
||||
|
||||
err = cm.allocateUDPListeners(
|
||||
err = cm.createUDPListeners(
|
||||
false,
|
||||
nil,
|
||||
net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)),
|
||||
@@ -1544,7 +1563,7 @@ func (c *Client) doSetup(
|
||||
readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP
|
||||
}
|
||||
|
||||
err = cm.allocateUDPListeners(
|
||||
err = cm.createUDPListeners(
|
||||
true,
|
||||
readIP,
|
||||
net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[0]), 10)),
|
||||
@@ -1680,6 +1699,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
|
||||
|
||||
c.state = clientStatePlay
|
||||
c.startTransportRoutines()
|
||||
c.createWriter()
|
||||
|
||||
// Range is mandatory in Parrot Streaming Server
|
||||
if ra == nil {
|
||||
@@ -1704,12 +1724,14 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
|
||||
Header: header,
|
||||
}, false)
|
||||
if err != nil {
|
||||
c.destroyWriter()
|
||||
c.stopTransportRoutines()
|
||||
c.state = clientStatePrePlay
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res.StatusCode != base.StatusOK {
|
||||
c.destroyWriter()
|
||||
c.stopTransportRoutines()
|
||||
c.state = clientStatePrePlay
|
||||
return nil, liberrors.ErrClientBadStatusCode{
|
||||
@@ -1731,7 +1753,8 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
|
||||
}
|
||||
}
|
||||
|
||||
c.writer.start()
|
||||
c.startWriter()
|
||||
|
||||
c.lastRange = ra
|
||||
|
||||
return res, nil
|
||||
@@ -1761,18 +1784,21 @@ func (c *Client) doRecord() (*base.Response, error) {
|
||||
|
||||
c.state = clientStateRecord
|
||||
c.startTransportRoutines()
|
||||
c.createWriter()
|
||||
|
||||
res, err := c.do(&base.Request{
|
||||
Method: base.Record,
|
||||
URL: c.baseURL,
|
||||
}, false)
|
||||
if err != nil {
|
||||
c.destroyWriter()
|
||||
c.stopTransportRoutines()
|
||||
c.state = clientStatePreRecord
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res.StatusCode != base.StatusOK {
|
||||
c.destroyWriter()
|
||||
c.stopTransportRoutines()
|
||||
c.state = clientStatePreRecord
|
||||
return nil, liberrors.ErrClientBadStatusCode{
|
||||
@@ -1780,7 +1806,7 @@ func (c *Client) doRecord() (*base.Response, error) {
|
||||
}
|
||||
}
|
||||
|
||||
c.writer.start()
|
||||
c.startWriter()
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1808,19 +1834,21 @@ func (c *Client) doPause() (*base.Response, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.writer.stop()
|
||||
c.destroyWriter()
|
||||
|
||||
res, err := c.do(&base.Request{
|
||||
Method: base.Pause,
|
||||
URL: c.baseURL,
|
||||
}, false)
|
||||
if err != nil {
|
||||
c.writer.start()
|
||||
c.createWriter()
|
||||
c.startWriter()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res.StatusCode != base.StatusOK {
|
||||
c.writer.start()
|
||||
c.createWriter()
|
||||
c.startWriter()
|
||||
return nil, liberrors.ErrClientBadStatusCode{
|
||||
Code: res.StatusCode, Message: res.StatusMessage,
|
||||
}
|
||||
@@ -1918,6 +1946,13 @@ func (c *Client) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.Packet,
|
||||
default:
|
||||
}
|
||||
|
||||
c.writerMutex.RLock()
|
||||
defer c.writerMutex.RUnlock()
|
||||
|
||||
if c.writer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cm := c.setuppedMedias[medi]
|
||||
cf := cm.formats[pkt.PayloadType]
|
||||
|
||||
@@ -1946,6 +1981,13 @@ func (c *Client) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error
|
||||
default:
|
||||
}
|
||||
|
||||
c.writerMutex.RLock()
|
||||
defer c.writerMutex.RUnlock()
|
||||
|
||||
if c.writer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cm := c.setuppedMedias[medi]
|
||||
|
||||
ok := c.writer.push(func() error {
|
||||
|
@@ -59,7 +59,7 @@ func (cm *clientMedia) close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *clientMedia) allocateUDPListeners(
|
||||
func (cm *clientMedia) createUDPListeners(
|
||||
multicastEnable bool,
|
||||
multicastSourceIP net.IP,
|
||||
rtpAddress string,
|
||||
@@ -94,7 +94,7 @@ func (cm *clientMedia) allocateUDPListeners(
|
||||
}
|
||||
|
||||
var err error
|
||||
cm.udpRTPListener, cm.udpRTCPListener, err = allocateUDPListenerPair(cm.c)
|
||||
cm.udpRTPListener, cm.udpRTCPListener, err = createUDPListenerPair(cm.c)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@@ -1813,7 +1813,7 @@ func TestClientPlayRedirect(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientPlayPause(t *testing.T) {
|
||||
func TestClientPlayPausePlay(t *testing.T) {
|
||||
writeFrames := func(inTH *headers.Transport, conn *conn.Conn) (chan struct{}, chan struct{}) {
|
||||
writerTerminate := make(chan struct{})
|
||||
writerDone := make(chan struct{})
|
||||
|
@@ -482,7 +482,7 @@ func TestClientRecordSocketError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientRecordPauseSerial(t *testing.T) {
|
||||
func TestClientRecordPauseRecordSerial(t *testing.T) {
|
||||
for _, transport := range []string{
|
||||
"udp",
|
||||
"tcp",
|
||||
@@ -618,6 +618,9 @@ func TestClientRecordPauseSerial(t *testing.T) {
|
||||
_, err = c.Pause()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = c.WritePacketRTP(medi, &testRTPPacket)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = c.Record()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -627,6 +630,187 @@ func TestClientRecordPauseSerial(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientRecordPauseRecordParallel(t *testing.T) {
|
||||
for _, transport := range []string{
|
||||
"udp",
|
||||
"tcp",
|
||||
} {
|
||||
t.Run(transport, func(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "localhost:8554")
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
|
||||
serverDone := make(chan struct{})
|
||||
defer func() { <-serverDone }()
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
|
||||
nconn, err2 := l.Accept()
|
||||
require.NoError(t, err2)
|
||||
defer nconn.Close()
|
||||
conn := conn.NewConn(nconn)
|
||||
|
||||
req, err2 := conn.ReadRequest()
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, base.Options, req.Method)
|
||||
|
||||
err2 = conn.WriteResponse(&base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
Header: base.Header{
|
||||
"Public": base.HeaderValue{strings.Join([]string{
|
||||
string(base.Announce),
|
||||
string(base.Setup),
|
||||
string(base.Record),
|
||||
string(base.Pause),
|
||||
}, ", ")},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err2)
|
||||
|
||||
req, err2 = conn.ReadRequest()
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, base.Announce, req.Method)
|
||||
|
||||
err2 = conn.WriteResponse(&base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
})
|
||||
require.NoError(t, err2)
|
||||
|
||||
req, err2 = conn.ReadRequest()
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, base.Setup, req.Method)
|
||||
|
||||
var inTH headers.Transport
|
||||
err2 = inTH.Unmarshal(req.Header["Transport"])
|
||||
require.NoError(t, err2)
|
||||
|
||||
th := headers.Transport{
|
||||
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
|
||||
}
|
||||
|
||||
if transport == "udp" {
|
||||
th.Protocol = headers.TransportProtocolUDP
|
||||
th.ServerPorts = &[2]int{34556, 34557}
|
||||
th.ClientPorts = inTH.ClientPorts
|
||||
} else {
|
||||
th.Protocol = headers.TransportProtocolTCP
|
||||
th.InterleavedIDs = inTH.InterleavedIDs
|
||||
}
|
||||
|
||||
err2 = conn.WriteResponse(&base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
Header: base.Header{
|
||||
"Transport": th.Marshal(),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err2)
|
||||
|
||||
req, err2 = conn.ReadRequest()
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, base.Record, req.Method)
|
||||
|
||||
err2 = conn.WriteResponse(&base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
})
|
||||
require.NoError(t, err2)
|
||||
|
||||
if transport == "tcp" {
|
||||
_, err2 = conn.ReadInterleavedFrame()
|
||||
require.NoError(t, err2)
|
||||
}
|
||||
|
||||
req, err2 = readRequestIgnoreFrames(conn)
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, base.Pause, req.Method)
|
||||
|
||||
err2 = conn.WriteResponse(&base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
})
|
||||
require.NoError(t, err2)
|
||||
|
||||
req, err2 = conn.ReadRequest()
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, base.Record, req.Method)
|
||||
|
||||
err2 = conn.WriteResponse(&base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
})
|
||||
require.NoError(t, err2)
|
||||
|
||||
if transport == "tcp" {
|
||||
_, err2 = conn.ReadInterleavedFrame()
|
||||
require.NoError(t, err2)
|
||||
}
|
||||
|
||||
req, err2 = readRequestIgnoreFrames(conn)
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, base.Teardown, req.Method)
|
||||
|
||||
err2 = conn.WriteResponse(&base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
})
|
||||
require.NoError(t, err2)
|
||||
}()
|
||||
|
||||
c := Client{
|
||||
Transport: func() *Transport {
|
||||
if transport == "udp" {
|
||||
v := TransportUDP
|
||||
return &v
|
||||
}
|
||||
v := TransportTCP
|
||||
return &v
|
||||
}(),
|
||||
}
|
||||
|
||||
medi := testH264Media
|
||||
medias := []*description.Media{medi}
|
||||
|
||||
err = record(&c, "rtsp://localhost:8554/teststream", medias, nil)
|
||||
require.NoError(t, err)
|
||||
defer c.Close()
|
||||
|
||||
writerTerminate := make(chan struct{})
|
||||
writerDone := make(chan struct{})
|
||||
|
||||
defer func() {
|
||||
close(writerTerminate)
|
||||
<-writerDone
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer close(writerDone)
|
||||
|
||||
ti := time.NewTicker(50 * time.Millisecond)
|
||||
defer ti.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ti.C:
|
||||
err2 := c.WritePacketRTP(medi, &testRTPPacket)
|
||||
require.NoError(t, err2)
|
||||
|
||||
case <-writerTerminate:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
_, err = c.Pause()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
_, err = c.Record()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientRecordAutomaticProtocol(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "localhost:8554")
|
||||
require.NoError(t, err)
|
||||
|
@@ -24,7 +24,7 @@ func randInRange(maxVal int) (int, error) {
|
||||
return int(n.Int64()), nil
|
||||
}
|
||||
|
||||
func allocateUDPListenerPair(c *Client) (*clientUDPListener, *clientUDPListener, error) {
|
||||
func createUDPListenerPair(c *Client) (*clientUDPListener, *clientUDPListener, error) {
|
||||
// choose two consecutive ports in range 65535-10000
|
||||
// RTP port must be even and RTCP port odd
|
||||
for {
|
||||
|
@@ -103,7 +103,7 @@ func (cr *serverConnReader) readFuncTCP() error {
|
||||
// reset deadline
|
||||
cr.sc.nconn.SetReadDeadline(time.Time{})
|
||||
|
||||
cr.sc.session.startWriter()
|
||||
cr.sc.session.asyncStartWriter()
|
||||
|
||||
for {
|
||||
if cr.sc.session.state == ServerSessionStateRecord {
|
||||
|
@@ -22,7 +22,7 @@ func (h *serverMulticastWriter) initialize() error {
|
||||
return err
|
||||
}
|
||||
|
||||
rtpl, rtcpl, err := allocateUDPListenerMulticastPair(
|
||||
rtpl, rtcpl, err := createUDPListenerMulticastPair(
|
||||
h.s.ListenPacket,
|
||||
h.s.WriteTimeout,
|
||||
h.s.MulticastRTPPort,
|
||||
@@ -60,7 +60,7 @@ func (h *serverMulticastWriter) initialize() error {
|
||||
func (h *serverMulticastWriter) close() {
|
||||
h.rtpl.close()
|
||||
h.rtcpl.close()
|
||||
h.writer.stop()
|
||||
h.writer.close()
|
||||
}
|
||||
|
||||
func (h *serverMulticastWriter) ip() net.IP {
|
||||
|
@@ -1528,62 +1528,7 @@ func TestServerPlayTCPResponseBeforeFrames(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestServerPlayPlayPlay(t *testing.T) {
|
||||
var stream *ServerStream
|
||||
|
||||
s := &Server{
|
||||
Handler: &testServerHandler{
|
||||
onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, stream, nil
|
||||
},
|
||||
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, stream, nil
|
||||
},
|
||||
onPlay: func(_ *ServerHandlerOnPlayCtx) (*base.Response, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
UDPRTPAddress: "127.0.0.1:8000",
|
||||
UDPRTCPAddress: "127.0.0.1:8001",
|
||||
RTSPAddress: "localhost:8554",
|
||||
}
|
||||
|
||||
err := s.Start()
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
stream = NewServerStream(s, &description.Session{Medias: []*description.Media{testH264Media}})
|
||||
defer stream.Close()
|
||||
|
||||
nconn, err := net.Dial("tcp", "localhost:8554")
|
||||
require.NoError(t, err)
|
||||
defer nconn.Close()
|
||||
conn := conn.NewConn(nconn)
|
||||
|
||||
desc := doDescribe(t, conn)
|
||||
|
||||
inTH := &headers.Transport{
|
||||
Protocol: headers.TransportProtocolUDP,
|
||||
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
|
||||
Mode: transportModePtr(headers.TransportModePlay),
|
||||
ClientPorts: &[2]int{30450, 30451},
|
||||
}
|
||||
|
||||
res, _ := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "")
|
||||
|
||||
session := readSession(t, res)
|
||||
|
||||
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
|
||||
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
|
||||
}
|
||||
|
||||
func TestServerPlayPlayPausePlay(t *testing.T) {
|
||||
func TestServerPlayPause(t *testing.T) {
|
||||
var stream *ServerStream
|
||||
writerStarted := false
|
||||
writerDone := make(chan struct{})
|
||||
@@ -1666,91 +1611,105 @@ func TestServerPlayPlayPausePlay(t *testing.T) {
|
||||
|
||||
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
|
||||
doPause(t, conn, "rtsp://localhost:8554/teststream", session)
|
||||
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
|
||||
}
|
||||
|
||||
func TestServerPlayPlayPausePause(t *testing.T) {
|
||||
var stream *ServerStream
|
||||
writerDone := make(chan struct{})
|
||||
writerTerminate := make(chan struct{})
|
||||
func TestServerPlayPlayPausePausePlay(t *testing.T) {
|
||||
for _, ca := range []string{"stream", "direct"} {
|
||||
t.Run(ca, func(t *testing.T) {
|
||||
var stream *ServerStream
|
||||
writerStarted := false
|
||||
writerDone := make(chan struct{})
|
||||
writerTerminate := make(chan struct{})
|
||||
|
||||
s := &Server{
|
||||
Handler: &testServerHandler{
|
||||
onConnClose: func(_ *ServerHandlerOnConnCloseCtx) {
|
||||
close(writerTerminate)
|
||||
<-writerDone
|
||||
},
|
||||
onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, stream, nil
|
||||
},
|
||||
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, stream, nil
|
||||
},
|
||||
onPlay: func(_ *ServerHandlerOnPlayCtx) (*base.Response, error) {
|
||||
go func() {
|
||||
defer close(writerDone)
|
||||
s := &Server{
|
||||
Handler: &testServerHandler{
|
||||
onConnClose: func(_ *ServerHandlerOnConnCloseCtx) {
|
||||
close(writerTerminate)
|
||||
<-writerDone
|
||||
},
|
||||
onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, stream, nil
|
||||
},
|
||||
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, stream, nil
|
||||
},
|
||||
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
|
||||
if !writerStarted {
|
||||
writerStarted = true
|
||||
go func() {
|
||||
defer close(writerDone)
|
||||
|
||||
ti := time.NewTicker(50 * time.Millisecond)
|
||||
defer ti.Stop()
|
||||
ti := time.NewTicker(50 * time.Millisecond)
|
||||
defer ti.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ti.C:
|
||||
err := stream.WritePacketRTP(stream.Description().Medias[0], &testRTPPacket)
|
||||
require.NoError(t, err)
|
||||
case <-writerTerminate:
|
||||
return
|
||||
for {
|
||||
select {
|
||||
case <-ti.C:
|
||||
if ca == "stream" {
|
||||
err := stream.WritePacketRTP(stream.Description().Medias[0], &testRTPPacket)
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
err := ctx.Session.WritePacketRTP(stream.Description().Medias[0], &testRTPPacket)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
case <-writerTerminate:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, nil
|
||||
},
|
||||
onPause: func(_ *ServerHandlerOnPauseCtx) (*base.Response, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
RTSPAddress: "localhost:8554",
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, nil
|
||||
},
|
||||
onPause: func(_ *ServerHandlerOnPauseCtx) (*base.Response, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
RTSPAddress: "localhost:8554",
|
||||
}
|
||||
|
||||
err := s.Start()
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
stream = NewServerStream(s, &description.Session{Medias: []*description.Media{testH264Media}})
|
||||
defer stream.Close()
|
||||
|
||||
nconn, err := net.Dial("tcp", "localhost:8554")
|
||||
require.NoError(t, err)
|
||||
defer nconn.Close()
|
||||
conn := conn.NewConn(nconn)
|
||||
|
||||
desc := doDescribe(t, conn)
|
||||
|
||||
inTH := &headers.Transport{
|
||||
Protocol: headers.TransportProtocolTCP,
|
||||
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
|
||||
Mode: transportModePtr(headers.TransportModePlay),
|
||||
InterleavedIDs: &[2]int{0, 1},
|
||||
}
|
||||
|
||||
res, _ := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "")
|
||||
|
||||
session := readSession(t, res)
|
||||
|
||||
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
|
||||
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
|
||||
doPause(t, conn, "rtsp://localhost:8554/teststream", session)
|
||||
doPause(t, conn, "rtsp://localhost:8554/teststream", session)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
|
||||
})
|
||||
}
|
||||
|
||||
err := s.Start()
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
stream = NewServerStream(s, &description.Session{Medias: []*description.Media{testH264Media}})
|
||||
defer stream.Close()
|
||||
|
||||
nconn, err := net.Dial("tcp", "localhost:8554")
|
||||
require.NoError(t, err)
|
||||
defer nconn.Close()
|
||||
conn := conn.NewConn(nconn)
|
||||
|
||||
desc := doDescribe(t, conn)
|
||||
|
||||
inTH := &headers.Transport{
|
||||
Protocol: headers.TransportProtocolTCP,
|
||||
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
|
||||
Mode: transportModePtr(headers.TransportModePlay),
|
||||
InterleavedIDs: &[2]int{0, 1},
|
||||
}
|
||||
|
||||
res, _ := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "")
|
||||
|
||||
session := readSession(t, res)
|
||||
|
||||
doPlay(t, conn, "rtsp://localhost:8554/teststream", session)
|
||||
|
||||
doPause(t, conn, "rtsp://localhost:8554/teststream", session)
|
||||
|
||||
doPause(t, conn, "rtsp://localhost:8554/teststream", session)
|
||||
}
|
||||
|
||||
func TestServerPlayTimeout(t *testing.T) {
|
||||
|
@@ -7,6 +7,7 @@ import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -253,14 +254,15 @@ type ServerSession struct {
|
||||
udpLastPacketTime *int64 // publish
|
||||
udpCheckStreamTimer *time.Timer
|
||||
writer *asyncProcessor
|
||||
writerMutex sync.RWMutex
|
||||
timeDecoder *rtptime.GlobalDecoder2
|
||||
tcpFrame *base.InterleavedFrame
|
||||
tcpBuffer []byte
|
||||
|
||||
// in
|
||||
chHandleRequest chan sessionRequestReq
|
||||
chRemoveConn chan *ServerConn
|
||||
chStartWriter chan struct{}
|
||||
chHandleRequest chan sessionRequestReq
|
||||
chRemoveConn chan *ServerConn
|
||||
chAsyncStartWriter chan struct{}
|
||||
}
|
||||
|
||||
func (ss *ServerSession) initialize() {
|
||||
@@ -278,7 +280,7 @@ func (ss *ServerSession) initialize() {
|
||||
|
||||
ss.chHandleRequest = make(chan sessionRequestReq)
|
||||
ss.chRemoveConn = make(chan *ServerConn)
|
||||
ss.chStartWriter = make(chan struct{})
|
||||
ss.chAsyncStartWriter = make(chan struct{})
|
||||
|
||||
ss.s.wg.Add(1)
|
||||
go ss.run()
|
||||
@@ -575,6 +577,37 @@ func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) err
|
||||
return liberrors.ErrServerInvalidState{AllowedList: allowedList, State: ss.state}
|
||||
}
|
||||
|
||||
func (ss *ServerSession) createWriter() {
|
||||
ss.writerMutex.Lock()
|
||||
|
||||
ss.writer = &asyncProcessor{
|
||||
bufferSize: func() int {
|
||||
if ss.state == ServerSessionStatePrePlay {
|
||||
return ss.s.WriteQueueSize
|
||||
}
|
||||
|
||||
// when recording, writeBuffer is only used to send RTCP receiver reports,
|
||||
// that are much smaller than RTP packets and are sent at a fixed interval.
|
||||
// decrease RAM consumption by allocating less buffers.
|
||||
return 8
|
||||
}(),
|
||||
}
|
||||
|
||||
ss.writer.initialize()
|
||||
|
||||
ss.writerMutex.Unlock()
|
||||
}
|
||||
|
||||
func (ss *ServerSession) startWriter() {
|
||||
ss.writer.start()
|
||||
}
|
||||
|
||||
func (ss *ServerSession) destroyWriter() {
|
||||
ss.writerMutex.Lock()
|
||||
ss.writer = nil
|
||||
ss.writerMutex.Unlock()
|
||||
}
|
||||
|
||||
func (ss *ServerSession) run() {
|
||||
defer ss.s.wg.Done()
|
||||
|
||||
@@ -611,7 +644,7 @@ func (ss *ServerSession) run() {
|
||||
}
|
||||
|
||||
if ss.writer != nil {
|
||||
ss.writer.stop()
|
||||
ss.destroyWriter()
|
||||
}
|
||||
|
||||
ss.s.closeSession(ss)
|
||||
@@ -627,8 +660,8 @@ func (ss *ServerSession) run() {
|
||||
func (ss *ServerSession) runInner() error {
|
||||
for {
|
||||
chWriterError := func() chan struct{} {
|
||||
if ss.writer != nil && ss.writer.running {
|
||||
return ss.writer.stopped
|
||||
if ss.writer != nil {
|
||||
return ss.writer.chStopped
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
@@ -703,11 +736,11 @@ func (ss *ServerSession) runInner() error {
|
||||
return liberrors.ErrServerSessionNotInUse{}
|
||||
}
|
||||
|
||||
case <-ss.chStartWriter:
|
||||
case <-ss.chAsyncStartWriter:
|
||||
if (ss.state == ServerSessionStateRecord ||
|
||||
ss.state == ServerSessionStatePlay) &&
|
||||
*ss.setuppedTransport == TransportTCP {
|
||||
ss.writer.start()
|
||||
ss.startWriter()
|
||||
}
|
||||
|
||||
case <-ss.udpCheckStreamTimer.C:
|
||||
@@ -1118,15 +1151,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
||||
}, liberrors.ErrServerPathHasChanged{Prev: ss.setuppedPath, Cur: path}
|
||||
}
|
||||
|
||||
// allocate writeBuffer before calling OnPlay().
|
||||
// in this way it's possible to call ServerSession.WritePacket*()
|
||||
// inside the callback.
|
||||
if ss.state != ServerSessionStatePlay &&
|
||||
*ss.setuppedTransport != TransportUDPMulticast {
|
||||
ss.writer = &asyncProcessor{
|
||||
bufferSize: ss.s.WriteQueueSize,
|
||||
}
|
||||
ss.writer.initialize()
|
||||
ss.createWriter()
|
||||
}
|
||||
|
||||
res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{
|
||||
@@ -1138,8 +1165,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
||||
})
|
||||
|
||||
if res.StatusCode != base.StatusOK {
|
||||
if ss.state != ServerSessionStatePlay {
|
||||
ss.writer = nil
|
||||
if ss.state != ServerSessionStatePlay &&
|
||||
*ss.setuppedTransport != TransportUDPMulticast {
|
||||
ss.destroyWriter()
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
@@ -1167,7 +1195,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
||||
switch *ss.setuppedTransport {
|
||||
case TransportUDP:
|
||||
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
|
||||
ss.writer.start()
|
||||
ss.startWriter()
|
||||
|
||||
case TransportUDPMulticast:
|
||||
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
|
||||
@@ -1175,7 +1203,8 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
||||
default: // TCP
|
||||
ss.tcpConn = sc
|
||||
err = switchReadFuncError{true}
|
||||
// writer.start() is called by ServerConn after the response has been sent
|
||||
// startWriter() is called by ServerConn, through chAsyncStartWriter,
|
||||
// after the response has been sent
|
||||
}
|
||||
|
||||
ss.setuppedStream.readerSetActive(ss)
|
||||
@@ -1218,16 +1247,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
||||
}, liberrors.ErrServerPathHasChanged{Prev: ss.setuppedPath, Cur: path}
|
||||
}
|
||||
|
||||
// allocate writeBuffer before calling OnRecord().
|
||||
// in this way it's possible to call ServerSession.WritePacket*()
|
||||
// inside the callback.
|
||||
// when recording, writeBuffer is only used to send RTCP receiver reports,
|
||||
// that are much smaller than RTP packets and are sent at a fixed interval.
|
||||
// decrease RAM consumption by allocating less buffers.
|
||||
ss.writer = &asyncProcessor{
|
||||
bufferSize: 8,
|
||||
}
|
||||
ss.writer.initialize()
|
||||
ss.createWriter()
|
||||
|
||||
res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{
|
||||
Session: ss,
|
||||
@@ -1238,7 +1258,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
||||
})
|
||||
|
||||
if res.StatusCode != base.StatusOK {
|
||||
ss.writer = nil
|
||||
ss.destroyWriter()
|
||||
return res, err
|
||||
}
|
||||
|
||||
@@ -1261,12 +1281,13 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
||||
switch *ss.setuppedTransport {
|
||||
case TransportUDP:
|
||||
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
|
||||
ss.writer.start()
|
||||
ss.startWriter()
|
||||
|
||||
default: // TCP
|
||||
ss.tcpConn = sc
|
||||
err = switchReadFuncError{true}
|
||||
// runWriter() is called by conn after sending the response
|
||||
// startWriter() is called by ServerConn, through chAsyncStartWriter,
|
||||
// after the response has been sent
|
||||
}
|
||||
|
||||
return res, err
|
||||
@@ -1297,6 +1318,8 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
||||
}
|
||||
|
||||
if ss.state == ServerSessionStatePlay || ss.state == ServerSessionStateRecord {
|
||||
ss.destroyWriter()
|
||||
|
||||
if ss.setuppedStream != nil {
|
||||
ss.setuppedStream.readerSetInactive(ss)
|
||||
}
|
||||
@@ -1305,8 +1328,6 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
||||
sm.stop()
|
||||
}
|
||||
|
||||
ss.writer.stop()
|
||||
|
||||
ss.timeDecoder = nil
|
||||
|
||||
switch ss.state {
|
||||
@@ -1446,6 +1467,13 @@ func (ss *ServerSession) writePacketRTP(medi *description.Media, payloadType uin
|
||||
sm := ss.setuppedMedias[medi]
|
||||
sf := sm.formats[payloadType]
|
||||
|
||||
ss.writerMutex.RLock()
|
||||
defer ss.writerMutex.RUnlock()
|
||||
|
||||
if ss.writer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ok := ss.writer.push(func() error {
|
||||
return sf.writePacketRTPInQueue(byts)
|
||||
})
|
||||
@@ -1471,6 +1499,13 @@ func (ss *ServerSession) WritePacketRTP(medi *description.Media, pkt *rtp.Packet
|
||||
func (ss *ServerSession) writePacketRTCP(medi *description.Media, byts []byte) error {
|
||||
sm := ss.setuppedMedias[medi]
|
||||
|
||||
ss.writerMutex.RLock()
|
||||
defer ss.writerMutex.RUnlock()
|
||||
|
||||
if ss.writer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ok := ss.writer.push(func() error {
|
||||
return sm.writePacketRTCPInQueue(byts)
|
||||
})
|
||||
@@ -1543,9 +1578,9 @@ func (ss *ServerSession) removeConn(sc *ServerConn) {
|
||||
}
|
||||
}
|
||||
|
||||
func (ss *ServerSession) startWriter() {
|
||||
func (ss *ServerSession) asyncStartWriter() {
|
||||
select {
|
||||
case ss.chStartWriter <- struct{}{}:
|
||||
case ss.chAsyncStartWriter <- struct{}{}:
|
||||
case <-ss.ctx.Done():
|
||||
}
|
||||
}
|
||||
|
@@ -25,7 +25,7 @@ func (p *clientAddr) fill(ip net.IP, port int) {
|
||||
}
|
||||
}
|
||||
|
||||
func allocateUDPListenerMulticastPair(
|
||||
func createUDPListenerMulticastPair(
|
||||
listenPacket func(network, address string) (net.PacketConn, error),
|
||||
writeTimeout time.Duration,
|
||||
multicastRTPPort int,
|
||||
|
Reference in New Issue
Block a user