diff --git a/pkg/headers/auth.go b/pkg/headers/auth.go index 1839d1ab..97585dd7 100644 --- a/pkg/headers/auth.go +++ b/pkg/headers/auth.go @@ -49,36 +49,6 @@ type Auth struct { Algorithm *string } -func findValue(v0 string) (string, string, error) { - if v0 == "" { - return "", "", nil - } - - if v0[0] == '"' { - i := 1 - for { - if i >= len(v0) { - return "", "", fmt.Errorf("apices not closed (%v)", v0) - } - - if v0[i] == '"' { - return v0[1:i], v0[i+1:], nil - } - - i++ - } - } - - i := 0 - for { - if i >= len(v0) || v0[i] == ',' { - return v0[:i], v0[i:], nil - } - - i++ - } -} - // Read decodes an Authenticate or a WWW-Authenticate header. func (h *Auth) Read(v base.HeaderValue) error { if len(v) == 0 { @@ -93,10 +63,11 @@ func (h *Auth) Read(v base.HeaderValue) error { i := strings.IndexByte(v0, ' ') if i < 0 { - return fmt.Errorf("unable to find method (%s)", v0) + return fmt.Errorf("unable to split between method and keys (%v)", v) } + method, v0 := v0[:i], v0[i+1:] - switch v0[:i] { + switch method { case "Basic": h.Method = AuthBasic @@ -104,62 +75,45 @@ func (h *Auth) Read(v base.HeaderValue) error { h.Method = AuthDigest default: - return fmt.Errorf("invalid method (%s)", v0[:i]) + return fmt.Errorf("invalid method (%s)", method) } - v0 = v0[i+1:] - for len(v0) > 0 { - i := strings.IndexByte(v0, '=') - if i < 0 { - return fmt.Errorf("unable to find key (%s)", v0) - } - var key string - key, v0 = v0[:i], v0[i+1:] + kvs, err := keyValParse(v0, ',') + if err != nil { + return err + } - var val string - var err error - val, v0, err = findValue(v0) - if err != nil { - return err - } + for k, rv := range kvs { + v := rv - switch key { + switch k { case "username": - h.Username = &val + h.Username = &v case "realm": - h.Realm = &val + h.Realm = &v case "nonce": - h.Nonce = &val + h.Nonce = &v case "uri": - h.URI = &val + h.URI = &v case "response": - h.Response = &val + h.Response = &v case "opaque": - h.Opaque = &val + h.Opaque = &v case "stale": - h.Stale = &val + h.Stale = &v case "algorithm": - h.Algorithm = &val + h.Algorithm = &v + default: // ignore non-standard keys } - - // skip comma - if len(v0) > 0 && v0[0] == ',' { - v0 = v0[1:] - } - - // skip spaces - for len(v0) > 0 && v0[0] == ' ' { - v0 = v0[1:] - } } return nil diff --git a/pkg/headers/auth_test.go b/pkg/headers/auth_test.go index fcd09661..c67aa329 100644 --- a/pkg/headers/auth_test.go +++ b/pkg/headers/auth_test.go @@ -173,21 +173,51 @@ var casesAuth = []struct { } func TestAuthRead(t *testing.T) { - for _, c := range casesAuth { - t.Run(c.name, func(t *testing.T) { + for _, ca := range casesAuth { + t.Run(ca.name, func(t *testing.T) { var h Auth - err := h.Read(c.vin) + err := h.Read(ca.vin) require.NoError(t, err) - require.Equal(t, c.h, h) + require.Equal(t, ca.h, h) + }) + } +} + +func TestAuthReadError(t *testing.T) { + for _, ca := range []struct { + name string + hv base.HeaderValue + }{ + { + "empty", + base.HeaderValue{}, + }, + { + "2 values", + base.HeaderValue{"a", "b"}, + }, + { + "no keys", + base.HeaderValue{"Basic"}, + }, + { + "invalid method", + base.HeaderValue{"Testing key1=val1"}, + }, + } { + t.Run(ca.name, func(t *testing.T) { + var h Auth + err := h.Read(ca.hv) + require.Error(t, err) }) } } func TestAuthWrite(t *testing.T) { - for _, c := range casesAuth { - t.Run(c.name, func(t *testing.T) { - vout := c.h.Write() - require.Equal(t, c.vout, vout) + for _, ca := range casesAuth { + t.Run(ca.name, func(t *testing.T) { + vout := ca.h.Write() + require.Equal(t, ca.vout, vout) }) } } diff --git a/pkg/headers/keyval.go b/pkg/headers/keyval.go new file mode 100644 index 00000000..f09bf530 --- /dev/null +++ b/pkg/headers/keyval.go @@ -0,0 +1,70 @@ +package headers + +import ( + "fmt" + "strings" +) + +func findValue(str string, separator byte) (string, string, error) { + if str == "" { + return "", "", nil + } + + if str[0] == '"' { + i := 1 + for { + if i >= len(str) { + return "", "", fmt.Errorf("apices not closed (%v)", str) + } + + if str[i] == '"' { + return str[1:i], str[i+1:], nil + } + + i++ + } + } + + i := 0 + for { + if i >= len(str) || str[i] == separator { + return str[:i], str[i:], nil + } + + i++ + } +} + +func keyValParse(str string, separator byte) (map[string]string, error) { + ret := make(map[string]string) + + for len(str) > 0 { + i := strings.IndexByte(str, '=') + if i < 0 { + return nil, fmt.Errorf("unable to find key") + } + var k string + k, str = str[:i], str[i+1:] + + var v string + var err error + v, str, err = findValue(str, separator) + if err != nil { + return nil, err + } + + ret[k] = v + + // skip separator + if len(str) > 0 && str[0] == separator { + str = str[1:] + } + + // skip spaces + for len(str) > 0 && str[0] == ' ' { + str = str[1:] + } + } + + return ret, nil +} diff --git a/pkg/headers/keyval_test.go b/pkg/headers/keyval_test.go new file mode 100644 index 00000000..bd96dbd9 --- /dev/null +++ b/pkg/headers/keyval_test.go @@ -0,0 +1,64 @@ +package headers + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +var casesKeyVal = []struct { + name string + s string + kvs map[string]string +}{ + { + "base", + `key1=v1,key2=v2`, + map[string]string{ + "key1": "v1", + "key2": "v2", + }, + }, + { + "with space", + `key1=v1, key2=v2`, + map[string]string{ + "key1": "v1", + "key2": "v2", + }, + }, + { + "with apices", + `key1="v1", key2=v2`, + map[string]string{ + "key1": "v1", + "key2": "v2", + }, + }, + { + "with apices and comma", + `key1="v,1", key2="v2"`, + map[string]string{ + "key1": "v,1", + "key2": "v2", + }, + }, + { + "with apices and equal", + `key1="v=1", key2="v2"`, + map[string]string{ + "key1": "v=1", + "key2": "v2", + }, + }, +} + +func TestKeyValParse(t *testing.T) { + for _, ca := range casesKeyVal { + t.Run(ca.name, func(t *testing.T) { + kvs, err := keyValParse(ca.s, ',') + require.NoError(t, err) + require.Equal(t, ca.kvs, kvs) + }) + } +} diff --git a/pkg/headers/rtpinfo.go b/pkg/headers/rtpinfo.go index 3f02fe67..31744c42 100644 --- a/pkg/headers/rtpinfo.go +++ b/pkg/headers/rtpinfo.go @@ -34,13 +34,12 @@ func (h *RTPInfo) Read(v base.HeaderValue) error { // remove leading spaces part = strings.TrimLeft(part, " ") - for _, kv := range strings.Split(part, ";") { - tmp := strings.SplitN(kv, "=", 2) - if len(tmp) != 2 { - return fmt.Errorf("unable to parse key-value (%v)", kv) - } - k, v := tmp[0], tmp[1] + kvs, err := keyValParse(part, ';') + if err != nil { + return err + } + for k, v := range kvs { switch k { case "url": e.URL = v diff --git a/pkg/headers/rtpinfo_test.go b/pkg/headers/rtpinfo_test.go index 3beccafe..b06b44e0 100644 --- a/pkg/headers/rtpinfo_test.go +++ b/pkg/headers/rtpinfo_test.go @@ -168,21 +168,43 @@ var casesRTPInfo = []struct { } func TestRTPInfoRead(t *testing.T) { - for _, c := range casesRTPInfo { - t.Run(c.name, func(t *testing.T) { + for _, ca := range casesRTPInfo { + t.Run(ca.name, func(t *testing.T) { var h RTPInfo - err := h.Read(c.vin) + err := h.Read(ca.vin) require.NoError(t, err) - require.Equal(t, c.h, h) + require.Equal(t, ca.h, h) + }) + } +} + +func TestRTPInfoReadError(t *testing.T) { + for _, ca := range []struct { + name string + hv base.HeaderValue + }{ + { + "empty", + base.HeaderValue{}, + }, + { + "2 values", + base.HeaderValue{"a", "b"}, + }, + } { + t.Run(ca.name, func(t *testing.T) { + var h RTPInfo + err := h.Read(ca.hv) + require.Error(t, err) }) } } func TestRTPInfoWrite(t *testing.T) { - for _, c := range casesRTPInfo { - t.Run(c.name, func(t *testing.T) { - req := c.h.Write() - require.Equal(t, c.vout, req) + for _, ca := range casesRTPInfo { + t.Run(ca.name, func(t *testing.T) { + req := ca.h.Write() + require.Equal(t, ca.vout, req) }) } } diff --git a/pkg/headers/session.go b/pkg/headers/session.go index 6e62bf4f..f66b2fbc 100644 --- a/pkg/headers/session.go +++ b/pkg/headers/session.go @@ -27,23 +27,25 @@ func (h *Session) Read(v base.HeaderValue) error { return fmt.Errorf("value provided multiple times (%v)", v) } - parts := strings.Split(v[0], ";") - if len(parts) == 0 { - return fmt.Errorf("invalid value (%v)", v) + v0 := v[0] + + i := strings.IndexByte(v0, ';') + if i < 0 { + h.Session = v0 + return nil } - h.Session = parts[0] + h.Session = v0[:i] + v0 = v0[i+1:] - for _, kv := range parts[1:] { - // remove leading spaces - kv = strings.TrimLeft(kv, " ") + v0 = strings.TrimLeft(v0, " ") - tmp := strings.SplitN(kv, "=", 2) - if len(tmp) != 2 { - return fmt.Errorf("unable to parse key-value (%v)", kv) - } - k, v := tmp[0], tmp[1] + kvs, err := keyValParse(v0, ';') + if err != nil { + return err + } + for k, v := range kvs { switch k { case "timeout": iv, err := strconv.ParseUint(v, 10, 64) diff --git a/pkg/headers/session_test.go b/pkg/headers/session_test.go index 39966c5c..b528f22f 100644 --- a/pkg/headers/session_test.go +++ b/pkg/headers/session_test.go @@ -49,21 +49,43 @@ var casesSession = []struct { } func TestSessionRead(t *testing.T) { - for _, c := range casesSession { - t.Run(c.name, func(t *testing.T) { + for _, ca := range casesSession { + t.Run(ca.name, func(t *testing.T) { var h Session - err := h.Read(c.vin) + err := h.Read(ca.vin) require.NoError(t, err) - require.Equal(t, c.h, h) + require.Equal(t, ca.h, h) + }) + } +} + +func TestSessionReadError(t *testing.T) { + for _, ca := range []struct { + name string + hv base.HeaderValue + }{ + { + "empty", + base.HeaderValue{}, + }, + { + "2 values", + base.HeaderValue{"a", "b"}, + }, + } { + t.Run(ca.name, func(t *testing.T) { + var h Session + err := h.Read(ca.hv) + require.Error(t, err) }) } } func TestSessionWrite(t *testing.T) { - for _, c := range casesSession { - t.Run(c.name, func(t *testing.T) { - req := c.h.Write() - require.Equal(t, c.vout, req) + for _, ca := range casesSession { + t.Run(ca.name, func(t *testing.T) { + req := ca.h.Write() + require.Equal(t, ca.vout, req) }) } } diff --git a/pkg/headers/transport.go b/pkg/headers/transport.go index c49b86ee..e0b5b466 100644 --- a/pkg/headers/transport.go +++ b/pkg/headers/transport.go @@ -99,12 +99,17 @@ func (h *Transport) Read(v base.HeaderValue) error { return fmt.Errorf("value provided multiple times (%v)", v) } - parts := strings.Split(v[0], ";") - if len(parts) == 0 { - return fmt.Errorf("invalid value (%v)", v) + v0 := v[0] + + var part string + i := strings.IndexByte(v0, ';') + if i >= 0 { + part, v0 = v0[:i], v0[i+1:] + } else { + part, v0 = v0, "" } - switch parts[0] { + switch part { case "RTP/AVP", "RTP/AVP/UDP": h.Protocol = base.StreamProtocolUDP @@ -114,28 +119,35 @@ func (h *Transport) Read(v base.HeaderValue) error { default: return fmt.Errorf("invalid protocol (%v)", v) } - parts = parts[1:] - switch parts[0] { + i = strings.IndexByte(v0, ';') + if i >= 0 { + part, v0 = v0[:i], v0[i+1:] + } else { + part, v0 = v0, "" + } + + switch part { case "unicast": v := base.StreamDeliveryUnicast h.Delivery = &v - parts = parts[1:] case "multicast": v := base.StreamDeliveryMulticast h.Delivery = &v - parts = parts[1:] - // cast is optional, do not return any error + default: + // cast is optional, go back + v0 = part + ";" + v0 } - for _, kv := range parts { - tmp := strings.SplitN(kv, "=", 2) - if len(tmp) != 2 { - return fmt.Errorf("unable to parse key-value (%v)", kv) - } - k, v := tmp[0], tmp[1] + kvs, err := keyValParse(v0, ';') + if err != nil { + return err + } + + for k, rv := range kvs { + v := rv switch k { case "destination": diff --git a/pkg/headers/transport_test.go b/pkg/headers/transport_test.go index 3eca0e88..18085660 100644 --- a/pkg/headers/transport_test.go +++ b/pkg/headers/transport_test.go @@ -114,21 +114,43 @@ var casesTransport = []struct { } func TestTransportRead(t *testing.T) { - for _, c := range casesTransport { - t.Run(c.name, func(t *testing.T) { + for _, ca := range casesTransport { + t.Run(ca.name, func(t *testing.T) { var h Transport - err := h.Read(c.vin) + err := h.Read(ca.vin) require.NoError(t, err) - require.Equal(t, c.h, h) + require.Equal(t, ca.h, h) + }) + } +} + +func TestTransportReadError(t *testing.T) { + for _, ca := range []struct { + name string + hv base.HeaderValue + }{ + { + "empty", + base.HeaderValue{}, + }, + { + "2 values", + base.HeaderValue{"a", "b"}, + }, + } { + t.Run(ca.name, func(t *testing.T) { + var h Transport + err := h.Read(ca.hv) + require.Error(t, err) }) } } func TestTransportWrite(t *testing.T) { - for _, c := range casesTransport { - t.Run(c.name, func(t *testing.T) { - req := c.h.Write() - require.Equal(t, c.vout, req) + for _, ca := range casesTransport { + t.Run(ca.name, func(t *testing.T) { + req := ca.h.Write() + require.Equal(t, ca.vout, req) }) } }