make WritePacket*() return errors when write queue is full (#388)

This commit is contained in:
Alessandro Ros
2023-08-26 18:09:45 +02:00
committed by GitHub
parent 9453e55f3d
commit 3bdae4ed46
14 changed files with 127 additions and 37 deletions

View File

@@ -44,6 +44,6 @@ func (w *asyncProcessor) run() {
}
}
func (w *asyncProcessor) queue(cb func()) {
w.buffer.Push(cb)
func (w *asyncProcessor) push(cb func()) bool {
return w.buffer.Push(cb)
}

View File

@@ -1785,8 +1785,7 @@ func (c *Client) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.Packet,
cm := c.medias[medi]
ct := cm.formats[pkt.PayloadType]
ct.writePacketRTP(byts, pkt, ntp)
return nil
return ct.writePacketRTP(byts, pkt, ntp)
}
// WritePacketRTCP writes a RTCP packet to the server.
@@ -1803,8 +1802,7 @@ func (c *Client) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error
}
cm := c.medias[medi]
cm.writePacketRTCP(byts)
return nil
return cm.writePacketRTCP(byts)
}
// PacketPTS returns the PTS of an incoming RTP packet.

View File

@@ -8,6 +8,7 @@ import (
"github.com/pion/rtp"
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
"github.com/bluenviron/gortsplib/v4/pkg/rtcpreceiver"
"github.com/bluenviron/gortsplib/v4/pkg/rtcpsender"
"github.com/bluenviron/gortsplib/v4/pkg/rtplossdetector"
@@ -78,12 +79,17 @@ func (ct *clientFormat) stop() {
}
}
func (ct *clientFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) {
func (ct *clientFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) error {
ct.rtcpSender.ProcessPacket(pkt, ntp, ct.format.PTSEqualsDTS(pkt))
ct.cm.c.writer.queue(func() {
ok := ct.cm.c.writer.push(func() {
ct.cm.writePacketRTPInQueue(byts)
})
if !ok {
return liberrors.ErrClientWriteQueueFull{}
}
return nil
}
func (ct *clientFormat) readRTPUDP(pkt *rtp.Packet) {

View File

@@ -10,6 +10,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
)
type clientMedia struct {
@@ -168,10 +169,15 @@ func (cm *clientMedia) writePacketRTCPInQueueTCP(payload []byte) {
cm.c.conn.WriteInterleavedFrame(cm.tcpRTCPFrame, cm.tcpBuffer) //nolint:errcheck
}
func (cm *clientMedia) writePacketRTCP(byts []byte) {
cm.c.writer.queue(func() {
func (cm *clientMedia) writePacketRTCP(byts []byte) error {
ok := cm.c.writer.push(func() {
cm.writePacketRTCPInQueue(byts)
})
if !ok {
return liberrors.ErrClientWriteQueueFull{}
}
return nil
}
func (cm *clientMedia) readRTPTCPPlay(payload []byte) {

View File

@@ -776,6 +776,8 @@ func TestClientRecordAutomaticProtocol(t *testing.T) {
require.NoError(t, err)
defer l.Close()
recv := make(chan struct{})
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
@@ -866,6 +868,8 @@ func TestClientRecordAutomaticProtocol(t *testing.T) {
require.NoError(t, err)
require.Equal(t, testRTPPacket, pkt)
close(recv)
req, err = conn.ReadRequest()
require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method)
@@ -887,6 +891,8 @@ func TestClientRecordAutomaticProtocol(t *testing.T) {
err = c.WritePacketRTP(medi, &testRTPPacket)
require.NoError(t, err)
<-recv
}
func TestClientRecordDecodeErrors(t *testing.T) {

View File

@@ -256,3 +256,11 @@ type ErrClientUnhandledMethod struct {
func (e ErrClientUnhandledMethod) Error() string {
return fmt.Sprintf("unhandled method: %v", e.Method)
}
// ErrClientWriteQueueFull is an error that can be returned by a client.
type ErrClientWriteQueueFull struct{}
// Error implements the error interface.
func (e ErrClientWriteQueueFull) Error() string {
return "write queue is full"
}

View File

@@ -252,10 +252,18 @@ func (e ErrServerUnexpectedFrame) Error() string {
return "received unexpected interleaved frame"
}
// ErrServerUnexpectedResponse is an error that can be returned by a client.
// ErrServerUnexpectedResponse is an error that can be returned by a server.
type ErrServerUnexpectedResponse struct{}
// Error implements the error interface.
func (e ErrServerUnexpectedResponse) Error() string {
return "received unexpected response"
}
// ErrServerWriteQueueFull is an error that can be returned by a server.
type ErrServerWriteQueueFull struct{}
// Error implements the error interface.
func (e ErrServerWriteQueueFull) Error() string {
return "write queue is full"
}

View File

@@ -215,3 +215,15 @@ type ServerHandlerOnDecodeError interface {
// called when a non-fatal decode error occurs.
OnDecodeError(*ServerHandlerOnDecodeErrorCtx)
}
// ServerHandlerOnStreamWriteErrorCtx is the context of OnStreamWriteError.
type ServerHandlerOnStreamWriteErrorCtx struct {
Session *ServerSession
Error error
}
// ServerHandlerOnStreamWriteError can be implemented by a ServerHandler.
type ServerHandlerOnStreamWriteError interface {
// called when a write error occurs when writing a stream.
OnStreamWriteError(*ServerHandlerOnStreamWriteErrorCtx)
}

View File

@@ -2,6 +2,8 @@ package gortsplib
import (
"net"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
)
type serverMulticastWriter struct {
@@ -62,14 +64,24 @@ func (h *serverMulticastWriter) ip() net.IP {
return h.rtpl.ip()
}
func (h *serverMulticastWriter) writePacketRTP(payload []byte) {
h.writer.queue(func() {
func (h *serverMulticastWriter) writePacketRTP(payload []byte) error {
ok := h.writer.push(func() {
h.rtpl.write(payload, h.rtpAddr) //nolint:errcheck
})
if !ok {
return liberrors.ErrServerWriteQueueFull{}
}
return nil
}
func (h *serverMulticastWriter) writePacketRTCP(payload []byte) {
h.writer.queue(func() {
func (h *serverMulticastWriter) writePacketRTCP(payload []byte) error {
ok := h.writer.push(func() {
h.rtcpl.write(payload, h.rtcpAddr) //nolint:errcheck
})
if !ok {
return liberrors.ErrServerWriteQueueFull{}
}
return nil
}

View File

@@ -297,6 +297,17 @@ func (ss *ServerSession) onDecodeError(err error) {
}
}
func (ss *ServerSession) onStreamWriteError(err error) {
if h, ok := ss.s.Handler.(ServerHandlerOnStreamWriteError); ok {
h.OnStreamWriteError(&ServerHandlerOnStreamWriteErrorCtx{
Session: ss,
Error: err,
})
} else {
log.Println(err.Error())
}
}
func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) error {
if _, ok := allowed[ss.state]; ok {
return nil
@@ -1186,9 +1197,9 @@ func (ss *ServerSession) OnPacketRTCP(medi *description.Media, cb OnPacketRTCPFu
sm.onPacketRTCP = cb
}
func (ss *ServerSession) writePacketRTP(medi *description.Media, byts []byte) {
func (ss *ServerSession) writePacketRTP(medi *description.Media, byts []byte) error {
sm := ss.setuppedMedias[medi]
sm.writePacketRTP(byts)
return sm.writePacketRTP(byts)
}
// WritePacketRTP writes a RTP packet to the session.
@@ -1200,13 +1211,12 @@ func (ss *ServerSession) WritePacketRTP(medi *description.Media, pkt *rtp.Packet
}
byts = byts[:n]
ss.writePacketRTP(medi, byts)
return nil
return ss.writePacketRTP(medi, byts)
}
func (ss *ServerSession) writePacketRTCP(medi *description.Media, byts []byte) {
func (ss *ServerSession) writePacketRTCP(medi *description.Media, byts []byte) error {
sm := ss.setuppedMedias[medi]
sm.writePacketRTCP(byts)
return sm.writePacketRTCP(byts)
}
// WritePacketRTCP writes a RTCP packet to the session.
@@ -1216,8 +1226,7 @@ func (ss *ServerSession) WritePacketRTCP(medi *description.Media, pkt rtcp.Packe
return err
}
ss.writePacketRTCP(medi, byts)
return nil
return ss.writePacketRTCP(medi, byts)
}
// PacketPTS returns the PTS of an incoming RTP packet.

View File

@@ -11,6 +11,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
)
type serverSessionMedia struct {
@@ -142,16 +143,26 @@ func (sm *serverSessionMedia) writePacketRTCPInQueueTCP(payload []byte) {
sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTCPFrame, sm.tcpBuffer) //nolint:errcheck
}
func (sm *serverSessionMedia) writePacketRTP(payload []byte) {
sm.ss.writer.queue(func() {
func (sm *serverSessionMedia) writePacketRTP(payload []byte) error {
ok := sm.ss.writer.push(func() {
sm.writePacketRTPInQueue(payload)
})
if !ok {
return liberrors.ErrServerWriteQueueFull{}
}
return nil
}
func (sm *serverSessionMedia) writePacketRTCP(payload []byte) {
sm.ss.writer.queue(func() {
func (sm *serverSessionMedia) writePacketRTCP(payload []byte) error {
ok := sm.ss.writer.push(func() {
sm.writePacketRTCPInQueue(payload)
})
if !ok {
return liberrors.ErrServerWriteQueueFull{}
}
return nil
}
func (sm *serverSessionMedia) readRTCPUDPPlay(payload []byte) {

View File

@@ -256,8 +256,7 @@ func (st *ServerStream) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.
sm := st.streamMedias[medi]
sf := sm.formats[pkt.PayloadType]
sf.writePacketRTP(byts, pkt, ntp)
return nil
return sf.writePacketRTP(byts, pkt, ntp)
}
// WritePacketRTCP writes a RTCP packet to all the readers of the stream.
@@ -275,6 +274,5 @@ func (st *ServerStream) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet
}
sm := st.streamMedias[medi]
sm.writePacketRTCP(byts)
return nil
return sm.writePacketRTCP(byts)
}

View File

@@ -36,19 +36,27 @@ func newServerStreamFormat(sm *serverStreamMedia, forma format.Format) *serverSt
return sf
}
func (sf *serverStreamFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) {
func (sf *serverStreamFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) error {
sf.rtcpSender.ProcessPacket(pkt, ntp, sf.format.PTSEqualsDTS(pkt))
// send unicast
for r := range sf.sm.st.activeUnicastReaders {
sm, ok := r.setuppedMedias[sf.sm.media]
if ok {
sm.writePacketRTP(byts)
err := sm.writePacketRTP(byts)
if err != nil {
r.onStreamWriteError(err)
}
}
}
// send multicast
if sf.sm.multicastWriter != nil {
sf.sm.multicastWriter.writePacketRTP(byts)
err := sf.sm.multicastWriter.writePacketRTP(byts)
if err != nil {
return err
}
}
return nil
}

View File

@@ -53,17 +53,25 @@ func (sm *serverStreamMedia) allocateMulticastHandler(s *Server) error {
return nil
}
func (sm *serverStreamMedia) writePacketRTCP(byts []byte) {
func (sm *serverStreamMedia) writePacketRTCP(byts []byte) error {
// send unicast
for r := range sm.st.activeUnicastReaders {
sm, ok := r.setuppedMedias[sm.media]
if ok {
sm.writePacketRTCP(byts)
err := sm.writePacketRTCP(byts)
if err != nil {
r.onStreamWriteError(err)
}
}
}
// send multicast
if sm.multicastWriter != nil {
sm.multicastWriter.writePacketRTCP(byts)
err := sm.multicastWriter.writePacketRTCP(byts)
if err != nil {
return err
}
}
return nil
}