Fixed data race issue when reading and writing ACL user data. Now, a write-lock is acquired before modifying the connection object as well in addition to the users list.

This commit is contained in:
Kelvin Clement Mwinuka
2024-06-02 15:06:12 +08:00
parent 66b6c4b809
commit 99be0fd4f0
5 changed files with 647 additions and 1814 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -99,7 +99,7 @@ type EchoVault struct {
snapshotEngine *snapshot.Engine // Snapshot engine for standalone mode. snapshotEngine *snapshot.Engine // Snapshot engine for standalone mode.
aofEngine *aof.Engine // AOF engine for standalone mode. aofEngine *aof.Engine // AOF engine for standalone mode.
listener net.Listener // TCP listener. listener atomic.Value // Holds the TCP listener.
quit chan struct{} // Channel that signals the closing of all client connections. quit chan struct{} // Channel that signals the closing of all client connections.
} }
@@ -385,7 +385,7 @@ func (server *EchoVault) startTCP() {
}) })
} }
server.listener = listener server.listener.Store(listener)
// Listen to connection. // Listen to connection.
for { for {
@@ -393,7 +393,7 @@ func (server *EchoVault) startTCP() {
case <-server.quit: case <-server.quit:
return return
default: default:
conn, err := server.listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
log.Printf("listener error: %v\n", err) log.Printf("listener error: %v\n", err)
continue continue
@@ -553,10 +553,10 @@ func (server *EchoVault) rewriteAOF() error {
// ShutDown gracefully shuts down the EchoVault instance. // ShutDown gracefully shuts down the EchoVault instance.
// This function shuts down the memberlist and raft layers. // This function shuts down the memberlist and raft layers.
func (server *EchoVault) ShutDown() { func (server *EchoVault) ShutDown() {
if server.listener != nil { if server.listener.Load() != nil {
go func() { server.quit <- struct{}{} }() go func() { server.quit <- struct{}{} }()
log.Println("closing tcp listener...") log.Println("closing tcp listener...")
if err := server.listener.Close(); err != nil { if err := server.listener.Load().(net.Listener).Close(); err != nil {
log.Printf("listener close: %v\n", err) log.Printf("listener close: %v\n", err)
} }
} }

View File

@@ -226,9 +226,6 @@ func (acl *ACL) DeleteUser(_ context.Context, usernames []string) error {
} }
func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []string) error { func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []string) error {
acl.RLockUsers()
defer acl.RUnlockUsers()
var passwords []Password var passwords []Password
var user *User var user *User

View File

@@ -36,12 +36,75 @@ func handleAuth(params internal.HandlerFuncParams) ([]byte, error) {
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
acl.LockUsers()
defer acl.UnlockUsers()
if err := acl.AuthenticateConnection(params.Context, params.Connection, params.Command); err != nil { if err := acl.AuthenticateConnection(params.Context, params.Connection, params.Command); err != nil {
return nil, err return nil, err
} }
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handleCat(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) > 3 {
return nil, errors.New(constants.WrongArgsResponse)
}
categories := make(map[string][]string)
commands := params.GetAllCommands()
for _, command := range commands {
if len(command.SubCommands) == 0 {
for _, category := range command.Categories {
categories[category] = append(categories[category], command.Command)
}
continue
}
for _, subcommand := range command.SubCommands {
for _, category := range subcommand.Categories {
categories[category] = append(categories[category],
fmt.Sprintf("%s|%s", command.Command, subcommand.Command))
}
}
}
if len(params.Command) == 2 {
var cats []string
length := 0
for key, _ := range categories {
cats = append(cats, key)
length += 1
}
res := fmt.Sprintf("*%d", length)
for i, cat := range cats {
res = fmt.Sprintf("%s\r\n+%s", res, cat)
if i == len(cats)-1 {
res = res + "\r\n"
}
}
return []byte(res), nil
}
if len(params.Command) == 3 {
var res string
for category, commands := range categories {
if strings.EqualFold(category, params.Command[2]) {
res = fmt.Sprintf("*%d", len(commands))
for i, command := range commands {
res = fmt.Sprintf("%s\r\n+%s", res, command)
if i == len(commands)-1 {
res = res + "\r\n"
}
}
return []byte(res), nil
}
}
}
return nil, fmt.Errorf("category %s not found", strings.ToUpper(params.Command[2]))
}
func handleGetUser(params internal.HandlerFuncParams) ([]byte, error) { func handleGetUser(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) != 3 { if len(params.Command) != 3 {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
@@ -51,6 +114,8 @@ func handleGetUser(params internal.HandlerFuncParams) ([]byte, error) {
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
acl.RLockUsers()
defer acl.RUnlockUsers()
var user *User var user *User
userFound := false userFound := false
@@ -159,71 +224,12 @@ func handleGetUser(params internal.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleCat(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) > 3 {
return nil, errors.New(constants.WrongArgsResponse)
}
categories := make(map[string][]string)
commands := params.GetAllCommands()
for _, command := range commands {
if len(command.SubCommands) == 0 {
for _, category := range command.Categories {
categories[category] = append(categories[category], command.Command)
}
continue
}
for _, subcommand := range command.SubCommands {
for _, category := range subcommand.Categories {
categories[category] = append(categories[category],
fmt.Sprintf("%s|%s", command.Command, subcommand.Command))
}
}
}
if len(params.Command) == 2 {
var cats []string
length := 0
for key, _ := range categories {
cats = append(cats, key)
length += 1
}
res := fmt.Sprintf("*%d", length)
for i, cat := range cats {
res = fmt.Sprintf("%s\r\n+%s", res, cat)
if i == len(cats)-1 {
res = res + "\r\n"
}
}
return []byte(res), nil
}
if len(params.Command) == 3 {
var res string
for category, commands := range categories {
if strings.EqualFold(category, params.Command[2]) {
res = fmt.Sprintf("*%d", len(commands))
for i, command := range commands {
res = fmt.Sprintf("%s\r\n+%s", res, command)
if i == len(commands)-1 {
res = res + "\r\n"
}
}
return []byte(res), nil
}
}
}
return nil, fmt.Errorf("category %s not found", strings.ToUpper(params.Command[2]))
}
func handleUsers(params internal.HandlerFuncParams) ([]byte, error) { func handleUsers(params internal.HandlerFuncParams) ([]byte, error) {
acl, ok := params.GetACL().(*ACL) acl, ok := params.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
res := fmt.Sprintf("*%d", len(acl.Users)) res := fmt.Sprintf("*%d", len(acl.Users))
for _, user := range acl.Users { for _, user := range acl.Users {
res += fmt.Sprintf("\r\n$%d\r\n%s", len(user.Username), user.Username) res += fmt.Sprintf("\r\n$%d\r\n%s", len(user.Username), user.Username)
@@ -262,6 +268,9 @@ func handleWhoAmI(params internal.HandlerFuncParams) ([]byte, error) {
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
acl.RLockUsers()
defer acl.RUnlockUsers()
connectionInfo := acl.Connections[params.Connection] connectionInfo := acl.Connections[params.Connection]
return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil
} }
@@ -274,6 +283,9 @@ func handleList(params internal.HandlerFuncParams) ([]byte, error) {
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
acl.RLockUsers()
defer acl.RUnlockUsers()
res := fmt.Sprintf("*%d", len(acl.Users)) res := fmt.Sprintf("*%d", len(acl.Users))
s := "" s := ""
for _, user := range acl.Users { for _, user := range acl.Users {
@@ -371,7 +383,6 @@ func handleLoad(params internal.HandlerFuncParams) ([]byte, error) {
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
acl.LockUsers() acl.LockUsers()
defer acl.UnlockUsers() defer acl.UnlockUsers()
@@ -438,9 +449,8 @@ func handleSave(params internal.HandlerFuncParams) ([]byte, error) {
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
acl.RLockUsers()
acl.LockUsers() defer acl.RUnlockUsers()
acl.UnlockUsers()
f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm)
if err != nil { if err != nil {

View File

@@ -27,6 +27,7 @@ import (
"path" "path"
"slices" "slices"
"strings" "strings"
"sync"
"testing" "testing"
) )
@@ -167,7 +168,6 @@ func Test_ACL(t *testing.T) {
t.Error(err) t.Error(err)
return return
} }
go func() { go func() {
mockServer.Start() mockServer.Start()
}() }()
@@ -710,8 +710,8 @@ func Test_ACL(t *testing.T) {
}, },
}, },
{ {
name: `10. Create user that can access some read keys and some write keys. name: `10. Create user that can access some read keys and some write keys.
Provide keys that are RW, W-Only and R-Only`, Provide keys that are RW, W-Only and R-Only`,
presetUser: nil, presetUser: nil,
cmd: []resp.Value{ cmd: []resp.Value{
resp.StringValue("ACL"), resp.StringValue("ACL"),
@@ -853,7 +853,7 @@ Provide keys that are RW, W-Only and R-Only`,
}, },
{ {
name: `16. Create new user with no password using 'nopass'. name: `16. Create new user with no password using 'nopass'.
When nopass is provided, ignore any passwords that may have been provided in the command.`, When nopass is provided, ignore any passwords that may have been provided in the command.`,
presetUser: nil, presetUser: nil,
cmd: []resp.Value{ cmd: []resp.Value{
resp.StringValue("ACL"), resp.StringValue("ACL"),
@@ -1566,10 +1566,6 @@ When nopass is provided, ignore any passwords that may have been provided in the
baseDir := path.Join(".", "testdata", "save") baseDir := path.Join(".", "testdata", "save")
t.Cleanup(func() {
_ = os.RemoveAll(baseDir)
})
tests := []struct { tests := []struct {
name string name string
path string path string
@@ -1610,10 +1606,22 @@ When nopass is provided, ignore any passwords that may have been provided in the
}, },
} }
for _, test := range tests { servers := make([]*echovault.EchoVault, len(tests))
mut := sync.Mutex{}
t.Cleanup(func() {
_ = os.RemoveAll(baseDir)
for _, server := range servers {
if server != nil {
server.ShutDown()
}
}
})
for i, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
t.Parallel() t.Parallel()
mut.Lock()
defer mut.Unlock()
// Get free port. // Get free port.
port, err := internal.GetFreePort() port, err := internal.GetFreePort()
if err != nil { if err != nil {
@@ -1627,7 +1635,7 @@ When nopass is provided, ignore any passwords that may have been provided in the
t.Error(err) t.Error(err)
return return
} }
servers[i] = mockServer
go func() { go func() {
mockServer.Start() mockServer.Start()
}() }()
@@ -1635,34 +1643,29 @@ When nopass is provided, ignore any passwords that may have been provided in the
conn, err := internal.GetConnection("localhost", port) conn, err := internal.GetConnection("localhost", port)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
mockServer.ShutDown()
return return
} }
client := resp.NewConn(conn) client := resp.NewConn(conn)
if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("SAVE")}); err != nil { if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("SAVE")}); err != nil {
t.Error(err) t.Error(err)
mockServer.ShutDown()
return return
} }
res, _, err := client.ReadValue() res, _, err := client.ReadValue()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
mockServer.ShutDown()
return return
} }
if !strings.EqualFold(res.String(), "ok") { if !strings.EqualFold(res.String(), "ok") {
t.Errorf("expected OK response, got \"%s\"", res.String()) t.Errorf("expected OK response, got \"%s\"", res.String())
mockServer.ShutDown()
return return
} }
// Close client connection // Close client connection
if err = conn.Close(); err != nil { if err = conn.Close(); err != nil {
t.Error(err) t.Error(err)
mockServer.ShutDown()
return return
} }
@@ -1687,21 +1690,18 @@ When nopass is provided, ignore any passwords that may have been provided in the
conn, err = internal.GetConnection("localhost", port) conn, err = internal.GetConnection("localhost", port)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
mockServer.ShutDown()
return return
} }
client = resp.NewConn(conn) client = resp.NewConn(conn)
if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("LIST")}); err != nil { if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("LIST")}); err != nil {
t.Error(err) t.Error(err)
mockServer.ShutDown()
return return
} }
res, _, err = client.ReadValue() res, _, err = client.ReadValue()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
mockServer.ShutDown()
return return
} }
@@ -1709,7 +1709,6 @@ When nopass is provided, ignore any passwords that may have been provided in the
resArr := res.Array() resArr := res.Array()
if len(resArr) != len(test.want) { if len(resArr) != len(test.want) {
t.Errorf("expected response of lenght %d, got lenght %d", len(test.want), len(resArr)) t.Errorf("expected response of lenght %d, got lenght %d", len(test.want), len(resArr))
mockServer.ShutDown()
return return
} }
@@ -1721,12 +1720,9 @@ When nopass is provided, ignore any passwords that may have been provided in the
return compareSlices(resStr, expectedUserSlice) == nil return compareSlices(resStr, expectedUserSlice) == nil
}) { }) {
t.Errorf("could not find the following user in expected slice: %+v", resStr) t.Errorf("could not find the following user in expected slice: %+v", resStr)
mockServer.ShutDown()
return return
} }
} }
mockServer.ShutDown()
}) })
} }
}) })
@@ -1736,19 +1732,6 @@ When nopass is provided, ignore any passwords that may have been provided in the
baseDir := path.Join(".", "testdata", "load") baseDir := path.Join(".", "testdata", "load")
t.Cleanup(func() {
_ = os.RemoveAll(baseDir)
})
servers := make([]*echovault.EchoVault, 5)
defer func() {
for _, server := range servers {
if server != nil {
server.ShutDown()
}
}
}()
tests := []struct { tests := []struct {
name string name string
path string path string
@@ -1862,8 +1845,22 @@ When nopass is provided, ignore any passwords that may have been provided in the
}, },
} }
servers := make([]*echovault.EchoVault, len(tests))
mut := sync.Mutex{}
t.Cleanup(func() {
_ = os.RemoveAll(baseDir)
for _, server := range servers {
if server != nil {
server.ShutDown()
}
}
})
for i, test := range tests { for i, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
t.Parallel()
mut.Lock()
defer mut.Unlock()
// Create server with pre-generated users. // Create server with pre-generated users.
port, err := internal.GetFreePort() port, err := internal.GetFreePort()
if err != nil { if err != nil {
@@ -1907,7 +1904,6 @@ When nopass is provided, ignore any passwords that may have been provided in the
t.Error(err) t.Error(err)
return return
} }
fmt.Println("COMMAND WRITTEN")
res, _, err := client.ReadValue() res, _, err := client.ReadValue()
if err != nil { if err != nil {