server: add new test

This commit is contained in:
aler9
2021-03-28 16:04:29 +02:00
parent 213eb4908e
commit 2ed4b079e8
3 changed files with 191 additions and 49 deletions

View File

@@ -79,13 +79,13 @@ type ClientConn struct {
udpRTCPListeners map[int]*clientConnUDPListener udpRTCPListeners map[int]*clientConnUDPListener
getParameterSupported bool getParameterSupported bool
// read only // read
rtpInfo *headers.RTPInfo rtpInfo *headers.RTPInfo
rtcpReceivers map[int]*rtcpreceiver.RTCPReceiver rtcpReceivers map[int]*rtcpreceiver.RTCPReceiver
tcpFrameBuffer *multibuffer.MultiBuffer tcpFrameBuffer *multibuffer.MultiBuffer
readCB func(int, StreamType, []byte) readCB func(int, StreamType, []byte)
// publish only // publish
rtcpSenders map[int]*rtcpsender.RTCPSender rtcpSenders map[int]*rtcpsender.RTCPSender
publishError error publishError error
publishWriteMutex sync.RWMutex publishWriteMutex sync.RWMutex
@@ -187,7 +187,7 @@ func (cc *ClientConn) reset() {
cc.udpRTCPListeners = make(map[int]*clientConnUDPListener) cc.udpRTCPListeners = make(map[int]*clientConnUDPListener)
cc.getParameterSupported = false cc.getParameterSupported = false
// read only // read
cc.rtpInfo = nil cc.rtpInfo = nil
cc.rtcpReceivers = nil cc.rtcpReceivers = nil
cc.tcpFrameBuffer = nil cc.tcpFrameBuffer = nil

View File

@@ -261,18 +261,18 @@ type ServerConn struct {
setupPath *string setupPath *string
setupQuery *string setupQuery *string
// frame mode only // TCP stream protocol
doEnableFrames bool doEnableTCPFrame bool
framesEnabled bool tcpFrameEnabled bool
readTimeoutEnabled bool tcpFrameTimeout bool
tcpFrameBuffer *multibuffer.MultiBuffer tcpFrameBuffer *multibuffer.MultiBuffer
frameRingBuffer *ringbuffer.RingBuffer tcpFrameWriteBuffer *ringbuffer.RingBuffer
backgroundWriteDone chan struct{} tcpBackgroundWriteDone chan struct{}
// read only // read
readHandlers ServerConnReadHandlers readHandlers ServerConnReadHandlers
// publish only // publish
announcedTracks []ServerConnAnnouncedTrack announcedTracks []ServerConnAnnouncedTrack
backgroundRecordTerminate chan struct{} backgroundRecordTerminate chan struct{}
backgroundRecordDone chan struct{} backgroundRecordDone chan struct{}
@@ -294,15 +294,16 @@ func newServerConn(conf ServerConf,
}() }()
return &ServerConn{ return &ServerConn{
conf: conf, conf: conf,
udpRTPListener: udpRTPListener, udpRTPListener: udpRTPListener,
udpRTCPListener: udpRTCPListener, udpRTCPListener: udpRTCPListener,
nconn: nconn, nconn: nconn,
br: bufio.NewReaderSize(conn, serverConnReadBufferSize), br: bufio.NewReaderSize(conn, serverConnReadBufferSize),
bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize), bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize),
frameRingBuffer: ringbuffer.New(uint64(conf.ReadBufferCount)), // always instantiate to allow writing to it before Play()
backgroundWriteDone: make(chan struct{}), tcpFrameWriteBuffer: ringbuffer.New(uint64(conf.ReadBufferCount)),
terminate: make(chan struct{}), tcpBackgroundWriteDone: make(chan struct{}),
terminate: make(chan struct{}),
} }
} }
@@ -333,11 +334,11 @@ func (sc *ServerConn) AnnouncedTracks() []ServerConnAnnouncedTrack {
return sc.announcedTracks return sc.announcedTracks
} }
func (sc *ServerConn) backgroundWrite() { func (sc *ServerConn) tcpBackgroundWrite() {
defer close(sc.backgroundWriteDone) defer close(sc.tcpBackgroundWriteDone)
for { for {
what, ok := sc.frameRingBuffer.Pull() what, ok := sc.tcpFrameWriteBuffer.Pull()
if !ok { if !ok {
return return
} }
@@ -385,7 +386,7 @@ func (sc *ServerConn) frameModeEnable() {
switch sc.state { switch sc.state {
case ServerConnStatePlay: case ServerConnStatePlay:
if *sc.setupProtocol == StreamProtocolTCP { if *sc.setupProtocol == StreamProtocolTCP {
sc.doEnableFrames = true sc.doEnableTCPFrame = true
} else { } else {
// readers can send RTCP frames, they cannot sent RTP frames // readers can send RTCP frames, they cannot sent RTP frames
for trackID, track := range sc.setuppedTracks { for trackID, track := range sc.setuppedTracks {
@@ -395,8 +396,8 @@ func (sc *ServerConn) frameModeEnable() {
case ServerConnStateRecord: case ServerConnStateRecord:
if *sc.setupProtocol == StreamProtocolTCP { if *sc.setupProtocol == StreamProtocolTCP {
sc.doEnableFrames = true sc.doEnableTCPFrame = true
sc.readTimeoutEnabled = true sc.tcpFrameTimeout = true
} else { } else {
for trackID, track := range sc.setuppedTracks { for trackID, track := range sc.setuppedTracks {
@@ -421,9 +422,9 @@ func (sc *ServerConn) frameModeDisable() {
switch sc.state { switch sc.state {
case ServerConnStatePlay: case ServerConnStatePlay:
if *sc.setupProtocol == StreamProtocolTCP { if *sc.setupProtocol == StreamProtocolTCP {
sc.framesEnabled = false sc.tcpFrameEnabled = false
sc.frameRingBuffer.Close() sc.tcpFrameWriteBuffer.Close()
<-sc.backgroundWriteDone <-sc.tcpBackgroundWriteDone
} else { } else {
for _, track := range sc.setuppedTracks { for _, track := range sc.setuppedTracks {
@@ -436,12 +437,12 @@ func (sc *ServerConn) frameModeDisable() {
<-sc.backgroundRecordDone <-sc.backgroundRecordDone
if *sc.setupProtocol == StreamProtocolTCP { if *sc.setupProtocol == StreamProtocolTCP {
sc.readTimeoutEnabled = false sc.tcpFrameTimeout = false
sc.nconn.SetReadDeadline(time.Time{}) sc.nconn.SetReadDeadline(time.Time{})
sc.framesEnabled = false sc.tcpFrameEnabled = false
sc.frameRingBuffer.Close() sc.tcpFrameWriteBuffer.Close()
<-sc.backgroundWriteDone <-sc.tcpBackgroundWriteDone
} else { } else {
for _, track := range sc.setuppedTracks { for _, track := range sc.setuppedTracks {
@@ -1045,9 +1046,9 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
} }
switch { switch {
case sc.doEnableFrames: // start background write case sc.doEnableTCPFrame: // start background write
sc.doEnableFrames = false sc.doEnableTCPFrame = false
sc.framesEnabled = true sc.tcpFrameEnabled = true
if sc.state == ServerConnStateRecord { if sc.state == ServerConnStateRecord {
sc.tcpFrameBuffer = multibuffer.New(uint64(sc.conf.ReadBufferCount), uint64(sc.conf.ReadBufferSize)) sc.tcpFrameBuffer = multibuffer.New(uint64(sc.conf.ReadBufferCount), uint64(sc.conf.ReadBufferSize))
@@ -1064,12 +1065,12 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
res.Write(sc.bw) res.Write(sc.bw)
// start background write // start background write
sc.frameRingBuffer.Reset() sc.tcpFrameWriteBuffer.Reset()
sc.backgroundWriteDone = make(chan struct{}) sc.tcpBackgroundWriteDone = make(chan struct{})
go sc.backgroundWrite() go sc.tcpBackgroundWrite()
case sc.framesEnabled: // write to background write case sc.tcpFrameEnabled: // write to background write
sc.frameRingBuffer.Push(res) sc.tcpFrameWriteBuffer.Push(res)
default: // write directly default: // write directly
sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout))
@@ -1086,11 +1087,11 @@ func (sc *ServerConn) backgroundRead() error {
var frame base.InterleavedFrame var frame base.InterleavedFrame
for { for {
if sc.readTimeoutEnabled { if sc.tcpFrameEnabled {
sc.nconn.SetReadDeadline(time.Now().Add(sc.conf.ReadTimeout)) if sc.tcpFrameTimeout {
} sc.nconn.SetReadDeadline(time.Now().Add(sc.conf.ReadTimeout))
}
if sc.framesEnabled {
frame.Payload = sc.tcpFrameBuffer.Next() frame.Payload = sc.tcpFrameBuffer.Next()
what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br) what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br)
if err != nil { if err != nil {
@@ -1169,7 +1170,7 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []b
return return
} }
sc.frameRingBuffer.Push(&base.InterleavedFrame{ sc.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{
TrackID: trackID, TrackID: trackID,
StreamType: streamType, StreamType: streamType,
Payload: payload, Payload: payload,

View File

@@ -531,7 +531,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
func TestServerReadPlayMultiple(t *testing.T) { func TestServerReadPlayPlay(t *testing.T) {
s, err := Serve("127.0.0.1:8554") s, err := Serve("127.0.0.1:8554")
require.NoError(t, err) require.NoError(t, err)
defer s.Close() defer s.Close()
@@ -645,7 +645,148 @@ func TestServerReadPlayMultiple(t *testing.T) {
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
} }
func TestServerReadPauseMultiple(t *testing.T) { func TestServerReadPlayPausePlay(t *testing.T) {
s, err := Serve("127.0.0.1:8554")
require.NoError(t, err)
defer s.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
conn, err := s.Accept()
require.NoError(t, err)
defer conn.Close()
writerStarted := false
writerDone := make(chan struct{})
defer func() { <-writerDone }()
writerTerminate := make(chan struct{})
defer close(writerTerminate)
onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
}
onPlay := func(ctx *ServerConnPlayCtx) (*base.Response, error) {
if !writerStarted {
writerStarted = true
go func() {
defer close(writerDone)
t := time.NewTicker(50 * time.Millisecond)
defer t.Stop()
for {
select {
case <-t.C:
conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00"))
case <-writerTerminate:
return
}
}
}()
}
return &base.Response{
StatusCode: base.StatusOK,
}, nil
}
onPause := func(ctx *ServerConnPauseCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
}
<-conn.Read(ServerConnReadHandlers{
OnSetup: onSetup,
OnPlay: onPlay,
OnPause: onPause,
})
}()
conn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer conn.Close()
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
err = base.Request{
Method: base.Setup,
URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Transport": headers.Transport{
Protocol: StreamProtocolTCP,
Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast
return &v
}(),
Mode: func() *headers.TransportMode {
v := headers.TransportModePlay
return &v
}(),
InterleavedIDs: &[2]int{0, 1},
}.Write(),
},
}.Write(bconn.Writer)
require.NoError(t, err)
var res base.Response
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
err = base.Request{
Method: base.Play,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
},
}.Write(bconn.Writer)
require.NoError(t, err)
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
err = base.Request{
Method: base.Pause,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
},
}.Write(bconn.Writer)
require.NoError(t, err)
buf := make([]byte, 2048)
err = res.ReadIgnoreFrames(bconn.Reader, buf)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
err = base.Request{
Method: base.Play,
URL: base.MustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
},
}.Write(bconn.Writer)
require.NoError(t, err)
err = res.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
var fr base.InterleavedFrame
fr.Payload = make([]byte, 2048)
err = fr.Read(bconn.Reader)
require.NoError(t, err)
}
func TestServerReadPlayPausePause(t *testing.T) {
s, err := Serve("127.0.0.1:8554") s, err := Serve("127.0.0.1:8554")
require.NoError(t, err) require.NoError(t, err)
defer s.Close() defer s.Close()