Implemented locking mechanism for ACL LOAD and ACL SAVE commands

This commit is contained in:
Kelvin Mwinuka
2024-03-24 17:32:29 +08:00
parent 2a7d47271b
commit 0cd8a4aec2
4 changed files with 677 additions and 654 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -107,8 +107,8 @@ func NewACL(config utils.Config) *ACL {
} }
func (acl *ACL) RegisterConnection(conn *net.Conn) { func (acl *ACL) RegisterConnection(conn *net.Conn) {
acl.UsersMutex.Lock() acl.LockUsers()
defer acl.UsersMutex.Unlock() defer acl.UnlockUsers()
// This is called only when a connection is established. // This is called only when a connection is established.
defaultUserIdx := slices.IndexFunc(acl.Users, func(user *User) bool { defaultUserIdx := slices.IndexFunc(acl.Users, func(user *User) bool {
@@ -122,8 +122,8 @@ func (acl *ACL) RegisterConnection(conn *net.Conn) {
} }
func (acl *ACL) SetUser(cmd []string) error { func (acl *ACL) SetUser(cmd []string) error {
acl.UsersMutex.Lock() acl.LockUsers()
defer acl.UsersMutex.Unlock() defer acl.UnlockUsers()
// Check if user with the given username already exists // Check if user with the given username already exists
// If it does, replace user variable with this user // If it does, replace user variable with this user
@@ -154,8 +154,8 @@ func (acl *ACL) SetUser(cmd []string) error {
} }
func (acl *ACL) DeleteUser(_ context.Context, usernames []string) error { func (acl *ACL) DeleteUser(_ context.Context, usernames []string) error {
acl.UsersMutex.Lock() acl.LockUsers()
defer acl.UsersMutex.Unlock() defer acl.UnlockUsers()
var user *User var user *User
for _, username := range usernames { for _, username := range usernames {
@@ -188,8 +188,8 @@ func (acl *ACL) DeleteUser(_ context.Context, usernames []string) error {
} }
func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []string) error { func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []string) error {
acl.UsersMutex.RLock() acl.RLockUsers()
defer acl.UsersMutex.RUnlock() defer acl.RUnlockUsers()
var passwords []Password var passwords []Password
var user *User var user *User
@@ -264,8 +264,8 @@ func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []
} }
func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.Command, subCommand utils.SubCommand) error { func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.Command, subCommand utils.SubCommand) error {
acl.UsersMutex.RLock() acl.RLockUsers()
defer acl.UsersMutex.RUnlock() defer acl.RUnlockUsers()
// Extract command, categories, and keys // Extract command, categories, and keys
comm := command.Command comm := command.Command
@@ -439,3 +439,19 @@ func (acl *ACL) CompileGlobs() {
} }
} }
} }
func (acl *ACL) LockUsers() {
acl.UsersMutex.Lock()
}
func (acl *ACL) UnlockUsers() {
acl.UsersMutex.Unlock()
}
func (acl *ACL) RLockUsers() {
acl.UsersMutex.RLock()
}
func (acl *ACL) RUnlockUsers() {
acl.UsersMutex.RUnlock()
}

View File

@@ -359,6 +359,9 @@ func handleLoad(_ context.Context, cmd []string, server utils.Server, _ *net.Con
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
acl.LockUsers()
defer acl.RUnlockUsers()
f, err := os.Open(acl.Config.AclConfig) f, err := os.Open(acl.Config.AclConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -423,6 +426,9 @@ func handleSave(_ context.Context, cmd []string, server utils.Server, _ *net.Con
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
acl.RLockUsers()
acl.RUnlockUsers()
f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE, os.ModeAppend) f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE, os.ModeAppend)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -566,7 +572,7 @@ If the optional category is provided, list all the commands in the category`,
Description: ` Description: `
(ACL LOAD <MERGE | REPLACE>) Reloads the rules from the configured ACL config file. (ACL LOAD <MERGE | REPLACE>) Reloads the rules from the configured ACL config file.
When 'MERGE' is passed, users from config file who share a username with users in memory will be merged. When 'MERGE' is passed, users from config file who share a username with users in memory will be merged.
When 'REPLACED' is passed, users from config file who share a username with users in memory will replace the user in memory.`, When 'REPLACE' is passed, users from config file who share a username with users in memory will replace the user in memory.`,
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil return []string{}, nil

View File

@@ -23,14 +23,14 @@ func init() {
bindAddr = "localhost" bindAddr = "localhost"
port = 7490 port = 7490
mockServer = setUpServer(bindAddr, port, true) mockServer = setUpServer(bindAddr, port, true, "")
go func() { go func() {
mockServer.Start(context.Background()) mockServer.Start(context.Background())
}() }()
} }
func setUpServer(bindAddr string, port uint16, requirePass bool) *server.Server { func setUpServer(bindAddr string, port uint16, requirePass bool, aclConfig string) *server.Server {
config := utils.Config{ config := utils.Config{
BindAddr: bindAddr, BindAddr: bindAddr,
Port: port, Port: port,
@@ -38,6 +38,7 @@ func setUpServer(bindAddr string, port uint16, requirePass bool) *server.Server
EvictionPolicy: utils.NoEviction, EvictionPolicy: utils.NoEviction,
RequirePass: requirePass, RequirePass: requirePass,
Password: "password1", Password: "password1",
AclConfig: aclConfig,
} }
acl = NewACL(config) acl = NewACL(config)
@@ -379,7 +380,7 @@ func Test_HandleCat(t *testing.T) {
func Test_HandleUsers(t *testing.T) { func Test_HandleUsers(t *testing.T) {
var port uint16 = 7491 var port uint16 = 7491
mockServer := setUpServer(bindAddr, port, false) mockServer := setUpServer(bindAddr, port, false, "")
go func() { go func() {
mockServer.Start(context.Background()) mockServer.Start(context.Background())
}() }()
@@ -428,7 +429,7 @@ func Test_HandleUsers(t *testing.T) {
func Test_HandleSetUser(t *testing.T) { func Test_HandleSetUser(t *testing.T) {
var port uint16 = 7492 var port uint16 = 7492
mockServer := setUpServer(bindAddr, port, false) mockServer := setUpServer(bindAddr, port, false, "")
go func() { go func() {
mockServer.Start(context.Background()) mockServer.Start(context.Background())
}() }()
@@ -1017,7 +1018,7 @@ func Test_HandleSetUser(t *testing.T) {
func Test_HandleGetUser(t *testing.T) { func Test_HandleGetUser(t *testing.T) {
var port uint16 = 7493 var port uint16 = 7493
mockServer := setUpServer(bindAddr, port, false) mockServer := setUpServer(bindAddr, port, false, "")
go func() { go func() {
mockServer.Start(context.Background()) mockServer.Start(context.Background())
}() }()
@@ -1163,7 +1164,7 @@ func Test_HandleGetUser(t *testing.T) {
func Test_HandleDelUser(t *testing.T) { func Test_HandleDelUser(t *testing.T) {
var port uint16 = 7494 var port uint16 = 7494
mockServer := setUpServer(bindAddr, port, false) mockServer := setUpServer(bindAddr, port, false, "")
go func() { go func() {
mockServer.Start(context.Background()) mockServer.Start(context.Background())
}() }()
@@ -1304,7 +1305,7 @@ func Test_HandleWhoAmI(t *testing.T) {
func Test_HandleList(t *testing.T) { func Test_HandleList(t *testing.T) {
var port uint16 = 7495 var port uint16 = 7495
mockServer := setUpServer(bindAddr, port, false) mockServer := setUpServer(bindAddr, port, false, "")
go func() { go func() {
mockServer.Start(context.Background()) mockServer.Start(context.Background())
}() }()
@@ -1428,7 +1429,3 @@ func Test_HandleList(t *testing.T) {
} }
} }
} }
func Test_HandleLoad(t *testing.T) {}
func Test_HandleSave(t *testing.T) {}