server: expose both setupped tracks and published tracks

This commit is contained in:
aler9
2021-03-16 10:21:35 +01:00
parent 580917d607
commit b47ebbee01
3 changed files with 37 additions and 38 deletions

View File

@@ -42,7 +42,7 @@ func stringsReverseIndex(s, substr string) int {
func extractTrackIDAndPath(url *base.URL, func extractTrackIDAndPath(url *base.URL,
thMode *headers.TransportMode, thMode *headers.TransportMode,
publishTracks []ServerConnAnnouncedTrack, announcedTracks []ServerConnAnnouncedTrack,
setupPath *string) (int, string, error) { setupPath *string) (int, string, error) {
pathAndQuery, ok := url.RTSPPathAndQuery() pathAndQuery, ok := url.RTSPPathAndQuery()
@@ -80,7 +80,7 @@ func extractTrackIDAndPath(url *base.URL,
return trackID, path, nil return trackID, path, nil
} }
for trackID, track := range publishTracks { for trackID, track := range announcedTracks {
u, _ := track.track.URL() u, _ := track.track.URL()
if u.String() == url.String() { if u.String() == url.String() {
return trackID, *setupPath, nil return trackID, *setupPath, nil
@@ -119,8 +119,8 @@ func (s ServerConnState) String() string {
return "unknown" return "unknown"
} }
// ServerConnTrack is a setupped track of a ServerConn. // ServerConnSetuppedTrack is a setupped track of a ServerConn.
type ServerConnTrack struct { type ServerConnSetuppedTrack struct {
rtpPort int rtpPort int
rtcpPort int rtcpPort int
} }
@@ -187,7 +187,7 @@ type ServerConn struct {
br *bufio.Reader br *bufio.Reader
bw *bufio.Writer bw *bufio.Writer
state ServerConnState state ServerConnState
tracks map[int]ServerConnTrack setuppedTracks map[int]ServerConnSetuppedTrack
setupProtocol *StreamProtocol setupProtocol *StreamProtocol
setupPath *string setupPath *string
@@ -202,7 +202,7 @@ type ServerConn struct {
readHandlers ServerConnReadHandlers readHandlers ServerConnReadHandlers
// publish only // publish only
publishTracks []ServerConnAnnouncedTrack announcedTracks []ServerConnAnnouncedTrack
backgroundRecordTerminate chan struct{} backgroundRecordTerminate chan struct{}
backgroundRecordDone chan struct{} backgroundRecordDone chan struct{}
udpTimeout int32 udpTimeout int32
@@ -247,20 +247,19 @@ func (sc *ServerConn) State() ServerConnState {
return sc.state return sc.state
} }
// StreamProtocol returns the setupped tracks protocol. // StreamProtocol returns the stream protocol of the setupped tracks.
func (sc *ServerConn) StreamProtocol() *StreamProtocol { func (sc *ServerConn) StreamProtocol() *StreamProtocol {
return sc.setupProtocol return sc.setupProtocol
} }
// HasTrack checks whether a track has been setup. // SetuppedTracks returns the setupped tracks.
func (sc *ServerConn) HasTrack(trackID int) bool { func (sc *ServerConn) SetuppedTracks() map[int]ServerConnSetuppedTrack {
_, ok := sc.tracks[trackID] return sc.setuppedTracks
return ok
} }
// Tracks returns the setupped tracks. // AnnouncedTracks returns the announced tracks.
func (sc *ServerConn) Tracks() map[int]ServerConnTrack { func (sc *ServerConn) AnnouncedTracks() []ServerConnAnnouncedTrack {
return sc.tracks return sc.announcedTracks
} }
func (sc *ServerConn) backgroundWrite() { func (sc *ServerConn) backgroundWrite() {
@@ -320,7 +319,7 @@ func (sc *ServerConn) frameModeEnable() {
sc.doEnableFrames = true sc.doEnableFrames = true
} else { } else {
// readers can send RTCP frames, they cannot sent RTP frames // readers can send RTCP frames, they cannot sent RTP frames
for trackID, track := range sc.tracks { for trackID, track := range sc.setuppedTracks {
sc.udpRTCPListener.addClient(sc.ip(), track.rtcpPort, sc, trackID, false) sc.udpRTCPListener.addClient(sc.ip(), track.rtcpPort, sc, trackID, false)
} }
} }
@@ -331,7 +330,7 @@ func (sc *ServerConn) frameModeEnable() {
sc.readTimeoutEnabled = true sc.readTimeoutEnabled = true
} else { } else {
for trackID, track := range sc.tracks { for trackID, track := range sc.setuppedTracks {
sc.udpRTPListener.addClient(sc.ip(), track.rtpPort, sc, trackID, true) sc.udpRTPListener.addClient(sc.ip(), track.rtpPort, sc, trackID, true)
sc.udpRTCPListener.addClient(sc.ip(), track.rtcpPort, sc, trackID, true) sc.udpRTCPListener.addClient(sc.ip(), track.rtcpPort, sc, trackID, true)
@@ -358,7 +357,7 @@ func (sc *ServerConn) frameModeDisable() {
<-sc.backgroundWriteDone <-sc.backgroundWriteDone
} else { } else {
for _, track := range sc.tracks { for _, track := range sc.setuppedTracks {
sc.udpRTCPListener.removeClient(sc.ip(), track.rtcpPort) sc.udpRTCPListener.removeClient(sc.ip(), track.rtcpPort)
} }
} }
@@ -376,7 +375,7 @@ func (sc *ServerConn) frameModeDisable() {
<-sc.backgroundWriteDone <-sc.backgroundWriteDone
} else { } else {
for _, track := range sc.tracks { for _, track := range sc.setuppedTracks {
sc.udpRTPListener.removeClient(sc.ip(), track.rtpPort) sc.udpRTPListener.removeClient(sc.ip(), track.rtpPort)
sc.udpRTCPListener.removeClient(sc.ip(), track.rtcpPort) sc.udpRTCPListener.removeClient(sc.ip(), track.rtcpPort)
} }
@@ -521,12 +520,12 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
sc.state = ServerConnStatePreRecord sc.state = ServerConnStatePreRecord
sc.setupPath = &reqPath sc.setupPath = &reqPath
sc.publishTracks = make([]ServerConnAnnouncedTrack, len(tracks)) sc.announcedTracks = make([]ServerConnAnnouncedTrack, len(tracks))
for trackID, track := range tracks { for trackID, track := range tracks {
clockRate, _ := track.ClockRate() clockRate, _ := track.ClockRate()
v := time.Now().Unix() v := time.Now().Unix()
sc.publishTracks[trackID] = ServerConnAnnouncedTrack{ sc.announcedTracks[trackID] = ServerConnAnnouncedTrack{
track: track, track: track,
rtcpReceiver: rtcpreceiver.New(nil, clockRate), rtcpReceiver: rtcpreceiver.New(nil, clockRate),
udpLastFrameTime: &v, udpLastFrameTime: &v,
@@ -564,14 +563,14 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
} }
trackID, path, err := extractTrackIDAndPath(req.URL, th.Mode, trackID, path, err := extractTrackIDAndPath(req.URL, th.Mode,
sc.publishTracks, sc.setupPath) sc.announcedTracks, sc.setupPath)
if err != nil { if err != nil {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, err }, err
} }
if _, ok := sc.tracks[trackID]; ok { if _, ok := sc.setuppedTracks[trackID]; ok {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("track %d has already been setup", trackID) }, fmt.Errorf("track %d has already been setup", trackID)
@@ -633,12 +632,12 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
if res.StatusCode == 200 { if res.StatusCode == 200 {
sc.setupProtocol = &th.Protocol sc.setupProtocol = &th.Protocol
if sc.tracks == nil { if sc.setuppedTracks == nil {
sc.tracks = make(map[int]ServerConnTrack) sc.setuppedTracks = make(map[int]ServerConnSetuppedTrack)
} }
if th.Protocol == StreamProtocolUDP { if th.Protocol == StreamProtocolUDP {
sc.tracks[trackID] = ServerConnTrack{ sc.setuppedTracks[trackID] = ServerConnSetuppedTrack{
rtpPort: th.ClientPorts[0], rtpPort: th.ClientPorts[0],
rtcpPort: th.ClientPorts[1], rtcpPort: th.ClientPorts[1],
} }
@@ -657,7 +656,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}.Write() }.Write()
} else { } else {
sc.tracks[trackID] = ServerConnTrack{} sc.setuppedTracks[trackID] = ServerConnSetuppedTrack{}
if res.Header == nil { if res.Header == nil {
res.Header = make(base.Header) res.Header = make(base.Header)
@@ -703,7 +702,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, err }, err
} }
if len(sc.tracks) == 0 { if len(sc.setuppedTracks) == 0 {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("no tracks have been setup") }, fmt.Errorf("no tracks have been setup")
@@ -730,16 +729,16 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, err }, err
} }
if len(sc.tracks) == 0 { if len(sc.setuppedTracks) == 0 {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("no tracks have been setup") }, fmt.Errorf("no tracks have been setup")
} }
if len(sc.tracks) != len(sc.publishTracks) { if len(sc.setuppedTracks) != len(sc.announcedTracks) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("not all tracks have been setup") }, fmt.Errorf("not all announced tracks have been setup")
} }
res, err := sc.readHandlers.OnRecord(req) res, err := sc.readHandlers.OnRecord(req)
@@ -897,9 +896,9 @@ outer:
switch what.(type) { switch what.(type) {
case *base.InterleavedFrame: case *base.InterleavedFrame:
// forward frame only if it has been set up // forward frame only if it has been set up
if _, ok := sc.tracks[frame.TrackID]; ok { if _, ok := sc.setuppedTracks[frame.TrackID]; ok {
if sc.state == ServerConnStateRecord { if sc.state == ServerConnStateRecord {
sc.publishTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(), sc.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(),
frame.StreamType, frame.Payload) frame.StreamType, frame.Payload)
} }
sc.readHandlers.OnFrame(frame.TrackID, frame.StreamType, frame.Payload) sc.readHandlers.OnFrame(frame.TrackID, frame.StreamType, frame.Payload)
@@ -955,7 +954,7 @@ func (sc *ServerConn) Read(readHandlers ServerConnReadHandlers) chan error {
// WriteFrame writes a frame. // WriteFrame writes a frame.
func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) { func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) {
if *sc.setupProtocol == StreamProtocolUDP { if *sc.setupProtocol == StreamProtocolUDP {
track := sc.tracks[trackID] track := sc.setuppedTracks[trackID]
if streamType == StreamTypeRTP { if streamType == StreamTypeRTP {
sc.udpRTPListener.write(payload, &net.UDPAddr{ sc.udpRTPListener.write(payload, &net.UDPAddr{
@@ -1000,7 +999,7 @@ func (sc *ServerConn) backgroundRecord() {
} }
now := time.Now() now := time.Now()
for _, track := range sc.publishTracks { for _, track := range sc.announcedTracks {
last := time.Unix(atomic.LoadInt64(track.udpLastFrameTime), 0) last := time.Unix(atomic.LoadInt64(track.udpLastFrameTime), 0)
if now.Sub(last) >= sc.conf.ReadTimeout { if now.Sub(last) >= sc.conf.ReadTimeout {
@@ -1012,7 +1011,7 @@ func (sc *ServerConn) backgroundRecord() {
case <-receiverReportTicker.C: case <-receiverReportTicker.C:
now := time.Now() now := time.Now()
for trackID, track := range sc.publishTracks { for trackID, track := range sc.announcedTracks {
r := track.rtcpReceiver.Report(now) r := track.rtcpReceiver.Report(now)
sc.WriteFrame(trackID, StreamTypeRTP, r) sc.WriteFrame(trackID, StreamTypeRTP, r)
} }

View File

@@ -401,7 +401,7 @@ func TestServerConnPublishRecordPartialTracks(t *testing.T) {
require.Equal(t, base.StatusBadRequest, res.StatusCode) require.Equal(t, base.StatusBadRequest, res.StatusCode)
err = <-serverErr err = <-serverErr
require.Equal(t, "not all tracks have been setup", err.Error()) require.Equal(t, "not all announced tracks have been setup", err.Error())
} }
func TestServerConnPublishReceivePackets(t *testing.T) { func TestServerConnPublishReceivePackets(t *testing.T) {

View File

@@ -121,8 +121,8 @@ func (s *serverUDPListener) run() {
if clientData.isPublishing { if clientData.isPublishing {
now := time.Now() now := time.Now()
atomic.StoreInt64(clientData.sc.publishTracks[clientData.trackID].udpLastFrameTime, now.Unix()) atomic.StoreInt64(clientData.sc.announcedTracks[clientData.trackID].udpLastFrameTime, now.Unix())
clientData.sc.publishTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, s.streamType, buf[:n]) clientData.sc.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, s.streamType, buf[:n])
} }
clientData.sc.readHandlers.OnFrame(clientData.trackID, s.streamType, buf[:n]) clientData.sc.readHandlers.OnFrame(clientData.trackID, s.streamType, buf[:n])