Implemented test for AUTH command handler

This commit is contained in:
Kelvin Mwinuka
2024-03-21 14:32:20 +08:00
parent 0e5f8ff99d
commit 9191d16762
4 changed files with 172 additions and 28 deletions

View File

@@ -164,7 +164,7 @@ func (acl *ACL) DeleteUser(ctx context.Context, usernames []string) error {
// Terminate every connection attached to this user
for connRef, connection := range acl.Connections {
if connection.User.Username == user.Username {
(*connRef).SetReadDeadline(time.Now().Add(-1 * time.Second))
_ = (*connRef).SetReadDeadline(time.Now().Add(-1 * time.Second))
}
}
// Delete the user from the ACL
@@ -175,7 +175,7 @@ func (acl *ACL) DeleteUser(ctx context.Context, usernames []string) error {
return nil
}
func (acl *ACL) AuthenticateConnection(ctx context.Context, conn *net.Conn, cmd []string) error {
func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []string) error {
var passwords []Password
var user *User
@@ -194,6 +194,7 @@ func (acl *ACL) AuthenticateConnection(ctx context.Context, conn *net.Conn, cmd
})
user = acl.Users[idx]
}
if len(cmd) == 3 {
// Process AUTH <username> <password>
h.Write([]byte(cmd[2]))
@@ -278,7 +279,6 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
// If the command is 'auth', then return early and allow it
if strings.EqualFold(comm, "auth") {
// TODO: Add rate limiting to prevent auth spamming
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"github.com/echovault/echovault/src/utils"
"gopkg.in/yaml.v3"
"log"
"net"
"os"
"path"
@@ -28,7 +29,7 @@ func handleAuth(ctx context.Context, cmd []string, server utils.Server, conn *ne
return []byte(utils.OkResponse), nil
}
func handleGetUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleGetUser(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) != 3 {
return nil, errors.New(utils.WrongArgsResponse)
}
@@ -144,7 +145,7 @@ func handleGetUser(ctx context.Context, cmd []string, server utils.Server, conn
return []byte(res), nil
}
func handleCat(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleCat(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) > 3 {
return nil, errors.New(utils.WrongArgsResponse)
}
@@ -204,7 +205,7 @@ func handleCat(ctx context.Context, cmd []string, server utils.Server, conn *net
return nil, errors.New("category not found")
}
func handleUsers(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleUsers(_ context.Context, _ []string, server utils.Server, _ *net.Conn) ([]byte, error) {
acl, ok := server.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
@@ -217,7 +218,7 @@ func handleUsers(ctx context.Context, cmd []string, server utils.Server, conn *n
return []byte(res), nil
}
func handleSetUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleSetUser(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
acl, ok := server.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
@@ -228,7 +229,7 @@ func handleSetUser(ctx context.Context, cmd []string, server utils.Server, conn
return []byte(utils.OkResponse), nil
}
func handleDelUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleDelUser(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) < 3 {
return nil, errors.New(utils.WrongArgsResponse)
}
@@ -242,7 +243,7 @@ func handleDelUser(ctx context.Context, cmd []string, server utils.Server, conn
return []byte(utils.OkResponse), nil
}
func handleWhoAmI(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleWhoAmI(_ context.Context, _ []string, server utils.Server, conn *net.Conn) ([]byte, error) {
acl, ok := server.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
@@ -251,7 +252,7 @@ func handleWhoAmI(ctx context.Context, cmd []string, server utils.Server, conn *
return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil
}
func handleList(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleList(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) > 2 {
return nil, errors.New(utils.WrongArgsResponse)
}
@@ -347,7 +348,7 @@ func handleList(ctx context.Context, cmd []string, server utils.Server, conn *ne
return []byte(res), nil
}
func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleLoad(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) != 3 {
return nil, errors.New(utils.WrongArgsResponse)
}
@@ -364,8 +365,7 @@ func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *ne
defer func() {
if err := f.Close(); err != nil {
// TODO: Log file close error with context
fmt.Println(err)
log.Println(err)
}
}()
@@ -412,7 +412,7 @@ func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *ne
return []byte(utils.OkResponse), nil
}
func handleSave(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleSave(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) > 2 {
return nil, errors.New(utils.WrongArgsResponse)
}
@@ -429,8 +429,7 @@ func handleSave(ctx context.Context, cmd []string, server utils.Server, conn *ne
defer func() {
if err := f.Close(); err != nil {
// TODO: Log file close error with context
fmt.Println(err)
log.Println(err)
}
}()
@@ -492,7 +491,8 @@ func Commands() []utils.Command {
{
Command: "cat",
Categories: []string{utils.SlowCategory},
Description: "(ACL CAT [category]) List all the categories and commands inside a category.",
Description: `(ACL CAT [category]) List all the categories.
If the optional category is provided, list all the commands in the category`,
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil

View File

@@ -2,8 +2,12 @@ package acl
import (
"context"
"crypto/sha256"
"fmt"
"github.com/echovault/echovault/src/server"
"github.com/echovault/echovault/src/utils"
"github.com/tidwall/resp"
"net"
"testing"
)
@@ -22,13 +26,17 @@ func init() {
Port: port,
DataDir: "",
EvictionPolicy: utils.NoEviction,
RequirePass: true,
Password: "password1",
}
acl = NewACL(config)
acl.Users = append(acl.Users, generateInitialTestUsers()...)
mockServer = server.NewServer(server.Opts{
Config: config,
ACL: acl,
Commands: Commands(),
})
go func() {
@@ -36,9 +44,140 @@ func init() {
}()
}
func Test_HandleAuth(t *testing.T) {}
func generateInitialTestUsers() []*User {
// User with both hash password and plaintext password
withPasswordUser := CreateUser("with_password_user")
h := sha256.New()
h.Write([]byte("password3"))
withPasswordUser.Passwords = []Password{
{PasswordType: PasswordPlainText, PasswordValue: "password2"},
{PasswordType: PasswordSHA256, PasswordValue: string(h.Sum(nil))},
}
func Test_HandleCat(t *testing.T) {}
// User with NoPassword option
noPasswordUser := CreateUser("no_password_user")
noPasswordUser.Passwords = []Password{
{PasswordType: PasswordPlainText, PasswordValue: "password4"},
}
noPasswordUser.NoPassword = true
// Disabled user
disabledUser := CreateUser("disabled_user")
disabledUser.Passwords = []Password{
{PasswordType: PasswordPlainText, PasswordValue: "password5"},
}
disabledUser.Enabled = false
return []*User{
withPasswordUser,
noPasswordUser,
disabledUser,
}
}
func Test_HandleAuth(t *testing.T) {
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil {
t.Error(err)
}
r := resp.NewConn(conn)
tests := []struct {
cmd []resp.Value
wantRes string
wantErr string
}{
{ // 1. Authenticate with default user without specifying username
cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")},
wantRes: "OK",
wantErr: "",
},
{ // 2. Authenticate with plaintext password
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("with_password_user"),
resp.StringValue("password2"),
},
wantRes: "OK",
wantErr: "",
},
{ // 3. Authenticate with SHA256 password
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("with_password_user"),
resp.StringValue("password3"),
},
wantRes: "OK",
wantErr: "",
},
{ // 4. Authenticate with no password user
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("no_password_user"),
resp.StringValue("password4"),
},
wantRes: "OK",
wantErr: "",
},
{ // 5. Fail to authenticate with disabled user
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("disabled_user"),
resp.StringValue("password5"),
},
wantRes: "",
wantErr: "Error user disabled_user is disabled",
},
{ // 6. Fail to authenticate with non-existent user
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("non_existent_user"),
resp.StringValue("password6"),
},
wantRes: "",
wantErr: "Error no user with username non_existent_user",
},
{ // 7. Command too short
cmd: []resp.Value{resp.StringValue("AUTH")},
wantRes: "",
wantErr: fmt.Sprintf("Error %s", utils.WrongArgsResponse),
},
{ // 8. Command too long
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("user"),
resp.StringValue("password1"),
resp.StringValue("password2"),
},
wantRes: "",
wantErr: fmt.Sprintf("Error %s", utils.WrongArgsResponse),
},
}
for _, test := range tests {
if err = r.WriteArray(test.cmd); err != nil {
t.Error(err)
}
rv, _, err := r.ReadValue()
if err != nil {
t.Error(err)
}
if test.wantErr != "" {
if rv.Error().Error() != test.wantErr {
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
}
continue
}
if rv.String() != test.wantRes {
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
}
}
}
func Test_HandleCat(t *testing.T) {
// Since only ACL commands are loaded in this test suite, this test will only test against the
// list of categories and commands available in the ACL module.
}
func Test_HandleUsers(t *testing.T) {}

View File

@@ -5,6 +5,11 @@ import (
"strings"
)
const (
PasswordPlainText = "plaintext"
PasswordSHA256 = "SHA256"
)
type Password struct {
PasswordType string `json:"PasswordType" yaml:"PasswordType"` // plaintext, SHA256
PasswordValue string `json:"PasswordValue" yaml:"PasswordValue"`
@@ -105,7 +110,7 @@ func (user *User) UpdateUser(cmd []string) error {
}
if str[0] == '<' {
user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool {
if strings.EqualFold(password.PasswordType, "SHA256") {
if strings.EqualFold(password.PasswordType, PasswordSHA256) {
return false
}
return password.PasswordValue == str[1:]
@@ -114,7 +119,7 @@ func (user *User) UpdateUser(cmd []string) error {
}
if str[0] == '!' {
user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool {
if strings.EqualFold(password.PasswordType, "plaintext") {
if strings.EqualFold(password.PasswordType, PasswordPlainText) {
return false
}
return password.PasswordValue == str[1:]
@@ -278,7 +283,7 @@ func CreateUser(username string) *User {
func GetPasswordType(password string) string {
if password[0] == '#' {
return "SHA256"
return PasswordSHA256
}
return "plaintext"
return PasswordPlainText
}