HeaderAuth: use struct instead of map for storing

This commit is contained in:
aler9
2020-09-13 16:40:12 +02:00
parent 633f25bb33
commit 45cf5562de
4 changed files with 304 additions and 147 deletions

View File

@@ -2,18 +2,68 @@ package gortsplib
import (
"fmt"
"regexp"
"sort"
"strings"
)
// HeaderAuth is an Authenticate or a WWWW-Authenticate header.
type HeaderAuth struct {
Prefix string
Values map[string]string
// authentication method
Method AuthMethod
// (optional) username
Username *string
// (optional) realm
Realm *string
// (optional) nonce
Nonce *string
// (optional) uri
URI *string
// (optional) response
Response *string
// (optional) opaque
Opaque *string
// (optional) stale
Stale *string
// (optional) algorithm
Algorithm *string
}
var regHeaderAuthKeyValue = regexp.MustCompile("^([a-z]+)=(\"(.*?)\"|([a-zA-Z0-9]+))(, *|$)")
func findValue(v0 string) (string, string, error) {
if v0 == "" {
return "", "", nil
}
if v0[0] == '"' {
i := 1
for {
if i >= len(v0) {
return "", "", fmt.Errorf("apices not closed (%v)", v0)
}
if v0[i] == '"' {
return v0[1:i], v0[i+1:], nil
}
i++
}
}
i := 0
for {
if i >= len(v0) || v0[i] == ',' {
return v0[:i], v0[i:], nil
}
i++
}
}
// ReadHeaderAuth parses an Authenticate or a WWW-Authenticate header.
func ReadHeaderAuth(v HeaderValue) (*HeaderAuth, error) {
@@ -25,28 +75,79 @@ func ReadHeaderAuth(v HeaderValue) (*HeaderAuth, error) {
return nil, fmt.Errorf("value provided multiple times (%v)", v)
}
ha := &HeaderAuth{
Values: make(map[string]string),
}
ha := &HeaderAuth{}
v0 := v[0]
i := strings.IndexByte(v[0], ' ')
i := strings.IndexByte(v0, ' ')
if i < 0 {
return nil, fmt.Errorf("unable to find prefix (%s)", v0)
return nil, fmt.Errorf("unable to find method (%s)", v0)
}
ha.Prefix, v0 = v0[:i], v0[i+1:]
switch v0[:i] {
case "Basic":
ha.Method = Basic
case "Digest":
ha.Method = Digest
default:
return nil, fmt.Errorf("invalid method (%s)", v0[:i])
}
v0 = v0[i+1:]
for len(v0) > 0 {
m := regHeaderAuthKeyValue.FindStringSubmatch(v0)
if m == nil {
return nil, fmt.Errorf("unable to parse key-value (%s)", v0)
i := strings.IndexByte(v0, '=')
if i < 0 {
return nil, fmt.Errorf("unable to find key (%s)", v0)
}
v0 = v0[len(m[0]):]
var key string
key, v0 = v0[:i], v0[i+1:]
m[2] = strings.TrimPrefix(m[2], "\"")
m[2] = strings.TrimSuffix(m[2], "\"")
ha.Values[m[1]] = m[2]
var val string
var err error
val, v0, err = findValue(v0)
if err != nil {
return nil, err
}
switch key {
case "username":
ha.Username = &val
case "realm":
ha.Realm = &val
case "nonce":
ha.Nonce = &val
case "uri":
ha.URI = &val
case "response":
ha.Response = &val
case "opaque":
ha.Opaque = &val
case "stale":
ha.Stale = &val
case "algorithm":
ha.Algorithm = &val
// ignore non-standard keys
}
// skip comma
if len(v0) > 0 && v0[0] == ',' {
v0 = v0[1:]
}
// skip spaces
for len(v0) > 0 && v0[0] == ' ' {
v0 = v0[1:]
}
}
return ha, nil
@@ -54,48 +155,53 @@ func ReadHeaderAuth(v HeaderValue) (*HeaderAuth, error) {
// Write encodes an Authenticate or a WWW-Authenticate header.
func (ha *HeaderAuth) Write() HeaderValue {
ret := ha.Prefix + " "
ret := ""
// follow a specific order, otherwise some clients/servers do not work correctly
var sortedKeys []string
for key := range ha.Values {
sortedKeys = append(sortedKeys, key)
}
score := func(v string) int {
switch v {
case "username":
return 0
case "realm":
return 1
case "nonce":
return 2
case "uri":
return 3
case "response":
return 4
case "opaque":
return 5
case "stale":
return 6
case "algorithm":
return 7
}
return 8
}
sort.Slice(sortedKeys, func(a, b int) bool {
sa := score(sortedKeys[a])
sb := score(sortedKeys[b])
if sa != sb {
return sa < sb
}
return a < b
})
switch ha.Method {
case Basic:
ret += "Basic"
var tmp []string
for _, key := range sortedKeys {
tmp = append(tmp, key+"=\""+ha.Values[key]+"\"")
case Digest:
ret += "Digest"
}
ret += strings.Join(tmp, ", ")
ret += " "
var vals []string
if ha.Username != nil {
vals = append(vals, "username=\""+*ha.Username+"\"")
}
if ha.Realm != nil {
vals = append(vals, "realm=\""+*ha.Realm+"\"")
}
if ha.Nonce != nil {
vals = append(vals, "nonce=\""+*ha.Nonce+"\"")
}
if ha.URI != nil {
vals = append(vals, "uri=\""+*ha.URI+"\"")
}
if ha.Response != nil {
vals = append(vals, "response=\""+*ha.Response+"\"")
}
if ha.Opaque != nil {
vals = append(vals, "opaque=\""+*ha.Opaque+"\"")
}
if ha.Stale != nil {
vals = append(vals, "stale=\""+*ha.Stale+"\"")
}
if ha.Algorithm != nil {
vals = append(vals, "algorithm=\""+*ha.Algorithm+"\"")
}
ret += strings.Join(vals, ", ")
return HeaderValue{ret}
}