mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-10 02:10:17 +08:00
Pre-compile globs for matching during authorization in acl
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gobwas/glob"
|
||||
"github.com/kelvinmwinuka/memstore/src/utils"
|
||||
"gopkg.in/yaml.v3"
|
||||
"log"
|
||||
@@ -18,11 +19,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Password struct {
|
||||
PasswordType string `json:"PasswordType" yaml:"PasswordType"` // plaintext, SHA256
|
||||
PasswordValue string `json:"PasswordValue" yaml:"PasswordValue"`
|
||||
}
|
||||
|
||||
type Connection struct {
|
||||
Authenticated bool
|
||||
User *User
|
||||
@@ -32,6 +28,7 @@ type ACL struct {
|
||||
Users []*User
|
||||
Connections map[*net.Conn]Connection
|
||||
Config utils.Config
|
||||
GlobPatterns map[string]glob.Glob
|
||||
}
|
||||
|
||||
func NewACL(config utils.Config) *ACL {
|
||||
@@ -98,8 +95,11 @@ func NewACL(config utils.Config) *ACL {
|
||||
Users: users,
|
||||
Connections: make(map[*net.Conn]Connection),
|
||||
Config: config,
|
||||
GlobPatterns: make(map[string]glob.Glob),
|
||||
}
|
||||
|
||||
acl.CompileGlobs()
|
||||
|
||||
return &acl
|
||||
}
|
||||
|
||||
@@ -119,7 +119,12 @@ func (acl *ACL) SetUser(ctx context.Context, cmd []string) error {
|
||||
// If it does, replace user variable with this user
|
||||
for _, user := range acl.Users {
|
||||
if user.Username == cmd[0] {
|
||||
return user.UpdateUser(cmd)
|
||||
if err := user.UpdateUser(cmd); err != nil {
|
||||
return err
|
||||
} else {
|
||||
acl.CompileGlobs()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,6 +138,8 @@ func (acl *ACL) SetUser(ctx context.Context, cmd []string) error {
|
||||
// Add user to ACL
|
||||
acl.Users = append(acl.Users, user)
|
||||
|
||||
acl.CompileGlobs()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -320,14 +327,14 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
|
||||
// In PUBSUB, KeyExtractionFunc returns channels so keys[0] is aliased to channel
|
||||
channel := keys[0]
|
||||
// 2.1) Check if the channel is in IncludedPubSubChannels
|
||||
if !slices.ContainsFunc(connection.User.IncludedPubSubChannels, func(includedChannel string) bool {
|
||||
return utils.GlobMatches(includedChannel, channel)
|
||||
if !slices.ContainsFunc(connection.User.IncludedPubSubChannels, func(includedChannelGlob string) bool {
|
||||
return acl.GlobPatterns[includedChannelGlob].Match(channel)
|
||||
}) {
|
||||
return fmt.Errorf("not authorised to access channel &%s", channel)
|
||||
}
|
||||
// 2.2) Check if the channel is in ExcludedPubSubChannels
|
||||
if slices.ContainsFunc(connection.User.ExcludedPubSubChannels, func(excludedChannel string) bool {
|
||||
return utils.GlobMatches(excludedChannel, channel)
|
||||
if slices.ContainsFunc(connection.User.ExcludedPubSubChannels, func(excludedChannelGlob string) bool {
|
||||
return acl.GlobPatterns[excludedChannelGlob].Match(channel)
|
||||
}) {
|
||||
return fmt.Errorf("not authorised to access channel &%s", channel)
|
||||
}
|
||||
@@ -336,8 +343,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
|
||||
|
||||
// 7. Check if keys are in IncludedKeys
|
||||
if len(keys) > 0 && !slices.ContainsFunc(keys, func(key string) bool {
|
||||
return slices.ContainsFunc(connection.User.IncludedKeys, func(includedKey string) bool {
|
||||
if utils.GlobMatches(includedKey, key) {
|
||||
return slices.ContainsFunc(connection.User.IncludedKeys, func(includedKeyGlob string) bool {
|
||||
if acl.GlobPatterns[includedKeyGlob].Match(key) {
|
||||
return true
|
||||
}
|
||||
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%RW", key))
|
||||
@@ -349,8 +356,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
|
||||
|
||||
// 8. If @read is in the list of categories, check if keys are in IncludedReadKeys
|
||||
if len(keys) > 0 && slices.Contains(categories, utils.ReadCategory) && !slices.ContainsFunc(keys, func(key string) bool {
|
||||
return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKey string) bool {
|
||||
if utils.GlobMatches(readKey, key) {
|
||||
return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKeyGlob string) bool {
|
||||
if acl.GlobPatterns[readKeyGlob].Match(key) {
|
||||
return true
|
||||
}
|
||||
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%R", key))
|
||||
@@ -362,8 +369,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
|
||||
|
||||
// 9. If @write is in the list of categories, check if keys are in IncludedWriteKeys
|
||||
if len(keys) > 0 && slices.Contains(categories, utils.WriteCategory) && !slices.ContainsFunc(keys, func(key string) bool {
|
||||
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKey string) bool {
|
||||
if utils.GlobMatches(writeKey, key) {
|
||||
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool {
|
||||
if acl.GlobPatterns[writeKeyGlob].Match(key) {
|
||||
return true
|
||||
}
|
||||
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%W", key))
|
||||
@@ -375,3 +382,31 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (acl *ACL) CompileGlobs() {
|
||||
// Extract all the relevant globs from all the users
|
||||
var allGlobs []string
|
||||
var userGlobs []string
|
||||
for _, user := range acl.Users {
|
||||
userGlobs = append(userGlobs, user.IncludedPubSubChannels...)
|
||||
userGlobs = append(userGlobs, user.ExcludedPubSubChannels...)
|
||||
userGlobs = append(userGlobs, user.IncludedKeys...)
|
||||
userGlobs = append(userGlobs, user.IncludedReadKeys...)
|
||||
userGlobs = append(userGlobs, user.IncludedWriteKeys...)
|
||||
for _, g := range userGlobs {
|
||||
if !slices.Contains(allGlobs, g) {
|
||||
allGlobs = append(allGlobs, g)
|
||||
}
|
||||
}
|
||||
userGlobs = []string{}
|
||||
}
|
||||
// Compile the globs that have not been compiled yet
|
||||
for _, g := range allGlobs {
|
||||
if acl.GlobPatterns[g] == nil {
|
||||
fmt.Println("COMPILING GLOB ", g)
|
||||
acl.GlobPatterns[g] = glob.MustCompile(g)
|
||||
} else {
|
||||
fmt.Println("GLOB ", g, "ALREADY COMPILED, SKIPPING...")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -354,7 +354,7 @@ func handleList(ctx context.Context, cmd []string, server utils.Server, conn *ne
|
||||
res = res + fmt.Sprintf("\r\n$%d\r\n%s", len(s), s)
|
||||
}
|
||||
|
||||
res = res + "\r\n\n"
|
||||
res = res + "\r\n\r\n"
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
|
@@ -6,6 +6,11 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Password struct {
|
||||
PasswordType string `json:"PasswordType" yaml:"PasswordType"` // plaintext, SHA256
|
||||
PasswordValue string `json:"PasswordValue" yaml:"PasswordValue"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string `json:"Username" yaml:"Username"`
|
||||
Enabled bool `json:"Enabled" yaml:"Enabled"`
|
||||
|
@@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/glob"
|
||||
"github.com/sethvargo/go-retry"
|
||||
"github.com/tidwall/resp"
|
||||
)
|
||||
@@ -148,8 +147,3 @@ func AbsInt(n int) int {
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func GlobMatches(pattern, s string) bool {
|
||||
g := glob.MustCompile(pattern)
|
||||
return g.Match(s)
|
||||
}
|
||||
|
Reference in New Issue
Block a user