From bc5b3d9cbc88f8b68726360787407c5d25d9888f Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Mon, 18 Jan 2021 22:21:36 +0100 Subject: [PATCH] ServerConn: save announced tracks --- clientconn.go | 8 +- serverconn.go | 211 ++++++++++++++++++++++++++++++++------------------ serverudpl.go | 4 +- track.go | 9 ++- track_test.go | 2 +- 5 files changed, 147 insertions(+), 87 deletions(-) diff --git a/clientconn.go b/clientconn.go index 411183da..76fea44b 100644 --- a/clientconn.go +++ b/clientconn.go @@ -59,7 +59,7 @@ func (s clientConnState) String() string { case clientConnStateRecord: return "record" } - return "uknown" + return "unknown" } // ClientConn is a client-side RTSP connection. @@ -395,15 +395,11 @@ func (c *ClientConn) Describe(u *base.URL) (Tracks, *base.Response, error) { return nil, nil, fmt.Errorf("wrong Content-Type, expected application/sdp") } - tracks, err := ReadTracks(res.Body) + tracks, err := ReadTracks(res.Body, u) if err != nil { return nil, nil, err } - for _, t := range tracks { - t.BaseURL = u - } - return tracks, res, nil } diff --git a/serverconn.go b/serverconn.go index bdbfc892..49d8fc97 100644 --- a/serverconn.go +++ b/serverconn.go @@ -58,34 +58,20 @@ func (s ServerConnState) String() string { case ServerConnStateRecord: return "record" } - return "uknown" + return "unknown" } -// ServerConnTrack is a track of a ServerConn. -type ServerConnTrack struct { +// ServerConnSetuppedTrack is a setupped track of a ServerConn. +type ServerConnSetuppedTrack struct { rtpPort int rtcpPort int } -func extractTrackID(pathAndQuery string, mode *headers.TransportMode, trackLen int) (int, error) { - if mode == nil || *mode == headers.TransportModePlay { - i := strings.Index(pathAndQuery, "/trackID=") - - // URL doesn't contain trackID - we assume it's track 0 - if i < 0 { - return 0, nil - } - - tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64) - if err != nil || tmp < 0 { - return 0, fmt.Errorf("invalid track id (%s)", pathAndQuery) - } - trackID := int(tmp) - - return trackID, nil - } - - return trackLen, nil +// ServerConnAnnouncedTrack is an announced track of a ServerConn. +type ServerConnAnnouncedTrack struct { + track *Track + rtcpReceiver *rtcpreceiver.RTCPReceiver + udpLastFrameTime *int64 } // ServerConnReadHandlers allows to set the handlers required by ServerConn.Read. @@ -136,15 +122,16 @@ type ServerConnReadHandlers struct { // ServerConn is a server-side RTSP connection. type ServerConn struct { - conf ServerConf - nconn net.Conn - br *bufio.Reader - bw *bufio.Writer - state ServerConnState - tracks map[int]ServerConnTrack - tracksProtocol *StreamProtocol - readHandlers ServerConnReadHandlers - rtcpReceivers []*rtcpreceiver.RTCPReceiver + conf ServerConf + nconn net.Conn + br *bufio.Reader + bw *bufio.Writer + state ServerConnState + readHandlers ServerConnReadHandlers + setuppedTracks map[int]ServerConnSetuppedTrack + setuppedTracksProtocol *StreamProtocol + announcedTracks []ServerConnAnnouncedTrack + doEnableFrames bool framesEnabled bool readTimeoutEnabled bool @@ -157,7 +144,6 @@ type ServerConn struct { backgroundRecordTerminate chan struct{} backgroundRecordDone chan struct{} udpTimeout int32 - udpLastFrameTimes []*int64 // in terminate chan struct{} @@ -176,7 +162,6 @@ func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn { nconn: nconn, br: bufio.NewReaderSize(conn, serverConnReadBufferSize), bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize), - tracks: make(map[int]ServerConnTrack), frameRingBuffer: ringbuffer.New(conf.ReadBufferCount), backgroundWriteDone: make(chan struct{}), terminate: make(chan struct{}), @@ -195,25 +180,25 @@ func (sc *ServerConn) State() ServerConnState { return sc.state } -// TracksProtocol returns the tracks protocol. -func (sc *ServerConn) TracksProtocol() *StreamProtocol { - return sc.tracksProtocol +// SetuppedTracksProtocol returns the setupped tracks protocol. +func (sc *ServerConn) SetuppedTracksProtocol() *StreamProtocol { + return sc.setuppedTracksProtocol } -// TracksLen returns the number of setupped tracks. -func (sc *ServerConn) TracksLen() int { - return len(sc.tracks) +// SetuppedTracksLen returns the number of setupped tracks. +func (sc *ServerConn) SetuppedTracksLen() int { + return len(sc.setuppedTracks) } -// HasTrack checks whether a track has been setup. -func (sc *ServerConn) HasTrack(trackID int) bool { - _, ok := sc.tracks[trackID] +// HasSetuppedTrack checks whether a track has been setup. +func (sc *ServerConn) HasSetuppedTrack(trackID int) bool { + _, ok := sc.setuppedTracks[trackID] return ok } -// Tracks returns the setupped tracks. -func (sc *ServerConn) Tracks() map[int]ServerConnTrack { - return sc.tracks +// SetuppedTracks returns the setupped tracks. +func (sc *ServerConn) SetuppedTracks() map[int]ServerConnSetuppedTrack { + return sc.setuppedTracks } func (sc *ServerConn) backgroundWrite() { @@ -269,17 +254,17 @@ func (sc *ServerConn) zone() string { func (sc *ServerConn) frameModeEnable() { switch sc.state { case ServerConnStatePlay: - if *sc.tracksProtocol == StreamProtocolTCP { + if *sc.setuppedTracksProtocol == StreamProtocolTCP { sc.doEnableFrames = true } case ServerConnStateRecord: - if *sc.tracksProtocol == StreamProtocolTCP { + if *sc.setuppedTracksProtocol == StreamProtocolTCP { sc.doEnableFrames = true sc.readTimeoutEnabled = true } else { - for trackID, track := range sc.tracks { + for trackID, track := range sc.setuppedTracks { sc.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc) sc.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc) @@ -300,7 +285,7 @@ func (sc *ServerConn) frameModeEnable() { func (sc *ServerConn) frameModeDisable() { switch sc.state { case ServerConnStatePlay: - if *sc.tracksProtocol == StreamProtocolTCP { + if *sc.setuppedTracksProtocol == StreamProtocolTCP { sc.framesEnabled = false sc.frameRingBuffer.Close() <-sc.backgroundWriteDone @@ -310,7 +295,7 @@ func (sc *ServerConn) frameModeDisable() { close(sc.backgroundRecordTerminate) <-sc.backgroundRecordDone - if *sc.tracksProtocol == StreamProtocolTCP { + if *sc.setuppedTracksProtocol == StreamProtocolTCP { sc.readTimeoutEnabled = false sc.nconn.SetReadDeadline(time.Time{}) @@ -319,7 +304,7 @@ func (sc *ServerConn) frameModeDisable() { <-sc.backgroundWriteDone } else { - for _, track := range sc.tracks { + for _, track := range sc.setuppedTracks { sc.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort) sc.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort) } @@ -415,7 +400,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, fmt.Errorf("unsupported Content-Type '%s'", ct) } - tracks, err := ReadTracks(req.Body) + tracks, err := ReadTracks(req.Body, req.URL) if err != nil { return &base.Response{ StatusCode: base.StatusBadRequest, @@ -428,19 +413,52 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, errors.New("no tracks defined") } + reqPath, ok := req.URL.RTSPPath() + if !ok { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, errors.New("invalid path") + } + + for _, track := range tracks { + trackURL, err := track.URL() + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid track URL") + } + + trackPath, ok := trackURL.RTSPPath() + if !ok { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid track URL") + } + + if !strings.HasPrefix(trackPath, reqPath) { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid track URL: must begin with '%s', but is '%s'", + reqPath, trackPath) + } + } + res, err := sc.readHandlers.OnAnnounce(req, tracks) if res.StatusCode == 200 { sc.state = ServerConnStatePreRecord - sc.rtcpReceivers = make([]*rtcpreceiver.RTCPReceiver, len(tracks)) - sc.udpLastFrameTimes = make([]*int64, len(tracks)) + sc.announcedTracks = make([]ServerConnAnnouncedTrack, len(tracks)) for trackID, track := range tracks { clockRate, _ := track.ClockRate() - sc.rtcpReceivers[trackID] = rtcpreceiver.New(nil, clockRate) v := time.Now().Unix() - sc.udpLastFrameTimes[trackID] = &v + + sc.announcedTracks[trackID] = ServerConnAnnouncedTrack{ + track: track, + rtcpReceiver: rtcpreceiver.New(nil, clockRate), + udpLastFrameTime: &v, + } } } @@ -480,20 +498,55 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, fmt.Errorf("multicast is not supported") } - trackID, err := extractTrackID(pathAndQuery, th.Mode, len(sc.tracks)) + trackID, err := func() (int, error) { + if th.Mode == nil || *th.Mode == headers.TransportModePlay { + i := strings.Index(pathAndQuery, "/trackID=") + + // URL doesn't contain trackID - we assume it's track 0 + if i < 0 { + return 0, nil + } + + tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64) + if err != nil || tmp < 0 { + return 0, fmt.Errorf("invalid track (%s)", pathAndQuery) + } + trackID := int(tmp) + + // remove track ID from path + nu := &base.URL{ + Scheme: req.URL.Scheme, + Host: req.URL.Host, + User: req.URL.User, + } + nu, _ = base.ParseURL(nu.String() + pathAndQuery[:i]) + req.URL = nu + + return trackID, nil + } + + for trackID, track := range sc.announcedTracks { + u, _ := track.track.URL() + if u.String() == req.URL.String() { + return trackID, nil + } + } + + return 0, fmt.Errorf("invalid track (%s)", pathAndQuery) + }() if err != nil { return &base.Response{ StatusCode: base.StatusBadRequest, }, err } - if _, ok := sc.tracks[trackID]; ok { + if _, ok := sc.setuppedTracks[trackID]; ok { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("track %d has already been setup", trackID) } - if sc.tracksProtocol != nil && *sc.tracksProtocol != th.Protocol { + if sc.setuppedTracksProtocol != nil && *sc.setuppedTracksProtocol != th.Protocol { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("can't setup tracks with different protocols") @@ -542,15 +595,25 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { StatusCode: base.StatusBadRequest, }, fmt.Errorf("transport header does not contain mode=record") } + + if trackID >= len(sc.announcedTracks) { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("unable to setup track %d", trackID) + } } res, err := sc.readHandlers.OnSetup(req, th, trackID) if res.StatusCode == 200 { - sc.tracksProtocol = &th.Protocol + sc.setuppedTracksProtocol = &th.Protocol + + if sc.setuppedTracks == nil { + sc.setuppedTracks = make(map[int]ServerConnSetuppedTrack) + } if th.Protocol == StreamProtocolUDP { - sc.tracks[trackID] = ServerConnTrack{ + sc.setuppedTracks[trackID] = ServerConnSetuppedTrack{ rtpPort: th.ClientPorts[0], rtcpPort: th.ClientPorts[1], } @@ -566,7 +629,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }.Write() } else { - sc.tracks[trackID] = ServerConnTrack{} + sc.setuppedTracks[trackID] = ServerConnSetuppedTrack{} res.Header["Transport"] = headers.Transport{ Protocol: StreamProtocolTCP, @@ -610,7 +673,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, err } - if len(sc.tracks) == 0 { + if len(sc.setuppedTracks) == 0 { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("no tracks have been setup") @@ -637,13 +700,13 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, err } - if len(sc.tracks) == 0 { + if len(sc.setuppedTracks) == 0 { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("no tracks have been setup") } - if len(sc.tracks) != len(sc.rtcpReceivers) { + if len(sc.setuppedTracks) != len(sc.announcedTracks) { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("not all tracks have been setup") @@ -804,9 +867,9 @@ outer: switch what.(type) { case *base.InterleavedFrame: // forward frame only if it has been set up - if _, ok := sc.tracks[frame.TrackID]; ok { + if _, ok := sc.setuppedTracks[frame.TrackID]; ok { if sc.state == ServerConnStateRecord { - sc.rtcpReceivers[frame.TrackID].ProcessFrame(time.Now(), + sc.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(), frame.StreamType, frame.Payload) } sc.readHandlers.OnFrame(frame.TrackID, frame.StreamType, frame.Payload) @@ -861,8 +924,8 @@ func (sc *ServerConn) Read(readHandlers ServerConnReadHandlers) chan error { // WriteFrame writes a frame. func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) { - if *sc.tracksProtocol == StreamProtocolUDP { - track := sc.tracks[trackID] + if *sc.setuppedTracksProtocol == StreamProtocolUDP { + track := sc.setuppedTracks[trackID] if streamType == StreamTypeRTP { sc.conf.UDPRTPListener.write(payload, &net.UDPAddr{ @@ -902,13 +965,13 @@ func (sc *ServerConn) backgroundRecord() { for { select { case <-checkStreamTicker.C: - if *sc.tracksProtocol != StreamProtocolUDP { + if *sc.setuppedTracksProtocol != StreamProtocolUDP { continue } now := time.Now() - for _, lastUnix := range sc.udpLastFrameTimes { - last := time.Unix(atomic.LoadInt64(lastUnix), 0) + for _, track := range sc.announcedTracks { + last := time.Unix(atomic.LoadInt64(track.udpLastFrameTime), 0) if now.Sub(last) >= sc.conf.ReadTimeout { atomic.StoreInt32(&sc.udpTimeout, 1) @@ -919,8 +982,8 @@ func (sc *ServerConn) backgroundRecord() { case <-receiverReportTicker.C: now := time.Now() - for trackID := range sc.tracks { - r := sc.rtcpReceivers[trackID].Report(now) + for trackID, track := range sc.announcedTracks { + r := track.rtcpReceiver.Report(now) sc.WriteFrame(trackID, StreamTypeRTP, r) } diff --git a/serverudpl.go b/serverudpl.go index a1208820..bde94b14 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -128,8 +128,8 @@ func (s *ServerUDPListener) run() { } now := time.Now() - atomic.StoreInt64(pubData.publisher.udpLastFrameTimes[pubData.trackID], now.Unix()) - pubData.publisher.rtcpReceivers[pubData.trackID].ProcessFrame(now, s.streamType, buf[:n]) + atomic.StoreInt64(pubData.publisher.announcedTracks[pubData.trackID].udpLastFrameTime, now.Unix()) + pubData.publisher.announcedTracks[pubData.trackID].rtcpReceiver.ProcessFrame(now, s.streamType, buf[:n]) pubData.publisher.readHandlers.OnFrame(pubData.trackID, s.streamType, buf[:n]) }() } diff --git a/track.go b/track.go index 85122938..d77cd08d 100644 --- a/track.go +++ b/track.go @@ -16,7 +16,7 @@ import ( // Track is a track available in a certain URL. type Track struct { - // base url + // base URL BaseURL *base.URL // id @@ -204,7 +204,7 @@ func (t *Track) URL() (*base.URL, error) { type Tracks []*Track // ReadTracks decodes tracks from SDP. -func ReadTracks(byts []byte) (Tracks, error) { +func ReadTracks(byts []byte, baseURL *base.URL) (Tracks, error) { desc := sdp.SessionDescription{} err := desc.Unmarshal(byts) if err != nil { @@ -215,8 +215,9 @@ func ReadTracks(byts []byte) (Tracks, error) { for i, media := range desc.MediaDescriptions { tracks[i] = &Track{ - ID: i, - Media: media, + BaseURL: baseURL, + ID: i, + Media: media, } } diff --git a/track_test.go b/track_test.go index 51950f60..c8927b37 100644 --- a/track_test.go +++ b/track_test.go @@ -72,7 +72,7 @@ func TestTrackClockRate(t *testing.T) { }, } { t.Run(ca.name, func(t *testing.T) { - tracks, err := ReadTracks(ca.sdp) + tracks, err := ReadTracks(ca.sdp, nil) require.NoError(t, err) clockRate, err := tracks[0].ClockRate()