server: save RAM by releasing read buffers earlier

This commit is contained in:
aler9
2022-02-17 23:50:50 +01:00
committed by Alessandro Ros
parent 8c7b4c1ce7
commit d44f1eb03a
2 changed files with 172 additions and 149 deletions

View File

@@ -5,6 +5,7 @@ import (
"bytes"
"context"
"crypto/tls"
"errors"
"net"
"net/url"
"strings"
@@ -34,18 +35,13 @@ type ServerConn struct {
s *Server
conn net.Conn
ctx context.Context
ctxCancel func()
remoteAddr *net.TCPAddr
br *bufio.Reader
sessions map[string]*ServerSession
tcpFrameEnabled bool
tcpSession *ServerSession
tcpFrameTimeout bool
tcpReadBuffer *multibuffer.MultiBuffer
tcpRTPPacketBuffer *rtpPacketMultiBuffer
tcpProcessFunc func(int, bool, []byte)
tcpWriterRunning bool
ctx context.Context
ctxCancel func()
remoteAddr *net.TCPAddr
br *bufio.Reader
sessions map[string]*ServerSession
readFunc func(readRequest chan readReq) error
tcpSession *ServerSession
// in
sessionRemove chan *ServerSession
@@ -76,6 +72,8 @@ func newServerConn(
done: make(chan struct{}),
}
sc.readFunc = sc.readFuncStandard
s.wg.Add(1)
go sc.run()
@@ -117,77 +115,7 @@ func (sc *ServerConn) run() {
readRequest := make(chan readReq)
readErr := make(chan error)
readDone := make(chan struct{})
go func() {
defer close(readDone)
err := func() error {
var req base.Request
var frame base.InterleavedFrame
for {
if sc.tcpFrameEnabled {
if sc.tcpFrameTimeout {
sc.conn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout))
}
frame.Payload = sc.tcpReadBuffer.Next()
what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br)
if err != nil {
return err
}
switch what.(type) {
case *base.InterleavedFrame:
channel := frame.Channel
isRTP := true
if (channel % 2) != 0 {
channel--
isRTP = false
}
// forward frame only if it has been set up
if trackID, ok := sc.tcpSession.tcpTracksByChannel[channel]; ok {
sc.tcpProcessFunc(trackID, isRTP, frame.Payload)
}
case *base.Request:
cres := make(chan error)
select {
case readRequest <- readReq{req: &req, res: cres}:
err := <-cres
if err != nil {
return err
}
case <-sc.ctx.Done():
return liberrors.ErrServerTerminated{}
}
}
} else {
err := req.Read(sc.br)
if err != nil {
return err
}
cres := make(chan error)
select {
case readRequest <- readReq{req: &req, res: cres}:
err = <-cres
if err != nil {
return err
}
case <-sc.ctx.Done():
return liberrors.ErrServerTerminated{}
}
}
}
}()
select {
case readErr <- err:
case <-sc.ctx.Done():
}
}()
go sc.runReader(readRequest, readErr, readDone)
err := func() error {
for {
@@ -239,53 +167,165 @@ func (sc *ServerConn) run() {
}
}
func (sc *ServerConn) tcpProcessPlay(trackID int, isRTP bool, payload []byte) {
if !isRTP {
packets, err := rtcp.Unmarshal(payload)
if err != nil {
return
var errSwitchReadFunc = errors.New("switch read function")
func (sc *ServerConn) runReader(readRequest chan readReq, readErr chan error, readDone chan struct{}) {
defer close(readDone)
for {
err := sc.readFunc(readRequest)
if err == errSwitchReadFunc {
continue
}
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
for _, pkt := range packets {
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
Session: sc.tcpSession,
TrackID: trackID,
Packet: pkt,
})
select {
case readErr <- err:
case <-sc.ctx.Done():
}
break
}
}
func (sc *ServerConn) readFuncStandard(readRequest chan readReq) error {
// reset deadline
sc.conn.SetReadDeadline(time.Time{})
var req base.Request
for {
err := req.Read(sc.br)
if err != nil {
return err
}
cres := make(chan error)
select {
case readRequest <- readReq{req: &req, res: cres}:
err = <-cres
if err != nil {
return err
}
case <-sc.ctx.Done():
return liberrors.ErrServerTerminated{}
}
}
}
func (sc *ServerConn) tcpProcessRecord(trackID int, isRTP bool, payload []byte) {
if isRTP {
pkt := sc.tcpRTPPacketBuffer.next()
err := pkt.Unmarshal(payload)
if err != nil {
return
}
func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
// reset deadline
sc.conn.SetReadDeadline(time.Time{})
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
Session: sc.tcpSession,
TrackID: trackID,
Packet: pkt,
})
select {
case sc.tcpSession.startWriter <- struct{}{}:
case <-sc.tcpSession.ctx.Done():
}
var tcpReadBuffer *multibuffer.MultiBuffer
var processFunc func(int, bool, []byte)
if sc.tcpSession.state == ServerSessionStateRead {
// when playing, tcpReadBuffer is only used to receive RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
tcpReadBuffer = multibuffer.New(8, uint64(sc.s.ReadBufferSize))
processFunc = func(trackID int, isRTP bool, payload []byte) {
if !isRTP {
packets, err := rtcp.Unmarshal(payload)
if err != nil {
return
}
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
for _, pkt := range packets {
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
Session: sc.tcpSession,
TrackID: trackID,
Packet: pkt,
})
}
}
}
}
} else {
packets, err := rtcp.Unmarshal(payload)
if err != nil {
return
tcpReadBuffer = multibuffer.New(uint64(sc.s.ReadBufferCount), uint64(sc.s.ReadBufferSize))
tcpRTPPacketBuffer := newRTPPacketMultiBuffer(uint64(sc.s.ReadBufferCount))
processFunc = func(trackID int, isRTP bool, payload []byte) {
if isRTP {
pkt := tcpRTPPacketBuffer.next()
err := pkt.Unmarshal(payload)
if err != nil {
return
}
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
Session: sc.tcpSession,
TrackID: trackID,
Packet: pkt,
})
}
} else {
packets, err := rtcp.Unmarshal(payload)
if err != nil {
return
}
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
for _, pkt := range packets {
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
Session: sc.tcpSession,
TrackID: trackID,
Packet: pkt,
})
}
}
}
}
}
var req base.Request
var frame base.InterleavedFrame
for {
if sc.tcpSession.state == ServerSessionStatePublish {
sc.conn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout))
}
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
for _, pkt := range packets {
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
Session: sc.tcpSession,
TrackID: trackID,
Packet: pkt,
})
frame.Payload = tcpReadBuffer.Next()
what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br)
if err != nil {
return err
}
switch what.(type) {
case *base.InterleavedFrame:
channel := frame.Channel
isRTP := true
if (channel % 2) != 0 {
channel--
isRTP = false
}
// forward frame only if it has been set up
if trackID, ok := sc.tcpSession.tcpTracksByChannel[channel]; ok {
processFunc(trackID, isRTP, frame.Payload)
}
case *base.Request:
cres := make(chan error)
select {
case readRequest <- readReq{req: &req, res: cres}:
err := <-cres
if err != nil {
return err
}
case <-sc.ctx.Done():
return liberrors.ErrServerTerminated{}
}
}
}
@@ -503,15 +543,6 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
sc.conn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout))
sc.conn.Write(buf.Bytes())
// start writer after sending the response
if sc.tcpFrameEnabled && !sc.tcpWriterRunning {
sc.tcpWriterRunning = true
select {
case sc.tcpSession.startWriter <- struct{}{}:
case <-sc.tcpSession.ctx.Done():
}
}
return err
}

View File

@@ -17,7 +17,6 @@ import (
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/headers"
"github.com/aler9/gortsplib/pkg/liberrors"
"github.com/aler9/gortsplib/pkg/multibuffer"
"github.com/aler9/gortsplib/pkg/ringbuffer"
"github.com/aler9/gortsplib/pkg/rtcpreceiver"
)
@@ -883,17 +882,12 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
default: // TCP
ss.tcpConn = sc
ss.tcpConn.tcpSession = ss
ss.tcpConn.tcpFrameEnabled = true
ss.tcpConn.tcpFrameTimeout = false
// when playing, tcpReadBuffer is only used to receive RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
ss.tcpConn.tcpReadBuffer = multibuffer.New(8, uint64(sc.s.ReadBufferSize))
ss.tcpConn.tcpProcessFunc = sc.tcpProcessPlay
ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP
err = errSwitchReadFunc
ss.writeBuffer = ringbuffer.New(uint64(ss.s.ReadBufferCount))
// run writer after sending the response
ss.tcpConn.tcpWriterRunning = false
// runWriter() is called by conn after sending the response
}
// add RTP-Info
@@ -1016,18 +1010,15 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
default: // TCP
ss.tcpConn = sc
ss.tcpConn.tcpSession = ss
ss.tcpConn.tcpFrameEnabled = true
ss.tcpConn.tcpFrameTimeout = true
ss.tcpConn.tcpReadBuffer = multibuffer.New(uint64(sc.s.ReadBufferCount), uint64(sc.s.ReadBufferSize))
ss.tcpConn.tcpRTPPacketBuffer = newRTPPacketMultiBuffer(uint64(sc.s.ReadBufferCount))
ss.tcpConn.tcpProcessFunc = sc.tcpProcessRecord
ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP
err = errSwitchReadFunc
// when recording, writeBuffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
ss.writeBuffer = ringbuffer.New(uint64(8))
// run writer after sending the response
ss.tcpConn.tcpWriterRunning = false
// runWriter() is called by conn after sending the response
}
return res, err
@@ -1089,9 +1080,10 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
case TransportUDPMulticast:
default: // TCP
ss.tcpConn.readFunc = ss.tcpConn.readFuncStandard
err = errSwitchReadFunc
ss.tcpConn.tcpSession = nil
ss.tcpConn.tcpFrameEnabled = false
ss.tcpConn.tcpReadBuffer = nil
ss.tcpConn = nil
}
@@ -1108,10 +1100,10 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
case TransportUDPMulticast:
default: // TCP
ss.tcpConn.readFunc = ss.tcpConn.readFuncStandard
err = errSwitchReadFunc
ss.tcpConn.tcpSession = nil
ss.tcpConn.tcpFrameEnabled = false
ss.tcpConn.tcpReadBuffer = nil
ss.tcpConn.conn.SetReadDeadline(time.Time{})
ss.tcpConn = nil
}
}