diff --git a/clientconf.go b/clientconf.go index fd0f281d..1a35f3b8 100644 --- a/clientconf.go +++ b/clientconf.go @@ -55,7 +55,7 @@ type ClientConf struct { // If greater than 1, allows to pass buffers to routines different than the one // that is reading frames. // It defaults to 1. - ReadBufferCount int + ReadBufferCount uint64 // callback called before every request. OnRequest func(req *base.Request) diff --git a/clientconn.go b/clientconn.go index 0031f243..411183da 100644 --- a/clientconn.go +++ b/clientconn.go @@ -33,7 +33,7 @@ const ( clientConnSenderReportPeriod = 10 * time.Second clientConnUDPCheckStreamPeriod = 5 * time.Second clientConnUDPKeepalivePeriod = 30 * time.Second - clientConnTCPFrameReadBufferSize = 128 * 1024 + clientConnTCPFrameReadBufferSize = 2048 ) type clientConnState int diff --git a/pkg/multibuffer/multibuffer.go b/pkg/multibuffer/multibuffer.go index e92d7185..2962e386 100644 --- a/pkg/multibuffer/multibuffer.go +++ b/pkg/multibuffer/multibuffer.go @@ -2,17 +2,17 @@ package multibuffer // MultiBuffer implements software multi buffering, that allows to reuse -// existing buffers without creating new ones, increasing performance. +// existing buffers without creating new ones, improving performance. type MultiBuffer struct { - count int + count uint64 buffers [][]byte - cur int + cur uint64 } // New allocates a MultiBuffer. -func New(count int, size int) *MultiBuffer { +func New(count uint64, size uint64) *MultiBuffer { buffers := make([][]byte, count) - for i := 0; i < count; i++ { + for i := uint64(0); i < count; i++ { buffers[i] = make([]byte, size) } @@ -24,10 +24,7 @@ func New(count int, size int) *MultiBuffer { // Next gets the current buffer and sets the next buffer as the current one. func (mb *MultiBuffer) Next() []byte { - ret := mb.buffers[mb.cur] + ret := mb.buffers[mb.cur%mb.count] mb.cur++ - if mb.cur >= mb.count { - mb.cur = 0 - } return ret } diff --git a/pkg/ringbuffer/ringbuffer.go b/pkg/ringbuffer/ringbuffer.go new file mode 100644 index 00000000..e000cc3d --- /dev/null +++ b/pkg/ringbuffer/ringbuffer.go @@ -0,0 +1,67 @@ +package ringbuffer + +import ( + "sync/atomic" + "time" + "unsafe" +) + +// RingBuffer is a ring buffer. +type RingBuffer struct { + bufferSize uint64 + readIndex uint64 + writeIndex uint64 + closed int64 + buffer []unsafe.Pointer +} + +// New allocates a RingBuffer. +func New(size uint64) *RingBuffer { + return &RingBuffer{ + bufferSize: size, + readIndex: 1, + writeIndex: 0, + buffer: make([]unsafe.Pointer, size), + } +} + +// Close makes Pull() return false. +func (r *RingBuffer) Close() { + atomic.StoreInt64(&r.closed, 1) +} + +// Reset restores Pull(). +func (r *RingBuffer) Reset() { + for i := uint64(0); i < r.bufferSize; i++ { + atomic.SwapPointer(&r.buffer[i], nil) + } + atomic.SwapUint64(&r.writeIndex, 0) + r.readIndex = 1 + atomic.StoreInt64(&r.closed, 0) +} + +// Push pushes some data at the end of the buffer. +func (r *RingBuffer) Push(data interface{}) { + writeIndex := atomic.AddUint64(&r.writeIndex, 1) + i := writeIndex % r.bufferSize + atomic.SwapPointer(&r.buffer[i], unsafe.Pointer(&data)) +} + +// Pull pulls some data from the beginning of the buffer. +func (r *RingBuffer) Pull() (interface{}, bool) { + for { + if atomic.SwapInt64(&r.closed, 0) == 1 { + return nil, false + } + + i := r.readIndex % r.bufferSize + res := (*interface{})(atomic.SwapPointer(&r.buffer[i], nil)) + if res == nil { + time.Sleep(10 * time.Millisecond) + continue + } + + r.readIndex++ + return *res, true + } +} diff --git a/server.go b/server.go index 9015f806..4a2aee04 100644 --- a/server.go +++ b/server.go @@ -20,7 +20,7 @@ func newServer(conf ServerConf, address string) (*Server, error) { conf.WriteTimeout = 10 * time.Second } if conf.ReadBufferCount == 0 { - conf.ReadBufferCount = 1 + conf.ReadBufferCount = 1024 } if conf.Listen == nil { conf.Listen = net.Listen @@ -36,11 +36,8 @@ func newServer(conf ServerConf, address string) (*Server, error) { } if conf.UDPRTPListener != nil { - conf.UDPRTPListener.streamType = StreamTypeRTP - conf.UDPRTPListener.writeTimeout = conf.WriteTimeout - - conf.UDPRTCPListener.streamType = StreamTypeRTCP - conf.UDPRTCPListener.writeTimeout = conf.WriteTimeout + conf.UDPRTPListener.initialize(conf, StreamTypeRTP) + conf.UDPRTCPListener.initialize(conf, StreamTypeRTCP) } listener, err := conf.Listen("tcp", address) diff --git a/serverconf.go b/serverconf.go index c87800fb..3dc45304 100644 --- a/serverconf.go +++ b/serverconf.go @@ -39,8 +39,10 @@ type ServerConf struct { // Read buffer count. // If greater than 1, allows to pass buffers to routines different than the one // that is reading frames. - // It defaults to 1 - ReadBufferCount int + // It also allows to buffer routed frames and mitigate network fluctuations + // that are particularly high when using UDP. + // It defaults to 1024 + ReadBufferCount uint64 // Function used to initialize the TCP listener. // It defaults to net.Listen diff --git a/serverconn.go b/serverconn.go index 7c45818a..e2586f93 100644 --- a/serverconn.go +++ b/serverconn.go @@ -8,13 +8,13 @@ import ( "net" "strconv" "strings" - "sync" "sync/atomic" "time" "github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/multibuffer" + "github.com/aler9/gortsplib/pkg/ringbuffer" "github.com/aler9/gortsplib/pkg/rtcpreceiver" ) @@ -23,11 +23,13 @@ const ( serverConnWriteBufferSize = 4096 serverConnCheckStreamInterval = 5 * time.Second serverConnReceiverReportInterval = 10 * time.Second + serverConnTCPFrameReadBufferSize = 2048 ) // server errors. var ( - ErrServerTeardown = errors.New("teardown") + ErrServerTeardown = errors.New("teardown") + errServerCSeqMissing = errors.New("CSeq is missing") ) // ServerConnState is the state of the connection. @@ -138,20 +140,24 @@ type ServerConn struct { state ServerConnState tracks map[int]ServerConnTrack tracksProtocol *StreamProtocol - rtcpReceivers []*rtcpreceiver.RTCPReceiver - udpLastFrameTimes []*int64 - writeMutex sync.Mutex readHandlers ServerConnReadHandlers - nextFramesEnabled bool + rtcpReceivers []*rtcpreceiver.RTCPReceiver + doEnableFrames bool framesEnabled bool readTimeoutEnabled bool - udpTimeout *int32 + + // writer + frameRingBuffer *ringbuffer.RingBuffer + backgroundWriteDone chan struct{} + + // background record + backgroundRecordTerminate chan struct{} + backgroundRecordDone chan struct{} + udpTimeout int32 + udpLastFrameTimes []*int64 // in terminate chan struct{} - - backgroundRecordTerminate chan struct{} - backgroundRecordDone chan struct{} } func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn { @@ -163,13 +169,14 @@ func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn { }() return &ServerConn{ - conf: conf, - nconn: nconn, - br: bufio.NewReaderSize(conn, serverConnReadBufferSize), - bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize), - tracks: make(map[int]ServerConnTrack), - udpTimeout: new(int32), - terminate: make(chan struct{}), + conf: conf, + nconn: nconn, + br: bufio.NewReaderSize(conn, serverConnReadBufferSize), + bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize), + tracks: make(map[int]ServerConnTrack), + frameRingBuffer: ringbuffer.New(conf.ReadBufferCount), + backgroundWriteDone: make(chan struct{}), + terminate: make(chan struct{}), } } @@ -206,6 +213,30 @@ func (sc *ServerConn) Tracks() map[int]ServerConnTrack { return sc.tracks } +func (sc *ServerConn) backgroundWrite() { + defer close(sc.backgroundWriteDone) + + for { + what, ok := sc.frameRingBuffer.Pull() + if !ok { + return + } + + switch w := what.(type) { + case *base.InterleavedFrame: + sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) + w.Write(sc.bw) + + case *base.Response: + sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) + w.Write(sc.bw) + + default: + panic(fmt.Errorf("unsupported type: %T", what)) + } + } +} + func (sc *ServerConn) checkState(allowed map[ServerConnState]struct{}) error { if _, ok := allowed[sc.state]; ok { return nil @@ -236,12 +267,12 @@ func (sc *ServerConn) frameModeEnable() { switch sc.state { case ServerConnStatePlay: if *sc.tracksProtocol == StreamProtocolTCP { - sc.nextFramesEnabled = true + sc.doEnableFrames = true } case ServerConnStateRecord: if *sc.tracksProtocol == StreamProtocolTCP { - sc.nextFramesEnabled = true + sc.doEnableFrames = true sc.readTimeoutEnabled = true } else { @@ -266,16 +297,25 @@ func (sc *ServerConn) frameModeEnable() { func (sc *ServerConn) frameModeDisable() { switch sc.state { case ServerConnStatePlay: - sc.nextFramesEnabled = false + if *sc.tracksProtocol == StreamProtocolTCP { + sc.framesEnabled = false + sc.frameRingBuffer.Close() + <-sc.backgroundWriteDone + } case ServerConnStateRecord: close(sc.backgroundRecordTerminate) <-sc.backgroundRecordDone - sc.nextFramesEnabled = false - sc.readTimeoutEnabled = false + if *sc.tracksProtocol == StreamProtocolTCP { + sc.readTimeoutEnabled = false + sc.nconn.SetReadDeadline(time.Time{}) - if *sc.tracksProtocol == StreamProtocolUDP { + sc.framesEnabled = false + sc.frameRingBuffer.Close() + <-sc.backgroundWriteDone + + } else { for _, track := range sc.tracks { sc.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort) sc.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort) @@ -285,6 +325,13 @@ func (sc *ServerConn) frameModeDisable() { } func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { + if cseq, ok := req.Header["CSeq"]; !ok || len(cseq) != 1 { + return &base.Response{ + StatusCode: base.StatusBadRequest, + Header: base.Header{}, + }, errServerCSeqMissing + } + if sc.readHandlers.OnRequest != nil { sc.readHandlers.OnRequest(req) } @@ -676,19 +723,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { func (sc *ServerConn) backgroundRead() error { handleRequestOuter := func(req *base.Request) error { - // check cseq - cseq, ok := req.Header["CSeq"] - if !ok || len(cseq) != 1 { - sc.writeMutex.Lock() - sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) - base.Response{ - StatusCode: base.StatusBadRequest, - Header: base.Header{}, - }.Write(sc.bw) - sc.writeMutex.Unlock() - return errors.New("CSeq is missing") - } - res, err := sc.handleRequest(req) if res.Header == nil { @@ -696,7 +730,9 @@ func (sc *ServerConn) backgroundRead() error { } // add cseq - res.Header["CSeq"] = cseq + if err != errServerCSeqMissing { + res.Header["CSeq"] = req.Header["CSeq"] + } // add server res.Header["Server"] = base.HeaderValue{"gortsplib"} @@ -705,33 +741,42 @@ func (sc *ServerConn) backgroundRead() error { sc.readHandlers.OnResponse(res) } - sc.writeMutex.Lock() + // start background write + if sc.doEnableFrames { + sc.doEnableFrames = false + sc.framesEnabled = true - sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) - res.Write(sc.bw) + // write response before frames + sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) + res.Write(sc.bw) - // set framesEnabled after sending the response - // in order to start sending frames after the response, never before - if sc.framesEnabled != sc.nextFramesEnabled { - sc.framesEnabled = sc.nextFramesEnabled + // start background write + sc.frameRingBuffer.Reset() + sc.backgroundWriteDone = make(chan struct{}) + go sc.backgroundWrite() + + // write to background write + } else if sc.framesEnabled { + sc.frameRingBuffer.Push(res) + + // write directly + } else { + sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) + res.Write(sc.bw) } - sc.writeMutex.Unlock() - return err } var req base.Request var frame base.InterleavedFrame - tcpFrameBuffer := multibuffer.New(sc.conf.ReadBufferCount, clientConnTCPFrameReadBufferSize) + tcpFrameBuffer := multibuffer.New(sc.conf.ReadBufferCount, serverConnTCPFrameReadBufferSize) var errRet error outer: for { if sc.readTimeoutEnabled { sc.nconn.SetReadDeadline(time.Now().Add(sc.conf.ReadTimeout)) - } else { - sc.nconn.SetReadDeadline(time.Time{}) } if sc.framesEnabled { @@ -764,7 +809,7 @@ outer: } else { err := req.Read(sc.br) if err != nil { - if atomic.LoadInt32(sc.udpTimeout) == 1 { + if atomic.LoadInt32(&sc.udpTimeout) == 1 { errRet = fmt.Errorf("no UDP packets received recently (maybe there's a firewall/NAT in between)") } else { errRet = err @@ -801,41 +846,34 @@ func (sc *ServerConn) Read(readHandlers ServerConnReadHandlers) chan error { } // WriteFrame writes a frame. -func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) error { - sc.writeMutex.Lock() - defer sc.writeMutex.Unlock() - +func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) { if *sc.tracksProtocol == StreamProtocolUDP { track := sc.tracks[trackID] if streamType == StreamTypeRTP { - return sc.conf.UDPRTPListener.write(payload, &net.UDPAddr{ + sc.conf.UDPRTPListener.write(payload, &net.UDPAddr{ IP: sc.ip(), Zone: sc.zone(), Port: track.rtpPort, }) + return } - return sc.conf.UDPRTCPListener.write(payload, &net.UDPAddr{ + sc.conf.UDPRTCPListener.write(payload, &net.UDPAddr{ IP: sc.ip(), Zone: sc.zone(), Port: track.rtcpPort, }) + return } // StreamProtocolTCP - if !sc.framesEnabled { - return errors.New("frames are disabled") - } - - sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) - frame := base.InterleavedFrame{ + sc.frameRingBuffer.Push(&base.InterleavedFrame{ TrackID: trackID, StreamType: streamType, Payload: payload, - } - return frame.Write(sc.bw) + }) } func (sc *ServerConn) backgroundRecord() { @@ -859,7 +897,7 @@ func (sc *ServerConn) backgroundRecord() { last := time.Unix(atomic.LoadInt64(lastUnix), 0) if now.Sub(last) >= sc.conf.ReadTimeout { - atomic.StoreInt32(sc.udpTimeout, 1) + atomic.StoreInt32(&sc.udpTimeout, 1) sc.nconn.Close() return } diff --git a/serverudpl.go b/serverudpl.go index 4f1cdb31..a1208820 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -7,15 +7,19 @@ import ( "time" "github.com/aler9/gortsplib/pkg/multibuffer" + "github.com/aler9/gortsplib/pkg/ringbuffer" ) const ( - // use the same buffer size as gstreamer's rtspsrc - kernelReadBufferSize = 0x80000 - - readBufferSize = 2048 + serverConnUDPListenerKernelReadBufferSize = 0x80000 // same as gstreamer's rtspsrc + serverConnUDPListenerReadBufferSize = 2048 ) +type bufAddrPair struct { + buf []byte + addr *net.UDPAddr +} + type publisherData struct { publisher *ServerConn trackID int @@ -39,14 +43,14 @@ func (p *publisherAddr) fill(ip net.IP, port int) { // ServerUDPListener is a UDP server that can be used to send and receive RTP and RTCP packets. type ServerUDPListener struct { - streamType StreamType - writeTimeout time.Duration - pc *net.UDPConn + initialized bool + streamType StreamType + writeTimeout time.Duration readBuf *multibuffer.MultiBuffer publishersMutex sync.RWMutex publishers map[publisherAddr]*publisherData - writeMutex sync.Mutex + ringBuffer *ringbuffer.RingBuffer // out done chan struct{} @@ -60,70 +64,102 @@ func NewServerUDPListener(address string) (*ServerUDPListener, error) { } pc := tmp.(*net.UDPConn) - err = pc.SetReadBuffer(kernelReadBufferSize) + err = pc.SetReadBuffer(serverConnUDPListenerKernelReadBufferSize) if err != nil { return nil, err } - s := &ServerUDPListener{ + return &ServerUDPListener{ pc: pc, - readBuf: multibuffer.New(1, readBufferSize), publishers: make(map[publisherAddr]*publisherData), done: make(chan struct{}), - } - - go s.run() - - return s, nil + }, nil } // Close closes the listener. func (s *ServerUDPListener) Close() { s.pc.Close() - <-s.done + + if s.initialized { + s.ringBuffer.Close() + <-s.done + } +} + +func (s *ServerUDPListener) initialize(conf ServerConf, streamType StreamType) { + if s.initialized { + return + } + + s.initialized = true + s.streamType = streamType + s.writeTimeout = conf.WriteTimeout + s.readBuf = multibuffer.New(conf.ReadBufferCount, serverConnUDPListenerReadBufferSize) + s.ringBuffer = ringbuffer.New(conf.ReadBufferCount) + go s.run() } func (s *ServerUDPListener) run() { defer close(s.done) - for { - buf := s.readBuf.Next() - n, addr, err := s.pc.ReadFromUDP(buf) - if err != nil { - break + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + + for { + buf := s.readBuf.Next() + n, addr, err := s.pc.ReadFromUDP(buf) + if err != nil { + break + } + + func() { + s.publishersMutex.RLock() + defer s.publishersMutex.RUnlock() + + // find publisher data + var pubAddr publisherAddr + pubAddr.fill(addr.IP, addr.Port) + pubData, ok := s.publishers[pubAddr] + if !ok { + return + } + + now := time.Now() + atomic.StoreInt64(pubData.publisher.udpLastFrameTimes[pubData.trackID], now.Unix()) + pubData.publisher.rtcpReceivers[pubData.trackID].ProcessFrame(now, s.streamType, buf[:n]) + pubData.publisher.readHandlers.OnFrame(pubData.trackID, s.streamType, buf[:n]) + }() } + }() - func() { - s.publishersMutex.RLock() - defer s.publishersMutex.RUnlock() + wg.Add(1) + go func() { + defer wg.Done() - // find publisher data - var pubAddr publisherAddr - pubAddr.fill(addr.IP, addr.Port) - pubData, ok := s.publishers[pubAddr] + for { + tmp, ok := s.ringBuffer.Pull() if !ok { return } + pair := tmp.(bufAddrPair) - now := time.Now() - atomic.StoreInt64(pubData.publisher.udpLastFrameTimes[pubData.trackID], now.Unix()) - pubData.publisher.rtcpReceivers[pubData.trackID].ProcessFrame(now, s.streamType, buf[:n]) - pubData.publisher.readHandlers.OnFrame(pubData.trackID, s.streamType, buf[:n]) - }() - } + s.pc.SetWriteDeadline(time.Now().Add(s.writeTimeout)) + s.pc.WriteTo(pair.buf, pair.addr) + } + }() + + wg.Wait() } func (s *ServerUDPListener) port() int { return s.pc.LocalAddr().(*net.UDPAddr).Port } -func (s *ServerUDPListener) write(buf []byte, addr *net.UDPAddr) error { - s.writeMutex.Lock() - defer s.writeMutex.Unlock() - - s.pc.SetWriteDeadline(time.Now().Add(s.writeTimeout)) - _, err := s.pc.WriteTo(buf, addr) - return err +func (s *ServerUDPListener) write(buf []byte, addr *net.UDPAddr) { + s.ringBuffer.Push(bufAddrPair{buf, addr}) } func (s *ServerUDPListener) addPublisher(ip net.IP, port int, trackID int, sc *ServerConn) {