diff --git a/async_processor.go b/async_processor.go index 5f1fc271..d63230f8 100644 --- a/async_processor.go +++ b/async_processor.go @@ -44,6 +44,6 @@ func (w *asyncProcessor) run() { } } -func (w *asyncProcessor) queue(cb func()) { - w.buffer.Push(cb) +func (w *asyncProcessor) push(cb func()) bool { + return w.buffer.Push(cb) } diff --git a/client.go b/client.go index 26f64f95..c3108ee3 100644 --- a/client.go +++ b/client.go @@ -1785,8 +1785,7 @@ func (c *Client) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.Packet, cm := c.medias[medi] ct := cm.formats[pkt.PayloadType] - ct.writePacketRTP(byts, pkt, ntp) - return nil + return ct.writePacketRTP(byts, pkt, ntp) } // WritePacketRTCP writes a RTCP packet to the server. @@ -1803,8 +1802,7 @@ func (c *Client) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error } cm := c.medias[medi] - cm.writePacketRTCP(byts) - return nil + return cm.writePacketRTCP(byts) } // PacketPTS returns the PTS of an incoming RTP packet. diff --git a/client_format.go b/client_format.go index 62496c47..d07b4b88 100644 --- a/client_format.go +++ b/client_format.go @@ -8,6 +8,7 @@ import ( "github.com/pion/rtp" "github.com/bluenviron/gortsplib/v4/pkg/format" + "github.com/bluenviron/gortsplib/v4/pkg/liberrors" "github.com/bluenviron/gortsplib/v4/pkg/rtcpreceiver" "github.com/bluenviron/gortsplib/v4/pkg/rtcpsender" "github.com/bluenviron/gortsplib/v4/pkg/rtplossdetector" @@ -78,12 +79,17 @@ func (ct *clientFormat) stop() { } } -func (ct *clientFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) { +func (ct *clientFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) error { ct.rtcpSender.ProcessPacket(pkt, ntp, ct.format.PTSEqualsDTS(pkt)) - ct.cm.c.writer.queue(func() { + ok := ct.cm.c.writer.push(func() { ct.cm.writePacketRTPInQueue(byts) }) + if !ok { + return liberrors.ErrClientWriteQueueFull{} + } + + return nil } func (ct *clientFormat) readRTPUDP(pkt *rtp.Packet) { diff --git a/client_media.go b/client_media.go index 9e49b403..e2eca9fd 100644 --- a/client_media.go +++ b/client_media.go @@ -10,6 +10,7 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/base" "github.com/bluenviron/gortsplib/v4/pkg/description" + "github.com/bluenviron/gortsplib/v4/pkg/liberrors" ) type clientMedia struct { @@ -168,10 +169,15 @@ func (cm *clientMedia) writePacketRTCPInQueueTCP(payload []byte) { cm.c.conn.WriteInterleavedFrame(cm.tcpRTCPFrame, cm.tcpBuffer) //nolint:errcheck } -func (cm *clientMedia) writePacketRTCP(byts []byte) { - cm.c.writer.queue(func() { +func (cm *clientMedia) writePacketRTCP(byts []byte) error { + ok := cm.c.writer.push(func() { cm.writePacketRTCPInQueue(byts) }) + if !ok { + return liberrors.ErrClientWriteQueueFull{} + } + + return nil } func (cm *clientMedia) readRTPTCPPlay(payload []byte) { diff --git a/client_record_test.go b/client_record_test.go index e3a04e85..9564e30f 100644 --- a/client_record_test.go +++ b/client_record_test.go @@ -776,6 +776,8 @@ func TestClientRecordAutomaticProtocol(t *testing.T) { require.NoError(t, err) defer l.Close() + recv := make(chan struct{}) + serverDone := make(chan struct{}) defer func() { <-serverDone }() go func() { @@ -866,6 +868,8 @@ func TestClientRecordAutomaticProtocol(t *testing.T) { require.NoError(t, err) require.Equal(t, testRTPPacket, pkt) + close(recv) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -887,6 +891,8 @@ func TestClientRecordAutomaticProtocol(t *testing.T) { err = c.WritePacketRTP(medi, &testRTPPacket) require.NoError(t, err) + + <-recv } func TestClientRecordDecodeErrors(t *testing.T) { diff --git a/pkg/liberrors/client.go b/pkg/liberrors/client.go index 90d94dbd..d4f9a761 100644 --- a/pkg/liberrors/client.go +++ b/pkg/liberrors/client.go @@ -256,3 +256,11 @@ type ErrClientUnhandledMethod struct { func (e ErrClientUnhandledMethod) Error() string { return fmt.Sprintf("unhandled method: %v", e.Method) } + +// ErrClientWriteQueueFull is an error that can be returned by a client. +type ErrClientWriteQueueFull struct{} + +// Error implements the error interface. +func (e ErrClientWriteQueueFull) Error() string { + return "write queue is full" +} diff --git a/pkg/liberrors/server.go b/pkg/liberrors/server.go index 274cc371..4fd602bd 100644 --- a/pkg/liberrors/server.go +++ b/pkg/liberrors/server.go @@ -252,10 +252,18 @@ func (e ErrServerUnexpectedFrame) Error() string { return "received unexpected interleaved frame" } -// ErrServerUnexpectedResponse is an error that can be returned by a client. +// ErrServerUnexpectedResponse is an error that can be returned by a server. type ErrServerUnexpectedResponse struct{} // Error implements the error interface. func (e ErrServerUnexpectedResponse) Error() string { return "received unexpected response" } + +// ErrServerWriteQueueFull is an error that can be returned by a server. +type ErrServerWriteQueueFull struct{} + +// Error implements the error interface. +func (e ErrServerWriteQueueFull) Error() string { + return "write queue is full" +} diff --git a/server_handler.go b/server_handler.go index d0e49218..28aff138 100644 --- a/server_handler.go +++ b/server_handler.go @@ -215,3 +215,15 @@ type ServerHandlerOnDecodeError interface { // called when a non-fatal decode error occurs. OnDecodeError(*ServerHandlerOnDecodeErrorCtx) } + +// ServerHandlerOnStreamWriteErrorCtx is the context of OnStreamWriteError. +type ServerHandlerOnStreamWriteErrorCtx struct { + Session *ServerSession + Error error +} + +// ServerHandlerOnStreamWriteError can be implemented by a ServerHandler. +type ServerHandlerOnStreamWriteError interface { + // called when a write error occurs when writing a stream. + OnStreamWriteError(*ServerHandlerOnStreamWriteErrorCtx) +} diff --git a/server_multicast_writer.go b/server_multicast_writer.go index 5c2e3524..cc2af0ef 100644 --- a/server_multicast_writer.go +++ b/server_multicast_writer.go @@ -2,6 +2,8 @@ package gortsplib import ( "net" + + "github.com/bluenviron/gortsplib/v4/pkg/liberrors" ) type serverMulticastWriter struct { @@ -62,14 +64,24 @@ func (h *serverMulticastWriter) ip() net.IP { return h.rtpl.ip() } -func (h *serverMulticastWriter) writePacketRTP(payload []byte) { - h.writer.queue(func() { +func (h *serverMulticastWriter) writePacketRTP(payload []byte) error { + ok := h.writer.push(func() { h.rtpl.write(payload, h.rtpAddr) //nolint:errcheck }) + if !ok { + return liberrors.ErrServerWriteQueueFull{} + } + + return nil } -func (h *serverMulticastWriter) writePacketRTCP(payload []byte) { - h.writer.queue(func() { +func (h *serverMulticastWriter) writePacketRTCP(payload []byte) error { + ok := h.writer.push(func() { h.rtcpl.write(payload, h.rtcpAddr) //nolint:errcheck }) + if !ok { + return liberrors.ErrServerWriteQueueFull{} + } + + return nil } diff --git a/server_session.go b/server_session.go index 9ccca424..0d567068 100644 --- a/server_session.go +++ b/server_session.go @@ -297,6 +297,17 @@ func (ss *ServerSession) onDecodeError(err error) { } } +func (ss *ServerSession) onStreamWriteError(err error) { + if h, ok := ss.s.Handler.(ServerHandlerOnStreamWriteError); ok { + h.OnStreamWriteError(&ServerHandlerOnStreamWriteErrorCtx{ + Session: ss, + Error: err, + }) + } else { + log.Println(err.Error()) + } +} + func (ss *ServerSession) checkState(allowed map[ServerSessionState]struct{}) error { if _, ok := allowed[ss.state]; ok { return nil @@ -1186,9 +1197,9 @@ func (ss *ServerSession) OnPacketRTCP(medi *description.Media, cb OnPacketRTCPFu sm.onPacketRTCP = cb } -func (ss *ServerSession) writePacketRTP(medi *description.Media, byts []byte) { +func (ss *ServerSession) writePacketRTP(medi *description.Media, byts []byte) error { sm := ss.setuppedMedias[medi] - sm.writePacketRTP(byts) + return sm.writePacketRTP(byts) } // WritePacketRTP writes a RTP packet to the session. @@ -1200,13 +1211,12 @@ func (ss *ServerSession) WritePacketRTP(medi *description.Media, pkt *rtp.Packet } byts = byts[:n] - ss.writePacketRTP(medi, byts) - return nil + return ss.writePacketRTP(medi, byts) } -func (ss *ServerSession) writePacketRTCP(medi *description.Media, byts []byte) { +func (ss *ServerSession) writePacketRTCP(medi *description.Media, byts []byte) error { sm := ss.setuppedMedias[medi] - sm.writePacketRTCP(byts) + return sm.writePacketRTCP(byts) } // WritePacketRTCP writes a RTCP packet to the session. @@ -1216,8 +1226,7 @@ func (ss *ServerSession) WritePacketRTCP(medi *description.Media, pkt rtcp.Packe return err } - ss.writePacketRTCP(medi, byts) - return nil + return ss.writePacketRTCP(medi, byts) } // PacketPTS returns the PTS of an incoming RTP packet. diff --git a/server_session_media.go b/server_session_media.go index d979793b..a7bdc03a 100644 --- a/server_session_media.go +++ b/server_session_media.go @@ -11,6 +11,7 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/base" "github.com/bluenviron/gortsplib/v4/pkg/description" + "github.com/bluenviron/gortsplib/v4/pkg/liberrors" ) type serverSessionMedia struct { @@ -142,16 +143,26 @@ func (sm *serverSessionMedia) writePacketRTCPInQueueTCP(payload []byte) { sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTCPFrame, sm.tcpBuffer) //nolint:errcheck } -func (sm *serverSessionMedia) writePacketRTP(payload []byte) { - sm.ss.writer.queue(func() { +func (sm *serverSessionMedia) writePacketRTP(payload []byte) error { + ok := sm.ss.writer.push(func() { sm.writePacketRTPInQueue(payload) }) + if !ok { + return liberrors.ErrServerWriteQueueFull{} + } + + return nil } -func (sm *serverSessionMedia) writePacketRTCP(payload []byte) { - sm.ss.writer.queue(func() { +func (sm *serverSessionMedia) writePacketRTCP(payload []byte) error { + ok := sm.ss.writer.push(func() { sm.writePacketRTCPInQueue(payload) }) + if !ok { + return liberrors.ErrServerWriteQueueFull{} + } + + return nil } func (sm *serverSessionMedia) readRTCPUDPPlay(payload []byte) { diff --git a/server_stream.go b/server_stream.go index 80083765..2d58f5af 100644 --- a/server_stream.go +++ b/server_stream.go @@ -256,8 +256,7 @@ func (st *ServerStream) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp. sm := st.streamMedias[medi] sf := sm.formats[pkt.PayloadType] - sf.writePacketRTP(byts, pkt, ntp) - return nil + return sf.writePacketRTP(byts, pkt, ntp) } // WritePacketRTCP writes a RTCP packet to all the readers of the stream. @@ -275,6 +274,5 @@ func (st *ServerStream) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet } sm := st.streamMedias[medi] - sm.writePacketRTCP(byts) - return nil + return sm.writePacketRTCP(byts) } diff --git a/server_stream_format.go b/server_stream_format.go index f1efc436..6afa385f 100644 --- a/server_stream_format.go +++ b/server_stream_format.go @@ -36,19 +36,27 @@ func newServerStreamFormat(sm *serverStreamMedia, forma format.Format) *serverSt return sf } -func (sf *serverStreamFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) { +func (sf *serverStreamFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) error { sf.rtcpSender.ProcessPacket(pkt, ntp, sf.format.PTSEqualsDTS(pkt)) // send unicast for r := range sf.sm.st.activeUnicastReaders { sm, ok := r.setuppedMedias[sf.sm.media] if ok { - sm.writePacketRTP(byts) + err := sm.writePacketRTP(byts) + if err != nil { + r.onStreamWriteError(err) + } } } // send multicast if sf.sm.multicastWriter != nil { - sf.sm.multicastWriter.writePacketRTP(byts) + err := sf.sm.multicastWriter.writePacketRTP(byts) + if err != nil { + return err + } } + + return nil } diff --git a/server_stream_media.go b/server_stream_media.go index 619c83c2..19798003 100644 --- a/server_stream_media.go +++ b/server_stream_media.go @@ -53,17 +53,25 @@ func (sm *serverStreamMedia) allocateMulticastHandler(s *Server) error { return nil } -func (sm *serverStreamMedia) writePacketRTCP(byts []byte) { +func (sm *serverStreamMedia) writePacketRTCP(byts []byte) error { // send unicast for r := range sm.st.activeUnicastReaders { sm, ok := r.setuppedMedias[sm.media] if ok { - sm.writePacketRTCP(byts) + err := sm.writePacketRTCP(byts) + if err != nil { + r.onStreamWriteError(err) + } } } // send multicast if sm.multicastWriter != nil { - sm.multicastWriter.writePacketRTCP(byts) + err := sm.multicastWriter.writePacketRTCP(byts) + if err != nil { + return err + } } + + return nil }