mirror of
https://github.com/aler9/gortsplib
synced 2025-10-05 15:16:51 +08:00
server: save RAM by releasing read buffers earlier
This commit is contained in:
285
serverconn.go
285
serverconn.go
@@ -5,6 +5,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -34,18 +35,13 @@ type ServerConn struct {
|
||||
s *Server
|
||||
conn net.Conn
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel func()
|
||||
remoteAddr *net.TCPAddr
|
||||
br *bufio.Reader
|
||||
sessions map[string]*ServerSession
|
||||
tcpFrameEnabled bool
|
||||
tcpSession *ServerSession
|
||||
tcpFrameTimeout bool
|
||||
tcpReadBuffer *multibuffer.MultiBuffer
|
||||
tcpRTPPacketBuffer *rtpPacketMultiBuffer
|
||||
tcpProcessFunc func(int, bool, []byte)
|
||||
tcpWriterRunning bool
|
||||
ctx context.Context
|
||||
ctxCancel func()
|
||||
remoteAddr *net.TCPAddr
|
||||
br *bufio.Reader
|
||||
sessions map[string]*ServerSession
|
||||
readFunc func(readRequest chan readReq) error
|
||||
tcpSession *ServerSession
|
||||
|
||||
// in
|
||||
sessionRemove chan *ServerSession
|
||||
@@ -76,6 +72,8 @@ func newServerConn(
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
sc.readFunc = sc.readFuncStandard
|
||||
|
||||
s.wg.Add(1)
|
||||
go sc.run()
|
||||
|
||||
@@ -117,77 +115,7 @@ func (sc *ServerConn) run() {
|
||||
readRequest := make(chan readReq)
|
||||
readErr := make(chan error)
|
||||
readDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(readDone)
|
||||
err := func() error {
|
||||
var req base.Request
|
||||
var frame base.InterleavedFrame
|
||||
|
||||
for {
|
||||
if sc.tcpFrameEnabled {
|
||||
if sc.tcpFrameTimeout {
|
||||
sc.conn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout))
|
||||
}
|
||||
|
||||
frame.Payload = sc.tcpReadBuffer.Next()
|
||||
what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch what.(type) {
|
||||
case *base.InterleavedFrame:
|
||||
channel := frame.Channel
|
||||
isRTP := true
|
||||
if (channel % 2) != 0 {
|
||||
channel--
|
||||
isRTP = false
|
||||
}
|
||||
|
||||
// forward frame only if it has been set up
|
||||
if trackID, ok := sc.tcpSession.tcpTracksByChannel[channel]; ok {
|
||||
sc.tcpProcessFunc(trackID, isRTP, frame.Payload)
|
||||
}
|
||||
|
||||
case *base.Request:
|
||||
cres := make(chan error)
|
||||
select {
|
||||
case readRequest <- readReq{req: &req, res: cres}:
|
||||
err := <-cres
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case <-sc.ctx.Done():
|
||||
return liberrors.ErrServerTerminated{}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err := req.Read(sc.br)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cres := make(chan error)
|
||||
select {
|
||||
case readRequest <- readReq{req: &req, res: cres}:
|
||||
err = <-cres
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case <-sc.ctx.Done():
|
||||
return liberrors.ErrServerTerminated{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case readErr <- err:
|
||||
case <-sc.ctx.Done():
|
||||
}
|
||||
}()
|
||||
go sc.runReader(readRequest, readErr, readDone)
|
||||
|
||||
err := func() error {
|
||||
for {
|
||||
@@ -239,53 +167,165 @@ func (sc *ServerConn) run() {
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *ServerConn) tcpProcessPlay(trackID int, isRTP bool, payload []byte) {
|
||||
if !isRTP {
|
||||
packets, err := rtcp.Unmarshal(payload)
|
||||
if err != nil {
|
||||
return
|
||||
var errSwitchReadFunc = errors.New("switch read function")
|
||||
|
||||
func (sc *ServerConn) runReader(readRequest chan readReq, readErr chan error, readDone chan struct{}) {
|
||||
defer close(readDone)
|
||||
|
||||
for {
|
||||
err := sc.readFunc(readRequest)
|
||||
|
||||
if err == errSwitchReadFunc {
|
||||
continue
|
||||
}
|
||||
|
||||
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
|
||||
for _, pkt := range packets {
|
||||
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
|
||||
Session: sc.tcpSession,
|
||||
TrackID: trackID,
|
||||
Packet: pkt,
|
||||
})
|
||||
select {
|
||||
case readErr <- err:
|
||||
case <-sc.ctx.Done():
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *ServerConn) readFuncStandard(readRequest chan readReq) error {
|
||||
// reset deadline
|
||||
sc.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
var req base.Request
|
||||
|
||||
for {
|
||||
err := req.Read(sc.br)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cres := make(chan error)
|
||||
select {
|
||||
case readRequest <- readReq{req: &req, res: cres}:
|
||||
err = <-cres
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case <-sc.ctx.Done():
|
||||
return liberrors.ErrServerTerminated{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *ServerConn) tcpProcessRecord(trackID int, isRTP bool, payload []byte) {
|
||||
if isRTP {
|
||||
pkt := sc.tcpRTPPacketBuffer.next()
|
||||
err := pkt.Unmarshal(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
|
||||
// reset deadline
|
||||
sc.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
|
||||
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
|
||||
Session: sc.tcpSession,
|
||||
TrackID: trackID,
|
||||
Packet: pkt,
|
||||
})
|
||||
select {
|
||||
case sc.tcpSession.startWriter <- struct{}{}:
|
||||
case <-sc.tcpSession.ctx.Done():
|
||||
}
|
||||
|
||||
var tcpReadBuffer *multibuffer.MultiBuffer
|
||||
var processFunc func(int, bool, []byte)
|
||||
|
||||
if sc.tcpSession.state == ServerSessionStateRead {
|
||||
// when playing, tcpReadBuffer is only used to receive RTCP receiver reports,
|
||||
// that are much smaller than RTP packets and are sent at a fixed interval.
|
||||
// decrease RAM consumption by allocating less buffers.
|
||||
tcpReadBuffer = multibuffer.New(8, uint64(sc.s.ReadBufferSize))
|
||||
|
||||
processFunc = func(trackID int, isRTP bool, payload []byte) {
|
||||
if !isRTP {
|
||||
packets, err := rtcp.Unmarshal(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
|
||||
for _, pkt := range packets {
|
||||
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
|
||||
Session: sc.tcpSession,
|
||||
TrackID: trackID,
|
||||
Packet: pkt,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
packets, err := rtcp.Unmarshal(payload)
|
||||
if err != nil {
|
||||
return
|
||||
tcpReadBuffer = multibuffer.New(uint64(sc.s.ReadBufferCount), uint64(sc.s.ReadBufferSize))
|
||||
tcpRTPPacketBuffer := newRTPPacketMultiBuffer(uint64(sc.s.ReadBufferCount))
|
||||
|
||||
processFunc = func(trackID int, isRTP bool, payload []byte) {
|
||||
if isRTP {
|
||||
pkt := tcpRTPPacketBuffer.next()
|
||||
err := pkt.Unmarshal(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
|
||||
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
|
||||
Session: sc.tcpSession,
|
||||
TrackID: trackID,
|
||||
Packet: pkt,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
packets, err := rtcp.Unmarshal(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
|
||||
for _, pkt := range packets {
|
||||
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
|
||||
Session: sc.tcpSession,
|
||||
TrackID: trackID,
|
||||
Packet: pkt,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var req base.Request
|
||||
var frame base.InterleavedFrame
|
||||
|
||||
for {
|
||||
if sc.tcpSession.state == ServerSessionStatePublish {
|
||||
sc.conn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout))
|
||||
}
|
||||
|
||||
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
|
||||
for _, pkt := range packets {
|
||||
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
|
||||
Session: sc.tcpSession,
|
||||
TrackID: trackID,
|
||||
Packet: pkt,
|
||||
})
|
||||
frame.Payload = tcpReadBuffer.Next()
|
||||
what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch what.(type) {
|
||||
case *base.InterleavedFrame:
|
||||
channel := frame.Channel
|
||||
isRTP := true
|
||||
if (channel % 2) != 0 {
|
||||
channel--
|
||||
isRTP = false
|
||||
}
|
||||
|
||||
// forward frame only if it has been set up
|
||||
if trackID, ok := sc.tcpSession.tcpTracksByChannel[channel]; ok {
|
||||
processFunc(trackID, isRTP, frame.Payload)
|
||||
}
|
||||
|
||||
case *base.Request:
|
||||
cres := make(chan error)
|
||||
select {
|
||||
case readRequest <- readReq{req: &req, res: cres}:
|
||||
err := <-cres
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case <-sc.ctx.Done():
|
||||
return liberrors.ErrServerTerminated{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -503,15 +543,6 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
|
||||
sc.conn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout))
|
||||
sc.conn.Write(buf.Bytes())
|
||||
|
||||
// start writer after sending the response
|
||||
if sc.tcpFrameEnabled && !sc.tcpWriterRunning {
|
||||
sc.tcpWriterRunning = true
|
||||
select {
|
||||
case sc.tcpSession.startWriter <- struct{}{}:
|
||||
case <-sc.tcpSession.ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user