From ca6286321d976ae7e5c5730ab28dcee44ba7e93e Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Sun, 19 Jan 2025 12:07:59 +0100 Subject: [PATCH] fix various race conditions when writing packets to closed clients or server sessions (#684) --- async_processor.go | 21 ++-- async_processor_test.go | 6 +- client.go | 94 +++++++++++----- client_media.go | 4 +- client_play_test.go | 2 +- client_record_test.go | 186 +++++++++++++++++++++++++++++- client_udp_listener.go | 2 +- server_conn_reader.go | 2 +- server_multicast_writer.go | 4 +- server_play_test.go | 225 +++++++++++++++---------------------- server_session.go | 109 ++++++++++++------ server_udp_listener.go | 2 +- 12 files changed, 438 insertions(+), 219 deletions(-) diff --git a/async_processor.go b/async_processor.go index 2a0a9f0a..7a5e12d6 100644 --- a/async_processor.go +++ b/async_processor.go @@ -14,30 +14,29 @@ type asyncProcessor struct { buffer *ringbuffer.RingBuffer stopError error - stopped chan struct{} + chStopped chan struct{} } func (w *asyncProcessor) initialize() { w.buffer, _ = ringbuffer.New(uint64(w.bufferSize)) } -func (w *asyncProcessor) start() { - w.running = true - w.stopped = make(chan struct{}) - go w.run() -} - -func (w *asyncProcessor) stop() { +func (w *asyncProcessor) close() { if w.running { w.buffer.Close() - <-w.stopped - w.running = false + <-w.chStopped } } +func (w *asyncProcessor) start() { + w.running = true + w.chStopped = make(chan struct{}) + go w.run() +} + func (w *asyncProcessor) run() { w.stopError = w.runInner() - close(w.stopped) + close(w.chStopped) } func (w *asyncProcessor) runInner() error { diff --git a/async_processor_test.go b/async_processor_test.go index 2f881bd6..e2e8456f 100644 --- a/async_processor_test.go +++ b/async_processor_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestAsyncProcessorStopAfterError(t *testing.T) { +func TestAsyncProcessorCloseAfterError(t *testing.T) { p := &asyncProcessor{bufferSize: 8} p.initialize() @@ -17,8 +17,8 @@ func TestAsyncProcessorStopAfterError(t *testing.T) { p.start() - <-p.stopped + <-p.chStopped require.EqualError(t, p.stopError, "ok") - p.stop() + p.close() } diff --git a/client.go b/client.go index be014ed7..42e8f7e3 100644 --- a/client.go +++ b/client.go @@ -13,6 +13,7 @@ import ( "net" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -340,6 +341,7 @@ type Client struct { keepaliveTimer *time.Timer closeError error writer *asyncProcessor + writerMutex sync.RWMutex reader *clientReader timeDecoder *rtptime.GlobalDecoder2 mustClose bool @@ -560,8 +562,8 @@ func (c *Client) runInner() error { }() chWriterError := func() chan struct{} { - if c.writer != nil && c.writer.running { - return c.writer.stopped + if c.writer != nil { + return c.writer.chStopped } return nil }() @@ -721,7 +723,7 @@ func (c *Client) handleServerRequest(req *base.Request) error { func (c *Client) doClose() { if c.state == clientStatePlay || c.state == clientStateRecord { - c.writer.stop() + c.destroyWriter() c.stopTransportRoutines() } @@ -848,22 +850,6 @@ func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *base.UR } func (c *Client) startTransportRoutines() { - // allocate writer here because it's needed by RTCP receiver / sender - if c.state == clientStateRecord || c.backChannelSetupped { - c.writer = &asyncProcessor{ - bufferSize: c.WriteQueueSize, - } - c.writer.initialize() - } else { - // when reading, buffer is only used to send RTCP receiver reports, - // that are much smaller than RTP packets and are sent at a fixed interval. - // decrease RAM consumption by allocating less buffers. - c.writer = &asyncProcessor{ - bufferSize: 8, - } - c.writer.initialize() - } - c.timeDecoder = rtptime.NewGlobalDecoder2() for _, cm := range c.setuppedMedias { @@ -913,6 +899,39 @@ func (c *Client) stopTransportRoutines() { c.timeDecoder = nil } +func (c *Client) createWriter() { + c.writerMutex.Lock() + + c.writer = &asyncProcessor{ + bufferSize: func() int { + if c.state == clientStateRecord || c.backChannelSetupped { + return c.WriteQueueSize + } + + // when reading, buffer is only used to send RTCP receiver reports, + // that are much smaller than RTP packets and are sent at a fixed interval. + // decrease RAM consumption by allocating less buffers. + return 8 + }(), + } + + c.writer.initialize() + + c.writerMutex.Unlock() +} + +func (c *Client) startWriter() { + c.writer.start() +} + +func (c *Client) destroyWriter() { + c.writer.close() + + c.writerMutex.Lock() + c.writer = nil + c.writerMutex.Unlock() +} + func (c *Client) connOpen() error { if c.nconn != nil { return nil @@ -1389,7 +1408,7 @@ func (c *Client) doSetup( return nil, liberrors.ErrClientUDPPortsNotConsecutive{} } - err = cm.allocateUDPListeners( + err = cm.createUDPListeners( false, nil, net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)), @@ -1544,7 +1563,7 @@ func (c *Client) doSetup( readIP = c.nconn.RemoteAddr().(*net.TCPAddr).IP } - err = cm.allocateUDPListeners( + err = cm.createUDPListeners( true, readIP, net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[0]), 10)), @@ -1680,6 +1699,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) { c.state = clientStatePlay c.startTransportRoutines() + c.createWriter() // Range is mandatory in Parrot Streaming Server if ra == nil { @@ -1704,12 +1724,14 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) { Header: header, }, false) if err != nil { + c.destroyWriter() c.stopTransportRoutines() c.state = clientStatePrePlay return nil, err } if res.StatusCode != base.StatusOK { + c.destroyWriter() c.stopTransportRoutines() c.state = clientStatePrePlay return nil, liberrors.ErrClientBadStatusCode{ @@ -1731,7 +1753,8 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) { } } - c.writer.start() + c.startWriter() + c.lastRange = ra return res, nil @@ -1761,18 +1784,21 @@ func (c *Client) doRecord() (*base.Response, error) { c.state = clientStateRecord c.startTransportRoutines() + c.createWriter() res, err := c.do(&base.Request{ Method: base.Record, URL: c.baseURL, }, false) if err != nil { + c.destroyWriter() c.stopTransportRoutines() c.state = clientStatePreRecord return nil, err } if res.StatusCode != base.StatusOK { + c.destroyWriter() c.stopTransportRoutines() c.state = clientStatePreRecord return nil, liberrors.ErrClientBadStatusCode{ @@ -1780,7 +1806,7 @@ func (c *Client) doRecord() (*base.Response, error) { } } - c.writer.start() + c.startWriter() return nil, nil } @@ -1808,19 +1834,21 @@ func (c *Client) doPause() (*base.Response, error) { return nil, err } - c.writer.stop() + c.destroyWriter() res, err := c.do(&base.Request{ Method: base.Pause, URL: c.baseURL, }, false) if err != nil { - c.writer.start() + c.createWriter() + c.startWriter() return nil, err } if res.StatusCode != base.StatusOK { - c.writer.start() + c.createWriter() + c.startWriter() return nil, liberrors.ErrClientBadStatusCode{ Code: res.StatusCode, Message: res.StatusMessage, } @@ -1918,6 +1946,13 @@ func (c *Client) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.Packet, default: } + c.writerMutex.RLock() + defer c.writerMutex.RUnlock() + + if c.writer == nil { + return nil + } + cm := c.setuppedMedias[medi] cf := cm.formats[pkt.PayloadType] @@ -1946,6 +1981,13 @@ func (c *Client) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error default: } + c.writerMutex.RLock() + defer c.writerMutex.RUnlock() + + if c.writer == nil { + return nil + } + cm := c.setuppedMedias[medi] ok := c.writer.push(func() error { diff --git a/client_media.go b/client_media.go index a108497b..c6c4f3ba 100644 --- a/client_media.go +++ b/client_media.go @@ -59,7 +59,7 @@ func (cm *clientMedia) close() { } } -func (cm *clientMedia) allocateUDPListeners( +func (cm *clientMedia) createUDPListeners( multicastEnable bool, multicastSourceIP net.IP, rtpAddress string, @@ -94,7 +94,7 @@ func (cm *clientMedia) allocateUDPListeners( } var err error - cm.udpRTPListener, cm.udpRTCPListener, err = allocateUDPListenerPair(cm.c) + cm.udpRTPListener, cm.udpRTCPListener, err = createUDPListenerPair(cm.c) return err } diff --git a/client_play_test.go b/client_play_test.go index 05a6dc61..ea49ee63 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -1813,7 +1813,7 @@ func TestClientPlayRedirect(t *testing.T) { } } -func TestClientPlayPause(t *testing.T) { +func TestClientPlayPausePlay(t *testing.T) { writeFrames := func(inTH *headers.Transport, conn *conn.Conn) (chan struct{}, chan struct{}) { writerTerminate := make(chan struct{}) writerDone := make(chan struct{}) diff --git a/client_record_test.go b/client_record_test.go index edc10b0a..e5408ac7 100644 --- a/client_record_test.go +++ b/client_record_test.go @@ -482,7 +482,7 @@ func TestClientRecordSocketError(t *testing.T) { } } -func TestClientRecordPauseSerial(t *testing.T) { +func TestClientRecordPauseRecordSerial(t *testing.T) { for _, transport := range []string{ "udp", "tcp", @@ -618,6 +618,9 @@ func TestClientRecordPauseSerial(t *testing.T) { _, err = c.Pause() require.NoError(t, err) + err = c.WritePacketRTP(medi, &testRTPPacket) + require.NoError(t, err) + _, err = c.Record() require.NoError(t, err) @@ -627,6 +630,187 @@ func TestClientRecordPauseSerial(t *testing.T) { } } +func TestClientRecordPauseRecordParallel(t *testing.T) { + for _, transport := range []string{ + "udp", + "tcp", + } { + t.Run(transport, func(t *testing.T) { + l, err := net.Listen("tcp", "localhost:8554") + require.NoError(t, err) + defer l.Close() + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) + + nconn, err2 := l.Accept() + require.NoError(t, err2) + defer nconn.Close() + conn := conn.NewConn(nconn) + + req, err2 := conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Options, req.Method) + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Announce), + string(base.Setup), + string(base.Record), + string(base.Pause), + }, ", ")}, + }, + }) + require.NoError(t, err2) + + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Announce, req.Method) + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err2) + + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Setup, req.Method) + + var inTH headers.Transport + err2 = inTH.Unmarshal(req.Header["Transport"]) + require.NoError(t, err2) + + th := headers.Transport{ + Delivery: deliveryPtr(headers.TransportDeliveryUnicast), + } + + if transport == "udp" { + th.Protocol = headers.TransportProtocolUDP + th.ServerPorts = &[2]int{34556, 34557} + th.ClientPorts = inTH.ClientPorts + } else { + th.Protocol = headers.TransportProtocolTCP + th.InterleavedIDs = inTH.InterleavedIDs + } + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": th.Marshal(), + }, + }) + require.NoError(t, err2) + + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Record, req.Method) + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err2) + + if transport == "tcp" { + _, err2 = conn.ReadInterleavedFrame() + require.NoError(t, err2) + } + + req, err2 = readRequestIgnoreFrames(conn) + require.NoError(t, err2) + require.Equal(t, base.Pause, req.Method) + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err2) + + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Record, req.Method) + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err2) + + if transport == "tcp" { + _, err2 = conn.ReadInterleavedFrame() + require.NoError(t, err2) + } + + req, err2 = readRequestIgnoreFrames(conn) + require.NoError(t, err2) + require.Equal(t, base.Teardown, req.Method) + + err2 = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err2) + }() + + c := Client{ + Transport: func() *Transport { + if transport == "udp" { + v := TransportUDP + return &v + } + v := TransportTCP + return &v + }(), + } + + medi := testH264Media + medias := []*description.Media{medi} + + err = record(&c, "rtsp://localhost:8554/teststream", medias, nil) + require.NoError(t, err) + defer c.Close() + + writerTerminate := make(chan struct{}) + writerDone := make(chan struct{}) + + defer func() { + close(writerTerminate) + <-writerDone + }() + + go func() { + defer close(writerDone) + + ti := time.NewTicker(50 * time.Millisecond) + defer ti.Stop() + + for { + select { + case <-ti.C: + err2 := c.WritePacketRTP(medi, &testRTPPacket) + require.NoError(t, err2) + + case <-writerTerminate: + return + } + } + }() + + time.Sleep(500 * time.Millisecond) + + _, err = c.Pause() + require.NoError(t, err) + + time.Sleep(500 * time.Millisecond) + + _, err = c.Record() + require.NoError(t, err) + + time.Sleep(500 * time.Millisecond) + }) + } +} + func TestClientRecordAutomaticProtocol(t *testing.T) { l, err := net.Listen("tcp", "localhost:8554") require.NoError(t, err) diff --git a/client_udp_listener.go b/client_udp_listener.go index 6c0fc575..0c7bb2de 100644 --- a/client_udp_listener.go +++ b/client_udp_listener.go @@ -24,7 +24,7 @@ func randInRange(maxVal int) (int, error) { return int(n.Int64()), nil } -func allocateUDPListenerPair(c *Client) (*clientUDPListener, *clientUDPListener, error) { +func createUDPListenerPair(c *Client) (*clientUDPListener, *clientUDPListener, error) { // choose two consecutive ports in range 65535-10000 // RTP port must be even and RTCP port odd for { diff --git a/server_conn_reader.go b/server_conn_reader.go index 461f610f..4329881a 100644 --- a/server_conn_reader.go +++ b/server_conn_reader.go @@ -103,7 +103,7 @@ func (cr *serverConnReader) readFuncTCP() error { // reset deadline cr.sc.nconn.SetReadDeadline(time.Time{}) - cr.sc.session.startWriter() + cr.sc.session.asyncStartWriter() for { if cr.sc.session.state == ServerSessionStateRecord { diff --git a/server_multicast_writer.go b/server_multicast_writer.go index 1d214d71..e8bb1d8c 100644 --- a/server_multicast_writer.go +++ b/server_multicast_writer.go @@ -22,7 +22,7 @@ func (h *serverMulticastWriter) initialize() error { return err } - rtpl, rtcpl, err := allocateUDPListenerMulticastPair( + rtpl, rtcpl, err := createUDPListenerMulticastPair( h.s.ListenPacket, h.s.WriteTimeout, h.s.MulticastRTPPort, @@ -60,7 +60,7 @@ func (h *serverMulticastWriter) initialize() error { func (h *serverMulticastWriter) close() { h.rtpl.close() h.rtcpl.close() - h.writer.stop() + h.writer.close() } func (h *serverMulticastWriter) ip() net.IP { diff --git a/server_play_test.go b/server_play_test.go index 7c40b200..679ae61e 100644 --- a/server_play_test.go +++ b/server_play_test.go @@ -1528,62 +1528,7 @@ func TestServerPlayTCPResponseBeforeFrames(t *testing.T) { require.NoError(t, err) } -func TestServerPlayPlayPlay(t *testing.T) { - var stream *ServerStream - - s := &Server{ - Handler: &testServerHandler{ - onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, stream, nil - }, - onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, stream, nil - }, - onPlay: func(_ *ServerHandlerOnPlayCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - }, - }, - UDPRTPAddress: "127.0.0.1:8000", - UDPRTCPAddress: "127.0.0.1:8001", - RTSPAddress: "localhost:8554", - } - - err := s.Start() - require.NoError(t, err) - defer s.Close() - - stream = NewServerStream(s, &description.Session{Medias: []*description.Media{testH264Media}}) - defer stream.Close() - - nconn, err := net.Dial("tcp", "localhost:8554") - require.NoError(t, err) - defer nconn.Close() - conn := conn.NewConn(nconn) - - desc := doDescribe(t, conn) - - inTH := &headers.Transport{ - Protocol: headers.TransportProtocolUDP, - Delivery: deliveryPtr(headers.TransportDeliveryUnicast), - Mode: transportModePtr(headers.TransportModePlay), - ClientPorts: &[2]int{30450, 30451}, - } - - res, _ := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "") - - session := readSession(t, res) - - doPlay(t, conn, "rtsp://localhost:8554/teststream", session) - doPlay(t, conn, "rtsp://localhost:8554/teststream", session) -} - -func TestServerPlayPlayPausePlay(t *testing.T) { +func TestServerPlayPause(t *testing.T) { var stream *ServerStream writerStarted := false writerDone := make(chan struct{}) @@ -1666,91 +1611,105 @@ func TestServerPlayPlayPausePlay(t *testing.T) { doPlay(t, conn, "rtsp://localhost:8554/teststream", session) doPause(t, conn, "rtsp://localhost:8554/teststream", session) - doPlay(t, conn, "rtsp://localhost:8554/teststream", session) } -func TestServerPlayPlayPausePause(t *testing.T) { - var stream *ServerStream - writerDone := make(chan struct{}) - writerTerminate := make(chan struct{}) +func TestServerPlayPlayPausePausePlay(t *testing.T) { + for _, ca := range []string{"stream", "direct"} { + t.Run(ca, func(t *testing.T) { + var stream *ServerStream + writerStarted := false + writerDone := make(chan struct{}) + writerTerminate := make(chan struct{}) - s := &Server{ - Handler: &testServerHandler{ - onConnClose: func(_ *ServerHandlerOnConnCloseCtx) { - close(writerTerminate) - <-writerDone - }, - onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, stream, nil - }, - onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, stream, nil - }, - onPlay: func(_ *ServerHandlerOnPlayCtx) (*base.Response, error) { - go func() { - defer close(writerDone) + s := &Server{ + Handler: &testServerHandler{ + onConnClose: func(_ *ServerHandlerOnConnCloseCtx) { + close(writerTerminate) + <-writerDone + }, + onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { + if !writerStarted { + writerStarted = true + go func() { + defer close(writerDone) - ti := time.NewTicker(50 * time.Millisecond) - defer ti.Stop() + ti := time.NewTicker(50 * time.Millisecond) + defer ti.Stop() - for { - select { - case <-ti.C: - err := stream.WritePacketRTP(stream.Description().Medias[0], &testRTPPacket) - require.NoError(t, err) - case <-writerTerminate: - return + for { + select { + case <-ti.C: + if ca == "stream" { + err := stream.WritePacketRTP(stream.Description().Medias[0], &testRTPPacket) + require.NoError(t, err) + } else { + err := ctx.Session.WritePacketRTP(stream.Description().Medias[0], &testRTPPacket) + require.NoError(t, err) + } + + case <-writerTerminate: + return + } + } + }() } - } - }() - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - }, - onPause: func(_ *ServerHandlerOnPauseCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - }, - }, - RTSPAddress: "localhost:8554", + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onPause: func(_ *ServerHandlerOnPauseCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + RTSPAddress: "localhost:8554", + } + + err := s.Start() + require.NoError(t, err) + defer s.Close() + + stream = NewServerStream(s, &description.Session{Medias: []*description.Media{testH264Media}}) + defer stream.Close() + + nconn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) + + desc := doDescribe(t, conn) + + inTH := &headers.Transport{ + Protocol: headers.TransportProtocolTCP, + Delivery: deliveryPtr(headers.TransportDeliveryUnicast), + Mode: transportModePtr(headers.TransportModePlay), + InterleavedIDs: &[2]int{0, 1}, + } + + res, _ := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "") + + session := readSession(t, res) + + doPlay(t, conn, "rtsp://localhost:8554/teststream", session) + doPlay(t, conn, "rtsp://localhost:8554/teststream", session) + doPause(t, conn, "rtsp://localhost:8554/teststream", session) + doPause(t, conn, "rtsp://localhost:8554/teststream", session) + time.Sleep(500 * time.Millisecond) + doPlay(t, conn, "rtsp://localhost:8554/teststream", session) + }) } - - err := s.Start() - require.NoError(t, err) - defer s.Close() - - stream = NewServerStream(s, &description.Session{Medias: []*description.Media{testH264Media}}) - defer stream.Close() - - nconn, err := net.Dial("tcp", "localhost:8554") - require.NoError(t, err) - defer nconn.Close() - conn := conn.NewConn(nconn) - - desc := doDescribe(t, conn) - - inTH := &headers.Transport{ - Protocol: headers.TransportProtocolTCP, - Delivery: deliveryPtr(headers.TransportDeliveryUnicast), - Mode: transportModePtr(headers.TransportModePlay), - InterleavedIDs: &[2]int{0, 1}, - } - - res, _ := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "") - - session := readSession(t, res) - - doPlay(t, conn, "rtsp://localhost:8554/teststream", session) - - doPause(t, conn, "rtsp://localhost:8554/teststream", session) - - doPause(t, conn, "rtsp://localhost:8554/teststream", session) } func TestServerPlayTimeout(t *testing.T) { diff --git a/server_session.go b/server_session.go index 1c3c77bc..300e5f8e 100644 --- a/server_session.go +++ b/server_session.go @@ -7,6 +7,7 @@ import ( "net" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -253,14 +254,15 @@ type ServerSession struct { udpLastPacketTime *int64 // publish udpCheckStreamTimer *time.Timer writer *asyncProcessor + writerMutex sync.RWMutex timeDecoder *rtptime.GlobalDecoder2 tcpFrame *base.InterleavedFrame tcpBuffer []byte // in - chHandleRequest chan sessionRequestReq - chRemoveConn chan *ServerConn - chStartWriter chan struct{} + chHandleRequest chan sessionRequestReq + chRemoveConn chan *ServerConn + chAsyncStartWriter chan struct{} } func (ss *ServerSession) initialize() { @@ -278,7 +280,7 @@ func (ss *ServerSession) initialize() { ss.chHandleRequest = make(chan sessionRequestReq) ss.chRemoveConn = make(chan *ServerConn) - ss.chStartWriter = make(chan struct{}) + ss.chAsyncStartWriter = make(chan struct{}) ss.s.wg.Add(1) go ss.run() @@ -575,6 +577,37 @@ func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) err return liberrors.ErrServerInvalidState{AllowedList: allowedList, State: ss.state} } +func (ss *ServerSession) createWriter() { + ss.writerMutex.Lock() + + ss.writer = &asyncProcessor{ + bufferSize: func() int { + if ss.state == ServerSessionStatePrePlay { + return ss.s.WriteQueueSize + } + + // when recording, writeBuffer is only used to send RTCP receiver reports, + // that are much smaller than RTP packets and are sent at a fixed interval. + // decrease RAM consumption by allocating less buffers. + return 8 + }(), + } + + ss.writer.initialize() + + ss.writerMutex.Unlock() +} + +func (ss *ServerSession) startWriter() { + ss.writer.start() +} + +func (ss *ServerSession) destroyWriter() { + ss.writerMutex.Lock() + ss.writer = nil + ss.writerMutex.Unlock() +} + func (ss *ServerSession) run() { defer ss.s.wg.Done() @@ -611,7 +644,7 @@ func (ss *ServerSession) run() { } if ss.writer != nil { - ss.writer.stop() + ss.destroyWriter() } ss.s.closeSession(ss) @@ -627,8 +660,8 @@ func (ss *ServerSession) run() { func (ss *ServerSession) runInner() error { for { chWriterError := func() chan struct{} { - if ss.writer != nil && ss.writer.running { - return ss.writer.stopped + if ss.writer != nil { + return ss.writer.chStopped } return nil }() @@ -703,11 +736,11 @@ func (ss *ServerSession) runInner() error { return liberrors.ErrServerSessionNotInUse{} } - case <-ss.chStartWriter: + case <-ss.chAsyncStartWriter: if (ss.state == ServerSessionStateRecord || ss.state == ServerSessionStatePlay) && *ss.setuppedTransport == TransportTCP { - ss.writer.start() + ss.startWriter() } case <-ss.udpCheckStreamTimer.C: @@ -1118,15 +1151,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( }, liberrors.ErrServerPathHasChanged{Prev: ss.setuppedPath, Cur: path} } - // allocate writeBuffer before calling OnPlay(). - // in this way it's possible to call ServerSession.WritePacket*() - // inside the callback. if ss.state != ServerSessionStatePlay && *ss.setuppedTransport != TransportUDPMulticast { - ss.writer = &asyncProcessor{ - bufferSize: ss.s.WriteQueueSize, - } - ss.writer.initialize() + ss.createWriter() } res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{ @@ -1138,8 +1165,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( }) if res.StatusCode != base.StatusOK { - if ss.state != ServerSessionStatePlay { - ss.writer = nil + if ss.state != ServerSessionStatePlay && + *ss.setuppedTransport != TransportUDPMulticast { + ss.destroyWriter() } return res, err } @@ -1167,7 +1195,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( switch *ss.setuppedTransport { case TransportUDP: ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod) - ss.writer.start() + ss.startWriter() case TransportUDPMulticast: ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod) @@ -1175,7 +1203,8 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( default: // TCP ss.tcpConn = sc err = switchReadFuncError{true} - // writer.start() is called by ServerConn after the response has been sent + // startWriter() is called by ServerConn, through chAsyncStartWriter, + // after the response has been sent } ss.setuppedStream.readerSetActive(ss) @@ -1218,16 +1247,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( }, liberrors.ErrServerPathHasChanged{Prev: ss.setuppedPath, Cur: path} } - // allocate writeBuffer before calling OnRecord(). - // in this way it's possible to call ServerSession.WritePacket*() - // inside the callback. - // when recording, writeBuffer is only used to send RTCP receiver reports, - // that are much smaller than RTP packets and are sent at a fixed interval. - // decrease RAM consumption by allocating less buffers. - ss.writer = &asyncProcessor{ - bufferSize: 8, - } - ss.writer.initialize() + ss.createWriter() res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{ Session: ss, @@ -1238,7 +1258,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( }) if res.StatusCode != base.StatusOK { - ss.writer = nil + ss.destroyWriter() return res, err } @@ -1261,12 +1281,13 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( switch *ss.setuppedTransport { case TransportUDP: ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod) - ss.writer.start() + ss.startWriter() default: // TCP ss.tcpConn = sc err = switchReadFuncError{true} - // runWriter() is called by conn after sending the response + // startWriter() is called by ServerConn, through chAsyncStartWriter, + // after the response has been sent } return res, err @@ -1297,6 +1318,8 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( } if ss.state == ServerSessionStatePlay || ss.state == ServerSessionStateRecord { + ss.destroyWriter() + if ss.setuppedStream != nil { ss.setuppedStream.readerSetInactive(ss) } @@ -1305,8 +1328,6 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( sm.stop() } - ss.writer.stop() - ss.timeDecoder = nil switch ss.state { @@ -1446,6 +1467,13 @@ func (ss *ServerSession) writePacketRTP(medi *description.Media, payloadType uin sm := ss.setuppedMedias[medi] sf := sm.formats[payloadType] + ss.writerMutex.RLock() + defer ss.writerMutex.RUnlock() + + if ss.writer == nil { + return nil + } + ok := ss.writer.push(func() error { return sf.writePacketRTPInQueue(byts) }) @@ -1471,6 +1499,13 @@ func (ss *ServerSession) WritePacketRTP(medi *description.Media, pkt *rtp.Packet func (ss *ServerSession) writePacketRTCP(medi *description.Media, byts []byte) error { sm := ss.setuppedMedias[medi] + ss.writerMutex.RLock() + defer ss.writerMutex.RUnlock() + + if ss.writer == nil { + return nil + } + ok := ss.writer.push(func() error { return sm.writePacketRTCPInQueue(byts) }) @@ -1543,9 +1578,9 @@ func (ss *ServerSession) removeConn(sc *ServerConn) { } } -func (ss *ServerSession) startWriter() { +func (ss *ServerSession) asyncStartWriter() { select { - case ss.chStartWriter <- struct{}{}: + case ss.chAsyncStartWriter <- struct{}{}: case <-ss.ctx.Done(): } } diff --git a/server_udp_listener.go b/server_udp_listener.go index 62e0e64f..96dd756b 100644 --- a/server_udp_listener.go +++ b/server_udp_listener.go @@ -25,7 +25,7 @@ func (p *clientAddr) fill(ip net.IP, port int) { } } -func allocateUDPListenerMulticastPair( +func createUDPListenerMulticastPair( listenPacket func(network, address string) (net.PacketConn, error), writeTimeout time.Duration, multicastRTPPort int,