diff --git a/README.md b/README.md index 425316a4..b4a94fb6 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Features: * Pause reading without disconnecting from the server * Generate RTCP receiver reports automatically * Publish - * Publish streams to servers with UDP, TCP or TLS + * Publish streams to servers with UDP, TCP or TLS (RTSPS) * Switch protocol automatically (switch to TCP in case of server error) * Pause publishing without disconnecting from the server * Generate RTCP sender reports automatically diff --git a/pkg/auth/sender.go b/pkg/auth/sender.go index dd68e6b4..a8649712 100644 --- a/pkg/auth/sender.go +++ b/pkg/auth/sender.go @@ -1,7 +1,6 @@ package auth import ( - "encoding/base64" "fmt" "strings" @@ -87,25 +86,28 @@ func NewSender(v base.HeaderValue, user string, pass string) (*Sender, error) { func (se *Sender) GenerateHeader(method base.Method, ur *base.URL) base.HeaderValue { urStr := ur.CloneWithoutCredentials().String() + h := headers.Authorization{ + Method: se.method, + } + switch se.method { case headers.AuthBasic: - response := base64.StdEncoding.EncodeToString([]byte(se.user + ":" + se.pass)) + h.BasicUser = se.user + h.BasicPass = se.pass - return base.HeaderValue{"Basic " + response} - - case headers.AuthDigest: + default: // headers.AuthDigest response := md5Hex(md5Hex(se.user+":"+se.realm+":"+se.pass) + ":" + se.nonce + ":" + md5Hex(string(method)+":"+urStr)) - return headers.Auth{ + h.DigestValues = headers.Auth{ Method: headers.AuthDigest, Username: &se.user, Realm: &se.realm, Nonce: &se.nonce, URI: &urStr, Response: &response, - }.Write() + } } - return nil + return h.Write() } diff --git a/pkg/auth/validator.go b/pkg/auth/validator.go index ff3705b8..d1dec1a8 100644 --- a/pkg/auth/validator.go +++ b/pkg/auth/validator.go @@ -2,7 +2,6 @@ package auth import ( "crypto/rand" - "encoding/base64" "encoding/hex" "fmt" "strings" @@ -90,99 +89,82 @@ func (va *Validator) GenerateHeader() base.HeaderValue { // ValidateHeader validates the Authorization header sent by a client after receiving the // WWW-Authenticate header. -func (va *Validator) ValidateHeader(v base.HeaderValue, method base.Method, ur *base.URL, +func (va *Validator) ValidateHeader( + v base.HeaderValue, + method base.Method, + ur *base.URL, altURL *base.URL) error { - if len(v) == 0 { - return fmt.Errorf("authorization header not provided") - } - if len(v) > 1 { - return fmt.Errorf("authorization header provided multiple times") + + var auth headers.Authorization + err := auth.Read(v) + if err != nil { + return err } - v0 := v[0] - - switch { - case strings.HasPrefix(v0, "Basic "): - inResponse := v0[len("Basic "):] - - tmp, err := base64.StdEncoding.DecodeString(inResponse) - if err != nil { - return fmt.Errorf("wrong response") - } - tmp2 := strings.Split(string(tmp), ":") - if len(tmp2) != 2 { - return fmt.Errorf("wrong response") - } - user, pass := tmp2[0], tmp2[1] - + switch auth.Method { + case headers.AuthBasic: if !va.userHashed { - if user != va.user { + if auth.BasicUser != va.user { return fmt.Errorf("wrong response") } } else { - if sha256Base64(user) != va.user { + if sha256Base64(auth.BasicUser) != va.user { return fmt.Errorf("wrong response") } } if !va.passHashed { - if pass != va.pass { + if auth.BasicPass != va.pass { return fmt.Errorf("wrong response") } } else { - if sha256Base64(pass) != va.pass { + if sha256Base64(auth.BasicPass) != va.pass { return fmt.Errorf("wrong response") } } - case strings.HasPrefix(v0, "Digest "): - var auth headers.Auth - err := auth.Read(base.HeaderValue{v0}) - if err != nil { - return err - } - - if auth.Realm == nil { + default: // headers.AuthDigest + if auth.DigestValues.Realm == nil { return fmt.Errorf("realm not provided") } - if auth.Nonce == nil { + if auth.DigestValues.Nonce == nil { return fmt.Errorf("nonce not provided") } - if auth.Username == nil { + if auth.DigestValues.Username == nil { return fmt.Errorf("username not provided") } - if auth.URI == nil { + if auth.DigestValues.URI == nil { return fmt.Errorf("uri not provided") } - if auth.Response == nil { + if auth.DigestValues.Response == nil { return fmt.Errorf("response not provided") } - if *auth.Nonce != va.nonce { + if *auth.DigestValues.Nonce != va.nonce { return fmt.Errorf("wrong nonce") } - if *auth.Realm != va.realm { + if *auth.DigestValues.Realm != va.realm { return fmt.Errorf("wrong realm") } - if *auth.Username != va.user { + if *auth.DigestValues.Username != va.user { return fmt.Errorf("wrong username") } urlString := ur.String() - if *auth.URI != urlString { + if *auth.DigestValues.URI != urlString { // do another try with the alternative URL if altURL != nil { urlString = altURL.String() } - if *auth.URI != urlString { + if *auth.DigestValues.URI != urlString { return fmt.Errorf("wrong url") } } @@ -190,12 +172,9 @@ func (va *Validator) ValidateHeader(v base.HeaderValue, method base.Method, ur * response := md5Hex(md5Hex(va.user+":"+va.realm+":"+va.pass) + ":" + va.nonce + ":" + md5Hex(string(method)+":"+urlString)) - if *auth.Response != response { + if *auth.DigestValues.Response != response { return fmt.Errorf("wrong response") } - - default: - return fmt.Errorf("unsupported authorization header") } return nil diff --git a/pkg/headers/auth.go b/pkg/headers/authenticate.go similarity index 100% rename from pkg/headers/auth.go rename to pkg/headers/authenticate.go diff --git a/pkg/headers/auth_test.go b/pkg/headers/authenticate_test.go similarity index 100% rename from pkg/headers/auth_test.go rename to pkg/headers/authenticate_test.go diff --git a/pkg/headers/authorization.go b/pkg/headers/authorization.go new file mode 100644 index 00000000..377b864e --- /dev/null +++ b/pkg/headers/authorization.go @@ -0,0 +1,87 @@ +package headers + +import ( + "encoding/base64" + "fmt" + "strings" + + "github.com/aler9/gortsplib/pkg/base" +) + +// Authorization is an Authorization header. +type Authorization struct { + // authentication method + Method AuthMethod + + // basic user + BasicUser string + + // basic password + BasicPass string + + // digest values + DigestValues Auth +} + +// Read decodes an Authorization header. +func (h *Authorization) Read(v base.HeaderValue) error { + if len(v) == 0 { + return fmt.Errorf("value not provided") + } + + if len(v) > 1 { + return fmt.Errorf("value provided multiple times (%v)", v) + } + + v0 := v[0] + + switch { + case strings.HasPrefix(v0, "Basic "): + h.Method = AuthBasic + + v0 = v0[len("Basic "):] + + tmp, err := base64.StdEncoding.DecodeString(v0) + if err != nil { + return fmt.Errorf("invalid value") + } + + tmp2 := strings.Split(string(tmp), ":") + if len(tmp2) != 2 { + return fmt.Errorf("invalid value") + } + + h.BasicUser, h.BasicPass = tmp2[0], tmp2[1] + + case strings.HasPrefix(v0, "Digest "): + h.Method = AuthDigest + + var vals Auth + err := vals.Read(base.HeaderValue{v0}) + if err != nil { + return err + } + + h.DigestValues = vals + + default: + return fmt.Errorf("invalid authorization header") + } + + return nil +} + +// Write encodes an Authorization header. +func (h Authorization) Write() base.HeaderValue { + switch h.Method { + case AuthBasic: + response := base64.StdEncoding.EncodeToString([]byte(h.BasicUser + ":" + h.BasicPass)) + + return base.HeaderValue{"Basic " + response} + + case AuthDigest: + return h.DigestValues.Write() + } + + return nil +} diff --git a/pkg/headers/authorization_test.go b/pkg/headers/authorization_test.go new file mode 100644 index 00000000..55bca86e --- /dev/null +++ b/pkg/headers/authorization_test.go @@ -0,0 +1,95 @@ +package headers + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/aler9/gortsplib/pkg/base" +) + +var casesAuthorization = []struct { + name string + vin base.HeaderValue + vout base.HeaderValue + h Authorization +}{ + { + "basic", + base.HeaderValue{"Basic bXl1c2VyOm15cGFzcw=="}, + base.HeaderValue{"Basic bXl1c2VyOm15cGFzcw=="}, + Authorization{ + Method: AuthBasic, + BasicUser: "myuser", + BasicPass: "mypass", + }, + }, + { + "digest", + base.HeaderValue{"Digest realm=\"4419b63f5e51\", nonce=\"8b84a3b789283a8bea8da7fa7d41f08b\", stale=\"FALSE\""}, + base.HeaderValue{"Digest realm=\"4419b63f5e51\", nonce=\"8b84a3b789283a8bea8da7fa7d41f08b\", stale=\"FALSE\""}, + Authorization{ + Method: AuthDigest, + DigestValues: Auth{ + Method: AuthDigest, + Realm: func() *string { + v := "4419b63f5e51" + return &v + }(), + Nonce: func() *string { + v := "8b84a3b789283a8bea8da7fa7d41f08b" + return &v + }(), + Stale: func() *string { + v := "FALSE" + return &v + }(), + }, + }, + }, +} + +func TestAuthorizationRead(t *testing.T) { + for _, ca := range casesAuthorization { + t.Run(ca.name, func(t *testing.T) { + var h Authorization + err := h.Read(ca.vin) + require.NoError(t, err) + require.Equal(t, ca.h, h) + }) + } +} + +func TestAuthorizationWrite(t *testing.T) { + for _, ca := range casesAuthorization { + t.Run(ca.name, func(t *testing.T) { + vout := ca.h.Write() + require.Equal(t, ca.vout, vout) + }) + } +} + +func TestAuthorizationReadError(t *testing.T) { + for _, ca := range []struct { + name string + hv base.HeaderValue + err string + }{ + { + "empty", + base.HeaderValue{}, + "value not provided", + }, + { + "2 values", + base.HeaderValue{"a", "b"}, + "value provided multiple times ([a b])", + }, + } { + t.Run(ca.name, func(t *testing.T) { + var h Authorization + err := h.Read(ca.hv) + require.Equal(t, ca.err, err.Error()) + }) + } +}