support calling Pause() in parallel with ReadFrame()

This commit is contained in:
aler9
2020-11-15 19:17:47 +01:00
parent 5636d64651
commit 862cd0ea62
6 changed files with 158 additions and 107 deletions

View File

@@ -87,17 +87,14 @@ type ConnClient struct {
udpLastFrameTimes map[int]*int64 udpLastFrameTimes map[int]*int64
udpRtpListeners map[int]*connClientUDPListener udpRtpListeners map[int]*connClientUDPListener
udpRtcpListeners map[int]*connClientUDPListener udpRtcpListeners map[int]*connClientUDPListener
response *base.Response
frame *base.InterleavedFrame
tcpFrameBuffer *multibuffer.MultiBuffer tcpFrameBuffer *multibuffer.MultiBuffer
readFrameFunc func() (int, StreamType, []byte, error)
writeFrameFunc func(trackId int, streamType StreamType, content []byte) error writeFrameFunc func(trackId int, streamType StreamType, content []byte) error
getParameterSupported bool getParameterSupported bool
backgroundUDPError error backgroundError error
backgroundTerminate chan struct{} backgroundTerminate chan struct{}
backgroundDone chan struct{} backgroundDone chan struct{}
udpFrame chan base.InterleavedFrame readFrame chan base.InterleavedFrame
} }
// Close closes all the ConnClient resources. // Close closes all the ConnClient resources.
@@ -117,15 +114,6 @@ func (c *ConnClient) Close() error {
}) })
} }
if s == connClientStatePlay {
if *c.streamProtocol == StreamProtocolUDP {
go func() {
for range c.udpFrame {
}
}()
}
}
for _, l := range c.udpRtpListeners { for _, l := range c.udpRtpListeners {
l.close() l.close()
} }
@@ -134,12 +122,6 @@ func (c *ConnClient) Close() error {
l.close() l.close()
} }
if s == connClientStatePlay {
if *c.streamProtocol == StreamProtocolUDP {
close(c.udpFrame)
}
}
err := c.nconn.Close() err := c.nconn.Close()
return err return err
} }
@@ -169,10 +151,12 @@ func (c *ConnClient) Tracks() Tracks {
} }
func (c *ConnClient) readFrameTCPOrResponse() (interface{}, error) { func (c *ConnClient) readFrameTCPOrResponse() (interface{}, error) {
c.frame.Content = c.tcpFrameBuffer.Next()
c.nconn.SetReadDeadline(time.Now().Add(c.d.ReadTimeout)) c.nconn.SetReadDeadline(time.Now().Add(c.d.ReadTimeout))
return base.ReadInterleavedFrameOrResponse(c.frame, c.response, c.br) f := base.InterleavedFrame{
Content: c.tcpFrameBuffer.Next(),
}
r := base.Response{}
return base.ReadInterleavedFrameOrResponse(&f, &r, c.br)
} }
// Do writes a Request and reads a Response. // Do writes a Request and reads a Response.
@@ -589,23 +573,6 @@ func (c *ConnClient) Pause() (*base.Response, error) {
close(c.backgroundTerminate) close(c.backgroundTerminate)
<-c.backgroundDone <-c.backgroundDone
if s == connClientStatePlay {
if *c.streamProtocol == StreamProtocolUDP {
ch := c.udpFrame
go func() {
for range ch {
}
}()
for trackId := range c.udpRtpListeners {
c.udpRtpListeners[trackId].stop()
c.udpRtcpListeners[trackId].stop()
}
close(ch)
}
}
res, err := c.Do(&base.Request{ res, err := c.Do(&base.Request{
Method: base.PAUSE, Method: base.PAUSE,
URL: c.streamUrl, URL: c.streamUrl,

View File

@@ -101,12 +101,12 @@ func (c *ConnClient) backgroundRecordUDP() {
case <-c.backgroundTerminate: case <-c.backgroundTerminate:
c.nconn.SetReadDeadline(time.Now()) c.nconn.SetReadDeadline(time.Now())
<-readDone <-readDone
c.backgroundUDPError = fmt.Errorf("terminated") c.backgroundError = fmt.Errorf("terminated")
c.state.store(connClientStateUDPError) c.state.store(connClientStateUDPError)
return return
case err := <-readDone: case err := <-readDone:
c.backgroundUDPError = err c.backgroundError = err
c.state.store(connClientStateUDPError) c.state.store(connClientStateUDPError)
return return
} }
@@ -119,7 +119,7 @@ func (c *ConnClient) backgroundRecordTCP() {
func (c *ConnClient) writeFrameUDP(trackId int, streamType StreamType, content []byte) error { func (c *ConnClient) writeFrameUDP(trackId int, streamType StreamType, content []byte) error {
switch c.state.load() { switch c.state.load() {
case connClientStateUDPError: case connClientStateUDPError:
return c.backgroundUDPError return c.backgroundError
case connClientStateRecord: case connClientStateRecord:

View File

@@ -30,33 +30,19 @@ func (c *ConnClient) Play() (*base.Response, error) {
return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage)
} }
if *c.streamProtocol == StreamProtocolUDP {
c.readFrameFunc = c.readFrameUDP
c.writeFrameFunc = c.writeFrameUDP
} else {
c.readFrameFunc = c.readFrameTCP
c.writeFrameFunc = c.writeFrameTCP
}
c.state.store(connClientStatePlay) c.state.store(connClientStatePlay)
c.readFrame = make(chan base.InterleavedFrame)
c.backgroundTerminate = make(chan struct{}) c.backgroundTerminate = make(chan struct{})
c.backgroundDone = make(chan struct{}) c.backgroundDone = make(chan struct{})
if *c.streamProtocol == StreamProtocolUDP { if *c.streamProtocol == StreamProtocolUDP {
c.udpFrame = make(chan base.InterleavedFrame)
for trackId := range c.udpRtpListeners {
c.udpRtpListeners[trackId].start()
c.udpRtcpListeners[trackId].start()
}
// open the firewall by sending packets to the counterpart // open the firewall by sending packets to the counterpart
for trackId := range c.udpRtpListeners { for trackId := range c.udpRtpListeners {
c.WriteFrame(trackId, StreamTypeRtp, c.udpRtpListeners[trackId].write(
[]byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
c.WriteFrame(trackId, StreamTypeRtcp, c.udpRtcpListeners[trackId].write(
[]byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00})
} }
@@ -71,7 +57,28 @@ func (c *ConnClient) Play() (*base.Response, error) {
func (c *ConnClient) backgroundPlayUDP() { func (c *ConnClient) backgroundPlayUDP() {
defer close(c.backgroundDone) defer close(c.backgroundDone)
c.nconn.SetReadDeadline(time.Time{}) // disable deadline defer func() {
ch := c.readFrame
go func() {
for range ch {
}
}()
for trackId := range c.udpRtpListeners {
c.udpRtpListeners[trackId].stop()
c.udpRtcpListeners[trackId].stop()
}
close(ch)
}()
for trackId := range c.udpRtpListeners {
c.udpRtpListeners[trackId].start()
c.udpRtcpListeners[trackId].start()
}
// disable deadline
c.nconn.SetReadDeadline(time.Time{})
readDone := make(chan error) readDone := make(chan error)
go func() { go func() {
@@ -99,14 +106,13 @@ func (c *ConnClient) backgroundPlayUDP() {
case <-c.backgroundTerminate: case <-c.backgroundTerminate:
c.nconn.SetReadDeadline(time.Now()) c.nconn.SetReadDeadline(time.Now())
<-readDone <-readDone
c.backgroundUDPError = fmt.Errorf("terminated") c.backgroundError = fmt.Errorf("terminated")
c.state.store(connClientStateUDPError)
return return
case <-reportTicker.C: case <-reportTicker.C:
for trackId := range c.rtcpReceivers { for trackId := range c.rtcpReceivers {
frame := c.rtcpReceivers[trackId].Report() report := c.rtcpReceivers[trackId].Report()
c.WriteFrame(trackId, StreamTypeRtcp, frame) c.udpRtcpListeners[trackId].write(report)
} }
case <-keepaliveTicker.C: case <-keepaliveTicker.C:
@@ -125,8 +131,7 @@ func (c *ConnClient) backgroundPlayUDP() {
if err != nil { if err != nil {
c.nconn.SetReadDeadline(time.Now()) c.nconn.SetReadDeadline(time.Now())
<-readDone <-readDone
c.backgroundUDPError = err c.backgroundError = err
c.state.store(connClientStateUDPError)
return return
} }
@@ -139,11 +144,14 @@ func (c *ConnClient) backgroundPlayUDP() {
if now.Sub(last) >= c.d.ReadTimeout { if now.Sub(last) >= c.d.ReadTimeout {
c.nconn.SetReadDeadline(time.Now()) c.nconn.SetReadDeadline(time.Now())
<-readDone <-readDone
c.backgroundUDPError = fmt.Errorf("no packets received recently (maybe there's a firewall/NAT in between)") c.backgroundError = fmt.Errorf("no packets received recently (maybe there's a firewall/NAT in between)")
c.state.store(connClientStateUDPError)
return return
} }
} }
case err := <-readDone:
c.backgroundError = err
return
} }
} }
} }
@@ -151,51 +159,71 @@ func (c *ConnClient) backgroundPlayUDP() {
func (c *ConnClient) backgroundPlayTCP() { func (c *ConnClient) backgroundPlayTCP() {
defer close(c.backgroundDone) defer close(c.backgroundDone)
defer func() {
ch := c.readFrame
go func() {
for range ch {
}
}()
close(ch)
}()
readDone := make(chan error)
go func() {
for {
c.nconn.SetReadDeadline(time.Now().Add(c.d.ReadTimeout))
frame := base.InterleavedFrame{
Content: c.tcpFrameBuffer.Next(),
}
err := frame.Read(c.br)
if err != nil {
readDone <- err
return
}
c.rtcpReceivers[frame.TrackId].OnFrame(frame.StreamType, frame.Content)
c.readFrame <- frame
}
}()
reportTicker := time.NewTicker(clientReceiverReportPeriod) reportTicker := time.NewTicker(clientReceiverReportPeriod)
defer reportTicker.Stop() defer reportTicker.Stop()
for { for {
select { select {
case <-c.backgroundTerminate: case <-c.backgroundTerminate:
c.nconn.SetReadDeadline(time.Now())
<-readDone
c.backgroundError = fmt.Errorf("terminated")
return return
case <-reportTicker.C: case <-reportTicker.C:
for trackId := range c.rtcpReceivers { for trackId := range c.rtcpReceivers {
frame := c.rtcpReceivers[trackId].Report() report := c.rtcpReceivers[trackId].Report()
c.WriteFrame(trackId, StreamTypeRtcp, frame) c.nconn.SetWriteDeadline(time.Now().Add(c.d.WriteTimeout))
frame := base.InterleavedFrame{
TrackId: trackId,
StreamType: StreamTypeRtcp,
Content: report,
}
frame.Write(c.bw)
} }
case err := <-readDone:
c.backgroundError = err
return
} }
} }
} }
func (c *ConnClient) readFrameUDP() (int, StreamType, []byte, error) {
if c.state.load() != connClientStatePlay {
return 0, 0, nil, fmt.Errorf("not playing")
}
f := <-c.udpFrame
return f.TrackId, f.StreamType, f.Content, nil
}
func (c *ConnClient) readFrameTCP() (int, StreamType, []byte, error) {
if c.state.load() != connClientStatePlay {
return 0, 0, nil, fmt.Errorf("not playing")
}
c.nconn.SetReadDeadline(time.Now().Add(c.d.ReadTimeout))
c.frame.Content = c.tcpFrameBuffer.Next()
err := c.frame.Read(c.br)
if err != nil {
return 0, 0, nil, err
}
c.rtcpReceivers[c.frame.TrackId].OnFrame(c.frame.StreamType, c.frame.Content)
return c.frame.TrackId, c.frame.StreamType, c.frame.Content, nil
}
// ReadFrame reads a frame. // ReadFrame reads a frame.
// This can be used only after Play(). // This can be used only after Play().
func (c *ConnClient) ReadFrame() (int, StreamType, []byte, error) { func (c *ConnClient) ReadFrame() (int, StreamType, []byte, error) {
return c.readFrameFunc() f, ok := <-c.readFrame
if !ok {
return 0, 0, nil, c.backgroundError
}
return f.TrackId, f.StreamType, f.Content, nil
} }

View File

@@ -76,7 +76,7 @@ func (l *connClientUDPListener) run() {
l.c.rtcpReceivers[l.trackId].OnFrame(l.streamType, buf[:n]) l.c.rtcpReceivers[l.trackId].OnFrame(l.streamType, buf[:n])
l.c.udpFrame <- base.InterleavedFrame{ l.c.readFrame <- base.InterleavedFrame{
TrackId: l.trackId, TrackId: l.trackId,
StreamType: l.streamType, StreamType: l.streamType,
Content: buf[:n], Content: buf[:n],

View File

@@ -96,14 +96,12 @@ func (d Dialer) Dial(host string) (*ConnClient, error) {
v := connClientState(0) v := connClientState(0)
return &v return &v
}(), }(),
rtcpReceivers: make(map[int]*rtcpreceiver.RtcpReceiver), rtcpReceivers: make(map[int]*rtcpreceiver.RtcpReceiver),
udpLastFrameTimes: make(map[int]*int64), udpLastFrameTimes: make(map[int]*int64),
udpRtpListeners: make(map[int]*connClientUDPListener), udpRtpListeners: make(map[int]*connClientUDPListener),
udpRtcpListeners: make(map[int]*connClientUDPListener), udpRtcpListeners: make(map[int]*connClientUDPListener),
response: &base.Response{}, tcpFrameBuffer: multibuffer.New(d.ReadBufferCount+1, clientTCPFrameReadBufferSize),
frame: &base.InterleavedFrame{}, backgroundError: fmt.Errorf("not running"),
tcpFrameBuffer: multibuffer.New(d.ReadBufferCount, clientTCPFrameReadBufferSize),
backgroundUDPError: fmt.Errorf("not running"),
}, nil }, nil
} }

View File

@@ -251,6 +251,64 @@ func TestDialReadPause(t *testing.T) {
} }
} }
func TestDialReadPauseParallel(t *testing.T) {
for _, proto := range []string{
"udp",
"tcp",
} {
t.Run(proto, func(t *testing.T) {
cnt1, err := newContainer("rtsp-simple-server", "server", []string{"{}"})
require.NoError(t, err)
defer cnt1.close()
time.Sleep(1 * time.Second)
cnt2, err := newContainer("ffmpeg", "publish", []string{
"-re",
"-stream_loop", "-1",
"-i", "/emptyvideo.ts",
"-c", "copy",
"-f", "rtsp",
"-rtsp_transport", "udp",
"rtsp://localhost:8554/teststream",
})
require.NoError(t, err)
defer cnt2.close()
time.Sleep(1 * time.Second)
dialer := func() Dialer {
if proto == "udp" {
return Dialer{}
}
return Dialer{StreamProtocol: StreamProtocolTCP}
}()
conn, err := dialer.DialRead("rtsp://localhost:8554/teststream")
require.NoError(t, err)
readDone := make(chan struct{})
go func() {
defer close(readDone)
for {
_, _, _, err := conn.ReadFrame()
if err != nil {
break
}
}
}()
time.Sleep(1 * time.Second)
conn.Pause()
<-readDone
conn.Close()
})
}
}
func TestDialPublish(t *testing.T) { func TestDialPublish(t *testing.T) {
for _, proto := range []string{ for _, proto := range []string{
"udp", "udp",