ServerConn: save announced tracks

This commit is contained in:
aler9
2021-01-18 22:21:36 +01:00
parent 4c12bbe5a0
commit bc5b3d9cbc
5 changed files with 147 additions and 87 deletions

View File

@@ -59,7 +59,7 @@ func (s clientConnState) String() string {
case clientConnStateRecord: case clientConnStateRecord:
return "record" return "record"
} }
return "uknown" return "unknown"
} }
// ClientConn is a client-side RTSP connection. // ClientConn is a client-side RTSP connection.
@@ -395,15 +395,11 @@ func (c *ClientConn) Describe(u *base.URL) (Tracks, *base.Response, error) {
return nil, nil, fmt.Errorf("wrong Content-Type, expected application/sdp") return nil, nil, fmt.Errorf("wrong Content-Type, expected application/sdp")
} }
tracks, err := ReadTracks(res.Body) tracks, err := ReadTracks(res.Body, u)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
for _, t := range tracks {
t.BaseURL = u
}
return tracks, res, nil return tracks, res, nil
} }

View File

@@ -58,34 +58,20 @@ func (s ServerConnState) String() string {
case ServerConnStateRecord: case ServerConnStateRecord:
return "record" return "record"
} }
return "uknown" return "unknown"
} }
// ServerConnTrack is a track of a ServerConn. // ServerConnSetuppedTrack is a setupped track of a ServerConn.
type ServerConnTrack struct { type ServerConnSetuppedTrack struct {
rtpPort int rtpPort int
rtcpPort int rtcpPort int
} }
func extractTrackID(pathAndQuery string, mode *headers.TransportMode, trackLen int) (int, error) { // ServerConnAnnouncedTrack is an announced track of a ServerConn.
if mode == nil || *mode == headers.TransportModePlay { type ServerConnAnnouncedTrack struct {
i := strings.Index(pathAndQuery, "/trackID=") track *Track
rtcpReceiver *rtcpreceiver.RTCPReceiver
// URL doesn't contain trackID - we assume it's track 0 udpLastFrameTime *int64
if i < 0 {
return 0, nil
}
tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
if err != nil || tmp < 0 {
return 0, fmt.Errorf("invalid track id (%s)", pathAndQuery)
}
trackID := int(tmp)
return trackID, nil
}
return trackLen, nil
} }
// ServerConnReadHandlers allows to set the handlers required by ServerConn.Read. // ServerConnReadHandlers allows to set the handlers required by ServerConn.Read.
@@ -136,15 +122,16 @@ type ServerConnReadHandlers struct {
// ServerConn is a server-side RTSP connection. // ServerConn is a server-side RTSP connection.
type ServerConn struct { type ServerConn struct {
conf ServerConf conf ServerConf
nconn net.Conn nconn net.Conn
br *bufio.Reader br *bufio.Reader
bw *bufio.Writer bw *bufio.Writer
state ServerConnState state ServerConnState
tracks map[int]ServerConnTrack readHandlers ServerConnReadHandlers
tracksProtocol *StreamProtocol setuppedTracks map[int]ServerConnSetuppedTrack
readHandlers ServerConnReadHandlers setuppedTracksProtocol *StreamProtocol
rtcpReceivers []*rtcpreceiver.RTCPReceiver announcedTracks []ServerConnAnnouncedTrack
doEnableFrames bool doEnableFrames bool
framesEnabled bool framesEnabled bool
readTimeoutEnabled bool readTimeoutEnabled bool
@@ -157,7 +144,6 @@ type ServerConn struct {
backgroundRecordTerminate chan struct{} backgroundRecordTerminate chan struct{}
backgroundRecordDone chan struct{} backgroundRecordDone chan struct{}
udpTimeout int32 udpTimeout int32
udpLastFrameTimes []*int64
// in // in
terminate chan struct{} terminate chan struct{}
@@ -176,7 +162,6 @@ func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn {
nconn: nconn, nconn: nconn,
br: bufio.NewReaderSize(conn, serverConnReadBufferSize), br: bufio.NewReaderSize(conn, serverConnReadBufferSize),
bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize), bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize),
tracks: make(map[int]ServerConnTrack),
frameRingBuffer: ringbuffer.New(conf.ReadBufferCount), frameRingBuffer: ringbuffer.New(conf.ReadBufferCount),
backgroundWriteDone: make(chan struct{}), backgroundWriteDone: make(chan struct{}),
terminate: make(chan struct{}), terminate: make(chan struct{}),
@@ -195,25 +180,25 @@ func (sc *ServerConn) State() ServerConnState {
return sc.state return sc.state
} }
// TracksProtocol returns the tracks protocol. // SetuppedTracksProtocol returns the setupped tracks protocol.
func (sc *ServerConn) TracksProtocol() *StreamProtocol { func (sc *ServerConn) SetuppedTracksProtocol() *StreamProtocol {
return sc.tracksProtocol return sc.setuppedTracksProtocol
} }
// TracksLen returns the number of setupped tracks. // SetuppedTracksLen returns the number of setupped tracks.
func (sc *ServerConn) TracksLen() int { func (sc *ServerConn) SetuppedTracksLen() int {
return len(sc.tracks) return len(sc.setuppedTracks)
} }
// HasTrack checks whether a track has been setup. // HasSetuppedTrack checks whether a track has been setup.
func (sc *ServerConn) HasTrack(trackID int) bool { func (sc *ServerConn) HasSetuppedTrack(trackID int) bool {
_, ok := sc.tracks[trackID] _, ok := sc.setuppedTracks[trackID]
return ok return ok
} }
// Tracks returns the setupped tracks. // SetuppedTracks returns the setupped tracks.
func (sc *ServerConn) Tracks() map[int]ServerConnTrack { func (sc *ServerConn) SetuppedTracks() map[int]ServerConnSetuppedTrack {
return sc.tracks return sc.setuppedTracks
} }
func (sc *ServerConn) backgroundWrite() { func (sc *ServerConn) backgroundWrite() {
@@ -269,17 +254,17 @@ func (sc *ServerConn) zone() string {
func (sc *ServerConn) frameModeEnable() { func (sc *ServerConn) frameModeEnable() {
switch sc.state { switch sc.state {
case ServerConnStatePlay: case ServerConnStatePlay:
if *sc.tracksProtocol == StreamProtocolTCP { if *sc.setuppedTracksProtocol == StreamProtocolTCP {
sc.doEnableFrames = true sc.doEnableFrames = true
} }
case ServerConnStateRecord: case ServerConnStateRecord:
if *sc.tracksProtocol == StreamProtocolTCP { if *sc.setuppedTracksProtocol == StreamProtocolTCP {
sc.doEnableFrames = true sc.doEnableFrames = true
sc.readTimeoutEnabled = true sc.readTimeoutEnabled = true
} else { } else {
for trackID, track := range sc.tracks { for trackID, track := range sc.setuppedTracks {
sc.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc) sc.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc)
sc.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc) sc.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc)
@@ -300,7 +285,7 @@ func (sc *ServerConn) frameModeEnable() {
func (sc *ServerConn) frameModeDisable() { func (sc *ServerConn) frameModeDisable() {
switch sc.state { switch sc.state {
case ServerConnStatePlay: case ServerConnStatePlay:
if *sc.tracksProtocol == StreamProtocolTCP { if *sc.setuppedTracksProtocol == StreamProtocolTCP {
sc.framesEnabled = false sc.framesEnabled = false
sc.frameRingBuffer.Close() sc.frameRingBuffer.Close()
<-sc.backgroundWriteDone <-sc.backgroundWriteDone
@@ -310,7 +295,7 @@ func (sc *ServerConn) frameModeDisable() {
close(sc.backgroundRecordTerminate) close(sc.backgroundRecordTerminate)
<-sc.backgroundRecordDone <-sc.backgroundRecordDone
if *sc.tracksProtocol == StreamProtocolTCP { if *sc.setuppedTracksProtocol == StreamProtocolTCP {
sc.readTimeoutEnabled = false sc.readTimeoutEnabled = false
sc.nconn.SetReadDeadline(time.Time{}) sc.nconn.SetReadDeadline(time.Time{})
@@ -319,7 +304,7 @@ func (sc *ServerConn) frameModeDisable() {
<-sc.backgroundWriteDone <-sc.backgroundWriteDone
} else { } else {
for _, track := range sc.tracks { for _, track := range sc.setuppedTracks {
sc.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort) sc.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort)
sc.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort) sc.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort)
} }
@@ -415,7 +400,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, fmt.Errorf("unsupported Content-Type '%s'", ct) }, fmt.Errorf("unsupported Content-Type '%s'", ct)
} }
tracks, err := ReadTracks(req.Body) tracks, err := ReadTracks(req.Body, req.URL)
if err != nil { if err != nil {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
@@ -428,19 +413,52 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, errors.New("no tracks defined") }, errors.New("no tracks defined")
} }
reqPath, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, errors.New("invalid path")
}
for _, track := range tracks {
trackURL, err := track.URL()
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid track URL")
}
trackPath, ok := trackURL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid track URL")
}
if !strings.HasPrefix(trackPath, reqPath) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid track URL: must begin with '%s', but is '%s'",
reqPath, trackPath)
}
}
res, err := sc.readHandlers.OnAnnounce(req, tracks) res, err := sc.readHandlers.OnAnnounce(req, tracks)
if res.StatusCode == 200 { if res.StatusCode == 200 {
sc.state = ServerConnStatePreRecord sc.state = ServerConnStatePreRecord
sc.rtcpReceivers = make([]*rtcpreceiver.RTCPReceiver, len(tracks)) sc.announcedTracks = make([]ServerConnAnnouncedTrack, len(tracks))
sc.udpLastFrameTimes = make([]*int64, len(tracks))
for trackID, track := range tracks { for trackID, track := range tracks {
clockRate, _ := track.ClockRate() clockRate, _ := track.ClockRate()
sc.rtcpReceivers[trackID] = rtcpreceiver.New(nil, clockRate)
v := time.Now().Unix() v := time.Now().Unix()
sc.udpLastFrameTimes[trackID] = &v
sc.announcedTracks[trackID] = ServerConnAnnouncedTrack{
track: track,
rtcpReceiver: rtcpreceiver.New(nil, clockRate),
udpLastFrameTime: &v,
}
} }
} }
@@ -480,20 +498,55 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, fmt.Errorf("multicast is not supported") }, fmt.Errorf("multicast is not supported")
} }
trackID, err := extractTrackID(pathAndQuery, th.Mode, len(sc.tracks)) trackID, err := func() (int, error) {
if th.Mode == nil || *th.Mode == headers.TransportModePlay {
i := strings.Index(pathAndQuery, "/trackID=")
// URL doesn't contain trackID - we assume it's track 0
if i < 0 {
return 0, nil
}
tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
if err != nil || tmp < 0 {
return 0, fmt.Errorf("invalid track (%s)", pathAndQuery)
}
trackID := int(tmp)
// remove track ID from path
nu := &base.URL{
Scheme: req.URL.Scheme,
Host: req.URL.Host,
User: req.URL.User,
}
nu, _ = base.ParseURL(nu.String() + pathAndQuery[:i])
req.URL = nu
return trackID, nil
}
for trackID, track := range sc.announcedTracks {
u, _ := track.track.URL()
if u.String() == req.URL.String() {
return trackID, nil
}
}
return 0, fmt.Errorf("invalid track (%s)", pathAndQuery)
}()
if err != nil { if err != nil {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, err }, err
} }
if _, ok := sc.tracks[trackID]; ok { if _, ok := sc.setuppedTracks[trackID]; ok {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("track %d has already been setup", trackID) }, fmt.Errorf("track %d has already been setup", trackID)
} }
if sc.tracksProtocol != nil && *sc.tracksProtocol != th.Protocol { if sc.setuppedTracksProtocol != nil && *sc.setuppedTracksProtocol != th.Protocol {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("can't setup tracks with different protocols") }, fmt.Errorf("can't setup tracks with different protocols")
@@ -542,15 +595,25 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("transport header does not contain mode=record") }, fmt.Errorf("transport header does not contain mode=record")
} }
if trackID >= len(sc.announcedTracks) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unable to setup track %d", trackID)
}
} }
res, err := sc.readHandlers.OnSetup(req, th, trackID) res, err := sc.readHandlers.OnSetup(req, th, trackID)
if res.StatusCode == 200 { if res.StatusCode == 200 {
sc.tracksProtocol = &th.Protocol sc.setuppedTracksProtocol = &th.Protocol
if sc.setuppedTracks == nil {
sc.setuppedTracks = make(map[int]ServerConnSetuppedTrack)
}
if th.Protocol == StreamProtocolUDP { if th.Protocol == StreamProtocolUDP {
sc.tracks[trackID] = ServerConnTrack{ sc.setuppedTracks[trackID] = ServerConnSetuppedTrack{
rtpPort: th.ClientPorts[0], rtpPort: th.ClientPorts[0],
rtcpPort: th.ClientPorts[1], rtcpPort: th.ClientPorts[1],
} }
@@ -566,7 +629,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}.Write() }.Write()
} else { } else {
sc.tracks[trackID] = ServerConnTrack{} sc.setuppedTracks[trackID] = ServerConnSetuppedTrack{}
res.Header["Transport"] = headers.Transport{ res.Header["Transport"] = headers.Transport{
Protocol: StreamProtocolTCP, Protocol: StreamProtocolTCP,
@@ -610,7 +673,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, err }, err
} }
if len(sc.tracks) == 0 { if len(sc.setuppedTracks) == 0 {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("no tracks have been setup") }, fmt.Errorf("no tracks have been setup")
@@ -637,13 +700,13 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, err }, err
} }
if len(sc.tracks) == 0 { if len(sc.setuppedTracks) == 0 {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("no tracks have been setup") }, fmt.Errorf("no tracks have been setup")
} }
if len(sc.tracks) != len(sc.rtcpReceivers) { if len(sc.setuppedTracks) != len(sc.announcedTracks) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("not all tracks have been setup") }, fmt.Errorf("not all tracks have been setup")
@@ -804,9 +867,9 @@ outer:
switch what.(type) { switch what.(type) {
case *base.InterleavedFrame: case *base.InterleavedFrame:
// forward frame only if it has been set up // forward frame only if it has been set up
if _, ok := sc.tracks[frame.TrackID]; ok { if _, ok := sc.setuppedTracks[frame.TrackID]; ok {
if sc.state == ServerConnStateRecord { if sc.state == ServerConnStateRecord {
sc.rtcpReceivers[frame.TrackID].ProcessFrame(time.Now(), sc.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(),
frame.StreamType, frame.Payload) frame.StreamType, frame.Payload)
} }
sc.readHandlers.OnFrame(frame.TrackID, frame.StreamType, frame.Payload) sc.readHandlers.OnFrame(frame.TrackID, frame.StreamType, frame.Payload)
@@ -861,8 +924,8 @@ func (sc *ServerConn) Read(readHandlers ServerConnReadHandlers) chan error {
// WriteFrame writes a frame. // WriteFrame writes a frame.
func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) { func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) {
if *sc.tracksProtocol == StreamProtocolUDP { if *sc.setuppedTracksProtocol == StreamProtocolUDP {
track := sc.tracks[trackID] track := sc.setuppedTracks[trackID]
if streamType == StreamTypeRTP { if streamType == StreamTypeRTP {
sc.conf.UDPRTPListener.write(payload, &net.UDPAddr{ sc.conf.UDPRTPListener.write(payload, &net.UDPAddr{
@@ -902,13 +965,13 @@ func (sc *ServerConn) backgroundRecord() {
for { for {
select { select {
case <-checkStreamTicker.C: case <-checkStreamTicker.C:
if *sc.tracksProtocol != StreamProtocolUDP { if *sc.setuppedTracksProtocol != StreamProtocolUDP {
continue continue
} }
now := time.Now() now := time.Now()
for _, lastUnix := range sc.udpLastFrameTimes { for _, track := range sc.announcedTracks {
last := time.Unix(atomic.LoadInt64(lastUnix), 0) last := time.Unix(atomic.LoadInt64(track.udpLastFrameTime), 0)
if now.Sub(last) >= sc.conf.ReadTimeout { if now.Sub(last) >= sc.conf.ReadTimeout {
atomic.StoreInt32(&sc.udpTimeout, 1) atomic.StoreInt32(&sc.udpTimeout, 1)
@@ -919,8 +982,8 @@ func (sc *ServerConn) backgroundRecord() {
case <-receiverReportTicker.C: case <-receiverReportTicker.C:
now := time.Now() now := time.Now()
for trackID := range sc.tracks { for trackID, track := range sc.announcedTracks {
r := sc.rtcpReceivers[trackID].Report(now) r := track.rtcpReceiver.Report(now)
sc.WriteFrame(trackID, StreamTypeRTP, r) sc.WriteFrame(trackID, StreamTypeRTP, r)
} }

View File

@@ -128,8 +128,8 @@ func (s *ServerUDPListener) run() {
} }
now := time.Now() now := time.Now()
atomic.StoreInt64(pubData.publisher.udpLastFrameTimes[pubData.trackID], now.Unix()) atomic.StoreInt64(pubData.publisher.announcedTracks[pubData.trackID].udpLastFrameTime, now.Unix())
pubData.publisher.rtcpReceivers[pubData.trackID].ProcessFrame(now, s.streamType, buf[:n]) pubData.publisher.announcedTracks[pubData.trackID].rtcpReceiver.ProcessFrame(now, s.streamType, buf[:n])
pubData.publisher.readHandlers.OnFrame(pubData.trackID, s.streamType, buf[:n]) pubData.publisher.readHandlers.OnFrame(pubData.trackID, s.streamType, buf[:n])
}() }()
} }

View File

@@ -16,7 +16,7 @@ import (
// Track is a track available in a certain URL. // Track is a track available in a certain URL.
type Track struct { type Track struct {
// base url // base URL
BaseURL *base.URL BaseURL *base.URL
// id // id
@@ -204,7 +204,7 @@ func (t *Track) URL() (*base.URL, error) {
type Tracks []*Track type Tracks []*Track
// ReadTracks decodes tracks from SDP. // ReadTracks decodes tracks from SDP.
func ReadTracks(byts []byte) (Tracks, error) { func ReadTracks(byts []byte, baseURL *base.URL) (Tracks, error) {
desc := sdp.SessionDescription{} desc := sdp.SessionDescription{}
err := desc.Unmarshal(byts) err := desc.Unmarshal(byts)
if err != nil { if err != nil {
@@ -215,8 +215,9 @@ func ReadTracks(byts []byte) (Tracks, error) {
for i, media := range desc.MediaDescriptions { for i, media := range desc.MediaDescriptions {
tracks[i] = &Track{ tracks[i] = &Track{
ID: i, BaseURL: baseURL,
Media: media, ID: i,
Media: media,
} }
} }

View File

@@ -72,7 +72,7 @@ func TestTrackClockRate(t *testing.T) {
}, },
} { } {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
tracks, err := ReadTracks(ca.sdp) tracks, err := ReadTracks(ca.sdp, nil)
require.NoError(t, err) require.NoError(t, err)
clockRate, err := tracks[0].ClockRate() clockRate, err := tracks[0].ClockRate()