mirror of
https://github.com/aler9/gortsplib
synced 2025-10-05 23:26:54 +08:00
ServerConn: save announced tracks
This commit is contained in:
@@ -59,7 +59,7 @@ func (s clientConnState) String() string {
|
||||
case clientConnStateRecord:
|
||||
return "record"
|
||||
}
|
||||
return "uknown"
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// ClientConn is a client-side RTSP connection.
|
||||
@@ -395,15 +395,11 @@ func (c *ClientConn) Describe(u *base.URL) (Tracks, *base.Response, error) {
|
||||
return nil, nil, fmt.Errorf("wrong Content-Type, expected application/sdp")
|
||||
}
|
||||
|
||||
tracks, err := ReadTracks(res.Body)
|
||||
tracks, err := ReadTracks(res.Body, u)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, t := range tracks {
|
||||
t.BaseURL = u
|
||||
}
|
||||
|
||||
return tracks, res, nil
|
||||
}
|
||||
|
||||
|
199
serverconn.go
199
serverconn.go
@@ -58,34 +58,20 @@ func (s ServerConnState) String() string {
|
||||
case ServerConnStateRecord:
|
||||
return "record"
|
||||
}
|
||||
return "uknown"
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// ServerConnTrack is a track of a ServerConn.
|
||||
type ServerConnTrack struct {
|
||||
// ServerConnSetuppedTrack is a setupped track of a ServerConn.
|
||||
type ServerConnSetuppedTrack struct {
|
||||
rtpPort int
|
||||
rtcpPort int
|
||||
}
|
||||
|
||||
func extractTrackID(pathAndQuery string, mode *headers.TransportMode, trackLen int) (int, error) {
|
||||
if mode == nil || *mode == headers.TransportModePlay {
|
||||
i := strings.Index(pathAndQuery, "/trackID=")
|
||||
|
||||
// URL doesn't contain trackID - we assume it's track 0
|
||||
if i < 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
|
||||
if err != nil || tmp < 0 {
|
||||
return 0, fmt.Errorf("invalid track id (%s)", pathAndQuery)
|
||||
}
|
||||
trackID := int(tmp)
|
||||
|
||||
return trackID, nil
|
||||
}
|
||||
|
||||
return trackLen, nil
|
||||
// ServerConnAnnouncedTrack is an announced track of a ServerConn.
|
||||
type ServerConnAnnouncedTrack struct {
|
||||
track *Track
|
||||
rtcpReceiver *rtcpreceiver.RTCPReceiver
|
||||
udpLastFrameTime *int64
|
||||
}
|
||||
|
||||
// ServerConnReadHandlers allows to set the handlers required by ServerConn.Read.
|
||||
@@ -141,10 +127,11 @@ type ServerConn struct {
|
||||
br *bufio.Reader
|
||||
bw *bufio.Writer
|
||||
state ServerConnState
|
||||
tracks map[int]ServerConnTrack
|
||||
tracksProtocol *StreamProtocol
|
||||
readHandlers ServerConnReadHandlers
|
||||
rtcpReceivers []*rtcpreceiver.RTCPReceiver
|
||||
setuppedTracks map[int]ServerConnSetuppedTrack
|
||||
setuppedTracksProtocol *StreamProtocol
|
||||
announcedTracks []ServerConnAnnouncedTrack
|
||||
|
||||
doEnableFrames bool
|
||||
framesEnabled bool
|
||||
readTimeoutEnabled bool
|
||||
@@ -157,7 +144,6 @@ type ServerConn struct {
|
||||
backgroundRecordTerminate chan struct{}
|
||||
backgroundRecordDone chan struct{}
|
||||
udpTimeout int32
|
||||
udpLastFrameTimes []*int64
|
||||
|
||||
// in
|
||||
terminate chan struct{}
|
||||
@@ -176,7 +162,6 @@ func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn {
|
||||
nconn: nconn,
|
||||
br: bufio.NewReaderSize(conn, serverConnReadBufferSize),
|
||||
bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize),
|
||||
tracks: make(map[int]ServerConnTrack),
|
||||
frameRingBuffer: ringbuffer.New(conf.ReadBufferCount),
|
||||
backgroundWriteDone: make(chan struct{}),
|
||||
terminate: make(chan struct{}),
|
||||
@@ -195,25 +180,25 @@ func (sc *ServerConn) State() ServerConnState {
|
||||
return sc.state
|
||||
}
|
||||
|
||||
// TracksProtocol returns the tracks protocol.
|
||||
func (sc *ServerConn) TracksProtocol() *StreamProtocol {
|
||||
return sc.tracksProtocol
|
||||
// SetuppedTracksProtocol returns the setupped tracks protocol.
|
||||
func (sc *ServerConn) SetuppedTracksProtocol() *StreamProtocol {
|
||||
return sc.setuppedTracksProtocol
|
||||
}
|
||||
|
||||
// TracksLen returns the number of setupped tracks.
|
||||
func (sc *ServerConn) TracksLen() int {
|
||||
return len(sc.tracks)
|
||||
// SetuppedTracksLen returns the number of setupped tracks.
|
||||
func (sc *ServerConn) SetuppedTracksLen() int {
|
||||
return len(sc.setuppedTracks)
|
||||
}
|
||||
|
||||
// HasTrack checks whether a track has been setup.
|
||||
func (sc *ServerConn) HasTrack(trackID int) bool {
|
||||
_, ok := sc.tracks[trackID]
|
||||
// HasSetuppedTrack checks whether a track has been setup.
|
||||
func (sc *ServerConn) HasSetuppedTrack(trackID int) bool {
|
||||
_, ok := sc.setuppedTracks[trackID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Tracks returns the setupped tracks.
|
||||
func (sc *ServerConn) Tracks() map[int]ServerConnTrack {
|
||||
return sc.tracks
|
||||
// SetuppedTracks returns the setupped tracks.
|
||||
func (sc *ServerConn) SetuppedTracks() map[int]ServerConnSetuppedTrack {
|
||||
return sc.setuppedTracks
|
||||
}
|
||||
|
||||
func (sc *ServerConn) backgroundWrite() {
|
||||
@@ -269,17 +254,17 @@ func (sc *ServerConn) zone() string {
|
||||
func (sc *ServerConn) frameModeEnable() {
|
||||
switch sc.state {
|
||||
case ServerConnStatePlay:
|
||||
if *sc.tracksProtocol == StreamProtocolTCP {
|
||||
if *sc.setuppedTracksProtocol == StreamProtocolTCP {
|
||||
sc.doEnableFrames = true
|
||||
}
|
||||
|
||||
case ServerConnStateRecord:
|
||||
if *sc.tracksProtocol == StreamProtocolTCP {
|
||||
if *sc.setuppedTracksProtocol == StreamProtocolTCP {
|
||||
sc.doEnableFrames = true
|
||||
sc.readTimeoutEnabled = true
|
||||
|
||||
} else {
|
||||
for trackID, track := range sc.tracks {
|
||||
for trackID, track := range sc.setuppedTracks {
|
||||
sc.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc)
|
||||
sc.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc)
|
||||
|
||||
@@ -300,7 +285,7 @@ func (sc *ServerConn) frameModeEnable() {
|
||||
func (sc *ServerConn) frameModeDisable() {
|
||||
switch sc.state {
|
||||
case ServerConnStatePlay:
|
||||
if *sc.tracksProtocol == StreamProtocolTCP {
|
||||
if *sc.setuppedTracksProtocol == StreamProtocolTCP {
|
||||
sc.framesEnabled = false
|
||||
sc.frameRingBuffer.Close()
|
||||
<-sc.backgroundWriteDone
|
||||
@@ -310,7 +295,7 @@ func (sc *ServerConn) frameModeDisable() {
|
||||
close(sc.backgroundRecordTerminate)
|
||||
<-sc.backgroundRecordDone
|
||||
|
||||
if *sc.tracksProtocol == StreamProtocolTCP {
|
||||
if *sc.setuppedTracksProtocol == StreamProtocolTCP {
|
||||
sc.readTimeoutEnabled = false
|
||||
sc.nconn.SetReadDeadline(time.Time{})
|
||||
|
||||
@@ -319,7 +304,7 @@ func (sc *ServerConn) frameModeDisable() {
|
||||
<-sc.backgroundWriteDone
|
||||
|
||||
} else {
|
||||
for _, track := range sc.tracks {
|
||||
for _, track := range sc.setuppedTracks {
|
||||
sc.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort)
|
||||
sc.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort)
|
||||
}
|
||||
@@ -415,7 +400,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
}, fmt.Errorf("unsupported Content-Type '%s'", ct)
|
||||
}
|
||||
|
||||
tracks, err := ReadTracks(req.Body)
|
||||
tracks, err := ReadTracks(req.Body, req.URL)
|
||||
if err != nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
@@ -428,19 +413,52 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
}, errors.New("no tracks defined")
|
||||
}
|
||||
|
||||
reqPath, ok := req.URL.RTSPPath()
|
||||
if !ok {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, errors.New("invalid path")
|
||||
}
|
||||
|
||||
for _, track := range tracks {
|
||||
trackURL, err := track.URL()
|
||||
if err != nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, fmt.Errorf("invalid track URL")
|
||||
}
|
||||
|
||||
trackPath, ok := trackURL.RTSPPath()
|
||||
if !ok {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, fmt.Errorf("invalid track URL")
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(trackPath, reqPath) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, fmt.Errorf("invalid track URL: must begin with '%s', but is '%s'",
|
||||
reqPath, trackPath)
|
||||
}
|
||||
}
|
||||
|
||||
res, err := sc.readHandlers.OnAnnounce(req, tracks)
|
||||
|
||||
if res.StatusCode == 200 {
|
||||
sc.state = ServerConnStatePreRecord
|
||||
|
||||
sc.rtcpReceivers = make([]*rtcpreceiver.RTCPReceiver, len(tracks))
|
||||
sc.udpLastFrameTimes = make([]*int64, len(tracks))
|
||||
sc.announcedTracks = make([]ServerConnAnnouncedTrack, len(tracks))
|
||||
|
||||
for trackID, track := range tracks {
|
||||
clockRate, _ := track.ClockRate()
|
||||
sc.rtcpReceivers[trackID] = rtcpreceiver.New(nil, clockRate)
|
||||
v := time.Now().Unix()
|
||||
sc.udpLastFrameTimes[trackID] = &v
|
||||
|
||||
sc.announcedTracks[trackID] = ServerConnAnnouncedTrack{
|
||||
track: track,
|
||||
rtcpReceiver: rtcpreceiver.New(nil, clockRate),
|
||||
udpLastFrameTime: &v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -480,20 +498,55 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
}, fmt.Errorf("multicast is not supported")
|
||||
}
|
||||
|
||||
trackID, err := extractTrackID(pathAndQuery, th.Mode, len(sc.tracks))
|
||||
trackID, err := func() (int, error) {
|
||||
if th.Mode == nil || *th.Mode == headers.TransportModePlay {
|
||||
i := strings.Index(pathAndQuery, "/trackID=")
|
||||
|
||||
// URL doesn't contain trackID - we assume it's track 0
|
||||
if i < 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
|
||||
if err != nil || tmp < 0 {
|
||||
return 0, fmt.Errorf("invalid track (%s)", pathAndQuery)
|
||||
}
|
||||
trackID := int(tmp)
|
||||
|
||||
// remove track ID from path
|
||||
nu := &base.URL{
|
||||
Scheme: req.URL.Scheme,
|
||||
Host: req.URL.Host,
|
||||
User: req.URL.User,
|
||||
}
|
||||
nu, _ = base.ParseURL(nu.String() + pathAndQuery[:i])
|
||||
req.URL = nu
|
||||
|
||||
return trackID, nil
|
||||
}
|
||||
|
||||
for trackID, track := range sc.announcedTracks {
|
||||
u, _ := track.track.URL()
|
||||
if u.String() == req.URL.String() {
|
||||
return trackID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("invalid track (%s)", pathAndQuery)
|
||||
}()
|
||||
if err != nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, err
|
||||
}
|
||||
|
||||
if _, ok := sc.tracks[trackID]; ok {
|
||||
if _, ok := sc.setuppedTracks[trackID]; ok {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, fmt.Errorf("track %d has already been setup", trackID)
|
||||
}
|
||||
|
||||
if sc.tracksProtocol != nil && *sc.tracksProtocol != th.Protocol {
|
||||
if sc.setuppedTracksProtocol != nil && *sc.setuppedTracksProtocol != th.Protocol {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, fmt.Errorf("can't setup tracks with different protocols")
|
||||
@@ -542,15 +595,25 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, fmt.Errorf("transport header does not contain mode=record")
|
||||
}
|
||||
|
||||
if trackID >= len(sc.announcedTracks) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, fmt.Errorf("unable to setup track %d", trackID)
|
||||
}
|
||||
}
|
||||
|
||||
res, err := sc.readHandlers.OnSetup(req, th, trackID)
|
||||
|
||||
if res.StatusCode == 200 {
|
||||
sc.tracksProtocol = &th.Protocol
|
||||
sc.setuppedTracksProtocol = &th.Protocol
|
||||
|
||||
if sc.setuppedTracks == nil {
|
||||
sc.setuppedTracks = make(map[int]ServerConnSetuppedTrack)
|
||||
}
|
||||
|
||||
if th.Protocol == StreamProtocolUDP {
|
||||
sc.tracks[trackID] = ServerConnTrack{
|
||||
sc.setuppedTracks[trackID] = ServerConnSetuppedTrack{
|
||||
rtpPort: th.ClientPorts[0],
|
||||
rtcpPort: th.ClientPorts[1],
|
||||
}
|
||||
@@ -566,7 +629,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
}.Write()
|
||||
|
||||
} else {
|
||||
sc.tracks[trackID] = ServerConnTrack{}
|
||||
sc.setuppedTracks[trackID] = ServerConnSetuppedTrack{}
|
||||
|
||||
res.Header["Transport"] = headers.Transport{
|
||||
Protocol: StreamProtocolTCP,
|
||||
@@ -610,7 +673,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
}, err
|
||||
}
|
||||
|
||||
if len(sc.tracks) == 0 {
|
||||
if len(sc.setuppedTracks) == 0 {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, fmt.Errorf("no tracks have been setup")
|
||||
@@ -637,13 +700,13 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
}, err
|
||||
}
|
||||
|
||||
if len(sc.tracks) == 0 {
|
||||
if len(sc.setuppedTracks) == 0 {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, fmt.Errorf("no tracks have been setup")
|
||||
}
|
||||
|
||||
if len(sc.tracks) != len(sc.rtcpReceivers) {
|
||||
if len(sc.setuppedTracks) != len(sc.announcedTracks) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, fmt.Errorf("not all tracks have been setup")
|
||||
@@ -804,9 +867,9 @@ outer:
|
||||
switch what.(type) {
|
||||
case *base.InterleavedFrame:
|
||||
// forward frame only if it has been set up
|
||||
if _, ok := sc.tracks[frame.TrackID]; ok {
|
||||
if _, ok := sc.setuppedTracks[frame.TrackID]; ok {
|
||||
if sc.state == ServerConnStateRecord {
|
||||
sc.rtcpReceivers[frame.TrackID].ProcessFrame(time.Now(),
|
||||
sc.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(),
|
||||
frame.StreamType, frame.Payload)
|
||||
}
|
||||
sc.readHandlers.OnFrame(frame.TrackID, frame.StreamType, frame.Payload)
|
||||
@@ -861,8 +924,8 @@ func (sc *ServerConn) Read(readHandlers ServerConnReadHandlers) chan error {
|
||||
|
||||
// WriteFrame writes a frame.
|
||||
func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) {
|
||||
if *sc.tracksProtocol == StreamProtocolUDP {
|
||||
track := sc.tracks[trackID]
|
||||
if *sc.setuppedTracksProtocol == StreamProtocolUDP {
|
||||
track := sc.setuppedTracks[trackID]
|
||||
|
||||
if streamType == StreamTypeRTP {
|
||||
sc.conf.UDPRTPListener.write(payload, &net.UDPAddr{
|
||||
@@ -902,13 +965,13 @@ func (sc *ServerConn) backgroundRecord() {
|
||||
for {
|
||||
select {
|
||||
case <-checkStreamTicker.C:
|
||||
if *sc.tracksProtocol != StreamProtocolUDP {
|
||||
if *sc.setuppedTracksProtocol != StreamProtocolUDP {
|
||||
continue
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for _, lastUnix := range sc.udpLastFrameTimes {
|
||||
last := time.Unix(atomic.LoadInt64(lastUnix), 0)
|
||||
for _, track := range sc.announcedTracks {
|
||||
last := time.Unix(atomic.LoadInt64(track.udpLastFrameTime), 0)
|
||||
|
||||
if now.Sub(last) >= sc.conf.ReadTimeout {
|
||||
atomic.StoreInt32(&sc.udpTimeout, 1)
|
||||
@@ -919,8 +982,8 @@ func (sc *ServerConn) backgroundRecord() {
|
||||
|
||||
case <-receiverReportTicker.C:
|
||||
now := time.Now()
|
||||
for trackID := range sc.tracks {
|
||||
r := sc.rtcpReceivers[trackID].Report(now)
|
||||
for trackID, track := range sc.announcedTracks {
|
||||
r := track.rtcpReceiver.Report(now)
|
||||
sc.WriteFrame(trackID, StreamTypeRTP, r)
|
||||
}
|
||||
|
||||
|
@@ -128,8 +128,8 @@ func (s *ServerUDPListener) run() {
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
atomic.StoreInt64(pubData.publisher.udpLastFrameTimes[pubData.trackID], now.Unix())
|
||||
pubData.publisher.rtcpReceivers[pubData.trackID].ProcessFrame(now, s.streamType, buf[:n])
|
||||
atomic.StoreInt64(pubData.publisher.announcedTracks[pubData.trackID].udpLastFrameTime, now.Unix())
|
||||
pubData.publisher.announcedTracks[pubData.trackID].rtcpReceiver.ProcessFrame(now, s.streamType, buf[:n])
|
||||
pubData.publisher.readHandlers.OnFrame(pubData.trackID, s.streamType, buf[:n])
|
||||
}()
|
||||
}
|
||||
|
5
track.go
5
track.go
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
// Track is a track available in a certain URL.
|
||||
type Track struct {
|
||||
// base url
|
||||
// base URL
|
||||
BaseURL *base.URL
|
||||
|
||||
// id
|
||||
@@ -204,7 +204,7 @@ func (t *Track) URL() (*base.URL, error) {
|
||||
type Tracks []*Track
|
||||
|
||||
// ReadTracks decodes tracks from SDP.
|
||||
func ReadTracks(byts []byte) (Tracks, error) {
|
||||
func ReadTracks(byts []byte, baseURL *base.URL) (Tracks, error) {
|
||||
desc := sdp.SessionDescription{}
|
||||
err := desc.Unmarshal(byts)
|
||||
if err != nil {
|
||||
@@ -215,6 +215,7 @@ func ReadTracks(byts []byte) (Tracks, error) {
|
||||
|
||||
for i, media := range desc.MediaDescriptions {
|
||||
tracks[i] = &Track{
|
||||
BaseURL: baseURL,
|
||||
ID: i,
|
||||
Media: media,
|
||||
}
|
||||
|
@@ -72,7 +72,7 @@ func TestTrackClockRate(t *testing.T) {
|
||||
},
|
||||
} {
|
||||
t.Run(ca.name, func(t *testing.T) {
|
||||
tracks, err := ReadTracks(ca.sdp)
|
||||
tracks, err := ReadTracks(ca.sdp, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
clockRate, err := tracks[0].ClockRate()
|
||||
|
Reference in New Issue
Block a user