diff --git a/cluster/cluster.go b/cluster/cluster.go index 548345d0..085703de 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -1007,6 +1007,10 @@ func (c *cluster) AddIdentity(origin string, identity iamidentity.User) error { return ErrDegraded } + if err := identity.Validate(); err != nil { + return fmt.Errorf("invalid identity: %w", err) + } + if !c.IsRaftLeader() { return c.forwarder.AddIdentity(origin, identity) } diff --git a/encoding/token/token.go b/encoding/token/token.go new file mode 100644 index 00000000..45c2b175 --- /dev/null +++ b/encoding/token/token.go @@ -0,0 +1,45 @@ +package token + +import "strings" + +func Marshal(username, token string) string { + username = strings.ReplaceAll(username, ":", "::") + + return username + ":" + token +} + +// Unmarshal splits a username/token combination into a username and +// token. If the input doesn't contain a username, the whole input +// is returned as token. +func Unmarshal(usertoken string) (string, string) { + r := []rune(usertoken) + + count := 0 + index := -1 + for i, ru := range r { + if ru == ':' { + count++ + continue + } + + if count > 0 && count%2 != 0 { + index = i - 1 + break + } + + count = 0 + } + + if index == -1 { + return "", usertoken + } + + before, after := r[:index], r[index+1:] + + username := string(before) + token := string(after) + + username = strings.ReplaceAll(username, "::", ":") + + return username, token +} diff --git a/encoding/token/token_test.go b/encoding/token/token_test.go new file mode 100644 index 00000000..ad0df144 --- /dev/null +++ b/encoding/token/token_test.go @@ -0,0 +1,41 @@ +package token + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMarshal(t *testing.T) { + token := "xxx" + + data := [][2]string{ + {"", ""}, + {"foo", "foo"}, + {"foo:bar", "foo::bar"}, + {"foo::bar", "foo::::bar"}, + } + + for _, d := range data { + encoded := Marshal(d[0], token) + require.Equal(t, d[1]+":"+token, encoded, d[1]) + } +} + +func TestUnmarshal(t *testing.T) { + data := [][3]string{ + {"foo", "", "foo"}, + {"fo::o", "", "fo::o"}, + {"::foo", "", "::foo"}, + {":xxx", "", "xxx"}, + {"foo:xxx", "foo", "xxx"}, + {"foo::bar:xxx", "foo:bar", "xxx"}, + {"foo::::bar:xxx", "foo::bar", "xxx"}, + } + + for _, d := range data { + username, token := Unmarshal(d[0]) + require.Equal(t, d[1], username, d[0]) + require.Equal(t, d[2], token) + } +} diff --git a/http/handler/api/cluster.go b/http/handler/api/cluster.go index cc8fb139..afbe1be4 100644 --- a/http/handler/api/cluster.go +++ b/http/handler/api/cluster.go @@ -744,7 +744,7 @@ func (h *ClusterHandler) DeleteProcess(c echo.Context) error { func (h *ClusterHandler) AddIdentity(c echo.Context) error { ctxuser := util.DefaultContext(c, "user", "") superuser := util.DefaultContext(c, "superuser", false) - domain := util.DefaultQuery(c, "domain", "$none") + domain := util.DefaultQuery(c, "domain", "") user := api.IAMUser{} @@ -799,7 +799,7 @@ func (h *ClusterHandler) AddIdentity(c echo.Context) error { func (h *ClusterHandler) UpdateIdentity(c echo.Context) error { ctxuser := util.DefaultContext(c, "user", "") superuser := util.DefaultContext(c, "superuser", false) - domain := util.DefaultQuery(c, "domain", "$none") + domain := util.DefaultQuery(c, "domain", "") name := util.PathParam(c, "name") if !h.iam.Enforce(ctxuser, domain, "iam:"+name, "write") { @@ -880,7 +880,7 @@ func (h *ClusterHandler) UpdateIdentity(c echo.Context) error { func (h *ClusterHandler) UpdateIdentityPolicies(c echo.Context) error { ctxuser := util.DefaultContext(c, "user", "") superuser := util.DefaultContext(c, "superuser", false) - domain := util.DefaultQuery(c, "domain", "$none") + domain := util.DefaultQuery(c, "domain", "") name := util.PathParam(c, "name") if !h.iam.Enforce(ctxuser, domain, "iam:"+name, "write") { @@ -952,7 +952,7 @@ func (h *ClusterHandler) UpdateIdentityPolicies(c echo.Context) error { // @Router /api/v3/cluster/db/user [get] func (h *ClusterHandler) ListStoreIdentities(c echo.Context) error { ctxuser := util.DefaultContext(c, "user", "") - domain := util.DefaultQuery(c, "domain", "$none") + domain := util.DefaultQuery(c, "domain", "") updatedAt, identities := h.cluster.ListIdentities() @@ -991,7 +991,7 @@ func (h *ClusterHandler) ListStoreIdentities(c echo.Context) error { // @Router /api/v3/cluster/db/user/{name} [get] func (h *ClusterHandler) ListStoreIdentity(c echo.Context) error { ctxuser := util.DefaultContext(c, "user", "") - domain := util.DefaultQuery(c, "domain", "$none") + domain := util.DefaultQuery(c, "domain", "") name := util.PathParam(c, "name") if !h.iam.Enforce(ctxuser, domain, "iam:"+name, "read") { @@ -1069,7 +1069,7 @@ func (h *ClusterHandler) ReloadIAM(c echo.Context) error { // @Router /api/v3/cluster/iam/user [get] func (h *ClusterHandler) ListIdentities(c echo.Context) error { ctxuser := util.DefaultContext(c, "user", "") - domain := util.DefaultQuery(c, "domain", "$none") + domain := util.DefaultQuery(c, "domain", "") identities := h.iam.ListIdentities() @@ -1115,7 +1115,7 @@ func (h *ClusterHandler) ListIdentities(c echo.Context) error { // @Router /api/v3/cluster/iam/user/{name} [get] func (h *ClusterHandler) ListIdentity(c echo.Context) error { ctxuser := util.DefaultContext(c, "user", "") - domain := util.DefaultQuery(c, "domain", "$none") + domain := util.DefaultQuery(c, "domain", "") name := util.PathParam(c, "name") if !h.iam.Enforce(ctxuser, domain, "iam:"+name, "read") { diff --git a/iam/identity/identity.go b/iam/identity/identity.go index 57ba5710..8453b774 100644 --- a/iam/identity/identity.go +++ b/iam/identity/identity.go @@ -2,10 +2,11 @@ package identity import ( "fmt" - "regexp" + "strings" "sync" "time" + enctoken "github.com/datarhei/core/v16/encoding/token" "github.com/datarhei/core/v16/iam/jwks" "github.com/datarhei/core/v16/log" "github.com/google/uuid" @@ -45,16 +46,13 @@ type UserAuthServices struct { Session []string `json:"session"` // Secrets for session JWT } -func (u *User) validate() error { +func (u *User) Validate() error { if len(u.Name) == 0 { return fmt.Errorf("a name is required") } - chars := `A-Za-z0-9._-` - - re := regexp.MustCompile(`[^` + chars + `]`) - if re.MatchString(u.Name) { - return fmt.Errorf("the name can only contain [%s]", chars) + if strings.HasPrefix(u.Name, "$") { + return fmt.Errorf("name is not allowed to start with $") } if len(u.Auth.API.Auth0.User) != 0 { @@ -363,7 +361,7 @@ func (i *identity) GetServiceToken() string { continue } - return i.Name() + ":" + token + return enctoken.Marshal(i.Name(), token) } return "" @@ -608,7 +606,7 @@ func (im *identityManager) Reload() error { continue } - if err := u.validate(); err != nil { + if err := u.Validate(); err != nil { continue } @@ -624,7 +622,7 @@ func (im *identityManager) Reload() error { } func (im *identityManager) Create(u User) error { - if err := u.validate(); err != nil { + if err := u.Validate(); err != nil { return err } @@ -689,7 +687,7 @@ func (im *identityManager) create(u User) (*identity, error) { } func (im *identityManager) Update(name string, u User) error { - if err := u.validate(); err != nil { + if err := u.Validate(); err != nil { return err } diff --git a/iam/identity/identity_test.go b/iam/identity/identity_test.go index 87c3e724..c864969a 100644 --- a/iam/identity/identity_test.go +++ b/iam/identity/identity_test.go @@ -19,15 +19,19 @@ func createAdapter() (Adapter, error) { func TestUserName(t *testing.T) { user := User{} - err := user.validate() + err := user.Validate() require.Error(t, err) user.Name = "foobar_5" - err = user.validate() + err = user.Validate() + require.NoError(t, err) + + user.Name = "foobar:5" + err = user.Validate() require.NoError(t, err) user.Name = "$foob:ar" - err = user.validate() + err = user.Validate() require.Error(t, err) } diff --git a/rtmp/rtmp.go b/rtmp/rtmp.go index be349b16..b5cf16ae 100644 --- a/rtmp/rtmp.go +++ b/rtmp/rtmp.go @@ -12,6 +12,7 @@ import ( "time" "github.com/datarhei/core/v16/cluster/proxy" + enctoken "github.com/datarhei/core/v16/encoding/token" "github.com/datarhei/core/v16/iam" iamidentity "github.com/datarhei/core/v16/iam/identity" "github.com/datarhei/core/v16/log" @@ -212,7 +213,7 @@ func (s *server) log(who, handler, action, resource, message string, client net. func GetToken(u *url.URL) (string, string) { q := u.Query() if q.Has("token") { - // The token was in the query. Return the unmomdified path and the token + // The token was in the query. Return the unmomdified path and the token. return u.Path, q.Get("token") } @@ -471,18 +472,16 @@ func (s *server) findIdentityFromStreamKey(key string) (string, error) { return "$anon", nil } - var identity iamidentity.Verifier - var err error + var identity iamidentity.Verifier = nil + var err error = nil var token string - before, after, found := strings.Cut(key, ":") - if !found { + username, token := enctoken.Unmarshal(key) + if len(username) == 0 { identity = s.iam.GetDefaultVerifier() - token = before } else { - identity, err = s.iam.GetVerifier(before) - token = after + identity, err = s.iam.GetVerifier(username) } if err != nil { diff --git a/srt/srt.go b/srt/srt.go index 443bbf77..d39ba1d6 100644 --- a/srt/srt.go +++ b/srt/srt.go @@ -10,6 +10,7 @@ import ( "time" "github.com/datarhei/core/v16/cluster/proxy" + enctoken "github.com/datarhei/core/v16/encoding/token" "github.com/datarhei/core/v16/iam" iamidentity "github.com/datarhei/core/v16/iam/identity" "github.com/datarhei/core/v16/log" @@ -493,15 +494,11 @@ func (s *server) findIdentityFromToken(key string) (string, error) { var identity iamidentity.Verifier var err error - var token string - - before, after, found := strings.Cut(key, ":") - if !found { + username, token := enctoken.Unmarshal(key) + if len(username) == 0 { identity = s.iam.GetDefaultVerifier() - token = before } else { - identity, err = s.iam.GetVerifier(before) - token = after + identity, err = s.iam.GetVerifier(username) } if err != nil {