From 3f62e11795f69b10099ac2fee67db852e44a3c0f Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Fri, 3 May 2024 22:42:50 +0200 Subject: [PATCH] simplify usage of auth.Validate (#557) --- client_play_test.go | 4 ++-- client_test.go | 2 +- pkg/auth/auth_test.go | 16 ++++++++-------- pkg/auth/validate.go | 40 +++++++++++++++++++++++---------------- pkg/auth/validate_test.go | 2 -- server_test.go | 2 +- 6 files changed, 36 insertions(+), 30 deletions(-) diff --git a/client_play_test.go b/client_play_test.go index c4247e93..cc69a8f8 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -1320,7 +1320,7 @@ func TestClientPlayAutomaticProtocol(t *testing.T) { require.NoError(t, err2) require.Equal(t, base.Describe, req.Method) - err2 = auth.Validate(req, "myuser", "mypass", nil, nil, "IPCAM", nonce) + err2 = auth.Validate(req, "myuser", "mypass", nil, "IPCAM", nonce) require.NoError(t, err2) err2 = conn.WriteResponse(&base.Response{ @@ -1432,7 +1432,7 @@ func TestClientPlayAutomaticProtocol(t *testing.T) { require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/"+medias[0].Control), req.URL) - err2 = auth.Validate(req, "myuser", "mypass", nil, nil, "IPCAM", nonce) + err2 = auth.Validate(req, "myuser", "mypass", nil, "IPCAM", nonce) require.NoError(t, err2) var inTH headers.Transport diff --git a/client_test.go b/client_test.go index 6125ba74..c37ae980 100644 --- a/client_test.go +++ b/client_test.go @@ -287,7 +287,7 @@ func TestClientAuth(t *testing.T) { require.NoError(t, err2) require.Equal(t, base.Describe, req.Method) - err2 = auth.Validate(req, "myuser", "mypass", nil, nil, "IPCAM", nonce) + err2 = auth.Validate(req, "myuser", "mypass", nil, "IPCAM", nonce) require.NoError(t, err2) medias := []*description.Media{testH264Media} diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 6f064871..61b8b3b4 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -1,6 +1,7 @@ package auth import ( + "fmt" "testing" "github.com/stretchr/testify/require" @@ -82,7 +83,7 @@ func TestAuth(t *testing.T) { req.URL = mustParseURL("rtsp://myhost/mypath") - err = Validate(req, "testuser", "testpass", nil, c1.methods, "IPCAM", nonce) + err = Validate(req, "testuser", "testpass", c1.methods, "IPCAM", nonce) if conf != "nofail" { require.Error(t, err) @@ -96,8 +97,8 @@ func TestAuth(t *testing.T) { func TestAuthVLC(t *testing.T) { for _, ca := range []struct { - clientURL string - mediaURL string + baseURL string + mediaURL string }{ { "rtsp://myhost/mypath/", @@ -119,15 +120,14 @@ func TestAuthVLC(t *testing.T) { req := &base.Request{ Method: base.Setup, - URL: mustParseURL(ca.clientURL), + URL: mustParseURL(ca.baseURL), } se.AddAuthorization(req) req.URL = mustParseURL(ca.mediaURL) - err = Validate(req, "testuser", "testpass", mustParseURL(ca.clientURL), nil, "IPCAM", nonce) - require.NoError(t, err) + fmt.Println(req.URL, req.Header) - err = Validate(req, "testuser", "testpass", mustParseURL("rtsp://invalid"), nil, "IPCAM", nonce) - require.Error(t, err) + err = Validate(req, "testuser", "testpass", nil, "IPCAM", nonce) + require.NoError(t, err) } } diff --git a/pkg/auth/validate.go b/pkg/auth/validate.go index 50bac2f4..dca2bf81 100644 --- a/pkg/auth/validate.go +++ b/pkg/auth/validate.go @@ -5,11 +5,14 @@ import ( "crypto/sha256" "encoding/hex" "fmt" + "regexp" "github.com/bluenviron/gortsplib/v4/pkg/base" "github.com/bluenviron/gortsplib/v4/pkg/headers" ) +var reControlAttribute = regexp.MustCompile("^(.+/)trackID=[0-9]+$") + func md5Hex(in string) string { h := md5.New() h.Write([]byte(in)) @@ -31,12 +34,28 @@ func contains(list []headers.AuthMethod, item headers.AuthMethod) bool { return false } +func urlMatches(expected string, received string, isSetup bool) bool { + if received == expected { + return true + } + + // in SETUP requests, VLC uses the base URL of the stream + // instead of the URL of the track. + // Strip the control attribute to obtain the URL of the stream. + if isSetup { + if m := reControlAttribute.FindStringSubmatch(expected); m != nil && received == m[1] { + return true + } + } + + return false +} + // Validate validates a request sent by a client. func Validate( req *base.Request, user string, pass string, - baseURL *base.URL, methods []headers.AuthMethod, realm string, nonce string, @@ -66,29 +85,18 @@ func Validate( return fmt.Errorf("authentication failed") } - ur := req.URL - - if auth.URI != ur.String() { - // in SETUP requests, VLC strips the control attribute. - // try again with the base URL. - if baseURL != nil { - ur = baseURL - if auth.URI != ur.String() { - return fmt.Errorf("wrong URL") - } - } else { - return fmt.Errorf("wrong URL") - } + if !urlMatches(req.URL.String(), auth.URI, req.Method == base.Setup) { + return fmt.Errorf("wrong URL") } var response string if auth.Method == headers.AuthDigestSHA256 { response = sha256Hex(sha256Hex(user+":"+realm+":"+pass) + - ":" + nonce + ":" + sha256Hex(string(req.Method)+":"+ur.String())) + ":" + nonce + ":" + sha256Hex(string(req.Method)+":"+auth.URI)) } else { response = md5Hex(md5Hex(user+":"+realm+":"+pass) + - ":" + nonce + ":" + md5Hex(string(req.Method)+":"+ur.String())) + ":" + nonce + ":" + md5Hex(string(req.Method)+":"+auth.URI)) } if auth.Response != response { diff --git a/pkg/auth/validate_test.go b/pkg/auth/validate_test.go index d85d507c..2484682f 100644 --- a/pkg/auth/validate_test.go +++ b/pkg/auth/validate_test.go @@ -30,7 +30,6 @@ func FuzzValidate(f *testing.F) { "myuser", "mypass", nil, - nil, "IPCAM", "abcde", ) @@ -48,7 +47,6 @@ func TestValidateAdditionalErrors(t *testing.T) { }, "myuser", "mypass", - nil, []headers.AuthMethod{headers.AuthDigestMD5}, "IPCAM", "abcde", diff --git a/server_test.go b/server_test.go index 8bb768cf..d88f09d1 100644 --- a/server_test.go +++ b/server_test.go @@ -1041,7 +1041,7 @@ func TestServerAuth(t *testing.T) { s := &Server{ Handler: &testServerHandler{ onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { - err2 := auth.Validate(ctx.Request, "myuser", "mypass", nil, nil, "IPCAM", nonce) + err2 := auth.Validate(ctx.Request, "myuser", "mypass", nil, "IPCAM", nonce) if err2 != nil { return &base.Response{ //nolint:nilerr StatusCode: base.StatusUnauthorized,