server: add path and query to all contexts

This commit is contained in:
aler9
2021-03-16 22:03:46 +01:00
parent 1fc6c9e661
commit c2de28c185
2 changed files with 167 additions and 76 deletions

View File

@@ -40,14 +40,14 @@ func stringsReverseIndex(s, substr string) int {
return -1
}
func extractTrackIDAndPath(url *base.URL,
func setupGetTrackIDPathQuery(url *base.URL,
thMode *headers.TransportMode,
announcedTracks []ServerConnAnnouncedTrack,
setupPath *string) (int, string, error) {
setupPath *string, setupQuery *string) (int, string, string, error) {
pathAndQuery, ok := url.RTSPPathAndQuery()
if !ok {
return 0, "", fmt.Errorf("invalid URL (%s)", url)
return 0, "", "", fmt.Errorf("invalid URL (%s)", url)
}
if thMode == nil || *thMode == headers.TransportModePlay {
@@ -56,38 +56,40 @@ func extractTrackIDAndPath(url *base.URL,
// URL doesn't contain trackID - it's track zero
if i < 0 {
if !strings.HasSuffix(pathAndQuery, "/") {
return 0, "", fmt.Errorf("path must end with a slash (%v)", pathAndQuery)
return 0, "", "", fmt.Errorf("path must end with a slash (%v)", pathAndQuery)
}
pathAndQuery = pathAndQuery[:len(pathAndQuery)-1]
path, query := base.PathSplitQuery(pathAndQuery)
// we assume it's track 0
return 0, pathAndQuery, nil
return 0, path, query, nil
}
tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
if err != nil || tmp < 0 {
return 0, "", fmt.Errorf("unable to parse track ID (%v)", pathAndQuery)
return 0, "", "", fmt.Errorf("unable to parse track ID (%v)", pathAndQuery)
}
trackID := int(tmp)
pathAndQuery = pathAndQuery[:i]
path, _ := base.PathSplitQuery(pathAndQuery)
path, query := base.PathSplitQuery(pathAndQuery)
if setupPath != nil && path != *setupPath {
return 0, "", fmt.Errorf("can't setup tracks with different paths")
if setupPath != nil && (path != *setupPath || query != *setupQuery) {
return 0, "", "", fmt.Errorf("can't setup tracks with different paths")
}
return trackID, path, nil
return trackID, path, query, nil
}
for trackID, track := range announcedTracks {
u, _ := track.track.URL()
if u.String() == url.String() {
return trackID, *setupPath, nil
return trackID, *setupPath, *setupQuery, nil
}
}
return 0, "", fmt.Errorf("invalid track path (%s)", pathAndQuery)
return 0, "", "", fmt.Errorf("invalid track path (%s)", pathAndQuery)
}
// ServerConnState is the state of the connection.
@@ -135,55 +137,74 @@ type ServerConnAnnouncedTrack struct {
// ServerConnOptionsCtx is the context of a OPTIONS request.
type ServerConnOptionsCtx struct {
Req *base.Request
Path string
Query string
}
// ServerConnDescribeCtx is the context of a DESCRIBE request.
type ServerConnDescribeCtx struct {
Req *base.Request
Path string
Query string
}
// ServerConnAnnounceCtx is the context of a ANNOUNCE request.
type ServerConnAnnounceCtx struct {
Req *base.Request
Path string
Query string
Tracks Tracks
}
// ServerConnSetupCtx is the context of a OPTIONS request.
type ServerConnSetupCtx struct {
Req *base.Request
Transport *headers.Transport
Path string
Query string
TrackID int
Transport *headers.Transport
}
// ServerConnPlayCtx is the context of a PLAY request.
type ServerConnPlayCtx struct {
Req *base.Request
Path string
Query string
}
// ServerConnRecordCtx is the context of a RECORD request.
type ServerConnRecordCtx struct {
Req *base.Request
Path string
Query string
}
// ServerConnPauseCtx is the context of a PAUSE request.
type ServerConnPauseCtx struct {
Req *base.Request
Path string
Query string
}
// ServerConnGetParameterCtx is the context of a GET_PARAMETER request.
type ServerConnGetParameterCtx struct {
Req *base.Request
Path string
Query string
}
// ServerConnSetParameterCtx is the context of a SET_PARAMETER request.
type ServerConnSetParameterCtx struct {
Req *base.Request
Path string
Query string
}
// ServerConnTeardownCtx is the context of a TEARDOWN request.
type ServerConnTeardownCtx struct {
Req *base.Request
Path string
Query string
}
// ServerConnReadHandlers allows to set the handlers required by ServerConn.Read.
@@ -245,6 +266,7 @@ type ServerConn struct {
setuppedTracks map[int]ServerConnSetuppedTrack
setupProtocol *StreamProtocol
setupPath *string
setupQuery *string
// frame mode only
doEnableFrames bool
@@ -453,8 +475,19 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
switch req.Method {
case base.Options:
if sc.readHandlers.OnOptions != nil {
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
path, query := base.PathSplitQuery(pathAndQuery)
return sc.readHandlers.OnOptions(&ServerConnOptionsCtx{
Req: req,
Path: path,
Query: query,
})
}
@@ -501,8 +534,19 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, err
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
path, query := base.PathSplitQuery(pathAndQuery)
res, sdp, err := sc.readHandlers.OnDescribe(&ServerConnDescribeCtx{
Req: req,
Path: path,
Query: query,
})
if res.StatusCode == base.StatusOK && sdp != nil {
@@ -555,13 +599,15 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, errors.New("no tracks defined")
}
reqPath, ok := req.URL.RTSPPath()
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, errors.New("invalid path")
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
path, query := base.PathSplitQuery(pathAndQuery)
for _, track := range tracks {
trackURL, err := track.URL()
if err != nil {
@@ -577,22 +623,25 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, fmt.Errorf("invalid track URL (%v)", trackURL)
}
if !strings.HasPrefix(trackPath, reqPath) {
if !strings.HasPrefix(trackPath, path) {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid track path: must begin with '%s', but is '%s'",
reqPath, trackPath)
path, trackPath)
}
}
res, err := sc.readHandlers.OnAnnounce(&ServerConnAnnounceCtx{
Req: req,
Path: path,
Query: query,
Tracks: tracks,
})
if res.StatusCode == base.StatusOK {
sc.state = ServerConnStatePreRecord
sc.setupPath = &reqPath
sc.setupPath = &path
sc.setupQuery = &query
sc.announcedTracks = make([]ServerConnAnnouncedTrack, len(tracks))
for trackID, track := range tracks {
@@ -636,8 +685,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, nil
}
trackID, path, err := extractTrackIDAndPath(req.URL, th.Mode,
sc.announcedTracks, sc.setupPath)
trackID, path, query, err := setupGetTrackIDPathQuery(req.URL, th.Mode,
sc.announcedTracks, sc.setupPath, sc.setupQuery)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
@@ -703,9 +752,10 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
res, err := sc.readHandlers.OnSetup(&ServerConnSetupCtx{
Req: req,
Transport: th,
Path: path,
Query: query,
TrackID: trackID,
Transport: th,
})
if res.StatusCode == base.StatusOK {
@@ -751,6 +801,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case ServerConnStateInitial:
sc.state = ServerConnStatePrePlay
sc.setupPath = &path
sc.setupQuery = &query
}
// workaround to prevent a bug in rtspclientsink
@@ -787,8 +838,22 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, fmt.Errorf("no tracks have been setup")
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := base.PathSplitQuery(pathAndQuery)
res, err := sc.readHandlers.OnPlay(&ServerConnPlayCtx{
Req: req,
Path: path,
Query: query,
})
if res.StatusCode == base.StatusOK && sc.state != ServerConnStatePlay {
@@ -822,8 +887,22 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, fmt.Errorf("not all announced tracks have been setup")
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := base.PathSplitQuery(pathAndQuery)
res, err := sc.readHandlers.OnRecord(&ServerConnRecordCtx{
Req: req,
Path: path,
Query: query,
})
if res.StatusCode == base.StatusOK {
@@ -848,8 +927,22 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, err
}
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := base.PathSplitQuery(pathAndQuery)
res, err := sc.readHandlers.OnPause(&ServerConnPauseCtx{
Req: req,
Path: path,
Query: query,
})
if res.StatusCode == base.StatusOK {
@@ -869,8 +962,19 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.GetParameter:
if sc.readHandlers.OnGetParameter != nil {
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
path, query := base.PathSplitQuery(pathAndQuery)
return sc.readHandlers.OnGetParameter(&ServerConnGetParameterCtx{
Req: req,
Path: path,
Query: query,
})
}
@@ -885,15 +989,37 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.SetParameter:
if sc.readHandlers.OnSetParameter != nil {
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
path, query := base.PathSplitQuery(pathAndQuery)
return sc.readHandlers.OnSetParameter(&ServerConnSetParameterCtx{
Req: req,
Path: path,
Query: query,
})
}
case base.Teardown:
if sc.readHandlers.OnTeardown != nil {
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
path, query := base.PathSplitQuery(pathAndQuery)
return sc.readHandlers.OnTeardown(&ServerConnTeardownCtx{
Req: req,
Path: path,
Query: query,
})
}

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"io"
"net"
"strings"
"sync"
"testing"
"time"
@@ -78,14 +77,7 @@ func (ts *testServ) handleConn(conn *ServerConn) {
defer conn.Close()
onDescribe := func(ctx *ServerConnDescribeCtx) (*base.Response, []byte, error) {
reqPath, ok := ctx.Req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, nil, fmt.Errorf("invalid path (%s)", ctx.Req.URL)
}
if reqPath != "teststream" {
if ctx.Path != "teststream" {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, nil, fmt.Errorf("invalid path (%s)", ctx.Req.URL)
@@ -106,14 +98,7 @@ func (ts *testServ) handleConn(conn *ServerConn) {
}
onAnnounce := func(ctx *ServerConnAnnounceCtx) (*base.Response, error) {
reqPath, ok := ctx.Req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", ctx.Req.URL)
}
if reqPath != "teststream" {
if ctx.Path != "teststream" {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", ctx.Req.URL)
@@ -155,17 +140,7 @@ func (ts *testServ) handleConn(conn *ServerConn) {
}
onPlay := func(ctx *ServerConnPlayCtx) (*base.Response, error) {
reqPath, ok := ctx.Req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", ctx.Req.URL)
}
// path can end with a slash, remove it
reqPath = strings.TrimSuffix(reqPath, "/")
if reqPath != "teststream" {
if ctx.Path != "teststream" {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", ctx.Req.URL)
@@ -185,17 +160,7 @@ func (ts *testServ) handleConn(conn *ServerConn) {
}
onRecord := func(ctx *ServerConnRecordCtx) (*base.Response, error) {
reqPath, ok := ctx.Req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", ctx.Req.URL)
}
// path can end with a slash, remove it
reqPath = strings.TrimSuffix(reqPath, "/")
if reqPath != "teststream" {
if ctx.Path != "teststream" {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", ctx.Req.URL)