mirror of
https://github.com/aler9/gortsplib
synced 2025-10-06 15:46:51 +08:00
headers: add authorization header
This commit is contained in:
@@ -21,7 +21,7 @@ Features:
|
||||
* Pause reading without disconnecting from the server
|
||||
* Generate RTCP receiver reports automatically
|
||||
* Publish
|
||||
* Publish streams to servers with UDP, TCP or TLS
|
||||
* Publish streams to servers with UDP, TCP or TLS (RTSPS)
|
||||
* Switch protocol automatically (switch to TCP in case of server error)
|
||||
* Pause publishing without disconnecting from the server
|
||||
* Generate RTCP sender reports automatically
|
||||
|
@@ -1,7 +1,6 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -87,25 +86,28 @@ func NewSender(v base.HeaderValue, user string, pass string) (*Sender, error) {
|
||||
func (se *Sender) GenerateHeader(method base.Method, ur *base.URL) base.HeaderValue {
|
||||
urStr := ur.CloneWithoutCredentials().String()
|
||||
|
||||
h := headers.Authorization{
|
||||
Method: se.method,
|
||||
}
|
||||
|
||||
switch se.method {
|
||||
case headers.AuthBasic:
|
||||
response := base64.StdEncoding.EncodeToString([]byte(se.user + ":" + se.pass))
|
||||
h.BasicUser = se.user
|
||||
h.BasicPass = se.pass
|
||||
|
||||
return base.HeaderValue{"Basic " + response}
|
||||
|
||||
case headers.AuthDigest:
|
||||
default: // headers.AuthDigest
|
||||
response := md5Hex(md5Hex(se.user+":"+se.realm+":"+se.pass) + ":" +
|
||||
se.nonce + ":" + md5Hex(string(method)+":"+urStr))
|
||||
|
||||
return headers.Auth{
|
||||
h.DigestValues = headers.Auth{
|
||||
Method: headers.AuthDigest,
|
||||
Username: &se.user,
|
||||
Realm: &se.realm,
|
||||
Nonce: &se.nonce,
|
||||
URI: &urStr,
|
||||
Response: &response,
|
||||
}.Write()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return h.Write()
|
||||
}
|
||||
|
@@ -2,7 +2,6 @@ package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
@@ -90,99 +89,82 @@ func (va *Validator) GenerateHeader() base.HeaderValue {
|
||||
|
||||
// ValidateHeader validates the Authorization header sent by a client after receiving the
|
||||
// WWW-Authenticate header.
|
||||
func (va *Validator) ValidateHeader(v base.HeaderValue, method base.Method, ur *base.URL,
|
||||
func (va *Validator) ValidateHeader(
|
||||
v base.HeaderValue,
|
||||
method base.Method,
|
||||
ur *base.URL,
|
||||
altURL *base.URL) error {
|
||||
if len(v) == 0 {
|
||||
return fmt.Errorf("authorization header not provided")
|
||||
}
|
||||
if len(v) > 1 {
|
||||
return fmt.Errorf("authorization header provided multiple times")
|
||||
}
|
||||
|
||||
v0 := v[0]
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(v0, "Basic "):
|
||||
inResponse := v0[len("Basic "):]
|
||||
|
||||
tmp, err := base64.StdEncoding.DecodeString(inResponse)
|
||||
var auth headers.Authorization
|
||||
err := auth.Read(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong response")
|
||||
return err
|
||||
}
|
||||
tmp2 := strings.Split(string(tmp), ":")
|
||||
if len(tmp2) != 2 {
|
||||
return fmt.Errorf("wrong response")
|
||||
}
|
||||
user, pass := tmp2[0], tmp2[1]
|
||||
|
||||
switch auth.Method {
|
||||
case headers.AuthBasic:
|
||||
if !va.userHashed {
|
||||
if user != va.user {
|
||||
if auth.BasicUser != va.user {
|
||||
return fmt.Errorf("wrong response")
|
||||
}
|
||||
} else {
|
||||
if sha256Base64(user) != va.user {
|
||||
if sha256Base64(auth.BasicUser) != va.user {
|
||||
return fmt.Errorf("wrong response")
|
||||
}
|
||||
}
|
||||
|
||||
if !va.passHashed {
|
||||
if pass != va.pass {
|
||||
if auth.BasicPass != va.pass {
|
||||
return fmt.Errorf("wrong response")
|
||||
}
|
||||
} else {
|
||||
if sha256Base64(pass) != va.pass {
|
||||
if sha256Base64(auth.BasicPass) != va.pass {
|
||||
return fmt.Errorf("wrong response")
|
||||
}
|
||||
}
|
||||
|
||||
case strings.HasPrefix(v0, "Digest "):
|
||||
var auth headers.Auth
|
||||
err := auth.Read(base.HeaderValue{v0})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if auth.Realm == nil {
|
||||
default: // headers.AuthDigest
|
||||
if auth.DigestValues.Realm == nil {
|
||||
return fmt.Errorf("realm not provided")
|
||||
}
|
||||
|
||||
if auth.Nonce == nil {
|
||||
if auth.DigestValues.Nonce == nil {
|
||||
return fmt.Errorf("nonce not provided")
|
||||
}
|
||||
|
||||
if auth.Username == nil {
|
||||
if auth.DigestValues.Username == nil {
|
||||
return fmt.Errorf("username not provided")
|
||||
}
|
||||
|
||||
if auth.URI == nil {
|
||||
if auth.DigestValues.URI == nil {
|
||||
return fmt.Errorf("uri not provided")
|
||||
}
|
||||
|
||||
if auth.Response == nil {
|
||||
if auth.DigestValues.Response == nil {
|
||||
return fmt.Errorf("response not provided")
|
||||
}
|
||||
|
||||
if *auth.Nonce != va.nonce {
|
||||
if *auth.DigestValues.Nonce != va.nonce {
|
||||
return fmt.Errorf("wrong nonce")
|
||||
}
|
||||
|
||||
if *auth.Realm != va.realm {
|
||||
if *auth.DigestValues.Realm != va.realm {
|
||||
return fmt.Errorf("wrong realm")
|
||||
}
|
||||
|
||||
if *auth.Username != va.user {
|
||||
if *auth.DigestValues.Username != va.user {
|
||||
return fmt.Errorf("wrong username")
|
||||
}
|
||||
|
||||
urlString := ur.String()
|
||||
|
||||
if *auth.URI != urlString {
|
||||
if *auth.DigestValues.URI != urlString {
|
||||
// do another try with the alternative URL
|
||||
if altURL != nil {
|
||||
urlString = altURL.String()
|
||||
}
|
||||
|
||||
if *auth.URI != urlString {
|
||||
if *auth.DigestValues.URI != urlString {
|
||||
return fmt.Errorf("wrong url")
|
||||
}
|
||||
}
|
||||
@@ -190,12 +172,9 @@ func (va *Validator) ValidateHeader(v base.HeaderValue, method base.Method, ur *
|
||||
response := md5Hex(md5Hex(va.user+":"+va.realm+":"+va.pass) +
|
||||
":" + va.nonce + ":" + md5Hex(string(method)+":"+urlString))
|
||||
|
||||
if *auth.Response != response {
|
||||
if *auth.DigestValues.Response != response {
|
||||
return fmt.Errorf("wrong response")
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported authorization header")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
87
pkg/headers/authorization.go
Normal file
87
pkg/headers/authorization.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package headers
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/aler9/gortsplib/pkg/base"
|
||||
)
|
||||
|
||||
// Authorization is an Authorization header.
|
||||
type Authorization struct {
|
||||
// authentication method
|
||||
Method AuthMethod
|
||||
|
||||
// basic user
|
||||
BasicUser string
|
||||
|
||||
// basic password
|
||||
BasicPass string
|
||||
|
||||
// digest values
|
||||
DigestValues Auth
|
||||
}
|
||||
|
||||
// Read decodes an Authorization header.
|
||||
func (h *Authorization) Read(v base.HeaderValue) error {
|
||||
if len(v) == 0 {
|
||||
return fmt.Errorf("value not provided")
|
||||
}
|
||||
|
||||
if len(v) > 1 {
|
||||
return fmt.Errorf("value provided multiple times (%v)", v)
|
||||
}
|
||||
|
||||
v0 := v[0]
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(v0, "Basic "):
|
||||
h.Method = AuthBasic
|
||||
|
||||
v0 = v0[len("Basic "):]
|
||||
|
||||
tmp, err := base64.StdEncoding.DecodeString(v0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value")
|
||||
}
|
||||
|
||||
tmp2 := strings.Split(string(tmp), ":")
|
||||
if len(tmp2) != 2 {
|
||||
return fmt.Errorf("invalid value")
|
||||
}
|
||||
|
||||
h.BasicUser, h.BasicPass = tmp2[0], tmp2[1]
|
||||
|
||||
case strings.HasPrefix(v0, "Digest "):
|
||||
h.Method = AuthDigest
|
||||
|
||||
var vals Auth
|
||||
err := vals.Read(base.HeaderValue{v0})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.DigestValues = vals
|
||||
|
||||
default:
|
||||
return fmt.Errorf("invalid authorization header")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write encodes an Authorization header.
|
||||
func (h Authorization) Write() base.HeaderValue {
|
||||
switch h.Method {
|
||||
case AuthBasic:
|
||||
response := base64.StdEncoding.EncodeToString([]byte(h.BasicUser + ":" + h.BasicPass))
|
||||
|
||||
return base.HeaderValue{"Basic " + response}
|
||||
|
||||
case AuthDigest:
|
||||
return h.DigestValues.Write()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
95
pkg/headers/authorization_test.go
Normal file
95
pkg/headers/authorization_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package headers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/aler9/gortsplib/pkg/base"
|
||||
)
|
||||
|
||||
var casesAuthorization = []struct {
|
||||
name string
|
||||
vin base.HeaderValue
|
||||
vout base.HeaderValue
|
||||
h Authorization
|
||||
}{
|
||||
{
|
||||
"basic",
|
||||
base.HeaderValue{"Basic bXl1c2VyOm15cGFzcw=="},
|
||||
base.HeaderValue{"Basic bXl1c2VyOm15cGFzcw=="},
|
||||
Authorization{
|
||||
Method: AuthBasic,
|
||||
BasicUser: "myuser",
|
||||
BasicPass: "mypass",
|
||||
},
|
||||
},
|
||||
{
|
||||
"digest",
|
||||
base.HeaderValue{"Digest realm=\"4419b63f5e51\", nonce=\"8b84a3b789283a8bea8da7fa7d41f08b\", stale=\"FALSE\""},
|
||||
base.HeaderValue{"Digest realm=\"4419b63f5e51\", nonce=\"8b84a3b789283a8bea8da7fa7d41f08b\", stale=\"FALSE\""},
|
||||
Authorization{
|
||||
Method: AuthDigest,
|
||||
DigestValues: Auth{
|
||||
Method: AuthDigest,
|
||||
Realm: func() *string {
|
||||
v := "4419b63f5e51"
|
||||
return &v
|
||||
}(),
|
||||
Nonce: func() *string {
|
||||
v := "8b84a3b789283a8bea8da7fa7d41f08b"
|
||||
return &v
|
||||
}(),
|
||||
Stale: func() *string {
|
||||
v := "FALSE"
|
||||
return &v
|
||||
}(),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestAuthorizationRead(t *testing.T) {
|
||||
for _, ca := range casesAuthorization {
|
||||
t.Run(ca.name, func(t *testing.T) {
|
||||
var h Authorization
|
||||
err := h.Read(ca.vin)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ca.h, h)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizationWrite(t *testing.T) {
|
||||
for _, ca := range casesAuthorization {
|
||||
t.Run(ca.name, func(t *testing.T) {
|
||||
vout := ca.h.Write()
|
||||
require.Equal(t, ca.vout, vout)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizationReadError(t *testing.T) {
|
||||
for _, ca := range []struct {
|
||||
name string
|
||||
hv base.HeaderValue
|
||||
err string
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
base.HeaderValue{},
|
||||
"value not provided",
|
||||
},
|
||||
{
|
||||
"2 values",
|
||||
base.HeaderValue{"a", "b"},
|
||||
"value provided multiple times ([a b])",
|
||||
},
|
||||
} {
|
||||
t.Run(ca.name, func(t *testing.T) {
|
||||
var h Authorization
|
||||
err := h.Read(ca.hv)
|
||||
require.Equal(t, ca.err, err.Error())
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user