mirror of
https://github.com/aler9/gortsplib
synced 2025-10-05 07:06:58 +08:00
server: provide path to OnSetup()
This commit is contained in:
120
serverconn.go
120
serverconn.go
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -30,6 +31,61 @@ var (
|
||||
errServerCSeqMissing = errors.New("CSeq is missing")
|
||||
)
|
||||
|
||||
func stringsReverseIndex(s, substr string) int {
|
||||
for i := len(s) - 1 - len(substr); i >= 0; i-- {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func extractTrackIDAndPath(url *base.URL,
|
||||
thMode *headers.TransportMode,
|
||||
publishTracks []ServerConnAnnouncedTrack,
|
||||
publishPath string) (int, string, error) {
|
||||
|
||||
pathAndQuery, ok := url.RTSPPathAndQuery()
|
||||
if !ok {
|
||||
return 0, "", fmt.Errorf("invalid URL (%s)", url)
|
||||
}
|
||||
|
||||
if thMode == nil || *thMode == headers.TransportModePlay {
|
||||
i := stringsReverseIndex(pathAndQuery, "/trackID=")
|
||||
|
||||
// URL doesn't contain trackID - it's track zero
|
||||
if i < 0 {
|
||||
if !strings.HasSuffix(pathAndQuery, "/") {
|
||||
return 0, "", fmt.Errorf("path must end with a slash (%v)", pathAndQuery)
|
||||
}
|
||||
pathAndQuery = pathAndQuery[:len(pathAndQuery)-1]
|
||||
|
||||
// we assume it's track 0
|
||||
return 0, pathAndQuery, nil
|
||||
}
|
||||
|
||||
tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
|
||||
if err != nil || tmp < 0 {
|
||||
return 0, "", fmt.Errorf("unable to parse track ID (%v)", pathAndQuery)
|
||||
}
|
||||
trackID := int(tmp)
|
||||
pathAndQuery = pathAndQuery[:i]
|
||||
|
||||
path, _ := base.PathSplitQuery(pathAndQuery)
|
||||
|
||||
return trackID, path, nil
|
||||
}
|
||||
|
||||
for trackID, track := range publishTracks {
|
||||
u, _ := track.track.URL()
|
||||
if u.String() == url.String() {
|
||||
return trackID, publishPath, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, "", fmt.Errorf("invalid track path (%s)", pathAndQuery)
|
||||
}
|
||||
|
||||
// ServerConnState is the state of the connection.
|
||||
type ServerConnState int
|
||||
|
||||
@@ -92,7 +148,7 @@ type ServerConnReadHandlers struct {
|
||||
OnAnnounce func(req *base.Request, tracks Tracks) (*base.Response, error)
|
||||
|
||||
// called after receiving a SETUP request.
|
||||
OnSetup func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error)
|
||||
OnSetup func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error)
|
||||
|
||||
// called after receiving a PLAY request.
|
||||
OnPlay func(req *base.Request) (*base.Response, error)
|
||||
@@ -127,20 +183,22 @@ type ServerConn struct {
|
||||
br *bufio.Reader
|
||||
bw *bufio.Writer
|
||||
state ServerConnState
|
||||
readHandlers ServerConnReadHandlers
|
||||
tracks map[int]ServerConnTrack
|
||||
streamProtocol *StreamProtocol
|
||||
announcedTracks []ServerConnAnnouncedTrack
|
||||
|
||||
doEnableFrames bool
|
||||
framesEnabled bool
|
||||
readTimeoutEnabled bool
|
||||
|
||||
// writer
|
||||
// frame mode only
|
||||
doEnableFrames bool
|
||||
framesEnabled bool
|
||||
readTimeoutEnabled bool
|
||||
frameRingBuffer *ringbuffer.RingBuffer
|
||||
backgroundWriteDone chan struct{}
|
||||
|
||||
// background record
|
||||
// read only
|
||||
readHandlers ServerConnReadHandlers
|
||||
|
||||
// publish only
|
||||
publishPath string
|
||||
publishTracks []ServerConnAnnouncedTrack
|
||||
backgroundRecordTerminate chan struct{}
|
||||
backgroundRecordDone chan struct{}
|
||||
udpTimeout int32
|
||||
@@ -457,14 +515,14 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
|
||||
if res.StatusCode == 200 {
|
||||
sc.state = ServerConnStatePreRecord
|
||||
sc.publishPath = reqPath
|
||||
|
||||
sc.announcedTracks = make([]ServerConnAnnouncedTrack, len(tracks))
|
||||
|
||||
sc.publishTracks = make([]ServerConnAnnouncedTrack, len(tracks))
|
||||
for trackID, track := range tracks {
|
||||
clockRate, _ := track.ClockRate()
|
||||
v := time.Now().Unix()
|
||||
|
||||
sc.announcedTracks[trackID] = ServerConnAnnouncedTrack{
|
||||
sc.publishTracks[trackID] = ServerConnAnnouncedTrack{
|
||||
track: track,
|
||||
rtcpReceiver: rtcpreceiver.New(nil, clockRate),
|
||||
udpLastFrameTime: &v,
|
||||
@@ -488,13 +546,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
}, err
|
||||
}
|
||||
|
||||
pathAndQuery, ok := req.URL.RTSPPathAndQuery()
|
||||
if !ok {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, fmt.Errorf("invalid path (%s)", req.URL)
|
||||
}
|
||||
|
||||
th, err := headers.ReadTransport(req.Header["Transport"])
|
||||
if err != nil {
|
||||
return &base.Response{
|
||||
@@ -524,25 +575,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
trackID, err := func() (int, error) {
|
||||
if th.Mode == nil || *th.Mode == headers.TransportModePlay {
|
||||
trackID, _, ok := base.PathSplitControlAttribute(pathAndQuery)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("invalid track path (%s)", pathAndQuery)
|
||||
}
|
||||
|
||||
return trackID, nil
|
||||
}
|
||||
|
||||
for trackID, track := range sc.announcedTracks {
|
||||
u, _ := track.track.URL()
|
||||
if u.String() == req.URL.String() {
|
||||
return trackID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("invalid track path (%s)", pathAndQuery)
|
||||
}()
|
||||
trackID, path, err := extractTrackIDAndPath(req.URL, th.Mode,
|
||||
sc.publishTracks, sc.publishPath)
|
||||
if err != nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
@@ -590,7 +624,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
}
|
||||
}
|
||||
|
||||
res, err := sc.readHandlers.OnSetup(req, th, trackID)
|
||||
res, err := sc.readHandlers.OnSetup(req, th, path, trackID)
|
||||
|
||||
if res.StatusCode == 200 {
|
||||
sc.streamProtocol = &th.Protocol
|
||||
@@ -697,7 +731,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
}, fmt.Errorf("no tracks have been setup")
|
||||
}
|
||||
|
||||
if len(sc.tracks) != len(sc.announcedTracks) {
|
||||
if len(sc.tracks) != len(sc.publishTracks) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, fmt.Errorf("not all tracks have been setup")
|
||||
@@ -860,7 +894,7 @@ outer:
|
||||
// forward frame only if it has been set up
|
||||
if _, ok := sc.tracks[frame.TrackID]; ok {
|
||||
if sc.state == ServerConnStateRecord {
|
||||
sc.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(),
|
||||
sc.publishTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(),
|
||||
frame.StreamType, frame.Payload)
|
||||
}
|
||||
sc.readHandlers.OnFrame(frame.TrackID, frame.StreamType, frame.Payload)
|
||||
@@ -961,7 +995,7 @@ func (sc *ServerConn) backgroundRecord() {
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for _, track := range sc.announcedTracks {
|
||||
for _, track := range sc.publishTracks {
|
||||
last := time.Unix(atomic.LoadInt64(track.udpLastFrameTime), 0)
|
||||
|
||||
if now.Sub(last) >= sc.conf.ReadTimeout {
|
||||
@@ -973,7 +1007,7 @@ func (sc *ServerConn) backgroundRecord() {
|
||||
|
||||
case <-receiverReportTicker.C:
|
||||
now := time.Now()
|
||||
for trackID, track := range sc.announcedTracks {
|
||||
for trackID, track := range sc.publishTracks {
|
||||
r := track.rtcpReceiver.Report(now)
|
||||
sc.WriteFrame(trackID, StreamTypeRTP, r)
|
||||
}
|
||||
|
Reference in New Issue
Block a user