Pre-compile globs for matching during authorization in acl

This commit is contained in:
Kelvin Clement Mwinuka
2024-01-06 20:58:26 +03:00
parent 90782ea5ff
commit 5668b759e5
4 changed files with 63 additions and 29 deletions

View File

@@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gobwas/glob"
"github.com/kelvinmwinuka/memstore/src/utils" "github.com/kelvinmwinuka/memstore/src/utils"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"log" "log"
@@ -18,20 +19,16 @@ import (
"time" "time"
) )
type Password struct {
PasswordType string `json:"PasswordType" yaml:"PasswordType"` // plaintext, SHA256
PasswordValue string `json:"PasswordValue" yaml:"PasswordValue"`
}
type Connection struct { type Connection struct {
Authenticated bool Authenticated bool
User *User User *User
} }
type ACL struct { type ACL struct {
Users []*User Users []*User
Connections map[*net.Conn]Connection Connections map[*net.Conn]Connection
Config utils.Config Config utils.Config
GlobPatterns map[string]glob.Glob
} }
func NewACL(config utils.Config) *ACL { func NewACL(config utils.Config) *ACL {
@@ -95,11 +92,14 @@ func NewACL(config utils.Config) *ACL {
} }
acl := ACL{ acl := ACL{
Users: users, Users: users,
Connections: make(map[*net.Conn]Connection), Connections: make(map[*net.Conn]Connection),
Config: config, Config: config,
GlobPatterns: make(map[string]glob.Glob),
} }
acl.CompileGlobs()
return &acl return &acl
} }
@@ -119,7 +119,12 @@ func (acl *ACL) SetUser(ctx context.Context, cmd []string) error {
// If it does, replace user variable with this user // If it does, replace user variable with this user
for _, user := range acl.Users { for _, user := range acl.Users {
if user.Username == cmd[0] { if user.Username == cmd[0] {
return user.UpdateUser(cmd) if err := user.UpdateUser(cmd); err != nil {
return err
} else {
acl.CompileGlobs()
return nil
}
} }
} }
@@ -133,6 +138,8 @@ func (acl *ACL) SetUser(ctx context.Context, cmd []string) error {
// Add user to ACL // Add user to ACL
acl.Users = append(acl.Users, user) acl.Users = append(acl.Users, user)
acl.CompileGlobs()
return nil return nil
} }
@@ -320,14 +327,14 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
// In PUBSUB, KeyExtractionFunc returns channels so keys[0] is aliased to channel // In PUBSUB, KeyExtractionFunc returns channels so keys[0] is aliased to channel
channel := keys[0] channel := keys[0]
// 2.1) Check if the channel is in IncludedPubSubChannels // 2.1) Check if the channel is in IncludedPubSubChannels
if !slices.ContainsFunc(connection.User.IncludedPubSubChannels, func(includedChannel string) bool { if !slices.ContainsFunc(connection.User.IncludedPubSubChannels, func(includedChannelGlob string) bool {
return utils.GlobMatches(includedChannel, channel) return acl.GlobPatterns[includedChannelGlob].Match(channel)
}) { }) {
return fmt.Errorf("not authorised to access channel &%s", channel) return fmt.Errorf("not authorised to access channel &%s", channel)
} }
// 2.2) Check if the channel is in ExcludedPubSubChannels // 2.2) Check if the channel is in ExcludedPubSubChannels
if slices.ContainsFunc(connection.User.ExcludedPubSubChannels, func(excludedChannel string) bool { if slices.ContainsFunc(connection.User.ExcludedPubSubChannels, func(excludedChannelGlob string) bool {
return utils.GlobMatches(excludedChannel, channel) return acl.GlobPatterns[excludedChannelGlob].Match(channel)
}) { }) {
return fmt.Errorf("not authorised to access channel &%s", channel) return fmt.Errorf("not authorised to access channel &%s", channel)
} }
@@ -336,8 +343,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
// 7. Check if keys are in IncludedKeys // 7. Check if keys are in IncludedKeys
if len(keys) > 0 && !slices.ContainsFunc(keys, func(key string) bool { if len(keys) > 0 && !slices.ContainsFunc(keys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedKeys, func(includedKey string) bool { return slices.ContainsFunc(connection.User.IncludedKeys, func(includedKeyGlob string) bool {
if utils.GlobMatches(includedKey, key) { if acl.GlobPatterns[includedKeyGlob].Match(key) {
return true return true
} }
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%RW", key)) notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%RW", key))
@@ -349,8 +356,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
// 8. If @read is in the list of categories, check if keys are in IncludedReadKeys // 8. If @read is in the list of categories, check if keys are in IncludedReadKeys
if len(keys) > 0 && slices.Contains(categories, utils.ReadCategory) && !slices.ContainsFunc(keys, func(key string) bool { if len(keys) > 0 && slices.Contains(categories, utils.ReadCategory) && !slices.ContainsFunc(keys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKey string) bool { return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKeyGlob string) bool {
if utils.GlobMatches(readKey, key) { if acl.GlobPatterns[readKeyGlob].Match(key) {
return true return true
} }
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%R", key)) notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%R", key))
@@ -362,8 +369,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
// 9. If @write is in the list of categories, check if keys are in IncludedWriteKeys // 9. If @write is in the list of categories, check if keys are in IncludedWriteKeys
if len(keys) > 0 && slices.Contains(categories, utils.WriteCategory) && !slices.ContainsFunc(keys, func(key string) bool { if len(keys) > 0 && slices.Contains(categories, utils.WriteCategory) && !slices.ContainsFunc(keys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKey string) bool { return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool {
if utils.GlobMatches(writeKey, key) { if acl.GlobPatterns[writeKeyGlob].Match(key) {
return true return true
} }
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%W", key)) notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%W", key))
@@ -375,3 +382,31 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
return nil return nil
} }
func (acl *ACL) CompileGlobs() {
// Extract all the relevant globs from all the users
var allGlobs []string
var userGlobs []string
for _, user := range acl.Users {
userGlobs = append(userGlobs, user.IncludedPubSubChannels...)
userGlobs = append(userGlobs, user.ExcludedPubSubChannels...)
userGlobs = append(userGlobs, user.IncludedKeys...)
userGlobs = append(userGlobs, user.IncludedReadKeys...)
userGlobs = append(userGlobs, user.IncludedWriteKeys...)
for _, g := range userGlobs {
if !slices.Contains(allGlobs, g) {
allGlobs = append(allGlobs, g)
}
}
userGlobs = []string{}
}
// Compile the globs that have not been compiled yet
for _, g := range allGlobs {
if acl.GlobPatterns[g] == nil {
fmt.Println("COMPILING GLOB ", g)
acl.GlobPatterns[g] = glob.MustCompile(g)
} else {
fmt.Println("GLOB ", g, "ALREADY COMPILED, SKIPPING...")
}
}
}

View File

@@ -354,7 +354,7 @@ func handleList(ctx context.Context, cmd []string, server utils.Server, conn *ne
res = res + fmt.Sprintf("\r\n$%d\r\n%s", len(s), s) res = res + fmt.Sprintf("\r\n$%d\r\n%s", len(s), s)
} }
res = res + "\r\n\n" res = res + "\r\n\r\n"
return []byte(res), nil return []byte(res), nil
} }

View File

@@ -6,6 +6,11 @@ import (
"strings" "strings"
) )
type Password struct {
PasswordType string `json:"PasswordType" yaml:"PasswordType"` // plaintext, SHA256
PasswordValue string `json:"PasswordValue" yaml:"PasswordValue"`
}
type User struct { type User struct {
Username string `json:"Username" yaml:"Username"` Username string `json:"Username" yaml:"Username"`
Enabled bool `json:"Enabled" yaml:"Enabled"` Enabled bool `json:"Enabled" yaml:"Enabled"`

View File

@@ -9,7 +9,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/gobwas/glob"
"github.com/sethvargo/go-retry" "github.com/sethvargo/go-retry"
"github.com/tidwall/resp" "github.com/tidwall/resp"
) )
@@ -148,8 +147,3 @@ func AbsInt(n int) int {
} }
return n return n
} }
func GlobMatches(pattern, s string) bool {
g := glob.MustCompile(pattern)
return g.Match(s)
}