headers: add authorization header

This commit is contained in:
aler9
2021-05-10 20:43:23 +02:00
parent 9aa91dce29
commit 034c43202e
7 changed files with 220 additions and 57 deletions

View File

@@ -21,7 +21,7 @@ Features:
* Pause reading without disconnecting from the server * Pause reading without disconnecting from the server
* Generate RTCP receiver reports automatically * Generate RTCP receiver reports automatically
* Publish * Publish
* Publish streams to servers with UDP, TCP or TLS * Publish streams to servers with UDP, TCP or TLS (RTSPS)
* Switch protocol automatically (switch to TCP in case of server error) * Switch protocol automatically (switch to TCP in case of server error)
* Pause publishing without disconnecting from the server * Pause publishing without disconnecting from the server
* Generate RTCP sender reports automatically * Generate RTCP sender reports automatically

View File

@@ -1,7 +1,6 @@
package auth package auth
import ( import (
"encoding/base64"
"fmt" "fmt"
"strings" "strings"
@@ -87,25 +86,28 @@ func NewSender(v base.HeaderValue, user string, pass string) (*Sender, error) {
func (se *Sender) GenerateHeader(method base.Method, ur *base.URL) base.HeaderValue { func (se *Sender) GenerateHeader(method base.Method, ur *base.URL) base.HeaderValue {
urStr := ur.CloneWithoutCredentials().String() urStr := ur.CloneWithoutCredentials().String()
h := headers.Authorization{
Method: se.method,
}
switch se.method { switch se.method {
case headers.AuthBasic: case headers.AuthBasic:
response := base64.StdEncoding.EncodeToString([]byte(se.user + ":" + se.pass)) h.BasicUser = se.user
h.BasicPass = se.pass
return base.HeaderValue{"Basic " + response} default: // headers.AuthDigest
case headers.AuthDigest:
response := md5Hex(md5Hex(se.user+":"+se.realm+":"+se.pass) + ":" + response := md5Hex(md5Hex(se.user+":"+se.realm+":"+se.pass) + ":" +
se.nonce + ":" + md5Hex(string(method)+":"+urStr)) se.nonce + ":" + md5Hex(string(method)+":"+urStr))
return headers.Auth{ h.DigestValues = headers.Auth{
Method: headers.AuthDigest, Method: headers.AuthDigest,
Username: &se.user, Username: &se.user,
Realm: &se.realm, Realm: &se.realm,
Nonce: &se.nonce, Nonce: &se.nonce,
URI: &urStr, URI: &urStr,
Response: &response, Response: &response,
}.Write() }
} }
return nil return h.Write()
} }

View File

@@ -2,7 +2,6 @@ package auth
import ( import (
"crypto/rand" "crypto/rand"
"encoding/base64"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"strings" "strings"
@@ -90,99 +89,82 @@ func (va *Validator) GenerateHeader() base.HeaderValue {
// ValidateHeader validates the Authorization header sent by a client after receiving the // ValidateHeader validates the Authorization header sent by a client after receiving the
// WWW-Authenticate header. // WWW-Authenticate header.
func (va *Validator) ValidateHeader(v base.HeaderValue, method base.Method, ur *base.URL, func (va *Validator) ValidateHeader(
v base.HeaderValue,
method base.Method,
ur *base.URL,
altURL *base.URL) error { altURL *base.URL) error {
if len(v) == 0 {
return fmt.Errorf("authorization header not provided")
}
if len(v) > 1 {
return fmt.Errorf("authorization header provided multiple times")
}
v0 := v[0] var auth headers.Authorization
err := auth.Read(v)
switch {
case strings.HasPrefix(v0, "Basic "):
inResponse := v0[len("Basic "):]
tmp, err := base64.StdEncoding.DecodeString(inResponse)
if err != nil { if err != nil {
return fmt.Errorf("wrong response") return err
} }
tmp2 := strings.Split(string(tmp), ":")
if len(tmp2) != 2 {
return fmt.Errorf("wrong response")
}
user, pass := tmp2[0], tmp2[1]
switch auth.Method {
case headers.AuthBasic:
if !va.userHashed { if !va.userHashed {
if user != va.user { if auth.BasicUser != va.user {
return fmt.Errorf("wrong response") return fmt.Errorf("wrong response")
} }
} else { } else {
if sha256Base64(user) != va.user { if sha256Base64(auth.BasicUser) != va.user {
return fmt.Errorf("wrong response") return fmt.Errorf("wrong response")
} }
} }
if !va.passHashed { if !va.passHashed {
if pass != va.pass { if auth.BasicPass != va.pass {
return fmt.Errorf("wrong response") return fmt.Errorf("wrong response")
} }
} else { } else {
if sha256Base64(pass) != va.pass { if sha256Base64(auth.BasicPass) != va.pass {
return fmt.Errorf("wrong response") return fmt.Errorf("wrong response")
} }
} }
case strings.HasPrefix(v0, "Digest "): default: // headers.AuthDigest
var auth headers.Auth if auth.DigestValues.Realm == nil {
err := auth.Read(base.HeaderValue{v0})
if err != nil {
return err
}
if auth.Realm == nil {
return fmt.Errorf("realm not provided") return fmt.Errorf("realm not provided")
} }
if auth.Nonce == nil { if auth.DigestValues.Nonce == nil {
return fmt.Errorf("nonce not provided") return fmt.Errorf("nonce not provided")
} }
if auth.Username == nil { if auth.DigestValues.Username == nil {
return fmt.Errorf("username not provided") return fmt.Errorf("username not provided")
} }
if auth.URI == nil { if auth.DigestValues.URI == nil {
return fmt.Errorf("uri not provided") return fmt.Errorf("uri not provided")
} }
if auth.Response == nil { if auth.DigestValues.Response == nil {
return fmt.Errorf("response not provided") return fmt.Errorf("response not provided")
} }
if *auth.Nonce != va.nonce { if *auth.DigestValues.Nonce != va.nonce {
return fmt.Errorf("wrong nonce") return fmt.Errorf("wrong nonce")
} }
if *auth.Realm != va.realm { if *auth.DigestValues.Realm != va.realm {
return fmt.Errorf("wrong realm") return fmt.Errorf("wrong realm")
} }
if *auth.Username != va.user { if *auth.DigestValues.Username != va.user {
return fmt.Errorf("wrong username") return fmt.Errorf("wrong username")
} }
urlString := ur.String() urlString := ur.String()
if *auth.URI != urlString { if *auth.DigestValues.URI != urlString {
// do another try with the alternative URL // do another try with the alternative URL
if altURL != nil { if altURL != nil {
urlString = altURL.String() urlString = altURL.String()
} }
if *auth.URI != urlString { if *auth.DigestValues.URI != urlString {
return fmt.Errorf("wrong url") return fmt.Errorf("wrong url")
} }
} }
@@ -190,12 +172,9 @@ func (va *Validator) ValidateHeader(v base.HeaderValue, method base.Method, ur *
response := md5Hex(md5Hex(va.user+":"+va.realm+":"+va.pass) + response := md5Hex(md5Hex(va.user+":"+va.realm+":"+va.pass) +
":" + va.nonce + ":" + md5Hex(string(method)+":"+urlString)) ":" + va.nonce + ":" + md5Hex(string(method)+":"+urlString))
if *auth.Response != response { if *auth.DigestValues.Response != response {
return fmt.Errorf("wrong response") return fmt.Errorf("wrong response")
} }
default:
return fmt.Errorf("unsupported authorization header")
} }
return nil return nil

View File

@@ -0,0 +1,87 @@
package headers
import (
"encoding/base64"
"fmt"
"strings"
"github.com/aler9/gortsplib/pkg/base"
)
// Authorization is an Authorization header.
type Authorization struct {
// authentication method
Method AuthMethod
// basic user
BasicUser string
// basic password
BasicPass string
// digest values
DigestValues Auth
}
// Read decodes an Authorization header.
func (h *Authorization) Read(v base.HeaderValue) error {
if len(v) == 0 {
return fmt.Errorf("value not provided")
}
if len(v) > 1 {
return fmt.Errorf("value provided multiple times (%v)", v)
}
v0 := v[0]
switch {
case strings.HasPrefix(v0, "Basic "):
h.Method = AuthBasic
v0 = v0[len("Basic "):]
tmp, err := base64.StdEncoding.DecodeString(v0)
if err != nil {
return fmt.Errorf("invalid value")
}
tmp2 := strings.Split(string(tmp), ":")
if len(tmp2) != 2 {
return fmt.Errorf("invalid value")
}
h.BasicUser, h.BasicPass = tmp2[0], tmp2[1]
case strings.HasPrefix(v0, "Digest "):
h.Method = AuthDigest
var vals Auth
err := vals.Read(base.HeaderValue{v0})
if err != nil {
return err
}
h.DigestValues = vals
default:
return fmt.Errorf("invalid authorization header")
}
return nil
}
// Write encodes an Authorization header.
func (h Authorization) Write() base.HeaderValue {
switch h.Method {
case AuthBasic:
response := base64.StdEncoding.EncodeToString([]byte(h.BasicUser + ":" + h.BasicPass))
return base.HeaderValue{"Basic " + response}
case AuthDigest:
return h.DigestValues.Write()
}
return nil
}

View File

@@ -0,0 +1,95 @@
package headers
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/aler9/gortsplib/pkg/base"
)
var casesAuthorization = []struct {
name string
vin base.HeaderValue
vout base.HeaderValue
h Authorization
}{
{
"basic",
base.HeaderValue{"Basic bXl1c2VyOm15cGFzcw=="},
base.HeaderValue{"Basic bXl1c2VyOm15cGFzcw=="},
Authorization{
Method: AuthBasic,
BasicUser: "myuser",
BasicPass: "mypass",
},
},
{
"digest",
base.HeaderValue{"Digest realm=\"4419b63f5e51\", nonce=\"8b84a3b789283a8bea8da7fa7d41f08b\", stale=\"FALSE\""},
base.HeaderValue{"Digest realm=\"4419b63f5e51\", nonce=\"8b84a3b789283a8bea8da7fa7d41f08b\", stale=\"FALSE\""},
Authorization{
Method: AuthDigest,
DigestValues: Auth{
Method: AuthDigest,
Realm: func() *string {
v := "4419b63f5e51"
return &v
}(),
Nonce: func() *string {
v := "8b84a3b789283a8bea8da7fa7d41f08b"
return &v
}(),
Stale: func() *string {
v := "FALSE"
return &v
}(),
},
},
},
}
func TestAuthorizationRead(t *testing.T) {
for _, ca := range casesAuthorization {
t.Run(ca.name, func(t *testing.T) {
var h Authorization
err := h.Read(ca.vin)
require.NoError(t, err)
require.Equal(t, ca.h, h)
})
}
}
func TestAuthorizationWrite(t *testing.T) {
for _, ca := range casesAuthorization {
t.Run(ca.name, func(t *testing.T) {
vout := ca.h.Write()
require.Equal(t, ca.vout, vout)
})
}
}
func TestAuthorizationReadError(t *testing.T) {
for _, ca := range []struct {
name string
hv base.HeaderValue
err string
}{
{
"empty",
base.HeaderValue{},
"value not provided",
},
{
"2 values",
base.HeaderValue{"a", "b"},
"value provided multiple times ([a b])",
},
} {
t.Run(ca.name, func(t *testing.T) {
var h Authorization
err := h.Read(ca.hv)
require.Equal(t, ca.err, err.Error())
})
}
}