mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-09 01:40:20 +08:00
Implemented test for AUTH command handler
This commit is contained in:
@@ -164,7 +164,7 @@ func (acl *ACL) DeleteUser(ctx context.Context, usernames []string) error {
|
||||
// Terminate every connection attached to this user
|
||||
for connRef, connection := range acl.Connections {
|
||||
if connection.User.Username == user.Username {
|
||||
(*connRef).SetReadDeadline(time.Now().Add(-1 * time.Second))
|
||||
_ = (*connRef).SetReadDeadline(time.Now().Add(-1 * time.Second))
|
||||
}
|
||||
}
|
||||
// Delete the user from the ACL
|
||||
@@ -175,7 +175,7 @@ func (acl *ACL) DeleteUser(ctx context.Context, usernames []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (acl *ACL) AuthenticateConnection(ctx context.Context, conn *net.Conn, cmd []string) error {
|
||||
func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []string) error {
|
||||
var passwords []Password
|
||||
var user *User
|
||||
|
||||
@@ -194,6 +194,7 @@ func (acl *ACL) AuthenticateConnection(ctx context.Context, conn *net.Conn, cmd
|
||||
})
|
||||
user = acl.Users[idx]
|
||||
}
|
||||
|
||||
if len(cmd) == 3 {
|
||||
// Process AUTH <username> <password>
|
||||
h.Write([]byte(cmd[2]))
|
||||
@@ -278,7 +279,6 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
|
||||
|
||||
// If the command is 'auth', then return early and allow it
|
||||
if strings.EqualFold(comm, "auth") {
|
||||
// TODO: Add rate limiting to prevent auth spamming
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/src/utils"
|
||||
"gopkg.in/yaml.v3"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
@@ -28,7 +29,7 @@ func handleAuth(ctx context.Context, cmd []string, server utils.Server, conn *ne
|
||||
return []byte(utils.OkResponse), nil
|
||||
}
|
||||
|
||||
func handleGetUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
func handleGetUser(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||
if len(cmd) != 3 {
|
||||
return nil, errors.New(utils.WrongArgsResponse)
|
||||
}
|
||||
@@ -144,7 +145,7 @@ func handleGetUser(ctx context.Context, cmd []string, server utils.Server, conn
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleCat(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
func handleCat(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||
if len(cmd) > 3 {
|
||||
return nil, errors.New(utils.WrongArgsResponse)
|
||||
}
|
||||
@@ -204,7 +205,7 @@ func handleCat(ctx context.Context, cmd []string, server utils.Server, conn *net
|
||||
return nil, errors.New("category not found")
|
||||
}
|
||||
|
||||
func handleUsers(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
func handleUsers(_ context.Context, _ []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||
acl, ok := server.GetACL().(*ACL)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
@@ -217,7 +218,7 @@ func handleUsers(ctx context.Context, cmd []string, server utils.Server, conn *n
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleSetUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
func handleSetUser(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||
acl, ok := server.GetACL().(*ACL)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
@@ -228,7 +229,7 @@ func handleSetUser(ctx context.Context, cmd []string, server utils.Server, conn
|
||||
return []byte(utils.OkResponse), nil
|
||||
}
|
||||
|
||||
func handleDelUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
func handleDelUser(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||
if len(cmd) < 3 {
|
||||
return nil, errors.New(utils.WrongArgsResponse)
|
||||
}
|
||||
@@ -242,7 +243,7 @@ func handleDelUser(ctx context.Context, cmd []string, server utils.Server, conn
|
||||
return []byte(utils.OkResponse), nil
|
||||
}
|
||||
|
||||
func handleWhoAmI(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
func handleWhoAmI(_ context.Context, _ []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
acl, ok := server.GetACL().(*ACL)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
@@ -251,7 +252,7 @@ func handleWhoAmI(ctx context.Context, cmd []string, server utils.Server, conn *
|
||||
return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil
|
||||
}
|
||||
|
||||
func handleList(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
func handleList(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||
if len(cmd) > 2 {
|
||||
return nil, errors.New(utils.WrongArgsResponse)
|
||||
}
|
||||
@@ -347,7 +348,7 @@ func handleList(ctx context.Context, cmd []string, server utils.Server, conn *ne
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
func handleLoad(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||
if len(cmd) != 3 {
|
||||
return nil, errors.New(utils.WrongArgsResponse)
|
||||
}
|
||||
@@ -364,8 +365,7 @@ func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *ne
|
||||
|
||||
defer func() {
|
||||
if err := f.Close(); err != nil {
|
||||
// TODO: Log file close error with context
|
||||
fmt.Println(err)
|
||||
log.Println(err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -412,7 +412,7 @@ func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *ne
|
||||
return []byte(utils.OkResponse), nil
|
||||
}
|
||||
|
||||
func handleSave(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
func handleSave(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||
if len(cmd) > 2 {
|
||||
return nil, errors.New(utils.WrongArgsResponse)
|
||||
}
|
||||
@@ -429,8 +429,7 @@ func handleSave(ctx context.Context, cmd []string, server utils.Server, conn *ne
|
||||
|
||||
defer func() {
|
||||
if err := f.Close(); err != nil {
|
||||
// TODO: Log file close error with context
|
||||
fmt.Println(err)
|
||||
log.Println(err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -492,7 +491,8 @@ func Commands() []utils.Command {
|
||||
{
|
||||
Command: "cat",
|
||||
Categories: []string{utils.SlowCategory},
|
||||
Description: "(ACL CAT [category]) List all the categories and commands inside a category.",
|
||||
Description: `(ACL CAT [category]) List all the categories.
|
||||
If the optional category is provided, list all the commands in the category`,
|
||||
Sync: false,
|
||||
KeyExtractionFunc: func(cmd []string) ([]string, error) {
|
||||
return []string{}, nil
|
||||
|
@@ -2,8 +2,12 @@ package acl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/src/server"
|
||||
"github.com/echovault/echovault/src/utils"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -22,13 +26,17 @@ func init() {
|
||||
Port: port,
|
||||
DataDir: "",
|
||||
EvictionPolicy: utils.NoEviction,
|
||||
RequirePass: true,
|
||||
Password: "password1",
|
||||
}
|
||||
|
||||
acl = NewACL(config)
|
||||
acl.Users = append(acl.Users, generateInitialTestUsers()...)
|
||||
|
||||
mockServer = server.NewServer(server.Opts{
|
||||
Config: config,
|
||||
ACL: acl,
|
||||
Commands: Commands(),
|
||||
})
|
||||
|
||||
go func() {
|
||||
@@ -36,9 +44,140 @@ func init() {
|
||||
}()
|
||||
}
|
||||
|
||||
func Test_HandleAuth(t *testing.T) {}
|
||||
func generateInitialTestUsers() []*User {
|
||||
// User with both hash password and plaintext password
|
||||
withPasswordUser := CreateUser("with_password_user")
|
||||
h := sha256.New()
|
||||
h.Write([]byte("password3"))
|
||||
withPasswordUser.Passwords = []Password{
|
||||
{PasswordType: PasswordPlainText, PasswordValue: "password2"},
|
||||
{PasswordType: PasswordSHA256, PasswordValue: string(h.Sum(nil))},
|
||||
}
|
||||
|
||||
func Test_HandleCat(t *testing.T) {}
|
||||
// User with NoPassword option
|
||||
noPasswordUser := CreateUser("no_password_user")
|
||||
noPasswordUser.Passwords = []Password{
|
||||
{PasswordType: PasswordPlainText, PasswordValue: "password4"},
|
||||
}
|
||||
noPasswordUser.NoPassword = true
|
||||
|
||||
// Disabled user
|
||||
disabledUser := CreateUser("disabled_user")
|
||||
disabledUser.Passwords = []Password{
|
||||
{PasswordType: PasswordPlainText, PasswordValue: "password5"},
|
||||
}
|
||||
disabledUser.Enabled = false
|
||||
|
||||
return []*User{
|
||||
withPasswordUser,
|
||||
noPasswordUser,
|
||||
disabledUser,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleAuth(t *testing.T) {
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
r := resp.NewConn(conn)
|
||||
|
||||
tests := []struct {
|
||||
cmd []resp.Value
|
||||
wantRes string
|
||||
wantErr string
|
||||
}{
|
||||
{ // 1. Authenticate with default user without specifying username
|
||||
cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")},
|
||||
wantRes: "OK",
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 2. Authenticate with plaintext password
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("with_password_user"),
|
||||
resp.StringValue("password2"),
|
||||
},
|
||||
wantRes: "OK",
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 3. Authenticate with SHA256 password
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("with_password_user"),
|
||||
resp.StringValue("password3"),
|
||||
},
|
||||
wantRes: "OK",
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 4. Authenticate with no password user
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("no_password_user"),
|
||||
resp.StringValue("password4"),
|
||||
},
|
||||
wantRes: "OK",
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 5. Fail to authenticate with disabled user
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("disabled_user"),
|
||||
resp.StringValue("password5"),
|
||||
},
|
||||
wantRes: "",
|
||||
wantErr: "Error user disabled_user is disabled",
|
||||
},
|
||||
{ // 6. Fail to authenticate with non-existent user
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("non_existent_user"),
|
||||
resp.StringValue("password6"),
|
||||
},
|
||||
wantRes: "",
|
||||
wantErr: "Error no user with username non_existent_user",
|
||||
},
|
||||
{ // 7. Command too short
|
||||
cmd: []resp.Value{resp.StringValue("AUTH")},
|
||||
wantRes: "",
|
||||
wantErr: fmt.Sprintf("Error %s", utils.WrongArgsResponse),
|
||||
},
|
||||
{ // 8. Command too long
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("user"),
|
||||
resp.StringValue("password1"),
|
||||
resp.StringValue("password2"),
|
||||
},
|
||||
wantRes: "",
|
||||
wantErr: fmt.Sprintf("Error %s", utils.WrongArgsResponse),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
if err = r.WriteArray(test.cmd); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
rv, _, err := r.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if test.wantErr != "" {
|
||||
if rv.Error().Error() != test.wantErr {
|
||||
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
|
||||
}
|
||||
continue
|
||||
}
|
||||
if rv.String() != test.wantRes {
|
||||
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleCat(t *testing.T) {
|
||||
// Since only ACL commands are loaded in this test suite, this test will only test against the
|
||||
// list of categories and commands available in the ACL module.
|
||||
}
|
||||
|
||||
func Test_HandleUsers(t *testing.T) {}
|
||||
|
||||
|
@@ -5,6 +5,11 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
PasswordPlainText = "plaintext"
|
||||
PasswordSHA256 = "SHA256"
|
||||
)
|
||||
|
||||
type Password struct {
|
||||
PasswordType string `json:"PasswordType" yaml:"PasswordType"` // plaintext, SHA256
|
||||
PasswordValue string `json:"PasswordValue" yaml:"PasswordValue"`
|
||||
@@ -105,7 +110,7 @@ func (user *User) UpdateUser(cmd []string) error {
|
||||
}
|
||||
if str[0] == '<' {
|
||||
user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool {
|
||||
if strings.EqualFold(password.PasswordType, "SHA256") {
|
||||
if strings.EqualFold(password.PasswordType, PasswordSHA256) {
|
||||
return false
|
||||
}
|
||||
return password.PasswordValue == str[1:]
|
||||
@@ -114,7 +119,7 @@ func (user *User) UpdateUser(cmd []string) error {
|
||||
}
|
||||
if str[0] == '!' {
|
||||
user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool {
|
||||
if strings.EqualFold(password.PasswordType, "plaintext") {
|
||||
if strings.EqualFold(password.PasswordType, PasswordPlainText) {
|
||||
return false
|
||||
}
|
||||
return password.PasswordValue == str[1:]
|
||||
@@ -278,7 +283,7 @@ func CreateUser(username string) *User {
|
||||
|
||||
func GetPasswordType(password string) string {
|
||||
if password[0] == '#' {
|
||||
return "SHA256"
|
||||
return PasswordSHA256
|
||||
}
|
||||
return "plaintext"
|
||||
return PasswordPlainText
|
||||
}
|
||||
|
Reference in New Issue
Block a user