perform frame readings and writings in separate routines, in order to increase UDP throughput and avoid freezes caused by a single laggy reader (https://github.com/aler9/rtsp-simple-server/issues/125) (https://github.com/aler9/rtsp-simple-server/issues/162)

This commit is contained in:
aler9
2021-01-09 22:59:41 +01:00
parent 87bd5bde32
commit 7d91c13972
8 changed files with 263 additions and 126 deletions

View File

@@ -8,13 +8,13 @@ import (
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/headers"
"github.com/aler9/gortsplib/pkg/multibuffer"
"github.com/aler9/gortsplib/pkg/ringbuffer"
"github.com/aler9/gortsplib/pkg/rtcpreceiver"
)
@@ -23,11 +23,13 @@ const (
serverConnWriteBufferSize = 4096
serverConnCheckStreamInterval = 5 * time.Second
serverConnReceiverReportInterval = 10 * time.Second
serverConnTCPFrameReadBufferSize = 2048
)
// server errors.
var (
ErrServerTeardown = errors.New("teardown")
ErrServerTeardown = errors.New("teardown")
errServerCSeqMissing = errors.New("CSeq is missing")
)
// ServerConnState is the state of the connection.
@@ -138,20 +140,24 @@ type ServerConn struct {
state ServerConnState
tracks map[int]ServerConnTrack
tracksProtocol *StreamProtocol
rtcpReceivers []*rtcpreceiver.RTCPReceiver
udpLastFrameTimes []*int64
writeMutex sync.Mutex
readHandlers ServerConnReadHandlers
nextFramesEnabled bool
rtcpReceivers []*rtcpreceiver.RTCPReceiver
doEnableFrames bool
framesEnabled bool
readTimeoutEnabled bool
udpTimeout *int32
// writer
frameRingBuffer *ringbuffer.RingBuffer
backgroundWriteDone chan struct{}
// background record
backgroundRecordTerminate chan struct{}
backgroundRecordDone chan struct{}
udpTimeout int32
udpLastFrameTimes []*int64
// in
terminate chan struct{}
backgroundRecordTerminate chan struct{}
backgroundRecordDone chan struct{}
}
func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn {
@@ -163,13 +169,14 @@ func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn {
}()
return &ServerConn{
conf: conf,
nconn: nconn,
br: bufio.NewReaderSize(conn, serverConnReadBufferSize),
bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize),
tracks: make(map[int]ServerConnTrack),
udpTimeout: new(int32),
terminate: make(chan struct{}),
conf: conf,
nconn: nconn,
br: bufio.NewReaderSize(conn, serverConnReadBufferSize),
bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize),
tracks: make(map[int]ServerConnTrack),
frameRingBuffer: ringbuffer.New(conf.ReadBufferCount),
backgroundWriteDone: make(chan struct{}),
terminate: make(chan struct{}),
}
}
@@ -206,6 +213,30 @@ func (sc *ServerConn) Tracks() map[int]ServerConnTrack {
return sc.tracks
}
func (sc *ServerConn) backgroundWrite() {
defer close(sc.backgroundWriteDone)
for {
what, ok := sc.frameRingBuffer.Pull()
if !ok {
return
}
switch w := what.(type) {
case *base.InterleavedFrame:
sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout))
w.Write(sc.bw)
case *base.Response:
sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout))
w.Write(sc.bw)
default:
panic(fmt.Errorf("unsupported type: %T", what))
}
}
}
func (sc *ServerConn) checkState(allowed map[ServerConnState]struct{}) error {
if _, ok := allowed[sc.state]; ok {
return nil
@@ -236,12 +267,12 @@ func (sc *ServerConn) frameModeEnable() {
switch sc.state {
case ServerConnStatePlay:
if *sc.tracksProtocol == StreamProtocolTCP {
sc.nextFramesEnabled = true
sc.doEnableFrames = true
}
case ServerConnStateRecord:
if *sc.tracksProtocol == StreamProtocolTCP {
sc.nextFramesEnabled = true
sc.doEnableFrames = true
sc.readTimeoutEnabled = true
} else {
@@ -266,16 +297,25 @@ func (sc *ServerConn) frameModeEnable() {
func (sc *ServerConn) frameModeDisable() {
switch sc.state {
case ServerConnStatePlay:
sc.nextFramesEnabled = false
if *sc.tracksProtocol == StreamProtocolTCP {
sc.framesEnabled = false
sc.frameRingBuffer.Close()
<-sc.backgroundWriteDone
}
case ServerConnStateRecord:
close(sc.backgroundRecordTerminate)
<-sc.backgroundRecordDone
sc.nextFramesEnabled = false
sc.readTimeoutEnabled = false
if *sc.tracksProtocol == StreamProtocolTCP {
sc.readTimeoutEnabled = false
sc.nconn.SetReadDeadline(time.Time{})
if *sc.tracksProtocol == StreamProtocolUDP {
sc.framesEnabled = false
sc.frameRingBuffer.Close()
<-sc.backgroundWriteDone
} else {
for _, track := range sc.tracks {
sc.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort)
sc.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort)
@@ -285,6 +325,13 @@ func (sc *ServerConn) frameModeDisable() {
}
func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
if cseq, ok := req.Header["CSeq"]; !ok || len(cseq) != 1 {
return &base.Response{
StatusCode: base.StatusBadRequest,
Header: base.Header{},
}, errServerCSeqMissing
}
if sc.readHandlers.OnRequest != nil {
sc.readHandlers.OnRequest(req)
}
@@ -676,19 +723,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
func (sc *ServerConn) backgroundRead() error {
handleRequestOuter := func(req *base.Request) error {
// check cseq
cseq, ok := req.Header["CSeq"]
if !ok || len(cseq) != 1 {
sc.writeMutex.Lock()
sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout))
base.Response{
StatusCode: base.StatusBadRequest,
Header: base.Header{},
}.Write(sc.bw)
sc.writeMutex.Unlock()
return errors.New("CSeq is missing")
}
res, err := sc.handleRequest(req)
if res.Header == nil {
@@ -696,7 +730,9 @@ func (sc *ServerConn) backgroundRead() error {
}
// add cseq
res.Header["CSeq"] = cseq
if err != errServerCSeqMissing {
res.Header["CSeq"] = req.Header["CSeq"]
}
// add server
res.Header["Server"] = base.HeaderValue{"gortsplib"}
@@ -705,33 +741,42 @@ func (sc *ServerConn) backgroundRead() error {
sc.readHandlers.OnResponse(res)
}
sc.writeMutex.Lock()
// start background write
if sc.doEnableFrames {
sc.doEnableFrames = false
sc.framesEnabled = true
sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout))
res.Write(sc.bw)
// write response before frames
sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout))
res.Write(sc.bw)
// set framesEnabled after sending the response
// in order to start sending frames after the response, never before
if sc.framesEnabled != sc.nextFramesEnabled {
sc.framesEnabled = sc.nextFramesEnabled
// start background write
sc.frameRingBuffer.Reset()
sc.backgroundWriteDone = make(chan struct{})
go sc.backgroundWrite()
// write to background write
} else if sc.framesEnabled {
sc.frameRingBuffer.Push(res)
// write directly
} else {
sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout))
res.Write(sc.bw)
}
sc.writeMutex.Unlock()
return err
}
var req base.Request
var frame base.InterleavedFrame
tcpFrameBuffer := multibuffer.New(sc.conf.ReadBufferCount, clientConnTCPFrameReadBufferSize)
tcpFrameBuffer := multibuffer.New(sc.conf.ReadBufferCount, serverConnTCPFrameReadBufferSize)
var errRet error
outer:
for {
if sc.readTimeoutEnabled {
sc.nconn.SetReadDeadline(time.Now().Add(sc.conf.ReadTimeout))
} else {
sc.nconn.SetReadDeadline(time.Time{})
}
if sc.framesEnabled {
@@ -764,7 +809,7 @@ outer:
} else {
err := req.Read(sc.br)
if err != nil {
if atomic.LoadInt32(sc.udpTimeout) == 1 {
if atomic.LoadInt32(&sc.udpTimeout) == 1 {
errRet = fmt.Errorf("no UDP packets received recently (maybe there's a firewall/NAT in between)")
} else {
errRet = err
@@ -801,41 +846,34 @@ func (sc *ServerConn) Read(readHandlers ServerConnReadHandlers) chan error {
}
// WriteFrame writes a frame.
func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) error {
sc.writeMutex.Lock()
defer sc.writeMutex.Unlock()
func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) {
if *sc.tracksProtocol == StreamProtocolUDP {
track := sc.tracks[trackID]
if streamType == StreamTypeRTP {
return sc.conf.UDPRTPListener.write(payload, &net.UDPAddr{
sc.conf.UDPRTPListener.write(payload, &net.UDPAddr{
IP: sc.ip(),
Zone: sc.zone(),
Port: track.rtpPort,
})
return
}
return sc.conf.UDPRTCPListener.write(payload, &net.UDPAddr{
sc.conf.UDPRTCPListener.write(payload, &net.UDPAddr{
IP: sc.ip(),
Zone: sc.zone(),
Port: track.rtcpPort,
})
return
}
// StreamProtocolTCP
if !sc.framesEnabled {
return errors.New("frames are disabled")
}
sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout))
frame := base.InterleavedFrame{
sc.frameRingBuffer.Push(&base.InterleavedFrame{
TrackID: trackID,
StreamType: streamType,
Payload: payload,
}
return frame.Write(sc.bw)
})
}
func (sc *ServerConn) backgroundRecord() {
@@ -859,7 +897,7 @@ func (sc *ServerConn) backgroundRecord() {
last := time.Unix(atomic.LoadInt64(lastUnix), 0)
if now.Sub(last) >= sc.conf.ReadTimeout {
atomic.StoreInt32(sc.udpTimeout, 1)
atomic.StoreInt32(&sc.udpTimeout, 1)
sc.nconn.Close()
return
}