From 2ed4b079e85b75db8ec68baff2127aee39ab2395 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sun, 28 Mar 2021 16:04:29 +0200 Subject: [PATCH] server: add new test --- clientconn.go | 6 +- serverconn.go | 89 ++++++++++++------------- serverconnread_test.go | 145 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 191 insertions(+), 49 deletions(-) diff --git a/clientconn.go b/clientconn.go index 5d62b63d..a87bab97 100644 --- a/clientconn.go +++ b/clientconn.go @@ -79,13 +79,13 @@ type ClientConn struct { udpRTCPListeners map[int]*clientConnUDPListener getParameterSupported bool - // read only + // read rtpInfo *headers.RTPInfo rtcpReceivers map[int]*rtcpreceiver.RTCPReceiver tcpFrameBuffer *multibuffer.MultiBuffer readCB func(int, StreamType, []byte) - // publish only + // publish rtcpSenders map[int]*rtcpsender.RTCPSender publishError error publishWriteMutex sync.RWMutex @@ -187,7 +187,7 @@ func (cc *ClientConn) reset() { cc.udpRTCPListeners = make(map[int]*clientConnUDPListener) cc.getParameterSupported = false - // read only + // read cc.rtpInfo = nil cc.rtcpReceivers = nil cc.tcpFrameBuffer = nil diff --git a/serverconn.go b/serverconn.go index ffb6df52..bd290d4c 100644 --- a/serverconn.go +++ b/serverconn.go @@ -261,18 +261,18 @@ type ServerConn struct { setupPath *string setupQuery *string - // frame mode only - doEnableFrames bool - framesEnabled bool - readTimeoutEnabled bool - tcpFrameBuffer *multibuffer.MultiBuffer - frameRingBuffer *ringbuffer.RingBuffer - backgroundWriteDone chan struct{} + // TCP stream protocol + doEnableTCPFrame bool + tcpFrameEnabled bool + tcpFrameTimeout bool + tcpFrameBuffer *multibuffer.MultiBuffer + tcpFrameWriteBuffer *ringbuffer.RingBuffer + tcpBackgroundWriteDone chan struct{} - // read only + // read readHandlers ServerConnReadHandlers - // publish only + // publish announcedTracks []ServerConnAnnouncedTrack backgroundRecordTerminate chan struct{} backgroundRecordDone chan struct{} @@ -294,15 +294,16 @@ func newServerConn(conf ServerConf, }() return &ServerConn{ - conf: conf, - udpRTPListener: udpRTPListener, - udpRTCPListener: udpRTCPListener, - nconn: nconn, - br: bufio.NewReaderSize(conn, serverConnReadBufferSize), - bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize), - frameRingBuffer: ringbuffer.New(uint64(conf.ReadBufferCount)), - backgroundWriteDone: make(chan struct{}), - terminate: make(chan struct{}), + conf: conf, + udpRTPListener: udpRTPListener, + udpRTCPListener: udpRTCPListener, + nconn: nconn, + br: bufio.NewReaderSize(conn, serverConnReadBufferSize), + bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize), + // always instantiate to allow writing to it before Play() + tcpFrameWriteBuffer: ringbuffer.New(uint64(conf.ReadBufferCount)), + tcpBackgroundWriteDone: make(chan struct{}), + terminate: make(chan struct{}), } } @@ -333,11 +334,11 @@ func (sc *ServerConn) AnnouncedTracks() []ServerConnAnnouncedTrack { return sc.announcedTracks } -func (sc *ServerConn) backgroundWrite() { - defer close(sc.backgroundWriteDone) +func (sc *ServerConn) tcpBackgroundWrite() { + defer close(sc.tcpBackgroundWriteDone) for { - what, ok := sc.frameRingBuffer.Pull() + what, ok := sc.tcpFrameWriteBuffer.Pull() if !ok { return } @@ -385,7 +386,7 @@ func (sc *ServerConn) frameModeEnable() { switch sc.state { case ServerConnStatePlay: if *sc.setupProtocol == StreamProtocolTCP { - sc.doEnableFrames = true + sc.doEnableTCPFrame = true } else { // readers can send RTCP frames, they cannot sent RTP frames for trackID, track := range sc.setuppedTracks { @@ -395,8 +396,8 @@ func (sc *ServerConn) frameModeEnable() { case ServerConnStateRecord: if *sc.setupProtocol == StreamProtocolTCP { - sc.doEnableFrames = true - sc.readTimeoutEnabled = true + sc.doEnableTCPFrame = true + sc.tcpFrameTimeout = true } else { for trackID, track := range sc.setuppedTracks { @@ -421,9 +422,9 @@ func (sc *ServerConn) frameModeDisable() { switch sc.state { case ServerConnStatePlay: if *sc.setupProtocol == StreamProtocolTCP { - sc.framesEnabled = false - sc.frameRingBuffer.Close() - <-sc.backgroundWriteDone + sc.tcpFrameEnabled = false + sc.tcpFrameWriteBuffer.Close() + <-sc.tcpBackgroundWriteDone } else { for _, track := range sc.setuppedTracks { @@ -436,12 +437,12 @@ func (sc *ServerConn) frameModeDisable() { <-sc.backgroundRecordDone if *sc.setupProtocol == StreamProtocolTCP { - sc.readTimeoutEnabled = false + sc.tcpFrameTimeout = false sc.nconn.SetReadDeadline(time.Time{}) - sc.framesEnabled = false - sc.frameRingBuffer.Close() - <-sc.backgroundWriteDone + sc.tcpFrameEnabled = false + sc.tcpFrameWriteBuffer.Close() + <-sc.tcpBackgroundWriteDone } else { for _, track := range sc.setuppedTracks { @@ -1045,9 +1046,9 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error { } switch { - case sc.doEnableFrames: // start background write - sc.doEnableFrames = false - sc.framesEnabled = true + case sc.doEnableTCPFrame: // start background write + sc.doEnableTCPFrame = false + sc.tcpFrameEnabled = true if sc.state == ServerConnStateRecord { sc.tcpFrameBuffer = multibuffer.New(uint64(sc.conf.ReadBufferCount), uint64(sc.conf.ReadBufferSize)) @@ -1064,12 +1065,12 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error { res.Write(sc.bw) // start background write - sc.frameRingBuffer.Reset() - sc.backgroundWriteDone = make(chan struct{}) - go sc.backgroundWrite() + sc.tcpFrameWriteBuffer.Reset() + sc.tcpBackgroundWriteDone = make(chan struct{}) + go sc.tcpBackgroundWrite() - case sc.framesEnabled: // write to background write - sc.frameRingBuffer.Push(res) + case sc.tcpFrameEnabled: // write to background write + sc.tcpFrameWriteBuffer.Push(res) default: // write directly sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) @@ -1086,11 +1087,11 @@ func (sc *ServerConn) backgroundRead() error { var frame base.InterleavedFrame for { - if sc.readTimeoutEnabled { - sc.nconn.SetReadDeadline(time.Now().Add(sc.conf.ReadTimeout)) - } + if sc.tcpFrameEnabled { + if sc.tcpFrameTimeout { + sc.nconn.SetReadDeadline(time.Now().Add(sc.conf.ReadTimeout)) + } - if sc.framesEnabled { frame.Payload = sc.tcpFrameBuffer.Next() what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br) if err != nil { @@ -1169,7 +1170,7 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []b return } - sc.frameRingBuffer.Push(&base.InterleavedFrame{ + sc.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{ TrackID: trackID, StreamType: streamType, Payload: payload, diff --git a/serverconnread_test.go b/serverconnread_test.go index 65bd5c37..8eb05f30 100644 --- a/serverconnread_test.go +++ b/serverconnread_test.go @@ -531,7 +531,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) { require.NoError(t, err) } -func TestServerReadPlayMultiple(t *testing.T) { +func TestServerReadPlayPlay(t *testing.T) { s, err := Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -645,7 +645,148 @@ func TestServerReadPlayMultiple(t *testing.T) { require.Equal(t, base.StatusOK, res.StatusCode) } -func TestServerReadPauseMultiple(t *testing.T) { +func TestServerReadPlayPausePlay(t *testing.T) { + s, err := Serve("127.0.0.1:8554") + require.NoError(t, err) + defer s.Close() + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) + + conn, err := s.Accept() + require.NoError(t, err) + defer conn.Close() + + writerStarted := false + writerDone := make(chan struct{}) + defer func() { <-writerDone }() + writerTerminate := make(chan struct{}) + defer close(writerTerminate) + + onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + } + + onPlay := func(ctx *ServerConnPlayCtx) (*base.Response, error) { + if !writerStarted { + writerStarted = true + go func() { + defer close(writerDone) + + t := time.NewTicker(50 * time.Millisecond) + defer t.Stop() + + for { + select { + case <-t.C: + conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) + case <-writerTerminate: + return + } + } + }() + } + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + } + + onPause := func(ctx *ServerConnPauseCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + } + + <-conn.Read(ServerConnReadHandlers{ + OnSetup: onSetup, + OnPlay: onPlay, + OnPause: onPause, + }) + }() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer conn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": headers.Transport{ + Protocol: StreamProtocolTCP, + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + InterleavedIDs: &[2]int{0, 1}, + }.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + var res base.Response + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + err = base.Request{ + Method: base.Play, + URL: base.MustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + err = base.Request{ + Method: base.Pause, + URL: base.MustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + buf := make([]byte, 2048) + err = res.ReadIgnoreFrames(bconn.Reader, buf) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + err = base.Request{ + Method: base.Play, + URL: base.MustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + var fr base.InterleavedFrame + fr.Payload = make([]byte, 2048) + err = fr.Read(bconn.Reader) + require.NoError(t, err) +} + +func TestServerReadPlayPausePause(t *testing.T) { s, err := Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close()