close connections in case of write errors (#613) (#655)

This commit is contained in:
Alessandro Ros
2024-12-14 13:45:11 +01:00
committed by GitHub
parent a2df9d83b3
commit 8f74559616
12 changed files with 427 additions and 350 deletions

View File

@@ -4,46 +4,57 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/ringbuffer" "github.com/bluenviron/gortsplib/v4/pkg/ringbuffer"
) )
// this struct contains a queue that allows to detach the routine that is reading a stream // this is an asynchronous queue processor
// that allows to detach the routine that is reading a stream
// from the routine that is writing a stream. // from the routine that is writing a stream.
type asyncProcessor struct { type asyncProcessor struct {
bufferSize int
running bool running bool
buffer *ringbuffer.RingBuffer buffer *ringbuffer.RingBuffer
done chan struct{} chError chan error
} }
func (w *asyncProcessor) allocateBuffer(size int) { func (w *asyncProcessor) initialize() {
w.buffer, _ = ringbuffer.New(uint64(size)) w.buffer, _ = ringbuffer.New(uint64(w.bufferSize))
} }
func (w *asyncProcessor) start() { func (w *asyncProcessor) start() {
w.running = true w.running = true
w.done = make(chan struct{}) w.chError = make(chan error)
go w.run() go w.run()
} }
func (w *asyncProcessor) stop() { func (w *asyncProcessor) stop() {
if w.running { if !w.running {
w.buffer.Close() panic("should not happen")
<-w.done
w.running = false
} }
w.buffer.Close()
<-w.chError
w.running = false
} }
func (w *asyncProcessor) run() { func (w *asyncProcessor) run() {
defer close(w.done) err := w.runInner()
w.chError <- err
close(w.chError)
}
func (w *asyncProcessor) runInner() error {
for { for {
tmp, ok := w.buffer.Pull() tmp, ok := w.buffer.Pull()
if !ok { if !ok {
return return nil
} }
tmp.(func())() err := tmp.(func() error)()
if err != nil {
return err
}
} }
} }
func (w *asyncProcessor) push(cb func()) bool { func (w *asyncProcessor) push(cb func() error) bool {
return w.buffer.Push(cb) return w.buffer.Push(cb)
} }

113
client.go
View File

@@ -335,7 +335,7 @@ type Client struct {
keepalivePeriod time.Duration keepalivePeriod time.Duration
keepaliveTimer *time.Timer keepaliveTimer *time.Timer
closeError error closeError error
writer asyncProcessor writer *asyncProcessor
reader *clientReader reader *clientReader
timeDecoder *rtptime.GlobalDecoder2 timeDecoder *rtptime.GlobalDecoder2
mustClose bool mustClose bool
@@ -348,9 +348,6 @@ type Client struct {
chPlay chan playReq chPlay chan playReq
chRecord chan recordReq chRecord chan recordReq
chPause chan pauseReq chPause chan pauseReq
chReadError chan error
chReadResponse chan *base.Response
chReadRequest chan *base.Request
// out // out
done chan struct{} done chan struct{}
@@ -462,9 +459,6 @@ func (c *Client) Start(scheme string, host string) error {
c.chPlay = make(chan playReq) c.chPlay = make(chan playReq)
c.chRecord = make(chan recordReq) c.chRecord = make(chan recordReq)
c.chPause = make(chan pauseReq) c.chPause = make(chan pauseReq)
c.chReadError = make(chan error)
c.chReadResponse = make(chan *base.Response)
c.chReadRequest = make(chan *base.Request)
c.done = make(chan struct{}) c.done = make(chan struct{})
go c.run() go c.run()
@@ -530,6 +524,34 @@ func (c *Client) run() {
func (c *Client) runInner() error { func (c *Client) runInner() error {
for { for {
chReaderResponse := func() chan *base.Response {
if c.reader != nil {
return c.reader.chResponse
}
return nil
}()
chReaderRequest := func() chan *base.Request {
if c.reader != nil {
return c.reader.chRequest
}
return nil
}()
chReaderError := func() chan error {
if c.reader != nil {
return c.reader.chError
}
return nil
}()
chWriterError := func() chan error {
if c.writer != nil {
return c.writer.chError
}
return nil
}()
select { select {
case req := <-c.chOptions: case req := <-c.chOptions:
res, err := c.doOptions(req.url) res, err := c.doOptions(req.url)
@@ -601,15 +623,18 @@ func (c *Client) runInner() error {
} }
c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) c.keepaliveTimer = time.NewTimer(c.keepalivePeriod)
case err := <-c.chReadError: case err := <-chWriterError:
return err
case err := <-chReaderError:
c.reader = nil c.reader = nil
return err return err
case res := <-c.chReadResponse: case res := <-chReaderResponse:
c.OnResponse(res) c.OnResponse(res)
// these are responses to keepalives, ignore them. // these are responses to keepalives, ignore them.
case req := <-c.chReadRequest: case req := <-chReaderRequest:
err := c.handleServerRequest(req) err := c.handleServerRequest(req)
if err != nil { if err != nil {
return err return err
@@ -630,11 +655,11 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) {
case <-t.C: case <-t.C:
return nil, liberrors.ErrClientRequestTimedOut{} return nil, liberrors.ErrClientRequestTimedOut{}
case err := <-c.chReadError: case err := <-c.reader.chError:
c.reader = nil c.reader = nil
return nil, err return nil, err
case res := <-c.chReadResponse: case res := <-c.reader.chResponse:
c.OnResponse(res) c.OnResponse(res)
// accept response if CSeq equals request CSeq, or if CSeq is not present // accept response if CSeq equals request CSeq, or if CSeq is not present
@@ -642,7 +667,7 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) {
return res, nil return res, nil
} }
case req := <-c.chReadRequest: case req := <-c.reader.chRequest:
err := c.handleServerRequest(req) err := c.handleServerRequest(req)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -682,8 +707,8 @@ func (c *Client) handleServerRequest(req *base.Request) error {
func (c *Client) doClose() { func (c *Client) doClose() {
if c.state == clientStatePlay || c.state == clientStateRecord { if c.state == clientStatePlay || c.state == clientStateRecord {
c.stopWriter() c.writer.stop()
c.stopReadRoutines() c.stopTransportRoutines()
} }
if c.nconn != nil && c.baseURL != nil { if c.nconn != nil && c.baseURL != nil {
@@ -808,15 +833,21 @@ func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *base.UR
return c.doSetup(baseURL, medi, 0, 0) return c.doSetup(baseURL, medi, 0, 0)
} }
func (c *Client) startReadRoutines() { func (c *Client) startTransportRoutines() {
// allocate writer here because it's needed by RTCP receiver / sender // allocate writer here because it's needed by RTCP receiver / sender
if c.state == clientStateRecord || c.backChannelSetupped { if c.state == clientStateRecord || c.backChannelSetupped {
c.writer.allocateBuffer(c.WriteQueueSize) c.writer = &asyncProcessor{
bufferSize: c.WriteQueueSize,
}
c.writer.initialize()
} else { } else {
// when reading, buffer is only used to send RTCP receiver reports, // when reading, buffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval. // that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers. // decrease RAM consumption by allocating less buffers.
c.writer.allocateBuffer(8) c.writer = &asyncProcessor{
bufferSize: 8,
}
c.writer.initialize()
} }
c.timeDecoder = rtptime.NewGlobalDecoder2() c.timeDecoder = rtptime.NewGlobalDecoder2()
@@ -848,7 +879,7 @@ func (c *Client) startReadRoutines() {
} }
} }
func (c *Client) stopReadRoutines() { func (c *Client) stopTransportRoutines() {
if c.reader != nil { if c.reader != nil {
c.reader.setAllowInterleavedFrames(false) c.reader.setAllowInterleavedFrames(false)
} }
@@ -861,14 +892,8 @@ func (c *Client) stopReadRoutines() {
} }
c.timeDecoder = nil c.timeDecoder = nil
}
func (c *Client) startWriter() { c.writer = nil
c.writer.start()
}
func (c *Client) stopWriter() {
c.writer.stop()
} }
func (c *Client) connOpen() error { func (c *Client) connOpen() error {
@@ -1637,7 +1662,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
} }
c.state = clientStatePlay c.state = clientStatePlay
c.startReadRoutines() c.startTransportRoutines()
// Range is mandatory in Parrot Streaming Server // Range is mandatory in Parrot Streaming Server
if ra == nil { if ra == nil {
@@ -1662,13 +1687,13 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
Header: header, Header: header,
}, false) }, false)
if err != nil { if err != nil {
c.stopReadRoutines() c.stopTransportRoutines()
c.state = clientStatePrePlay c.state = clientStatePrePlay
return nil, err return nil, err
} }
if res.StatusCode != base.StatusOK { if res.StatusCode != base.StatusOK {
c.stopReadRoutines() c.stopTransportRoutines()
c.state = clientStatePrePlay c.state = clientStatePrePlay
return nil, liberrors.ErrClientBadStatusCode{ return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage, Code: res.StatusCode, Message: res.StatusMessage,
@@ -1689,7 +1714,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
} }
} }
c.startWriter() c.writer.start()
c.lastRange = ra c.lastRange = ra
return res, nil return res, nil
@@ -1718,27 +1743,27 @@ func (c *Client) doRecord() (*base.Response, error) {
} }
c.state = clientStateRecord c.state = clientStateRecord
c.startReadRoutines() c.startTransportRoutines()
res, err := c.do(&base.Request{ res, err := c.do(&base.Request{
Method: base.Record, Method: base.Record,
URL: c.baseURL, URL: c.baseURL,
}, false) }, false)
if err != nil { if err != nil {
c.stopReadRoutines() c.stopTransportRoutines()
c.state = clientStatePreRecord c.state = clientStatePreRecord
return nil, err return nil, err
} }
if res.StatusCode != base.StatusOK { if res.StatusCode != base.StatusOK {
c.stopReadRoutines() c.stopTransportRoutines()
c.state = clientStatePreRecord c.state = clientStatePreRecord
return nil, liberrors.ErrClientBadStatusCode{ return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage, Code: res.StatusCode, Message: res.StatusMessage,
} }
} }
c.startWriter() c.writer.start()
return nil, nil return nil, nil
} }
@@ -1766,25 +1791,25 @@ func (c *Client) doPause() (*base.Response, error) {
return nil, err return nil, err
} }
c.stopWriter() c.writer.stop()
res, err := c.do(&base.Request{ res, err := c.do(&base.Request{
Method: base.Pause, Method: base.Pause,
URL: c.baseURL, URL: c.baseURL,
}, false) }, false)
if err != nil { if err != nil {
c.startWriter() c.writer.start()
return nil, err return nil, err
} }
if res.StatusCode != base.StatusOK { if res.StatusCode != base.StatusOK {
c.startWriter() c.writer.start()
return nil, liberrors.ErrClientBadStatusCode{ return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage, Code: res.StatusCode, Message: res.StatusMessage,
} }
} }
c.stopReadRoutines() c.stopTransportRoutines()
switch c.state { switch c.state {
case clientStatePlay: case clientStatePlay:
@@ -1929,15 +1954,3 @@ func (c *Client) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time,
ct := cm.formats[pkt.PayloadType] ct := cm.formats[pkt.PayloadType]
return ct.rtcpReceiver.PacketNTP(pkt.Timestamp) return ct.rtcpReceiver.PacketNTP(pkt.Timestamp)
} }
func (c *Client) readResponse(res *base.Response) {
c.chReadResponse <- res
}
func (c *Client) readRequest(req *base.Request) {
c.chReadRequest <- req
}
func (c *Client) readError(err error) {
c.chReadError <- err
}

View File

@@ -74,8 +74,8 @@ func (cf *clientFormat) stop() {
func (cf *clientFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) error { func (cf *clientFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) error {
cf.rtcpSender.ProcessPacket(pkt, ntp, cf.format.PTSEqualsDTS(pkt)) cf.rtcpSender.ProcessPacket(pkt, ntp, cf.format.PTSEqualsDTS(pkt))
ok := cf.cm.c.writer.push(func() { ok := cf.cm.c.writer.push(func() error {
cf.cm.writePacketRTPInQueue(byts) return cf.cm.writePacketRTPInQueue(byts)
}) })
if !ok { if !ok {
return liberrors.ErrClientWriteQueueFull{} return liberrors.ErrClientWriteQueueFull{}

View File

@@ -25,8 +25,8 @@ type clientMedia struct {
tcpRTPFrame *base.InterleavedFrame tcpRTPFrame *base.InterleavedFrame
tcpRTCPFrame *base.InterleavedFrame tcpRTCPFrame *base.InterleavedFrame
tcpBuffer []byte tcpBuffer []byte
writePacketRTPInQueue func([]byte) writePacketRTPInQueue func([]byte) error
writePacketRTCPInQueue func([]byte) writePacketRTCPInQueue func([]byte) error
} }
func (cm *clientMedia) close() { func (cm *clientMedia) close() {
@@ -152,29 +152,29 @@ func (cm *clientMedia) findFormatWithSSRC(ssrc uint32) *clientFormat {
return nil return nil
} }
func (cm *clientMedia) writePacketRTPInQueueUDP(payload []byte) { func (cm *clientMedia) writePacketRTPInQueueUDP(payload []byte) error {
cm.udpRTPListener.write(payload) //nolint:errcheck return cm.udpRTPListener.write(payload)
} }
func (cm *clientMedia) writePacketRTCPInQueueUDP(payload []byte) { func (cm *clientMedia) writePacketRTCPInQueueUDP(payload []byte) error {
cm.udpRTCPListener.write(payload) //nolint:errcheck return cm.udpRTCPListener.write(payload)
} }
func (cm *clientMedia) writePacketRTPInQueueTCP(payload []byte) { func (cm *clientMedia) writePacketRTPInQueueTCP(payload []byte) error {
cm.tcpRTPFrame.Payload = payload cm.tcpRTPFrame.Payload = payload
cm.c.nconn.SetWriteDeadline(time.Now().Add(cm.c.WriteTimeout)) cm.c.nconn.SetWriteDeadline(time.Now().Add(cm.c.WriteTimeout))
cm.c.conn.WriteInterleavedFrame(cm.tcpRTPFrame, cm.tcpBuffer) //nolint:errcheck return cm.c.conn.WriteInterleavedFrame(cm.tcpRTPFrame, cm.tcpBuffer)
} }
func (cm *clientMedia) writePacketRTCPInQueueTCP(payload []byte) { func (cm *clientMedia) writePacketRTCPInQueueTCP(payload []byte) error {
cm.tcpRTCPFrame.Payload = payload cm.tcpRTCPFrame.Payload = payload
cm.c.nconn.SetWriteDeadline(time.Now().Add(cm.c.WriteTimeout)) cm.c.nconn.SetWriteDeadline(time.Now().Add(cm.c.WriteTimeout))
cm.c.conn.WriteInterleavedFrame(cm.tcpRTCPFrame, cm.tcpBuffer) //nolint:errcheck return cm.c.conn.WriteInterleavedFrame(cm.tcpRTCPFrame, cm.tcpBuffer)
} }
func (cm *clientMedia) writePacketRTCP(byts []byte) error { func (cm *clientMedia) writePacketRTCP(byts []byte) error {
ok := cm.c.writer.push(func() { ok := cm.c.writer.push(func() error {
cm.writePacketRTCPInQueue(byts) return cm.writePacketRTCPInQueue(byts)
}) })
if !ok { if !ok {
return liberrors.ErrClientWriteQueueFull{} return liberrors.ErrClientWriteQueueFull{}

View File

@@ -12,9 +12,17 @@ type clientReader struct {
mutex sync.Mutex mutex sync.Mutex
allowInterleavedFrames bool allowInterleavedFrames bool
chResponse chan *base.Response
chRequest chan *base.Request
chError chan error
} }
func (r *clientReader) start() { func (r *clientReader) start() {
r.chResponse = make(chan *base.Response)
r.chRequest = make(chan *base.Request)
r.chError = make(chan error)
go r.run() go r.run()
} }
@@ -27,18 +35,17 @@ func (r *clientReader) setAllowInterleavedFrames(v bool) {
func (r *clientReader) wait() { func (r *clientReader) wait() {
for { for {
select { select {
case <-r.c.chReadError: case <-r.chError:
return return
case <-r.c.chReadResponse: case <-r.chResponse:
case <-r.c.chReadRequest: case <-r.chRequest:
} }
} }
} }
func (r *clientReader) run() { func (r *clientReader) run() {
err := r.runInner() r.chError <- r.runInner()
r.c.readError(err)
} }
func (r *clientReader) runInner() error { func (r *clientReader) runInner() error {
@@ -50,10 +57,10 @@ func (r *clientReader) runInner() error {
switch what := what.(type) { switch what := what.(type) {
case *base.Response: case *base.Response:
r.c.readResponse(what) r.chResponse <- what
case *base.Request: case *base.Request:
r.c.readRequest(what) r.chRequest <- what
case *base.InterleavedFrame: case *base.InterleavedFrame:
r.mutex.Lock() r.mutex.Lock()

View File

@@ -126,7 +126,7 @@ func readRequestIgnoreFrames(c *conn.Conn) (*base.Request, error) {
} }
} }
func TestClientRecordSerial(t *testing.T) { func TestClientRecord(t *testing.T) {
for _, transport := range []string{ for _, transport := range []string{
"udp", "udp",
"tcp", "tcp",
@@ -350,7 +350,7 @@ func TestClientRecordSerial(t *testing.T) {
} }
} }
func TestClientRecordParallel(t *testing.T) { func TestClientRecordSocketError(t *testing.T) {
for _, transport := range []string{ for _, transport := range []string{
"udp", "udp",
"tcp", "tcp",
@@ -446,15 +446,6 @@ func TestClientRecordParallel(t *testing.T) {
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}) })
require.NoError(t, err2) require.NoError(t, err2)
req, err2 = readRequestIgnoreFrames(conn)
require.NoError(t, err2)
require.Equal(t, base.Teardown, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
}() }()
c := Client{ c := Client{
@@ -471,9 +462,6 @@ func TestClientRecordParallel(t *testing.T) {
}(), }(),
} }
writerDone := make(chan struct{})
defer func() { <-writerDone }()
medi := testH264Media medi := testH264Media
medias := []*description.Media{medi} medias := []*description.Media{medi}
@@ -481,21 +469,15 @@ func TestClientRecordParallel(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer c.Close() defer c.Close()
go func() { ti := time.NewTicker(50 * time.Millisecond)
defer close(writerDone) defer ti.Stop()
t := time.NewTicker(50 * time.Millisecond) for range ti.C {
defer t.Stop()
for range t.C {
err := c.WritePacketRTP(medi, &testRTPPacket) err := c.WritePacketRTP(medi, &testRTPPacket)
if err != nil { if err != nil {
return break
} }
} }
}()
time.Sleep(1 * time.Second)
}) })
} }
} }
@@ -645,143 +627,6 @@ func TestClientRecordPauseSerial(t *testing.T) {
} }
} }
func TestClientRecordPauseParallel(t *testing.T) {
for _, transport := range []string{
"udp",
"tcp",
} {
t.Run(transport, func(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
nconn, err2 := l.Accept()
require.NoError(t, err2)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err2 := conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Options, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Announce),
string(base.Setup),
string(base.Record),
string(base.Pause),
}, ", ")},
},
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Announce, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Setup, req.Method)
var inTH headers.Transport
err2 = inTH.Unmarshal(req.Header["Transport"])
require.NoError(t, err2)
th := headers.Transport{
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
}
if transport == "udp" {
th.Protocol = headers.TransportProtocolUDP
th.ServerPorts = &[2]int{34556, 34557}
th.ClientPorts = inTH.ClientPorts
} else {
th.Protocol = headers.TransportProtocolTCP
th.InterleavedIDs = inTH.InterleavedIDs
}
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Record, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
req, err2 = readRequestIgnoreFrames(conn)
require.NoError(t, err2)
require.Equal(t, base.Pause, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
}()
c := Client{
Transport: func() *Transport {
if transport == "udp" {
v := TransportUDP
return &v
}
v := TransportTCP
return &v
}(),
}
medi := testH264Media
medias := []*description.Media{medi}
err = record(&c, "rtsp://localhost:8554/teststream", medias, nil)
require.NoError(t, err)
writerDone := make(chan struct{})
go func() {
defer close(writerDone)
t := time.NewTicker(50 * time.Millisecond)
defer t.Stop()
for range t.C {
err2 := c.WritePacketRTP(medi, &testRTPPacket)
if err2 != nil {
return
}
}
}()
time.Sleep(1 * time.Second)
_, err = c.Pause()
require.NoError(t, err)
c.Close()
<-writerDone
})
}
}
func TestClientRecordAutomaticProtocol(t *testing.T) { func TestClientRecordAutomaticProtocol(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554") l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err) require.NoError(t, err)

View File

@@ -63,10 +63,9 @@ type ServerConn struct {
bc *bytecounter.ByteCounter bc *bytecounter.ByteCounter
conn *conn.Conn conn *conn.Conn
session *ServerSession session *ServerSession
reader *serverConnReader
// in // in
chReadRequest chan readReq
chReadError chan error
chRemoveSession chan *ServerSession chRemoveSession chan *ServerSession
// out // out
@@ -84,8 +83,6 @@ func (sc *ServerConn) initialize() {
sc.ctx = ctx sc.ctx = ctx
sc.ctxCancel = ctxCancel sc.ctxCancel = ctxCancel
sc.remoteAddr = sc.nconn.RemoteAddr().(*net.TCPAddr) sc.remoteAddr = sc.nconn.RemoteAddr().(*net.TCPAddr)
sc.chReadRequest = make(chan readReq)
sc.chReadError = make(chan error)
sc.chRemoveSession = make(chan *ServerSession) sc.chRemoveSession = make(chan *ServerSession)
sc.done = make(chan struct{}) sc.done = make(chan struct{})
@@ -142,10 +139,10 @@ func (sc *ServerConn) run() {
} }
sc.conn = conn.NewConn(sc.bc) sc.conn = conn.NewConn(sc.bc)
cr := &serverConnReader{ sc.reader = &serverConnReader{
sc: sc, sc: sc,
} }
cr.initialize() sc.reader.initialize()
err := sc.runInner() err := sc.runInner()
@@ -153,7 +150,9 @@ func (sc *ServerConn) run() {
sc.nconn.Close() sc.nconn.Close()
cr.wait() if sc.reader != nil {
sc.reader.wait()
}
if sc.session != nil { if sc.session != nil {
sc.session.removeConn(sc) sc.session.removeConn(sc)
@@ -172,10 +171,11 @@ func (sc *ServerConn) run() {
func (sc *ServerConn) runInner() error { func (sc *ServerConn) runInner() error {
for { for {
select { select {
case req := <-sc.chReadRequest: case req := <-sc.reader.chRequest:
req.res <- sc.handleRequestOuter(req.req) req.res <- sc.handleRequestOuter(req.req)
case err := <-sc.chReadError: case err := <-sc.reader.chError:
sc.reader = nil
return err return err
case ss := <-sc.chRemoveSession: case ss := <-sc.chRemoveSession:
@@ -446,20 +446,3 @@ func (sc *ServerConn) removeSession(ss *ServerSession) {
case <-sc.ctx.Done(): case <-sc.ctx.Done():
} }
} }
func (sc *ServerConn) readRequest(req readReq) error {
select {
case sc.chReadRequest <- req:
return <-req.res
case <-sc.ctx.Done():
return liberrors.ErrServerTerminated{}
}
}
func (sc *ServerConn) readError(err error) {
select {
case sc.chReadError <- err:
case <-sc.ctx.Done():
}
}

View File

@@ -2,6 +2,7 @@ package gortsplib
import ( import (
"errors" "errors"
"fmt"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -25,26 +26,35 @@ func isSwitchReadFuncError(err error) bool {
type serverConnReader struct { type serverConnReader struct {
sc *ServerConn sc *ServerConn
chReadDone chan struct{} chRequest chan readReq
chError chan error
} }
func (cr *serverConnReader) initialize() { func (cr *serverConnReader) initialize() {
cr.chReadDone = make(chan struct{}) cr.chRequest = make(chan readReq)
cr.chError = make(chan error)
go cr.run() go cr.run()
} }
func (cr *serverConnReader) wait() { func (cr *serverConnReader) wait() {
<-cr.chReadDone for {
select {
case <-cr.chError:
return
case req := <-cr.chRequest:
req.res <- fmt.Errorf("terminated")
}
}
} }
func (cr *serverConnReader) run() { func (cr *serverConnReader) run() {
defer close(cr.chReadDone)
readFunc := cr.readFuncStandard readFunc := cr.readFuncStandard
for { for {
err := readFunc() err := readFunc()
var eerr switchReadFuncError var eerr switchReadFuncError
if errors.As(err, &eerr) { if errors.As(err, &eerr) {
if eerr.tcp { if eerr.tcp {
@@ -55,7 +65,7 @@ func (cr *serverConnReader) run() {
continue continue
} }
cr.sc.readError(err) cr.chError <- err
break break
} }
} }
@@ -74,7 +84,9 @@ func (cr *serverConnReader) readFuncStandard() error {
case *base.Request: case *base.Request:
cres := make(chan error) cres := make(chan error)
req := readReq{req: what, res: cres} req := readReq{req: what, res: cres}
err := cr.sc.readRequest(req) cr.chRequest <- req
err := <-cres
if err != nil { if err != nil {
return err return err
} }
@@ -108,7 +120,9 @@ func (cr *serverConnReader) readFuncTCP() error {
case *base.Request: case *base.Request:
cres := make(chan error) cres := make(chan error)
req := readReq{req: what, res: cres} req := readReq{req: what, res: cres}
err := cr.sc.readRequest(req) cr.chRequest <- req
err := <-cres
if err != nil { if err != nil {
return err return err
} }

View File

@@ -11,7 +11,7 @@ type serverMulticastWriter struct {
rtpl *serverUDPListener rtpl *serverUDPListener
rtcpl *serverUDPListener rtcpl *serverUDPListener
writer asyncProcessor writer *asyncProcessor
rtpAddr *net.UDPAddr rtpAddr *net.UDPAddr
rtcpAddr *net.UDPAddr rtcpAddr *net.UDPAddr
} }
@@ -48,7 +48,10 @@ func (h *serverMulticastWriter) initialize() error {
h.rtpAddr = rtpAddr h.rtpAddr = rtpAddr
h.rtcpAddr = rtcpAddr h.rtcpAddr = rtcpAddr
h.writer.allocateBuffer(h.s.WriteQueueSize) h.writer = &asyncProcessor{
bufferSize: h.s.WriteQueueSize,
}
h.writer.initialize()
h.writer.start() h.writer.start()
return nil return nil
@@ -65,8 +68,8 @@ func (h *serverMulticastWriter) ip() net.IP {
} }
func (h *serverMulticastWriter) writePacketRTP(payload []byte) error { func (h *serverMulticastWriter) writePacketRTP(payload []byte) error {
ok := h.writer.push(func() { ok := h.writer.push(func() error {
h.rtpl.write(payload, h.rtpAddr) //nolint:errcheck return h.rtpl.write(payload, h.rtpAddr)
}) })
if !ok { if !ok {
return liberrors.ErrServerWriteQueueFull{} return liberrors.ErrServerWriteQueueFull{}
@@ -76,8 +79,8 @@ func (h *serverMulticastWriter) writePacketRTP(payload []byte) error {
} }
func (h *serverMulticastWriter) writePacketRTCP(payload []byte) error { func (h *serverMulticastWriter) writePacketRTCP(payload []byte) error {
ok := h.writer.push(func() { ok := h.writer.push(func() error {
h.rtcpl.write(payload, h.rtcpAddr) //nolint:errcheck return h.rtcpl.write(payload, h.rtcpAddr)
}) })
if !ok { if !ok {
return liberrors.ErrServerWriteQueueFull{} return liberrors.ErrServerWriteQueueFull{}

View File

@@ -765,7 +765,7 @@ func TestServerPlay(t *testing.T) {
var l1 net.PacketConn var l1 net.PacketConn
var l2 net.PacketConn var l2 net.PacketConn
switch transport { switch transport { //nolint:dupl
case "udp": case "udp":
require.Equal(t, headers.TransportProtocolUDP, th.Protocol) require.Equal(t, headers.TransportProtocolUDP, th.Protocol)
require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery) require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery)
@@ -942,6 +942,186 @@ func TestServerPlay(t *testing.T) {
} }
} }
func TestServerPlaySocketError(t *testing.T) {
for _, transport := range []string{
"udp",
"multicast",
"tcp",
"tls",
} {
t.Run(transport, func(t *testing.T) {
var stream *ServerStream
connClosed := make(chan struct{})
writeDone := make(chan struct{})
listenIP := multicastCapableIP(t)
s := &Server{
Handler: &testServerHandler{
onConnClose: func(_ *ServerHandlerOnConnCloseCtx) {
close(connClosed)
},
onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onPlay: func(_ *ServerHandlerOnPlayCtx) (*base.Response, error) {
go func() {
defer close(writeDone)
t := time.NewTicker(50 * time.Millisecond)
defer t.Stop()
for range t.C {
err := stream.WritePacketRTP(stream.Description().Medias[0], &testRTPPacket)
if err != nil {
return
}
}
}()
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
RTSPAddress: listenIP + ":8554",
}
switch transport {
case "udp":
s.UDPRTPAddress = "127.0.0.1:8000"
s.UDPRTCPAddress = "127.0.0.1:8001"
case "multicast":
s.MulticastIPRange = "224.1.0.0/16"
s.MulticastRTPPort = 8000
s.MulticastRTCPPort = 8001
case "tls":
cert, err := tls.X509KeyPair(serverCert, serverKey)
require.NoError(t, err)
s.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
stream = NewServerStream(s, &description.Session{Medias: []*description.Media{testH264Media}})
func() {
nconn, err := net.Dial("tcp", listenIP+":8554")
require.NoError(t, err)
defer nconn.Close()
nconn = func() net.Conn {
if transport == "tls" {
return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true})
}
return nconn
}()
conn := conn.NewConn(nconn)
desc := doDescribe(t, conn)
inTH := &headers.Transport{
Mode: transportModePtr(headers.TransportModePlay),
}
switch transport {
case "udp":
v := headers.TransportDeliveryUnicast
inTH.Delivery = &v
inTH.Protocol = headers.TransportProtocolUDP
inTH.ClientPorts = &[2]int{35466, 35467}
case "multicast":
v := headers.TransportDeliveryMulticast
inTH.Delivery = &v
inTH.Protocol = headers.TransportProtocolUDP
default:
v := headers.TransportDeliveryUnicast
inTH.Delivery = &v
inTH.Protocol = headers.TransportProtocolTCP
inTH.InterleavedIDs = &[2]int{5, 6} // odd value
}
res, th := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "")
var l1 net.PacketConn
var l2 net.PacketConn
switch transport { //nolint:dupl
case "udp":
require.Equal(t, headers.TransportProtocolUDP, th.Protocol)
require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery)
l1, err = net.ListenPacket("udp", listenIP+":35466")
require.NoError(t, err)
defer l1.Close()
l2, err = net.ListenPacket("udp", listenIP+":35467")
require.NoError(t, err)
defer l2.Close()
case "multicast":
require.Equal(t, headers.TransportProtocolUDP, th.Protocol)
require.Equal(t, headers.TransportDeliveryMulticast, *th.Delivery)
l1, err = net.ListenPacket("udp", "224.0.0.0:"+strconv.FormatInt(int64(th.Ports[0]), 10))
require.NoError(t, err)
defer l1.Close()
p := ipv4.NewPacketConn(l1)
var intfs []net.Interface
intfs, err = net.Interfaces()
require.NoError(t, err)
for _, intf := range intfs {
err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination})
require.NoError(t, err)
}
l2, err = net.ListenPacket("udp", "224.0.0.0:"+strconv.FormatInt(int64(th.Ports[1]), 10))
require.NoError(t, err)
defer l2.Close()
p = ipv4.NewPacketConn(l2)
intfs, err = net.Interfaces()
require.NoError(t, err)
for _, intf := range intfs {
err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination})
require.NoError(t, err)
}
default:
require.Equal(t, headers.TransportProtocolTCP, th.Protocol)
require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery)
}
session := readSession(t, res)
doPlay(t, conn, "rtsp://"+listenIP+":8554/teststream", session)
}()
<-connClosed
stream.Close()
<-writeDone
})
}
}
func TestServerPlayDecodeErrors(t *testing.T) { func TestServerPlayDecodeErrors(t *testing.T) {
for _, ca := range []struct { for _, ca := range []struct {
proto string proto string

View File

@@ -252,7 +252,7 @@ type ServerSession struct {
announcedDesc *description.Session // publish announcedDesc *description.Session // publish
udpLastPacketTime *int64 // publish udpLastPacketTime *int64 // publish
udpCheckStreamTimer *time.Timer udpCheckStreamTimer *time.Timer
writer asyncProcessor writer *asyncProcessor
timeDecoder *rtptime.GlobalDecoder2 timeDecoder *rtptime.GlobalDecoder2
// in // in
@@ -425,12 +425,14 @@ func (ss *ServerSession) run() {
ss.setuppedStream.readerRemove(ss) ss.setuppedStream.readerRemove(ss)
} }
ss.writer.stop()
for _, sm := range ss.setuppedMedias { for _, sm := range ss.setuppedMedias {
sm.stop() sm.stop()
} }
if ss.writer != nil {
ss.writer.stop()
}
ss.s.closeSession(ss) ss.s.closeSession(ss)
if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok { if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok {
@@ -443,6 +445,13 @@ func (ss *ServerSession) run() {
func (ss *ServerSession) runInner() error { func (ss *ServerSession) runInner() error {
for { for {
chWriterError := func() chan error {
if ss.writer != nil {
return ss.writer.chError
}
return nil
}()
select { select {
case req := <-ss.chHandleRequest: case req := <-ss.chHandleRequest:
ss.lastRequestTime = ss.s.timeNow() ss.lastRequestTime = ss.s.timeNow()
@@ -539,6 +548,9 @@ func (ss *ServerSession) runInner() error {
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod) ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
case err := <-chWriterError:
return err
case <-ss.ctx.Done(): case <-ss.ctx.Done():
return liberrors.ErrServerTerminated{} return liberrors.ErrServerTerminated{}
} }
@@ -930,7 +942,10 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
// inside the callback. // inside the callback.
if ss.state != ServerSessionStatePlay && if ss.state != ServerSessionStatePlay &&
*ss.setuppedTransport != TransportUDPMulticast { *ss.setuppedTransport != TransportUDPMulticast {
ss.writer.allocateBuffer(ss.s.WriteQueueSize) ss.writer = &asyncProcessor{
bufferSize: ss.s.WriteQueueSize,
}
ss.writer.initialize()
} }
res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{ res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{
@@ -1023,7 +1038,10 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
// when recording, writeBuffer is only used to send RTCP receiver reports, // when recording, writeBuffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval. // that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers. // decrease RAM consumption by allocating less buffers.
ss.writer.allocateBuffer(8) ss.writer = &asyncProcessor{
bufferSize: 8,
}
ss.writer.initialize()
res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{ res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{
Session: ss, Session: ss,
@@ -1087,16 +1105,18 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
return res, err return res, err
} }
if ss.state == ServerSessionStatePlay || ss.state == ServerSessionStateRecord {
if ss.setuppedStream != nil { if ss.setuppedStream != nil {
ss.setuppedStream.readerSetInactive(ss) ss.setuppedStream.readerSetInactive(ss)
} }
ss.writer.stop()
for _, sm := range ss.setuppedMedias { for _, sm := range ss.setuppedMedias {
sm.stop() sm.stop()
} }
ss.writer.stop()
ss.writer = nil
ss.timeDecoder = nil ss.timeDecoder = nil
switch ss.state { switch ss.state {
@@ -1127,6 +1147,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
ss.state = ServerSessionStatePreRecord ss.state = ServerSessionStatePreRecord
} }
}
return res, err return res, err

View File

@@ -27,8 +27,8 @@ type serverSessionMedia struct {
tcpRTCPFrame *base.InterleavedFrame tcpRTCPFrame *base.InterleavedFrame
tcpBuffer []byte tcpBuffer []byte
formats map[uint8]*serverSessionFormat // record only formats map[uint8]*serverSessionFormat // record only
writePacketRTPInQueue func([]byte) writePacketRTPInQueue func([]byte) error
writePacketRTCPInQueue func([]byte) writePacketRTCPInQueue func([]byte) error
} }
func (sm *serverSessionMedia) initialize() { func (sm *serverSessionMedia) initialize() {
@@ -115,33 +115,33 @@ func (sm *serverSessionMedia) findFormatWithSSRC(ssrc uint32) *serverSessionForm
return nil return nil
} }
func (sm *serverSessionMedia) writePacketRTPInQueueUDP(payload []byte) { func (sm *serverSessionMedia) writePacketRTPInQueueUDP(payload []byte) error {
atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload))) atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload)))
sm.ss.s.udpRTPListener.write(payload, sm.udpRTPWriteAddr) //nolint:errcheck return sm.ss.s.udpRTPListener.write(payload, sm.udpRTPWriteAddr)
} }
func (sm *serverSessionMedia) writePacketRTCPInQueueUDP(payload []byte) { func (sm *serverSessionMedia) writePacketRTCPInQueueUDP(payload []byte) error {
atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload))) atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload)))
sm.ss.s.udpRTCPListener.write(payload, sm.udpRTCPWriteAddr) //nolint:errcheck return sm.ss.s.udpRTCPListener.write(payload, sm.udpRTCPWriteAddr)
} }
func (sm *serverSessionMedia) writePacketRTPInQueueTCP(payload []byte) { func (sm *serverSessionMedia) writePacketRTPInQueueTCP(payload []byte) error {
atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload))) atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload)))
sm.tcpRTPFrame.Payload = payload sm.tcpRTPFrame.Payload = payload
sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout)) sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout))
sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTPFrame, sm.tcpBuffer) //nolint:errcheck return sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTPFrame, sm.tcpBuffer)
} }
func (sm *serverSessionMedia) writePacketRTCPInQueueTCP(payload []byte) { func (sm *serverSessionMedia) writePacketRTCPInQueueTCP(payload []byte) error {
atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload))) atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload)))
sm.tcpRTCPFrame.Payload = payload sm.tcpRTCPFrame.Payload = payload
sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout)) sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout))
sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTCPFrame, sm.tcpBuffer) //nolint:errcheck return sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTCPFrame, sm.tcpBuffer)
} }
func (sm *serverSessionMedia) writePacketRTP(payload []byte) error { func (sm *serverSessionMedia) writePacketRTP(payload []byte) error {
ok := sm.ss.writer.push(func() { ok := sm.ss.writer.push(func() error {
sm.writePacketRTPInQueue(payload) return sm.writePacketRTPInQueue(payload)
}) })
if !ok { if !ok {
return liberrors.ErrServerWriteQueueFull{} return liberrors.ErrServerWriteQueueFull{}
@@ -151,8 +151,8 @@ func (sm *serverSessionMedia) writePacketRTP(payload []byte) error {
} }
func (sm *serverSessionMedia) writePacketRTCP(payload []byte) error { func (sm *serverSessionMedia) writePacketRTCP(payload []byte) error {
ok := sm.ss.writer.push(func() { ok := sm.ss.writer.push(func() error {
sm.writePacketRTCPInQueue(payload) return sm.writePacketRTCPInQueue(payload)
}) })
if !ok { if !ok {
return liberrors.ErrServerWriteQueueFull{} return liberrors.ErrServerWriteQueueFull{}