mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-24 16:30:21 +08:00
Fixed data race issue when reading and writing ACL user data. Now, a write-lock is acquired before modifying the connection object as well in addition to the users list.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -99,7 +99,7 @@ type EchoVault struct {
|
||||
snapshotEngine *snapshot.Engine // Snapshot engine for standalone mode.
|
||||
aofEngine *aof.Engine // AOF engine for standalone mode.
|
||||
|
||||
listener net.Listener // TCP listener.
|
||||
listener atomic.Value // Holds the TCP listener.
|
||||
quit chan struct{} // Channel that signals the closing of all client connections.
|
||||
}
|
||||
|
||||
@@ -385,7 +385,7 @@ func (server *EchoVault) startTCP() {
|
||||
})
|
||||
}
|
||||
|
||||
server.listener = listener
|
||||
server.listener.Store(listener)
|
||||
|
||||
// Listen to connection.
|
||||
for {
|
||||
@@ -393,7 +393,7 @@ func (server *EchoVault) startTCP() {
|
||||
case <-server.quit:
|
||||
return
|
||||
default:
|
||||
conn, err := server.listener.Accept()
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("listener error: %v\n", err)
|
||||
continue
|
||||
@@ -553,10 +553,10 @@ func (server *EchoVault) rewriteAOF() error {
|
||||
// ShutDown gracefully shuts down the EchoVault instance.
|
||||
// This function shuts down the memberlist and raft layers.
|
||||
func (server *EchoVault) ShutDown() {
|
||||
if server.listener != nil {
|
||||
if server.listener.Load() != nil {
|
||||
go func() { server.quit <- struct{}{} }()
|
||||
log.Println("closing tcp listener...")
|
||||
if err := server.listener.Close(); err != nil {
|
||||
if err := server.listener.Load().(net.Listener).Close(); err != nil {
|
||||
log.Printf("listener close: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
@@ -226,9 +226,6 @@ func (acl *ACL) DeleteUser(_ context.Context, usernames []string) error {
|
||||
}
|
||||
|
||||
func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []string) error {
|
||||
acl.RLockUsers()
|
||||
defer acl.RUnlockUsers()
|
||||
|
||||
var passwords []Password
|
||||
var user *User
|
||||
|
||||
|
@@ -36,12 +36,75 @@ func handleAuth(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
acl.LockUsers()
|
||||
defer acl.UnlockUsers()
|
||||
|
||||
if err := acl.AuthenticateConnection(params.Context, params.Connection, params.Command); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(constants.OkResponse), nil
|
||||
}
|
||||
|
||||
func handleCat(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
if len(params.Command) > 3 {
|
||||
return nil, errors.New(constants.WrongArgsResponse)
|
||||
}
|
||||
|
||||
categories := make(map[string][]string)
|
||||
|
||||
commands := params.GetAllCommands()
|
||||
|
||||
for _, command := range commands {
|
||||
if len(command.SubCommands) == 0 {
|
||||
for _, category := range command.Categories {
|
||||
categories[category] = append(categories[category], command.Command)
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, subcommand := range command.SubCommands {
|
||||
for _, category := range subcommand.Categories {
|
||||
categories[category] = append(categories[category],
|
||||
fmt.Sprintf("%s|%s", command.Command, subcommand.Command))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(params.Command) == 2 {
|
||||
var cats []string
|
||||
length := 0
|
||||
for key, _ := range categories {
|
||||
cats = append(cats, key)
|
||||
length += 1
|
||||
}
|
||||
res := fmt.Sprintf("*%d", length)
|
||||
for i, cat := range cats {
|
||||
res = fmt.Sprintf("%s\r\n+%s", res, cat)
|
||||
if i == len(cats)-1 {
|
||||
res = res + "\r\n"
|
||||
}
|
||||
}
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
if len(params.Command) == 3 {
|
||||
var res string
|
||||
for category, commands := range categories {
|
||||
if strings.EqualFold(category, params.Command[2]) {
|
||||
res = fmt.Sprintf("*%d", len(commands))
|
||||
for i, command := range commands {
|
||||
res = fmt.Sprintf("%s\r\n+%s", res, command)
|
||||
if i == len(commands)-1 {
|
||||
res = res + "\r\n"
|
||||
}
|
||||
}
|
||||
return []byte(res), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("category %s not found", strings.ToUpper(params.Command[2]))
|
||||
}
|
||||
|
||||
func handleGetUser(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
if len(params.Command) != 3 {
|
||||
return nil, errors.New(constants.WrongArgsResponse)
|
||||
@@ -51,6 +114,8 @@ func handleGetUser(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
acl.RLockUsers()
|
||||
defer acl.RUnlockUsers()
|
||||
|
||||
var user *User
|
||||
userFound := false
|
||||
@@ -159,71 +224,12 @@ func handleGetUser(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleCat(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
if len(params.Command) > 3 {
|
||||
return nil, errors.New(constants.WrongArgsResponse)
|
||||
}
|
||||
|
||||
categories := make(map[string][]string)
|
||||
|
||||
commands := params.GetAllCommands()
|
||||
|
||||
for _, command := range commands {
|
||||
if len(command.SubCommands) == 0 {
|
||||
for _, category := range command.Categories {
|
||||
categories[category] = append(categories[category], command.Command)
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, subcommand := range command.SubCommands {
|
||||
for _, category := range subcommand.Categories {
|
||||
categories[category] = append(categories[category],
|
||||
fmt.Sprintf("%s|%s", command.Command, subcommand.Command))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(params.Command) == 2 {
|
||||
var cats []string
|
||||
length := 0
|
||||
for key, _ := range categories {
|
||||
cats = append(cats, key)
|
||||
length += 1
|
||||
}
|
||||
res := fmt.Sprintf("*%d", length)
|
||||
for i, cat := range cats {
|
||||
res = fmt.Sprintf("%s\r\n+%s", res, cat)
|
||||
if i == len(cats)-1 {
|
||||
res = res + "\r\n"
|
||||
}
|
||||
}
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
if len(params.Command) == 3 {
|
||||
var res string
|
||||
for category, commands := range categories {
|
||||
if strings.EqualFold(category, params.Command[2]) {
|
||||
res = fmt.Sprintf("*%d", len(commands))
|
||||
for i, command := range commands {
|
||||
res = fmt.Sprintf("%s\r\n+%s", res, command)
|
||||
if i == len(commands)-1 {
|
||||
res = res + "\r\n"
|
||||
}
|
||||
}
|
||||
return []byte(res), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("category %s not found", strings.ToUpper(params.Command[2]))
|
||||
}
|
||||
|
||||
func handleUsers(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
acl, ok := params.GetACL().(*ACL)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
|
||||
res := fmt.Sprintf("*%d", len(acl.Users))
|
||||
for _, user := range acl.Users {
|
||||
res += fmt.Sprintf("\r\n$%d\r\n%s", len(user.Username), user.Username)
|
||||
@@ -262,6 +268,9 @@ func handleWhoAmI(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
acl.RLockUsers()
|
||||
defer acl.RUnlockUsers()
|
||||
|
||||
connectionInfo := acl.Connections[params.Connection]
|
||||
return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil
|
||||
}
|
||||
@@ -274,6 +283,9 @@ func handleList(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
acl.RLockUsers()
|
||||
defer acl.RUnlockUsers()
|
||||
|
||||
res := fmt.Sprintf("*%d", len(acl.Users))
|
||||
s := ""
|
||||
for _, user := range acl.Users {
|
||||
@@ -371,7 +383,6 @@ func handleLoad(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
|
||||
acl.LockUsers()
|
||||
defer acl.UnlockUsers()
|
||||
|
||||
@@ -438,9 +449,8 @@ func handleSave(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
|
||||
acl.LockUsers()
|
||||
acl.UnlockUsers()
|
||||
acl.RLockUsers()
|
||||
defer acl.RUnlockUsers()
|
||||
|
||||
f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm)
|
||||
if err != nil {
|
||||
|
@@ -27,6 +27,7 @@ import (
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -167,7 +168,6 @@ func Test_ACL(t *testing.T) {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
mockServer.Start()
|
||||
}()
|
||||
@@ -711,7 +711,7 @@ func Test_ACL(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: `10. Create user that can access some read keys and some write keys.
|
||||
Provide keys that are RW, W-Only and R-Only`,
|
||||
Provide keys that are RW, W-Only and R-Only`,
|
||||
presetUser: nil,
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("ACL"),
|
||||
@@ -853,7 +853,7 @@ Provide keys that are RW, W-Only and R-Only`,
|
||||
},
|
||||
{
|
||||
name: `16. Create new user with no password using 'nopass'.
|
||||
When nopass is provided, ignore any passwords that may have been provided in the command.`,
|
||||
When nopass is provided, ignore any passwords that may have been provided in the command.`,
|
||||
presetUser: nil,
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("ACL"),
|
||||
@@ -1566,10 +1566,6 @@ When nopass is provided, ignore any passwords that may have been provided in the
|
||||
|
||||
baseDir := path.Join(".", "testdata", "save")
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(baseDir)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
@@ -1610,10 +1606,22 @@ When nopass is provided, ignore any passwords that may have been provided in the
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
servers := make([]*echovault.EchoVault, len(tests))
|
||||
mut := sync.Mutex{}
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(baseDir)
|
||||
for _, server := range servers {
|
||||
if server != nil {
|
||||
server.ShutDown()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
for i, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mut.Lock()
|
||||
defer mut.Unlock()
|
||||
// Get free port.
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
@@ -1627,7 +1635,7 @@ When nopass is provided, ignore any passwords that may have been provided in the
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
servers[i] = mockServer
|
||||
go func() {
|
||||
mockServer.Start()
|
||||
}()
|
||||
@@ -1635,34 +1643,29 @@ When nopass is provided, ignore any passwords that may have been provided in the
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("SAVE")}); err != nil {
|
||||
t.Error(err)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.EqualFold(res.String(), "ok") {
|
||||
t.Errorf("expected OK response, got \"%s\"", res.String())
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
|
||||
// Close client connection
|
||||
if err = conn.Close(); err != nil {
|
||||
t.Error(err)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1687,21 +1690,18 @@ When nopass is provided, ignore any passwords that may have been provided in the
|
||||
conn, err = internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
client = resp.NewConn(conn)
|
||||
|
||||
if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("LIST")}); err != nil {
|
||||
t.Error(err)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
|
||||
res, _, err = client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1709,7 +1709,6 @@ When nopass is provided, ignore any passwords that may have been provided in the
|
||||
resArr := res.Array()
|
||||
if len(resArr) != len(test.want) {
|
||||
t.Errorf("expected response of lenght %d, got lenght %d", len(test.want), len(resArr))
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1721,12 +1720,9 @@ When nopass is provided, ignore any passwords that may have been provided in the
|
||||
return compareSlices(resStr, expectedUserSlice) == nil
|
||||
}) {
|
||||
t.Errorf("could not find the following user in expected slice: %+v", resStr)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
mockServer.ShutDown()
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -1736,19 +1732,6 @@ When nopass is provided, ignore any passwords that may have been provided in the
|
||||
|
||||
baseDir := path.Join(".", "testdata", "load")
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(baseDir)
|
||||
})
|
||||
|
||||
servers := make([]*echovault.EchoVault, 5)
|
||||
defer func() {
|
||||
for _, server := range servers {
|
||||
if server != nil {
|
||||
server.ShutDown()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
@@ -1862,8 +1845,22 @@ When nopass is provided, ignore any passwords that may have been provided in the
|
||||
},
|
||||
}
|
||||
|
||||
servers := make([]*echovault.EchoVault, len(tests))
|
||||
mut := sync.Mutex{}
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(baseDir)
|
||||
for _, server := range servers {
|
||||
if server != nil {
|
||||
server.ShutDown()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
for i, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mut.Lock()
|
||||
defer mut.Unlock()
|
||||
// Create server with pre-generated users.
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
@@ -1907,7 +1904,6 @@ When nopass is provided, ignore any passwords that may have been provided in the
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
fmt.Println("COMMAND WRITTEN")
|
||||
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
|
Reference in New Issue
Block a user