client: allow calling WriteFrame() when reading

This commit is contained in:
aler9
2021-03-28 17:55:40 +02:00
parent 4c5742354c
commit 1c88a7fee9
7 changed files with 181 additions and 84 deletions

View File

@@ -84,17 +84,16 @@ type ClientConn struct {
streamProtocol *StreamProtocol streamProtocol *StreamProtocol
tracks map[int]clientConnTrack tracks map[int]clientConnTrack
getParameterSupported bool getParameterSupported bool
writeMutex sync.Mutex
writeFrameAllowed bool
writeError error
backgroundRunning bool
// read // read
rtpInfo *headers.RTPInfo rtpInfo *headers.RTPInfo
tcpFrameBuffer *multibuffer.MultiBuffer tcpFrameBuffer *multibuffer.MultiBuffer
readCB func(int, StreamType, []byte) readCB func(int, StreamType, []byte)
// publish
publishError error
publishWriteMutex sync.RWMutex
publishOpen bool
// in // in
backgroundTerminate chan struct{} backgroundTerminate chan struct{}
@@ -139,7 +138,7 @@ func newClientConn(conf ClientConf, scheme string, host string) (*ClientConn, er
cc := &ClientConn{ cc := &ClientConn{
conf: conf, conf: conf,
tracks: make(map[int]clientConnTrack), tracks: make(map[int]clientConnTrack),
publishError: fmt.Errorf("not running"), writeError: fmt.Errorf("not running"),
} }
err := cc.connOpen(scheme, host) err := cc.connOpen(scheme, host)
@@ -152,10 +151,12 @@ func newClientConn(conf ClientConf, scheme string, host string) (*ClientConn, er
// Close closes all the ClientConn resources. // Close closes all the ClientConn resources.
func (cc *ClientConn) Close() error { func (cc *ClientConn) Close() error {
if cc.state == clientConnStatePlay || cc.state == clientConnStateRecord { if cc.backgroundRunning {
close(cc.backgroundTerminate) close(cc.backgroundTerminate)
<-cc.backgroundDone <-cc.backgroundDone
}
if cc.state == clientConnStatePlay || cc.state == clientConnStateRecord {
cc.Do(&base.Request{ cc.Do(&base.Request{
Method: base.Teardown, Method: base.Teardown,
URL: cc.streamURL, URL: cc.streamURL,
@@ -186,6 +187,7 @@ func (cc *ClientConn) reset() {
cc.streamProtocol = nil cc.streamProtocol = nil
cc.tracks = make(map[int]clientConnTrack) cc.tracks = make(map[int]clientConnTrack)
cc.getParameterSupported = false cc.getParameterSupported = false
cc.backgroundRunning = false
// read // read
cc.rtpInfo = nil cc.rtpInfo = nil
@@ -718,6 +720,7 @@ func (cc *ClientConn) Pause() (*base.Response, error) {
close(cc.backgroundTerminate) close(cc.backgroundTerminate)
<-cc.backgroundDone <-cc.backgroundDone
cc.backgroundRunning = false
res, err := cc.Do(&base.Request{ res, err := cc.Do(&base.Request{
Method: base.Pause, Method: base.Pause,
@@ -741,3 +744,34 @@ func (cc *ClientConn) Pause() (*base.Response, error) {
return res, nil return res, nil
} }
// WriteFrame writes a frame.
func (cc *ClientConn) WriteFrame(trackID int, streamType StreamType, payload []byte) error {
now := time.Now()
cc.writeMutex.Lock()
defer cc.writeMutex.Unlock()
if !cc.writeFrameAllowed {
return cc.writeError
}
if cc.tracks[trackID].rtcpSender != nil {
cc.tracks[trackID].rtcpSender.ProcessFrame(now, streamType, payload)
}
if *cc.streamProtocol == StreamProtocolUDP {
if streamType == StreamTypeRTP {
return cc.tracks[trackID].udpRTPListener.write(payload)
}
return cc.tracks[trackID].udpRTCPListener.write(payload)
}
cc.nconn.SetWriteDeadline(now.Add(cc.conf.WriteTimeout))
frame := base.InterleavedFrame{
TrackID: trackID,
StreamType: streamType,
Payload: payload,
}
return frame.Write(cc.bw)
}

View File

@@ -55,12 +55,6 @@ func (cc *ClientConn) Announce(u *base.URL, tracks Tracks) (*base.Response, erro
} }
func (cc *ClientConn) backgroundRecordUDP() { func (cc *ClientConn) backgroundRecordUDP() {
defer func() {
cc.publishWriteMutex.Lock()
defer cc.publishWriteMutex.Unlock()
cc.publishOpen = false
}()
// disable deadline // disable deadline
cc.nconn.SetReadDeadline(time.Time{}) cc.nconn.SetReadDeadline(time.Time{})
@@ -84,34 +78,26 @@ func (cc *ClientConn) backgroundRecordUDP() {
case <-cc.backgroundTerminate: case <-cc.backgroundTerminate:
cc.nconn.SetReadDeadline(time.Now()) cc.nconn.SetReadDeadline(time.Now())
<-readerDone <-readerDone
cc.publishError = fmt.Errorf("terminated") cc.writeError = fmt.Errorf("terminated")
return return
case <-reportTicker.C: case <-reportTicker.C:
cc.publishWriteMutex.Lock()
now := time.Now() now := time.Now()
for _, cct := range cc.tracks { for trackID, cct := range cc.tracks {
r := cct.rtcpSender.Report(now) sr := cct.rtcpSender.Report(now)
if r != nil { if sr != nil {
cct.udpRTCPListener.write(r) cc.WriteFrame(trackID, StreamTypeRTCP, sr)
} }
} }
cc.publishWriteMutex.Unlock()
case err := <-readerDone: case err := <-readerDone:
cc.publishError = err cc.writeError = err
return return
} }
} }
} }
func (cc *ClientConn) backgroundRecordTCP() { func (cc *ClientConn) backgroundRecordTCP() {
defer func() {
cc.publishWriteMutex.Lock()
defer cc.publishWriteMutex.Unlock()
cc.publishOpen = false
}()
reportTicker := time.NewTicker(cc.conf.senderReportPeriod) reportTicker := time.NewTicker(cc.conf.senderReportPeriod)
defer reportTicker.Stop() defer reportTicker.Stop()
@@ -121,21 +107,13 @@ func (cc *ClientConn) backgroundRecordTCP() {
return return
case <-reportTicker.C: case <-reportTicker.C:
cc.publishWriteMutex.Lock()
now := time.Now() now := time.Now()
for trackID, cct := range cc.tracks { for trackID, cct := range cc.tracks {
r := cct.rtcpSender.Report(now) sr := cct.rtcpSender.Report(now)
if r != nil { if sr != nil {
cc.nconn.SetWriteDeadline(time.Now().Add(cc.conf.WriteTimeout)) cc.WriteFrame(trackID, StreamTypeRTCP, sr)
frame := base.InterleavedFrame{
TrackID: trackID,
StreamType: StreamTypeRTCP,
Payload: r,
}
frame.Write(cc.bw)
} }
} }
cc.publishWriteMutex.Unlock()
} }
} }
} }
@@ -164,13 +142,21 @@ func (cc *ClientConn) Record() (*base.Response, error) {
} }
cc.state = clientConnStateRecord cc.state = clientConnStateRecord
cc.publishOpen = true cc.writeFrameAllowed = true
cc.backgroundRunning = true
cc.backgroundTerminate = make(chan struct{}) cc.backgroundTerminate = make(chan struct{})
cc.backgroundDone = make(chan struct{}) cc.backgroundDone = make(chan struct{})
go func() { go func() {
defer close(cc.backgroundDone) defer close(cc.backgroundDone)
defer func() {
cc.writeMutex.Lock()
defer cc.writeMutex.Unlock()
cc.writeFrameAllowed = false
}()
if *cc.streamProtocol == StreamProtocolUDP { if *cc.streamProtocol == StreamProtocolUDP {
cc.backgroundRecordUDP() cc.backgroundRecordUDP()
} else { } else {
@@ -180,33 +166,3 @@ func (cc *ClientConn) Record() (*base.Response, error) {
return nil, nil return nil, nil
} }
// WriteFrame writes a frame.
// This can be called only after Record().
func (cc *ClientConn) WriteFrame(trackID int, streamType StreamType, payload []byte) error {
cc.publishWriteMutex.RLock()
defer cc.publishWriteMutex.RUnlock()
if !cc.publishOpen {
return cc.publishError
}
now := time.Now()
cc.tracks[trackID].rtcpSender.ProcessFrame(now, streamType, payload)
if *cc.streamProtocol == StreamProtocolUDP {
if streamType == StreamTypeRTP {
return cc.tracks[trackID].udpRTPListener.write(payload)
}
return cc.tracks[trackID].udpRTCPListener.write(payload)
}
cc.nconn.SetWriteDeadline(now.Add(cc.conf.WriteTimeout))
frame := base.InterleavedFrame{
TrackID: trackID,
StreamType: streamType,
Payload: payload,
}
return frame.Write(cc.bw)
}

View File

@@ -743,7 +743,7 @@ func TestClientPublishRTCP(t *testing.T) {
Timestamp: 54352, Timestamp: 54352,
SSRC: 753621, SSRC: 753621,
}, },
Payload: []byte("\x01\x02\x03\x04"), Payload: []byte{0x01, 0x02, 0x03, 0x04},
}).Marshal() }).Marshal()
err = conn.WriteFrame(track.ID, StreamTypeRTP, byts) err = conn.WriteFrame(track.ID, StreamTypeRTP, byts)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -42,6 +42,9 @@ func (cc *ClientConn) Play() (*base.Response, error) {
cc.rtpInfo = &ri cc.rtpInfo = &ri
} }
cc.state = clientConnStatePlay
cc.writeFrameAllowed = true
return res, nil return res, nil
} }
@@ -252,17 +255,17 @@ func (cc *ClientConn) ReadFrames(onFrame func(int, StreamType, []byte)) chan err
done := make(chan error, 1) done := make(chan error, 1)
err := cc.checkState(map[clientConnState]struct{}{ err := cc.checkState(map[clientConnState]struct{}{
clientConnStatePrePlay: {}, clientConnStatePlay: {},
}) })
if err != nil { if err != nil {
done <- err done <- err
return done return done
} }
cc.state = clientConnStatePlay cc.backgroundRunning = true
cc.readCB = onFrame
cc.backgroundTerminate = make(chan struct{}) cc.backgroundTerminate = make(chan struct{})
cc.backgroundDone = make(chan struct{}) cc.backgroundDone = make(chan struct{})
cc.readCB = onFrame
go func() { go func() {
if *cc.streamProtocol == StreamProtocolUDP { if *cc.streamProtocol == StreamProtocolUDP {

View File

@@ -1072,7 +1072,7 @@ func TestClientReadRTCP(t *testing.T) {
Timestamp: 54352, Timestamp: 54352,
SSRC: 753621, SSRC: 753621,
}, },
Payload: []byte("\x01\x02\x03\x04"), Payload: []byte{0x01, 0x02, 0x03, 0x04},
}).Marshal() }).Marshal()
err = base.InterleavedFrame{ err = base.InterleavedFrame{
TrackID: 0, TrackID: 0,
@@ -1145,3 +1145,107 @@ func TestClientReadRTCP(t *testing.T) {
conn.Close() conn.Close()
<-done <-done
} }
func TestClientReadWriteManualRTCP(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
conn, err := l.Accept()
require.NoError(t, err)
defer conn.Close()
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
var req base.Request
err = req.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)
err = base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Describe),
string(base.Setup),
string(base.Play),
}, ", ")},
},
}.Write(bconn.Writer)
require.NoError(t, err)
err = req.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.Describe, req.Method)
track, err := NewTrackH264(96, []byte("123456"), []byte("123456"))
require.NoError(t, err)
err = base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Type": base.HeaderValue{"application/sdp"},
},
Body: Tracks{track}.Write(),
}.Write(bconn.Writer)
require.NoError(t, err)
err = req.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.Setup, req.Method)
var th headers.Transport
err = th.Read(req.Header["Transport"])
require.NoError(t, err)
err = base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": headers.Transport{
Protocol: StreamProtocolTCP,
Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast
return &v
}(),
ClientPorts: th.ClientPorts,
InterleavedIDs: &[2]int{0, 1},
}.Write(),
},
}.Write(bconn.Writer)
require.NoError(t, err)
err = req.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.Play, req.Method)
err = base.Response{
StatusCode: base.StatusOK,
}.Write(bconn.Writer)
require.NoError(t, err)
var f base.InterleavedFrame
f.Payload = make([]byte, 2048)
err = f.Read(bconn.Reader)
require.NoError(t, err)
require.Equal(t, StreamTypeRTCP, f.StreamType)
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, f.Payload)
}()
conf := ClientConf{
StreamProtocol: func() *StreamProtocol {
v := StreamProtocolTCP
return &v
}(),
}
conn, err := conf.DialRead("rtsp://localhost:8554/teststream")
require.NoError(t, err)
defer conn.Close()
err = conn.WriteFrame(0, StreamTypeRTCP, []byte{0x01, 0x02, 0x03, 0x04})
require.NoError(t, err)
}

View File

@@ -569,7 +569,7 @@ func TestServerPublishFrames(t *testing.T) {
if atomic.SwapUint64(&rtpReceived, 1) == 0 { if atomic.SwapUint64(&rtpReceived, 1) == 0 {
require.Equal(t, 0, trackID) require.Equal(t, 0, trackID)
require.Equal(t, StreamTypeRTP, typ) require.Equal(t, StreamTypeRTP, typ)
require.Equal(t, []byte("\x01\x02\x03\x04"), buf) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, buf)
} else { } else {
require.Equal(t, 0, trackID) require.Equal(t, 0, trackID)
require.Equal(t, StreamTypeRTCP, typ) require.Equal(t, StreamTypeRTCP, typ)
@@ -675,7 +675,7 @@ func TestServerPublishFrames(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer l1.Close() defer l1.Close()
l1.WriteTo([]byte("\x01\x02\x03\x04"), &net.UDPAddr{ l1.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: th.ServerPorts[0], Port: th.ServerPorts[0],
}) })
@@ -694,7 +694,7 @@ func TestServerPublishFrames(t *testing.T) {
err = base.InterleavedFrame{ err = base.InterleavedFrame{
TrackID: 0, TrackID: 0,
StreamType: StreamTypeRTP, StreamType: StreamTypeRTP,
Payload: []byte("\x01\x02\x03\x04"), Payload: []byte{0x01, 0x02, 0x03, 0x04},
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -839,7 +839,7 @@ func TestServerPublishFramesErrorWrongProtocol(t *testing.T) {
err = base.InterleavedFrame{ err = base.InterleavedFrame{
TrackID: 0, TrackID: 0,
StreamType: StreamTypeRTP, StreamType: StreamTypeRTP,
Payload: []byte("\x01\x02\x03\x04"), Payload: []byte{0x01, 0x02, 0x03, 0x04},
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
} }
@@ -976,7 +976,7 @@ func TestServerPublishRTCP(t *testing.T) {
Timestamp: 54352, Timestamp: 54352,
SSRC: 753621, SSRC: 753621,
}, },
Payload: []byte("\x01\x02\x03\x04"), Payload: []byte{0x01, 0x02, 0x03, 0x04},
}).Marshal() }).Marshal()
err = base.InterleavedFrame{ err = base.InterleavedFrame{
TrackID: 0, TrackID: 0,

View File

@@ -332,7 +332,7 @@ func TestServerReadFrames(t *testing.T) {
onFrame := func(trackID int, typ StreamType, buf []byte) { onFrame := func(trackID int, typ StreamType, buf []byte) {
require.Equal(t, 0, trackID) require.Equal(t, 0, trackID)
require.Equal(t, StreamTypeRTCP, typ) require.Equal(t, StreamTypeRTCP, typ)
require.Equal(t, []byte("\x01\x02\x03\x04"), buf) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, buf)
close(framesReceived) close(framesReceived)
} }
@@ -406,7 +406,7 @@ func TestServerReadFrames(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer l1.Close() defer l1.Close()
l1.WriteTo([]byte("\x01\x02\x03\x04"), &net.UDPAddr{ l1.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: th.ServerPorts[1], Port: th.ServerPorts[1],
}) })
@@ -414,7 +414,7 @@ func TestServerReadFrames(t *testing.T) {
err = base.InterleavedFrame{ err = base.InterleavedFrame{
TrackID: 0, TrackID: 0,
StreamType: StreamTypeRTCP, StreamType: StreamTypeRTCP,
Payload: []byte("\x01\x02\x03\x04"), Payload: []byte{0x01, 0x02, 0x03, 0x04},
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
} }