server: provide path to OnSetup()

This commit is contained in:
aler9
2021-03-14 16:41:17 +01:00
parent 378c5639bb
commit d902b7da93
9 changed files with 154 additions and 204 deletions

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"net"
"strconv"
"strings"
"sync/atomic"
"time"
@@ -30,6 +31,61 @@ var (
errServerCSeqMissing = errors.New("CSeq is missing")
)
func stringsReverseIndex(s, substr string) int {
for i := len(s) - 1 - len(substr); i >= 0; i-- {
if s[i:i+len(substr)] == substr {
return i
}
}
return -1
}
func extractTrackIDAndPath(url *base.URL,
thMode *headers.TransportMode,
publishTracks []ServerConnAnnouncedTrack,
publishPath string) (int, string, error) {
pathAndQuery, ok := url.RTSPPathAndQuery()
if !ok {
return 0, "", fmt.Errorf("invalid URL (%s)", url)
}
if thMode == nil || *thMode == headers.TransportModePlay {
i := stringsReverseIndex(pathAndQuery, "/trackID=")
// URL doesn't contain trackID - it's track zero
if i < 0 {
if !strings.HasSuffix(pathAndQuery, "/") {
return 0, "", fmt.Errorf("path must end with a slash (%v)", pathAndQuery)
}
pathAndQuery = pathAndQuery[:len(pathAndQuery)-1]
// we assume it's track 0
return 0, pathAndQuery, nil
}
tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
if err != nil || tmp < 0 {
return 0, "", fmt.Errorf("unable to parse track ID (%v)", pathAndQuery)
}
trackID := int(tmp)
pathAndQuery = pathAndQuery[:i]
path, _ := base.PathSplitQuery(pathAndQuery)
return trackID, path, nil
}
for trackID, track := range publishTracks {
u, _ := track.track.URL()
if u.String() == url.String() {
return trackID, publishPath, nil
}
}
return 0, "", fmt.Errorf("invalid track path (%s)", pathAndQuery)
}
// ServerConnState is the state of the connection.
type ServerConnState int
@@ -92,7 +148,7 @@ type ServerConnReadHandlers struct {
OnAnnounce func(req *base.Request, tracks Tracks) (*base.Response, error)
// called after receiving a SETUP request.
OnSetup func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error)
OnSetup func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error)
// called after receiving a PLAY request.
OnPlay func(req *base.Request) (*base.Response, error)
@@ -127,20 +183,22 @@ type ServerConn struct {
br *bufio.Reader
bw *bufio.Writer
state ServerConnState
readHandlers ServerConnReadHandlers
tracks map[int]ServerConnTrack
streamProtocol *StreamProtocol
announcedTracks []ServerConnAnnouncedTrack
doEnableFrames bool
framesEnabled bool
readTimeoutEnabled bool
// writer
// frame mode only
doEnableFrames bool
framesEnabled bool
readTimeoutEnabled bool
frameRingBuffer *ringbuffer.RingBuffer
backgroundWriteDone chan struct{}
// background record
// read only
readHandlers ServerConnReadHandlers
// publish only
publishPath string
publishTracks []ServerConnAnnouncedTrack
backgroundRecordTerminate chan struct{}
backgroundRecordDone chan struct{}
udpTimeout int32
@@ -457,14 +515,14 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
if res.StatusCode == 200 {
sc.state = ServerConnStatePreRecord
sc.publishPath = reqPath
sc.announcedTracks = make([]ServerConnAnnouncedTrack, len(tracks))
sc.publishTracks = make([]ServerConnAnnouncedTrack, len(tracks))
for trackID, track := range tracks {
clockRate, _ := track.ClockRate()
v := time.Now().Unix()
sc.announcedTracks[trackID] = ServerConnAnnouncedTrack{
sc.publishTracks[trackID] = ServerConnAnnouncedTrack{
track: track,
rtcpReceiver: rtcpreceiver.New(nil, clockRate),
udpLastFrameTime: &v,
@@ -488,13 +546,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, err
}
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", req.URL)
}
th, err := headers.ReadTransport(req.Header["Transport"])
if err != nil {
return &base.Response{
@@ -524,25 +575,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, nil
}
trackID, err := func() (int, error) {
if th.Mode == nil || *th.Mode == headers.TransportModePlay {
trackID, _, ok := base.PathSplitControlAttribute(pathAndQuery)
if !ok {
return 0, fmt.Errorf("invalid track path (%s)", pathAndQuery)
}
return trackID, nil
}
for trackID, track := range sc.announcedTracks {
u, _ := track.track.URL()
if u.String() == req.URL.String() {
return trackID, nil
}
}
return 0, fmt.Errorf("invalid track path (%s)", pathAndQuery)
}()
trackID, path, err := extractTrackIDAndPath(req.URL, th.Mode,
sc.publishTracks, sc.publishPath)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
@@ -590,7 +624,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}
}
res, err := sc.readHandlers.OnSetup(req, th, trackID)
res, err := sc.readHandlers.OnSetup(req, th, path, trackID)
if res.StatusCode == 200 {
sc.streamProtocol = &th.Protocol
@@ -697,7 +731,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, fmt.Errorf("no tracks have been setup")
}
if len(sc.tracks) != len(sc.announcedTracks) {
if len(sc.tracks) != len(sc.publishTracks) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("not all tracks have been setup")
@@ -860,7 +894,7 @@ outer:
// forward frame only if it has been set up
if _, ok := sc.tracks[frame.TrackID]; ok {
if sc.state == ServerConnStateRecord {
sc.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(),
sc.publishTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(),
frame.StreamType, frame.Payload)
}
sc.readHandlers.OnFrame(frame.TrackID, frame.StreamType, frame.Payload)
@@ -961,7 +995,7 @@ func (sc *ServerConn) backgroundRecord() {
}
now := time.Now()
for _, track := range sc.announcedTracks {
for _, track := range sc.publishTracks {
last := time.Unix(atomic.LoadInt64(track.udpLastFrameTime), 0)
if now.Sub(last) >= sc.conf.ReadTimeout {
@@ -973,7 +1007,7 @@ func (sc *ServerConn) backgroundRecord() {
case <-receiverReportTicker.C:
now := time.Now()
for trackID, track := range sc.announcedTracks {
for trackID, track := range sc.publishTracks {
r := track.rtcpReceiver.Report(now)
sc.WriteFrame(trackID, StreamTypeRTP, r)
}