server: add path and query to all contexts

This commit is contained in:
aler9
2021-03-16 22:03:46 +01:00
parent 1fc6c9e661
commit c2de28c185
2 changed files with 167 additions and 76 deletions

View File

@@ -40,14 +40,14 @@ func stringsReverseIndex(s, substr string) int {
return -1 return -1
} }
func extractTrackIDAndPath(url *base.URL, func setupGetTrackIDPathQuery(url *base.URL,
thMode *headers.TransportMode, thMode *headers.TransportMode,
announcedTracks []ServerConnAnnouncedTrack, announcedTracks []ServerConnAnnouncedTrack,
setupPath *string) (int, string, error) { setupPath *string, setupQuery *string) (int, string, string, error) {
pathAndQuery, ok := url.RTSPPathAndQuery() pathAndQuery, ok := url.RTSPPathAndQuery()
if !ok { if !ok {
return 0, "", fmt.Errorf("invalid URL (%s)", url) return 0, "", "", fmt.Errorf("invalid URL (%s)", url)
} }
if thMode == nil || *thMode == headers.TransportModePlay { if thMode == nil || *thMode == headers.TransportModePlay {
@@ -56,38 +56,40 @@ func extractTrackIDAndPath(url *base.URL,
// URL doesn't contain trackID - it's track zero // URL doesn't contain trackID - it's track zero
if i < 0 { if i < 0 {
if !strings.HasSuffix(pathAndQuery, "/") { if !strings.HasSuffix(pathAndQuery, "/") {
return 0, "", fmt.Errorf("path must end with a slash (%v)", pathAndQuery) return 0, "", "", fmt.Errorf("path must end with a slash (%v)", pathAndQuery)
} }
pathAndQuery = pathAndQuery[:len(pathAndQuery)-1] pathAndQuery = pathAndQuery[:len(pathAndQuery)-1]
path, query := base.PathSplitQuery(pathAndQuery)
// we assume it's track 0 // we assume it's track 0
return 0, pathAndQuery, nil return 0, path, query, nil
} }
tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64) tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
if err != nil || tmp < 0 { if err != nil || tmp < 0 {
return 0, "", fmt.Errorf("unable to parse track ID (%v)", pathAndQuery) return 0, "", "", fmt.Errorf("unable to parse track ID (%v)", pathAndQuery)
} }
trackID := int(tmp) trackID := int(tmp)
pathAndQuery = pathAndQuery[:i] pathAndQuery = pathAndQuery[:i]
path, _ := base.PathSplitQuery(pathAndQuery) path, query := base.PathSplitQuery(pathAndQuery)
if setupPath != nil && path != *setupPath { if setupPath != nil && (path != *setupPath || query != *setupQuery) {
return 0, "", fmt.Errorf("can't setup tracks with different paths") return 0, "", "", fmt.Errorf("can't setup tracks with different paths")
} }
return trackID, path, nil return trackID, path, query, nil
} }
for trackID, track := range announcedTracks { for trackID, track := range announcedTracks {
u, _ := track.track.URL() u, _ := track.track.URL()
if u.String() == url.String() { if u.String() == url.String() {
return trackID, *setupPath, nil return trackID, *setupPath, *setupQuery, nil
} }
} }
return 0, "", fmt.Errorf("invalid track path (%s)", pathAndQuery) return 0, "", "", fmt.Errorf("invalid track path (%s)", pathAndQuery)
} }
// ServerConnState is the state of the connection. // ServerConnState is the state of the connection.
@@ -134,56 +136,75 @@ type ServerConnAnnouncedTrack struct {
// ServerConnOptionsCtx is the context of a OPTIONS request. // ServerConnOptionsCtx is the context of a OPTIONS request.
type ServerConnOptionsCtx struct { type ServerConnOptionsCtx struct {
Req *base.Request Req *base.Request
Path string
Query string
} }
// ServerConnDescribeCtx is the context of a DESCRIBE request. // ServerConnDescribeCtx is the context of a DESCRIBE request.
type ServerConnDescribeCtx struct { type ServerConnDescribeCtx struct {
Req *base.Request Req *base.Request
Path string
Query string
} }
// ServerConnAnnounceCtx is the context of a ANNOUNCE request. // ServerConnAnnounceCtx is the context of a ANNOUNCE request.
type ServerConnAnnounceCtx struct { type ServerConnAnnounceCtx struct {
Req *base.Request Req *base.Request
Path string
Query string
Tracks Tracks Tracks Tracks
} }
// ServerConnSetupCtx is the context of a OPTIONS request. // ServerConnSetupCtx is the context of a OPTIONS request.
type ServerConnSetupCtx struct { type ServerConnSetupCtx struct {
Req *base.Request Req *base.Request
Transport *headers.Transport
Path string Path string
Query string
TrackID int TrackID int
Transport *headers.Transport
} }
// ServerConnPlayCtx is the context of a PLAY request. // ServerConnPlayCtx is the context of a PLAY request.
type ServerConnPlayCtx struct { type ServerConnPlayCtx struct {
Req *base.Request Req *base.Request
Path string
Query string
} }
// ServerConnRecordCtx is the context of a RECORD request. // ServerConnRecordCtx is the context of a RECORD request.
type ServerConnRecordCtx struct { type ServerConnRecordCtx struct {
Req *base.Request Req *base.Request
Path string
Query string
} }
// ServerConnPauseCtx is the context of a PAUSE request. // ServerConnPauseCtx is the context of a PAUSE request.
type ServerConnPauseCtx struct { type ServerConnPauseCtx struct {
Req *base.Request Req *base.Request
Path string
Query string
} }
// ServerConnGetParameterCtx is the context of a GET_PARAMETER request. // ServerConnGetParameterCtx is the context of a GET_PARAMETER request.
type ServerConnGetParameterCtx struct { type ServerConnGetParameterCtx struct {
Req *base.Request Req *base.Request
Path string
Query string
} }
// ServerConnSetParameterCtx is the context of a SET_PARAMETER request. // ServerConnSetParameterCtx is the context of a SET_PARAMETER request.
type ServerConnSetParameterCtx struct { type ServerConnSetParameterCtx struct {
Req *base.Request Req *base.Request
Path string
Query string
} }
// ServerConnTeardownCtx is the context of a TEARDOWN request. // ServerConnTeardownCtx is the context of a TEARDOWN request.
type ServerConnTeardownCtx struct { type ServerConnTeardownCtx struct {
Req *base.Request Req *base.Request
Path string
Query string
} }
// ServerConnReadHandlers allows to set the handlers required by ServerConn.Read. // ServerConnReadHandlers allows to set the handlers required by ServerConn.Read.
@@ -245,6 +266,7 @@ type ServerConn struct {
setuppedTracks map[int]ServerConnSetuppedTrack setuppedTracks map[int]ServerConnSetuppedTrack
setupProtocol *StreamProtocol setupProtocol *StreamProtocol
setupPath *string setupPath *string
setupQuery *string
// frame mode only // frame mode only
doEnableFrames bool doEnableFrames bool
@@ -453,8 +475,19 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
switch req.Method { switch req.Method {
case base.Options: case base.Options:
if sc.readHandlers.OnOptions != nil { if sc.readHandlers.OnOptions != nil {
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
path, query := base.PathSplitQuery(pathAndQuery)
return sc.readHandlers.OnOptions(&ServerConnOptionsCtx{ return sc.readHandlers.OnOptions(&ServerConnOptionsCtx{
Req: req, Req: req,
Path: path,
Query: query,
}) })
} }
@@ -501,8 +534,19 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, err }, err
} }
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
path, query := base.PathSplitQuery(pathAndQuery)
res, sdp, err := sc.readHandlers.OnDescribe(&ServerConnDescribeCtx{ res, sdp, err := sc.readHandlers.OnDescribe(&ServerConnDescribeCtx{
Req: req, Req: req,
Path: path,
Query: query,
}) })
if res.StatusCode == base.StatusOK && sdp != nil { if res.StatusCode == base.StatusOK && sdp != nil {
@@ -555,13 +599,15 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, errors.New("no tracks defined") }, errors.New("no tracks defined")
} }
reqPath, ok := req.URL.RTSPPath() pathAndQuery, ok := req.URL.RTSPPath()
if !ok { if !ok {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, errors.New("invalid path") }, fmt.Errorf("invalid URL (%s)", req.URL)
} }
path, query := base.PathSplitQuery(pathAndQuery)
for _, track := range tracks { for _, track := range tracks {
trackURL, err := track.URL() trackURL, err := track.URL()
if err != nil { if err != nil {
@@ -577,22 +623,25 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, fmt.Errorf("invalid track URL (%v)", trackURL) }, fmt.Errorf("invalid track URL (%v)", trackURL)
} }
if !strings.HasPrefix(trackPath, reqPath) { if !strings.HasPrefix(trackPath, path) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid track path: must begin with '%s', but is '%s'", }, fmt.Errorf("invalid track path: must begin with '%s', but is '%s'",
reqPath, trackPath) path, trackPath)
} }
} }
res, err := sc.readHandlers.OnAnnounce(&ServerConnAnnounceCtx{ res, err := sc.readHandlers.OnAnnounce(&ServerConnAnnounceCtx{
Req: req, Req: req,
Path: path,
Query: query,
Tracks: tracks, Tracks: tracks,
}) })
if res.StatusCode == base.StatusOK { if res.StatusCode == base.StatusOK {
sc.state = ServerConnStatePreRecord sc.state = ServerConnStatePreRecord
sc.setupPath = &reqPath sc.setupPath = &path
sc.setupQuery = &query
sc.announcedTracks = make([]ServerConnAnnouncedTrack, len(tracks)) sc.announcedTracks = make([]ServerConnAnnouncedTrack, len(tracks))
for trackID, track := range tracks { for trackID, track := range tracks {
@@ -636,8 +685,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, nil }, nil
} }
trackID, path, err := extractTrackIDAndPath(req.URL, th.Mode, trackID, path, query, err := setupGetTrackIDPathQuery(req.URL, th.Mode,
sc.announcedTracks, sc.setupPath) sc.announcedTracks, sc.setupPath, sc.setupQuery)
if err != nil { if err != nil {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
@@ -703,9 +752,10 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
res, err := sc.readHandlers.OnSetup(&ServerConnSetupCtx{ res, err := sc.readHandlers.OnSetup(&ServerConnSetupCtx{
Req: req, Req: req,
Transport: th,
Path: path, Path: path,
Query: query,
TrackID: trackID, TrackID: trackID,
Transport: th,
}) })
if res.StatusCode == base.StatusOK { if res.StatusCode == base.StatusOK {
@@ -751,6 +801,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case ServerConnStateInitial: case ServerConnStateInitial:
sc.state = ServerConnStatePrePlay sc.state = ServerConnStatePrePlay
sc.setupPath = &path sc.setupPath = &path
sc.setupQuery = &query
} }
// workaround to prevent a bug in rtspclientsink // workaround to prevent a bug in rtspclientsink
@@ -787,8 +838,22 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, fmt.Errorf("no tracks have been setup") }, fmt.Errorf("no tracks have been setup")
} }
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := base.PathSplitQuery(pathAndQuery)
res, err := sc.readHandlers.OnPlay(&ServerConnPlayCtx{ res, err := sc.readHandlers.OnPlay(&ServerConnPlayCtx{
Req: req, Req: req,
Path: path,
Query: query,
}) })
if res.StatusCode == base.StatusOK && sc.state != ServerConnStatePlay { if res.StatusCode == base.StatusOK && sc.state != ServerConnStatePlay {
@@ -822,8 +887,22 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, fmt.Errorf("not all announced tracks have been setup") }, fmt.Errorf("not all announced tracks have been setup")
} }
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := base.PathSplitQuery(pathAndQuery)
res, err := sc.readHandlers.OnRecord(&ServerConnRecordCtx{ res, err := sc.readHandlers.OnRecord(&ServerConnRecordCtx{
Req: req, Req: req,
Path: path,
Query: query,
}) })
if res.StatusCode == base.StatusOK { if res.StatusCode == base.StatusOK {
@@ -848,8 +927,22 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, err }, err
} }
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
// path can end with a slash due to Content-Base, remove it
pathAndQuery = strings.TrimSuffix(pathAndQuery, "/")
path, query := base.PathSplitQuery(pathAndQuery)
res, err := sc.readHandlers.OnPause(&ServerConnPauseCtx{ res, err := sc.readHandlers.OnPause(&ServerConnPauseCtx{
Req: req, Req: req,
Path: path,
Query: query,
}) })
if res.StatusCode == base.StatusOK { if res.StatusCode == base.StatusOK {
@@ -869,8 +962,19 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.GetParameter: case base.GetParameter:
if sc.readHandlers.OnGetParameter != nil { if sc.readHandlers.OnGetParameter != nil {
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
path, query := base.PathSplitQuery(pathAndQuery)
return sc.readHandlers.OnGetParameter(&ServerConnGetParameterCtx{ return sc.readHandlers.OnGetParameter(&ServerConnGetParameterCtx{
Req: req, Req: req,
Path: path,
Query: query,
}) })
} }
@@ -885,15 +989,37 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.SetParameter: case base.SetParameter:
if sc.readHandlers.OnSetParameter != nil { if sc.readHandlers.OnSetParameter != nil {
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
path, query := base.PathSplitQuery(pathAndQuery)
return sc.readHandlers.OnSetParameter(&ServerConnSetParameterCtx{ return sc.readHandlers.OnSetParameter(&ServerConnSetParameterCtx{
Req: req, Req: req,
Path: path,
Query: query,
}) })
} }
case base.Teardown: case base.Teardown:
if sc.readHandlers.OnTeardown != nil { if sc.readHandlers.OnTeardown != nil {
pathAndQuery, ok := req.URL.RTSPPath()
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid URL (%s)", req.URL)
}
path, query := base.PathSplitQuery(pathAndQuery)
return sc.readHandlers.OnTeardown(&ServerConnTeardownCtx{ return sc.readHandlers.OnTeardown(&ServerConnTeardownCtx{
Req: req, Req: req,
Path: path,
Query: query,
}) })
} }

View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -78,14 +77,7 @@ func (ts *testServ) handleConn(conn *ServerConn) {
defer conn.Close() defer conn.Close()
onDescribe := func(ctx *ServerConnDescribeCtx) (*base.Response, []byte, error) { onDescribe := func(ctx *ServerConnDescribeCtx) (*base.Response, []byte, error) {
reqPath, ok := ctx.Req.URL.RTSPPath() if ctx.Path != "teststream" {
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, nil, fmt.Errorf("invalid path (%s)", ctx.Req.URL)
}
if reqPath != "teststream" {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, nil, fmt.Errorf("invalid path (%s)", ctx.Req.URL) }, nil, fmt.Errorf("invalid path (%s)", ctx.Req.URL)
@@ -106,14 +98,7 @@ func (ts *testServ) handleConn(conn *ServerConn) {
} }
onAnnounce := func(ctx *ServerConnAnnounceCtx) (*base.Response, error) { onAnnounce := func(ctx *ServerConnAnnounceCtx) (*base.Response, error) {
reqPath, ok := ctx.Req.URL.RTSPPath() if ctx.Path != "teststream" {
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", ctx.Req.URL)
}
if reqPath != "teststream" {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", ctx.Req.URL) }, fmt.Errorf("invalid path (%s)", ctx.Req.URL)
@@ -155,17 +140,7 @@ func (ts *testServ) handleConn(conn *ServerConn) {
} }
onPlay := func(ctx *ServerConnPlayCtx) (*base.Response, error) { onPlay := func(ctx *ServerConnPlayCtx) (*base.Response, error) {
reqPath, ok := ctx.Req.URL.RTSPPath() if ctx.Path != "teststream" {
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", ctx.Req.URL)
}
// path can end with a slash, remove it
reqPath = strings.TrimSuffix(reqPath, "/")
if reqPath != "teststream" {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", ctx.Req.URL) }, fmt.Errorf("invalid path (%s)", ctx.Req.URL)
@@ -185,17 +160,7 @@ func (ts *testServ) handleConn(conn *ServerConn) {
} }
onRecord := func(ctx *ServerConnRecordCtx) (*base.Response, error) { onRecord := func(ctx *ServerConnRecordCtx) (*base.Response, error) {
reqPath, ok := ctx.Req.URL.RTSPPath() if ctx.Path != "teststream" {
if !ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", ctx.Req.URL)
}
// path can end with a slash, remove it
reqPath = strings.TrimSuffix(reqPath, "/")
if reqPath != "teststream" {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, fmt.Errorf("invalid path (%s)", ctx.Req.URL) }, fmt.Errorf("invalid path (%s)", ctx.Req.URL)