close connections in case of write errors (#613) (#655)

This commit is contained in:
Alessandro Ros
2024-12-14 13:45:11 +01:00
committed by GitHub
parent a2df9d83b3
commit 8f74559616
12 changed files with 427 additions and 350 deletions

View File

@@ -4,46 +4,57 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/ringbuffer" "github.com/bluenviron/gortsplib/v4/pkg/ringbuffer"
) )
// this struct contains a queue that allows to detach the routine that is reading a stream // this is an asynchronous queue processor
// that allows to detach the routine that is reading a stream
// from the routine that is writing a stream. // from the routine that is writing a stream.
type asyncProcessor struct { type asyncProcessor struct {
bufferSize int
running bool running bool
buffer *ringbuffer.RingBuffer buffer *ringbuffer.RingBuffer
done chan struct{} chError chan error
} }
func (w *asyncProcessor) allocateBuffer(size int) { func (w *asyncProcessor) initialize() {
w.buffer, _ = ringbuffer.New(uint64(size)) w.buffer, _ = ringbuffer.New(uint64(w.bufferSize))
} }
func (w *asyncProcessor) start() { func (w *asyncProcessor) start() {
w.running = true w.running = true
w.done = make(chan struct{}) w.chError = make(chan error)
go w.run() go w.run()
} }
func (w *asyncProcessor) stop() { func (w *asyncProcessor) stop() {
if w.running { if !w.running {
w.buffer.Close() panic("should not happen")
<-w.done
w.running = false
} }
w.buffer.Close()
<-w.chError
w.running = false
} }
func (w *asyncProcessor) run() { func (w *asyncProcessor) run() {
defer close(w.done) err := w.runInner()
w.chError <- err
close(w.chError)
}
func (w *asyncProcessor) runInner() error {
for { for {
tmp, ok := w.buffer.Pull() tmp, ok := w.buffer.Pull()
if !ok { if !ok {
return return nil
} }
tmp.(func())() err := tmp.(func() error)()
if err != nil {
return err
}
} }
} }
func (w *asyncProcessor) push(cb func()) bool { func (w *asyncProcessor) push(cb func() error) bool {
return w.buffer.Push(cb) return w.buffer.Push(cb)
} }

127
client.go
View File

@@ -335,22 +335,19 @@ type Client struct {
keepalivePeriod time.Duration keepalivePeriod time.Duration
keepaliveTimer *time.Timer keepaliveTimer *time.Timer
closeError error closeError error
writer asyncProcessor writer *asyncProcessor
reader *clientReader reader *clientReader
timeDecoder *rtptime.GlobalDecoder2 timeDecoder *rtptime.GlobalDecoder2
mustClose bool mustClose bool
// in // in
chOptions chan optionsReq chOptions chan optionsReq
chDescribe chan describeReq chDescribe chan describeReq
chAnnounce chan announceReq chAnnounce chan announceReq
chSetup chan setupReq chSetup chan setupReq
chPlay chan playReq chPlay chan playReq
chRecord chan recordReq chRecord chan recordReq
chPause chan pauseReq chPause chan pauseReq
chReadError chan error
chReadResponse chan *base.Response
chReadRequest chan *base.Request
// out // out
done chan struct{} done chan struct{}
@@ -462,9 +459,6 @@ func (c *Client) Start(scheme string, host string) error {
c.chPlay = make(chan playReq) c.chPlay = make(chan playReq)
c.chRecord = make(chan recordReq) c.chRecord = make(chan recordReq)
c.chPause = make(chan pauseReq) c.chPause = make(chan pauseReq)
c.chReadError = make(chan error)
c.chReadResponse = make(chan *base.Response)
c.chReadRequest = make(chan *base.Request)
c.done = make(chan struct{}) c.done = make(chan struct{})
go c.run() go c.run()
@@ -530,6 +524,34 @@ func (c *Client) run() {
func (c *Client) runInner() error { func (c *Client) runInner() error {
for { for {
chReaderResponse := func() chan *base.Response {
if c.reader != nil {
return c.reader.chResponse
}
return nil
}()
chReaderRequest := func() chan *base.Request {
if c.reader != nil {
return c.reader.chRequest
}
return nil
}()
chReaderError := func() chan error {
if c.reader != nil {
return c.reader.chError
}
return nil
}()
chWriterError := func() chan error {
if c.writer != nil {
return c.writer.chError
}
return nil
}()
select { select {
case req := <-c.chOptions: case req := <-c.chOptions:
res, err := c.doOptions(req.url) res, err := c.doOptions(req.url)
@@ -601,15 +623,18 @@ func (c *Client) runInner() error {
} }
c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) c.keepaliveTimer = time.NewTimer(c.keepalivePeriod)
case err := <-c.chReadError: case err := <-chWriterError:
return err
case err := <-chReaderError:
c.reader = nil c.reader = nil
return err return err
case res := <-c.chReadResponse: case res := <-chReaderResponse:
c.OnResponse(res) c.OnResponse(res)
// these are responses to keepalives, ignore them. // these are responses to keepalives, ignore them.
case req := <-c.chReadRequest: case req := <-chReaderRequest:
err := c.handleServerRequest(req) err := c.handleServerRequest(req)
if err != nil { if err != nil {
return err return err
@@ -630,11 +655,11 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) {
case <-t.C: case <-t.C:
return nil, liberrors.ErrClientRequestTimedOut{} return nil, liberrors.ErrClientRequestTimedOut{}
case err := <-c.chReadError: case err := <-c.reader.chError:
c.reader = nil c.reader = nil
return nil, err return nil, err
case res := <-c.chReadResponse: case res := <-c.reader.chResponse:
c.OnResponse(res) c.OnResponse(res)
// accept response if CSeq equals request CSeq, or if CSeq is not present // accept response if CSeq equals request CSeq, or if CSeq is not present
@@ -642,7 +667,7 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) {
return res, nil return res, nil
} }
case req := <-c.chReadRequest: case req := <-c.reader.chRequest:
err := c.handleServerRequest(req) err := c.handleServerRequest(req)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -682,8 +707,8 @@ func (c *Client) handleServerRequest(req *base.Request) error {
func (c *Client) doClose() { func (c *Client) doClose() {
if c.state == clientStatePlay || c.state == clientStateRecord { if c.state == clientStatePlay || c.state == clientStateRecord {
c.stopWriter() c.writer.stop()
c.stopReadRoutines() c.stopTransportRoutines()
} }
if c.nconn != nil && c.baseURL != nil { if c.nconn != nil && c.baseURL != nil {
@@ -808,15 +833,21 @@ func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *base.UR
return c.doSetup(baseURL, medi, 0, 0) return c.doSetup(baseURL, medi, 0, 0)
} }
func (c *Client) startReadRoutines() { func (c *Client) startTransportRoutines() {
// allocate writer here because it's needed by RTCP receiver / sender // allocate writer here because it's needed by RTCP receiver / sender
if c.state == clientStateRecord || c.backChannelSetupped { if c.state == clientStateRecord || c.backChannelSetupped {
c.writer.allocateBuffer(c.WriteQueueSize) c.writer = &asyncProcessor{
bufferSize: c.WriteQueueSize,
}
c.writer.initialize()
} else { } else {
// when reading, buffer is only used to send RTCP receiver reports, // when reading, buffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval. // that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers. // decrease RAM consumption by allocating less buffers.
c.writer.allocateBuffer(8) c.writer = &asyncProcessor{
bufferSize: 8,
}
c.writer.initialize()
} }
c.timeDecoder = rtptime.NewGlobalDecoder2() c.timeDecoder = rtptime.NewGlobalDecoder2()
@@ -848,7 +879,7 @@ func (c *Client) startReadRoutines() {
} }
} }
func (c *Client) stopReadRoutines() { func (c *Client) stopTransportRoutines() {
if c.reader != nil { if c.reader != nil {
c.reader.setAllowInterleavedFrames(false) c.reader.setAllowInterleavedFrames(false)
} }
@@ -861,14 +892,8 @@ func (c *Client) stopReadRoutines() {
} }
c.timeDecoder = nil c.timeDecoder = nil
}
func (c *Client) startWriter() { c.writer = nil
c.writer.start()
}
func (c *Client) stopWriter() {
c.writer.stop()
} }
func (c *Client) connOpen() error { func (c *Client) connOpen() error {
@@ -1637,7 +1662,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
} }
c.state = clientStatePlay c.state = clientStatePlay
c.startReadRoutines() c.startTransportRoutines()
// Range is mandatory in Parrot Streaming Server // Range is mandatory in Parrot Streaming Server
if ra == nil { if ra == nil {
@@ -1662,13 +1687,13 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
Header: header, Header: header,
}, false) }, false)
if err != nil { if err != nil {
c.stopReadRoutines() c.stopTransportRoutines()
c.state = clientStatePrePlay c.state = clientStatePrePlay
return nil, err return nil, err
} }
if res.StatusCode != base.StatusOK { if res.StatusCode != base.StatusOK {
c.stopReadRoutines() c.stopTransportRoutines()
c.state = clientStatePrePlay c.state = clientStatePrePlay
return nil, liberrors.ErrClientBadStatusCode{ return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage, Code: res.StatusCode, Message: res.StatusMessage,
@@ -1689,7 +1714,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
} }
} }
c.startWriter() c.writer.start()
c.lastRange = ra c.lastRange = ra
return res, nil return res, nil
@@ -1718,27 +1743,27 @@ func (c *Client) doRecord() (*base.Response, error) {
} }
c.state = clientStateRecord c.state = clientStateRecord
c.startReadRoutines() c.startTransportRoutines()
res, err := c.do(&base.Request{ res, err := c.do(&base.Request{
Method: base.Record, Method: base.Record,
URL: c.baseURL, URL: c.baseURL,
}, false) }, false)
if err != nil { if err != nil {
c.stopReadRoutines() c.stopTransportRoutines()
c.state = clientStatePreRecord c.state = clientStatePreRecord
return nil, err return nil, err
} }
if res.StatusCode != base.StatusOK { if res.StatusCode != base.StatusOK {
c.stopReadRoutines() c.stopTransportRoutines()
c.state = clientStatePreRecord c.state = clientStatePreRecord
return nil, liberrors.ErrClientBadStatusCode{ return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage, Code: res.StatusCode, Message: res.StatusMessage,
} }
} }
c.startWriter() c.writer.start()
return nil, nil return nil, nil
} }
@@ -1766,25 +1791,25 @@ func (c *Client) doPause() (*base.Response, error) {
return nil, err return nil, err
} }
c.stopWriter() c.writer.stop()
res, err := c.do(&base.Request{ res, err := c.do(&base.Request{
Method: base.Pause, Method: base.Pause,
URL: c.baseURL, URL: c.baseURL,
}, false) }, false)
if err != nil { if err != nil {
c.startWriter() c.writer.start()
return nil, err return nil, err
} }
if res.StatusCode != base.StatusOK { if res.StatusCode != base.StatusOK {
c.startWriter() c.writer.start()
return nil, liberrors.ErrClientBadStatusCode{ return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage, Code: res.StatusCode, Message: res.StatusMessage,
} }
} }
c.stopReadRoutines() c.stopTransportRoutines()
switch c.state { switch c.state {
case clientStatePlay: case clientStatePlay:
@@ -1929,15 +1954,3 @@ func (c *Client) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time,
ct := cm.formats[pkt.PayloadType] ct := cm.formats[pkt.PayloadType]
return ct.rtcpReceiver.PacketNTP(pkt.Timestamp) return ct.rtcpReceiver.PacketNTP(pkt.Timestamp)
} }
func (c *Client) readResponse(res *base.Response) {
c.chReadResponse <- res
}
func (c *Client) readRequest(req *base.Request) {
c.chReadRequest <- req
}
func (c *Client) readError(err error) {
c.chReadError <- err
}

View File

@@ -74,8 +74,8 @@ func (cf *clientFormat) stop() {
func (cf *clientFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) error { func (cf *clientFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) error {
cf.rtcpSender.ProcessPacket(pkt, ntp, cf.format.PTSEqualsDTS(pkt)) cf.rtcpSender.ProcessPacket(pkt, ntp, cf.format.PTSEqualsDTS(pkt))
ok := cf.cm.c.writer.push(func() { ok := cf.cm.c.writer.push(func() error {
cf.cm.writePacketRTPInQueue(byts) return cf.cm.writePacketRTPInQueue(byts)
}) })
if !ok { if !ok {
return liberrors.ErrClientWriteQueueFull{} return liberrors.ErrClientWriteQueueFull{}

View File

@@ -25,8 +25,8 @@ type clientMedia struct {
tcpRTPFrame *base.InterleavedFrame tcpRTPFrame *base.InterleavedFrame
tcpRTCPFrame *base.InterleavedFrame tcpRTCPFrame *base.InterleavedFrame
tcpBuffer []byte tcpBuffer []byte
writePacketRTPInQueue func([]byte) writePacketRTPInQueue func([]byte) error
writePacketRTCPInQueue func([]byte) writePacketRTCPInQueue func([]byte) error
} }
func (cm *clientMedia) close() { func (cm *clientMedia) close() {
@@ -152,29 +152,29 @@ func (cm *clientMedia) findFormatWithSSRC(ssrc uint32) *clientFormat {
return nil return nil
} }
func (cm *clientMedia) writePacketRTPInQueueUDP(payload []byte) { func (cm *clientMedia) writePacketRTPInQueueUDP(payload []byte) error {
cm.udpRTPListener.write(payload) //nolint:errcheck return cm.udpRTPListener.write(payload)
} }
func (cm *clientMedia) writePacketRTCPInQueueUDP(payload []byte) { func (cm *clientMedia) writePacketRTCPInQueueUDP(payload []byte) error {
cm.udpRTCPListener.write(payload) //nolint:errcheck return cm.udpRTCPListener.write(payload)
} }
func (cm *clientMedia) writePacketRTPInQueueTCP(payload []byte) { func (cm *clientMedia) writePacketRTPInQueueTCP(payload []byte) error {
cm.tcpRTPFrame.Payload = payload cm.tcpRTPFrame.Payload = payload
cm.c.nconn.SetWriteDeadline(time.Now().Add(cm.c.WriteTimeout)) cm.c.nconn.SetWriteDeadline(time.Now().Add(cm.c.WriteTimeout))
cm.c.conn.WriteInterleavedFrame(cm.tcpRTPFrame, cm.tcpBuffer) //nolint:errcheck return cm.c.conn.WriteInterleavedFrame(cm.tcpRTPFrame, cm.tcpBuffer)
} }
func (cm *clientMedia) writePacketRTCPInQueueTCP(payload []byte) { func (cm *clientMedia) writePacketRTCPInQueueTCP(payload []byte) error {
cm.tcpRTCPFrame.Payload = payload cm.tcpRTCPFrame.Payload = payload
cm.c.nconn.SetWriteDeadline(time.Now().Add(cm.c.WriteTimeout)) cm.c.nconn.SetWriteDeadline(time.Now().Add(cm.c.WriteTimeout))
cm.c.conn.WriteInterleavedFrame(cm.tcpRTCPFrame, cm.tcpBuffer) //nolint:errcheck return cm.c.conn.WriteInterleavedFrame(cm.tcpRTCPFrame, cm.tcpBuffer)
} }
func (cm *clientMedia) writePacketRTCP(byts []byte) error { func (cm *clientMedia) writePacketRTCP(byts []byte) error {
ok := cm.c.writer.push(func() { ok := cm.c.writer.push(func() error {
cm.writePacketRTCPInQueue(byts) return cm.writePacketRTCPInQueue(byts)
}) })
if !ok { if !ok {
return liberrors.ErrClientWriteQueueFull{} return liberrors.ErrClientWriteQueueFull{}

View File

@@ -12,9 +12,17 @@ type clientReader struct {
mutex sync.Mutex mutex sync.Mutex
allowInterleavedFrames bool allowInterleavedFrames bool
chResponse chan *base.Response
chRequest chan *base.Request
chError chan error
} }
func (r *clientReader) start() { func (r *clientReader) start() {
r.chResponse = make(chan *base.Response)
r.chRequest = make(chan *base.Request)
r.chError = make(chan error)
go r.run() go r.run()
} }
@@ -27,18 +35,17 @@ func (r *clientReader) setAllowInterleavedFrames(v bool) {
func (r *clientReader) wait() { func (r *clientReader) wait() {
for { for {
select { select {
case <-r.c.chReadError: case <-r.chError:
return return
case <-r.c.chReadResponse: case <-r.chResponse:
case <-r.c.chReadRequest: case <-r.chRequest:
} }
} }
} }
func (r *clientReader) run() { func (r *clientReader) run() {
err := r.runInner() r.chError <- r.runInner()
r.c.readError(err)
} }
func (r *clientReader) runInner() error { func (r *clientReader) runInner() error {
@@ -50,10 +57,10 @@ func (r *clientReader) runInner() error {
switch what := what.(type) { switch what := what.(type) {
case *base.Response: case *base.Response:
r.c.readResponse(what) r.chResponse <- what
case *base.Request: case *base.Request:
r.c.readRequest(what) r.chRequest <- what
case *base.InterleavedFrame: case *base.InterleavedFrame:
r.mutex.Lock() r.mutex.Lock()

View File

@@ -126,7 +126,7 @@ func readRequestIgnoreFrames(c *conn.Conn) (*base.Request, error) {
} }
} }
func TestClientRecordSerial(t *testing.T) { func TestClientRecord(t *testing.T) {
for _, transport := range []string{ for _, transport := range []string{
"udp", "udp",
"tcp", "tcp",
@@ -350,7 +350,7 @@ func TestClientRecordSerial(t *testing.T) {
} }
} }
func TestClientRecordParallel(t *testing.T) { func TestClientRecordSocketError(t *testing.T) {
for _, transport := range []string{ for _, transport := range []string{
"udp", "udp",
"tcp", "tcp",
@@ -446,15 +446,6 @@ func TestClientRecordParallel(t *testing.T) {
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}) })
require.NoError(t, err2) require.NoError(t, err2)
req, err2 = readRequestIgnoreFrames(conn)
require.NoError(t, err2)
require.Equal(t, base.Teardown, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
}() }()
c := Client{ c := Client{
@@ -471,9 +462,6 @@ func TestClientRecordParallel(t *testing.T) {
}(), }(),
} }
writerDone := make(chan struct{})
defer func() { <-writerDone }()
medi := testH264Media medi := testH264Media
medias := []*description.Media{medi} medias := []*description.Media{medi}
@@ -481,21 +469,15 @@ func TestClientRecordParallel(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer c.Close() defer c.Close()
go func() { ti := time.NewTicker(50 * time.Millisecond)
defer close(writerDone) defer ti.Stop()
t := time.NewTicker(50 * time.Millisecond) for range ti.C {
defer t.Stop() err := c.WritePacketRTP(medi, &testRTPPacket)
if err != nil {
for range t.C { break
err := c.WritePacketRTP(medi, &testRTPPacket)
if err != nil {
return
}
} }
}() }
time.Sleep(1 * time.Second)
}) })
} }
} }
@@ -645,143 +627,6 @@ func TestClientRecordPauseSerial(t *testing.T) {
} }
} }
func TestClientRecordPauseParallel(t *testing.T) {
for _, transport := range []string{
"udp",
"tcp",
} {
t.Run(transport, func(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
nconn, err2 := l.Accept()
require.NoError(t, err2)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err2 := conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Options, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Announce),
string(base.Setup),
string(base.Record),
string(base.Pause),
}, ", ")},
},
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Announce, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Setup, req.Method)
var inTH headers.Transport
err2 = inTH.Unmarshal(req.Header["Transport"])
require.NoError(t, err2)
th := headers.Transport{
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
}
if transport == "udp" {
th.Protocol = headers.TransportProtocolUDP
th.ServerPorts = &[2]int{34556, 34557}
th.ClientPorts = inTH.ClientPorts
} else {
th.Protocol = headers.TransportProtocolTCP
th.InterleavedIDs = inTH.InterleavedIDs
}
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
})
require.NoError(t, err2)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Record, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
req, err2 = readRequestIgnoreFrames(conn)
require.NoError(t, err2)
require.Equal(t, base.Pause, req.Method)
err2 = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err2)
}()
c := Client{
Transport: func() *Transport {
if transport == "udp" {
v := TransportUDP
return &v
}
v := TransportTCP
return &v
}(),
}
medi := testH264Media
medias := []*description.Media{medi}
err = record(&c, "rtsp://localhost:8554/teststream", medias, nil)
require.NoError(t, err)
writerDone := make(chan struct{})
go func() {
defer close(writerDone)
t := time.NewTicker(50 * time.Millisecond)
defer t.Stop()
for range t.C {
err2 := c.WritePacketRTP(medi, &testRTPPacket)
if err2 != nil {
return
}
}
}()
time.Sleep(1 * time.Second)
_, err = c.Pause()
require.NoError(t, err)
c.Close()
<-writerDone
})
}
}
func TestClientRecordAutomaticProtocol(t *testing.T) { func TestClientRecordAutomaticProtocol(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554") l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err) require.NoError(t, err)

View File

@@ -63,10 +63,9 @@ type ServerConn struct {
bc *bytecounter.ByteCounter bc *bytecounter.ByteCounter
conn *conn.Conn conn *conn.Conn
session *ServerSession session *ServerSession
reader *serverConnReader
// in // in
chReadRequest chan readReq
chReadError chan error
chRemoveSession chan *ServerSession chRemoveSession chan *ServerSession
// out // out
@@ -84,8 +83,6 @@ func (sc *ServerConn) initialize() {
sc.ctx = ctx sc.ctx = ctx
sc.ctxCancel = ctxCancel sc.ctxCancel = ctxCancel
sc.remoteAddr = sc.nconn.RemoteAddr().(*net.TCPAddr) sc.remoteAddr = sc.nconn.RemoteAddr().(*net.TCPAddr)
sc.chReadRequest = make(chan readReq)
sc.chReadError = make(chan error)
sc.chRemoveSession = make(chan *ServerSession) sc.chRemoveSession = make(chan *ServerSession)
sc.done = make(chan struct{}) sc.done = make(chan struct{})
@@ -142,10 +139,10 @@ func (sc *ServerConn) run() {
} }
sc.conn = conn.NewConn(sc.bc) sc.conn = conn.NewConn(sc.bc)
cr := &serverConnReader{ sc.reader = &serverConnReader{
sc: sc, sc: sc,
} }
cr.initialize() sc.reader.initialize()
err := sc.runInner() err := sc.runInner()
@@ -153,7 +150,9 @@ func (sc *ServerConn) run() {
sc.nconn.Close() sc.nconn.Close()
cr.wait() if sc.reader != nil {
sc.reader.wait()
}
if sc.session != nil { if sc.session != nil {
sc.session.removeConn(sc) sc.session.removeConn(sc)
@@ -172,10 +171,11 @@ func (sc *ServerConn) run() {
func (sc *ServerConn) runInner() error { func (sc *ServerConn) runInner() error {
for { for {
select { select {
case req := <-sc.chReadRequest: case req := <-sc.reader.chRequest:
req.res <- sc.handleRequestOuter(req.req) req.res <- sc.handleRequestOuter(req.req)
case err := <-sc.chReadError: case err := <-sc.reader.chError:
sc.reader = nil
return err return err
case ss := <-sc.chRemoveSession: case ss := <-sc.chRemoveSession:
@@ -446,20 +446,3 @@ func (sc *ServerConn) removeSession(ss *ServerSession) {
case <-sc.ctx.Done(): case <-sc.ctx.Done():
} }
} }
func (sc *ServerConn) readRequest(req readReq) error {
select {
case sc.chReadRequest <- req:
return <-req.res
case <-sc.ctx.Done():
return liberrors.ErrServerTerminated{}
}
}
func (sc *ServerConn) readError(err error) {
select {
case sc.chReadError <- err:
case <-sc.ctx.Done():
}
}

View File

@@ -2,6 +2,7 @@ package gortsplib
import ( import (
"errors" "errors"
"fmt"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -25,26 +26,35 @@ func isSwitchReadFuncError(err error) bool {
type serverConnReader struct { type serverConnReader struct {
sc *ServerConn sc *ServerConn
chReadDone chan struct{} chRequest chan readReq
chError chan error
} }
func (cr *serverConnReader) initialize() { func (cr *serverConnReader) initialize() {
cr.chReadDone = make(chan struct{}) cr.chRequest = make(chan readReq)
cr.chError = make(chan error)
go cr.run() go cr.run()
} }
func (cr *serverConnReader) wait() { func (cr *serverConnReader) wait() {
<-cr.chReadDone for {
select {
case <-cr.chError:
return
case req := <-cr.chRequest:
req.res <- fmt.Errorf("terminated")
}
}
} }
func (cr *serverConnReader) run() { func (cr *serverConnReader) run() {
defer close(cr.chReadDone)
readFunc := cr.readFuncStandard readFunc := cr.readFuncStandard
for { for {
err := readFunc() err := readFunc()
var eerr switchReadFuncError var eerr switchReadFuncError
if errors.As(err, &eerr) { if errors.As(err, &eerr) {
if eerr.tcp { if eerr.tcp {
@@ -55,7 +65,7 @@ func (cr *serverConnReader) run() {
continue continue
} }
cr.sc.readError(err) cr.chError <- err
break break
} }
} }
@@ -74,7 +84,9 @@ func (cr *serverConnReader) readFuncStandard() error {
case *base.Request: case *base.Request:
cres := make(chan error) cres := make(chan error)
req := readReq{req: what, res: cres} req := readReq{req: what, res: cres}
err := cr.sc.readRequest(req) cr.chRequest <- req
err := <-cres
if err != nil { if err != nil {
return err return err
} }
@@ -108,7 +120,9 @@ func (cr *serverConnReader) readFuncTCP() error {
case *base.Request: case *base.Request:
cres := make(chan error) cres := make(chan error)
req := readReq{req: what, res: cres} req := readReq{req: what, res: cres}
err := cr.sc.readRequest(req) cr.chRequest <- req
err := <-cres
if err != nil { if err != nil {
return err return err
} }

View File

@@ -11,7 +11,7 @@ type serverMulticastWriter struct {
rtpl *serverUDPListener rtpl *serverUDPListener
rtcpl *serverUDPListener rtcpl *serverUDPListener
writer asyncProcessor writer *asyncProcessor
rtpAddr *net.UDPAddr rtpAddr *net.UDPAddr
rtcpAddr *net.UDPAddr rtcpAddr *net.UDPAddr
} }
@@ -48,7 +48,10 @@ func (h *serverMulticastWriter) initialize() error {
h.rtpAddr = rtpAddr h.rtpAddr = rtpAddr
h.rtcpAddr = rtcpAddr h.rtcpAddr = rtcpAddr
h.writer.allocateBuffer(h.s.WriteQueueSize) h.writer = &asyncProcessor{
bufferSize: h.s.WriteQueueSize,
}
h.writer.initialize()
h.writer.start() h.writer.start()
return nil return nil
@@ -65,8 +68,8 @@ func (h *serverMulticastWriter) ip() net.IP {
} }
func (h *serverMulticastWriter) writePacketRTP(payload []byte) error { func (h *serverMulticastWriter) writePacketRTP(payload []byte) error {
ok := h.writer.push(func() { ok := h.writer.push(func() error {
h.rtpl.write(payload, h.rtpAddr) //nolint:errcheck return h.rtpl.write(payload, h.rtpAddr)
}) })
if !ok { if !ok {
return liberrors.ErrServerWriteQueueFull{} return liberrors.ErrServerWriteQueueFull{}
@@ -76,8 +79,8 @@ func (h *serverMulticastWriter) writePacketRTP(payload []byte) error {
} }
func (h *serverMulticastWriter) writePacketRTCP(payload []byte) error { func (h *serverMulticastWriter) writePacketRTCP(payload []byte) error {
ok := h.writer.push(func() { ok := h.writer.push(func() error {
h.rtcpl.write(payload, h.rtcpAddr) //nolint:errcheck return h.rtcpl.write(payload, h.rtcpAddr)
}) })
if !ok { if !ok {
return liberrors.ErrServerWriteQueueFull{} return liberrors.ErrServerWriteQueueFull{}

View File

@@ -765,7 +765,7 @@ func TestServerPlay(t *testing.T) {
var l1 net.PacketConn var l1 net.PacketConn
var l2 net.PacketConn var l2 net.PacketConn
switch transport { switch transport { //nolint:dupl
case "udp": case "udp":
require.Equal(t, headers.TransportProtocolUDP, th.Protocol) require.Equal(t, headers.TransportProtocolUDP, th.Protocol)
require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery) require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery)
@@ -942,6 +942,186 @@ func TestServerPlay(t *testing.T) {
} }
} }
func TestServerPlaySocketError(t *testing.T) {
for _, transport := range []string{
"udp",
"multicast",
"tcp",
"tls",
} {
t.Run(transport, func(t *testing.T) {
var stream *ServerStream
connClosed := make(chan struct{})
writeDone := make(chan struct{})
listenIP := multicastCapableIP(t)
s := &Server{
Handler: &testServerHandler{
onConnClose: func(_ *ServerHandlerOnConnCloseCtx) {
close(connClosed)
},
onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onPlay: func(_ *ServerHandlerOnPlayCtx) (*base.Response, error) {
go func() {
defer close(writeDone)
t := time.NewTicker(50 * time.Millisecond)
defer t.Stop()
for range t.C {
err := stream.WritePacketRTP(stream.Description().Medias[0], &testRTPPacket)
if err != nil {
return
}
}
}()
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
RTSPAddress: listenIP + ":8554",
}
switch transport {
case "udp":
s.UDPRTPAddress = "127.0.0.1:8000"
s.UDPRTCPAddress = "127.0.0.1:8001"
case "multicast":
s.MulticastIPRange = "224.1.0.0/16"
s.MulticastRTPPort = 8000
s.MulticastRTCPPort = 8001
case "tls":
cert, err := tls.X509KeyPair(serverCert, serverKey)
require.NoError(t, err)
s.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
stream = NewServerStream(s, &description.Session{Medias: []*description.Media{testH264Media}})
func() {
nconn, err := net.Dial("tcp", listenIP+":8554")
require.NoError(t, err)
defer nconn.Close()
nconn = func() net.Conn {
if transport == "tls" {
return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true})
}
return nconn
}()
conn := conn.NewConn(nconn)
desc := doDescribe(t, conn)
inTH := &headers.Transport{
Mode: transportModePtr(headers.TransportModePlay),
}
switch transport {
case "udp":
v := headers.TransportDeliveryUnicast
inTH.Delivery = &v
inTH.Protocol = headers.TransportProtocolUDP
inTH.ClientPorts = &[2]int{35466, 35467}
case "multicast":
v := headers.TransportDeliveryMulticast
inTH.Delivery = &v
inTH.Protocol = headers.TransportProtocolUDP
default:
v := headers.TransportDeliveryUnicast
inTH.Delivery = &v
inTH.Protocol = headers.TransportProtocolTCP
inTH.InterleavedIDs = &[2]int{5, 6} // odd value
}
res, th := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "")
var l1 net.PacketConn
var l2 net.PacketConn
switch transport { //nolint:dupl
case "udp":
require.Equal(t, headers.TransportProtocolUDP, th.Protocol)
require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery)
l1, err = net.ListenPacket("udp", listenIP+":35466")
require.NoError(t, err)
defer l1.Close()
l2, err = net.ListenPacket("udp", listenIP+":35467")
require.NoError(t, err)
defer l2.Close()
case "multicast":
require.Equal(t, headers.TransportProtocolUDP, th.Protocol)
require.Equal(t, headers.TransportDeliveryMulticast, *th.Delivery)
l1, err = net.ListenPacket("udp", "224.0.0.0:"+strconv.FormatInt(int64(th.Ports[0]), 10))
require.NoError(t, err)
defer l1.Close()
p := ipv4.NewPacketConn(l1)
var intfs []net.Interface
intfs, err = net.Interfaces()
require.NoError(t, err)
for _, intf := range intfs {
err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination})
require.NoError(t, err)
}
l2, err = net.ListenPacket("udp", "224.0.0.0:"+strconv.FormatInt(int64(th.Ports[1]), 10))
require.NoError(t, err)
defer l2.Close()
p = ipv4.NewPacketConn(l2)
intfs, err = net.Interfaces()
require.NoError(t, err)
for _, intf := range intfs {
err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination})
require.NoError(t, err)
}
default:
require.Equal(t, headers.TransportProtocolTCP, th.Protocol)
require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery)
}
session := readSession(t, res)
doPlay(t, conn, "rtsp://"+listenIP+":8554/teststream", session)
}()
<-connClosed
stream.Close()
<-writeDone
})
}
}
func TestServerPlayDecodeErrors(t *testing.T) { func TestServerPlayDecodeErrors(t *testing.T) {
for _, ca := range []struct { for _, ca := range []struct {
proto string proto string

View File

@@ -252,7 +252,7 @@ type ServerSession struct {
announcedDesc *description.Session // publish announcedDesc *description.Session // publish
udpLastPacketTime *int64 // publish udpLastPacketTime *int64 // publish
udpCheckStreamTimer *time.Timer udpCheckStreamTimer *time.Timer
writer asyncProcessor writer *asyncProcessor
timeDecoder *rtptime.GlobalDecoder2 timeDecoder *rtptime.GlobalDecoder2
// in // in
@@ -425,12 +425,14 @@ func (ss *ServerSession) run() {
ss.setuppedStream.readerRemove(ss) ss.setuppedStream.readerRemove(ss)
} }
ss.writer.stop()
for _, sm := range ss.setuppedMedias { for _, sm := range ss.setuppedMedias {
sm.stop() sm.stop()
} }
if ss.writer != nil {
ss.writer.stop()
}
ss.s.closeSession(ss) ss.s.closeSession(ss)
if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok { if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok {
@@ -443,6 +445,13 @@ func (ss *ServerSession) run() {
func (ss *ServerSession) runInner() error { func (ss *ServerSession) runInner() error {
for { for {
chWriterError := func() chan error {
if ss.writer != nil {
return ss.writer.chError
}
return nil
}()
select { select {
case req := <-ss.chHandleRequest: case req := <-ss.chHandleRequest:
ss.lastRequestTime = ss.s.timeNow() ss.lastRequestTime = ss.s.timeNow()
@@ -539,6 +548,9 @@ func (ss *ServerSession) runInner() error {
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod) ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
case err := <-chWriterError:
return err
case <-ss.ctx.Done(): case <-ss.ctx.Done():
return liberrors.ErrServerTerminated{} return liberrors.ErrServerTerminated{}
} }
@@ -930,7 +942,10 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
// inside the callback. // inside the callback.
if ss.state != ServerSessionStatePlay && if ss.state != ServerSessionStatePlay &&
*ss.setuppedTransport != TransportUDPMulticast { *ss.setuppedTransport != TransportUDPMulticast {
ss.writer.allocateBuffer(ss.s.WriteQueueSize) ss.writer = &asyncProcessor{
bufferSize: ss.s.WriteQueueSize,
}
ss.writer.initialize()
} }
res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{ res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{
@@ -1023,7 +1038,10 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
// when recording, writeBuffer is only used to send RTCP receiver reports, // when recording, writeBuffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval. // that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers. // decrease RAM consumption by allocating less buffers.
ss.writer.allocateBuffer(8) ss.writer = &asyncProcessor{
bufferSize: 8,
}
ss.writer.initialize()
res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{ res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{
Session: ss, Session: ss,
@@ -1087,45 +1105,48 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
return res, err return res, err
} }
if ss.setuppedStream != nil { if ss.state == ServerSessionStatePlay || ss.state == ServerSessionStateRecord {
ss.setuppedStream.readerSetInactive(ss) if ss.setuppedStream != nil {
} ss.setuppedStream.readerSetInactive(ss)
ss.writer.stop()
for _, sm := range ss.setuppedMedias {
sm.stop()
}
ss.timeDecoder = nil
switch ss.state {
case ServerSessionStatePlay:
ss.state = ServerSessionStatePrePlay
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = emptyTimer()
case TransportUDPMulticast:
ss.udpCheckStreamTimer = emptyTimer()
default: // TCP
err = switchReadFuncError{false}
ss.tcpConn = nil
} }
case ServerSessionStateRecord: for _, sm := range ss.setuppedMedias {
switch *ss.setuppedTransport { sm.stop()
case TransportUDP:
ss.udpCheckStreamTimer = emptyTimer()
default: // TCP
err = switchReadFuncError{false}
ss.tcpConn = nil
} }
ss.state = ServerSessionStatePreRecord ss.writer.stop()
ss.writer = nil
ss.timeDecoder = nil
switch ss.state {
case ServerSessionStatePlay:
ss.state = ServerSessionStatePrePlay
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = emptyTimer()
case TransportUDPMulticast:
ss.udpCheckStreamTimer = emptyTimer()
default: // TCP
err = switchReadFuncError{false}
ss.tcpConn = nil
}
case ServerSessionStateRecord:
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = emptyTimer()
default: // TCP
err = switchReadFuncError{false}
ss.tcpConn = nil
}
ss.state = ServerSessionStatePreRecord
}
} }
return res, err return res, err

View File

@@ -27,8 +27,8 @@ type serverSessionMedia struct {
tcpRTCPFrame *base.InterleavedFrame tcpRTCPFrame *base.InterleavedFrame
tcpBuffer []byte tcpBuffer []byte
formats map[uint8]*serverSessionFormat // record only formats map[uint8]*serverSessionFormat // record only
writePacketRTPInQueue func([]byte) writePacketRTPInQueue func([]byte) error
writePacketRTCPInQueue func([]byte) writePacketRTCPInQueue func([]byte) error
} }
func (sm *serverSessionMedia) initialize() { func (sm *serverSessionMedia) initialize() {
@@ -115,33 +115,33 @@ func (sm *serverSessionMedia) findFormatWithSSRC(ssrc uint32) *serverSessionForm
return nil return nil
} }
func (sm *serverSessionMedia) writePacketRTPInQueueUDP(payload []byte) { func (sm *serverSessionMedia) writePacketRTPInQueueUDP(payload []byte) error {
atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload))) atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload)))
sm.ss.s.udpRTPListener.write(payload, sm.udpRTPWriteAddr) //nolint:errcheck return sm.ss.s.udpRTPListener.write(payload, sm.udpRTPWriteAddr)
} }
func (sm *serverSessionMedia) writePacketRTCPInQueueUDP(payload []byte) { func (sm *serverSessionMedia) writePacketRTCPInQueueUDP(payload []byte) error {
atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload))) atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload)))
sm.ss.s.udpRTCPListener.write(payload, sm.udpRTCPWriteAddr) //nolint:errcheck return sm.ss.s.udpRTCPListener.write(payload, sm.udpRTCPWriteAddr)
} }
func (sm *serverSessionMedia) writePacketRTPInQueueTCP(payload []byte) { func (sm *serverSessionMedia) writePacketRTPInQueueTCP(payload []byte) error {
atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload))) atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload)))
sm.tcpRTPFrame.Payload = payload sm.tcpRTPFrame.Payload = payload
sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout)) sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout))
sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTPFrame, sm.tcpBuffer) //nolint:errcheck return sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTPFrame, sm.tcpBuffer)
} }
func (sm *serverSessionMedia) writePacketRTCPInQueueTCP(payload []byte) { func (sm *serverSessionMedia) writePacketRTCPInQueueTCP(payload []byte) error {
atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload))) atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload)))
sm.tcpRTCPFrame.Payload = payload sm.tcpRTCPFrame.Payload = payload
sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout)) sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout))
sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTCPFrame, sm.tcpBuffer) //nolint:errcheck return sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTCPFrame, sm.tcpBuffer)
} }
func (sm *serverSessionMedia) writePacketRTP(payload []byte) error { func (sm *serverSessionMedia) writePacketRTP(payload []byte) error {
ok := sm.ss.writer.push(func() { ok := sm.ss.writer.push(func() error {
sm.writePacketRTPInQueue(payload) return sm.writePacketRTPInQueue(payload)
}) })
if !ok { if !ok {
return liberrors.ErrServerWriteQueueFull{} return liberrors.ErrServerWriteQueueFull{}
@@ -151,8 +151,8 @@ func (sm *serverSessionMedia) writePacketRTP(payload []byte) error {
} }
func (sm *serverSessionMedia) writePacketRTCP(payload []byte) error { func (sm *serverSessionMedia) writePacketRTCP(payload []byte) error {
ok := sm.ss.writer.push(func() { ok := sm.ss.writer.push(func() error {
sm.writePacketRTCPInQueue(payload) return sm.writePacketRTCPInQueue(payload)
}) })
if !ok { if !ok {
return liberrors.ErrServerWriteQueueFull{} return liberrors.ErrServerWriteQueueFull{}