diff --git a/examples/server-tls.go b/examples/server-tls.go index d0a8252b..180e805a 100644 --- a/examples/server-tls.go +++ b/examples/server-tls.go @@ -74,7 +74,7 @@ func handleConn(conn *gortsplib.ServerConn) { } // called after receiving a SETUP request. - onSetup := func(req *base.Request, th *headers.Transport, basePath string, trackID int) (*base.Response, error) { + onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ diff --git a/examples/server-udp.go b/examples/server-udp.go index b5e00102..efee9fda 100644 --- a/examples/server-udp.go +++ b/examples/server-udp.go @@ -73,7 +73,7 @@ func handleConn(conn *gortsplib.ServerConn) { } // called after receiving a SETUP request. - onSetup := func(req *base.Request, th *headers.Transport, basePath string, trackID int) (*base.Response, error) { + onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ diff --git a/examples/server.go b/examples/server.go index dbfec8bb..990a4b92 100644 --- a/examples/server.go +++ b/examples/server.go @@ -73,7 +73,7 @@ func handleConn(conn *gortsplib.ServerConn) { } // called after receiving a SETUP request. - onSetup := func(req *base.Request, th *headers.Transport, basePath string, trackID int) (*base.Response, error) { + onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ diff --git a/pkg/auth/validator.go b/pkg/auth/validator.go index 45513d6b..1e1242a9 100644 --- a/pkg/auth/validator.go +++ b/pkg/auth/validator.go @@ -91,7 +91,7 @@ func (va *Validator) GenerateHeader() base.HeaderValue { // ValidateHeader validates the Authorization header sent by a client after receiving the // WWW-Authenticate header. func (va *Validator) ValidateHeader(v base.HeaderValue, method base.Method, ur *base.URL, - altUrl *base.URL) error { + altURL *base.URL) error { if len(v) == 0 { return fmt.Errorf("authorization header not provided") } @@ -176,8 +176,8 @@ func (va *Validator) ValidateHeader(v base.HeaderValue, method base.Method, ur * if *auth.URI != urlString { // do another try with the alternative URL - if altUrl != nil { - urlString = altUrl.String() + if altURL != nil { + urlString = altURL.String() } if *auth.URI != urlString { diff --git a/pkg/base/url.go b/pkg/base/url.go index 58a3ac50..f254504a 100644 --- a/pkg/base/url.go +++ b/pkg/base/url.go @@ -3,7 +3,6 @@ package base import ( "fmt" "net/url" - "strings" ) func stringsReverseIndexByte(s string, c byte) int { @@ -16,7 +15,7 @@ func stringsReverseIndexByte(s string, c byte) int { } // URL is a RTSP URL. -// This is basically an HTTP url with some additional functions to handle +// This is basically an HTTP URL with some additional functions to handle // control attributes. type URL url.URL @@ -75,9 +74,8 @@ func (u *URL) CloneWithoutCredentials() *URL { }) } -// BasePath returns the base path of a RTSP URL. -// We assume that the URL doesn't contain a control attribute. -func (u *URL) BasePath() (string, bool) { +// RTSPPath returns the path of a RTSP URL. +func (u *URL) RTSPPath() (string, bool) { var path string if u.RawPath != "" { path = u.RawPath @@ -94,10 +92,8 @@ func (u *URL) BasePath() (string, bool) { return path, true } -// BasePathControlAttr returns the base path and the control attribute of a RTSP URL. -// We assume that the URL contains a control attribute. -// We assume that the base path and control attribute are divided with a slash. -func (u *URL) BasePathControlAttr() (string, string, bool) { +// RTSPPathAndQuery returns the path and the query of a RTSP URL. +func (u *URL) RTSPPathAndQuery() (string, bool) { var pathAndQuery string if u.RawPath != "" { pathAndQuery = u.RawPath @@ -110,33 +106,11 @@ func (u *URL) BasePathControlAttr() (string, string, bool) { // remove leading slash if len(pathAndQuery) == 0 || pathAndQuery[0] != '/' { - return "", "", false + return "", false } pathAndQuery = pathAndQuery[1:] - pos := stringsReverseIndexByte(pathAndQuery, '/') - if pos < 0 { - return "", "", false - } - - basePath := pathAndQuery[:pos] - - // remove query from basePath - i := strings.IndexByte(basePath, '?') - if i >= 0 { - basePath = basePath[:i] - } - - if len(basePath) == 0 { - return "", "", false - } - - controlPath := pathAndQuery[pos+1:] - if len(controlPath) == 0 { - return "", "", false - } - - return basePath, controlPath, true + return pathAndQuery, true } // AddControlAttribute adds a control attribute to a RTSP url. diff --git a/pkg/base/url_test.go b/pkg/base/url_test.go index df129af2..99bdf2c7 100644 --- a/pkg/base/url_test.go +++ b/pkg/base/url_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestURLBasePath(t *testing.T) { +func TestURLRTSPPath(t *testing.T) { for _, ca := range []struct { u *URL b string @@ -32,48 +32,41 @@ func TestURLBasePath(t *testing.T) { "user=tmp&password=BagRep1!&channel=1&stream=0.sdp", }, } { - b, ok := ca.u.BasePath() + b, ok := ca.u.RTSPPath() require.Equal(t, true, ok) require.Equal(t, ca.b, b) } } -func TestURLBasePathControlAttr(t *testing.T) { +func TestURLRTSPPathAndQuery(t *testing.T) { for _, ca := range []struct { u *URL b string - c string }{ { MustParseURL("rtsp://localhost:8554/teststream/trackID=1"), - "teststream", - "trackID=1", + "teststream/trackID=1", }, { MustParseURL("rtsp://localhost:8554/test/stream/trackID=1"), - "test/stream", - "trackID=1", + "test/stream/trackID=1", }, { MustParseURL("rtsp://192.168.1.99:554/test?user=tmp&password=BagRep1&channel=1&stream=0.sdp/trackID=1"), - "test", - "trackID=1", + "test?user=tmp&password=BagRep1&channel=1&stream=0.sdp/trackID=1", }, { MustParseURL("rtsp://192.168.1.99:554/te!st?user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1"), - "te!st", - "trackID=1", + "te!st?user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1", }, { MustParseURL("rtsp://192.168.1.99:554/user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1"), - "user=tmp&password=BagRep1!&channel=1&stream=0.sdp", - "trackID=1", + "user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1", }, } { - b, c, ok := ca.u.BasePathControlAttr() + b, ok := ca.u.RTSPPathAndQuery() require.Equal(t, true, ok) require.Equal(t, ca.b, b) - require.Equal(t, ca.c, c) } } diff --git a/serverconf_test.go b/serverconf_test.go index 6a96d03d..cd26600f 100644 --- a/serverconf_test.go +++ b/serverconf_test.go @@ -142,7 +142,7 @@ func (ts *testServ) handleConn(conn *ServerConn) { }, nil } - onSetup := func(req *base.Request, th *headers.Transport, basePath string, trackID int) (*base.Response, error) { + onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ diff --git a/serverconn.go b/serverconn.go index ae4b0649..3b386547 100644 --- a/serverconn.go +++ b/serverconn.go @@ -67,15 +67,16 @@ type ServerConnTrack struct { rtcpPort int } -func extractTrackID(controlPath string, mode *headers.TransportMode, trackLen int) (int, error) { +func extractTrackID(pathAndQuery string, mode *headers.TransportMode, trackLen int) (int, error) { if mode == nil || *mode == headers.TransportModePlay { - if !strings.HasPrefix(controlPath, "trackID=") { - return 0, fmt.Errorf("invalid control attribute (%s)", controlPath) + i := strings.Index(pathAndQuery, "/trackID=") + if i < 0 { + return 0, fmt.Errorf("unable to find control attribute (%s)", pathAndQuery) } - tmp, err := strconv.ParseInt(controlPath[len("trackID="):], 10, 64) + tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64) if err != nil || tmp < 0 { - return 0, fmt.Errorf("invalid track id (%s)", controlPath) + return 0, fmt.Errorf("invalid track id (%s)", pathAndQuery) } trackID := int(tmp) @@ -105,7 +106,7 @@ type ServerConnReadHandlers struct { OnAnnounce func(req *base.Request, tracks Tracks) (*base.Response, error) // called after receiving a SETUP request. - OnSetup func(req *base.Request, th *headers.Transport, basePath string, trackID int) (*base.Response, error) + OnSetup func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) // called after receiving a PLAY request. OnPlay func(req *base.Request) (*base.Response, error) @@ -457,7 +458,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, err } - basePath, controlPath, ok := req.URL.BasePathControlAttr() + pathAndQuery, ok := req.URL.RTSPPathAndQuery() if !ok { return &base.Response{ StatusCode: base.StatusBadRequest, @@ -477,7 +478,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, fmt.Errorf("multicast is not supported") } - trackID, err := extractTrackID(controlPath, th.Mode, len(sc.tracks)) + trackID, err := extractTrackID(pathAndQuery, th.Mode, len(sc.tracks)) if err != nil { return &base.Response{ StatusCode: base.StatusBadRequest, @@ -541,7 +542,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } } - res, err := sc.readHandlers.OnSetup(req, th, basePath, trackID) + res, err := sc.readHandlers.OnSetup(req, th, trackID) if res.StatusCode == 200 { sc.tracksProtocol = &th.Protocol