make WritePacket*() return errors when write queue is full (#388)

This commit is contained in:
Alessandro Ros
2023-08-26 18:09:45 +02:00
committed by GitHub
parent 9453e55f3d
commit 3bdae4ed46
14 changed files with 127 additions and 37 deletions

View File

@@ -44,6 +44,6 @@ func (w *asyncProcessor) run() {
} }
} }
func (w *asyncProcessor) queue(cb func()) { func (w *asyncProcessor) push(cb func()) bool {
w.buffer.Push(cb) return w.buffer.Push(cb)
} }

View File

@@ -1785,8 +1785,7 @@ func (c *Client) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.Packet,
cm := c.medias[medi] cm := c.medias[medi]
ct := cm.formats[pkt.PayloadType] ct := cm.formats[pkt.PayloadType]
ct.writePacketRTP(byts, pkt, ntp) return ct.writePacketRTP(byts, pkt, ntp)
return nil
} }
// WritePacketRTCP writes a RTCP packet to the server. // WritePacketRTCP writes a RTCP packet to the server.
@@ -1803,8 +1802,7 @@ func (c *Client) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error
} }
cm := c.medias[medi] cm := c.medias[medi]
cm.writePacketRTCP(byts) return cm.writePacketRTCP(byts)
return nil
} }
// PacketPTS returns the PTS of an incoming RTP packet. // PacketPTS returns the PTS of an incoming RTP packet.

View File

@@ -8,6 +8,7 @@ import (
"github.com/pion/rtp" "github.com/pion/rtp"
"github.com/bluenviron/gortsplib/v4/pkg/format" "github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
"github.com/bluenviron/gortsplib/v4/pkg/rtcpreceiver" "github.com/bluenviron/gortsplib/v4/pkg/rtcpreceiver"
"github.com/bluenviron/gortsplib/v4/pkg/rtcpsender" "github.com/bluenviron/gortsplib/v4/pkg/rtcpsender"
"github.com/bluenviron/gortsplib/v4/pkg/rtplossdetector" "github.com/bluenviron/gortsplib/v4/pkg/rtplossdetector"
@@ -78,12 +79,17 @@ func (ct *clientFormat) stop() {
} }
} }
func (ct *clientFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) { func (ct *clientFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) error {
ct.rtcpSender.ProcessPacket(pkt, ntp, ct.format.PTSEqualsDTS(pkt)) ct.rtcpSender.ProcessPacket(pkt, ntp, ct.format.PTSEqualsDTS(pkt))
ct.cm.c.writer.queue(func() { ok := ct.cm.c.writer.push(func() {
ct.cm.writePacketRTPInQueue(byts) ct.cm.writePacketRTPInQueue(byts)
}) })
if !ok {
return liberrors.ErrClientWriteQueueFull{}
}
return nil
} }
func (ct *clientFormat) readRTPUDP(pkt *rtp.Packet) { func (ct *clientFormat) readRTPUDP(pkt *rtp.Packet) {

View File

@@ -10,6 +10,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/base" "github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/description" "github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
) )
type clientMedia struct { type clientMedia struct {
@@ -168,10 +169,15 @@ func (cm *clientMedia) writePacketRTCPInQueueTCP(payload []byte) {
cm.c.conn.WriteInterleavedFrame(cm.tcpRTCPFrame, cm.tcpBuffer) //nolint:errcheck cm.c.conn.WriteInterleavedFrame(cm.tcpRTCPFrame, cm.tcpBuffer) //nolint:errcheck
} }
func (cm *clientMedia) writePacketRTCP(byts []byte) { func (cm *clientMedia) writePacketRTCP(byts []byte) error {
cm.c.writer.queue(func() { ok := cm.c.writer.push(func() {
cm.writePacketRTCPInQueue(byts) cm.writePacketRTCPInQueue(byts)
}) })
if !ok {
return liberrors.ErrClientWriteQueueFull{}
}
return nil
} }
func (cm *clientMedia) readRTPTCPPlay(payload []byte) { func (cm *clientMedia) readRTPTCPPlay(payload []byte) {

View File

@@ -776,6 +776,8 @@ func TestClientRecordAutomaticProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer l.Close() defer l.Close()
recv := make(chan struct{})
serverDone := make(chan struct{}) serverDone := make(chan struct{})
defer func() { <-serverDone }() defer func() { <-serverDone }()
go func() { go func() {
@@ -866,6 +868,8 @@ func TestClientRecordAutomaticProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, testRTPPacket, pkt) require.Equal(t, testRTPPacket, pkt)
close(recv)
req, err = conn.ReadRequest() req, err = conn.ReadRequest()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method) require.Equal(t, base.Teardown, req.Method)
@@ -887,6 +891,8 @@ func TestClientRecordAutomaticProtocol(t *testing.T) {
err = c.WritePacketRTP(medi, &testRTPPacket) err = c.WritePacketRTP(medi, &testRTPPacket)
require.NoError(t, err) require.NoError(t, err)
<-recv
} }
func TestClientRecordDecodeErrors(t *testing.T) { func TestClientRecordDecodeErrors(t *testing.T) {

View File

@@ -256,3 +256,11 @@ type ErrClientUnhandledMethod struct {
func (e ErrClientUnhandledMethod) Error() string { func (e ErrClientUnhandledMethod) Error() string {
return fmt.Sprintf("unhandled method: %v", e.Method) return fmt.Sprintf("unhandled method: %v", e.Method)
} }
// ErrClientWriteQueueFull is an error that can be returned by a client.
type ErrClientWriteQueueFull struct{}
// Error implements the error interface.
func (e ErrClientWriteQueueFull) Error() string {
return "write queue is full"
}

View File

@@ -252,10 +252,18 @@ func (e ErrServerUnexpectedFrame) Error() string {
return "received unexpected interleaved frame" return "received unexpected interleaved frame"
} }
// ErrServerUnexpectedResponse is an error that can be returned by a client. // ErrServerUnexpectedResponse is an error that can be returned by a server.
type ErrServerUnexpectedResponse struct{} type ErrServerUnexpectedResponse struct{}
// Error implements the error interface. // Error implements the error interface.
func (e ErrServerUnexpectedResponse) Error() string { func (e ErrServerUnexpectedResponse) Error() string {
return "received unexpected response" return "received unexpected response"
} }
// ErrServerWriteQueueFull is an error that can be returned by a server.
type ErrServerWriteQueueFull struct{}
// Error implements the error interface.
func (e ErrServerWriteQueueFull) Error() string {
return "write queue is full"
}

View File

@@ -215,3 +215,15 @@ type ServerHandlerOnDecodeError interface {
// called when a non-fatal decode error occurs. // called when a non-fatal decode error occurs.
OnDecodeError(*ServerHandlerOnDecodeErrorCtx) OnDecodeError(*ServerHandlerOnDecodeErrorCtx)
} }
// ServerHandlerOnStreamWriteErrorCtx is the context of OnStreamWriteError.
type ServerHandlerOnStreamWriteErrorCtx struct {
Session *ServerSession
Error error
}
// ServerHandlerOnStreamWriteError can be implemented by a ServerHandler.
type ServerHandlerOnStreamWriteError interface {
// called when a write error occurs when writing a stream.
OnStreamWriteError(*ServerHandlerOnStreamWriteErrorCtx)
}

View File

@@ -2,6 +2,8 @@ package gortsplib
import ( import (
"net" "net"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
) )
type serverMulticastWriter struct { type serverMulticastWriter struct {
@@ -62,14 +64,24 @@ func (h *serverMulticastWriter) ip() net.IP {
return h.rtpl.ip() return h.rtpl.ip()
} }
func (h *serverMulticastWriter) writePacketRTP(payload []byte) { func (h *serverMulticastWriter) writePacketRTP(payload []byte) error {
h.writer.queue(func() { ok := h.writer.push(func() {
h.rtpl.write(payload, h.rtpAddr) //nolint:errcheck h.rtpl.write(payload, h.rtpAddr) //nolint:errcheck
}) })
if !ok {
return liberrors.ErrServerWriteQueueFull{}
} }
func (h *serverMulticastWriter) writePacketRTCP(payload []byte) { return nil
h.writer.queue(func() { }
func (h *serverMulticastWriter) writePacketRTCP(payload []byte) error {
ok := h.writer.push(func() {
h.rtcpl.write(payload, h.rtcpAddr) //nolint:errcheck h.rtcpl.write(payload, h.rtcpAddr) //nolint:errcheck
}) })
if !ok {
return liberrors.ErrServerWriteQueueFull{}
}
return nil
} }

View File

@@ -297,6 +297,17 @@ func (ss *ServerSession) onDecodeError(err error) {
} }
} }
func (ss *ServerSession) onStreamWriteError(err error) {
if h, ok := ss.s.Handler.(ServerHandlerOnStreamWriteError); ok {
h.OnStreamWriteError(&ServerHandlerOnStreamWriteErrorCtx{
Session: ss,
Error: err,
})
} else {
log.Println(err.Error())
}
}
func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) error { func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) error {
if _, ok := allowed[ss.state]; ok { if _, ok := allowed[ss.state]; ok {
return nil return nil
@@ -1186,9 +1197,9 @@ func (ss *ServerSession) OnPacketRTCP(medi *description.Media, cb OnPacketRTCPFu
sm.onPacketRTCP = cb sm.onPacketRTCP = cb
} }
func (ss *ServerSession) writePacketRTP(medi *description.Media, byts []byte) { func (ss *ServerSession) writePacketRTP(medi *description.Media, byts []byte) error {
sm := ss.setuppedMedias[medi] sm := ss.setuppedMedias[medi]
sm.writePacketRTP(byts) return sm.writePacketRTP(byts)
} }
// WritePacketRTP writes a RTP packet to the session. // WritePacketRTP writes a RTP packet to the session.
@@ -1200,13 +1211,12 @@ func (ss *ServerSession) WritePacketRTP(medi *description.Media, pkt *rtp.Packet
} }
byts = byts[:n] byts = byts[:n]
ss.writePacketRTP(medi, byts) return ss.writePacketRTP(medi, byts)
return nil
} }
func (ss *ServerSession) writePacketRTCP(medi *description.Media, byts []byte) { func (ss *ServerSession) writePacketRTCP(medi *description.Media, byts []byte) error {
sm := ss.setuppedMedias[medi] sm := ss.setuppedMedias[medi]
sm.writePacketRTCP(byts) return sm.writePacketRTCP(byts)
} }
// WritePacketRTCP writes a RTCP packet to the session. // WritePacketRTCP writes a RTCP packet to the session.
@@ -1216,8 +1226,7 @@ func (ss *ServerSession) WritePacketRTCP(medi *description.Media, pkt rtcp.Packe
return err return err
} }
ss.writePacketRTCP(medi, byts) return ss.writePacketRTCP(medi, byts)
return nil
} }
// PacketPTS returns the PTS of an incoming RTP packet. // PacketPTS returns the PTS of an incoming RTP packet.

View File

@@ -11,6 +11,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/base" "github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/description" "github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
) )
type serverSessionMedia struct { type serverSessionMedia struct {
@@ -142,16 +143,26 @@ func (sm *serverSessionMedia) writePacketRTCPInQueueTCP(payload []byte) {
sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTCPFrame, sm.tcpBuffer) //nolint:errcheck sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTCPFrame, sm.tcpBuffer) //nolint:errcheck
} }
func (sm *serverSessionMedia) writePacketRTP(payload []byte) { func (sm *serverSessionMedia) writePacketRTP(payload []byte) error {
sm.ss.writer.queue(func() { ok := sm.ss.writer.push(func() {
sm.writePacketRTPInQueue(payload) sm.writePacketRTPInQueue(payload)
}) })
if !ok {
return liberrors.ErrServerWriteQueueFull{}
} }
func (sm *serverSessionMedia) writePacketRTCP(payload []byte) { return nil
sm.ss.writer.queue(func() { }
func (sm *serverSessionMedia) writePacketRTCP(payload []byte) error {
ok := sm.ss.writer.push(func() {
sm.writePacketRTCPInQueue(payload) sm.writePacketRTCPInQueue(payload)
}) })
if !ok {
return liberrors.ErrServerWriteQueueFull{}
}
return nil
} }
func (sm *serverSessionMedia) readRTCPUDPPlay(payload []byte) { func (sm *serverSessionMedia) readRTCPUDPPlay(payload []byte) {

View File

@@ -256,8 +256,7 @@ func (st *ServerStream) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.
sm := st.streamMedias[medi] sm := st.streamMedias[medi]
sf := sm.formats[pkt.PayloadType] sf := sm.formats[pkt.PayloadType]
sf.writePacketRTP(byts, pkt, ntp) return sf.writePacketRTP(byts, pkt, ntp)
return nil
} }
// WritePacketRTCP writes a RTCP packet to all the readers of the stream. // WritePacketRTCP writes a RTCP packet to all the readers of the stream.
@@ -275,6 +274,5 @@ func (st *ServerStream) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet
} }
sm := st.streamMedias[medi] sm := st.streamMedias[medi]
sm.writePacketRTCP(byts) return sm.writePacketRTCP(byts)
return nil
} }

View File

@@ -36,19 +36,27 @@ func newServerStreamFormat(sm *serverStreamMedia, forma format.Format) *serverSt
return sf return sf
} }
func (sf *serverStreamFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) { func (sf *serverStreamFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) error {
sf.rtcpSender.ProcessPacket(pkt, ntp, sf.format.PTSEqualsDTS(pkt)) sf.rtcpSender.ProcessPacket(pkt, ntp, sf.format.PTSEqualsDTS(pkt))
// send unicast // send unicast
for r := range sf.sm.st.activeUnicastReaders { for r := range sf.sm.st.activeUnicastReaders {
sm, ok := r.setuppedMedias[sf.sm.media] sm, ok := r.setuppedMedias[sf.sm.media]
if ok { if ok {
sm.writePacketRTP(byts) err := sm.writePacketRTP(byts)
if err != nil {
r.onStreamWriteError(err)
}
} }
} }
// send multicast // send multicast
if sf.sm.multicastWriter != nil { if sf.sm.multicastWriter != nil {
sf.sm.multicastWriter.writePacketRTP(byts) err := sf.sm.multicastWriter.writePacketRTP(byts)
if err != nil {
return err
} }
} }
return nil
}

View File

@@ -53,17 +53,25 @@ func (sm *serverStreamMedia) allocateMulticastHandler(s *Server) error {
return nil return nil
} }
func (sm *serverStreamMedia) writePacketRTCP(byts []byte) { func (sm *serverStreamMedia) writePacketRTCP(byts []byte) error {
// send unicast // send unicast
for r := range sm.st.activeUnicastReaders { for r := range sm.st.activeUnicastReaders {
sm, ok := r.setuppedMedias[sm.media] sm, ok := r.setuppedMedias[sm.media]
if ok { if ok {
sm.writePacketRTCP(byts) err := sm.writePacketRTCP(byts)
if err != nil {
r.onStreamWriteError(err)
}
} }
} }
// send multicast // send multicast
if sm.multicastWriter != nil { if sm.multicastWriter != nil {
sm.multicastWriter.writePacketRTCP(byts) err := sm.multicastWriter.writePacketRTCP(byts)
if err != nil {
return err
} }
} }
return nil
}