rewrite URL functions

This commit is contained in:
aler9
2021-01-17 22:46:11 +01:00
parent 64422b391e
commit d54a602e20
8 changed files with 33 additions and 65 deletions

View File

@@ -74,7 +74,7 @@ func handleConn(conn *gortsplib.ServerConn) {
} }
// called after receiving a SETUP request. // called after receiving a SETUP request.
onSetup := func(req *base.Request, th *headers.Transport, basePath string, trackID int) (*base.Response, error) { onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{

View File

@@ -73,7 +73,7 @@ func handleConn(conn *gortsplib.ServerConn) {
} }
// called after receiving a SETUP request. // called after receiving a SETUP request.
onSetup := func(req *base.Request, th *headers.Transport, basePath string, trackID int) (*base.Response, error) { onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{

View File

@@ -73,7 +73,7 @@ func handleConn(conn *gortsplib.ServerConn) {
} }
// called after receiving a SETUP request. // called after receiving a SETUP request.
onSetup := func(req *base.Request, th *headers.Transport, basePath string, trackID int) (*base.Response, error) { onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{

View File

@@ -91,7 +91,7 @@ func (va *Validator) GenerateHeader() base.HeaderValue {
// ValidateHeader validates the Authorization header sent by a client after receiving the // ValidateHeader validates the Authorization header sent by a client after receiving the
// WWW-Authenticate header. // WWW-Authenticate header.
func (va *Validator) ValidateHeader(v base.HeaderValue, method base.Method, ur *base.URL, func (va *Validator) ValidateHeader(v base.HeaderValue, method base.Method, ur *base.URL,
altUrl *base.URL) error { altURL *base.URL) error {
if len(v) == 0 { if len(v) == 0 {
return fmt.Errorf("authorization header not provided") return fmt.Errorf("authorization header not provided")
} }
@@ -176,8 +176,8 @@ func (va *Validator) ValidateHeader(v base.HeaderValue, method base.Method, ur *
if *auth.URI != urlString { if *auth.URI != urlString {
// do another try with the alternative URL // do another try with the alternative URL
if altUrl != nil { if altURL != nil {
urlString = altUrl.String() urlString = altURL.String()
} }
if *auth.URI != urlString { if *auth.URI != urlString {

View File

@@ -3,7 +3,6 @@ package base
import ( import (
"fmt" "fmt"
"net/url" "net/url"
"strings"
) )
func stringsReverseIndexByte(s string, c byte) int { func stringsReverseIndexByte(s string, c byte) int {
@@ -16,7 +15,7 @@ func stringsReverseIndexByte(s string, c byte) int {
} }
// URL is a RTSP URL. // URL is a RTSP URL.
// This is basically an HTTP url with some additional functions to handle // This is basically an HTTP URL with some additional functions to handle
// control attributes. // control attributes.
type URL url.URL type URL url.URL
@@ -75,9 +74,8 @@ func (u *URL) CloneWithoutCredentials() *URL {
}) })
} }
// BasePath returns the base path of a RTSP URL. // RTSPPath returns the path of a RTSP URL.
// We assume that the URL doesn't contain a control attribute. func (u *URL) RTSPPath() (string, bool) {
func (u *URL) BasePath() (string, bool) {
var path string var path string
if u.RawPath != "" { if u.RawPath != "" {
path = u.RawPath path = u.RawPath
@@ -94,10 +92,8 @@ func (u *URL) BasePath() (string, bool) {
return path, true return path, true
} }
// BasePathControlAttr returns the base path and the control attribute of a RTSP URL. // RTSPPathAndQuery returns the path and the query of a RTSP URL.
// We assume that the URL contains a control attribute. func (u *URL) RTSPPathAndQuery() (string, bool) {
// We assume that the base path and control attribute are divided with a slash.
func (u *URL) BasePathControlAttr() (string, string, bool) {
var pathAndQuery string var pathAndQuery string
if u.RawPath != "" { if u.RawPath != "" {
pathAndQuery = u.RawPath pathAndQuery = u.RawPath
@@ -110,33 +106,11 @@ func (u *URL) BasePathControlAttr() (string, string, bool) {
// remove leading slash // remove leading slash
if len(pathAndQuery) == 0 || pathAndQuery[0] != '/' { if len(pathAndQuery) == 0 || pathAndQuery[0] != '/' {
return "", "", false return "", false
} }
pathAndQuery = pathAndQuery[1:] pathAndQuery = pathAndQuery[1:]
pos := stringsReverseIndexByte(pathAndQuery, '/') return pathAndQuery, true
if pos < 0 {
return "", "", false
}
basePath := pathAndQuery[:pos]
// remove query from basePath
i := strings.IndexByte(basePath, '?')
if i >= 0 {
basePath = basePath[:i]
}
if len(basePath) == 0 {
return "", "", false
}
controlPath := pathAndQuery[pos+1:]
if len(controlPath) == 0 {
return "", "", false
}
return basePath, controlPath, true
} }
// AddControlAttribute adds a control attribute to a RTSP url. // AddControlAttribute adds a control attribute to a RTSP url.

View File

@@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestURLBasePath(t *testing.T) { func TestURLRTSPPath(t *testing.T) {
for _, ca := range []struct { for _, ca := range []struct {
u *URL u *URL
b string b string
@@ -32,48 +32,41 @@ func TestURLBasePath(t *testing.T) {
"user=tmp&password=BagRep1!&channel=1&stream=0.sdp", "user=tmp&password=BagRep1!&channel=1&stream=0.sdp",
}, },
} { } {
b, ok := ca.u.BasePath() b, ok := ca.u.RTSPPath()
require.Equal(t, true, ok) require.Equal(t, true, ok)
require.Equal(t, ca.b, b) require.Equal(t, ca.b, b)
} }
} }
func TestURLBasePathControlAttr(t *testing.T) { func TestURLRTSPPathAndQuery(t *testing.T) {
for _, ca := range []struct { for _, ca := range []struct {
u *URL u *URL
b string b string
c string
}{ }{
{ {
MustParseURL("rtsp://localhost:8554/teststream/trackID=1"), MustParseURL("rtsp://localhost:8554/teststream/trackID=1"),
"teststream", "teststream/trackID=1",
"trackID=1",
}, },
{ {
MustParseURL("rtsp://localhost:8554/test/stream/trackID=1"), MustParseURL("rtsp://localhost:8554/test/stream/trackID=1"),
"test/stream", "test/stream/trackID=1",
"trackID=1",
}, },
{ {
MustParseURL("rtsp://192.168.1.99:554/test?user=tmp&password=BagRep1&channel=1&stream=0.sdp/trackID=1"), MustParseURL("rtsp://192.168.1.99:554/test?user=tmp&password=BagRep1&channel=1&stream=0.sdp/trackID=1"),
"test", "test?user=tmp&password=BagRep1&channel=1&stream=0.sdp/trackID=1",
"trackID=1",
}, },
{ {
MustParseURL("rtsp://192.168.1.99:554/te!st?user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1"), MustParseURL("rtsp://192.168.1.99:554/te!st?user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1"),
"te!st", "te!st?user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1",
"trackID=1",
}, },
{ {
MustParseURL("rtsp://192.168.1.99:554/user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1"), MustParseURL("rtsp://192.168.1.99:554/user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1"),
"user=tmp&password=BagRep1!&channel=1&stream=0.sdp", "user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1",
"trackID=1",
}, },
} { } {
b, c, ok := ca.u.BasePathControlAttr() b, ok := ca.u.RTSPPathAndQuery()
require.Equal(t, true, ok) require.Equal(t, true, ok)
require.Equal(t, ca.b, b) require.Equal(t, ca.b, b)
require.Equal(t, ca.c, c)
} }
} }

View File

@@ -142,7 +142,7 @@ func (ts *testServ) handleConn(conn *ServerConn) {
}, nil }, nil
} }
onSetup := func(req *base.Request, th *headers.Transport, basePath string, trackID int) (*base.Response, error) { onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{

View File

@@ -67,15 +67,16 @@ type ServerConnTrack struct {
rtcpPort int rtcpPort int
} }
func extractTrackID(controlPath string, mode *headers.TransportMode, trackLen int) (int, error) { func extractTrackID(pathAndQuery string, mode *headers.TransportMode, trackLen int) (int, error) {
if mode == nil || *mode == headers.TransportModePlay { if mode == nil || *mode == headers.TransportModePlay {
if !strings.HasPrefix(controlPath, "trackID=") { i := strings.Index(pathAndQuery, "/trackID=")
return 0, fmt.Errorf("invalid control attribute (%s)", controlPath) if i < 0 {
return 0, fmt.Errorf("unable to find control attribute (%s)", pathAndQuery)
} }
tmp, err := strconv.ParseInt(controlPath[len("trackID="):], 10, 64) tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64)
if err != nil || tmp < 0 { if err != nil || tmp < 0 {
return 0, fmt.Errorf("invalid track id (%s)", controlPath) return 0, fmt.Errorf("invalid track id (%s)", pathAndQuery)
} }
trackID := int(tmp) trackID := int(tmp)
@@ -105,7 +106,7 @@ type ServerConnReadHandlers struct {
OnAnnounce func(req *base.Request, tracks Tracks) (*base.Response, error) OnAnnounce func(req *base.Request, tracks Tracks) (*base.Response, error)
// called after receiving a SETUP request. // called after receiving a SETUP request.
OnSetup func(req *base.Request, th *headers.Transport, basePath string, trackID int) (*base.Response, error) OnSetup func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error)
// called after receiving a PLAY request. // called after receiving a PLAY request.
OnPlay func(req *base.Request) (*base.Response, error) OnPlay func(req *base.Request) (*base.Response, error)
@@ -457,7 +458,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, err }, err
} }
basePath, controlPath, ok := req.URL.BasePathControlAttr() pathAndQuery, ok := req.URL.RTSPPathAndQuery()
if !ok { if !ok {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
@@ -477,7 +478,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}, fmt.Errorf("multicast is not supported") }, fmt.Errorf("multicast is not supported")
} }
trackID, err := extractTrackID(controlPath, th.Mode, len(sc.tracks)) trackID, err := extractTrackID(pathAndQuery, th.Mode, len(sc.tracks))
if err != nil { if err != nil {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
@@ -541,7 +542,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
} }
} }
res, err := sc.readHandlers.OnSetup(req, th, basePath, trackID) res, err := sc.readHandlers.OnSetup(req, th, trackID)
if res.StatusCode == 200 { if res.StatusCode == 200 {
sc.tracksProtocol = &th.Protocol sc.tracksProtocol = &th.Protocol