diff --git a/client_play_test.go b/client_play_test.go index c4a0c737..73ceb068 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -1328,7 +1328,7 @@ func TestClientPlayAutomaticProtocol(t *testing.T) { require.NoError(t, err2) require.Equal(t, base.Describe, req.Method) - err2 = auth.Validate(req, "myuser", "mypass", nil, "IPCAM", nonce) + err2 = auth.Verify(req, "myuser", "mypass", nil, "IPCAM", nonce) require.NoError(t, err2) err2 = conn.WriteResponse(&base.Response{ @@ -1440,7 +1440,7 @@ func TestClientPlayAutomaticProtocol(t *testing.T) { require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/"+medias[0].Control), req.URL) - err2 = auth.Validate(req, "myuser", "mypass", nil, "IPCAM", nonce) + err2 = auth.Verify(req, "myuser", "mypass", nil, "IPCAM", nonce) require.NoError(t, err2) var inTH headers.Transport diff --git a/client_test.go b/client_test.go index dd338aa0..e13da6e7 100644 --- a/client_test.go +++ b/client_test.go @@ -287,7 +287,7 @@ func TestClientAuth(t *testing.T) { require.NoError(t, err2) require.Equal(t, base.Describe, req.Method) - err2 = auth.Validate(req, "myuser", "mypass", nil, "IPCAM", nonce) + err2 = auth.Verify(req, "myuser", "mypass", nil, "IPCAM", nonce) require.NoError(t, err2) medias := []*description.Media{testH264Media} diff --git a/pkg/auth/testdata/fuzz/FuzzValidate/771e938e4458e983 b/pkg/auth/testdata/fuzz/FuzzVerify/771e938e4458e983 similarity index 100% rename from pkg/auth/testdata/fuzz/FuzzValidate/771e938e4458e983 rename to pkg/auth/testdata/fuzz/FuzzVerify/771e938e4458e983 diff --git a/pkg/auth/validate.go b/pkg/auth/validate.go index 86691292..23acc701 100644 --- a/pkg/auth/validate.go +++ b/pkg/auth/validate.go @@ -1,67 +1,26 @@ package auth import ( - "crypto/md5" - "crypto/sha256" - "encoding/hex" - "fmt" - "regexp" - "github.com/bluenviron/gortsplib/v4/pkg/base" - "github.com/bluenviron/gortsplib/v4/pkg/headers" ) -var reControlAttribute = regexp.MustCompile("^(.+/)trackID=[0-9]+$") - -func md5Hex(in string) string { - h := md5.New() - h.Write([]byte(in)) - return hex.EncodeToString(h.Sum(nil)) -} - -func sha256Hex(in string) string { - h := sha256.New() - h.Write([]byte(in)) - return hex.EncodeToString(h.Sum(nil)) -} - -func contains(list []ValidateMethod, item ValidateMethod) bool { - for _, i := range list { - if i == item { - return true - } - } - return false -} - -func urlMatches(expected string, received string, isSetup bool) bool { - if received == expected { - return true - } - - // in SETUP requests, VLC uses the base URL of the stream - // instead of the URL of the track. - // Strip the control attribute to obtain the URL of the stream. - if isSetup { - if m := reControlAttribute.FindStringSubmatch(expected); m != nil && received == m[1] { - return true - } - } - - return false -} - // ValidateMethod is a validation method. -type ValidateMethod int +// +// Deprecated: replaced by VerifyMethod +type ValidateMethod = VerifyMethod // validation methods. +// +// Deprecated. const ( - ValidateMethodBasic ValidateMethod = iota - ValidateMethodDigestMD5 - ValidateMethodSHA256 + ValidateMethodBasic = VerifyMethodBasic + ValidateMethodDigestMD5 = VerifyMethodDigestMD5 + ValidateMethodSHA256 = VerifyMethodDigestSHA256 ) // Validate validates a request sent by a client. +// +// Deprecated: replaced by Verify. func Validate( req *base.Request, user string, @@ -70,64 +29,5 @@ func Validate( realm string, nonce string, ) error { - if methods == nil { - methods = []ValidateMethod{ValidateMethodBasic, ValidateMethodDigestMD5, ValidateMethodSHA256} - } - - var auth headers.Authorization - err := auth.Unmarshal(req.Header["Authorization"]) - if err != nil { - return err - } - - switch { - case auth.Method == headers.AuthMethodDigest && - (contains(methods, ValidateMethodDigestMD5) && - (auth.Algorithm == nil || *auth.Algorithm == headers.AuthAlgorithmMD5) || - contains(methods, ValidateMethodSHA256) && - auth.Algorithm != nil && *auth.Algorithm == headers.AuthAlgorithmSHA256): - if auth.Nonce != nonce { - return fmt.Errorf("wrong nonce") - } - - if auth.Realm != realm { - return fmt.Errorf("wrong realm") - } - - if auth.Username != user { - return fmt.Errorf("authentication failed") - } - - if !urlMatches(req.URL.String(), auth.URI, req.Method == base.Setup) { - return fmt.Errorf("wrong URL") - } - - var response string - - if auth.Algorithm == nil || *auth.Algorithm == headers.AuthAlgorithmMD5 { - response = md5Hex(md5Hex(user+":"+realm+":"+pass) + - ":" + nonce + ":" + md5Hex(string(req.Method)+":"+auth.URI)) - } else { // sha256 - response = sha256Hex(sha256Hex(user+":"+realm+":"+pass) + - ":" + nonce + ":" + sha256Hex(string(req.Method)+":"+auth.URI)) - } - - if auth.Response != response { - return fmt.Errorf("authentication failed") - } - - case auth.Method == headers.AuthMethodBasic && contains(methods, ValidateMethodBasic): - if auth.BasicUser != user { - return fmt.Errorf("authentication failed") - } - - if auth.BasicPass != pass { - return fmt.Errorf("authentication failed") - } - - default: - return fmt.Errorf("no supported authentication methods found") - } - - return nil + return Verify(req, user, pass, methods, realm, nonce) } diff --git a/pkg/auth/verify.go b/pkg/auth/verify.go new file mode 100644 index 00000000..58e8f77e --- /dev/null +++ b/pkg/auth/verify.go @@ -0,0 +1,133 @@ +package auth + +import ( + "crypto/md5" + "crypto/sha256" + "encoding/hex" + "fmt" + "regexp" + + "github.com/bluenviron/gortsplib/v4/pkg/base" + "github.com/bluenviron/gortsplib/v4/pkg/headers" +) + +var reControlAttribute = regexp.MustCompile("^(.+/)trackID=[0-9]+$") + +func md5Hex(in string) string { + h := md5.New() + h.Write([]byte(in)) + return hex.EncodeToString(h.Sum(nil)) +} + +func sha256Hex(in string) string { + h := sha256.New() + h.Write([]byte(in)) + return hex.EncodeToString(h.Sum(nil)) +} + +func contains(list []VerifyMethod, item VerifyMethod) bool { + for _, i := range list { + if i == item { + return true + } + } + return false +} + +func urlMatches(expected string, received string, isSetup bool) bool { + if received == expected { + return true + } + + // in SETUP requests, VLC uses the base URL of the stream + // instead of the URL of the track. + // Strip the control attribute to obtain the URL of the stream. + if isSetup { + if m := reControlAttribute.FindStringSubmatch(expected); m != nil && received == m[1] { + return true + } + } + + return false +} + +// VerifyMethod is a validation method. +type VerifyMethod int + +// validation methods. +const ( + VerifyMethodBasic VerifyMethod = iota + VerifyMethodDigestMD5 + VerifyMethodDigestSHA256 +) + +// Verify validates a request sent by a client. +func Verify( + req *base.Request, + user string, + pass string, + methods []VerifyMethod, + realm string, + nonce string, +) error { + if methods == nil { + methods = []VerifyMethod{VerifyMethodBasic, VerifyMethodDigestMD5, VerifyMethodDigestSHA256} + } + + var auth headers.Authorization + err := auth.Unmarshal(req.Header["Authorization"]) + if err != nil { + return err + } + + switch { + case auth.Method == headers.AuthMethodDigest && + (contains(methods, VerifyMethodDigestMD5) && + (auth.Algorithm == nil || *auth.Algorithm == headers.AuthAlgorithmMD5) || + contains(methods, VerifyMethodDigestSHA256) && + auth.Algorithm != nil && *auth.Algorithm == headers.AuthAlgorithmSHA256): + if auth.Nonce != nonce { + return fmt.Errorf("wrong nonce") + } + + if auth.Realm != realm { + return fmt.Errorf("wrong realm") + } + + if auth.Username != user { + return fmt.Errorf("authentication failed") + } + + if !urlMatches(req.URL.String(), auth.URI, req.Method == base.Setup) { + return fmt.Errorf("wrong URL") + } + + var response string + + if auth.Algorithm == nil || *auth.Algorithm == headers.AuthAlgorithmMD5 { + response = md5Hex(md5Hex(user+":"+realm+":"+pass) + + ":" + nonce + ":" + md5Hex(string(req.Method)+":"+auth.URI)) + } else { // sha256 + response = sha256Hex(sha256Hex(user+":"+realm+":"+pass) + + ":" + nonce + ":" + sha256Hex(string(req.Method)+":"+auth.URI)) + } + + if auth.Response != response { + return fmt.Errorf("authentication failed") + } + + case auth.Method == headers.AuthMethodBasic && contains(methods, VerifyMethodBasic): + if auth.BasicUser != user { + return fmt.Errorf("authentication failed") + } + + if auth.BasicPass != pass { + return fmt.Errorf("authentication failed") + } + + default: + return fmt.Errorf("no supported authentication methods found") + } + + return nil +} diff --git a/pkg/auth/validate_test.go b/pkg/auth/verify_test.go similarity index 87% rename from pkg/auth/validate_test.go rename to pkg/auth/verify_test.go index 1d675b3d..e3ca0bb0 100644 --- a/pkg/auth/validate_test.go +++ b/pkg/auth/verify_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" ) -var casesValidate = []struct { +var casesVerify = []struct { name string authorization base.HeaderValue }{ @@ -50,11 +50,11 @@ var casesValidate = []struct { }, } -func TestValidate(t *testing.T) { - for _, ca := range casesValidate { +func TestVerify(t *testing.T) { + for _, ca := range casesVerify { t.Run(ca.name, func(t *testing.T) { se, err := NewSender( - GenerateWWWAuthenticate([]ValidateMethod{ValidateMethodDigestMD5}, "myrealm", "f49ac6dd0ba708d4becddc9692d1f2ce"), + GenerateWWWAuthenticate([]VerifyMethod{VerifyMethodDigestMD5}, "myrealm", "f49ac6dd0ba708d4becddc9692d1f2ce"), "myuser", "mypass") require.NoError(t, err) @@ -71,7 +71,7 @@ func TestValidate(t *testing.T) { "Authorization": ca.authorization, }, } - err = Validate( + err = Verify( req, "myuser", "mypass", @@ -83,13 +83,13 @@ func TestValidate(t *testing.T) { } } -func FuzzValidate(f *testing.F) { - for _, ca := range casesValidate { +func FuzzVerify(f *testing.F) { + for _, ca := range casesVerify { f.Add(ca.authorization[0]) } f.Fuzz(func(_ *testing.T, a string) { - Validate( //nolint:errcheck + Verify( //nolint:errcheck &base.Request{ Method: base.Describe, URL: nil, diff --git a/server_test.go b/server_test.go index 9aec0281..0b664f4f 100644 --- a/server_test.go +++ b/server_test.go @@ -1041,7 +1041,7 @@ func TestServerAuth(t *testing.T) { s := &Server{ Handler: &testServerHandler{ onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { - err2 := auth.Validate(ctx.Request, "myuser", "mypass", nil, "IPCAM", nonce) + err2 := auth.Verify(ctx.Request, "myuser", "mypass", nil, "IPCAM", nonce) if err2 != nil { return &base.Response{ //nolint:nilerr StatusCode: base.StatusUnauthorized,