Removed duplicate imports for set, sorted_set, pubsub and acl modules. Moved /modules from /pkg to /internal. Delted commands package: Commands will now be automatically loaded when an EchoVault instance is initialised.

This commit is contained in:
Kelvin Clement Mwinuka
2024-04-24 22:37:16 +08:00
parent 3e04b7a822
commit b6ddb43a49
46 changed files with 645 additions and 735 deletions

View File

@@ -7,7 +7,7 @@ build:
run: run:
make build && docker-compose up --build make build && docker-compose up --build
test-normal: test-unit:
go clean -testcache && go test ./... -coverprofile coverage/coverage.out go clean -testcache && go test ./... -coverprofile coverage/coverage.out
test-race: test-race:

View File

@@ -18,7 +18,6 @@ import (
"context" "context"
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/commands"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"log" "log"
"os" "os"
@@ -49,7 +48,6 @@ func main() {
server, err := echovault.NewEchoVault( server, err := echovault.NewEchoVault(
echovault.WithContext(ctx), echovault.WithContext(ctx),
echovault.WithConfig(conf), echovault.WithConfig(conf),
echovault.WithCommands(commands.All()),
) )
if err != nil { if err != nil {

View File

@@ -18,7 +18,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
internal_acl "github.com/echovault/echovault/internal/acl"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types" "github.com/echovault/echovault/pkg/types"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
@@ -33,7 +32,7 @@ func handleAuth(params types.HandlerFuncParams) ([]byte, error) {
if len(params.Command) < 2 || len(params.Command) > 3 { if len(params.Command) < 2 || len(params.Command) > 3 {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
acl, ok := params.GetACL().(*internal_acl.ACL) acl, ok := params.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
@@ -48,12 +47,12 @@ func handleGetUser(params types.HandlerFuncParams) ([]byte, error) {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
acl, ok := params.GetACL().(*internal_acl.ACL) acl, ok := params.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
var user *internal_acl.User var user *User
userFound := false userFound := false
for _, u := range acl.Users { for _, u := range acl.Users {
if u.Username == params.Command[2] { if u.Username == params.Command[2] {
@@ -221,7 +220,7 @@ func handleCat(params types.HandlerFuncParams) ([]byte, error) {
} }
func handleUsers(params types.HandlerFuncParams) ([]byte, error) { func handleUsers(params types.HandlerFuncParams) ([]byte, error) {
acl, ok := params.GetACL().(*internal_acl.ACL) acl, ok := params.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
@@ -234,7 +233,7 @@ func handleUsers(params types.HandlerFuncParams) ([]byte, error) {
} }
func handleSetUser(params types.HandlerFuncParams) ([]byte, error) { func handleSetUser(params types.HandlerFuncParams) ([]byte, error) {
acl, ok := params.GetACL().(*internal_acl.ACL) acl, ok := params.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
@@ -248,7 +247,7 @@ func handleDelUser(params types.HandlerFuncParams) ([]byte, error) {
if len(params.Command) < 3 { if len(params.Command) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
acl, ok := params.GetACL().(*internal_acl.ACL) acl, ok := params.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
@@ -259,7 +258,7 @@ func handleDelUser(params types.HandlerFuncParams) ([]byte, error) {
} }
func handleWhoAmI(params types.HandlerFuncParams) ([]byte, error) { func handleWhoAmI(params types.HandlerFuncParams) ([]byte, error) {
acl, ok := params.GetACL().(*internal_acl.ACL) acl, ok := params.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
@@ -271,7 +270,7 @@ func handleList(params types.HandlerFuncParams) ([]byte, error) {
if len(params.Command) > 2 { if len(params.Command) > 2 {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
acl, ok := params.GetACL().(*internal_acl.ACL) acl, ok := params.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
@@ -368,7 +367,7 @@ func handleLoad(params types.HandlerFuncParams) ([]byte, error) {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
acl, ok := params.GetACL().(*internal_acl.ACL) acl, ok := params.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
@@ -389,7 +388,7 @@ func handleLoad(params types.HandlerFuncParams) ([]byte, error) {
ext := path.Ext(f.Name()) ext := path.Ext(f.Name())
var users []*internal_acl.User var users []*User
if ext == ".json" { if ext == ".json" {
if err := json.NewDecoder(f).Decode(&users); err != nil { if err := json.NewDecoder(f).Decode(&users); err != nil {
@@ -435,7 +434,7 @@ func handleSave(params types.HandlerFuncParams) ([]byte, error) {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
acl, ok := params.GetACL().(*internal_acl.ACL) acl, ok := params.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }

View File

@@ -17,14 +17,13 @@ package pubsub
import ( import (
"errors" "errors"
"fmt" "fmt"
internal_pubsub "github.com/echovault/echovault/internal/pubsub"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types" "github.com/echovault/echovault/pkg/types"
"strings" "strings"
) )
func handleSubscribe(params types.HandlerFuncParams) ([]byte, error) { func handleSubscribe(params types.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub) pubsub, ok := params.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub module") return nil, errors.New("could not load pubsub module")
} }
@@ -42,7 +41,7 @@ func handleSubscribe(params types.HandlerFuncParams) ([]byte, error) {
} }
func handleUnsubscribe(params types.HandlerFuncParams) ([]byte, error) { func handleUnsubscribe(params types.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub) pubsub, ok := params.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub module") return nil, errors.New("could not load pubsub module")
} }
@@ -55,7 +54,7 @@ func handleUnsubscribe(params types.HandlerFuncParams) ([]byte, error) {
} }
func handlePublish(params types.HandlerFuncParams) ([]byte, error) { func handlePublish(params types.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub) pubsub, ok := params.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub module") return nil, errors.New("could not load pubsub module")
} }
@@ -71,7 +70,7 @@ func handlePubSubChannels(params types.HandlerFuncParams) ([]byte, error) {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub) pubsub, ok := params.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub module") return nil, errors.New("could not load pubsub module")
} }
@@ -85,7 +84,7 @@ func handlePubSubChannels(params types.HandlerFuncParams) ([]byte, error) {
} }
func handlePubSubNumPat(params types.HandlerFuncParams) ([]byte, error) { func handlePubSubNumPat(params types.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub) pubsub, ok := params.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub module") return nil, errors.New("could not load pubsub module")
} }
@@ -94,7 +93,7 @@ func handlePubSubNumPat(params types.HandlerFuncParams) ([]byte, error) {
} }
func handlePubSubNumSubs(params types.HandlerFuncParams) ([]byte, error) { func handlePubSubNumSubs(params types.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub) pubsub, ok := params.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub module") return nil, errors.New("could not load pubsub module")
} }

View File

@@ -18,7 +18,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
internal_set "github.com/echovault/echovault/internal/set"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types" "github.com/echovault/echovault/pkg/types"
"slices" "slices"
@@ -33,10 +32,10 @@ func handleSADD(params types.HandlerFuncParams) ([]byte, error) {
key := keys.WriteKeys[0] key := keys.WriteKeys[0]
var set *internal_set.Set var set *Set
if !params.KeyExists(params.Context, key) { if !params.KeyExists(params.Context, key) {
set = internal_set.NewSet(params.Command[2:]) set = NewSet(params.Command[2:])
if ok, err := params.CreateKeyAndLock(params.Context, key); !ok && err != nil { if ok, err := params.CreateKeyAndLock(params.Context, key); !ok && err != nil {
return nil, err return nil, err
} }
@@ -52,7 +51,7 @@ func handleSADD(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyUnlock(params.Context, key) defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set) set, ok := params.GetValue(params.Context, key).(*Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key) return nil, fmt.Errorf("value at key %s is not a set", key)
} }
@@ -79,7 +78,7 @@ func handleSCARD(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyRUnlock(params.Context, key) defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set) set, ok := params.GetValue(params.Context, key).(*Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key) return nil, fmt.Errorf("value at key %s is not a set", key)
} }
@@ -103,7 +102,7 @@ func handleSDIFF(params types.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0]) defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*internal_set.Set) baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0]) return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
} }
@@ -127,9 +126,9 @@ func handleSDIFF(params types.HandlerFuncParams) ([]byte, error) {
locks[key] = true locks[key] = true
} }
var sets []*internal_set.Set var sets []*Set
for _, key := range params.Command[2:] { for _, key := range params.Command[2:] {
set, ok := params.GetValue(params.Context, key).(*internal_set.Set) set, ok := params.GetValue(params.Context, key).(*Set)
if !ok { if !ok {
continue continue
} }
@@ -166,7 +165,7 @@ func handleSDIFFSTORE(params types.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0]) defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*internal_set.Set) baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0]) return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
} }
@@ -190,9 +189,9 @@ func handleSDIFFSTORE(params types.HandlerFuncParams) ([]byte, error) {
locks[key] = true locks[key] = true
} }
var sets []*internal_set.Set var sets []*Set
for _, key := range keys.ReadKeys[1:] { for _, key := range keys.ReadKeys[1:] {
set, ok := params.GetValue(params.Context, key).(*internal_set.Set) set, ok := params.GetValue(params.Context, key).(*Set)
if !ok { if !ok {
continue continue
} }
@@ -252,10 +251,10 @@ func handleSINTER(params types.HandlerFuncParams) ([]byte, error) {
locks[key] = true locks[key] = true
} }
var sets []*internal_set.Set var sets []*Set
for key, _ := range locks { for key, _ := range locks {
set, ok := params.GetValue(params.Context, key).(*internal_set.Set) set, ok := params.GetValue(params.Context, key).(*Set)
if !ok { if !ok {
// If the value at the key is not a set, return error // If the value at the key is not a set, return error
return nil, fmt.Errorf("value at key %s is not a set", key) return nil, fmt.Errorf("value at key %s is not a set", key)
@@ -267,7 +266,7 @@ func handleSINTER(params types.HandlerFuncParams) ([]byte, error) {
return nil, fmt.Errorf("not enough sets in the keys provided") return nil, fmt.Errorf("not enough sets in the keys provided")
} }
intersect, _ := internal_set.Intersection(0, sets...) intersect, _ := Intersection(0, sets...)
elems := intersect.GetAll() elems := intersect.GetAll()
res := fmt.Sprintf("*%d", len(elems)) res := fmt.Sprintf("*%d", len(elems))
@@ -328,10 +327,10 @@ func handleSINTERCARD(params types.HandlerFuncParams) ([]byte, error) {
locks[key] = true locks[key] = true
} }
var sets []*internal_set.Set var sets []*Set
for key, _ := range locks { for key, _ := range locks {
set, ok := params.GetValue(params.Context, key).(*internal_set.Set) set, ok := params.GetValue(params.Context, key).(*Set)
if !ok { if !ok {
// If the value at the key is not a set, return error // If the value at the key is not a set, return error
return nil, fmt.Errorf("value at key %s is not a set", key) return nil, fmt.Errorf("value at key %s is not a set", key)
@@ -343,7 +342,7 @@ func handleSINTERCARD(params types.HandlerFuncParams) ([]byte, error) {
return nil, fmt.Errorf("not enough sets in the keys provided") return nil, fmt.Errorf("not enough sets in the keys provided")
} }
intersect, _ := internal_set.Intersection(limit, sets...) intersect, _ := Intersection(limit, sets...)
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
} }
@@ -374,10 +373,10 @@ func handleSINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
locks[key] = true locks[key] = true
} }
var sets []*internal_set.Set var sets []*Set
for key, _ := range locks { for key, _ := range locks {
set, ok := params.GetValue(params.Context, key).(*internal_set.Set) set, ok := params.GetValue(params.Context, key).(*Set)
if !ok { if !ok {
// If the value at the key is not a set, return error // If the value at the key is not a set, return error
return nil, fmt.Errorf("value at key %s is not a set", key) return nil, fmt.Errorf("value at key %s is not a set", key)
@@ -385,7 +384,7 @@ func handleSINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
sets = append(sets, set) sets = append(sets, set)
} }
intersect, _ := internal_set.Intersection(0, sets...) intersect, _ := Intersection(0, sets...)
destination := keys.WriteKeys[0] destination := keys.WriteKeys[0]
if params.KeyExists(params.Context, destination) { if params.KeyExists(params.Context, destination) {
@@ -423,7 +422,7 @@ func handleSISMEMBER(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyRUnlock(params.Context, key) defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set) set, ok := params.GetValue(params.Context, key).(*Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key) return nil, fmt.Errorf("value at key %s is not a set", key)
} }
@@ -452,7 +451,7 @@ func handleSMEMBERS(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyRUnlock(params.Context, key) defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set) set, ok := params.GetValue(params.Context, key).(*Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key) return nil, fmt.Errorf("value at key %s is not a set", key)
} }
@@ -495,7 +494,7 @@ func handleSMISMEMBER(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyRUnlock(params.Context, key) defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set) set, ok := params.GetValue(params.Context, key).(*Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key) return nil, fmt.Errorf("value at key %s is not a set", key)
} }
@@ -531,12 +530,12 @@ func handleSMOVE(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyUnlock(params.Context, source) defer params.KeyUnlock(params.Context, source)
sourceSet, ok := params.GetValue(params.Context, source).(*internal_set.Set) sourceSet, ok := params.GetValue(params.Context, source).(*Set)
if !ok { if !ok {
return nil, errors.New("source is not a set") return nil, errors.New("source is not a set")
} }
var destinationSet *internal_set.Set var destinationSet *Set
if !params.KeyExists(params.Context, destination) { if !params.KeyExists(params.Context, destination) {
// Destination key does not exist // Destination key does not exist
@@ -544,7 +543,7 @@ func handleSMOVE(params types.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
defer params.KeyUnlock(params.Context, destination) defer params.KeyUnlock(params.Context, destination)
destinationSet = internal_set.NewSet([]string{}) destinationSet = NewSet([]string{})
if err = params.SetValue(params.Context, destination, destinationSet); err != nil { if err = params.SetValue(params.Context, destination, destinationSet); err != nil {
return nil, err return nil, err
} }
@@ -554,7 +553,7 @@ func handleSMOVE(params types.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
defer params.KeyUnlock(params.Context, destination) defer params.KeyUnlock(params.Context, destination)
ds, ok := params.GetValue(params.Context, destination).(*internal_set.Set) ds, ok := params.GetValue(params.Context, destination).(*Set)
if !ok { if !ok {
return nil, errors.New("destination is not a set") return nil, errors.New("destination is not a set")
} }
@@ -592,7 +591,7 @@ func handleSPOP(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyUnlock(params.Context, key) defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set) set, ok := params.GetValue(params.Context, key).(*Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a set", key) return nil, fmt.Errorf("value at %s is not a set", key)
} }
@@ -636,7 +635,7 @@ func handleSRANDMEMBER(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyUnlock(params.Context, key) defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set) set, ok := params.GetValue(params.Context, key).(*Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a set", key) return nil, fmt.Errorf("value at %s is not a set", key)
} }
@@ -672,7 +671,7 @@ func handleSREM(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyUnlock(params.Context, key) defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set) set, ok := params.GetValue(params.Context, key).(*Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key) return nil, fmt.Errorf("value at key %s is not a set", key)
} }
@@ -707,20 +706,20 @@ func handleSUNION(params types.HandlerFuncParams) ([]byte, error) {
locks[key] = true locks[key] = true
} }
var sets []*internal_set.Set var sets []*Set
for key, locked := range locks { for key, locked := range locks {
if !locked { if !locked {
continue continue
} }
set, ok := params.GetValue(params.Context, key).(*internal_set.Set) set, ok := params.GetValue(params.Context, key).(*Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key) return nil, fmt.Errorf("value at key %s is not a set", key)
} }
sets = append(sets, set) sets = append(sets, set)
} }
union := internal_set.Union(sets...) union := Union(sets...)
res := fmt.Sprintf("*%d", union.Cardinality()) res := fmt.Sprintf("*%d", union.Cardinality())
for i, e := range union.GetAll() { for i, e := range union.GetAll() {
@@ -758,20 +757,20 @@ func handleSUNIONSTORE(params types.HandlerFuncParams) ([]byte, error) {
locks[key] = true locks[key] = true
} }
var sets []*internal_set.Set var sets []*Set
for key, locked := range locks { for key, locked := range locks {
if !locked { if !locked {
continue continue
} }
set, ok := params.GetValue(params.Context, key).(*internal_set.Set) set, ok := params.GetValue(params.Context, key).(*Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key) return nil, fmt.Errorf("value at key %s is not a set", key)
} }
sets = append(sets, set) sets = append(sets, set)
} }
union := internal_set.Union(sets...) union := Union(sets...)
destination := keys.WriteKeys[0] destination := keys.WriteKeys[0]

View File

@@ -19,7 +19,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/sorted_set"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types" "github.com/echovault/echovault/pkg/types"
"math" "math"
@@ -63,7 +62,7 @@ func handleZADD(params types.HandlerFuncParams) ([]byte, error) {
return nil, errors.New("score/member pairs must be float/string") return nil, errors.New("score/member pairs must be float/string")
} }
var members []sorted_set.MemberParam var members []MemberParam
for i := 0; i < len(params.Command[membersStartIndex:]); i++ { for i := 0; i < len(params.Command[membersStartIndex:]); i++ {
if i%2 != 0 { if i%2 != 0 {
@@ -77,29 +76,29 @@ func handleZADD(params types.HandlerFuncParams) ([]byte, error) {
var s float64 var s float64
if strings.ToLower(score.(string)) == "-inf" { if strings.ToLower(score.(string)) == "-inf" {
s = math.Inf(-1) s = math.Inf(-1)
members = append(members, sorted_set.MemberParam{ members = append(members, MemberParam{
Value: sorted_set.Value(params.Command[membersStartIndex:][i+1]), Value: Value(params.Command[membersStartIndex:][i+1]),
Score: sorted_set.Score(s), Score: Score(s),
}) })
} }
if strings.ToLower(score.(string)) == "+inf" { if strings.ToLower(score.(string)) == "+inf" {
s = math.Inf(1) s = math.Inf(1)
members = append(members, sorted_set.MemberParam{ members = append(members, MemberParam{
Value: sorted_set.Value(params.Command[membersStartIndex:][i+1]), Value: Value(params.Command[membersStartIndex:][i+1]),
Score: sorted_set.Score(s), Score: Score(s),
}) })
} }
case float64: case float64:
s, _ := score.(float64) s, _ := score.(float64)
members = append(members, sorted_set.MemberParam{ members = append(members, MemberParam{
Value: sorted_set.Value(params.Command[membersStartIndex:][i+1]), Value: Value(params.Command[membersStartIndex:][i+1]),
Score: sorted_set.Score(s), Score: Score(s),
}) })
case int: case int:
s, _ := score.(int) s, _ := score.(int)
members = append(members, sorted_set.MemberParam{ members = append(members, MemberParam{
Value: sorted_set.Value(params.Command[membersStartIndex:][i+1]), Value: Value(params.Command[membersStartIndex:][i+1]),
Score: sorted_set.Score(s), Score: Score(s),
}) })
} }
} }
@@ -148,7 +147,7 @@ func handleZADD(params types.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
defer params.KeyUnlock(params.Context, key) defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
} }
@@ -171,7 +170,7 @@ func handleZADD(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyUnlock(params.Context, key) defer params.KeyUnlock(params.Context, key)
set := sorted_set.NewSortedSet(members) set := NewSortedSet(members)
if err = params.SetValue(params.Context, key, set); err != nil { if err = params.SetValue(params.Context, key, set); err != nil {
return nil, err return nil, err
} }
@@ -195,7 +194,7 @@ func handleZCARD(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyRUnlock(params.Context, key) defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
} }
@@ -211,40 +210,40 @@ func handleZCOUNT(params types.HandlerFuncParams) ([]byte, error) {
key := keys.ReadKeys[0] key := keys.ReadKeys[0]
minimum := sorted_set.Score(math.Inf(-1)) minimum := Score(math.Inf(-1))
switch internal.AdaptType(params.Command[2]).(type) { switch internal.AdaptType(params.Command[2]).(type) {
default: default:
return nil, errors.New("min constraint must be a double") return nil, errors.New("min constraint must be a double")
case string: case string:
if strings.ToLower(params.Command[2]) == "+inf" { if strings.ToLower(params.Command[2]) == "+inf" {
minimum = sorted_set.Score(math.Inf(1)) minimum = Score(math.Inf(1))
} else { } else {
return nil, errors.New("min constraint must be a double") return nil, errors.New("min constraint must be a double")
} }
case float64: case float64:
s, _ := internal.AdaptType(params.Command[2]).(float64) s, _ := internal.AdaptType(params.Command[2]).(float64)
minimum = sorted_set.Score(s) minimum = Score(s)
case int: case int:
s, _ := internal.AdaptType(params.Command[2]).(int) s, _ := internal.AdaptType(params.Command[2]).(int)
minimum = sorted_set.Score(s) minimum = Score(s)
} }
maximum := sorted_set.Score(math.Inf(1)) maximum := Score(math.Inf(1))
switch internal.AdaptType(params.Command[3]).(type) { switch internal.AdaptType(params.Command[3]).(type) {
default: default:
return nil, errors.New("max constraint must be a double") return nil, errors.New("max constraint must be a double")
case string: case string:
if strings.ToLower(params.Command[3]) == "-inf" { if strings.ToLower(params.Command[3]) == "-inf" {
maximum = sorted_set.Score(math.Inf(-1)) maximum = Score(math.Inf(-1))
} else { } else {
return nil, errors.New("max constraint must be a double") return nil, errors.New("max constraint must be a double")
} }
case float64: case float64:
s, _ := internal.AdaptType(params.Command[3]).(float64) s, _ := internal.AdaptType(params.Command[3]).(float64)
maximum = sorted_set.Score(s) maximum = Score(s)
case int: case int:
s, _ := internal.AdaptType(params.Command[3]).(int) s, _ := internal.AdaptType(params.Command[3]).(int)
maximum = sorted_set.Score(s) maximum = Score(s)
} }
if !params.KeyExists(params.Context, key) { if !params.KeyExists(params.Context, key) {
@@ -256,12 +255,12 @@ func handleZCOUNT(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyRUnlock(params.Context, key) defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
} }
var members []sorted_set.MemberParam var members []MemberParam
for _, m := range set.GetAll() { for _, m := range set.GetAll() {
if m.Score >= minimum && m.Score <= maximum { if m.Score >= minimum && m.Score <= maximum {
members = append(members, m) members = append(members, m)
@@ -290,7 +289,7 @@ func handleZLEXCOUNT(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyRUnlock(params.Context, key) defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
} }
@@ -347,13 +346,13 @@ func handleZDIFF(params types.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0]) defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
baseSortedSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*sorted_set.SortedSet) baseSortedSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0]) return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0])
} }
// Extract the remaining sets // Extract the remaining sets
var sets []*sorted_set.SortedSet var sets []*SortedSet
for i := 1; i < len(keys.ReadKeys); i++ { for i := 1; i < len(keys.ReadKeys); i++ {
if !params.KeyExists(params.Context, keys.ReadKeys[i]) { if !params.KeyExists(params.Context, keys.ReadKeys[i]) {
@@ -364,7 +363,7 @@ func handleZDIFF(params types.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
locks[keys.ReadKeys[i]] = locked locks[keys.ReadKeys[i]] = locked
set, ok := params.GetValue(params.Context, keys.ReadKeys[i]).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, keys.ReadKeys[i]).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i]) return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i])
} }
@@ -415,19 +414,19 @@ func handleZDIFFSTORE(params types.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0]) defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
baseSortedSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*sorted_set.SortedSet) baseSortedSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0]) return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0])
} }
var sets []*sorted_set.SortedSet var sets []*SortedSet
for i := 1; i < len(keys.ReadKeys); i++ { for i := 1; i < len(keys.ReadKeys); i++ {
if params.KeyExists(params.Context, keys.ReadKeys[i]) { if params.KeyExists(params.Context, keys.ReadKeys[i]) {
if _, err = params.KeyRLock(params.Context, keys.ReadKeys[i]); err != nil { if _, err = params.KeyRLock(params.Context, keys.ReadKeys[i]); err != nil {
return nil, err return nil, err
} }
set, ok := params.GetValue(params.Context, keys.ReadKeys[i]).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, keys.ReadKeys[i]).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i]) return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i])
} }
@@ -462,26 +461,26 @@ func handleZINCRBY(params types.HandlerFuncParams) ([]byte, error) {
} }
key := keys.WriteKeys[0] key := keys.WriteKeys[0]
member := sorted_set.Value(params.Command[3]) member := Value(params.Command[3])
var increment sorted_set.Score var increment Score
switch internal.AdaptType(params.Command[2]).(type) { switch internal.AdaptType(params.Command[2]).(type) {
default: default:
return nil, errors.New("increment must be a double") return nil, errors.New("increment must be a double")
case string: case string:
if strings.EqualFold("-inf", strings.ToLower(params.Command[2])) { if strings.EqualFold("-inf", strings.ToLower(params.Command[2])) {
increment = sorted_set.Score(math.Inf(-1)) increment = Score(math.Inf(-1))
} else if strings.EqualFold("+inf", strings.ToLower(params.Command[2])) { } else if strings.EqualFold("+inf", strings.ToLower(params.Command[2])) {
increment = sorted_set.Score(math.Inf(1)) increment = Score(math.Inf(1))
} else { } else {
return nil, errors.New("increment must be a double") return nil, errors.New("increment must be a double")
} }
case float64: case float64:
s, _ := internal.AdaptType(params.Command[2]).(float64) s, _ := internal.AdaptType(params.Command[2]).(float64)
increment = sorted_set.Score(s) increment = Score(s)
case int: case int:
s, _ := internal.AdaptType(params.Command[2]).(int) s, _ := internal.AdaptType(params.Command[2]).(int)
increment = sorted_set.Score(s) increment = Score(s)
} }
if !params.KeyExists(params.Context, key) { if !params.KeyExists(params.Context, key) {
@@ -493,7 +492,7 @@ func handleZINCRBY(params types.HandlerFuncParams) ([]byte, error) {
if err = params.SetValue( if err = params.SetValue(
params.Context, params.Context,
key, key,
sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: member, Score: increment}}), NewSortedSet([]MemberParam{{Value: member, Score: increment}}),
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@@ -505,12 +504,12 @@ func handleZINCRBY(params types.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
defer params.KeyUnlock(params.Context, key) defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
} }
if _, err = set.AddOrUpdate( if _, err = set.AddOrUpdate(
[]sorted_set.MemberParam{ []MemberParam{
{Value: member, Score: increment}}, {Value: member, Score: increment}},
"xx", "xx",
nil, nil,
@@ -542,7 +541,7 @@ func handleZINTER(params types.HandlerFuncParams) ([]byte, error) {
} }
}() }()
var setParams []sorted_set.SortedSetParam var setParams []SortedSetParam
for i := 0; i < len(keys); i++ { for i := 0; i < len(keys); i++ {
if !params.KeyExists(params.Context, keys[i]) { if !params.KeyExists(params.Context, keys[i]) {
@@ -553,17 +552,17 @@ func handleZINTER(params types.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
locks[keys[i]] = true locks[keys[i]] = true
set, ok := params.GetValue(params.Context, keys[i]).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
} }
setParams = append(setParams, sorted_set.SortedSetParam{ setParams = append(setParams, SortedSetParam{
Set: set, Set: set,
Weight: weights[i], Weight: weights[i],
}) })
} }
intersect := sorted_set.Intersect(aggregate, setParams...) intersect := Intersect(aggregate, setParams...)
res := fmt.Sprintf("*%d", intersect.Cardinality()) res := fmt.Sprintf("*%d", intersect.Cardinality())
@@ -609,7 +608,7 @@ func handleZINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
} }
}() }()
var setParams []sorted_set.SortedSetParam var setParams []SortedSetParam
for i := 0; i < len(keys); i++ { for i := 0; i < len(keys); i++ {
if !params.KeyExists(params.Context, keys[i]) { if !params.KeyExists(params.Context, keys[i]) {
@@ -619,17 +618,17 @@ func handleZINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
locks[keys[i]] = true locks[keys[i]] = true
set, ok := params.GetValue(params.Context, keys[i]).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
} }
setParams = append(setParams, sorted_set.SortedSetParam{ setParams = append(setParams, SortedSetParam{
Set: set, Set: set,
Weight: weights[i], Weight: weights[i],
}) })
} }
intersect := sorted_set.Intersect(aggregate, setParams...) intersect := Intersect(aggregate, setParams...)
if params.KeyExists(params.Context, destination) && intersect.Cardinality() > 0 { if params.KeyExists(params.Context, destination) && intersect.Cardinality() > 0 {
if _, err = params.KeyLock(params.Context, destination); err != nil { if _, err = params.KeyLock(params.Context, destination); err != nil {
@@ -700,7 +699,7 @@ func handleZMPOP(params types.HandlerFuncParams) ([]byte, error) {
if _, err = params.KeyLock(params.Context, keys.WriteKeys[i]); err != nil { if _, err = params.KeyLock(params.Context, keys.WriteKeys[i]); err != nil {
continue continue
} }
v, ok := params.GetValue(params.Context, keys.WriteKeys[i]).(*sorted_set.SortedSet) v, ok := params.GetValue(params.Context, keys.WriteKeys[i]).(*SortedSet)
if !ok || v.Cardinality() == 0 { if !ok || v.Cardinality() == 0 {
params.KeyUnlock(params.Context, keys.WriteKeys[i]) params.KeyUnlock(params.Context, keys.WriteKeys[i])
continue continue
@@ -760,7 +759,7 @@ func handleZPOP(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyUnlock(params.Context, key) defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at key %s is not a sorted set", key) return nil, fmt.Errorf("value at key %s is not a sorted set", key)
} }
@@ -797,7 +796,7 @@ func handleZMSCORE(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyRUnlock(params.Context, key) defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
} }
@@ -806,10 +805,10 @@ func handleZMSCORE(params types.HandlerFuncParams) ([]byte, error) {
res := fmt.Sprintf("*%d", len(members)) res := fmt.Sprintf("*%d", len(members))
var member sorted_set.MemberObject var member MemberObject
for i := 0; i < len(members); i++ { for i := 0; i < len(members); i++ {
member = set.Get(sorted_set.Value(members[i])) member = set.Get(Value(members[i]))
if !member.Exists { if !member.Exists {
res = fmt.Sprintf("%s\r\n$-1", res) res = fmt.Sprintf("%s\r\n$-1", res)
} else { } else {
@@ -859,7 +858,7 @@ func handleZRANDMEMBER(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyRUnlock(params.Context, key) defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
} }
@@ -903,13 +902,13 @@ func handleZRANK(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyRUnlock(params.Context, key) defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
} }
members := set.GetAll() members := set.GetAll()
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int { slices.SortFunc(members, func(a, b MemberParam) int {
if strings.EqualFold(params.Command[0], "zrevrank") { if strings.EqualFold(params.Command[0], "zrevrank") {
return cmp.Compare(b.Score, a.Score) return cmp.Compare(b.Score, a.Score)
} }
@@ -917,7 +916,7 @@ func handleZRANK(params types.HandlerFuncParams) ([]byte, error) {
}) })
for i := 0; i < len(members); i++ { for i := 0; i < len(members); i++ {
if members[i].Value == sorted_set.Value(member) { if members[i].Value == Value(member) {
if withscores { if withscores {
score := strconv.FormatFloat(float64(members[i].Score), 'f', -1, 64) score := strconv.FormatFloat(float64(members[i].Score), 'f', -1, 64)
return []byte(fmt.Sprintf("*2\r\n:%d\r\n$%d\r\n%s\r\n", i, len(score), score)), nil return []byte(fmt.Sprintf("*2\r\n:%d\r\n$%d\r\n%s\r\n", i, len(score), score)), nil
@@ -947,14 +946,14 @@ func handleZREM(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyUnlock(params.Context, key) defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
} }
deletedCount := 0 deletedCount := 0
for _, m := range params.Command[2:] { for _, m := range params.Command[2:] {
if set.Remove(sorted_set.Value(m)) { if set.Remove(Value(m)) {
deletedCount += 1 deletedCount += 1
} }
} }
@@ -977,11 +976,11 @@ func handleZSCORE(params types.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
defer params.KeyRUnlock(params.Context, key) defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
} }
member := set.Get(sorted_set.Value(params.Command[2])) member := set.Get(Value(params.Command[2]))
if !member.Exists { if !member.Exists {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
} }
@@ -1020,13 +1019,13 @@ func handleZREMRANGEBYSCORE(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyUnlock(params.Context, key) defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
} }
for _, m := range set.GetAll() { for _, m := range set.GetAll() {
if m.Score >= sorted_set.Score(minimum) && m.Score <= sorted_set.Score(maximum) { if m.Score >= Score(minimum) && m.Score <= Score(maximum) {
set.Remove(m.Value) set.Remove(m.Value)
deletedCount += 1 deletedCount += 1
} }
@@ -1062,7 +1061,7 @@ func handleZREMRANGEBYRANK(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyUnlock(params.Context, key) defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
} }
@@ -1079,7 +1078,7 @@ func handleZREMRANGEBYRANK(params types.HandlerFuncParams) ([]byte, error) {
} }
members := set.GetAll() members := set.GetAll()
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int { slices.SortFunc(members, func(a, b MemberParam) int {
return cmp.Compare(a.Score, b.Score) return cmp.Compare(a.Score, b.Score)
}) })
@@ -1119,7 +1118,7 @@ func handleZREMRANGEBYLEX(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyUnlock(params.Context, key) defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
} }
@@ -1217,7 +1216,7 @@ func handleZRANGE(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyRUnlock(params.Context, key) defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
} }
@@ -1231,7 +1230,7 @@ func handleZRANGE(params types.HandlerFuncParams) ([]byte, error) {
members := set.GetAll() members := set.GetAll()
if strings.EqualFold(policy, "byscore") { if strings.EqualFold(policy, "byscore") {
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int { slices.SortFunc(members, func(a, b MemberParam) int {
// Do a score sort // Do a score sort
if reverse { if reverse {
return cmp.Compare(b.Score, a.Score) return cmp.Compare(b.Score, a.Score)
@@ -1246,7 +1245,7 @@ func handleZRANGE(params types.HandlerFuncParams) ([]byte, error) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
} }
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int { slices.SortFunc(members, func(a, b MemberParam) int {
if reverse { if reverse {
return internal.CompareLex(string(b.Value), string(a.Value)) return internal.CompareLex(string(b.Value), string(a.Value))
} }
@@ -1254,14 +1253,14 @@ func handleZRANGE(params types.HandlerFuncParams) ([]byte, error) {
}) })
} }
var resultMembers []sorted_set.MemberParam var resultMembers []MemberParam
for i := offset; i <= count; i++ { for i := offset; i <= count; i++ {
if i >= len(members) { if i >= len(members) {
break break
} }
if strings.EqualFold(policy, "byscore") { if strings.EqualFold(policy, "byscore") {
if members[i].Score >= sorted_set.Score(scoreStart) && members[i].Score <= sorted_set.Score(scoreStop) { if members[i].Score >= Score(scoreStart) && members[i].Score <= Score(scoreStop) {
resultMembers = append(resultMembers, members[i]) resultMembers = append(resultMembers, members[i])
} }
continue continue
@@ -1354,7 +1353,7 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
} }
defer params.KeyRUnlock(params.Context, source) defer params.KeyRUnlock(params.Context, source)
set, ok := params.GetValue(params.Context, source).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, source).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", source) return nil, fmt.Errorf("value at %s is not a sorted set", source)
} }
@@ -1368,7 +1367,7 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
members := set.GetAll() members := set.GetAll()
if strings.EqualFold(policy, "byscore") { if strings.EqualFold(policy, "byscore") {
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int { slices.SortFunc(members, func(a, b MemberParam) int {
// Do a score sort // Do a score sort
if reverse { if reverse {
return cmp.Compare(b.Score, a.Score) return cmp.Compare(b.Score, a.Score)
@@ -1383,7 +1382,7 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
} }
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int { slices.SortFunc(members, func(a, b MemberParam) int {
if reverse { if reverse {
return internal.CompareLex(string(b.Value), string(a.Value)) return internal.CompareLex(string(b.Value), string(a.Value))
} }
@@ -1391,14 +1390,14 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
}) })
} }
var resultMembers []sorted_set.MemberParam var resultMembers []MemberParam
for i := offset; i <= count; i++ { for i := offset; i <= count; i++ {
if i >= len(members) { if i >= len(members) {
break break
} }
if strings.EqualFold(policy, "byscore") { if strings.EqualFold(policy, "byscore") {
if members[i].Score >= sorted_set.Score(scoreStart) && members[i].Score <= sorted_set.Score(scoreStop) { if members[i].Score >= Score(scoreStart) && members[i].Score <= Score(scoreStop) {
resultMembers = append(resultMembers, members[i]) resultMembers = append(resultMembers, members[i])
} }
continue continue
@@ -1409,7 +1408,7 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
} }
} }
newSortedSet := sorted_set.NewSortedSet(resultMembers) newSortedSet := NewSortedSet(resultMembers)
if params.KeyExists(params.Context, destination) { if params.KeyExists(params.Context, destination) {
if _, err = params.KeyLock(params.Context, destination); err != nil { if _, err = params.KeyLock(params.Context, destination); err != nil {
@@ -1448,7 +1447,7 @@ func handleZUNION(params types.HandlerFuncParams) ([]byte, error) {
} }
}() }()
var setParams []sorted_set.SortedSetParam var setParams []SortedSetParam
for i := 0; i < len(keys); i++ { for i := 0; i < len(keys); i++ {
if params.KeyExists(params.Context, keys[i]) { if params.KeyExists(params.Context, keys[i]) {
@@ -1456,18 +1455,18 @@ func handleZUNION(params types.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
locks[keys[i]] = true locks[keys[i]] = true
set, ok := params.GetValue(params.Context, keys[i]).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
} }
setParams = append(setParams, sorted_set.SortedSetParam{ setParams = append(setParams, SortedSetParam{
Set: set, Set: set,
Weight: weights[i], Weight: weights[i],
}) })
} }
} }
union := sorted_set.Union(aggregate, setParams...) union := Union(aggregate, setParams...)
res := fmt.Sprintf("*%d", union.Cardinality()) res := fmt.Sprintf("*%d", union.Cardinality())
for _, m := range union.GetAll() { for _, m := range union.GetAll() {
@@ -1510,7 +1509,7 @@ func handleZUNIONSTORE(params types.HandlerFuncParams) ([]byte, error) {
} }
}() }()
var setParams []sorted_set.SortedSetParam var setParams []SortedSetParam
for i := 0; i < len(keys); i++ { for i := 0; i < len(keys); i++ {
if params.KeyExists(params.Context, keys[i]) { if params.KeyExists(params.Context, keys[i]) {
@@ -1518,18 +1517,18 @@ func handleZUNIONSTORE(params types.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
locks[keys[i]] = true locks[keys[i]] = true
set, ok := params.GetValue(params.Context, keys[i]).(*sorted_set.SortedSet) set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
} }
setParams = append(setParams, sorted_set.SortedSetParam{ setParams = append(setParams, SortedSetParam{
Set: set, Set: set,
Weight: weights[i], Weight: weights[i],
}) })
} }
} }
union := sorted_set.Union(aggregate, setParams...) union := Union(aggregate, setParams...)
if params.KeyExists(params.Context, destination) { if params.KeyExists(params.Context, destination) {
if _, err = params.KeyLock(params.Context, destination); err != nil { if _, err = params.KeyLock(params.Context, destination); err != nil {

View File

@@ -90,3 +90,80 @@ func extractKeysWeightsAggregateWithScores(cmd []string) ([]string, []int, strin
return keys, weights, aggregate, withscores, nil return keys, weights, aggregate, withscores, nil
} }
func validateUpdatePolicy(updatePolicy interface{}) (string, error) {
if updatePolicy == nil {
return "", nil
}
err := errors.New("update policy must be a string of Value NX or XX")
policy, ok := updatePolicy.(string)
if !ok {
return "", err
}
if !slices.Contains([]string{"nx", "xx"}, strings.ToLower(policy)) {
return "", err
}
return policy, nil
}
func validateComparison(comparison interface{}) (string, error) {
if comparison == nil {
return "", nil
}
err := errors.New("comparison condition must be a string of Value LT or GT")
comp, ok := comparison.(string)
if !ok {
return "", err
}
if !slices.Contains([]string{"lt", "gt"}, strings.ToLower(comp)) {
return "", err
}
return comp, nil
}
func validateChanged(changed interface{}) (string, error) {
if changed == nil {
return "", nil
}
err := errors.New("changed condition should be a string of Value CH")
ch, ok := changed.(string)
if !ok {
return "", err
}
if !strings.EqualFold(ch, "ch") {
return "", err
}
return ch, nil
}
func validateIncr(incr interface{}) (string, error) {
if incr == nil {
return "", nil
}
err := errors.New("incr condition should be a string of Value INCR")
i, ok := incr.(string)
if !ok {
return "", err
}
if !strings.EqualFold(i, "incr") {
return "", err
}
return i, nil
}
func compareScores(old Score, new Score, comp string) Score {
switch strings.ToLower(comp) {
default:
return new
case "lt":
if new < old {
return new
}
return old
case "gt":
if new > old {
return new
}
return old
}
}

View File

@@ -1,98 +0,0 @@
// Copyright 2024 Kelvin Clement Mwinuka
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sorted_set
import (
"errors"
"slices"
"strings"
)
func validateUpdatePolicy(updatePolicy interface{}) (string, error) {
if updatePolicy == nil {
return "", nil
}
err := errors.New("update policy must be a string of Value NX or XX")
policy, ok := updatePolicy.(string)
if !ok {
return "", err
}
if !slices.Contains([]string{"nx", "xx"}, strings.ToLower(policy)) {
return "", err
}
return policy, nil
}
func validateComparison(comparison interface{}) (string, error) {
if comparison == nil {
return "", nil
}
err := errors.New("comparison condition must be a string of Value LT or GT")
comp, ok := comparison.(string)
if !ok {
return "", err
}
if !slices.Contains([]string{"lt", "gt"}, strings.ToLower(comp)) {
return "", err
}
return comp, nil
}
func validateChanged(changed interface{}) (string, error) {
if changed == nil {
return "", nil
}
err := errors.New("changed condition should be a string of Value CH")
ch, ok := changed.(string)
if !ok {
return "", err
}
if !strings.EqualFold(ch, "ch") {
return "", err
}
return ch, nil
}
func validateIncr(incr interface{}) (string, error) {
if incr == nil {
return "", nil
}
err := errors.New("incr condition should be a string of Value INCR")
i, ok := incr.(string)
if !ok {
return "", err
}
if !strings.EqualFold(i, "incr") {
return "", err
}
return i, nil
}
func compareScores(old Score, new Score, comp string) Score {
switch strings.ToLower(comp) {
default:
return new
case "lt":
if new < old {
return new
}
return old
case "gt":
if new > old {
return new
}
return old
}
}

View File

@@ -1,31 +0,0 @@
package commands
import (
"github.com/echovault/echovault/pkg/modules/acl"
"github.com/echovault/echovault/pkg/modules/admin"
"github.com/echovault/echovault/pkg/modules/connection"
"github.com/echovault/echovault/pkg/modules/generic"
"github.com/echovault/echovault/pkg/modules/hash"
"github.com/echovault/echovault/pkg/modules/list"
"github.com/echovault/echovault/pkg/modules/pubsub"
"github.com/echovault/echovault/pkg/modules/set"
"github.com/echovault/echovault/pkg/modules/sorted_set"
str "github.com/echovault/echovault/pkg/modules/string"
"github.com/echovault/echovault/pkg/types"
)
// All returns all the commands currently available on EchoVault
func All() []types.Command {
var commands []types.Command
commands = append(commands, acl.Commands()...)
commands = append(commands, admin.Commands()...)
commands = append(commands, generic.Commands()...)
commands = append(commands, hash.Commands()...)
commands = append(commands, list.Commands()...)
commands = append(commands, connection.Commands()...)
commands = append(commands, pubsub.Commands()...)
commands = append(commands, set.Commands()...)
commands = append(commands, sorted_set.Commands()...)
commands = append(commands, str.Commands()...)
return commands
}

View File

@@ -21,13 +21,21 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/acl"
"github.com/echovault/echovault/internal/aof" "github.com/echovault/echovault/internal/aof"
"github.com/echovault/echovault/internal/clock" "github.com/echovault/echovault/internal/clock"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/eviction" "github.com/echovault/echovault/internal/eviction"
"github.com/echovault/echovault/internal/memberlist" "github.com/echovault/echovault/internal/memberlist"
"github.com/echovault/echovault/internal/pubsub" "github.com/echovault/echovault/internal/modules/acl"
"github.com/echovault/echovault/internal/modules/admin"
"github.com/echovault/echovault/internal/modules/connection"
"github.com/echovault/echovault/internal/modules/generic"
"github.com/echovault/echovault/internal/modules/hash"
"github.com/echovault/echovault/internal/modules/list"
"github.com/echovault/echovault/internal/modules/pubsub"
"github.com/echovault/echovault/internal/modules/set"
"github.com/echovault/echovault/internal/modules/sorted_set"
str "github.com/echovault/echovault/internal/modules/string"
"github.com/echovault/echovault/internal/raft" "github.com/echovault/echovault/internal/raft"
"github.com/echovault/echovault/internal/snapshot" "github.com/echovault/echovault/internal/snapshot"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
@@ -126,11 +134,24 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
echovault := &EchoVault{ echovault := &EchoVault{
clock: clock.NewClock(), clock: clock.NewClock(),
context: context.Background(), context: context.Background(),
commands: make([]types.Command, 0),
config: config.DefaultConfig(), config: config.DefaultConfig(),
store: make(map[string]internal.KeyData), store: make(map[string]internal.KeyData),
keyLocks: make(map[string]*sync.RWMutex), keyLocks: make(map[string]*sync.RWMutex),
keyCreationLock: &sync.Mutex{}, keyCreationLock: &sync.Mutex{},
commands: func() []types.Command {
var commands []types.Command
commands = append(commands, acl.Commands()...)
commands = append(commands, admin.Commands()...)
commands = append(commands, generic.Commands()...)
commands = append(commands, hash.Commands()...)
commands = append(commands, list.Commands()...)
commands = append(commands, connection.Commands()...)
commands = append(commands, pubsub.Commands()...)
commands = append(commands, set.Commands()...)
commands = append(commands, sorted_set.Commands()...)
commands = append(commands, str.Commands()...)
return commands
}(),
} }
for _, option := range options { for _, option := range options {

View File

@@ -589,16 +589,3 @@ func (server *EchoVault) evictKeysWithExpiredTTL(ctx context.Context) error {
return nil return nil
} }
func presetValue(server *EchoVault, key string, value interface{}) {
_, _ = server.CreateKeyAndLock(server.context, key)
_ = server.SetValue(server.context, key, value)
server.KeyUnlock(server.context, key)
}
func presetKeyData(server *EchoVault, key string, data internal.KeyData) {
_, _ = server.CreateKeyAndLock(server.context, key)
defer server.KeyUnlock(server.context, key)
_ = server.SetValue(server.context, key, data.Value)
server.SetExpiry(server.context, key, data.ExpireAt, false)
}

View File

@@ -17,11 +17,10 @@ package acl
import ( import (
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
internal_acl "github.com/echovault/echovault/internal/acl"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/modules/acl"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
acl2 "github.com/echovault/echovault/pkg/modules/acl"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
"slices" "slices"
@@ -61,44 +60,43 @@ func setUpServer(bindAddr string, port uint16, requirePass bool, aclConfig strin
} }
mockServer, _ := echovault.NewEchoVault( mockServer, _ := echovault.NewEchoVault(
echovault.WithCommands(acl2.Commands()),
echovault.WithConfig(conf), echovault.WithConfig(conf),
) )
// Add the initial test users to the ACL module // Add the initial test users to the ACL module
acl := mockServer.GetACL().(*internal_acl.ACL) a := mockServer.GetACL().(*acl.ACL)
acl.AddUsers(generateInitialTestUsers()) a.AddUsers(generateInitialTestUsers())
return mockServer return mockServer
} }
func generateInitialTestUsers() []*internal_acl.User { func generateInitialTestUsers() []*acl.User {
// User with both hash password and plaintext password // User with both hash password and plaintext password
withPasswordUser := internal_acl.CreateUser("with_password_user") withPasswordUser := acl.CreateUser("with_password_user")
h := sha256.New() h := sha256.New()
h.Write([]byte("password3")) h.Write([]byte("password3"))
withPasswordUser.Passwords = []internal_acl.Password{ withPasswordUser.Passwords = []acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "password2"}, {PasswordType: acl.PasswordPlainText, PasswordValue: "password2"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: string(h.Sum(nil))}, {PasswordType: acl.PasswordSHA256, PasswordValue: string(h.Sum(nil))},
} }
withPasswordUser.IncludedCategories = []string{"*"} withPasswordUser.IncludedCategories = []string{"*"}
withPasswordUser.IncludedCommands = []string{"*"} withPasswordUser.IncludedCommands = []string{"*"}
// User with NoPassword option // User with NoPassword option
noPasswordUser := internal_acl.CreateUser("no_password_user") noPasswordUser := acl.CreateUser("no_password_user")
noPasswordUser.Passwords = []internal_acl.Password{ noPasswordUser.Passwords = []acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "password4"}, {PasswordType: acl.PasswordPlainText, PasswordValue: "password4"},
} }
noPasswordUser.NoPassword = true noPasswordUser.NoPassword = true
// Disabled user // Disabled user
disabledUser := internal_acl.CreateUser("disabled_user") disabledUser := acl.CreateUser("disabled_user")
disabledUser.Passwords = []internal_acl.Password{ disabledUser.Passwords = []acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "password5"}, {PasswordType: acl.PasswordPlainText, PasswordValue: "password5"},
} }
disabledUser.Enabled = false disabledUser.Enabled = false
return []*internal_acl.User{ return []*acl.User{
withPasswordUser, withPasswordUser,
noPasswordUser, noPasswordUser,
disabledUser, disabledUser,
@@ -129,7 +127,7 @@ func compareSlices[T comparable](res, expected []T) error {
} }
// compareUsers compares 2 users and checks if all their fields are equal // compareUsers compares 2 users and checks if all their fields are equal
func compareUsers(user1, user2 *internal_acl.User) error { func compareUsers(user1, user2 *acl.User) error {
// Compare flags // Compare flags
if user1.Username != user2.Username { if user1.Username != user2.Username {
return fmt.Errorf("mismatched usernames \"%s\", and \"%s\"", user1.Username, user2.Username) return fmt.Errorf("mismatched usernames \"%s\", and \"%s\"", user1.Username, user2.Username)
@@ -146,14 +144,14 @@ func compareUsers(user1, user2 *internal_acl.User) error {
// Compare passwords // Compare passwords
for _, password1 := range user1.Passwords { for _, password1 := range user1.Passwords {
if !slices.ContainsFunc(user2.Passwords, func(password2 internal_acl.Password) bool { if !slices.ContainsFunc(user2.Passwords, func(password2 acl.Password) bool {
return password1.PasswordType == password2.PasswordType && password1.PasswordValue == password2.PasswordValue return password1.PasswordType == password2.PasswordType && password1.PasswordValue == password2.PasswordValue
}) { }) {
return fmt.Errorf("found password %+v in user1 that was not found in user2", password1) return fmt.Errorf("found password %+v in user1 that was not found in user2", password1)
} }
} }
for _, password2 := range user2.Passwords { for _, password2 := range user2.Passwords {
if !slices.ContainsFunc(user1.Passwords, func(password1 internal_acl.Password) bool { if !slices.ContainsFunc(user1.Passwords, func(password1 acl.Password) bool {
return password1.PasswordType == password2.PasswordType && password1.PasswordValue == password2.PasswordValue return password1.PasswordType == password2.PasswordType && password1.PasswordValue == password2.PasswordValue
}) { }) {
return fmt.Errorf("found password %+v in user2 that was not found in user1", password2) return fmt.Errorf("found password %+v in user2 that was not found in user1", password2)
@@ -392,14 +390,6 @@ func Test_HandleCat(t *testing.T) {
t.Errorf("could not find expected command \"%s\" in the response array for category", expected) t.Errorf("could not find expected command \"%s\" in the response array for category", expected)
} }
} }
// Check if all the elements in the response array are in the expected array
for _, value := range resArr {
if !slices.ContainsFunc(test.wantRes, func(expected string) bool {
return value.String() == expected
}) {
t.Errorf("could not find response command \"%s\" in the expected array", value.String())
}
}
} }
} }
@@ -469,7 +459,7 @@ func Test_HandleSetUser(t *testing.T) {
}() }()
wg.Wait() wg.Wait()
acl, ok := mockServer.GetACL().(*internal_acl.ACL) a, ok := mockServer.GetACL().(*acl.ACL)
if !ok { if !ok {
t.Error("error loading ACL module") t.Error("error loading ACL module")
} }
@@ -487,11 +477,11 @@ func Test_HandleSetUser(t *testing.T) {
r := resp.NewConn(conn) r := resp.NewConn(conn)
tests := []struct { tests := []struct {
presetUser *internal_acl.User presetUser *acl.User
cmd []resp.Value cmd []resp.Value
wantRes string wantRes string
wantErr string wantErr string
wantUser *internal_acl.User wantUser *acl.User
}{ }{
{ {
// 1. Create new enabled user // 1. Create new enabled user
@@ -504,8 +494,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_1") user := acl.CreateUser("set_user_1")
user.Enabled = true user.Enabled = true
user.Normalise() user.Normalise()
return user return user
@@ -522,8 +512,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_2") user := acl.CreateUser("set_user_2")
user.Enabled = false user.Enabled = false
user.Normalise() user.Normalise()
return user return user
@@ -544,14 +534,14 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_3") user := acl.CreateUser("set_user_3")
user.Enabled = true user.Enabled = true
user.Passwords = []internal_acl.Password{ user.Passwords = []acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"}, {PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"}, {PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")}, {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")}, {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")},
} }
user.Normalise() user.Normalise()
return user return user
@@ -559,14 +549,14 @@ func Test_HandleSetUser(t *testing.T) {
}, },
{ {
// 4. Remove plaintext and SHA256 password from existing user // 4. Remove plaintext and SHA256 password from existing user
presetUser: func() *internal_acl.User { presetUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_4") user := acl.CreateUser("set_user_4")
user.Enabled = true user.Enabled = true
user.Passwords = []internal_acl.Password{ user.Passwords = []acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"}, {PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"}, {PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")}, {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")}, {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")},
} }
user.Normalise() user.Normalise()
return user return user
@@ -581,12 +571,12 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_4") user := acl.CreateUser("set_user_4")
user.Enabled = true user.Enabled = true
user.Passwords = []internal_acl.Password{ user.Passwords = []acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"}, {PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")}, {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
} }
user.Normalise() user.Normalise()
return user return user
@@ -604,8 +594,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_5") user := acl.CreateUser("set_user_5")
user.Enabled = true user.Enabled = true
user.ExcludedCommands = []string{"*"} user.ExcludedCommands = []string{"*"}
user.ExcludedCategories = []string{"*"} user.ExcludedCategories = []string{"*"}
@@ -625,8 +615,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_6") user := acl.CreateUser("set_user_6")
user.Enabled = true user.Enabled = true
user.IncludedCategories = []string{"*"} user.IncludedCategories = []string{"*"}
user.ExcludedCategories = []string{} user.ExcludedCategories = []string{}
@@ -646,8 +636,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_7") user := acl.CreateUser("set_user_7")
user.Enabled = true user.Enabled = true
user.IncludedCategories = []string{"*"} user.IncludedCategories = []string{"*"}
user.ExcludedCategories = []string{} user.ExcludedCategories = []string{}
@@ -672,8 +662,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_8") user := acl.CreateUser("set_user_8")
user.Enabled = true user.Enabled = true
user.IncludedCategories = []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory} user.IncludedCategories = []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory}
user.ExcludedCategories = []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory} user.ExcludedCategories = []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory}
@@ -693,8 +683,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_9") user := acl.CreateUser("set_user_9")
user.Enabled = true user.Enabled = true
user.NoKeys = true user.NoKeys = true
user.IncludedReadKeys = []string{} user.IncludedReadKeys = []string{}
@@ -723,8 +713,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_10") user := acl.CreateUser("set_user_10")
user.Enabled = true user.Enabled = true
user.NoKeys = false user.NoKeys = false
user.IncludedReadKeys = []string{"key1", "key2", "key3", "key4", "key5", "key6"} user.IncludedReadKeys = []string{"key1", "key2", "key3", "key4", "key5", "key6"}
@@ -745,8 +735,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_11") user := acl.CreateUser("set_user_11")
user.Enabled = true user.Enabled = true
user.IncludedPubSubChannels = []string{"*"} user.IncludedPubSubChannels = []string{"*"}
user.ExcludedPubSubChannels = []string{} user.ExcludedPubSubChannels = []string{}
@@ -766,8 +756,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_12") user := acl.CreateUser("set_user_12")
user.Enabled = true user.Enabled = true
user.IncludedPubSubChannels = []string{"*"} user.IncludedPubSubChannels = []string{"*"}
user.ExcludedPubSubChannels = []string{} user.ExcludedPubSubChannels = []string{}
@@ -790,8 +780,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_13") user := acl.CreateUser("set_user_13")
user.Enabled = true user.Enabled = true
user.IncludedPubSubChannels = []string{"channel1", "channel2"} user.IncludedPubSubChannels = []string{"channel1", "channel2"}
user.ExcludedPubSubChannels = []string{"channel3", "channel4"} user.ExcludedPubSubChannels = []string{"channel3", "channel4"}
@@ -811,8 +801,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_14") user := acl.CreateUser("set_user_14")
user.Enabled = true user.Enabled = true
user.IncludedCommands = []string{"*"} user.IncludedCommands = []string{"*"}
user.ExcludedCommands = []string{} user.ExcludedCommands = []string{}
@@ -837,8 +827,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_15") user := acl.CreateUser("set_user_15")
user.Enabled = true user.Enabled = true
user.IncludedCommands = []string{"acl|getuser", "acl|setuser", "acl|deluser"} user.IncludedCommands = []string{"acl|getuser", "acl|setuser", "acl|deluser"}
user.ExcludedCommands = []string{"rewriteaof", "save", "publish"} user.ExcludedCommands = []string{"rewriteaof", "save", "publish"}
@@ -861,24 +851,24 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_16") user := acl.CreateUser("set_user_16")
user.Enabled = true user.Enabled = true
user.NoPassword = true user.NoPassword = true
user.Passwords = []internal_acl.Password{} user.Passwords = []acl.Password{}
user.Normalise() user.Normalise()
return user return user
}(), }(),
}, },
{ {
// 17. Delete all existing users passwords using 'nopass' // 17. Delete all existing users passwords using 'nopass'
presetUser: func() *internal_acl.User { presetUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_17") user := acl.CreateUser("set_user_17")
user.Enabled = true user.Enabled = true
user.NoPassword = true user.NoPassword = true
user.Passwords = []internal_acl.Password{ user.Passwords = []acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "password1"}, {PasswordType: acl.PasswordPlainText, PasswordValue: "password1"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")}, {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")},
} }
user.Normalise() user.Normalise()
return user return user
@@ -892,24 +882,24 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_17") user := acl.CreateUser("set_user_17")
user.Enabled = true user.Enabled = true
user.NoPassword = true user.NoPassword = true
user.Passwords = []internal_acl.Password{} user.Passwords = []acl.Password{}
user.Normalise() user.Normalise()
return user return user
}(), }(),
}, },
{ {
// 18. Clear all of an existing user's passwords using 'resetpass' // 18. Clear all of an existing user's passwords using 'resetpass'
presetUser: func() *internal_acl.User { presetUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_18") user := acl.CreateUser("set_user_18")
user.Enabled = true user.Enabled = true
user.NoPassword = true user.NoPassword = true
user.Passwords = []internal_acl.Password{ user.Passwords = []acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "password1"}, {PasswordType: acl.PasswordPlainText, PasswordValue: "password1"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")}, {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")},
} }
user.Normalise() user.Normalise()
return user return user
@@ -923,19 +913,19 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_18") user := acl.CreateUser("set_user_18")
user.Enabled = true user.Enabled = true
user.NoPassword = true user.NoPassword = true
user.Passwords = []internal_acl.Password{} user.Passwords = []acl.Password{}
user.Normalise() user.Normalise()
return user return user
}(), }(),
}, },
{ {
// 19. Clear all of an existing user's command privileges using 'nocommands' // 19. Clear all of an existing user's command privileges using 'nocommands'
presetUser: func() *internal_acl.User { presetUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_19") user := acl.CreateUser("set_user_19")
user.Enabled = true user.Enabled = true
user.IncludedCommands = []string{"acl|getuser", "acl|setuser", "acl|deluser"} user.IncludedCommands = []string{"acl|getuser", "acl|setuser", "acl|deluser"}
user.ExcludedCommands = []string{"rewriteaof", "save"} user.ExcludedCommands = []string{"rewriteaof", "save"}
@@ -951,8 +941,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_19") user := acl.CreateUser("set_user_19")
user.Enabled = true user.Enabled = true
user.IncludedCommands = []string{} user.IncludedCommands = []string{}
user.ExcludedCommands = []string{"*"} user.ExcludedCommands = []string{"*"}
@@ -964,8 +954,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
{ {
// 20. Clear all of an existing user's allowed keys using 'resetkeys' // 20. Clear all of an existing user's allowed keys using 'resetkeys'
presetUser: func() *internal_acl.User { presetUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_20") user := acl.CreateUser("set_user_20")
user.Enabled = true user.Enabled = true
user.IncludedWriteKeys = []string{"key1", "key2", "key3", "key4", "key5", "key6"} user.IncludedWriteKeys = []string{"key1", "key2", "key3", "key4", "key5", "key6"}
user.IncludedReadKeys = []string{"key1", "key2", "key3", "key7", "key8", "key9"} user.IncludedReadKeys = []string{"key1", "key2", "key3", "key7", "key8", "key9"}
@@ -981,8 +971,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_20") user := acl.CreateUser("set_user_20")
user.Enabled = true user.Enabled = true
user.NoKeys = true user.NoKeys = true
user.IncludedReadKeys = []string{} user.IncludedReadKeys = []string{}
@@ -993,8 +983,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
{ {
// 21. Allow user to access all channels using 'resetchannels' // 21. Allow user to access all channels using 'resetchannels'
presetUser: func() *internal_acl.User { presetUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_21") user := acl.CreateUser("set_user_21")
user.IncludedPubSubChannels = []string{"channel1", "channel2"} user.IncludedPubSubChannels = []string{"channel1", "channel2"}
user.ExcludedPubSubChannels = []string{"channel3", "channel4"} user.ExcludedPubSubChannels = []string{"channel3", "channel4"}
user.Normalise() user.Normalise()
@@ -1008,8 +998,8 @@ func Test_HandleSetUser(t *testing.T) {
}, },
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
wantUser: func() *internal_acl.User { wantUser: func() *acl.User {
user := internal_acl.CreateUser("set_user_21") user := acl.CreateUser("set_user_21")
user.IncludedPubSubChannels = []string{} user.IncludedPubSubChannels = []string{}
user.ExcludedPubSubChannels = []string{"*"} user.ExcludedPubSubChannels = []string{"*"}
user.Normalise() user.Normalise()
@@ -1020,7 +1010,7 @@ func Test_HandleSetUser(t *testing.T) {
for i, test := range tests { for i, test := range tests {
if test.presetUser != nil { if test.presetUser != nil {
acl.AddUsers([]*internal_acl.User{test.presetUser}) a.AddUsers([]*acl.User{test.presetUser})
} }
if err = r.WriteArray(test.cmd); err != nil { if err = r.WriteArray(test.cmd); err != nil {
t.Error(err) t.Error(err)
@@ -1042,13 +1032,13 @@ func Test_HandleSetUser(t *testing.T) {
continue continue
} }
expectedUser := test.wantUser expectedUser := test.wantUser
currUserIdx := slices.IndexFunc(acl.Users, func(user *internal_acl.User) bool { currUserIdx := slices.IndexFunc(a.Users, func(user *acl.User) bool {
return user.Username == expectedUser.Username return user.Username == expectedUser.Username
}) })
if currUserIdx == -1 { if currUserIdx == -1 {
t.Errorf("expected to find user with username \"%s\" but could not find them.", expectedUser.Username) t.Errorf("expected to find user with username \"%s\" but could not find them.", expectedUser.Username)
} }
if err = compareUsers(expectedUser, acl.Users[currUserIdx]); err != nil { if err = compareUsers(expectedUser, a.Users[currUserIdx]); err != nil {
t.Errorf("test idx: %d, %+v", i, err) t.Errorf("test idx: %d, %+v", i, err)
} }
} }
@@ -1065,7 +1055,7 @@ func Test_HandleGetUser(t *testing.T) {
}() }()
wg.Wait() wg.Wait()
acl, _ := mockServer.GetACL().(*internal_acl.ACL) a, _ := mockServer.GetACL().(*acl.ACL)
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil { if err != nil {
@@ -1080,20 +1070,20 @@ func Test_HandleGetUser(t *testing.T) {
r := resp.NewConn(conn) r := resp.NewConn(conn)
tests := []struct { tests := []struct {
presetUser *internal_acl.User presetUser *acl.User
cmd []resp.Value cmd []resp.Value
wantRes []resp.Value wantRes []resp.Value
wantErr string wantErr string
}{ }{
{ // 1. Get the user and all their details { // 1. Get the user and all their details
presetUser: &internal_acl.User{ presetUser: &acl.User{
Username: "get_user_1", Username: "get_user_1",
Enabled: true, Enabled: true,
NoPassword: false, NoPassword: false,
NoKeys: false, NoKeys: false,
Passwords: []internal_acl.Password{ Passwords: []acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "get_user_password_1"}, {PasswordType: acl.PasswordPlainText, PasswordValue: "get_user_password_1"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("get_user_password_2")}, {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("get_user_password_2")},
}, },
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory}, IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory}, ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
@@ -1165,7 +1155,7 @@ func Test_HandleGetUser(t *testing.T) {
for _, test := range tests { for _, test := range tests {
if test.presetUser != nil { if test.presetUser != nil {
acl.AddUsers([]*internal_acl.User{test.presetUser}) a.AddUsers([]*acl.User{test.presetUser})
} }
if err = r.WriteArray(test.cmd); err != nil { if err = r.WriteArray(test.cmd); err != nil {
t.Error(err) t.Error(err)
@@ -1218,7 +1208,7 @@ func Test_HandleDelUser(t *testing.T) {
}() }()
wg.Wait() wg.Wait()
acl, _ := mockServer.GetACL().(*internal_acl.ACL) a, _ := mockServer.GetACL().(*acl.ACL)
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil { if err != nil {
@@ -1233,14 +1223,14 @@ func Test_HandleDelUser(t *testing.T) {
r := resp.NewConn(conn) r := resp.NewConn(conn)
tests := []struct { tests := []struct {
presetUser *internal_acl.User presetUser *acl.User
cmd []resp.Value cmd []resp.Value
wantRes string wantRes string
wantErr string wantErr string
}{ }{
{ {
// 1. Delete existing user while skipping default user and non-existent user // 1. Delete existing user while skipping default user and non-existent user
presetUser: internal_acl.CreateUser("user_to_delete"), presetUser: acl.CreateUser("user_to_delete"),
cmd: []resp.Value{ cmd: []resp.Value{
resp.StringValue("ACL"), resp.StringValue("ACL"),
resp.StringValue("DELUSER"), resp.StringValue("DELUSER"),
@@ -1262,7 +1252,7 @@ func Test_HandleDelUser(t *testing.T) {
for _, test := range tests { for _, test := range tests {
if test.presetUser != nil { if test.presetUser != nil {
acl.AddUsers([]*internal_acl.User{test.presetUser}) a.AddUsers([]*acl.User{test.presetUser})
} }
if err = r.WriteArray(test.cmd); err != nil { if err = r.WriteArray(test.cmd); err != nil {
t.Error(err) t.Error(err)
@@ -1278,13 +1268,13 @@ func Test_HandleDelUser(t *testing.T) {
continue continue
} }
// Check that default user still exists in the list of users // Check that default user still exists in the list of users
if !slices.ContainsFunc(acl.Users, func(user *internal_acl.User) bool { if !slices.ContainsFunc(a.Users, func(user *acl.User) bool {
return user.Username == "default" return user.Username == "default"
}) { }) {
t.Error("could not find user with username \"default\" in the ACL after deleting user") t.Error("could not find user with username \"default\" in the ACL after deleting user")
} }
// Check that the deleted user is no longer in the list // Check that the deleted user is no longer in the list
if slices.ContainsFunc(acl.Users, func(user *internal_acl.User) bool { if slices.ContainsFunc(a.Users, func(user *acl.User) bool {
return user.Username == "user_to_delete" return user.Username == "user_to_delete"
}) { }) {
t.Error("deleted user found in the ACL") t.Error("deleted user found in the ACL")
@@ -1368,7 +1358,7 @@ func Test_HandleList(t *testing.T) {
}() }()
wg.Wait() wg.Wait()
acl, _ := mockServer.GetACL().(*internal_acl.ACL) a, _ := mockServer.GetACL().(*acl.ACL)
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil { if err != nil {
@@ -1383,21 +1373,21 @@ func Test_HandleList(t *testing.T) {
r := resp.NewConn(conn) r := resp.NewConn(conn)
tests := []struct { tests := []struct {
presetUsers []*internal_acl.User presetUsers []*acl.User
cmd []resp.Value cmd []resp.Value
wantRes []string wantRes []string
wantErr string wantErr string
}{ }{
{ // 1. Get the user and all their details { // 1. Get the user and all their details
presetUsers: []*internal_acl.User{ presetUsers: []*acl.User{
{ {
Username: "list_user_1", Username: "list_user_1",
Enabled: true, Enabled: true,
NoPassword: false, NoPassword: false,
NoKeys: false, NoKeys: false,
Passwords: []internal_acl.Password{ Passwords: []acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "list_user_password_1"}, {PasswordType: acl.PasswordPlainText, PasswordValue: "list_user_password_1"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_2")}, {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_2")},
}, },
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory}, IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory}, ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
@@ -1413,7 +1403,7 @@ func Test_HandleList(t *testing.T) {
Enabled: true, Enabled: true,
NoPassword: true, NoPassword: true,
NoKeys: true, NoKeys: true,
Passwords: []internal_acl.Password{}, Passwords: []acl.Password{},
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory}, IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory}, ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
IncludedCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"}, IncludedCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"},
@@ -1428,9 +1418,9 @@ func Test_HandleList(t *testing.T) {
Enabled: true, Enabled: true,
NoPassword: false, NoPassword: false,
NoKeys: false, NoKeys: false,
Passwords: []internal_acl.Password{ Passwords: []acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "list_user_password_3"}, {PasswordType: acl.PasswordPlainText, PasswordValue: "list_user_password_3"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_4")}, {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_4")},
}, },
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory}, IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory}, ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
@@ -1457,7 +1447,7 @@ func Test_HandleList(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
acl.AddUsers(test.presetUsers) a.AddUsers(test.presetUsers)
if err = r.WriteArray(test.cmd); err != nil { if err = r.WriteArray(test.cmd); err != nil {
t.Error(err) t.Error(err)

View File

@@ -21,7 +21,6 @@ import (
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/admin"
"github.com/echovault/echovault/pkg/types" "github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
@@ -33,7 +32,6 @@ var mockServer *echovault.EchoVault
func init() { func init() {
mockServer, _ = echovault.NewEchoVault( mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(admin.Commands()),
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
DataDir: "", DataDir: "",
EvictionPolicy: constants.NoEviction, EvictionPolicy: constants.NoEviction,

View File

@@ -21,7 +21,6 @@ import (
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/connection"
"github.com/echovault/echovault/pkg/types" "github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
@@ -33,7 +32,6 @@ var mockServer *echovault.EchoVault
func init() { func init() {
mockServer, _ = echovault.NewEchoVault( mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(connection.Commands()),
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
DataDir: "", DataDir: "",
EvictionPolicy: constants.NoEviction, EvictionPolicy: constants.NoEviction,

View File

@@ -20,7 +20,6 @@ import (
"github.com/echovault/echovault/internal/clock" "github.com/echovault/echovault/internal/clock"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/generic"
"reflect" "reflect"
"slices" "slices"
"strings" "strings"
@@ -30,7 +29,6 @@ import (
func createEchoVault() *echovault.EchoVault { func createEchoVault() *echovault.EchoVault {
ev, _ := echovault.NewEchoVault( ev, _ := echovault.NewEchoVault(
echovault.WithCommands(generic.Commands()),
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
DataDir: "", DataDir: "",
}), }),

View File

@@ -23,7 +23,6 @@ import (
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/generic"
"github.com/echovault/echovault/pkg/types" "github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
@@ -45,7 +44,6 @@ func init() {
mockClock = clock.NewClock() mockClock = clock.NewClock()
mockServer, _ = echovault.NewEchoVault( mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(generic.Commands()),
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
DataDir: "", DataDir: "",
EvictionPolicy: constants.NoEviction, EvictionPolicy: constants.NoEviction,

View File

@@ -18,7 +18,6 @@ import (
"context" "context"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/hash"
"reflect" "reflect"
"slices" "slices"
"testing" "testing"
@@ -26,7 +25,6 @@ import (
func createEchoVault() *echovault.EchoVault { func createEchoVault() *echovault.EchoVault {
ev, _ := echovault.NewEchoVault( ev, _ := echovault.NewEchoVault(
echovault.WithCommands(hash.Commands()),
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
DataDir: "", DataDir: "",
}), }),

View File

@@ -22,7 +22,6 @@ import (
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/hash"
"github.com/echovault/echovault/pkg/types" "github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
@@ -35,7 +34,6 @@ var mockServer *echovault.EchoVault
func init() { func init() {
mockServer, _ = echovault.NewEchoVault( mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(hash.Commands()),
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
DataDir: "", DataDir: "",
EvictionPolicy: constants.NoEviction, EvictionPolicy: constants.NoEviction,

View File

@@ -18,14 +18,12 @@ import (
"context" "context"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/list"
"reflect" "reflect"
"testing" "testing"
) )
func createEchoVault() *echovault.EchoVault { func createEchoVault() *echovault.EchoVault {
ev, _ := echovault.NewEchoVault( ev, _ := echovault.NewEchoVault(
echovault.WithCommands(list.Commands()),
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
DataDir: "", DataDir: "",
}), }),

View File

@@ -22,7 +22,6 @@ import (
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/list"
"github.com/echovault/echovault/pkg/types" "github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
@@ -34,7 +33,6 @@ var mockServer *echovault.EchoVault
func init() { func init() {
mockServer, _ = echovault.NewEchoVault( mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(list.Commands()),
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
DataDir: "", DataDir: "",
EvictionPolicy: constants.NoEviction, EvictionPolicy: constants.NoEviction,

View File

@@ -19,10 +19,9 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
internal_pubsub "github.com/echovault/echovault/internal/pubsub" "github.com/echovault/echovault/internal/modules/pubsub"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
ps "github.com/echovault/echovault/pkg/modules/pubsub"
"github.com/echovault/echovault/pkg/types" "github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
@@ -33,7 +32,7 @@ import (
"time" "time"
) )
var pubsub *internal_pubsub.PubSub var ps *pubsub.PubSub
var mockServer *echovault.EchoVault var mockServer *echovault.EchoVault
var bindAddr = "localhost" var bindAddr = "localhost"
@@ -41,7 +40,7 @@ var port uint16 = 7490
func init() { func init() {
mockServer = setUpServer(bindAddr, port) mockServer = setUpServer(bindAddr, port)
pubsub = mockServer.GetPubSub().(*internal_pubsub.PubSub) ps = mockServer.GetPubSub().(*pubsub.PubSub)
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(1) wg.Add(1)
@@ -54,7 +53,6 @@ func init() {
func setUpServer(bindAddr string, port uint16) *echovault.EchoVault { func setUpServer(bindAddr string, port uint16) *echovault.EchoVault {
server, _ := echovault.NewEchoVault( server, _ := echovault.NewEchoVault(
echovault.WithCommands(ps.Commands()),
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
BindAddr: bindAddr, BindAddr: bindAddr,
Port: port, Port: port,
@@ -126,12 +124,12 @@ func Test_HandleSubscribe(t *testing.T) {
} }
for _, channel := range channels { for _, channel := range channels {
// Check if the channel exists in the pubsub module // Check if the channel exists in the pubsub module
if !slices.ContainsFunc(pubsub.GetAllChannels(), func(c *internal_pubsub.Channel) bool { if !slices.ContainsFunc(ps.GetAllChannels(), func(c *pubsub.Channel) bool {
return c.Name() == channel return c.Name() == channel
}) { }) {
t.Errorf("expected pubsub to contain channel \"%s\" but it was not found", channel) t.Errorf("expected pubsub to contain channel \"%s\" but it was not found", channel)
} }
for _, c := range pubsub.GetAllChannels() { for _, c := range ps.GetAllChannels() {
if c.Name() == channel { if c.Name() == channel {
// Check if channel has nil pattern // Check if channel has nil pattern
if c.Pattern() != nil { if c.Pattern() != nil {
@@ -157,12 +155,12 @@ func Test_HandleSubscribe(t *testing.T) {
} }
for _, pattern := range patterns { for _, pattern := range patterns {
// Check if pattern channel exists in pubsub module // Check if pattern channel exists in pubsub module
if !slices.ContainsFunc(pubsub.GetAllChannels(), func(c *internal_pubsub.Channel) bool { if !slices.ContainsFunc(ps.GetAllChannels(), func(c *pubsub.Channel) bool {
return c.Name() == pattern return c.Name() == pattern
}) { }) {
t.Errorf("expected pubsub to contain pattern channel \"%s\" but it was not found", pattern) t.Errorf("expected pubsub to contain pattern channel \"%s\" but it was not found", pattern)
} }
for _, c := range pubsub.GetAllChannels() { for _, c := range ps.GetAllChannels() {
if c.Name() == pattern { if c.Name() == pattern {
// Check if channel has non-nil pattern // Check if channel has non-nil pattern
if c.Pattern() == nil { if c.Pattern() == nil {
@@ -322,7 +320,7 @@ func Test_HandleUnsubscribe(t *testing.T) {
verifyResponse(res, test.expectedResponses["pattern"]) verifyResponse(res, test.expectedResponses["pattern"])
for _, channel := range append(test.unSubChannels, test.unSubPatterns...) { for _, channel := range append(test.unSubChannels, test.unSubPatterns...) {
for _, pubsubChannel := range pubsub.GetAllChannels() { for _, pubsubChannel := range ps.GetAllChannels() {
if pubsubChannel.Name() == channel { if pubsubChannel.Name() == channel {
// Assert that target connection is no longer in the unsub channels and patterns // Assert that target connection is no longer in the unsub channels and patterns
if _, ok := pubsubChannel.Subscribers()[test.targetConn]; ok { if _, ok := pubsubChannel.Subscribers()[test.targetConn]; ok {
@@ -339,7 +337,7 @@ func Test_HandleUnsubscribe(t *testing.T) {
// Assert that the target connection is still in the remain channels and patterns // Assert that the target connection is still in the remain channels and patterns
for _, channel := range append(test.remainChannels, test.remainPatterns...) { for _, channel := range append(test.remainChannels, test.remainPatterns...) {
for _, pubsubChannel := range pubsub.GetAllChannels() { for _, pubsubChannel := range ps.GetAllChannels() {
if pubsubChannel.Name() == channel { if pubsubChannel.Name() == channel {
if _, ok := pubsubChannel.Subscribers()[test.targetConn]; !ok { if _, ok := pubsubChannel.Subscribers()[test.targetConn]; !ok {
t.Errorf("could not find expected target connection in channel \"%s\"", channel) t.Errorf("could not find expected target connection in channel \"%s\"", channel)

View File

@@ -17,9 +17,8 @@ package set
import ( import (
"context" "context"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/set" "github.com/echovault/echovault/internal/modules/set"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
s "github.com/echovault/echovault/pkg/modules/set"
"reflect" "reflect"
"slices" "slices"
"testing" "testing"
@@ -27,7 +26,6 @@ import (
func createEchoVault() *echovault.EchoVault { func createEchoVault() *echovault.EchoVault {
ev, _ := echovault.NewEchoVault( ev, _ := echovault.NewEchoVault(
echovault.WithCommands(s.Commands()),
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
DataDir: "", DataDir: "",
}), }),

View File

@@ -20,10 +20,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/set" "github.com/echovault/echovault/internal/modules/set"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
s "github.com/echovault/echovault/pkg/modules/set"
"github.com/echovault/echovault/pkg/types" "github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
@@ -36,7 +35,6 @@ var mockServer *echovault.EchoVault
func init() { func init() {
mockServer, _ = echovault.NewEchoVault( mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(s.Commands()),
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
DataDir: "", DataDir: "",
EvictionPolicy: constants.NoEviction, EvictionPolicy: constants.NoEviction,

File diff suppressed because it is too large Load Diff

View File

@@ -20,10 +20,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/sorted_set" "github.com/echovault/echovault/internal/modules/sorted_set"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
ss "github.com/echovault/echovault/pkg/modules/sorted_set"
"github.com/echovault/echovault/pkg/types" "github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"math" "math"
@@ -38,7 +37,6 @@ var mockServer *echovault.EchoVault
func init() { func init() {
mockServer, _ = echovault.NewEchoVault( mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(ss.Commands()),
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
DataDir: "", DataDir: "",
EvictionPolicy: constants.NoEviction, EvictionPolicy: constants.NoEviction,

View File

@@ -18,13 +18,11 @@ import (
"context" "context"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
str "github.com/echovault/echovault/pkg/modules/string"
"testing" "testing"
) )
func createEchoVault() *echovault.EchoVault { func createEchoVault() *echovault.EchoVault {
ev, _ := echovault.NewEchoVault( ev, _ := echovault.NewEchoVault(
echovault.WithCommands(str.Commands()),
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
DataDir: "", DataDir: "",
}), }),

View File

@@ -23,7 +23,6 @@ import (
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
str "github.com/echovault/echovault/pkg/modules/string"
"github.com/echovault/echovault/pkg/types" "github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
@@ -36,7 +35,6 @@ var mockServer *echovault.EchoVault
func init() { func init() {
mockServer, _ = echovault.NewEchoVault( mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(str.Commands()),
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
DataDir: "", DataDir: "",
EvictionPolicy: constants.NoEviction, EvictionPolicy: constants.NoEviction,