mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-07 08:50:59 +08:00
Implemented locking mechanism for ACL LOAD and ACL SAVE commands
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -107,8 +107,8 @@ func NewACL(config utils.Config) *ACL {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (acl *ACL) RegisterConnection(conn *net.Conn) {
|
func (acl *ACL) RegisterConnection(conn *net.Conn) {
|
||||||
acl.UsersMutex.Lock()
|
acl.LockUsers()
|
||||||
defer acl.UsersMutex.Unlock()
|
defer acl.UnlockUsers()
|
||||||
|
|
||||||
// This is called only when a connection is established.
|
// This is called only when a connection is established.
|
||||||
defaultUserIdx := slices.IndexFunc(acl.Users, func(user *User) bool {
|
defaultUserIdx := slices.IndexFunc(acl.Users, func(user *User) bool {
|
||||||
@@ -122,8 +122,8 @@ func (acl *ACL) RegisterConnection(conn *net.Conn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (acl *ACL) SetUser(cmd []string) error {
|
func (acl *ACL) SetUser(cmd []string) error {
|
||||||
acl.UsersMutex.Lock()
|
acl.LockUsers()
|
||||||
defer acl.UsersMutex.Unlock()
|
defer acl.UnlockUsers()
|
||||||
|
|
||||||
// Check if user with the given username already exists
|
// Check if user with the given username already exists
|
||||||
// If it does, replace user variable with this user
|
// If it does, replace user variable with this user
|
||||||
@@ -154,8 +154,8 @@ func (acl *ACL) SetUser(cmd []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (acl *ACL) DeleteUser(_ context.Context, usernames []string) error {
|
func (acl *ACL) DeleteUser(_ context.Context, usernames []string) error {
|
||||||
acl.UsersMutex.Lock()
|
acl.LockUsers()
|
||||||
defer acl.UsersMutex.Unlock()
|
defer acl.UnlockUsers()
|
||||||
|
|
||||||
var user *User
|
var user *User
|
||||||
for _, username := range usernames {
|
for _, username := range usernames {
|
||||||
@@ -188,8 +188,8 @@ func (acl *ACL) DeleteUser(_ context.Context, usernames []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []string) error {
|
func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []string) error {
|
||||||
acl.UsersMutex.RLock()
|
acl.RLockUsers()
|
||||||
defer acl.UsersMutex.RUnlock()
|
defer acl.RUnlockUsers()
|
||||||
|
|
||||||
var passwords []Password
|
var passwords []Password
|
||||||
var user *User
|
var user *User
|
||||||
@@ -264,8 +264,8 @@ func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.Command, subCommand utils.SubCommand) error {
|
func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.Command, subCommand utils.SubCommand) error {
|
||||||
acl.UsersMutex.RLock()
|
acl.RLockUsers()
|
||||||
defer acl.UsersMutex.RUnlock()
|
defer acl.RUnlockUsers()
|
||||||
|
|
||||||
// Extract command, categories, and keys
|
// Extract command, categories, and keys
|
||||||
comm := command.Command
|
comm := command.Command
|
||||||
@@ -439,3 +439,19 @@ func (acl *ACL) CompileGlobs() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (acl *ACL) LockUsers() {
|
||||||
|
acl.UsersMutex.Lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acl *ACL) UnlockUsers() {
|
||||||
|
acl.UsersMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acl *ACL) RLockUsers() {
|
||||||
|
acl.UsersMutex.RLock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acl *ACL) RUnlockUsers() {
|
||||||
|
acl.UsersMutex.RUnlock()
|
||||||
|
}
|
||||||
|
@@ -359,6 +359,9 @@ func handleLoad(_ context.Context, cmd []string, server utils.Server, _ *net.Con
|
|||||||
return nil, errors.New("could not load ACL")
|
return nil, errors.New("could not load ACL")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
acl.LockUsers()
|
||||||
|
defer acl.RUnlockUsers()
|
||||||
|
|
||||||
f, err := os.Open(acl.Config.AclConfig)
|
f, err := os.Open(acl.Config.AclConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -423,6 +426,9 @@ func handleSave(_ context.Context, cmd []string, server utils.Server, _ *net.Con
|
|||||||
return nil, errors.New("could not load ACL")
|
return nil, errors.New("could not load ACL")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
acl.RLockUsers()
|
||||||
|
acl.RUnlockUsers()
|
||||||
|
|
||||||
f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE, os.ModeAppend)
|
f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE, os.ModeAppend)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -566,7 +572,7 @@ If the optional category is provided, list all the commands in the category`,
|
|||||||
Description: `
|
Description: `
|
||||||
(ACL LOAD <MERGE | REPLACE>) Reloads the rules from the configured ACL config file.
|
(ACL LOAD <MERGE | REPLACE>) Reloads the rules from the configured ACL config file.
|
||||||
When 'MERGE' is passed, users from config file who share a username with users in memory will be merged.
|
When 'MERGE' is passed, users from config file who share a username with users in memory will be merged.
|
||||||
When 'REPLACED' is passed, users from config file who share a username with users in memory will replace the user in memory.`,
|
When 'REPLACE' is passed, users from config file who share a username with users in memory will replace the user in memory.`,
|
||||||
Sync: true,
|
Sync: true,
|
||||||
KeyExtractionFunc: func(cmd []string) ([]string, error) {
|
KeyExtractionFunc: func(cmd []string) ([]string, error) {
|
||||||
return []string{}, nil
|
return []string{}, nil
|
||||||
|
@@ -23,14 +23,14 @@ func init() {
|
|||||||
bindAddr = "localhost"
|
bindAddr = "localhost"
|
||||||
port = 7490
|
port = 7490
|
||||||
|
|
||||||
mockServer = setUpServer(bindAddr, port, true)
|
mockServer = setUpServer(bindAddr, port, true, "")
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
mockServer.Start(context.Background())
|
mockServer.Start(context.Background())
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func setUpServer(bindAddr string, port uint16, requirePass bool) *server.Server {
|
func setUpServer(bindAddr string, port uint16, requirePass bool, aclConfig string) *server.Server {
|
||||||
config := utils.Config{
|
config := utils.Config{
|
||||||
BindAddr: bindAddr,
|
BindAddr: bindAddr,
|
||||||
Port: port,
|
Port: port,
|
||||||
@@ -38,6 +38,7 @@ func setUpServer(bindAddr string, port uint16, requirePass bool) *server.Server
|
|||||||
EvictionPolicy: utils.NoEviction,
|
EvictionPolicy: utils.NoEviction,
|
||||||
RequirePass: requirePass,
|
RequirePass: requirePass,
|
||||||
Password: "password1",
|
Password: "password1",
|
||||||
|
AclConfig: aclConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
acl = NewACL(config)
|
acl = NewACL(config)
|
||||||
@@ -379,7 +380,7 @@ func Test_HandleCat(t *testing.T) {
|
|||||||
|
|
||||||
func Test_HandleUsers(t *testing.T) {
|
func Test_HandleUsers(t *testing.T) {
|
||||||
var port uint16 = 7491
|
var port uint16 = 7491
|
||||||
mockServer := setUpServer(bindAddr, port, false)
|
mockServer := setUpServer(bindAddr, port, false, "")
|
||||||
go func() {
|
go func() {
|
||||||
mockServer.Start(context.Background())
|
mockServer.Start(context.Background())
|
||||||
}()
|
}()
|
||||||
@@ -428,7 +429,7 @@ func Test_HandleUsers(t *testing.T) {
|
|||||||
|
|
||||||
func Test_HandleSetUser(t *testing.T) {
|
func Test_HandleSetUser(t *testing.T) {
|
||||||
var port uint16 = 7492
|
var port uint16 = 7492
|
||||||
mockServer := setUpServer(bindAddr, port, false)
|
mockServer := setUpServer(bindAddr, port, false, "")
|
||||||
go func() {
|
go func() {
|
||||||
mockServer.Start(context.Background())
|
mockServer.Start(context.Background())
|
||||||
}()
|
}()
|
||||||
@@ -1017,7 +1018,7 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
|
|
||||||
func Test_HandleGetUser(t *testing.T) {
|
func Test_HandleGetUser(t *testing.T) {
|
||||||
var port uint16 = 7493
|
var port uint16 = 7493
|
||||||
mockServer := setUpServer(bindAddr, port, false)
|
mockServer := setUpServer(bindAddr, port, false, "")
|
||||||
go func() {
|
go func() {
|
||||||
mockServer.Start(context.Background())
|
mockServer.Start(context.Background())
|
||||||
}()
|
}()
|
||||||
@@ -1163,7 +1164,7 @@ func Test_HandleGetUser(t *testing.T) {
|
|||||||
|
|
||||||
func Test_HandleDelUser(t *testing.T) {
|
func Test_HandleDelUser(t *testing.T) {
|
||||||
var port uint16 = 7494
|
var port uint16 = 7494
|
||||||
mockServer := setUpServer(bindAddr, port, false)
|
mockServer := setUpServer(bindAddr, port, false, "")
|
||||||
go func() {
|
go func() {
|
||||||
mockServer.Start(context.Background())
|
mockServer.Start(context.Background())
|
||||||
}()
|
}()
|
||||||
@@ -1304,7 +1305,7 @@ func Test_HandleWhoAmI(t *testing.T) {
|
|||||||
|
|
||||||
func Test_HandleList(t *testing.T) {
|
func Test_HandleList(t *testing.T) {
|
||||||
var port uint16 = 7495
|
var port uint16 = 7495
|
||||||
mockServer := setUpServer(bindAddr, port, false)
|
mockServer := setUpServer(bindAddr, port, false, "")
|
||||||
go func() {
|
go func() {
|
||||||
mockServer.Start(context.Background())
|
mockServer.Start(context.Background())
|
||||||
}()
|
}()
|
||||||
@@ -1428,7 +1429,3 @@ func Test_HandleList(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_HandleLoad(t *testing.T) {}
|
|
||||||
|
|
||||||
func Test_HandleSave(t *testing.T) {}
|
|
||||||
|
Reference in New Issue
Block a user