From c2de28c185583f58b3b73e34d7fc4b1750b4fdcd Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Tue, 16 Mar 2021 22:03:46 +0100 Subject: [PATCH] server: add path and query to all contexts --- serverconn.go | 200 ++++++++++++++++++++++++++++++++++++--------- serverconn_test.go | 43 +--------- 2 files changed, 167 insertions(+), 76 deletions(-) diff --git a/serverconn.go b/serverconn.go index 191b69d3..f64d2a49 100644 --- a/serverconn.go +++ b/serverconn.go @@ -40,14 +40,14 @@ func stringsReverseIndex(s, substr string) int { return -1 } -func extractTrackIDAndPath(url *base.URL, +func setupGetTrackIDPathQuery(url *base.URL, thMode *headers.TransportMode, announcedTracks []ServerConnAnnouncedTrack, - setupPath *string) (int, string, error) { + setupPath *string, setupQuery *string) (int, string, string, error) { pathAndQuery, ok := url.RTSPPathAndQuery() if !ok { - return 0, "", fmt.Errorf("invalid URL (%s)", url) + return 0, "", "", fmt.Errorf("invalid URL (%s)", url) } if thMode == nil || *thMode == headers.TransportModePlay { @@ -56,38 +56,40 @@ func extractTrackIDAndPath(url *base.URL, // URL doesn't contain trackID - it's track zero if i < 0 { if !strings.HasSuffix(pathAndQuery, "/") { - return 0, "", fmt.Errorf("path must end with a slash (%v)", pathAndQuery) + return 0, "", "", fmt.Errorf("path must end with a slash (%v)", pathAndQuery) } pathAndQuery = pathAndQuery[:len(pathAndQuery)-1] + path, query := base.PathSplitQuery(pathAndQuery) + // we assume it's track 0 - return 0, pathAndQuery, nil + return 0, path, query, nil } tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64) if err != nil || tmp < 0 { - return 0, "", fmt.Errorf("unable to parse track ID (%v)", pathAndQuery) + return 0, "", "", fmt.Errorf("unable to parse track ID (%v)", pathAndQuery) } trackID := int(tmp) pathAndQuery = pathAndQuery[:i] - path, _ := base.PathSplitQuery(pathAndQuery) + path, query := base.PathSplitQuery(pathAndQuery) - if setupPath != nil && path != *setupPath { - return 0, "", fmt.Errorf("can't setup tracks with different paths") + if setupPath != nil && (path != *setupPath || query != *setupQuery) { + return 0, "", "", fmt.Errorf("can't setup tracks with different paths") } - return trackID, path, nil + return trackID, path, query, nil } for trackID, track := range announcedTracks { u, _ := track.track.URL() if u.String() == url.String() { - return trackID, *setupPath, nil + return trackID, *setupPath, *setupQuery, nil } } - return 0, "", fmt.Errorf("invalid track path (%s)", pathAndQuery) + return 0, "", "", fmt.Errorf("invalid track path (%s)", pathAndQuery) } // ServerConnState is the state of the connection. @@ -134,56 +136,75 @@ type ServerConnAnnouncedTrack struct { // ServerConnOptionsCtx is the context of a OPTIONS request. type ServerConnOptionsCtx struct { - Req *base.Request + Req *base.Request + Path string + Query string } // ServerConnDescribeCtx is the context of a DESCRIBE request. type ServerConnDescribeCtx struct { - Req *base.Request + Req *base.Request + Path string + Query string } // ServerConnAnnounceCtx is the context of a ANNOUNCE request. type ServerConnAnnounceCtx struct { Req *base.Request + Path string + Query string Tracks Tracks } // ServerConnSetupCtx is the context of a OPTIONS request. type ServerConnSetupCtx struct { Req *base.Request - Transport *headers.Transport Path string + Query string TrackID int + Transport *headers.Transport } // ServerConnPlayCtx is the context of a PLAY request. type ServerConnPlayCtx struct { - Req *base.Request + Req *base.Request + Path string + Query string } // ServerConnRecordCtx is the context of a RECORD request. type ServerConnRecordCtx struct { - Req *base.Request + Req *base.Request + Path string + Query string } // ServerConnPauseCtx is the context of a PAUSE request. type ServerConnPauseCtx struct { - Req *base.Request + Req *base.Request + Path string + Query string } // ServerConnGetParameterCtx is the context of a GET_PARAMETER request. type ServerConnGetParameterCtx struct { - Req *base.Request + Req *base.Request + Path string + Query string } // ServerConnSetParameterCtx is the context of a SET_PARAMETER request. type ServerConnSetParameterCtx struct { - Req *base.Request + Req *base.Request + Path string + Query string } // ServerConnTeardownCtx is the context of a TEARDOWN request. type ServerConnTeardownCtx struct { - Req *base.Request + Req *base.Request + Path string + Query string } // ServerConnReadHandlers allows to set the handlers required by ServerConn.Read. @@ -245,6 +266,7 @@ type ServerConn struct { setuppedTracks map[int]ServerConnSetuppedTrack setupProtocol *StreamProtocol setupPath *string + setupQuery *string // frame mode only doEnableFrames bool @@ -453,8 +475,19 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { switch req.Method { case base.Options: if sc.readHandlers.OnOptions != nil { + pathAndQuery, ok := req.URL.RTSPPath() + if !ok { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid URL (%s)", req.URL) + } + + path, query := base.PathSplitQuery(pathAndQuery) + return sc.readHandlers.OnOptions(&ServerConnOptionsCtx{ - Req: req, + Req: req, + Path: path, + Query: query, }) } @@ -501,8 +534,19 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, err } + pathAndQuery, ok := req.URL.RTSPPath() + if !ok { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid URL (%s)", req.URL) + } + + path, query := base.PathSplitQuery(pathAndQuery) + res, sdp, err := sc.readHandlers.OnDescribe(&ServerConnDescribeCtx{ - Req: req, + Req: req, + Path: path, + Query: query, }) if res.StatusCode == base.StatusOK && sdp != nil { @@ -555,13 +599,15 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, errors.New("no tracks defined") } - reqPath, ok := req.URL.RTSPPath() + pathAndQuery, ok := req.URL.RTSPPath() if !ok { return &base.Response{ StatusCode: base.StatusBadRequest, - }, errors.New("invalid path") + }, fmt.Errorf("invalid URL (%s)", req.URL) } + path, query := base.PathSplitQuery(pathAndQuery) + for _, track := range tracks { trackURL, err := track.URL() if err != nil { @@ -577,22 +623,25 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, fmt.Errorf("invalid track URL (%v)", trackURL) } - if !strings.HasPrefix(trackPath, reqPath) { + if !strings.HasPrefix(trackPath, path) { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("invalid track path: must begin with '%s', but is '%s'", - reqPath, trackPath) + path, trackPath) } } res, err := sc.readHandlers.OnAnnounce(&ServerConnAnnounceCtx{ Req: req, + Path: path, + Query: query, Tracks: tracks, }) if res.StatusCode == base.StatusOK { sc.state = ServerConnStatePreRecord - sc.setupPath = &reqPath + sc.setupPath = &path + sc.setupQuery = &query sc.announcedTracks = make([]ServerConnAnnouncedTrack, len(tracks)) for trackID, track := range tracks { @@ -636,8 +685,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, nil } - trackID, path, err := extractTrackIDAndPath(req.URL, th.Mode, - sc.announcedTracks, sc.setupPath) + trackID, path, query, err := setupGetTrackIDPathQuery(req.URL, th.Mode, + sc.announcedTracks, sc.setupPath, sc.setupQuery) if err != nil { return &base.Response{ StatusCode: base.StatusBadRequest, @@ -703,9 +752,10 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { res, err := sc.readHandlers.OnSetup(&ServerConnSetupCtx{ Req: req, - Transport: th, Path: path, + Query: query, TrackID: trackID, + Transport: th, }) if res.StatusCode == base.StatusOK { @@ -751,6 +801,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { case ServerConnStateInitial: sc.state = ServerConnStatePrePlay sc.setupPath = &path + sc.setupQuery = &query } // workaround to prevent a bug in rtspclientsink @@ -787,8 +838,22 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, fmt.Errorf("no tracks have been setup") } + pathAndQuery, ok := req.URL.RTSPPath() + if !ok { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid URL (%s)", req.URL) + } + + // path can end with a slash due to Content-Base, remove it + pathAndQuery = strings.TrimSuffix(pathAndQuery, "/") + + path, query := base.PathSplitQuery(pathAndQuery) + res, err := sc.readHandlers.OnPlay(&ServerConnPlayCtx{ - Req: req, + Req: req, + Path: path, + Query: query, }) if res.StatusCode == base.StatusOK && sc.state != ServerConnStatePlay { @@ -822,8 +887,22 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, fmt.Errorf("not all announced tracks have been setup") } + pathAndQuery, ok := req.URL.RTSPPath() + if !ok { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid URL (%s)", req.URL) + } + + // path can end with a slash due to Content-Base, remove it + pathAndQuery = strings.TrimSuffix(pathAndQuery, "/") + + path, query := base.PathSplitQuery(pathAndQuery) + res, err := sc.readHandlers.OnRecord(&ServerConnRecordCtx{ - Req: req, + Req: req, + Path: path, + Query: query, }) if res.StatusCode == base.StatusOK { @@ -848,8 +927,22 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, err } + pathAndQuery, ok := req.URL.RTSPPath() + if !ok { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid URL (%s)", req.URL) + } + + // path can end with a slash due to Content-Base, remove it + pathAndQuery = strings.TrimSuffix(pathAndQuery, "/") + + path, query := base.PathSplitQuery(pathAndQuery) + res, err := sc.readHandlers.OnPause(&ServerConnPauseCtx{ - Req: req, + Req: req, + Path: path, + Query: query, }) if res.StatusCode == base.StatusOK { @@ -869,8 +962,19 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { case base.GetParameter: if sc.readHandlers.OnGetParameter != nil { + pathAndQuery, ok := req.URL.RTSPPath() + if !ok { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid URL (%s)", req.URL) + } + + path, query := base.PathSplitQuery(pathAndQuery) + return sc.readHandlers.OnGetParameter(&ServerConnGetParameterCtx{ - Req: req, + Req: req, + Path: path, + Query: query, }) } @@ -885,15 +989,37 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { case base.SetParameter: if sc.readHandlers.OnSetParameter != nil { + pathAndQuery, ok := req.URL.RTSPPath() + if !ok { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid URL (%s)", req.URL) + } + + path, query := base.PathSplitQuery(pathAndQuery) + return sc.readHandlers.OnSetParameter(&ServerConnSetParameterCtx{ - Req: req, + Req: req, + Path: path, + Query: query, }) } case base.Teardown: if sc.readHandlers.OnTeardown != nil { + pathAndQuery, ok := req.URL.RTSPPath() + if !ok { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid URL (%s)", req.URL) + } + + path, query := base.PathSplitQuery(pathAndQuery) + return sc.readHandlers.OnTeardown(&ServerConnTeardownCtx{ - Req: req, + Req: req, + Path: path, + Query: query, }) } diff --git a/serverconn_test.go b/serverconn_test.go index dda6cb66..ed94a585 100644 --- a/serverconn_test.go +++ b/serverconn_test.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "net" - "strings" "sync" "testing" "time" @@ -78,14 +77,7 @@ func (ts *testServ) handleConn(conn *ServerConn) { defer conn.Close() onDescribe := func(ctx *ServerConnDescribeCtx) (*base.Response, []byte, error) { - reqPath, ok := ctx.Req.URL.RTSPPath() - if !ok { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, nil, fmt.Errorf("invalid path (%s)", ctx.Req.URL) - } - - if reqPath != "teststream" { + if ctx.Path != "teststream" { return &base.Response{ StatusCode: base.StatusBadRequest, }, nil, fmt.Errorf("invalid path (%s)", ctx.Req.URL) @@ -106,14 +98,7 @@ func (ts *testServ) handleConn(conn *ServerConn) { } onAnnounce := func(ctx *ServerConnAnnounceCtx) (*base.Response, error) { - reqPath, ok := ctx.Req.URL.RTSPPath() - if !ok { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid path (%s)", ctx.Req.URL) - } - - if reqPath != "teststream" { + if ctx.Path != "teststream" { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("invalid path (%s)", ctx.Req.URL) @@ -155,17 +140,7 @@ func (ts *testServ) handleConn(conn *ServerConn) { } onPlay := func(ctx *ServerConnPlayCtx) (*base.Response, error) { - reqPath, ok := ctx.Req.URL.RTSPPath() - if !ok { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid path (%s)", ctx.Req.URL) - } - - // path can end with a slash, remove it - reqPath = strings.TrimSuffix(reqPath, "/") - - if reqPath != "teststream" { + if ctx.Path != "teststream" { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("invalid path (%s)", ctx.Req.URL) @@ -185,17 +160,7 @@ func (ts *testServ) handleConn(conn *ServerConn) { } onRecord := func(ctx *ServerConnRecordCtx) (*base.Response, error) { - reqPath, ok := ctx.Req.URL.RTSPPath() - if !ok { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid path (%s)", ctx.Req.URL) - } - - // path can end with a slash, remove it - reqPath = strings.TrimSuffix(reqPath, "/") - - if reqPath != "teststream" { + if ctx.Path != "teststream" { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("invalid path (%s)", ctx.Req.URL)