close connections in case of write errors (#613) (#655)

This commit is contained in:
Alessandro Ros
2024-12-14 13:45:11 +01:00
committed by GitHub
parent a2df9d83b3
commit 8f74559616
12 changed files with 427 additions and 350 deletions

View File

@@ -4,46 +4,57 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/ringbuffer"
)
// this struct contains a queue that allows to detach the routine that is reading a stream
// this is an asynchronous queue processor
// that allows to detach the routine that is reading a stream
// from the routine that is writing a stream.
type asyncProcessor struct {
bufferSize int
running bool
buffer *ringbuffer.RingBuffer
done chan struct{}
chError chan error
}
func (w *asyncProcessor) allocateBuffer(size int) {
w.buffer, _ = ringbuffer.New(uint64(size))
func (w *asyncProcessor) initialize() {
w.buffer, _ = ringbuffer.New(uint64(w.bufferSize))
}
func (w *asyncProcessor) start() {
w.running = true
w.done = make(chan struct{})
w.chError = make(chan error)
go w.run()
}
func (w *asyncProcessor) stop() {
if w.running {
w.buffer.Close()
<-w.done
w.running = false
if !w.running {
panic("should not happen")
}
w.buffer.Close()
<-w.chError
w.running = false
}
func (w *asyncProcessor) run() {
defer close(w.done)
err := w.runInner()
w.chError <- err
close(w.chError)
}
func (w *asyncProcessor) runInner() error {
for {
tmp, ok := w.buffer.Pull()
if !ok {
return
return nil
}
tmp.(func())()
err := tmp.(func() error)()
if err != nil {
return err
}
}
}
func (w *asyncProcessor) push(cb func()) bool {
func (w *asyncProcessor) push(cb func() error) bool {
return w.buffer.Push(cb)
}

127
client.go
View File

@@ -335,22 +335,19 @@ type Client struct {
keepalivePeriod time.Duration
keepaliveTimer *time.Timer
closeError error
writer asyncProcessor
writer *asyncProcessor
reader *clientReader
timeDecoder *rtptime.GlobalDecoder2
mustClose bool
// in
chOptions chan optionsReq
chDescribe chan describeReq
chAnnounce chan announceReq
chSetup chan setupReq
chPlay chan playReq
chRecord chan recordReq
chPause chan pauseReq
chReadError chan error
chReadResponse chan *base.Response
chReadRequest chan *base.Request
chOptions chan optionsReq
chDescribe chan describeReq
chAnnounce chan announceReq
chSetup chan setupReq
chPlay chan playReq
chRecord chan recordReq
chPause chan pauseReq
// out
done chan struct{}
@@ -462,9 +459,6 @@ func (c *Client) Start(scheme string, host string) error {
c.chPlay = make(chan playReq)
c.chRecord = make(chan recordReq)
c.chPause = make(chan pauseReq)
c.chReadError = make(chan error)
c.chReadResponse = make(chan *base.Response)
c.chReadRequest = make(chan *base.Request)
c.done = make(chan struct{})
go c.run()
@@ -530,6 +524,34 @@ func (c *Client) run() {
func (c *Client) runInner() error {
for {
chReaderResponse := func() chan *base.Response {
if c.reader != nil {
return c.reader.chResponse
}
return nil
}()
chReaderRequest := func() chan *base.Request {
if c.reader != nil {
return c.reader.chRequest
}
return nil
}()
chReaderError := func() chan error {
if c.reader != nil {
return c.reader.chError
}
return nil
}()
chWriterError := func() chan error {
if c.writer != nil {
return c.writer.chError
}
return nil
}()
select {
case req := <-c.chOptions:
res, err := c.doOptions(req.url)
@@ -601,15 +623,18 @@ func (c *Client) runInner() error {
}
c.keepaliveTimer = time.NewTimer(c.keepalivePeriod)
case err := <-c.chReadError:
case err := <-chWriterError:
return err
case err := <-chReaderError:
c.reader = nil
return err
case res := <-c.chReadResponse:
case res := <-chReaderResponse:
c.OnResponse(res)
// these are responses to keepalives, ignore them.
case req := <-c.chReadRequest:
case req := <-chReaderRequest:
err := c.handleServerRequest(req)
if err != nil {
return err
@@ -630,11 +655,11 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) {
case <-t.C:
return nil, liberrors.ErrClientRequestTimedOut{}
case err := <-c.chReadError:
case err := <-c.reader.chError:
c.reader = nil
return nil, err
case res := <-c.chReadResponse:
case res := <-c.reader.chResponse:
c.OnResponse(res)
// accept response if CSeq equals request CSeq, or if CSeq is not present
@@ -642,7 +667,7 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) {
return res, nil
}
case req := <-c.chReadRequest:
case req := <-c.reader.chRequest:
err := c.handleServerRequest(req)
if err != nil {
return nil, err
@@ -682,8 +707,8 @@ func (c *Client) handleServerRequest(req *base.Request) error {
func (c *Client) doClose() {
if c.state == clientStatePlay || c.state == clientStateRecord {
c.stopWriter()
c.stopReadRoutines()
c.writer.stop()
c.stopTransportRoutines()
}
if c.nconn != nil && c.baseURL != nil {
@@ -808,15 +833,21 @@ func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *base.UR
return c.doSetup(baseURL, medi, 0, 0)
}
func (c *Client) startReadRoutines() {
func (c *Client) startTransportRoutines() {
// allocate writer here because it's needed by RTCP receiver / sender
if c.state == clientStateRecord || c.backChannelSetupped {
c.writer.allocateBuffer(c.WriteQueueSize)
c.writer = &asyncProcessor{
bufferSize: c.WriteQueueSize,
}
c.writer.initialize()
} else {
// when reading, buffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
c.writer.allocateBuffer(8)
c.writer = &asyncProcessor{
bufferSize: 8,
}
c.writer.initialize()
}
c.timeDecoder = rtptime.NewGlobalDecoder2()
@@ -848,7 +879,7 @@ func (c *Client) startReadRoutines() {
}
}
func (c *Client) stopReadRoutines() {
func (c *Client) stopTransportRoutines() {
if c.reader != nil {
c.reader.setAllowInterleavedFrames(false)
}
@@ -861,14 +892,8 @@ func (c *Client) stopReadRoutines() {
}
c.timeDecoder = nil
}
func (c *Client) startWriter() {
c.writer.start()
}
func (c *Client) stopWriter() {
c.writer.stop()
c.writer = nil
}
func (c *Client) connOpen() error {
@@ -1637,7 +1662,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
}
c.state = clientStatePlay
c.startReadRoutines()
c.startTransportRoutines()
// Range is mandatory in Parrot Streaming Server
if ra == nil {
@@ -1662,13 +1687,13 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
Header: header,
}, false)
if err != nil {
c.stopReadRoutines()
c.stopTransportRoutines()
c.state = clientStatePrePlay
return nil, err
}
if res.StatusCode != base.StatusOK {
c.stopReadRoutines()
c.stopTransportRoutines()
c.state = clientStatePrePlay
return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage,
@@ -1689,7 +1714,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
}
}
c.startWriter()
c.writer.start()
c.lastRange = ra
return res, nil
@@ -1718,27 +1743,27 @@ func (c *Client) doRecord() (*base.Response, error) {
}
c.state = clientStateRecord
c.startReadRoutines()
c.startTransportRoutines()
res, err := c.do(&base.Request{
Method: base.Record,
URL: c.baseURL,
}, false)
if err != nil {
c.stopReadRoutines()
c.stopTransportRoutines()
c.state = clientStatePreRecord
return nil, err
}
if res.StatusCode != base.StatusOK {
c.stopReadRoutines()
c.stopTransportRoutines()
c.state = clientStatePreRecord
return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage,
}
}
c.startWriter()
c.writer.start()
return nil, nil
}
@@ -1766,25 +1791,25 @@ func (c *Client) doPause() (*base.Response, error) {
return nil, err
}
c.stopWriter()
c.writer.stop()
res, err := c.do(&base.Request{
Method: base.Pause,
URL: c.baseURL,
}, false)
if err != nil {
c.startWriter()
c.writer.start()
return nil, err
}
if res.StatusCode != base.StatusOK {
c.startWriter()
c.writer.start()
return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage,
}
}
c.stopReadRoutines()
c.stopTransportRoutines()
switch c.state {
case clientStatePlay:
@@ -1929,15 +1954,3 @@ func (c *Client) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time,
ct := cm.formats[pkt.PayloadType]
return ct.rtcpReceiver.PacketNTP(pkt.Timestamp)
}
func (c *Client) readResponse(res *base.Response) {
c.chReadResponse <- res
}
func (c *Client) readRequest(req *base.Request) {
c.chReadRequest <- req
}
func (c *Client) readError(err error) {
c.chReadError <- err
}

View File

@@ -74,8 +74,8 @@ func (cf *clientFormat) stop() {
func (cf *clientFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) error {
cf.rtcpSender.ProcessPacket(pkt, ntp, cf.format.PTSEqualsDTS(pkt))
ok := cf.cm.c.writer.push(func() {
cf.cm.writePacketRTPInQueue(byts)
ok := cf.cm.c.writer.push(func() error {
return cf.cm.writePacketRTPInQueue(byts)
})
if !ok {
return liberrors.ErrClientWriteQueueFull{}

View File

@@ -25,8 +25,8 @@ type clientMedia struct {
tcpRTPFrame *base.InterleavedFrame
tcpRTCPFrame *base.InterleavedFrame
tcpBuffer []byte
writePacketRTPInQueue func([]byte)
writePacketRTCPInQueue func([]byte)
writePacketRTPInQueue func([]byte) error
writePacketRTCPInQueue func([]byte) error
}
func (cm *clientMedia) close() {
@@ -152,29 +152,29 @@ func (cm *clientMedia) findFormatWithSSRC(ssrc uint32) *clientFormat {
return nil
}
func (cm *clientMedia) writePacketRTPInQueueUDP(payload []byte) {
cm.udpRTPListener.write(payload) //nolint:errcheck
func (cm *clientMedia) writePacketRTPInQueueUDP(payload []byte) error {
return cm.udpRTPListener.write(payload)
}
func (cm *clientMedia) writePacketRTCPInQueueUDP(payload []byte) {
cm.udpRTCPListener.write(payload) //nolint:errcheck
func (cm *clientMedia) writePacketRTCPInQueueUDP(payload []byte) error {
return cm.udpRTCPListener.write(payload)
}
func (cm *clientMedia) writePacketRTPInQueueTCP(payload []byte) {
func (cm *clientMedia) writePacketRTPInQueueTCP(payload []byte) error {
cm.tcpRTPFrame.Payload = payload
cm.c.nconn.SetWriteDeadline(time.Now().Add(cm.c.WriteTimeout))
cm.c.conn.WriteInterleavedFrame(cm.tcpRTPFrame, cm.tcpBuffer) //nolint:errcheck
return cm.c.conn.WriteInterleavedFrame(cm.tcpRTPFrame, cm.tcpBuffer)
}
func (cm *clientMedia) writePacketRTCPInQueueTCP(payload []byte) {
func (cm *clientMedia) writePacketRTCPInQueueTCP(payload []byte) error {
cm.tcpRTCPFrame.Payload = payload
cm.c.nconn.SetWriteDeadline(time.Now().Add(cm.c.WriteTimeout))
cm.c.conn.WriteInterleavedFrame(cm.tcpRTCPFrame, cm.tcpBuffer) //nolint:errcheck
return cm.c.conn.WriteInterleavedFrame(cm.tcpRTCPFrame, cm.tcpBuffer)
}
func (cm *clientMedia) writePacketRTCP(byts []byte) error {
ok := cm.c.writer.push(func() {
cm.writePacketRTCPInQueue(byts)
ok := cm.c.writer.push(func() error {
return cm.writePacketRTCPInQueue(byts)
})
if !ok {
return liberrors.ErrClientWriteQueueFull{}

View File

@@ -12,9 +12,17 @@ type clientReader struct {
mutex sync.Mutex
allowInterleavedFrames bool
chResponse chan *base.Response
chRequest chan *base.Request
chError chan error
}
func (r *clientReader) start() {
r.chResponse = make(chan *base.Response)
r.chRequest = make(chan *base.Request)
r.chError = make(chan error)
go r.run()
}
@@ -27,18 +35,17 @@ func (r *clientReader) setAllowInterleavedFrames(v bool) {
func (r *clientReader) wait() {
for {
select {
case <-r.c.chReadError:
case <-r.chError:
return
case <-r.c.chReadResponse:
case <-r.c.chReadRequest:
case <-r.chResponse:
case <-r.chRequest:
}
}
}
func (r *clientReader) run() {
err := r.runInner()
r.c.readError(err)
r.chError <- r.runInner()
}
func (r *clientReader) runInner() error {
@@ -50,10 +57,10 @@ func (r *clientReader) runInner() error {
switch what := what.(type) {
case *base.Response:
r.c.readResponse(what)
r.chResponse <- what
case *base.Request:
r.c.readRequest(what)
r.chRequest <- what
case *base.InterleavedFrame:
r.mutex.Lock()

View File

@@ -126,7 +126,7 @@ func readRequestIgnoreFrames(c *conn.Conn) (*base.Request, error) {
}
}
func TestClientRecordSerial(t *testing.T) {
func TestClientRecord(t *testing.T) {
for _, transport := range []string{
"udp",
"tcp",
@@ -350,7 +350,7 @@ func TestClientRecordSerial(t *testing.T) {
}
}
func TestClientRecordParallel(t *testing.T) {
func TestClientRecordSocketError(t *testing.T) {
for _, transport := range []string{
"udp",
"tcp",
@@ -446,15 +446,6 @@ func TestClientRecordParallel(t *testing.T) {
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
req, err2 = readRequestIgnoreFrames(conn)
require.NoError(t, err2)
require.Equal(t, base.Teardown, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
}()
c := Client{
@@ -471,9 +462,6 @@ func TestClientRecordParallel(t *testing.T) {
}(),
}
writerDone := make(chan struct{})
defer func() { <-writerDone }()
medi := testH264Media
medias := []*description.Media{medi}
@@ -481,21 +469,15 @@ func TestClientRecordParallel(t *testing.T) {
require.NoError(t, err)
defer c.Close()
go func() {
defer close(writerDone)
ti := time.NewTicker(50 * time.Millisecond)
defer ti.Stop()
t := time.NewTicker(50 * time.Millisecond)
defer t.Stop()
for range t.C {
err := c.WritePacketRTP(medi, &testRTPPacket)
if err != nil {
return
}
for range ti.C {
err := c.WritePacketRTP(medi, &testRTPPacket)
if err != nil {
break
}
}()
time.Sleep(1 * time.Second)
}
})
}
}
@@ -645,143 +627,6 @@ func TestClientRecordPauseSerial(t *testing.T) {
}
}
func TestClientRecordPauseParallel(t *testing.T) {
for _, transport := range []string{
"udp",
"tcp",
} {
t.Run(transport, func(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
nconn, err2 := l.Accept()
require.NoError(t, err2)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err2 := conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Options, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Announce),
string(base.Setup),
string(base.Record),
string(base.Pause),
}, ", ")},
},
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Announce, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Setup, req.Method)
var inTH headers.Transport
err2 = inTH.Unmarshal(req.Header["Transport"])
require.NoError(t, err2)
th := headers.Transport{
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
}
if transport == "udp" {
th.Protocol = headers.TransportProtocolUDP
th.ServerPorts = &[2]int{34556, 34557}
th.ClientPorts = inTH.ClientPorts
} else {
th.Protocol = headers.TransportProtocolTCP
th.InterleavedIDs = inTH.InterleavedIDs
}
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Record, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
req, err2 = readRequestIgnoreFrames(conn)
require.NoError(t, err2)
require.Equal(t, base.Pause, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
}()
c := Client{
Transport: func() *Transport {
if transport == "udp" {
v := TransportUDP
return &v
}
v := TransportTCP
return &v
}(),
}
medi := testH264Media
medias := []*description.Media{medi}
err = record(&c, "rtsp://localhost:8554/teststream", medias, nil)
require.NoError(t, err)
writerDone := make(chan struct{})
go func() {
defer close(writerDone)
t := time.NewTicker(50 * time.Millisecond)
defer t.Stop()
for range t.C {
err2 := c.WritePacketRTP(medi, &testRTPPacket)
if err2 != nil {
return
}
}
}()
time.Sleep(1 * time.Second)
_, err = c.Pause()
require.NoError(t, err)
c.Close()
<-writerDone
})
}
}
func TestClientRecordAutomaticProtocol(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err)

View File

@@ -63,10 +63,9 @@ type ServerConn struct {
bc *bytecounter.ByteCounter
conn *conn.Conn
session *ServerSession
reader *serverConnReader
// in
chReadRequest chan readReq
chReadError chan error
chRemoveSession chan *ServerSession
// out
@@ -84,8 +83,6 @@ func (sc *ServerConn) initialize() {
sc.ctx = ctx
sc.ctxCancel = ctxCancel
sc.remoteAddr = sc.nconn.RemoteAddr().(*net.TCPAddr)
sc.chReadRequest = make(chan readReq)
sc.chReadError = make(chan error)
sc.chRemoveSession = make(chan *ServerSession)
sc.done = make(chan struct{})
@@ -142,10 +139,10 @@ func (sc *ServerConn) run() {
}
sc.conn = conn.NewConn(sc.bc)
cr := &serverConnReader{
sc.reader = &serverConnReader{
sc: sc,
}
cr.initialize()
sc.reader.initialize()
err := sc.runInner()
@@ -153,7 +150,9 @@ func (sc *ServerConn) run() {
sc.nconn.Close()
cr.wait()
if sc.reader != nil {
sc.reader.wait()
}
if sc.session != nil {
sc.session.removeConn(sc)
@@ -172,10 +171,11 @@ func (sc *ServerConn) run() {
func (sc *ServerConn) runInner() error {
for {
select {
case req := <-sc.chReadRequest:
case req := <-sc.reader.chRequest:
req.res <- sc.handleRequestOuter(req.req)
case err := <-sc.chReadError:
case err := <-sc.reader.chError:
sc.reader = nil
return err
case ss := <-sc.chRemoveSession:
@@ -446,20 +446,3 @@ func (sc *ServerConn) removeSession(ss *ServerSession) {
case <-sc.ctx.Done():
}
}
func (sc *ServerConn) readRequest(req readReq) error {
select {
case sc.chReadRequest <- req:
return <-req.res
case <-sc.ctx.Done():
return liberrors.ErrServerTerminated{}
}
}
func (sc *ServerConn) readError(err error) {
select {
case sc.chReadError <- err:
case <-sc.ctx.Done():
}
}

View File

@@ -2,6 +2,7 @@ package gortsplib
import (
"errors"
"fmt"
"sync/atomic"
"time"
@@ -25,26 +26,35 @@ func isSwitchReadFuncError(err error) bool {
type serverConnReader struct {
sc *ServerConn
chReadDone chan struct{}
chRequest chan readReq
chError chan error
}
func (cr *serverConnReader) initialize() {
cr.chReadDone = make(chan struct{})
cr.chRequest = make(chan readReq)
cr.chError = make(chan error)
go cr.run()
}
func (cr *serverConnReader) wait() {
<-cr.chReadDone
for {
select {
case <-cr.chError:
return
case req := <-cr.chRequest:
req.res <- fmt.Errorf("terminated")
}
}
}
func (cr *serverConnReader) run() {
defer close(cr.chReadDone)
readFunc := cr.readFuncStandard
for {
err := readFunc()
var eerr switchReadFuncError
if errors.As(err, &eerr) {
if eerr.tcp {
@@ -55,7 +65,7 @@ func (cr *serverConnReader) run() {
continue
}
cr.sc.readError(err)
cr.chError <- err
break
}
}
@@ -74,7 +84,9 @@ func (cr *serverConnReader) readFuncStandard() error {
case *base.Request:
cres := make(chan error)
req := readReq{req: what, res: cres}
err := cr.sc.readRequest(req)
cr.chRequest <- req
err := <-cres
if err != nil {
return err
}
@@ -108,7 +120,9 @@ func (cr *serverConnReader) readFuncTCP() error {
case *base.Request:
cres := make(chan error)
req := readReq{req: what, res: cres}
err := cr.sc.readRequest(req)
cr.chRequest <- req
err := <-cres
if err != nil {
return err
}

View File

@@ -11,7 +11,7 @@ type serverMulticastWriter struct {
rtpl *serverUDPListener
rtcpl *serverUDPListener
writer asyncProcessor
writer *asyncProcessor
rtpAddr *net.UDPAddr
rtcpAddr *net.UDPAddr
}
@@ -48,7 +48,10 @@ func (h *serverMulticastWriter) initialize() error {
h.rtpAddr = rtpAddr
h.rtcpAddr = rtcpAddr
h.writer.allocateBuffer(h.s.WriteQueueSize)
h.writer = &asyncProcessor{
bufferSize: h.s.WriteQueueSize,
}
h.writer.initialize()
h.writer.start()
return nil
@@ -65,8 +68,8 @@ func (h *serverMulticastWriter) ip() net.IP {
}
func (h *serverMulticastWriter) writePacketRTP(payload []byte) error {
ok := h.writer.push(func() {
h.rtpl.write(payload, h.rtpAddr) //nolint:errcheck
ok := h.writer.push(func() error {
return h.rtpl.write(payload, h.rtpAddr)
})
if !ok {
return liberrors.ErrServerWriteQueueFull{}
@@ -76,8 +79,8 @@ func (h *serverMulticastWriter) writePacketRTP(payload []byte) error {
}
func (h *serverMulticastWriter) writePacketRTCP(payload []byte) error {
ok := h.writer.push(func() {
h.rtcpl.write(payload, h.rtcpAddr) //nolint:errcheck
ok := h.writer.push(func() error {
return h.rtcpl.write(payload, h.rtcpAddr)
})
if !ok {
return liberrors.ErrServerWriteQueueFull{}

View File

@@ -765,7 +765,7 @@ func TestServerPlay(t *testing.T) {
var l1 net.PacketConn
var l2 net.PacketConn
switch transport {
switch transport { //nolint:dupl
case "udp":
require.Equal(t, headers.TransportProtocolUDP, th.Protocol)
require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery)
@@ -942,6 +942,186 @@ func TestServerPlay(t *testing.T) {
}
}
func TestServerPlaySocketError(t *testing.T) {
for _, transport := range []string{
"udp",
"multicast",
"tcp",
"tls",
} {
t.Run(transport, func(t *testing.T) {
var stream *ServerStream
connClosed := make(chan struct{})
writeDone := make(chan struct{})
listenIP := multicastCapableIP(t)
s := &Server{
Handler: &testServerHandler{
onConnClose: func(_ *ServerHandlerOnConnCloseCtx) {
close(connClosed)
},
onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onPlay: func(_ *ServerHandlerOnPlayCtx) (*base.Response, error) {
go func() {
defer close(writeDone)
t := time.NewTicker(50 * time.Millisecond)
defer t.Stop()
for range t.C {
err := stream.WritePacketRTP(stream.Description().Medias[0], &testRTPPacket)
if err != nil {
return
}
}
}()
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
RTSPAddress: listenIP + ":8554",
}
switch transport {
case "udp":
s.UDPRTPAddress = "127.0.0.1:8000"
s.UDPRTCPAddress = "127.0.0.1:8001"
case "multicast":
s.MulticastIPRange = "224.1.0.0/16"
s.MulticastRTPPort = 8000
s.MulticastRTCPPort = 8001
case "tls":
cert, err := tls.X509KeyPair(serverCert, serverKey)
require.NoError(t, err)
s.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
stream = NewServerStream(s, &description.Session{Medias: []*description.Media{testH264Media}})
func() {
nconn, err := net.Dial("tcp", listenIP+":8554")
require.NoError(t, err)
defer nconn.Close()
nconn = func() net.Conn {
if transport == "tls" {
return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true})
}
return nconn
}()
conn := conn.NewConn(nconn)
desc := doDescribe(t, conn)
inTH := &headers.Transport{
Mode: transportModePtr(headers.TransportModePlay),
}
switch transport {
case "udp":
v := headers.TransportDeliveryUnicast
inTH.Delivery = &v
inTH.Protocol = headers.TransportProtocolUDP
inTH.ClientPorts = &[2]int{35466, 35467}
case "multicast":
v := headers.TransportDeliveryMulticast
inTH.Delivery = &v
inTH.Protocol = headers.TransportProtocolUDP
default:
v := headers.TransportDeliveryUnicast
inTH.Delivery = &v
inTH.Protocol = headers.TransportProtocolTCP
inTH.InterleavedIDs = &[2]int{5, 6} // odd value
}
res, th := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "")
var l1 net.PacketConn
var l2 net.PacketConn
switch transport { //nolint:dupl
case "udp":
require.Equal(t, headers.TransportProtocolUDP, th.Protocol)
require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery)
l1, err = net.ListenPacket("udp", listenIP+":35466")
require.NoError(t, err)
defer l1.Close()
l2, err = net.ListenPacket("udp", listenIP+":35467")
require.NoError(t, err)
defer l2.Close()
case "multicast":
require.Equal(t, headers.TransportProtocolUDP, th.Protocol)
require.Equal(t, headers.TransportDeliveryMulticast, *th.Delivery)
l1, err = net.ListenPacket("udp", "224.0.0.0:"+strconv.FormatInt(int64(th.Ports[0]), 10))
require.NoError(t, err)
defer l1.Close()
p := ipv4.NewPacketConn(l1)
var intfs []net.Interface
intfs, err = net.Interfaces()
require.NoError(t, err)
for _, intf := range intfs {
err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination})
require.NoError(t, err)
}
l2, err = net.ListenPacket("udp", "224.0.0.0:"+strconv.FormatInt(int64(th.Ports[1]), 10))
require.NoError(t, err)
defer l2.Close()
p = ipv4.NewPacketConn(l2)
intfs, err = net.Interfaces()
require.NoError(t, err)
for _, intf := range intfs {
err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination})
require.NoError(t, err)
}
default:
require.Equal(t, headers.TransportProtocolTCP, th.Protocol)
require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery)
}
session := readSession(t, res)
doPlay(t, conn, "rtsp://"+listenIP+":8554/teststream", session)
}()
<-connClosed
stream.Close()
<-writeDone
})
}
}
func TestServerPlayDecodeErrors(t *testing.T) {
for _, ca := range []struct {
proto string

View File

@@ -252,7 +252,7 @@ type ServerSession struct {
announcedDesc *description.Session // publish
udpLastPacketTime *int64 // publish
udpCheckStreamTimer *time.Timer
writer asyncProcessor
writer *asyncProcessor
timeDecoder *rtptime.GlobalDecoder2
// in
@@ -425,12 +425,14 @@ func (ss *ServerSession) run() {
ss.setuppedStream.readerRemove(ss)
}
ss.writer.stop()
for _, sm := range ss.setuppedMedias {
sm.stop()
}
if ss.writer != nil {
ss.writer.stop()
}
ss.s.closeSession(ss)
if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok {
@@ -443,6 +445,13 @@ func (ss *ServerSession) run() {
func (ss *ServerSession) runInner() error {
for {
chWriterError := func() chan error {
if ss.writer != nil {
return ss.writer.chError
}
return nil
}()
select {
case req := <-ss.chHandleRequest:
ss.lastRequestTime = ss.s.timeNow()
@@ -539,6 +548,9 @@ func (ss *ServerSession) runInner() error {
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
case err := <-chWriterError:
return err
case <-ss.ctx.Done():
return liberrors.ErrServerTerminated{}
}
@@ -930,7 +942,10 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
// inside the callback.
if ss.state != ServerSessionStatePlay &&
*ss.setuppedTransport != TransportUDPMulticast {
ss.writer.allocateBuffer(ss.s.WriteQueueSize)
ss.writer = &asyncProcessor{
bufferSize: ss.s.WriteQueueSize,
}
ss.writer.initialize()
}
res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{
@@ -1023,7 +1038,10 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
// when recording, writeBuffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
ss.writer.allocateBuffer(8)
ss.writer = &asyncProcessor{
bufferSize: 8,
}
ss.writer.initialize()
res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{
Session: ss,
@@ -1087,45 +1105,48 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
return res, err
}
if ss.setuppedStream != nil {
ss.setuppedStream.readerSetInactive(ss)
}
ss.writer.stop()
for _, sm := range ss.setuppedMedias {
sm.stop()
}
ss.timeDecoder = nil
switch ss.state {
case ServerSessionStatePlay:
ss.state = ServerSessionStatePrePlay
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = emptyTimer()
case TransportUDPMulticast:
ss.udpCheckStreamTimer = emptyTimer()
default: // TCP
err = switchReadFuncError{false}
ss.tcpConn = nil
if ss.state == ServerSessionStatePlay || ss.state == ServerSessionStateRecord {
if ss.setuppedStream != nil {
ss.setuppedStream.readerSetInactive(ss)
}
case ServerSessionStateRecord:
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = emptyTimer()
default: // TCP
err = switchReadFuncError{false}
ss.tcpConn = nil
for _, sm := range ss.setuppedMedias {
sm.stop()
}
ss.state = ServerSessionStatePreRecord
ss.writer.stop()
ss.writer = nil
ss.timeDecoder = nil
switch ss.state {
case ServerSessionStatePlay:
ss.state = ServerSessionStatePrePlay
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = emptyTimer()
case TransportUDPMulticast:
ss.udpCheckStreamTimer = emptyTimer()
default: // TCP
err = switchReadFuncError{false}
ss.tcpConn = nil
}
case ServerSessionStateRecord:
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = emptyTimer()
default: // TCP
err = switchReadFuncError{false}
ss.tcpConn = nil
}
ss.state = ServerSessionStatePreRecord
}
}
return res, err

View File

@@ -27,8 +27,8 @@ type serverSessionMedia struct {
tcpRTCPFrame *base.InterleavedFrame
tcpBuffer []byte
formats map[uint8]*serverSessionFormat // record only
writePacketRTPInQueue func([]byte)
writePacketRTCPInQueue func([]byte)
writePacketRTPInQueue func([]byte) error
writePacketRTCPInQueue func([]byte) error
}
func (sm *serverSessionMedia) initialize() {
@@ -115,33 +115,33 @@ func (sm *serverSessionMedia) findFormatWithSSRC(ssrc uint32) *serverSessionForm
return nil
}
func (sm *serverSessionMedia) writePacketRTPInQueueUDP(payload []byte) {
func (sm *serverSessionMedia) writePacketRTPInQueueUDP(payload []byte) error {
atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload)))
sm.ss.s.udpRTPListener.write(payload, sm.udpRTPWriteAddr) //nolint:errcheck
return sm.ss.s.udpRTPListener.write(payload, sm.udpRTPWriteAddr)
}
func (sm *serverSessionMedia) writePacketRTCPInQueueUDP(payload []byte) {
func (sm *serverSessionMedia) writePacketRTCPInQueueUDP(payload []byte) error {
atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload)))
sm.ss.s.udpRTCPListener.write(payload, sm.udpRTCPWriteAddr) //nolint:errcheck
return sm.ss.s.udpRTCPListener.write(payload, sm.udpRTCPWriteAddr)
}
func (sm *serverSessionMedia) writePacketRTPInQueueTCP(payload []byte) {
func (sm *serverSessionMedia) writePacketRTPInQueueTCP(payload []byte) error {
atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload)))
sm.tcpRTPFrame.Payload = payload
sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout))
sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTPFrame, sm.tcpBuffer) //nolint:errcheck
return sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTPFrame, sm.tcpBuffer)
}
func (sm *serverSessionMedia) writePacketRTCPInQueueTCP(payload []byte) {
func (sm *serverSessionMedia) writePacketRTCPInQueueTCP(payload []byte) error {
atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload)))
sm.tcpRTCPFrame.Payload = payload
sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout))
sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTCPFrame, sm.tcpBuffer) //nolint:errcheck
return sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTCPFrame, sm.tcpBuffer)
}
func (sm *serverSessionMedia) writePacketRTP(payload []byte) error {
ok := sm.ss.writer.push(func() {
sm.writePacketRTPInQueue(payload)
ok := sm.ss.writer.push(func() error {
return sm.writePacketRTPInQueue(payload)
})
if !ok {
return liberrors.ErrServerWriteQueueFull{}
@@ -151,8 +151,8 @@ func (sm *serverSessionMedia) writePacketRTP(payload []byte) error {
}
func (sm *serverSessionMedia) writePacketRTCP(payload []byte) error {
ok := sm.ss.writer.push(func() {
sm.writePacketRTCPInQueue(payload)
ok := sm.ss.writer.push(func() error {
return sm.writePacketRTCPInQueue(payload)
})
if !ok {
return liberrors.ErrServerWriteQueueFull{}