diff --git a/clientconn.go b/clientconn.go index f589f517..1622bfd7 100644 --- a/clientconn.go +++ b/clientconn.go @@ -84,17 +84,16 @@ type ClientConn struct { streamProtocol *StreamProtocol tracks map[int]clientConnTrack getParameterSupported bool + writeMutex sync.Mutex + writeFrameAllowed bool + writeError error + backgroundRunning bool // read rtpInfo *headers.RTPInfo tcpFrameBuffer *multibuffer.MultiBuffer readCB func(int, StreamType, []byte) - // publish - publishError error - publishWriteMutex sync.RWMutex - publishOpen bool - // in backgroundTerminate chan struct{} @@ -137,9 +136,9 @@ func newClientConn(conf ClientConf, scheme string, host string) (*ClientConn, er } cc := &ClientConn{ - conf: conf, - tracks: make(map[int]clientConnTrack), - publishError: fmt.Errorf("not running"), + conf: conf, + tracks: make(map[int]clientConnTrack), + writeError: fmt.Errorf("not running"), } err := cc.connOpen(scheme, host) @@ -152,10 +151,12 @@ func newClientConn(conf ClientConf, scheme string, host string) (*ClientConn, er // Close closes all the ClientConn resources. func (cc *ClientConn) Close() error { - if cc.state == clientConnStatePlay || cc.state == clientConnStateRecord { + if cc.backgroundRunning { close(cc.backgroundTerminate) <-cc.backgroundDone + } + if cc.state == clientConnStatePlay || cc.state == clientConnStateRecord { cc.Do(&base.Request{ Method: base.Teardown, URL: cc.streamURL, @@ -186,6 +187,7 @@ func (cc *ClientConn) reset() { cc.streamProtocol = nil cc.tracks = make(map[int]clientConnTrack) cc.getParameterSupported = false + cc.backgroundRunning = false // read cc.rtpInfo = nil @@ -718,6 +720,7 @@ func (cc *ClientConn) Pause() (*base.Response, error) { close(cc.backgroundTerminate) <-cc.backgroundDone + cc.backgroundRunning = false res, err := cc.Do(&base.Request{ Method: base.Pause, @@ -741,3 +744,34 @@ func (cc *ClientConn) Pause() (*base.Response, error) { return res, nil } + +// WriteFrame writes a frame. +func (cc *ClientConn) WriteFrame(trackID int, streamType StreamType, payload []byte) error { + now := time.Now() + + cc.writeMutex.Lock() + defer cc.writeMutex.Unlock() + + if !cc.writeFrameAllowed { + return cc.writeError + } + + if cc.tracks[trackID].rtcpSender != nil { + cc.tracks[trackID].rtcpSender.ProcessFrame(now, streamType, payload) + } + + if *cc.streamProtocol == StreamProtocolUDP { + if streamType == StreamTypeRTP { + return cc.tracks[trackID].udpRTPListener.write(payload) + } + return cc.tracks[trackID].udpRTCPListener.write(payload) + } + + cc.nconn.SetWriteDeadline(now.Add(cc.conf.WriteTimeout)) + frame := base.InterleavedFrame{ + TrackID: trackID, + StreamType: streamType, + Payload: payload, + } + return frame.Write(cc.bw) +} diff --git a/clientconnpublish.go b/clientconnpublish.go index e6e823ee..84404b30 100644 --- a/clientconnpublish.go +++ b/clientconnpublish.go @@ -55,12 +55,6 @@ func (cc *ClientConn) Announce(u *base.URL, tracks Tracks) (*base.Response, erro } func (cc *ClientConn) backgroundRecordUDP() { - defer func() { - cc.publishWriteMutex.Lock() - defer cc.publishWriteMutex.Unlock() - cc.publishOpen = false - }() - // disable deadline cc.nconn.SetReadDeadline(time.Time{}) @@ -84,34 +78,26 @@ func (cc *ClientConn) backgroundRecordUDP() { case <-cc.backgroundTerminate: cc.nconn.SetReadDeadline(time.Now()) <-readerDone - cc.publishError = fmt.Errorf("terminated") + cc.writeError = fmt.Errorf("terminated") return case <-reportTicker.C: - cc.publishWriteMutex.Lock() now := time.Now() - for _, cct := range cc.tracks { - r := cct.rtcpSender.Report(now) - if r != nil { - cct.udpRTCPListener.write(r) + for trackID, cct := range cc.tracks { + sr := cct.rtcpSender.Report(now) + if sr != nil { + cc.WriteFrame(trackID, StreamTypeRTCP, sr) } } - cc.publishWriteMutex.Unlock() case err := <-readerDone: - cc.publishError = err + cc.writeError = err return } } } func (cc *ClientConn) backgroundRecordTCP() { - defer func() { - cc.publishWriteMutex.Lock() - defer cc.publishWriteMutex.Unlock() - cc.publishOpen = false - }() - reportTicker := time.NewTicker(cc.conf.senderReportPeriod) defer reportTicker.Stop() @@ -121,21 +107,13 @@ func (cc *ClientConn) backgroundRecordTCP() { return case <-reportTicker.C: - cc.publishWriteMutex.Lock() now := time.Now() for trackID, cct := range cc.tracks { - r := cct.rtcpSender.Report(now) - if r != nil { - cc.nconn.SetWriteDeadline(time.Now().Add(cc.conf.WriteTimeout)) - frame := base.InterleavedFrame{ - TrackID: trackID, - StreamType: StreamTypeRTCP, - Payload: r, - } - frame.Write(cc.bw) + sr := cct.rtcpSender.Report(now) + if sr != nil { + cc.WriteFrame(trackID, StreamTypeRTCP, sr) } } - cc.publishWriteMutex.Unlock() } } } @@ -164,13 +142,21 @@ func (cc *ClientConn) Record() (*base.Response, error) { } cc.state = clientConnStateRecord - cc.publishOpen = true + cc.writeFrameAllowed = true + + cc.backgroundRunning = true cc.backgroundTerminate = make(chan struct{}) cc.backgroundDone = make(chan struct{}) go func() { defer close(cc.backgroundDone) + defer func() { + cc.writeMutex.Lock() + defer cc.writeMutex.Unlock() + cc.writeFrameAllowed = false + }() + if *cc.streamProtocol == StreamProtocolUDP { cc.backgroundRecordUDP() } else { @@ -180,33 +166,3 @@ func (cc *ClientConn) Record() (*base.Response, error) { return nil, nil } - -// WriteFrame writes a frame. -// This can be called only after Record(). -func (cc *ClientConn) WriteFrame(trackID int, streamType StreamType, payload []byte) error { - cc.publishWriteMutex.RLock() - defer cc.publishWriteMutex.RUnlock() - - if !cc.publishOpen { - return cc.publishError - } - - now := time.Now() - - cc.tracks[trackID].rtcpSender.ProcessFrame(now, streamType, payload) - - if *cc.streamProtocol == StreamProtocolUDP { - if streamType == StreamTypeRTP { - return cc.tracks[trackID].udpRTPListener.write(payload) - } - return cc.tracks[trackID].udpRTCPListener.write(payload) - } - - cc.nconn.SetWriteDeadline(now.Add(cc.conf.WriteTimeout)) - frame := base.InterleavedFrame{ - TrackID: trackID, - StreamType: streamType, - Payload: payload, - } - return frame.Write(cc.bw) -} diff --git a/clientconnpublish_test.go b/clientconnpublish_test.go index 96b49f34..15b32636 100644 --- a/clientconnpublish_test.go +++ b/clientconnpublish_test.go @@ -743,7 +743,7 @@ func TestClientPublishRTCP(t *testing.T) { Timestamp: 54352, SSRC: 753621, }, - Payload: []byte("\x01\x02\x03\x04"), + Payload: []byte{0x01, 0x02, 0x03, 0x04}, }).Marshal() err = conn.WriteFrame(track.ID, StreamTypeRTP, byts) require.NoError(t, err) diff --git a/clientconnread.go b/clientconnread.go index 76d2ba0a..83cb6459 100644 --- a/clientconnread.go +++ b/clientconnread.go @@ -42,6 +42,9 @@ func (cc *ClientConn) Play() (*base.Response, error) { cc.rtpInfo = &ri } + cc.state = clientConnStatePlay + cc.writeFrameAllowed = true + return res, nil } @@ -252,17 +255,17 @@ func (cc *ClientConn) ReadFrames(onFrame func(int, StreamType, []byte)) chan err done := make(chan error, 1) err := cc.checkState(map[clientConnState]struct{}{ - clientConnStatePrePlay: {}, + clientConnStatePlay: {}, }) if err != nil { done <- err return done } - cc.state = clientConnStatePlay - cc.readCB = onFrame + cc.backgroundRunning = true cc.backgroundTerminate = make(chan struct{}) cc.backgroundDone = make(chan struct{}) + cc.readCB = onFrame go func() { if *cc.streamProtocol == StreamProtocolUDP { diff --git a/clientconnread_test.go b/clientconnread_test.go index c369f4ac..502770cf 100644 --- a/clientconnread_test.go +++ b/clientconnread_test.go @@ -1072,7 +1072,7 @@ func TestClientReadRTCP(t *testing.T) { Timestamp: 54352, SSRC: 753621, }, - Payload: []byte("\x01\x02\x03\x04"), + Payload: []byte{0x01, 0x02, 0x03, 0x04}, }).Marshal() err = base.InterleavedFrame{ TrackID: 0, @@ -1145,3 +1145,107 @@ func TestClientReadRTCP(t *testing.T) { conn.Close() <-done } + +func TestClientReadWriteManualRTCP(t *testing.T) { + l, err := net.Listen("tcp", "localhost:8554") + require.NoError(t, err) + defer l.Close() + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) + + conn, err := l.Accept() + require.NoError(t, err) + defer conn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + var req base.Request + err = req.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.Options, req.Method) + + err = base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + string(base.Setup), + string(base.Play), + }, ", ")}, + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + err = req.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.Describe, req.Method) + + track, err := NewTrackH264(96, []byte("123456"), []byte("123456")) + require.NoError(t, err) + + err = base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Body: Tracks{track}.Write(), + }.Write(bconn.Writer) + require.NoError(t, err) + + err = req.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.Setup, req.Method) + + var th headers.Transport + err = th.Read(req.Header["Transport"]) + require.NoError(t, err) + + err = base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": headers.Transport{ + Protocol: StreamProtocolTCP, + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + ClientPorts: th.ClientPorts, + InterleavedIDs: &[2]int{0, 1}, + }.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + err = req.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.Play, req.Method) + + err = base.Response{ + StatusCode: base.StatusOK, + }.Write(bconn.Writer) + require.NoError(t, err) + + var f base.InterleavedFrame + f.Payload = make([]byte, 2048) + err = f.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, StreamTypeRTCP, f.StreamType) + require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, f.Payload) + }() + + conf := ClientConf{ + StreamProtocol: func() *StreamProtocol { + v := StreamProtocolTCP + return &v + }(), + } + + conn, err := conf.DialRead("rtsp://localhost:8554/teststream") + require.NoError(t, err) + defer conn.Close() + + err = conn.WriteFrame(0, StreamTypeRTCP, []byte{0x01, 0x02, 0x03, 0x04}) + require.NoError(t, err) +} diff --git a/serverconnpublish_test.go b/serverconnpublish_test.go index dc69289d..ef4da20a 100644 --- a/serverconnpublish_test.go +++ b/serverconnpublish_test.go @@ -569,7 +569,7 @@ func TestServerPublishFrames(t *testing.T) { if atomic.SwapUint64(&rtpReceived, 1) == 0 { require.Equal(t, 0, trackID) require.Equal(t, StreamTypeRTP, typ) - require.Equal(t, []byte("\x01\x02\x03\x04"), buf) + require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, buf) } else { require.Equal(t, 0, trackID) require.Equal(t, StreamTypeRTCP, typ) @@ -675,7 +675,7 @@ func TestServerPublishFrames(t *testing.T) { require.NoError(t, err) defer l1.Close() - l1.WriteTo([]byte("\x01\x02\x03\x04"), &net.UDPAddr{ + l1.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ServerPorts[0], }) @@ -694,7 +694,7 @@ func TestServerPublishFrames(t *testing.T) { err = base.InterleavedFrame{ TrackID: 0, StreamType: StreamTypeRTP, - Payload: []byte("\x01\x02\x03\x04"), + Payload: []byte{0x01, 0x02, 0x03, 0x04}, }.Write(bconn.Writer) require.NoError(t, err) @@ -839,7 +839,7 @@ func TestServerPublishFramesErrorWrongProtocol(t *testing.T) { err = base.InterleavedFrame{ TrackID: 0, StreamType: StreamTypeRTP, - Payload: []byte("\x01\x02\x03\x04"), + Payload: []byte{0x01, 0x02, 0x03, 0x04}, }.Write(bconn.Writer) require.NoError(t, err) } @@ -976,7 +976,7 @@ func TestServerPublishRTCP(t *testing.T) { Timestamp: 54352, SSRC: 753621, }, - Payload: []byte("\x01\x02\x03\x04"), + Payload: []byte{0x01, 0x02, 0x03, 0x04}, }).Marshal() err = base.InterleavedFrame{ TrackID: 0, diff --git a/serverconnread_test.go b/serverconnread_test.go index 9a0ebe69..d5065afd 100644 --- a/serverconnread_test.go +++ b/serverconnread_test.go @@ -332,7 +332,7 @@ func TestServerReadFrames(t *testing.T) { onFrame := func(trackID int, typ StreamType, buf []byte) { require.Equal(t, 0, trackID) require.Equal(t, StreamTypeRTCP, typ) - require.Equal(t, []byte("\x01\x02\x03\x04"), buf) + require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, buf) close(framesReceived) } @@ -406,7 +406,7 @@ func TestServerReadFrames(t *testing.T) { require.NoError(t, err) defer l1.Close() - l1.WriteTo([]byte("\x01\x02\x03\x04"), &net.UDPAddr{ + l1.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ServerPorts[1], }) @@ -414,7 +414,7 @@ func TestServerReadFrames(t *testing.T) { err = base.InterleavedFrame{ TrackID: 0, StreamType: StreamTypeRTCP, - Payload: []byte("\x01\x02\x03\x04"), + Payload: []byte{0x01, 0x02, 0x03, 0x04}, }.Write(bconn.Writer) require.NoError(t, err) }