mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-09 01:40:20 +08:00
Implemented test for AUTH command handler
This commit is contained in:
@@ -164,7 +164,7 @@ func (acl *ACL) DeleteUser(ctx context.Context, usernames []string) error {
|
|||||||
// Terminate every connection attached to this user
|
// Terminate every connection attached to this user
|
||||||
for connRef, connection := range acl.Connections {
|
for connRef, connection := range acl.Connections {
|
||||||
if connection.User.Username == user.Username {
|
if connection.User.Username == user.Username {
|
||||||
(*connRef).SetReadDeadline(time.Now().Add(-1 * time.Second))
|
_ = (*connRef).SetReadDeadline(time.Now().Add(-1 * time.Second))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Delete the user from the ACL
|
// Delete the user from the ACL
|
||||||
@@ -175,7 +175,7 @@ func (acl *ACL) DeleteUser(ctx context.Context, usernames []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acl *ACL) AuthenticateConnection(ctx context.Context, conn *net.Conn, cmd []string) error {
|
func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []string) error {
|
||||||
var passwords []Password
|
var passwords []Password
|
||||||
var user *User
|
var user *User
|
||||||
|
|
||||||
@@ -194,6 +194,7 @@ func (acl *ACL) AuthenticateConnection(ctx context.Context, conn *net.Conn, cmd
|
|||||||
})
|
})
|
||||||
user = acl.Users[idx]
|
user = acl.Users[idx]
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(cmd) == 3 {
|
if len(cmd) == 3 {
|
||||||
// Process AUTH <username> <password>
|
// Process AUTH <username> <password>
|
||||||
h.Write([]byte(cmd[2]))
|
h.Write([]byte(cmd[2]))
|
||||||
@@ -278,7 +279,6 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
|
|||||||
|
|
||||||
// If the command is 'auth', then return early and allow it
|
// If the command is 'auth', then return early and allow it
|
||||||
if strings.EqualFold(comm, "auth") {
|
if strings.EqualFold(comm, "auth") {
|
||||||
// TODO: Add rate limiting to prevent auth spamming
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/echovault/echovault/src/utils"
|
"github.com/echovault/echovault/src/utils"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
@@ -28,7 +29,7 @@ func handleAuth(ctx context.Context, cmd []string, server utils.Server, conn *ne
|
|||||||
return []byte(utils.OkResponse), nil
|
return []byte(utils.OkResponse), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleGetUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
func handleGetUser(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||||
if len(cmd) != 3 {
|
if len(cmd) != 3 {
|
||||||
return nil, errors.New(utils.WrongArgsResponse)
|
return nil, errors.New(utils.WrongArgsResponse)
|
||||||
}
|
}
|
||||||
@@ -144,7 +145,7 @@ func handleGetUser(ctx context.Context, cmd []string, server utils.Server, conn
|
|||||||
return []byte(res), nil
|
return []byte(res), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleCat(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
func handleCat(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||||
if len(cmd) > 3 {
|
if len(cmd) > 3 {
|
||||||
return nil, errors.New(utils.WrongArgsResponse)
|
return nil, errors.New(utils.WrongArgsResponse)
|
||||||
}
|
}
|
||||||
@@ -204,7 +205,7 @@ func handleCat(ctx context.Context, cmd []string, server utils.Server, conn *net
|
|||||||
return nil, errors.New("category not found")
|
return nil, errors.New("category not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleUsers(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
func handleUsers(_ context.Context, _ []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||||
acl, ok := server.GetACL().(*ACL)
|
acl, ok := server.GetACL().(*ACL)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load ACL")
|
return nil, errors.New("could not load ACL")
|
||||||
@@ -217,7 +218,7 @@ func handleUsers(ctx context.Context, cmd []string, server utils.Server, conn *n
|
|||||||
return []byte(res), nil
|
return []byte(res), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleSetUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
func handleSetUser(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||||
acl, ok := server.GetACL().(*ACL)
|
acl, ok := server.GetACL().(*ACL)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load ACL")
|
return nil, errors.New("could not load ACL")
|
||||||
@@ -228,7 +229,7 @@ func handleSetUser(ctx context.Context, cmd []string, server utils.Server, conn
|
|||||||
return []byte(utils.OkResponse), nil
|
return []byte(utils.OkResponse), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleDelUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
func handleDelUser(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||||
if len(cmd) < 3 {
|
if len(cmd) < 3 {
|
||||||
return nil, errors.New(utils.WrongArgsResponse)
|
return nil, errors.New(utils.WrongArgsResponse)
|
||||||
}
|
}
|
||||||
@@ -242,7 +243,7 @@ func handleDelUser(ctx context.Context, cmd []string, server utils.Server, conn
|
|||||||
return []byte(utils.OkResponse), nil
|
return []byte(utils.OkResponse), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleWhoAmI(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
func handleWhoAmI(_ context.Context, _ []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||||
acl, ok := server.GetACL().(*ACL)
|
acl, ok := server.GetACL().(*ACL)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load ACL")
|
return nil, errors.New("could not load ACL")
|
||||||
@@ -251,7 +252,7 @@ func handleWhoAmI(ctx context.Context, cmd []string, server utils.Server, conn *
|
|||||||
return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil
|
return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleList(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
func handleList(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||||
if len(cmd) > 2 {
|
if len(cmd) > 2 {
|
||||||
return nil, errors.New(utils.WrongArgsResponse)
|
return nil, errors.New(utils.WrongArgsResponse)
|
||||||
}
|
}
|
||||||
@@ -347,7 +348,7 @@ func handleList(ctx context.Context, cmd []string, server utils.Server, conn *ne
|
|||||||
return []byte(res), nil
|
return []byte(res), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
func handleLoad(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||||
if len(cmd) != 3 {
|
if len(cmd) != 3 {
|
||||||
return nil, errors.New(utils.WrongArgsResponse)
|
return nil, errors.New(utils.WrongArgsResponse)
|
||||||
}
|
}
|
||||||
@@ -364,8 +365,7 @@ func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *ne
|
|||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := f.Close(); err != nil {
|
if err := f.Close(); err != nil {
|
||||||
// TODO: Log file close error with context
|
log.Println(err)
|
||||||
fmt.Println(err)
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -412,7 +412,7 @@ func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *ne
|
|||||||
return []byte(utils.OkResponse), nil
|
return []byte(utils.OkResponse), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleSave(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
func handleSave(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||||
if len(cmd) > 2 {
|
if len(cmd) > 2 {
|
||||||
return nil, errors.New(utils.WrongArgsResponse)
|
return nil, errors.New(utils.WrongArgsResponse)
|
||||||
}
|
}
|
||||||
@@ -429,8 +429,7 @@ func handleSave(ctx context.Context, cmd []string, server utils.Server, conn *ne
|
|||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := f.Close(); err != nil {
|
if err := f.Close(); err != nil {
|
||||||
// TODO: Log file close error with context
|
log.Println(err)
|
||||||
fmt.Println(err)
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -490,10 +489,11 @@ func Commands() []utils.Command {
|
|||||||
},
|
},
|
||||||
SubCommands: []utils.SubCommand{
|
SubCommands: []utils.SubCommand{
|
||||||
{
|
{
|
||||||
Command: "cat",
|
Command: "cat",
|
||||||
Categories: []string{utils.SlowCategory},
|
Categories: []string{utils.SlowCategory},
|
||||||
Description: "(ACL CAT [category]) List all the categories and commands inside a category.",
|
Description: `(ACL CAT [category]) List all the categories.
|
||||||
Sync: false,
|
If the optional category is provided, list all the commands in the category`,
|
||||||
|
Sync: false,
|
||||||
KeyExtractionFunc: func(cmd []string) ([]string, error) {
|
KeyExtractionFunc: func(cmd []string) ([]string, error) {
|
||||||
return []string{}, nil
|
return []string{}, nil
|
||||||
},
|
},
|
||||||
|
@@ -2,8 +2,12 @@ package acl
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
"github.com/echovault/echovault/src/server"
|
"github.com/echovault/echovault/src/server"
|
||||||
"github.com/echovault/echovault/src/utils"
|
"github.com/echovault/echovault/src/utils"
|
||||||
|
"github.com/tidwall/resp"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,13 +26,17 @@ func init() {
|
|||||||
Port: port,
|
Port: port,
|
||||||
DataDir: "",
|
DataDir: "",
|
||||||
EvictionPolicy: utils.NoEviction,
|
EvictionPolicy: utils.NoEviction,
|
||||||
|
RequirePass: true,
|
||||||
|
Password: "password1",
|
||||||
}
|
}
|
||||||
|
|
||||||
acl = NewACL(config)
|
acl = NewACL(config)
|
||||||
|
acl.Users = append(acl.Users, generateInitialTestUsers()...)
|
||||||
|
|
||||||
mockServer = server.NewServer(server.Opts{
|
mockServer = server.NewServer(server.Opts{
|
||||||
Config: config,
|
Config: config,
|
||||||
ACL: acl,
|
ACL: acl,
|
||||||
|
Commands: Commands(),
|
||||||
})
|
})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -36,9 +44,140 @@ func init() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_HandleAuth(t *testing.T) {}
|
func generateInitialTestUsers() []*User {
|
||||||
|
// User with both hash password and plaintext password
|
||||||
|
withPasswordUser := CreateUser("with_password_user")
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write([]byte("password3"))
|
||||||
|
withPasswordUser.Passwords = []Password{
|
||||||
|
{PasswordType: PasswordPlainText, PasswordValue: "password2"},
|
||||||
|
{PasswordType: PasswordSHA256, PasswordValue: string(h.Sum(nil))},
|
||||||
|
}
|
||||||
|
|
||||||
func Test_HandleCat(t *testing.T) {}
|
// User with NoPassword option
|
||||||
|
noPasswordUser := CreateUser("no_password_user")
|
||||||
|
noPasswordUser.Passwords = []Password{
|
||||||
|
{PasswordType: PasswordPlainText, PasswordValue: "password4"},
|
||||||
|
}
|
||||||
|
noPasswordUser.NoPassword = true
|
||||||
|
|
||||||
|
// Disabled user
|
||||||
|
disabledUser := CreateUser("disabled_user")
|
||||||
|
disabledUser.Passwords = []Password{
|
||||||
|
{PasswordType: PasswordPlainText, PasswordValue: "password5"},
|
||||||
|
}
|
||||||
|
disabledUser.Enabled = false
|
||||||
|
|
||||||
|
return []*User{
|
||||||
|
withPasswordUser,
|
||||||
|
noPasswordUser,
|
||||||
|
disabledUser,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_HandleAuth(t *testing.T) {
|
||||||
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
r := resp.NewConn(conn)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
cmd []resp.Value
|
||||||
|
wantRes string
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{ // 1. Authenticate with default user without specifying username
|
||||||
|
cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")},
|
||||||
|
wantRes: "OK",
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{ // 2. Authenticate with plaintext password
|
||||||
|
cmd: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("with_password_user"),
|
||||||
|
resp.StringValue("password2"),
|
||||||
|
},
|
||||||
|
wantRes: "OK",
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{ // 3. Authenticate with SHA256 password
|
||||||
|
cmd: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("with_password_user"),
|
||||||
|
resp.StringValue("password3"),
|
||||||
|
},
|
||||||
|
wantRes: "OK",
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{ // 4. Authenticate with no password user
|
||||||
|
cmd: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("no_password_user"),
|
||||||
|
resp.StringValue("password4"),
|
||||||
|
},
|
||||||
|
wantRes: "OK",
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{ // 5. Fail to authenticate with disabled user
|
||||||
|
cmd: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("disabled_user"),
|
||||||
|
resp.StringValue("password5"),
|
||||||
|
},
|
||||||
|
wantRes: "",
|
||||||
|
wantErr: "Error user disabled_user is disabled",
|
||||||
|
},
|
||||||
|
{ // 6. Fail to authenticate with non-existent user
|
||||||
|
cmd: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("non_existent_user"),
|
||||||
|
resp.StringValue("password6"),
|
||||||
|
},
|
||||||
|
wantRes: "",
|
||||||
|
wantErr: "Error no user with username non_existent_user",
|
||||||
|
},
|
||||||
|
{ // 7. Command too short
|
||||||
|
cmd: []resp.Value{resp.StringValue("AUTH")},
|
||||||
|
wantRes: "",
|
||||||
|
wantErr: fmt.Sprintf("Error %s", utils.WrongArgsResponse),
|
||||||
|
},
|
||||||
|
{ // 8. Command too long
|
||||||
|
cmd: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("user"),
|
||||||
|
resp.StringValue("password1"),
|
||||||
|
resp.StringValue("password2"),
|
||||||
|
},
|
||||||
|
wantRes: "",
|
||||||
|
wantErr: fmt.Sprintf("Error %s", utils.WrongArgsResponse),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
if err = r.WriteArray(test.cmd); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
rv, _, err := r.ReadValue()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if test.wantErr != "" {
|
||||||
|
if rv.Error().Error() != test.wantErr {
|
||||||
|
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if rv.String() != test.wantRes {
|
||||||
|
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_HandleCat(t *testing.T) {
|
||||||
|
// Since only ACL commands are loaded in this test suite, this test will only test against the
|
||||||
|
// list of categories and commands available in the ACL module.
|
||||||
|
}
|
||||||
|
|
||||||
func Test_HandleUsers(t *testing.T) {}
|
func Test_HandleUsers(t *testing.T) {}
|
||||||
|
|
||||||
|
@@ -5,6 +5,11 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PasswordPlainText = "plaintext"
|
||||||
|
PasswordSHA256 = "SHA256"
|
||||||
|
)
|
||||||
|
|
||||||
type Password struct {
|
type Password struct {
|
||||||
PasswordType string `json:"PasswordType" yaml:"PasswordType"` // plaintext, SHA256
|
PasswordType string `json:"PasswordType" yaml:"PasswordType"` // plaintext, SHA256
|
||||||
PasswordValue string `json:"PasswordValue" yaml:"PasswordValue"`
|
PasswordValue string `json:"PasswordValue" yaml:"PasswordValue"`
|
||||||
@@ -105,7 +110,7 @@ func (user *User) UpdateUser(cmd []string) error {
|
|||||||
}
|
}
|
||||||
if str[0] == '<' {
|
if str[0] == '<' {
|
||||||
user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool {
|
user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool {
|
||||||
if strings.EqualFold(password.PasswordType, "SHA256") {
|
if strings.EqualFold(password.PasswordType, PasswordSHA256) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return password.PasswordValue == str[1:]
|
return password.PasswordValue == str[1:]
|
||||||
@@ -114,7 +119,7 @@ func (user *User) UpdateUser(cmd []string) error {
|
|||||||
}
|
}
|
||||||
if str[0] == '!' {
|
if str[0] == '!' {
|
||||||
user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool {
|
user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool {
|
||||||
if strings.EqualFold(password.PasswordType, "plaintext") {
|
if strings.EqualFold(password.PasswordType, PasswordPlainText) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return password.PasswordValue == str[1:]
|
return password.PasswordValue == str[1:]
|
||||||
@@ -278,7 +283,7 @@ func CreateUser(username string) *User {
|
|||||||
|
|
||||||
func GetPasswordType(password string) string {
|
func GetPasswordType(password string) string {
|
||||||
if password[0] == '#' {
|
if password[0] == '#' {
|
||||||
return "SHA256"
|
return PasswordSHA256
|
||||||
}
|
}
|
||||||
return "plaintext"
|
return PasswordPlainText
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user