Pre-compile globs for matching during authorization in acl

This commit is contained in:
Kelvin Clement Mwinuka
2024-01-06 20:58:26 +03:00
parent 90782ea5ff
commit 5668b759e5
4 changed files with 63 additions and 29 deletions

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gobwas/glob"
"github.com/kelvinmwinuka/memstore/src/utils"
"gopkg.in/yaml.v3"
"log"
@@ -18,11 +19,6 @@ import (
"time"
)
type Password struct {
PasswordType string `json:"PasswordType" yaml:"PasswordType"` // plaintext, SHA256
PasswordValue string `json:"PasswordValue" yaml:"PasswordValue"`
}
type Connection struct {
Authenticated bool
User *User
@@ -32,6 +28,7 @@ type ACL struct {
Users []*User
Connections map[*net.Conn]Connection
Config utils.Config
GlobPatterns map[string]glob.Glob
}
func NewACL(config utils.Config) *ACL {
@@ -98,8 +95,11 @@ func NewACL(config utils.Config) *ACL {
Users: users,
Connections: make(map[*net.Conn]Connection),
Config: config,
GlobPatterns: make(map[string]glob.Glob),
}
acl.CompileGlobs()
return &acl
}
@@ -119,7 +119,12 @@ func (acl *ACL) SetUser(ctx context.Context, cmd []string) error {
// If it does, replace user variable with this user
for _, user := range acl.Users {
if user.Username == cmd[0] {
return user.UpdateUser(cmd)
if err := user.UpdateUser(cmd); err != nil {
return err
} else {
acl.CompileGlobs()
return nil
}
}
}
@@ -133,6 +138,8 @@ func (acl *ACL) SetUser(ctx context.Context, cmd []string) error {
// Add user to ACL
acl.Users = append(acl.Users, user)
acl.CompileGlobs()
return nil
}
@@ -320,14 +327,14 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
// In PUBSUB, KeyExtractionFunc returns channels so keys[0] is aliased to channel
channel := keys[0]
// 2.1) Check if the channel is in IncludedPubSubChannels
if !slices.ContainsFunc(connection.User.IncludedPubSubChannels, func(includedChannel string) bool {
return utils.GlobMatches(includedChannel, channel)
if !slices.ContainsFunc(connection.User.IncludedPubSubChannels, func(includedChannelGlob string) bool {
return acl.GlobPatterns[includedChannelGlob].Match(channel)
}) {
return fmt.Errorf("not authorised to access channel &%s", channel)
}
// 2.2) Check if the channel is in ExcludedPubSubChannels
if slices.ContainsFunc(connection.User.ExcludedPubSubChannels, func(excludedChannel string) bool {
return utils.GlobMatches(excludedChannel, channel)
if slices.ContainsFunc(connection.User.ExcludedPubSubChannels, func(excludedChannelGlob string) bool {
return acl.GlobPatterns[excludedChannelGlob].Match(channel)
}) {
return fmt.Errorf("not authorised to access channel &%s", channel)
}
@@ -336,8 +343,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
// 7. Check if keys are in IncludedKeys
if len(keys) > 0 && !slices.ContainsFunc(keys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedKeys, func(includedKey string) bool {
if utils.GlobMatches(includedKey, key) {
return slices.ContainsFunc(connection.User.IncludedKeys, func(includedKeyGlob string) bool {
if acl.GlobPatterns[includedKeyGlob].Match(key) {
return true
}
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%RW", key))
@@ -349,8 +356,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
// 8. If @read is in the list of categories, check if keys are in IncludedReadKeys
if len(keys) > 0 && slices.Contains(categories, utils.ReadCategory) && !slices.ContainsFunc(keys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKey string) bool {
if utils.GlobMatches(readKey, key) {
return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKeyGlob string) bool {
if acl.GlobPatterns[readKeyGlob].Match(key) {
return true
}
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%R", key))
@@ -362,8 +369,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
// 9. If @write is in the list of categories, check if keys are in IncludedWriteKeys
if len(keys) > 0 && slices.Contains(categories, utils.WriteCategory) && !slices.ContainsFunc(keys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKey string) bool {
if utils.GlobMatches(writeKey, key) {
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool {
if acl.GlobPatterns[writeKeyGlob].Match(key) {
return true
}
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%W", key))
@@ -375,3 +382,31 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
return nil
}
func (acl *ACL) CompileGlobs() {
// Extract all the relevant globs from all the users
var allGlobs []string
var userGlobs []string
for _, user := range acl.Users {
userGlobs = append(userGlobs, user.IncludedPubSubChannels...)
userGlobs = append(userGlobs, user.ExcludedPubSubChannels...)
userGlobs = append(userGlobs, user.IncludedKeys...)
userGlobs = append(userGlobs, user.IncludedReadKeys...)
userGlobs = append(userGlobs, user.IncludedWriteKeys...)
for _, g := range userGlobs {
if !slices.Contains(allGlobs, g) {
allGlobs = append(allGlobs, g)
}
}
userGlobs = []string{}
}
// Compile the globs that have not been compiled yet
for _, g := range allGlobs {
if acl.GlobPatterns[g] == nil {
fmt.Println("COMPILING GLOB ", g)
acl.GlobPatterns[g] = glob.MustCompile(g)
} else {
fmt.Println("GLOB ", g, "ALREADY COMPILED, SKIPPING...")
}
}
}

View File

@@ -354,7 +354,7 @@ func handleList(ctx context.Context, cmd []string, server utils.Server, conn *ne
res = res + fmt.Sprintf("\r\n$%d\r\n%s", len(s), s)
}
res = res + "\r\n\n"
res = res + "\r\n\r\n"
return []byte(res), nil
}

View File

@@ -6,6 +6,11 @@ import (
"strings"
)
type Password struct {
PasswordType string `json:"PasswordType" yaml:"PasswordType"` // plaintext, SHA256
PasswordValue string `json:"PasswordValue" yaml:"PasswordValue"`
}
type User struct {
Username string `json:"Username" yaml:"Username"`
Enabled bool `json:"Enabled" yaml:"Enabled"`

View File

@@ -9,7 +9,6 @@ import (
"strings"
"time"
"github.com/gobwas/glob"
"github.com/sethvargo/go-retry"
"github.com/tidwall/resp"
)
@@ -148,8 +147,3 @@ func AbsInt(n int) int {
}
return n
}
func GlobMatches(pattern, s string) bool {
g := glob.MustCompile(pattern)
return g.Match(s)
}