Implemented test for AUTH command handler

This commit is contained in:
Kelvin Mwinuka
2024-03-21 14:32:20 +08:00
parent 0e5f8ff99d
commit 9191d16762
4 changed files with 172 additions and 28 deletions

View File

@@ -164,7 +164,7 @@ func (acl *ACL) DeleteUser(ctx context.Context, usernames []string) error {
// Terminate every connection attached to this user // Terminate every connection attached to this user
for connRef, connection := range acl.Connections { for connRef, connection := range acl.Connections {
if connection.User.Username == user.Username { if connection.User.Username == user.Username {
(*connRef).SetReadDeadline(time.Now().Add(-1 * time.Second)) _ = (*connRef).SetReadDeadline(time.Now().Add(-1 * time.Second))
} }
} }
// Delete the user from the ACL // Delete the user from the ACL
@@ -175,7 +175,7 @@ func (acl *ACL) DeleteUser(ctx context.Context, usernames []string) error {
return nil return nil
} }
func (acl *ACL) AuthenticateConnection(ctx context.Context, conn *net.Conn, cmd []string) error { func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []string) error {
var passwords []Password var passwords []Password
var user *User var user *User
@@ -194,6 +194,7 @@ func (acl *ACL) AuthenticateConnection(ctx context.Context, conn *net.Conn, cmd
}) })
user = acl.Users[idx] user = acl.Users[idx]
} }
if len(cmd) == 3 { if len(cmd) == 3 {
// Process AUTH <username> <password> // Process AUTH <username> <password>
h.Write([]byte(cmd[2])) h.Write([]byte(cmd[2]))
@@ -278,7 +279,6 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
// If the command is 'auth', then return early and allow it // If the command is 'auth', then return early and allow it
if strings.EqualFold(comm, "auth") { if strings.EqualFold(comm, "auth") {
// TODO: Add rate limiting to prevent auth spamming
return nil return nil
} }

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"github.com/echovault/echovault/src/utils" "github.com/echovault/echovault/src/utils"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"log"
"net" "net"
"os" "os"
"path" "path"
@@ -28,7 +29,7 @@ func handleAuth(ctx context.Context, cmd []string, server utils.Server, conn *ne
return []byte(utils.OkResponse), nil return []byte(utils.OkResponse), nil
} }
func handleGetUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handleGetUser(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) != 3 { if len(cmd) != 3 {
return nil, errors.New(utils.WrongArgsResponse) return nil, errors.New(utils.WrongArgsResponse)
} }
@@ -144,7 +145,7 @@ func handleGetUser(ctx context.Context, cmd []string, server utils.Server, conn
return []byte(res), nil return []byte(res), nil
} }
func handleCat(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handleCat(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) > 3 { if len(cmd) > 3 {
return nil, errors.New(utils.WrongArgsResponse) return nil, errors.New(utils.WrongArgsResponse)
} }
@@ -204,7 +205,7 @@ func handleCat(ctx context.Context, cmd []string, server utils.Server, conn *net
return nil, errors.New("category not found") return nil, errors.New("category not found")
} }
func handleUsers(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handleUsers(_ context.Context, _ []string, server utils.Server, _ *net.Conn) ([]byte, error) {
acl, ok := server.GetACL().(*ACL) acl, ok := server.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
@@ -217,7 +218,7 @@ func handleUsers(ctx context.Context, cmd []string, server utils.Server, conn *n
return []byte(res), nil return []byte(res), nil
} }
func handleSetUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handleSetUser(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
acl, ok := server.GetACL().(*ACL) acl, ok := server.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
@@ -228,7 +229,7 @@ func handleSetUser(ctx context.Context, cmd []string, server utils.Server, conn
return []byte(utils.OkResponse), nil return []byte(utils.OkResponse), nil
} }
func handleDelUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handleDelUser(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(utils.WrongArgsResponse) return nil, errors.New(utils.WrongArgsResponse)
} }
@@ -242,7 +243,7 @@ func handleDelUser(ctx context.Context, cmd []string, server utils.Server, conn
return []byte(utils.OkResponse), nil return []byte(utils.OkResponse), nil
} }
func handleWhoAmI(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handleWhoAmI(_ context.Context, _ []string, server utils.Server, conn *net.Conn) ([]byte, error) {
acl, ok := server.GetACL().(*ACL) acl, ok := server.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
@@ -251,7 +252,7 @@ func handleWhoAmI(ctx context.Context, cmd []string, server utils.Server, conn *
return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil
} }
func handleList(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handleList(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) > 2 { if len(cmd) > 2 {
return nil, errors.New(utils.WrongArgsResponse) return nil, errors.New(utils.WrongArgsResponse)
} }
@@ -347,7 +348,7 @@ func handleList(ctx context.Context, cmd []string, server utils.Server, conn *ne
return []byte(res), nil return []byte(res), nil
} }
func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handleLoad(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) != 3 { if len(cmd) != 3 {
return nil, errors.New(utils.WrongArgsResponse) return nil, errors.New(utils.WrongArgsResponse)
} }
@@ -364,8 +365,7 @@ func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *ne
defer func() { defer func() {
if err := f.Close(); err != nil { if err := f.Close(); err != nil {
// TODO: Log file close error with context log.Println(err)
fmt.Println(err)
} }
}() }()
@@ -412,7 +412,7 @@ func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *ne
return []byte(utils.OkResponse), nil return []byte(utils.OkResponse), nil
} }
func handleSave(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handleSave(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) > 2 { if len(cmd) > 2 {
return nil, errors.New(utils.WrongArgsResponse) return nil, errors.New(utils.WrongArgsResponse)
} }
@@ -429,8 +429,7 @@ func handleSave(ctx context.Context, cmd []string, server utils.Server, conn *ne
defer func() { defer func() {
if err := f.Close(); err != nil { if err := f.Close(); err != nil {
// TODO: Log file close error with context log.Println(err)
fmt.Println(err)
} }
}() }()
@@ -490,10 +489,11 @@ func Commands() []utils.Command {
}, },
SubCommands: []utils.SubCommand{ SubCommands: []utils.SubCommand{
{ {
Command: "cat", Command: "cat",
Categories: []string{utils.SlowCategory}, Categories: []string{utils.SlowCategory},
Description: "(ACL CAT [category]) List all the categories and commands inside a category.", Description: `(ACL CAT [category]) List all the categories.
Sync: false, If the optional category is provided, list all the commands in the category`,
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil return []string{}, nil
}, },

View File

@@ -2,8 +2,12 @@ package acl
import ( import (
"context" "context"
"crypto/sha256"
"fmt"
"github.com/echovault/echovault/src/server" "github.com/echovault/echovault/src/server"
"github.com/echovault/echovault/src/utils" "github.com/echovault/echovault/src/utils"
"github.com/tidwall/resp"
"net"
"testing" "testing"
) )
@@ -22,13 +26,17 @@ func init() {
Port: port, Port: port,
DataDir: "", DataDir: "",
EvictionPolicy: utils.NoEviction, EvictionPolicy: utils.NoEviction,
RequirePass: true,
Password: "password1",
} }
acl = NewACL(config) acl = NewACL(config)
acl.Users = append(acl.Users, generateInitialTestUsers()...)
mockServer = server.NewServer(server.Opts{ mockServer = server.NewServer(server.Opts{
Config: config, Config: config,
ACL: acl, ACL: acl,
Commands: Commands(),
}) })
go func() { go func() {
@@ -36,9 +44,140 @@ func init() {
}() }()
} }
func Test_HandleAuth(t *testing.T) {} func generateInitialTestUsers() []*User {
// User with both hash password and plaintext password
withPasswordUser := CreateUser("with_password_user")
h := sha256.New()
h.Write([]byte("password3"))
withPasswordUser.Passwords = []Password{
{PasswordType: PasswordPlainText, PasswordValue: "password2"},
{PasswordType: PasswordSHA256, PasswordValue: string(h.Sum(nil))},
}
func Test_HandleCat(t *testing.T) {} // User with NoPassword option
noPasswordUser := CreateUser("no_password_user")
noPasswordUser.Passwords = []Password{
{PasswordType: PasswordPlainText, PasswordValue: "password4"},
}
noPasswordUser.NoPassword = true
// Disabled user
disabledUser := CreateUser("disabled_user")
disabledUser.Passwords = []Password{
{PasswordType: PasswordPlainText, PasswordValue: "password5"},
}
disabledUser.Enabled = false
return []*User{
withPasswordUser,
noPasswordUser,
disabledUser,
}
}
func Test_HandleAuth(t *testing.T) {
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil {
t.Error(err)
}
r := resp.NewConn(conn)
tests := []struct {
cmd []resp.Value
wantRes string
wantErr string
}{
{ // 1. Authenticate with default user without specifying username
cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")},
wantRes: "OK",
wantErr: "",
},
{ // 2. Authenticate with plaintext password
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("with_password_user"),
resp.StringValue("password2"),
},
wantRes: "OK",
wantErr: "",
},
{ // 3. Authenticate with SHA256 password
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("with_password_user"),
resp.StringValue("password3"),
},
wantRes: "OK",
wantErr: "",
},
{ // 4. Authenticate with no password user
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("no_password_user"),
resp.StringValue("password4"),
},
wantRes: "OK",
wantErr: "",
},
{ // 5. Fail to authenticate with disabled user
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("disabled_user"),
resp.StringValue("password5"),
},
wantRes: "",
wantErr: "Error user disabled_user is disabled",
},
{ // 6. Fail to authenticate with non-existent user
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("non_existent_user"),
resp.StringValue("password6"),
},
wantRes: "",
wantErr: "Error no user with username non_existent_user",
},
{ // 7. Command too short
cmd: []resp.Value{resp.StringValue("AUTH")},
wantRes: "",
wantErr: fmt.Sprintf("Error %s", utils.WrongArgsResponse),
},
{ // 8. Command too long
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("user"),
resp.StringValue("password1"),
resp.StringValue("password2"),
},
wantRes: "",
wantErr: fmt.Sprintf("Error %s", utils.WrongArgsResponse),
},
}
for _, test := range tests {
if err = r.WriteArray(test.cmd); err != nil {
t.Error(err)
}
rv, _, err := r.ReadValue()
if err != nil {
t.Error(err)
}
if test.wantErr != "" {
if rv.Error().Error() != test.wantErr {
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
}
continue
}
if rv.String() != test.wantRes {
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
}
}
}
func Test_HandleCat(t *testing.T) {
// Since only ACL commands are loaded in this test suite, this test will only test against the
// list of categories and commands available in the ACL module.
}
func Test_HandleUsers(t *testing.T) {} func Test_HandleUsers(t *testing.T) {}

View File

@@ -5,6 +5,11 @@ import (
"strings" "strings"
) )
const (
PasswordPlainText = "plaintext"
PasswordSHA256 = "SHA256"
)
type Password struct { type Password struct {
PasswordType string `json:"PasswordType" yaml:"PasswordType"` // plaintext, SHA256 PasswordType string `json:"PasswordType" yaml:"PasswordType"` // plaintext, SHA256
PasswordValue string `json:"PasswordValue" yaml:"PasswordValue"` PasswordValue string `json:"PasswordValue" yaml:"PasswordValue"`
@@ -105,7 +110,7 @@ func (user *User) UpdateUser(cmd []string) error {
} }
if str[0] == '<' { if str[0] == '<' {
user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool { user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool {
if strings.EqualFold(password.PasswordType, "SHA256") { if strings.EqualFold(password.PasswordType, PasswordSHA256) {
return false return false
} }
return password.PasswordValue == str[1:] return password.PasswordValue == str[1:]
@@ -114,7 +119,7 @@ func (user *User) UpdateUser(cmd []string) error {
} }
if str[0] == '!' { if str[0] == '!' {
user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool { user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool {
if strings.EqualFold(password.PasswordType, "plaintext") { if strings.EqualFold(password.PasswordType, PasswordPlainText) {
return false return false
} }
return password.PasswordValue == str[1:] return password.PasswordValue == str[1:]
@@ -278,7 +283,7 @@ func CreateUser(username string) *User {
func GetPasswordType(password string) string { func GetPasswordType(password string) string {
if password[0] == '#' { if password[0] == '#' {
return "SHA256" return PasswordSHA256
} }
return "plaintext" return PasswordPlainText
} }