client: allow calling ReadFrames() when publishing

This commit is contained in:
aler9
2021-03-28 21:36:12 +02:00
parent 58a8e777f0
commit 0e6811a178
6 changed files with 543 additions and 271 deletions

View File

@@ -88,11 +88,13 @@ type ClientConn struct {
writeFrameAllowed bool writeFrameAllowed bool
writeError error writeError error
backgroundRunning bool backgroundRunning bool
readCB func(int, StreamType, []byte)
// TCP stream protocol
tcpFrameBuffer *multibuffer.MultiBuffer
// read // read
rtpInfo *headers.RTPInfo rtpInfo *headers.RTPInfo
tcpFrameBuffer *multibuffer.MultiBuffer
readCB func(int, StreamType, []byte)
// in // in
backgroundTerminate chan struct{} backgroundTerminate chan struct{}
@@ -695,15 +697,15 @@ func (cc *ClientConn) Setup(mode headers.TransportMode, track *Track,
if mode == headers.TransportModePlay { if mode == headers.TransportModePlay {
cc.state = clientConnStatePrePlay cc.state = clientConnStatePrePlay
if *cc.streamProtocol == StreamProtocolTCP && cc.tcpFrameBuffer == nil {
cc.tcpFrameBuffer = multibuffer.New(uint64(cc.conf.ReadBufferCount), uint64(cc.conf.ReadBufferSize))
}
} else { } else {
cc.state = clientConnStatePreRecord cc.state = clientConnStatePreRecord
} }
if *cc.streamProtocol == StreamProtocolTCP &&
cc.tcpFrameBuffer == nil {
cc.tcpFrameBuffer = multibuffer.New(uint64(cc.conf.ReadBufferCount), uint64(cc.conf.ReadBufferSize))
}
return res, nil return res, nil
} }
@@ -775,3 +777,105 @@ func (cc *ClientConn) WriteFrame(trackID int, streamType StreamType, payload []b
} }
return frame.Write(cc.bw) return frame.Write(cc.bw)
} }
// ReadFrames starts reading frames.
// it returns a channel that is written when the reading stops.
func (cc *ClientConn) ReadFrames(onFrame func(int, StreamType, []byte)) chan error {
// channel is buffered, since listening to it is not mandatory
done := make(chan error, 1)
err := cc.checkState(map[clientConnState]struct{}{
clientConnStatePlay: {},
clientConnStateRecord: {},
})
if err != nil {
done <- err
return done
}
// close previous ReadFrames()
if cc.backgroundRunning {
close(cc.backgroundTerminate)
<-cc.backgroundDone
}
cc.backgroundRunning = true
cc.backgroundTerminate = make(chan struct{})
cc.backgroundDone = make(chan struct{})
cc.readCB = onFrame
cc.writeFrameAllowed = true
go func() {
done <- func() error {
safeState := cc.state
err := func() error {
if *cc.streamProtocol == StreamProtocolUDP {
if cc.state == clientConnStatePlay {
return cc.backgroundPlayUDP()
}
return cc.backgroundRecordUDP()
}
if cc.state == clientConnStatePlay {
return cc.backgroundPlayTCP()
}
return cc.backgroundRecordTCP()
}()
cc.writeError = err
func() {
cc.writeMutex.Lock()
defer cc.writeMutex.Unlock()
cc.writeFrameAllowed = false
}()
close(cc.backgroundDone)
// automatically change protocol in case of timeout
if *cc.streamProtocol == StreamProtocolUDP &&
safeState == clientConnStatePlay {
if _, ok := err.(liberrors.ErrClientNoUDPPacketsRecently); ok {
if cc.conf.StreamProtocol == nil {
prevURL := cc.streamURL
prevTracks := cc.tracks
cc.reset()
v := StreamProtocolTCP
cc.streamProtocol = &v
err := cc.connOpen(prevURL.Scheme, prevURL.Host)
if err != nil {
return err
}
_, err = cc.Options(prevURL)
if err != nil {
cc.Close()
return err
}
for _, track := range prevTracks {
_, err := cc.Setup(headers.TransportModePlay, track.track, 0, 0)
if err != nil {
cc.Close()
return err
}
}
_, err = cc.Play()
if err != nil {
cc.Close()
return err
}
return <-cc.ReadFrames(onFrame)
}
}
}
return err
}()
}()
return done
}

View File

@@ -54,70 +54,6 @@ func (cc *ClientConn) Announce(u *base.URL, tracks Tracks) (*base.Response, erro
return res, nil return res, nil
} }
func (cc *ClientConn) backgroundRecordUDP() {
// disable deadline
cc.nconn.SetReadDeadline(time.Time{})
readerDone := make(chan error)
go func() {
for {
var res base.Response
err := res.Read(cc.br)
if err != nil {
readerDone <- err
return
}
}
}()
reportTicker := time.NewTicker(cc.conf.senderReportPeriod)
defer reportTicker.Stop()
for {
select {
case <-cc.backgroundTerminate:
cc.nconn.SetReadDeadline(time.Now())
<-readerDone
cc.writeError = fmt.Errorf("terminated")
return
case <-reportTicker.C:
now := time.Now()
for trackID, cct := range cc.tracks {
sr := cct.rtcpSender.Report(now)
if sr != nil {
cc.WriteFrame(trackID, StreamTypeRTCP, sr)
}
}
case err := <-readerDone:
cc.writeError = err
return
}
}
}
func (cc *ClientConn) backgroundRecordTCP() {
reportTicker := time.NewTicker(cc.conf.senderReportPeriod)
defer reportTicker.Stop()
for {
select {
case <-cc.backgroundTerminate:
return
case <-reportTicker.C:
now := time.Now()
for trackID, cct := range cc.tracks {
sr := cct.rtcpSender.Report(now)
if sr != nil {
cc.WriteFrame(trackID, StreamTypeRTCP, sr)
}
}
}
}
}
// Record writes a RECORD request and reads a Response. // Record writes a RECORD request and reads a Response.
// This can be called only after Announce() and Setup(). // This can be called only after Announce() and Setup().
func (cc *ClientConn) Record() (*base.Response, error) { func (cc *ClientConn) Record() (*base.Response, error) {
@@ -142,27 +78,107 @@ func (cc *ClientConn) Record() (*base.Response, error) {
} }
cc.state = clientConnStateRecord cc.state = clientConnStateRecord
cc.writeFrameAllowed = true
cc.backgroundRunning = true cc.ReadFrames(func(trackID int, streamType StreamType, payload []byte) {
cc.backgroundTerminate = make(chan struct{}) })
cc.backgroundDone = make(chan struct{})
go func() {
defer close(cc.backgroundDone)
defer func() {
cc.writeMutex.Lock()
defer cc.writeMutex.Unlock()
cc.writeFrameAllowed = false
}()
if *cc.streamProtocol == StreamProtocolUDP {
cc.backgroundRecordUDP()
} else {
cc.backgroundRecordTCP()
}
}()
return nil, nil return nil, nil
} }
func (cc *ClientConn) backgroundRecordUDP() error {
for _, cct := range cc.tracks {
cct.udpRTPListener.start()
cct.udpRTCPListener.start()
}
defer func() {
for _, cct := range cc.tracks {
cct.udpRTPListener.stop()
cct.udpRTCPListener.stop()
}
}()
// disable deadline
cc.nconn.SetReadDeadline(time.Time{})
readerDone := make(chan error)
go func() {
for {
var res base.Response
err := res.Read(cc.br)
if err != nil {
readerDone <- err
return
}
}
}()
reportTicker := time.NewTicker(cc.conf.senderReportPeriod)
defer reportTicker.Stop()
for {
select {
case <-cc.backgroundTerminate:
cc.nconn.SetReadDeadline(time.Now())
<-readerDone
return fmt.Errorf("terminated")
case <-reportTicker.C:
now := time.Now()
for trackID, cct := range cc.tracks {
sr := cct.rtcpSender.Report(now)
if sr != nil {
cc.WriteFrame(trackID, StreamTypeRTCP, sr)
}
}
case err := <-readerDone:
return err
}
}
}
func (cc *ClientConn) backgroundRecordTCP() error {
// disable deadline
cc.nconn.SetReadDeadline(time.Time{})
readerDone := make(chan error)
go func() {
for {
frame := base.InterleavedFrame{
Payload: cc.tcpFrameBuffer.Next(),
}
err := frame.Read(cc.br)
if err != nil {
readerDone <- err
return
}
cc.readCB(frame.TrackID, frame.StreamType, frame.Payload)
}
}()
reportTicker := time.NewTicker(cc.conf.senderReportPeriod)
defer reportTicker.Stop()
for {
select {
case <-cc.backgroundTerminate:
cc.nconn.SetReadDeadline(time.Now())
<-readerDone
return fmt.Errorf("terminated")
case <-reportTicker.C:
now := time.Now()
for trackID, cct := range cc.tracks {
sr := cct.rtcpSender.Report(now)
if sr != nil {
cc.WriteFrame(trackID, StreamTypeRTCP, sr)
}
}
case err := <-readerDone:
return err
}
}
}

View File

@@ -753,3 +753,174 @@ func TestClientPublishRTCP(t *testing.T) {
err = conn.WriteFrame(track.ID, StreamTypeRTP, byts) err = conn.WriteFrame(track.ID, StreamTypeRTP, byts)
require.NoError(t, err) require.NoError(t, err)
} }
func TestClientPublishReadManualRTCP(t *testing.T) {
for _, proto := range []string{
"udp",
"tcp",
} {
t.Run(proto, func(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
conn, err := l.Accept()
require.NoError(t, err)
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
var req base.Request
err = req.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)
err = base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Announce),
string(base.Setup),
string(base.Record),
}, ", ")},
},
}.Write(bconn.Writer)
require.NoError(t, err)
err = req.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.Announce, req.Method)
err = base.Response{
StatusCode: base.StatusOK,
}.Write(bconn.Writer)
require.NoError(t, err)
err = req.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.Setup, req.Method)
var inTH headers.Transport
err = inTH.Read(req.Header["Transport"])
require.NoError(t, err)
th := headers.Transport{
Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast
return &v
}(),
}
var l1 net.PacketConn
if proto == "udp" {
var err error
l1, err = net.ListenPacket("udp", "localhost:34557")
require.NoError(t, err)
defer l1.Close()
th.Protocol = StreamProtocolUDP
th.ServerPorts = &[2]int{34556, 34557}
th.ClientPorts = inTH.ClientPorts
} else {
th.Protocol = StreamProtocolTCP
th.InterleavedIDs = inTH.InterleavedIDs
}
err = base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Write(),
},
}.Write(bconn.Writer)
require.NoError(t, err)
err = req.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.Record, req.Method)
err = base.Response{
StatusCode: base.StatusOK,
}.Write(bconn.Writer)
require.NoError(t, err)
if proto == "udp" {
buf := make([]byte, 2048)
n, _, err := l1.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, buf[:n])
} else {
var f base.InterleavedFrame
f.Payload = make([]byte, 2048)
err = f.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, StreamTypeRTCP, f.StreamType)
require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, f.Payload)
}
if proto == "udp" {
l1.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[1],
})
} else {
err = base.InterleavedFrame{
TrackID: 0,
StreamType: StreamTypeRTCP,
Payload: []byte{0x01, 0x02, 0x03, 0x04},
}.Write(bconn.Writer)
require.NoError(t, err)
}
err = req.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method)
base.Response{
StatusCode: base.StatusOK,
}.Write(bconn.Writer)
conn.Close()
}()
conf := ClientConf{
StreamProtocol: func() *StreamProtocol {
if proto == "udp" {
v := StreamProtocolUDP
return &v
}
v := StreamProtocolTCP
return &v
}(),
}
track, err := NewTrackH264(96, []byte("123456"), []byte("123456"))
require.NoError(t, err)
conn, err := conf.DialPublish("rtsp://localhost:8554/teststream",
Tracks{track})
require.NoError(t, err)
recvDone := make(chan struct{})
done := conn.ReadFrames(func(trackID int, streamType StreamType, payload []byte) {
require.Equal(t, 0, trackID)
require.Equal(t, StreamTypeRTCP, streamType)
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload)
close(recvDone)
})
err = conn.WriteFrame(track.ID, StreamTypeRTCP,
[]byte{0x05, 0x06, 0x07, 0x08})
require.NoError(t, err)
<-recvDone
conn.Close()
<-done
})
}
}

View File

@@ -43,7 +43,6 @@ func (cc *ClientConn) Play() (*base.Response, error) {
} }
cc.state = clientConnStatePlay cc.state = clientConnStatePlay
cc.writeFrameAllowed = true
return res, nil return res, nil
} }
@@ -246,81 +245,3 @@ func (cc *ClientConn) backgroundPlayTCP() error {
} }
} }
} }
// ReadFrames starts reading frames.
// it returns a channel that is written when the reading stops.
// This can be called only after Play().
func (cc *ClientConn) ReadFrames(onFrame func(int, StreamType, []byte)) chan error {
// channel is buffered, since listening to it is not mandatory
done := make(chan error, 1)
err := cc.checkState(map[clientConnState]struct{}{
clientConnStatePlay: {},
})
if err != nil {
done <- err
return done
}
cc.backgroundRunning = true
cc.backgroundTerminate = make(chan struct{})
cc.backgroundDone = make(chan struct{})
cc.readCB = onFrame
go func() {
if *cc.streamProtocol == StreamProtocolUDP {
err := cc.backgroundPlayUDP()
close(cc.backgroundDone)
// automatically change protocol in case of timeout
if _, ok := err.(liberrors.ErrClientNoUDPPacketsRecently); ok {
if cc.conf.StreamProtocol == nil {
err := func() error {
prevURL := cc.streamURL
prevTracks := cc.tracks
cc.reset()
v := StreamProtocolTCP
cc.streamProtocol = &v
err := cc.connOpen(prevURL.Scheme, prevURL.Host)
if err != nil {
return err
}
_, err = cc.Options(prevURL)
if err != nil {
cc.Close()
return err
}
for _, track := range prevTracks {
_, err := cc.Setup(headers.TransportModePlay, track.track, 0, 0)
if err != nil {
cc.Close()
return err
}
}
_, err = cc.Play()
if err != nil {
cc.Close()
return err
}
return <-cc.ReadFrames(onFrame)
}()
done <- err
return
}
}
done <- err
} else {
defer close(cc.backgroundDone)
done <- cc.backgroundPlayTCP()
}
}()
return done
}

View File

@@ -197,11 +197,6 @@ func TestClientRead(t *testing.T) {
<-frameRecv <-frameRecv
conn.Close() conn.Close()
<-done <-done
done = conn.ReadFrames(func(id int, typ StreamType, payload []byte) {
t.Error("should not happen")
})
<-done
}) })
} }
} }
@@ -1147,6 +1142,11 @@ func TestClientReadRTCP(t *testing.T) {
} }
func TestClientReadWriteManualRTCP(t *testing.T) { func TestClientReadWriteManualRTCP(t *testing.T) {
for _, proto := range []string{
"udp",
"tcp",
} {
t.Run(proto, func(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554") l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err) require.NoError(t, err)
defer l.Close() defer l.Close()
@@ -1198,22 +1198,37 @@ func TestClientReadWriteManualRTCP(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Setup, req.Method) require.Equal(t, base.Setup, req.Method)
var th headers.Transport var inTH headers.Transport
err = th.Read(req.Header["Transport"]) err = inTH.Read(req.Header["Transport"])
require.NoError(t, err) require.NoError(t, err)
err = base.Response{ th := headers.Transport{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": headers.Transport{
Protocol: StreamProtocolTCP,
Delivery: func() *base.StreamDelivery { Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast v := base.StreamDeliveryUnicast
return &v return &v
}(), }(),
ClientPorts: th.ClientPorts, }
InterleavedIDs: &[2]int{0, 1},
}.Write(), var l1 net.PacketConn
if proto == "udp" {
var err error
l1, err = net.ListenPacket("udp", "localhost:34557")
require.NoError(t, err)
defer l1.Close()
th.Protocol = StreamProtocolUDP
th.ServerPorts = &[2]int{34556, 34557}
th.ClientPorts = inTH.ClientPorts
} else {
th.Protocol = StreamProtocolTCP
th.InterleavedIDs = inTH.InterleavedIDs
}
err = base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Write(),
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -1227,16 +1242,34 @@ func TestClientReadWriteManualRTCP(t *testing.T) {
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
if proto == "udp" {
buf := make([]byte, 2048)
// skip firewall opening
_, _, err := l1.ReadFrom(buf)
require.NoError(t, err)
n, _, err := l1.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, buf[:n])
} else {
var f base.InterleavedFrame var f base.InterleavedFrame
f.Payload = make([]byte, 2048) f.Payload = make([]byte, 2048)
err = f.Read(bconn.Reader) err = f.Read(bconn.Reader)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, f.TrackID)
require.Equal(t, StreamTypeRTCP, f.StreamType) require.Equal(t, StreamTypeRTCP, f.StreamType)
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, f.Payload) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, f.Payload)
}
}() }()
conf := ClientConf{ conf := ClientConf{
StreamProtocol: func() *StreamProtocol { StreamProtocol: func() *StreamProtocol {
if proto == "udp" {
v := StreamProtocolUDP
return &v
}
v := StreamProtocolTCP v := StreamProtocolTCP
return &v return &v
}(), }(),
@@ -1246,6 +1279,13 @@ func TestClientReadWriteManualRTCP(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
conn.ReadFrames(func(trackID int, streamType StreamType, payload []byte) {
})
time.Sleep(500 * time.Millisecond)
err = conn.WriteFrame(0, StreamTypeRTCP, []byte{0x01, 0x02, 0x03, 0x04}) err = conn.WriteFrame(0, StreamTypeRTCP, []byte{0x01, 0x02, 0x03, 0x04})
require.NoError(t, err) require.NoError(t, err)
})
}
} }

View File

@@ -73,6 +73,7 @@ func (l *clientConnUDPListener) stop() {
func (l *clientConnUDPListener) run() { func (l *clientConnUDPListener) run() {
defer close(l.done) defer close(l.done)
if l.cc.state == clientConnStatePlay {
for { for {
buf := l.frameBuffer.Next() buf := l.frameBuffer.Next()
n, addr, err := l.pc.ReadFrom(buf) n, addr, err := l.pc.ReadFrom(buf)
@@ -91,6 +92,25 @@ func (l *clientConnUDPListener) run() {
l.cc.tracks[l.trackID].rtcpReceiver.ProcessFrame(now, l.streamType, buf[:n]) l.cc.tracks[l.trackID].rtcpReceiver.ProcessFrame(now, l.streamType, buf[:n])
l.cc.readCB(l.trackID, l.streamType, buf[:n]) l.cc.readCB(l.trackID, l.streamType, buf[:n])
} }
} else {
for {
buf := l.frameBuffer.Next()
n, addr, err := l.pc.ReadFrom(buf)
if err != nil {
return
}
uaddr := addr.(*net.UDPAddr)
if !l.remoteIP.Equal(uaddr.IP) || (l.remotePort != 0 && l.remotePort != uaddr.Port) {
continue
}
now := time.Now()
atomic.StoreInt64(l.lastFrameTime, now.Unix())
l.cc.readCB(l.trackID, l.streamType, buf[:n])
}
}
} }
func (l *clientConnUDPListener) write(buf []byte) error { func (l *clientConnUDPListener) write(buf []byte) error {