ServerConn: save announced tracks

This commit is contained in:
aler9
2021-01-18 22:21:36 +01:00
parent 4c12bbe5a0
commit bc5b3d9cbc
5 changed files with 147 additions and 87 deletions

View File

@@ -59,7 +59,7 @@ func (s clientConnState) String() string {
case clientConnStateRecord:
return "record"
}
return "uknown"
return "unknown"
}
// ClientConn is a client-side RTSP connection.
@@ -395,15 +395,11 @@ func (c *ClientConn) Describe(u *base.URL) (Tracks, *base.Response, error) {
return nil, nil, fmt.Errorf("wrong Content-Type, expected application/sdp")
}
tracks, err := ReadTracks(res.Body)
tracks, err := ReadTracks(res.Body, u)
if err != nil {
return nil, nil, err
}
for _, t := range tracks {
t.BaseURL = u
}
return tracks, res, nil
}

View File

@@ -58,34 +58,20 @@ func (s ServerConnState) String() string {
case ServerConnStateRecord:
return "record"
}
return "uknown"
return "unknown"
}
// ServerConnTrack is a track of a ServerConn.
type ServerConnTrack struct {
// ServerConnSetuppedTrack is a setupped track of a ServerConn.
type ServerConnSetuppedTrack struct {
rtpPort int
rtcpPort int
}
func extractTrackID(pathAndQuery string, mode *headers.TransportMode, trackLen int) (int, error) {
if mode == nil || *mode == headers.TransportModePlay {
i := strings.Index(pathAndQuery, "/trackID=")
// URL doesn't contain trackID - we assume it's track 0
if i < 0 {
return 0, nil
}
tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
if err != nil || tmp < 0 {
return 0, fmt.Errorf("invalid track id (%s)", pathAndQuery)
}
trackID := int(tmp)
return trackID, nil
}
return trackLen, nil
// ServerConnAnnouncedTrack is an announced track of a ServerConn.
type ServerConnAnnouncedTrack struct {
track *Track
rtcpReceiver *rtcpreceiver.RTCPReceiver
udpLastFrameTime *int64
}
// ServerConnReadHandlers allows to set the handlers required by ServerConn.Read.
@@ -136,15 +122,16 @@ type ServerConnReadHandlers struct {
// ServerConn is a server-side RTSP connection.
type ServerConn struct {
conf ServerConf
nconn net.Conn
br *bufio.Reader
bw *bufio.Writer
state ServerConnState
tracks map[int]ServerConnTrack
tracksProtocol *StreamProtocol
readHandlers ServerConnReadHandlers
rtcpReceivers []*rtcpreceiver.RTCPReceiver
conf ServerConf
nconn net.Conn
br *bufio.Reader
bw *bufio.Writer
state ServerConnState
readHandlers ServerConnReadHandlers
setuppedTracks map[int]ServerConnSetuppedTrack
setuppedTracksProtocol *StreamProtocol
announcedTracks []ServerConnAnnouncedTrack
doEnableFrames bool
framesEnabled bool
readTimeoutEnabled bool
@@ -157,7 +144,6 @@ type ServerConn struct {
backgroundRecordTerminate chan struct{}
backgroundRecordDone chan struct{}
udpTimeout int32
udpLastFrameTimes []*int64
// in
terminate chan struct{}
@@ -176,7 +162,6 @@ func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn {
nconn: nconn,
br: bufio.NewReaderSize(conn, serverConnReadBufferSize),
bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize),
tracks: make(map[int]ServerConnTrack),
frameRingBuffer: ringbuffer.New(conf.ReadBufferCount),
backgroundWriteDone: make(chan struct{}),
terminate: make(chan struct{}),
@@ -195,25 +180,25 @@ func (sc *ServerConn) State() ServerConnState {
return sc.state
}
// TracksProtocol returns the tracks protocol.
func (sc *ServerConn) TracksProtocol() *StreamProtocol {
return sc.tracksProtocol
// SetuppedTracksProtocol returns the setupped tracks protocol.
func (sc *ServerConn) SetuppedTracksProtocol() *StreamProtocol {
return sc.setuppedTracksProtocol
}
// TracksLen returns the number of setupped tracks.
func (sc *ServerConn) TracksLen() int {
return len(sc.tracks)
// SetuppedTracksLen returns the number of setupped tracks.
func (sc *ServerConn) SetuppedTracksLen() int {
return len(sc.setuppedTracks)
}
// HasTrack checks whether a track has been setup.
func (sc *ServerConn) HasTrack(trackID int) bool {
_, ok := sc.tracks[trackID]
// HasSetuppedTrack checks whether a track has been setup.
func (sc *ServerConn) HasSetuppedTrack(trackID int) bool {
_, ok := sc.setuppedTracks[trackID]
return ok
}
// Tracks returns the setupped tracks.
func (sc *ServerConn) Tracks() map[int]ServerConnTrack {
return sc.tracks
// SetuppedTracks returns the setupped tracks.
func (sc *ServerConn) SetuppedTracks() map[int]ServerConnSetuppedTrack {
return sc.setuppedTracks
}
func (sc *ServerConn) backgroundWrite() {
@@ -269,17 +254,17 @@ func (sc *ServerConn) zone() string {
func (sc *ServerConn) frameModeEnable() {
switch sc.state {
case ServerConnStatePlay:
if *sc.tracksProtocol == StreamProtocolTCP {
if *sc.setuppedTracksProtocol == StreamProtocolTCP {
sc.doEnableFrames = true
}
case ServerConnStateRecord:
if *sc.tracksProtocol == StreamProtocolTCP {
if *sc.setuppedTracksProtocol == StreamProtocolTCP {
sc.doEnableFrames = true
sc.readTimeoutEnabled = true
} else {
for trackID, track := range sc.tracks {
for trackID, track := range sc.setuppedTracks {
sc.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc)
sc.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc)
@@ -300,7 +285,7 @@ func (sc *ServerConn) frameModeEnable() {
func (sc *ServerConn) frameModeDisable() {
switch sc.state {
case ServerConnStatePlay:
if *sc.tracksProtocol == StreamProtocolTCP {
if *sc.setuppedTracksProtocol == StreamProtocolTCP {
sc.framesEnabled = false
sc.frameRingBuffer.Close()
<-sc.backgroundWriteDone
@@ -310,7 +295,7 @@ func (sc *ServerConn) frameModeDisable() {
close(sc.backgroundRecordTerminate)
<-sc.backgroundRecordDone
if *sc.tracksProtocol == StreamProtocolTCP {
if *sc.setuppedTracksProtocol == StreamProtocolTCP {
sc.readTimeoutEnabled = false
sc.nconn.SetReadDeadline(time.Time{})
@@ -319,7 +304,7 @@ func (sc *ServerConn) frameModeDisable() {
<-sc.backgroundWriteDone
} else {
for _, track := range sc.tracks {
for _, track := range sc.setuppedTracks {
sc.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort)
sc.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort)
}
@@ -415,7 +400,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, fmt.Errorf("unsupported Content-Type '%s'", ct)
}
tracks, err := ReadTracks(req.Body)
tracks, err := ReadTracks(req.Body, req.URL)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
@@ -428,19 +413,52 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, errors.New("no tracks defined")
}
reqPath, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, errors.New("invalid path")
}
for _, track := range tracks {
trackURL, err := track.URL()
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid track URL")
}
trackPath, ok := trackURL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid track URL")
}
if !strings.HasPrefix(trackPath, reqPath) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid track URL: must begin with '%s', but is '%s'",
reqPath, trackPath)
}
}
res, err := sc.readHandlers.OnAnnounce(req, tracks)
if res.StatusCode == 200 {
sc.state = ServerConnStatePreRecord
sc.rtcpReceivers = make([]*rtcpreceiver.RTCPReceiver, len(tracks))
sc.udpLastFrameTimes = make([]*int64, len(tracks))
sc.announcedTracks = make([]ServerConnAnnouncedTrack, len(tracks))
for trackID, track := range tracks {
clockRate, _ := track.ClockRate()
sc.rtcpReceivers[trackID] = rtcpreceiver.New(nil, clockRate)
v := time.Now().Unix()
sc.udpLastFrameTimes[trackID] = &v
sc.announcedTracks[trackID] = ServerConnAnnouncedTrack{
track: track,
rtcpReceiver: rtcpreceiver.New(nil, clockRate),
udpLastFrameTime: &v,
}
}
}
@@ -480,20 +498,55 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, fmt.Errorf("multicast is not supported")
}
trackID, err := extractTrackID(pathAndQuery, th.Mode, len(sc.tracks))
trackID, err := func() (int, error) {
if th.Mode == nil || *th.Mode == headers.TransportModePlay {
i := strings.Index(pathAndQuery, "/trackID=")
// URL doesn't contain trackID - we assume it's track 0
if i < 0 {
return 0, nil
}
tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
if err != nil || tmp < 0 {
return 0, fmt.Errorf("invalid track (%s)", pathAndQuery)
}
trackID := int(tmp)
// remove track ID from path
nu := &base.URL{
Scheme: req.URL.Scheme,
Host: req.URL.Host,
User: req.URL.User,
}
nu, _ = base.ParseURL(nu.String() + pathAndQuery[:i])
req.URL = nu
return trackID, nil
}
for trackID, track := range sc.announcedTracks {
u, _ := track.track.URL()
if u.String() == req.URL.String() {
return trackID, nil
}
}
return 0, fmt.Errorf("invalid track (%s)", pathAndQuery)
}()
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
if _, ok := sc.tracks[trackID]; ok {
if _, ok := sc.setuppedTracks[trackID]; ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("track %d has already been setup", trackID)
}
if sc.tracksProtocol != nil && *sc.tracksProtocol != th.Protocol {
if sc.setuppedTracksProtocol != nil && *sc.setuppedTracksProtocol != th.Protocol {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("can't setup tracks with different protocols")
@@ -542,15 +595,25 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("transport header does not contain mode=record")
}
if trackID >= len(sc.announcedTracks) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("unable to setup track %d", trackID)
}
}
res, err := sc.readHandlers.OnSetup(req, th, trackID)
if res.StatusCode == 200 {
sc.tracksProtocol = &th.Protocol
sc.setuppedTracksProtocol = &th.Protocol
if sc.setuppedTracks == nil {
sc.setuppedTracks = make(map[int]ServerConnSetuppedTrack)
}
if th.Protocol == StreamProtocolUDP {
sc.tracks[trackID] = ServerConnTrack{
sc.setuppedTracks[trackID] = ServerConnSetuppedTrack{
rtpPort: th.ClientPorts[0],
rtcpPort: th.ClientPorts[1],
}
@@ -566,7 +629,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}.Write()
} else {
sc.tracks[trackID] = ServerConnTrack{}
sc.setuppedTracks[trackID] = ServerConnSetuppedTrack{}
res.Header["Transport"] = headers.Transport{
Protocol: StreamProtocolTCP,
@@ -610,7 +673,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, err
}
if len(sc.tracks) == 0 {
if len(sc.setuppedTracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("no tracks have been setup")
@@ -637,13 +700,13 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, err
}
if len(sc.tracks) == 0 {
if len(sc.setuppedTracks) == 0 {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("no tracks have been setup")
}
if len(sc.tracks) != len(sc.rtcpReceivers) {
if len(sc.setuppedTracks) != len(sc.announcedTracks) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("not all tracks have been setup")
@@ -804,9 +867,9 @@ outer:
switch what.(type) {
case *base.InterleavedFrame:
// forward frame only if it has been set up
if _, ok := sc.tracks[frame.TrackID]; ok {
if _, ok := sc.setuppedTracks[frame.TrackID]; ok {
if sc.state == ServerConnStateRecord {
sc.rtcpReceivers[frame.TrackID].ProcessFrame(time.Now(),
sc.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(),
frame.StreamType, frame.Payload)
}
sc.readHandlers.OnFrame(frame.TrackID, frame.StreamType, frame.Payload)
@@ -861,8 +924,8 @@ func (sc *ServerConn) Read(readHandlers ServerConnReadHandlers) chan error {
// WriteFrame writes a frame.
func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) {
if *sc.tracksProtocol == StreamProtocolUDP {
track := sc.tracks[trackID]
if *sc.setuppedTracksProtocol == StreamProtocolUDP {
track := sc.setuppedTracks[trackID]
if streamType == StreamTypeRTP {
sc.conf.UDPRTPListener.write(payload, &net.UDPAddr{
@@ -902,13 +965,13 @@ func (sc *ServerConn) backgroundRecord() {
for {
select {
case <-checkStreamTicker.C:
if *sc.tracksProtocol != StreamProtocolUDP {
if *sc.setuppedTracksProtocol != StreamProtocolUDP {
continue
}
now := time.Now()
for _, lastUnix := range sc.udpLastFrameTimes {
last := time.Unix(atomic.LoadInt64(lastUnix), 0)
for _, track := range sc.announcedTracks {
last := time.Unix(atomic.LoadInt64(track.udpLastFrameTime), 0)
if now.Sub(last) >= sc.conf.ReadTimeout {
atomic.StoreInt32(&sc.udpTimeout, 1)
@@ -919,8 +982,8 @@ func (sc *ServerConn) backgroundRecord() {
case <-receiverReportTicker.C:
now := time.Now()
for trackID := range sc.tracks {
r := sc.rtcpReceivers[trackID].Report(now)
for trackID, track := range sc.announcedTracks {
r := track.rtcpReceiver.Report(now)
sc.WriteFrame(trackID, StreamTypeRTP, r)
}

View File

@@ -128,8 +128,8 @@ func (s *ServerUDPListener) run() {
}
now := time.Now()
atomic.StoreInt64(pubData.publisher.udpLastFrameTimes[pubData.trackID], now.Unix())
pubData.publisher.rtcpReceivers[pubData.trackID].ProcessFrame(now, s.streamType, buf[:n])
atomic.StoreInt64(pubData.publisher.announcedTracks[pubData.trackID].udpLastFrameTime, now.Unix())
pubData.publisher.announcedTracks[pubData.trackID].rtcpReceiver.ProcessFrame(now, s.streamType, buf[:n])
pubData.publisher.readHandlers.OnFrame(pubData.trackID, s.streamType, buf[:n])
}()
}

View File

@@ -16,7 +16,7 @@ import (
// Track is a track available in a certain URL.
type Track struct {
// base url
// base URL
BaseURL *base.URL
// id
@@ -204,7 +204,7 @@ func (t *Track) URL() (*base.URL, error) {
type Tracks []*Track
// ReadTracks decodes tracks from SDP.
func ReadTracks(byts []byte) (Tracks, error) {
func ReadTracks(byts []byte, baseURL *base.URL) (Tracks, error) {
desc := sdp.SessionDescription{}
err := desc.Unmarshal(byts)
if err != nil {
@@ -215,8 +215,9 @@ func ReadTracks(byts []byte) (Tracks, error) {
for i, media := range desc.MediaDescriptions {
tracks[i] = &Track{
ID: i,
Media: media,
BaseURL: baseURL,
ID: i,
Media: media,
}
}

View File

@@ -72,7 +72,7 @@ func TestTrackClockRate(t *testing.T) {
},
} {
t.Run(ca.name, func(t *testing.T) {
tracks, err := ReadTracks(ca.sdp)
tracks, err := ReadTracks(ca.sdp, nil)
require.NoError(t, err)
clockRate, err := tracks[0].ClockRate()