mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-09 01:40:20 +08:00
Pre-compile globs for matching during authorization in acl
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gobwas/glob"
|
||||||
"github.com/kelvinmwinuka/memstore/src/utils"
|
"github.com/kelvinmwinuka/memstore/src/utils"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
"log"
|
"log"
|
||||||
@@ -18,20 +19,16 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Password struct {
|
|
||||||
PasswordType string `json:"PasswordType" yaml:"PasswordType"` // plaintext, SHA256
|
|
||||||
PasswordValue string `json:"PasswordValue" yaml:"PasswordValue"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Connection struct {
|
type Connection struct {
|
||||||
Authenticated bool
|
Authenticated bool
|
||||||
User *User
|
User *User
|
||||||
}
|
}
|
||||||
|
|
||||||
type ACL struct {
|
type ACL struct {
|
||||||
Users []*User
|
Users []*User
|
||||||
Connections map[*net.Conn]Connection
|
Connections map[*net.Conn]Connection
|
||||||
Config utils.Config
|
Config utils.Config
|
||||||
|
GlobPatterns map[string]glob.Glob
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewACL(config utils.Config) *ACL {
|
func NewACL(config utils.Config) *ACL {
|
||||||
@@ -95,11 +92,14 @@ func NewACL(config utils.Config) *ACL {
|
|||||||
}
|
}
|
||||||
|
|
||||||
acl := ACL{
|
acl := ACL{
|
||||||
Users: users,
|
Users: users,
|
||||||
Connections: make(map[*net.Conn]Connection),
|
Connections: make(map[*net.Conn]Connection),
|
||||||
Config: config,
|
Config: config,
|
||||||
|
GlobPatterns: make(map[string]glob.Glob),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
acl.CompileGlobs()
|
||||||
|
|
||||||
return &acl
|
return &acl
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,7 +119,12 @@ func (acl *ACL) SetUser(ctx context.Context, cmd []string) error {
|
|||||||
// If it does, replace user variable with this user
|
// If it does, replace user variable with this user
|
||||||
for _, user := range acl.Users {
|
for _, user := range acl.Users {
|
||||||
if user.Username == cmd[0] {
|
if user.Username == cmd[0] {
|
||||||
return user.UpdateUser(cmd)
|
if err := user.UpdateUser(cmd); err != nil {
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
acl.CompileGlobs()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,6 +138,8 @@ func (acl *ACL) SetUser(ctx context.Context, cmd []string) error {
|
|||||||
// Add user to ACL
|
// Add user to ACL
|
||||||
acl.Users = append(acl.Users, user)
|
acl.Users = append(acl.Users, user)
|
||||||
|
|
||||||
|
acl.CompileGlobs()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -320,14 +327,14 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
|
|||||||
// In PUBSUB, KeyExtractionFunc returns channels so keys[0] is aliased to channel
|
// In PUBSUB, KeyExtractionFunc returns channels so keys[0] is aliased to channel
|
||||||
channel := keys[0]
|
channel := keys[0]
|
||||||
// 2.1) Check if the channel is in IncludedPubSubChannels
|
// 2.1) Check if the channel is in IncludedPubSubChannels
|
||||||
if !slices.ContainsFunc(connection.User.IncludedPubSubChannels, func(includedChannel string) bool {
|
if !slices.ContainsFunc(connection.User.IncludedPubSubChannels, func(includedChannelGlob string) bool {
|
||||||
return utils.GlobMatches(includedChannel, channel)
|
return acl.GlobPatterns[includedChannelGlob].Match(channel)
|
||||||
}) {
|
}) {
|
||||||
return fmt.Errorf("not authorised to access channel &%s", channel)
|
return fmt.Errorf("not authorised to access channel &%s", channel)
|
||||||
}
|
}
|
||||||
// 2.2) Check if the channel is in ExcludedPubSubChannels
|
// 2.2) Check if the channel is in ExcludedPubSubChannels
|
||||||
if slices.ContainsFunc(connection.User.ExcludedPubSubChannels, func(excludedChannel string) bool {
|
if slices.ContainsFunc(connection.User.ExcludedPubSubChannels, func(excludedChannelGlob string) bool {
|
||||||
return utils.GlobMatches(excludedChannel, channel)
|
return acl.GlobPatterns[excludedChannelGlob].Match(channel)
|
||||||
}) {
|
}) {
|
||||||
return fmt.Errorf("not authorised to access channel &%s", channel)
|
return fmt.Errorf("not authorised to access channel &%s", channel)
|
||||||
}
|
}
|
||||||
@@ -336,8 +343,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
|
|||||||
|
|
||||||
// 7. Check if keys are in IncludedKeys
|
// 7. Check if keys are in IncludedKeys
|
||||||
if len(keys) > 0 && !slices.ContainsFunc(keys, func(key string) bool {
|
if len(keys) > 0 && !slices.ContainsFunc(keys, func(key string) bool {
|
||||||
return slices.ContainsFunc(connection.User.IncludedKeys, func(includedKey string) bool {
|
return slices.ContainsFunc(connection.User.IncludedKeys, func(includedKeyGlob string) bool {
|
||||||
if utils.GlobMatches(includedKey, key) {
|
if acl.GlobPatterns[includedKeyGlob].Match(key) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%RW", key))
|
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%RW", key))
|
||||||
@@ -349,8 +356,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
|
|||||||
|
|
||||||
// 8. If @read is in the list of categories, check if keys are in IncludedReadKeys
|
// 8. If @read is in the list of categories, check if keys are in IncludedReadKeys
|
||||||
if len(keys) > 0 && slices.Contains(categories, utils.ReadCategory) && !slices.ContainsFunc(keys, func(key string) bool {
|
if len(keys) > 0 && slices.Contains(categories, utils.ReadCategory) && !slices.ContainsFunc(keys, func(key string) bool {
|
||||||
return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKey string) bool {
|
return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKeyGlob string) bool {
|
||||||
if utils.GlobMatches(readKey, key) {
|
if acl.GlobPatterns[readKeyGlob].Match(key) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%R", key))
|
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%R", key))
|
||||||
@@ -362,8 +369,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
|
|||||||
|
|
||||||
// 9. If @write is in the list of categories, check if keys are in IncludedWriteKeys
|
// 9. If @write is in the list of categories, check if keys are in IncludedWriteKeys
|
||||||
if len(keys) > 0 && slices.Contains(categories, utils.WriteCategory) && !slices.ContainsFunc(keys, func(key string) bool {
|
if len(keys) > 0 && slices.Contains(categories, utils.WriteCategory) && !slices.ContainsFunc(keys, func(key string) bool {
|
||||||
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKey string) bool {
|
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool {
|
||||||
if utils.GlobMatches(writeKey, key) {
|
if acl.GlobPatterns[writeKeyGlob].Match(key) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%W", key))
|
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%W", key))
|
||||||
@@ -375,3 +382,31 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (acl *ACL) CompileGlobs() {
|
||||||
|
// Extract all the relevant globs from all the users
|
||||||
|
var allGlobs []string
|
||||||
|
var userGlobs []string
|
||||||
|
for _, user := range acl.Users {
|
||||||
|
userGlobs = append(userGlobs, user.IncludedPubSubChannels...)
|
||||||
|
userGlobs = append(userGlobs, user.ExcludedPubSubChannels...)
|
||||||
|
userGlobs = append(userGlobs, user.IncludedKeys...)
|
||||||
|
userGlobs = append(userGlobs, user.IncludedReadKeys...)
|
||||||
|
userGlobs = append(userGlobs, user.IncludedWriteKeys...)
|
||||||
|
for _, g := range userGlobs {
|
||||||
|
if !slices.Contains(allGlobs, g) {
|
||||||
|
allGlobs = append(allGlobs, g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
userGlobs = []string{}
|
||||||
|
}
|
||||||
|
// Compile the globs that have not been compiled yet
|
||||||
|
for _, g := range allGlobs {
|
||||||
|
if acl.GlobPatterns[g] == nil {
|
||||||
|
fmt.Println("COMPILING GLOB ", g)
|
||||||
|
acl.GlobPatterns[g] = glob.MustCompile(g)
|
||||||
|
} else {
|
||||||
|
fmt.Println("GLOB ", g, "ALREADY COMPILED, SKIPPING...")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -354,7 +354,7 @@ func handleList(ctx context.Context, cmd []string, server utils.Server, conn *ne
|
|||||||
res = res + fmt.Sprintf("\r\n$%d\r\n%s", len(s), s)
|
res = res + fmt.Sprintf("\r\n$%d\r\n%s", len(s), s)
|
||||||
}
|
}
|
||||||
|
|
||||||
res = res + "\r\n\n"
|
res = res + "\r\n\r\n"
|
||||||
return []byte(res), nil
|
return []byte(res), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -6,6 +6,11 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Password struct {
|
||||||
|
PasswordType string `json:"PasswordType" yaml:"PasswordType"` // plaintext, SHA256
|
||||||
|
PasswordValue string `json:"PasswordValue" yaml:"PasswordValue"`
|
||||||
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
Username string `json:"Username" yaml:"Username"`
|
Username string `json:"Username" yaml:"Username"`
|
||||||
Enabled bool `json:"Enabled" yaml:"Enabled"`
|
Enabled bool `json:"Enabled" yaml:"Enabled"`
|
||||||
|
@@ -9,7 +9,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gobwas/glob"
|
|
||||||
"github.com/sethvargo/go-retry"
|
"github.com/sethvargo/go-retry"
|
||||||
"github.com/tidwall/resp"
|
"github.com/tidwall/resp"
|
||||||
)
|
)
|
||||||
@@ -148,8 +147,3 @@ func AbsInt(n int) int {
|
|||||||
}
|
}
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
func GlobMatches(pattern, s string) bool {
|
|
||||||
g := glob.MustCompile(pattern)
|
|
||||||
return g.Match(s)
|
|
||||||
}
|
|
||||||
|
Reference in New Issue
Block a user