headers: merge parsing of key-values

This commit is contained in:
aler9
2021-04-04 14:52:28 +02:00
parent cbb47e158a
commit 5847b507d1
10 changed files with 328 additions and 131 deletions

View File

@@ -49,36 +49,6 @@ type Auth struct {
Algorithm *string
}
func findValue(v0 string) (string, string, error) {
if v0 == "" {
return "", "", nil
}
if v0[0] == '"' {
i := 1
for {
if i >= len(v0) {
return "", "", fmt.Errorf("apices not closed (%v)", v0)
}
if v0[i] == '"' {
return v0[1:i], v0[i+1:], nil
}
i++
}
}
i := 0
for {
if i >= len(v0) || v0[i] == ',' {
return v0[:i], v0[i:], nil
}
i++
}
}
// Read decodes an Authenticate or a WWW-Authenticate header.
func (h *Auth) Read(v base.HeaderValue) error {
if len(v) == 0 {
@@ -93,10 +63,11 @@ func (h *Auth) Read(v base.HeaderValue) error {
i := strings.IndexByte(v0, ' ')
if i < 0 {
return fmt.Errorf("unable to find method (%s)", v0)
return fmt.Errorf("unable to split between method and keys (%v)", v)
}
method, v0 := v0[:i], v0[i+1:]
switch v0[:i] {
switch method {
case "Basic":
h.Method = AuthBasic
@@ -104,62 +75,45 @@ func (h *Auth) Read(v base.HeaderValue) error {
h.Method = AuthDigest
default:
return fmt.Errorf("invalid method (%s)", v0[:i])
return fmt.Errorf("invalid method (%s)", method)
}
v0 = v0[i+1:]
for len(v0) > 0 {
i := strings.IndexByte(v0, '=')
if i < 0 {
return fmt.Errorf("unable to find key (%s)", v0)
}
var key string
key, v0 = v0[:i], v0[i+1:]
kvs, err := keyValParse(v0, ',')
if err != nil {
return err
}
var val string
var err error
val, v0, err = findValue(v0)
if err != nil {
return err
}
for k, rv := range kvs {
v := rv
switch key {
switch k {
case "username":
h.Username = &val
h.Username = &v
case "realm":
h.Realm = &val
h.Realm = &v
case "nonce":
h.Nonce = &val
h.Nonce = &v
case "uri":
h.URI = &val
h.URI = &v
case "response":
h.Response = &val
h.Response = &v
case "opaque":
h.Opaque = &val
h.Opaque = &v
case "stale":
h.Stale = &val
h.Stale = &v
case "algorithm":
h.Algorithm = &val
h.Algorithm = &v
default:
// ignore non-standard keys
}
// skip comma
if len(v0) > 0 && v0[0] == ',' {
v0 = v0[1:]
}
// skip spaces
for len(v0) > 0 && v0[0] == ' ' {
v0 = v0[1:]
}
}
return nil

View File

@@ -173,21 +173,51 @@ var casesAuth = []struct {
}
func TestAuthRead(t *testing.T) {
for _, c := range casesAuth {
t.Run(c.name, func(t *testing.T) {
for _, ca := range casesAuth {
t.Run(ca.name, func(t *testing.T) {
var h Auth
err := h.Read(c.vin)
err := h.Read(ca.vin)
require.NoError(t, err)
require.Equal(t, c.h, h)
require.Equal(t, ca.h, h)
})
}
}
func TestAuthReadError(t *testing.T) {
for _, ca := range []struct {
name string
hv base.HeaderValue
}{
{
"empty",
base.HeaderValue{},
},
{
"2 values",
base.HeaderValue{"a", "b"},
},
{
"no keys",
base.HeaderValue{"Basic"},
},
{
"invalid method",
base.HeaderValue{"Testing key1=val1"},
},
} {
t.Run(ca.name, func(t *testing.T) {
var h Auth
err := h.Read(ca.hv)
require.Error(t, err)
})
}
}
func TestAuthWrite(t *testing.T) {
for _, c := range casesAuth {
t.Run(c.name, func(t *testing.T) {
vout := c.h.Write()
require.Equal(t, c.vout, vout)
for _, ca := range casesAuth {
t.Run(ca.name, func(t *testing.T) {
vout := ca.h.Write()
require.Equal(t, ca.vout, vout)
})
}
}

70
pkg/headers/keyval.go Normal file
View File

@@ -0,0 +1,70 @@
package headers
import (
"fmt"
"strings"
)
func findValue(str string, separator byte) (string, string, error) {
if str == "" {
return "", "", nil
}
if str[0] == '"' {
i := 1
for {
if i >= len(str) {
return "", "", fmt.Errorf("apices not closed (%v)", str)
}
if str[i] == '"' {
return str[1:i], str[i+1:], nil
}
i++
}
}
i := 0
for {
if i >= len(str) || str[i] == separator {
return str[:i], str[i:], nil
}
i++
}
}
func keyValParse(str string, separator byte) (map[string]string, error) {
ret := make(map[string]string)
for len(str) > 0 {
i := strings.IndexByte(str, '=')
if i < 0 {
return nil, fmt.Errorf("unable to find key")
}
var k string
k, str = str[:i], str[i+1:]
var v string
var err error
v, str, err = findValue(str, separator)
if err != nil {
return nil, err
}
ret[k] = v
// skip separator
if len(str) > 0 && str[0] == separator {
str = str[1:]
}
// skip spaces
for len(str) > 0 && str[0] == ' ' {
str = str[1:]
}
}
return ret, nil
}

View File

@@ -0,0 +1,64 @@
package headers
import (
"testing"
"github.com/stretchr/testify/require"
)
var casesKeyVal = []struct {
name string
s string
kvs map[string]string
}{
{
"base",
`key1=v1,key2=v2`,
map[string]string{
"key1": "v1",
"key2": "v2",
},
},
{
"with space",
`key1=v1, key2=v2`,
map[string]string{
"key1": "v1",
"key2": "v2",
},
},
{
"with apices",
`key1="v1", key2=v2`,
map[string]string{
"key1": "v1",
"key2": "v2",
},
},
{
"with apices and comma",
`key1="v,1", key2="v2"`,
map[string]string{
"key1": "v,1",
"key2": "v2",
},
},
{
"with apices and equal",
`key1="v=1", key2="v2"`,
map[string]string{
"key1": "v=1",
"key2": "v2",
},
},
}
func TestKeyValParse(t *testing.T) {
for _, ca := range casesKeyVal {
t.Run(ca.name, func(t *testing.T) {
kvs, err := keyValParse(ca.s, ',')
require.NoError(t, err)
require.Equal(t, ca.kvs, kvs)
})
}
}

View File

@@ -34,13 +34,12 @@ func (h *RTPInfo) Read(v base.HeaderValue) error {
// remove leading spaces
part = strings.TrimLeft(part, " ")
for _, kv := range strings.Split(part, ";") {
tmp := strings.SplitN(kv, "=", 2)
if len(tmp) != 2 {
return fmt.Errorf("unable to parse key-value (%v)", kv)
}
k, v := tmp[0], tmp[1]
kvs, err := keyValParse(part, ';')
if err != nil {
return err
}
for k, v := range kvs {
switch k {
case "url":
e.URL = v

View File

@@ -168,21 +168,43 @@ var casesRTPInfo = []struct {
}
func TestRTPInfoRead(t *testing.T) {
for _, c := range casesRTPInfo {
t.Run(c.name, func(t *testing.T) {
for _, ca := range casesRTPInfo {
t.Run(ca.name, func(t *testing.T) {
var h RTPInfo
err := h.Read(c.vin)
err := h.Read(ca.vin)
require.NoError(t, err)
require.Equal(t, c.h, h)
require.Equal(t, ca.h, h)
})
}
}
func TestRTPInfoReadError(t *testing.T) {
for _, ca := range []struct {
name string
hv base.HeaderValue
}{
{
"empty",
base.HeaderValue{},
},
{
"2 values",
base.HeaderValue{"a", "b"},
},
} {
t.Run(ca.name, func(t *testing.T) {
var h RTPInfo
err := h.Read(ca.hv)
require.Error(t, err)
})
}
}
func TestRTPInfoWrite(t *testing.T) {
for _, c := range casesRTPInfo {
t.Run(c.name, func(t *testing.T) {
req := c.h.Write()
require.Equal(t, c.vout, req)
for _, ca := range casesRTPInfo {
t.Run(ca.name, func(t *testing.T) {
req := ca.h.Write()
require.Equal(t, ca.vout, req)
})
}
}

View File

@@ -27,23 +27,25 @@ func (h *Session) Read(v base.HeaderValue) error {
return fmt.Errorf("value provided multiple times (%v)", v)
}
parts := strings.Split(v[0], ";")
if len(parts) == 0 {
return fmt.Errorf("invalid value (%v)", v)
v0 := v[0]
i := strings.IndexByte(v0, ';')
if i < 0 {
h.Session = v0
return nil
}
h.Session = parts[0]
h.Session = v0[:i]
v0 = v0[i+1:]
for _, kv := range parts[1:] {
// remove leading spaces
kv = strings.TrimLeft(kv, " ")
v0 = strings.TrimLeft(v0, " ")
tmp := strings.SplitN(kv, "=", 2)
if len(tmp) != 2 {
return fmt.Errorf("unable to parse key-value (%v)", kv)
}
k, v := tmp[0], tmp[1]
kvs, err := keyValParse(v0, ';')
if err != nil {
return err
}
for k, v := range kvs {
switch k {
case "timeout":
iv, err := strconv.ParseUint(v, 10, 64)

View File

@@ -49,21 +49,43 @@ var casesSession = []struct {
}
func TestSessionRead(t *testing.T) {
for _, c := range casesSession {
t.Run(c.name, func(t *testing.T) {
for _, ca := range casesSession {
t.Run(ca.name, func(t *testing.T) {
var h Session
err := h.Read(c.vin)
err := h.Read(ca.vin)
require.NoError(t, err)
require.Equal(t, c.h, h)
require.Equal(t, ca.h, h)
})
}
}
func TestSessionReadError(t *testing.T) {
for _, ca := range []struct {
name string
hv base.HeaderValue
}{
{
"empty",
base.HeaderValue{},
},
{
"2 values",
base.HeaderValue{"a", "b"},
},
} {
t.Run(ca.name, func(t *testing.T) {
var h Session
err := h.Read(ca.hv)
require.Error(t, err)
})
}
}
func TestSessionWrite(t *testing.T) {
for _, c := range casesSession {
t.Run(c.name, func(t *testing.T) {
req := c.h.Write()
require.Equal(t, c.vout, req)
for _, ca := range casesSession {
t.Run(ca.name, func(t *testing.T) {
req := ca.h.Write()
require.Equal(t, ca.vout, req)
})
}
}

View File

@@ -99,12 +99,17 @@ func (h *Transport) Read(v base.HeaderValue) error {
return fmt.Errorf("value provided multiple times (%v)", v)
}
parts := strings.Split(v[0], ";")
if len(parts) == 0 {
return fmt.Errorf("invalid value (%v)", v)
v0 := v[0]
var part string
i := strings.IndexByte(v0, ';')
if i >= 0 {
part, v0 = v0[:i], v0[i+1:]
} else {
part, v0 = v0, ""
}
switch parts[0] {
switch part {
case "RTP/AVP", "RTP/AVP/UDP":
h.Protocol = base.StreamProtocolUDP
@@ -114,28 +119,35 @@ func (h *Transport) Read(v base.HeaderValue) error {
default:
return fmt.Errorf("invalid protocol (%v)", v)
}
parts = parts[1:]
switch parts[0] {
i = strings.IndexByte(v0, ';')
if i >= 0 {
part, v0 = v0[:i], v0[i+1:]
} else {
part, v0 = v0, ""
}
switch part {
case "unicast":
v := base.StreamDeliveryUnicast
h.Delivery = &v
parts = parts[1:]
case "multicast":
v := base.StreamDeliveryMulticast
h.Delivery = &v
parts = parts[1:]
// cast is optional, do not return any error
default:
// cast is optional, go back
v0 = part + ";" + v0
}
for _, kv := range parts {
tmp := strings.SplitN(kv, "=", 2)
if len(tmp) != 2 {
return fmt.Errorf("unable to parse key-value (%v)", kv)
}
k, v := tmp[0], tmp[1]
kvs, err := keyValParse(v0, ';')
if err != nil {
return err
}
for k, rv := range kvs {
v := rv
switch k {
case "destination":

View File

@@ -114,21 +114,43 @@ var casesTransport = []struct {
}
func TestTransportRead(t *testing.T) {
for _, c := range casesTransport {
t.Run(c.name, func(t *testing.T) {
for _, ca := range casesTransport {
t.Run(ca.name, func(t *testing.T) {
var h Transport
err := h.Read(c.vin)
err := h.Read(ca.vin)
require.NoError(t, err)
require.Equal(t, c.h, h)
require.Equal(t, ca.h, h)
})
}
}
func TestTransportReadError(t *testing.T) {
for _, ca := range []struct {
name string
hv base.HeaderValue
}{
{
"empty",
base.HeaderValue{},
},
{
"2 values",
base.HeaderValue{"a", "b"},
},
} {
t.Run(ca.name, func(t *testing.T) {
var h Transport
err := h.Read(ca.hv)
require.Error(t, err)
})
}
}
func TestTransportWrite(t *testing.T) {
for _, c := range casesTransport {
t.Run(c.name, func(t *testing.T) {
req := c.h.Write()
require.Equal(t, c.vout, req)
for _, ca := range casesTransport {
t.Run(ca.name, func(t *testing.T) {
req := ca.h.Write()
require.Equal(t, ca.vout, req)
})
}
}