diff --git a/serverconn.go b/serverconn.go index b4b0867b..340ee89b 100644 --- a/serverconn.go +++ b/serverconn.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "crypto/tls" + "errors" "net" "net/url" "strings" @@ -34,18 +35,13 @@ type ServerConn struct { s *Server conn net.Conn - ctx context.Context - ctxCancel func() - remoteAddr *net.TCPAddr - br *bufio.Reader - sessions map[string]*ServerSession - tcpFrameEnabled bool - tcpSession *ServerSession - tcpFrameTimeout bool - tcpReadBuffer *multibuffer.MultiBuffer - tcpRTPPacketBuffer *rtpPacketMultiBuffer - tcpProcessFunc func(int, bool, []byte) - tcpWriterRunning bool + ctx context.Context + ctxCancel func() + remoteAddr *net.TCPAddr + br *bufio.Reader + sessions map[string]*ServerSession + readFunc func(readRequest chan readReq) error + tcpSession *ServerSession // in sessionRemove chan *ServerSession @@ -76,6 +72,8 @@ func newServerConn( done: make(chan struct{}), } + sc.readFunc = sc.readFuncStandard + s.wg.Add(1) go sc.run() @@ -117,77 +115,7 @@ func (sc *ServerConn) run() { readRequest := make(chan readReq) readErr := make(chan error) readDone := make(chan struct{}) - go func() { - defer close(readDone) - err := func() error { - var req base.Request - var frame base.InterleavedFrame - - for { - if sc.tcpFrameEnabled { - if sc.tcpFrameTimeout { - sc.conn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout)) - } - - frame.Payload = sc.tcpReadBuffer.Next() - what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br) - if err != nil { - return err - } - - switch what.(type) { - case *base.InterleavedFrame: - channel := frame.Channel - isRTP := true - if (channel % 2) != 0 { - channel-- - isRTP = false - } - - // forward frame only if it has been set up - if trackID, ok := sc.tcpSession.tcpTracksByChannel[channel]; ok { - sc.tcpProcessFunc(trackID, isRTP, frame.Payload) - } - - case *base.Request: - cres := make(chan error) - select { - case readRequest <- readReq{req: &req, res: cres}: - err := <-cres - if err != nil { - return err - } - - case <-sc.ctx.Done(): - return liberrors.ErrServerTerminated{} - } - } - } else { - err := req.Read(sc.br) - if err != nil { - return err - } - - cres := make(chan error) - select { - case readRequest <- readReq{req: &req, res: cres}: - err = <-cres - if err != nil { - return err - } - - case <-sc.ctx.Done(): - return liberrors.ErrServerTerminated{} - } - } - } - }() - - select { - case readErr <- err: - case <-sc.ctx.Done(): - } - }() + go sc.runReader(readRequest, readErr, readDone) err := func() error { for { @@ -239,53 +167,165 @@ func (sc *ServerConn) run() { } } -func (sc *ServerConn) tcpProcessPlay(trackID int, isRTP bool, payload []byte) { - if !isRTP { - packets, err := rtcp.Unmarshal(payload) - if err != nil { - return +var errSwitchReadFunc = errors.New("switch read function") + +func (sc *ServerConn) runReader(readRequest chan readReq, readErr chan error, readDone chan struct{}) { + defer close(readDone) + + for { + err := sc.readFunc(readRequest) + + if err == errSwitchReadFunc { + continue } - if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok { - for _, pkt := range packets { - h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ - Session: sc.tcpSession, - TrackID: trackID, - Packet: pkt, - }) + select { + case readErr <- err: + case <-sc.ctx.Done(): + } + break + } +} + +func (sc *ServerConn) readFuncStandard(readRequest chan readReq) error { + // reset deadline + sc.conn.SetReadDeadline(time.Time{}) + + var req base.Request + + for { + err := req.Read(sc.br) + if err != nil { + return err + } + + cres := make(chan error) + select { + case readRequest <- readReq{req: &req, res: cres}: + err = <-cres + if err != nil { + return err } + + case <-sc.ctx.Done(): + return liberrors.ErrServerTerminated{} } } } -func (sc *ServerConn) tcpProcessRecord(trackID int, isRTP bool, payload []byte) { - if isRTP { - pkt := sc.tcpRTPPacketBuffer.next() - err := pkt.Unmarshal(payload) - if err != nil { - return - } +func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { + // reset deadline + sc.conn.SetReadDeadline(time.Time{}) - if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { - h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ - Session: sc.tcpSession, - TrackID: trackID, - Packet: pkt, - }) + select { + case sc.tcpSession.startWriter <- struct{}{}: + case <-sc.tcpSession.ctx.Done(): + } + + var tcpReadBuffer *multibuffer.MultiBuffer + var processFunc func(int, bool, []byte) + + if sc.tcpSession.state == ServerSessionStateRead { + // when playing, tcpReadBuffer is only used to receive RTCP receiver reports, + // that are much smaller than RTP packets and are sent at a fixed interval. + // decrease RAM consumption by allocating less buffers. + tcpReadBuffer = multibuffer.New(8, uint64(sc.s.ReadBufferSize)) + + processFunc = func(trackID int, isRTP bool, payload []byte) { + if !isRTP { + packets, err := rtcp.Unmarshal(payload) + if err != nil { + return + } + + if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok { + for _, pkt := range packets { + h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ + Session: sc.tcpSession, + TrackID: trackID, + Packet: pkt, + }) + } + } + } } } else { - packets, err := rtcp.Unmarshal(payload) - if err != nil { - return + tcpReadBuffer = multibuffer.New(uint64(sc.s.ReadBufferCount), uint64(sc.s.ReadBufferSize)) + tcpRTPPacketBuffer := newRTPPacketMultiBuffer(uint64(sc.s.ReadBufferCount)) + + processFunc = func(trackID int, isRTP bool, payload []byte) { + if isRTP { + pkt := tcpRTPPacketBuffer.next() + err := pkt.Unmarshal(payload) + if err != nil { + return + } + + if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { + h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ + Session: sc.tcpSession, + TrackID: trackID, + Packet: pkt, + }) + } + } else { + packets, err := rtcp.Unmarshal(payload) + if err != nil { + return + } + + if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok { + for _, pkt := range packets { + h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ + Session: sc.tcpSession, + TrackID: trackID, + Packet: pkt, + }) + } + } + } + } + } + + var req base.Request + var frame base.InterleavedFrame + + for { + if sc.tcpSession.state == ServerSessionStatePublish { + sc.conn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout)) } - if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok { - for _, pkt := range packets { - h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ - Session: sc.tcpSession, - TrackID: trackID, - Packet: pkt, - }) + frame.Payload = tcpReadBuffer.Next() + what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br) + if err != nil { + return err + } + + switch what.(type) { + case *base.InterleavedFrame: + channel := frame.Channel + isRTP := true + if (channel % 2) != 0 { + channel-- + isRTP = false + } + + // forward frame only if it has been set up + if trackID, ok := sc.tcpSession.tcpTracksByChannel[channel]; ok { + processFunc(trackID, isRTP, frame.Payload) + } + + case *base.Request: + cres := make(chan error) + select { + case readRequest <- readReq{req: &req, res: cres}: + err := <-cres + if err != nil { + return err + } + + case <-sc.ctx.Done(): + return liberrors.ErrServerTerminated{} } } } @@ -503,15 +543,6 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error { sc.conn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout)) sc.conn.Write(buf.Bytes()) - // start writer after sending the response - if sc.tcpFrameEnabled && !sc.tcpWriterRunning { - sc.tcpWriterRunning = true - select { - case sc.tcpSession.startWriter <- struct{}{}: - case <-sc.tcpSession.ctx.Done(): - } - } - return err } diff --git a/serversession.go b/serversession.go index 270a1ff0..4060d2fe 100644 --- a/serversession.go +++ b/serversession.go @@ -17,7 +17,6 @@ import ( "github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/liberrors" - "github.com/aler9/gortsplib/pkg/multibuffer" "github.com/aler9/gortsplib/pkg/ringbuffer" "github.com/aler9/gortsplib/pkg/rtcpreceiver" ) @@ -883,17 +882,12 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base default: // TCP ss.tcpConn = sc ss.tcpConn.tcpSession = ss - ss.tcpConn.tcpFrameEnabled = true - ss.tcpConn.tcpFrameTimeout = false - // when playing, tcpReadBuffer is only used to receive RTCP receiver reports, - // that are much smaller than RTP packets and are sent at a fixed interval. - // decrease RAM consumption by allocating less buffers. - ss.tcpConn.tcpReadBuffer = multibuffer.New(8, uint64(sc.s.ReadBufferSize)) - ss.tcpConn.tcpProcessFunc = sc.tcpProcessPlay + + ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP + err = errSwitchReadFunc ss.writeBuffer = ringbuffer.New(uint64(ss.s.ReadBufferCount)) - // run writer after sending the response - ss.tcpConn.tcpWriterRunning = false + // runWriter() is called by conn after sending the response } // add RTP-Info @@ -1016,18 +1010,15 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base default: // TCP ss.tcpConn = sc ss.tcpConn.tcpSession = ss - ss.tcpConn.tcpFrameEnabled = true - ss.tcpConn.tcpFrameTimeout = true - ss.tcpConn.tcpReadBuffer = multibuffer.New(uint64(sc.s.ReadBufferCount), uint64(sc.s.ReadBufferSize)) - ss.tcpConn.tcpRTPPacketBuffer = newRTPPacketMultiBuffer(uint64(sc.s.ReadBufferCount)) - ss.tcpConn.tcpProcessFunc = sc.tcpProcessRecord + + ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP + err = errSwitchReadFunc // when recording, writeBuffer is only used to send RTCP receiver reports, // that are much smaller than RTP packets and are sent at a fixed interval. // decrease RAM consumption by allocating less buffers. ss.writeBuffer = ringbuffer.New(uint64(8)) - // run writer after sending the response - ss.tcpConn.tcpWriterRunning = false + // runWriter() is called by conn after sending the response } return res, err @@ -1089,9 +1080,10 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base case TransportUDPMulticast: default: // TCP + ss.tcpConn.readFunc = ss.tcpConn.readFuncStandard + err = errSwitchReadFunc + ss.tcpConn.tcpSession = nil - ss.tcpConn.tcpFrameEnabled = false - ss.tcpConn.tcpReadBuffer = nil ss.tcpConn = nil } @@ -1108,10 +1100,10 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base case TransportUDPMulticast: default: // TCP + ss.tcpConn.readFunc = ss.tcpConn.readFuncStandard + err = errSwitchReadFunc + ss.tcpConn.tcpSession = nil - ss.tcpConn.tcpFrameEnabled = false - ss.tcpConn.tcpReadBuffer = nil - ss.tcpConn.conn.SetReadDeadline(time.Time{}) ss.tcpConn = nil } }