Removed duplicate imports for set, sorted_set, pubsub and acl modules. Moved /modules from /pkg to /internal. Delted commands package: Commands will now be automatically loaded when an EchoVault instance is initialised.

This commit is contained in:
Kelvin Clement Mwinuka
2024-04-24 22:37:16 +08:00
parent 3e04b7a822
commit b6ddb43a49
46 changed files with 645 additions and 735 deletions

View File

@@ -7,7 +7,7 @@ build:
run:
make build && docker-compose up --build
test-normal:
test-unit:
go clean -testcache && go test ./... -coverprofile coverage/coverage.out
test-race:

View File

@@ -18,7 +18,6 @@ import (
"context"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/commands"
"github.com/echovault/echovault/pkg/echovault"
"log"
"os"
@@ -49,7 +48,6 @@ func main() {
server, err := echovault.NewEchoVault(
echovault.WithContext(ctx),
echovault.WithConfig(conf),
echovault.WithCommands(commands.All()),
)
if err != nil {

View File

@@ -18,7 +18,6 @@ import (
"encoding/json"
"errors"
"fmt"
internal_acl "github.com/echovault/echovault/internal/acl"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"gopkg.in/yaml.v3"
@@ -33,7 +32,7 @@ func handleAuth(params types.HandlerFuncParams) ([]byte, error) {
if len(params.Command) < 2 || len(params.Command) > 3 {
return nil, errors.New(constants.WrongArgsResponse)
}
acl, ok := params.GetACL().(*internal_acl.ACL)
acl, ok := params.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
@@ -48,12 +47,12 @@ func handleGetUser(params types.HandlerFuncParams) ([]byte, error) {
return nil, errors.New(constants.WrongArgsResponse)
}
acl, ok := params.GetACL().(*internal_acl.ACL)
acl, ok := params.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
var user *internal_acl.User
var user *User
userFound := false
for _, u := range acl.Users {
if u.Username == params.Command[2] {
@@ -221,7 +220,7 @@ func handleCat(params types.HandlerFuncParams) ([]byte, error) {
}
func handleUsers(params types.HandlerFuncParams) ([]byte, error) {
acl, ok := params.GetACL().(*internal_acl.ACL)
acl, ok := params.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
@@ -234,7 +233,7 @@ func handleUsers(params types.HandlerFuncParams) ([]byte, error) {
}
func handleSetUser(params types.HandlerFuncParams) ([]byte, error) {
acl, ok := params.GetACL().(*internal_acl.ACL)
acl, ok := params.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
@@ -248,7 +247,7 @@ func handleDelUser(params types.HandlerFuncParams) ([]byte, error) {
if len(params.Command) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
}
acl, ok := params.GetACL().(*internal_acl.ACL)
acl, ok := params.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
@@ -259,7 +258,7 @@ func handleDelUser(params types.HandlerFuncParams) ([]byte, error) {
}
func handleWhoAmI(params types.HandlerFuncParams) ([]byte, error) {
acl, ok := params.GetACL().(*internal_acl.ACL)
acl, ok := params.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
@@ -271,7 +270,7 @@ func handleList(params types.HandlerFuncParams) ([]byte, error) {
if len(params.Command) > 2 {
return nil, errors.New(constants.WrongArgsResponse)
}
acl, ok := params.GetACL().(*internal_acl.ACL)
acl, ok := params.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
@@ -368,7 +367,7 @@ func handleLoad(params types.HandlerFuncParams) ([]byte, error) {
return nil, errors.New(constants.WrongArgsResponse)
}
acl, ok := params.GetACL().(*internal_acl.ACL)
acl, ok := params.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
@@ -389,7 +388,7 @@ func handleLoad(params types.HandlerFuncParams) ([]byte, error) {
ext := path.Ext(f.Name())
var users []*internal_acl.User
var users []*User
if ext == ".json" {
if err := json.NewDecoder(f).Decode(&users); err != nil {
@@ -435,7 +434,7 @@ func handleSave(params types.HandlerFuncParams) ([]byte, error) {
return nil, errors.New(constants.WrongArgsResponse)
}
acl, ok := params.GetACL().(*internal_acl.ACL)
acl, ok := params.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
}

View File

@@ -17,14 +17,13 @@ package pubsub
import (
"errors"
"fmt"
internal_pubsub "github.com/echovault/echovault/internal/pubsub"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"strings"
)
func handleSubscribe(params types.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
pubsub, ok := params.GetPubSub().(*PubSub)
if !ok {
return nil, errors.New("could not load pubsub module")
}
@@ -42,7 +41,7 @@ func handleSubscribe(params types.HandlerFuncParams) ([]byte, error) {
}
func handleUnsubscribe(params types.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
pubsub, ok := params.GetPubSub().(*PubSub)
if !ok {
return nil, errors.New("could not load pubsub module")
}
@@ -55,7 +54,7 @@ func handleUnsubscribe(params types.HandlerFuncParams) ([]byte, error) {
}
func handlePublish(params types.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
pubsub, ok := params.GetPubSub().(*PubSub)
if !ok {
return nil, errors.New("could not load pubsub module")
}
@@ -71,7 +70,7 @@ func handlePubSubChannels(params types.HandlerFuncParams) ([]byte, error) {
return nil, errors.New(constants.WrongArgsResponse)
}
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
pubsub, ok := params.GetPubSub().(*PubSub)
if !ok {
return nil, errors.New("could not load pubsub module")
}
@@ -85,7 +84,7 @@ func handlePubSubChannels(params types.HandlerFuncParams) ([]byte, error) {
}
func handlePubSubNumPat(params types.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
pubsub, ok := params.GetPubSub().(*PubSub)
if !ok {
return nil, errors.New("could not load pubsub module")
}
@@ -94,7 +93,7 @@ func handlePubSubNumPat(params types.HandlerFuncParams) ([]byte, error) {
}
func handlePubSubNumSubs(params types.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
pubsub, ok := params.GetPubSub().(*PubSub)
if !ok {
return nil, errors.New("could not load pubsub module")
}

View File

@@ -18,7 +18,6 @@ import (
"errors"
"fmt"
"github.com/echovault/echovault/internal"
internal_set "github.com/echovault/echovault/internal/set"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"slices"
@@ -33,10 +32,10 @@ func handleSADD(params types.HandlerFuncParams) ([]byte, error) {
key := keys.WriteKeys[0]
var set *internal_set.Set
var set *Set
if !params.KeyExists(params.Context, key) {
set = internal_set.NewSet(params.Command[2:])
set = NewSet(params.Command[2:])
if ok, err := params.CreateKeyAndLock(params.Context, key); !ok && err != nil {
return nil, err
}
@@ -52,7 +51,7 @@ func handleSADD(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
@@ -79,7 +78,7 @@ func handleSCARD(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
@@ -103,7 +102,7 @@ func handleSDIFF(params types.HandlerFuncParams) ([]byte, error) {
return nil, err
}
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*internal_set.Set)
baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
}
@@ -127,9 +126,9 @@ func handleSDIFF(params types.HandlerFuncParams) ([]byte, error) {
locks[key] = true
}
var sets []*internal_set.Set
var sets []*Set
for _, key := range params.Command[2:] {
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*Set)
if !ok {
continue
}
@@ -166,7 +165,7 @@ func handleSDIFFSTORE(params types.HandlerFuncParams) ([]byte, error) {
return nil, err
}
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*internal_set.Set)
baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
}
@@ -190,9 +189,9 @@ func handleSDIFFSTORE(params types.HandlerFuncParams) ([]byte, error) {
locks[key] = true
}
var sets []*internal_set.Set
var sets []*Set
for _, key := range keys.ReadKeys[1:] {
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*Set)
if !ok {
continue
}
@@ -252,10 +251,10 @@ func handleSINTER(params types.HandlerFuncParams) ([]byte, error) {
locks[key] = true
}
var sets []*internal_set.Set
var sets []*Set
for key, _ := range locks {
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*Set)
if !ok {
// If the value at the key is not a set, return error
return nil, fmt.Errorf("value at key %s is not a set", key)
@@ -267,7 +266,7 @@ func handleSINTER(params types.HandlerFuncParams) ([]byte, error) {
return nil, fmt.Errorf("not enough sets in the keys provided")
}
intersect, _ := internal_set.Intersection(0, sets...)
intersect, _ := Intersection(0, sets...)
elems := intersect.GetAll()
res := fmt.Sprintf("*%d", len(elems))
@@ -328,10 +327,10 @@ func handleSINTERCARD(params types.HandlerFuncParams) ([]byte, error) {
locks[key] = true
}
var sets []*internal_set.Set
var sets []*Set
for key, _ := range locks {
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*Set)
if !ok {
// If the value at the key is not a set, return error
return nil, fmt.Errorf("value at key %s is not a set", key)
@@ -343,7 +342,7 @@ func handleSINTERCARD(params types.HandlerFuncParams) ([]byte, error) {
return nil, fmt.Errorf("not enough sets in the keys provided")
}
intersect, _ := internal_set.Intersection(limit, sets...)
intersect, _ := Intersection(limit, sets...)
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
}
@@ -374,10 +373,10 @@ func handleSINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
locks[key] = true
}
var sets []*internal_set.Set
var sets []*Set
for key, _ := range locks {
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*Set)
if !ok {
// If the value at the key is not a set, return error
return nil, fmt.Errorf("value at key %s is not a set", key)
@@ -385,7 +384,7 @@ func handleSINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
sets = append(sets, set)
}
intersect, _ := internal_set.Intersection(0, sets...)
intersect, _ := Intersection(0, sets...)
destination := keys.WriteKeys[0]
if params.KeyExists(params.Context, destination) {
@@ -423,7 +422,7 @@ func handleSISMEMBER(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
@@ -452,7 +451,7 @@ func handleSMEMBERS(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
@@ -495,7 +494,7 @@ func handleSMISMEMBER(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
@@ -531,12 +530,12 @@ func handleSMOVE(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyUnlock(params.Context, source)
sourceSet, ok := params.GetValue(params.Context, source).(*internal_set.Set)
sourceSet, ok := params.GetValue(params.Context, source).(*Set)
if !ok {
return nil, errors.New("source is not a set")
}
var destinationSet *internal_set.Set
var destinationSet *Set
if !params.KeyExists(params.Context, destination) {
// Destination key does not exist
@@ -544,7 +543,7 @@ func handleSMOVE(params types.HandlerFuncParams) ([]byte, error) {
return nil, err
}
defer params.KeyUnlock(params.Context, destination)
destinationSet = internal_set.NewSet([]string{})
destinationSet = NewSet([]string{})
if err = params.SetValue(params.Context, destination, destinationSet); err != nil {
return nil, err
}
@@ -554,7 +553,7 @@ func handleSMOVE(params types.HandlerFuncParams) ([]byte, error) {
return nil, err
}
defer params.KeyUnlock(params.Context, destination)
ds, ok := params.GetValue(params.Context, destination).(*internal_set.Set)
ds, ok := params.GetValue(params.Context, destination).(*Set)
if !ok {
return nil, errors.New("destination is not a set")
}
@@ -592,7 +591,7 @@ func handleSPOP(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*Set)
if !ok {
return nil, fmt.Errorf("value at %s is not a set", key)
}
@@ -636,7 +635,7 @@ func handleSRANDMEMBER(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*Set)
if !ok {
return nil, fmt.Errorf("value at %s is not a set", key)
}
@@ -672,7 +671,7 @@ func handleSREM(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
@@ -707,20 +706,20 @@ func handleSUNION(params types.HandlerFuncParams) ([]byte, error) {
locks[key] = true
}
var sets []*internal_set.Set
var sets []*Set
for key, locked := range locks {
if !locked {
continue
}
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
sets = append(sets, set)
}
union := internal_set.Union(sets...)
union := Union(sets...)
res := fmt.Sprintf("*%d", union.Cardinality())
for i, e := range union.GetAll() {
@@ -758,20 +757,20 @@ func handleSUNIONSTORE(params types.HandlerFuncParams) ([]byte, error) {
locks[key] = true
}
var sets []*internal_set.Set
var sets []*Set
for key, locked := range locks {
if !locked {
continue
}
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
sets = append(sets, set)
}
union := internal_set.Union(sets...)
union := Union(sets...)
destination := keys.WriteKeys[0]

View File

@@ -19,7 +19,6 @@ import (
"errors"
"fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/sorted_set"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"math"
@@ -63,7 +62,7 @@ func handleZADD(params types.HandlerFuncParams) ([]byte, error) {
return nil, errors.New("score/member pairs must be float/string")
}
var members []sorted_set.MemberParam
var members []MemberParam
for i := 0; i < len(params.Command[membersStartIndex:]); i++ {
if i%2 != 0 {
@@ -77,29 +76,29 @@ func handleZADD(params types.HandlerFuncParams) ([]byte, error) {
var s float64
if strings.ToLower(score.(string)) == "-inf" {
s = math.Inf(-1)
members = append(members, sorted_set.MemberParam{
Value: sorted_set.Value(params.Command[membersStartIndex:][i+1]),
Score: sorted_set.Score(s),
members = append(members, MemberParam{
Value: Value(params.Command[membersStartIndex:][i+1]),
Score: Score(s),
})
}
if strings.ToLower(score.(string)) == "+inf" {
s = math.Inf(1)
members = append(members, sorted_set.MemberParam{
Value: sorted_set.Value(params.Command[membersStartIndex:][i+1]),
Score: sorted_set.Score(s),
members = append(members, MemberParam{
Value: Value(params.Command[membersStartIndex:][i+1]),
Score: Score(s),
})
}
case float64:
s, _ := score.(float64)
members = append(members, sorted_set.MemberParam{
Value: sorted_set.Value(params.Command[membersStartIndex:][i+1]),
Score: sorted_set.Score(s),
members = append(members, MemberParam{
Value: Value(params.Command[membersStartIndex:][i+1]),
Score: Score(s),
})
case int:
s, _ := score.(int)
members = append(members, sorted_set.MemberParam{
Value: sorted_set.Value(params.Command[membersStartIndex:][i+1]),
Score: sorted_set.Score(s),
members = append(members, MemberParam{
Value: Value(params.Command[membersStartIndex:][i+1]),
Score: Score(s),
})
}
}
@@ -148,7 +147,7 @@ func handleZADD(params types.HandlerFuncParams) ([]byte, error) {
return nil, err
}
defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -171,7 +170,7 @@ func handleZADD(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyUnlock(params.Context, key)
set := sorted_set.NewSortedSet(members)
set := NewSortedSet(members)
if err = params.SetValue(params.Context, key, set); err != nil {
return nil, err
}
@@ -195,7 +194,7 @@ func handleZCARD(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -211,40 +210,40 @@ func handleZCOUNT(params types.HandlerFuncParams) ([]byte, error) {
key := keys.ReadKeys[0]
minimum := sorted_set.Score(math.Inf(-1))
minimum := Score(math.Inf(-1))
switch internal.AdaptType(params.Command[2]).(type) {
default:
return nil, errors.New("min constraint must be a double")
case string:
if strings.ToLower(params.Command[2]) == "+inf" {
minimum = sorted_set.Score(math.Inf(1))
minimum = Score(math.Inf(1))
} else {
return nil, errors.New("min constraint must be a double")
}
case float64:
s, _ := internal.AdaptType(params.Command[2]).(float64)
minimum = sorted_set.Score(s)
minimum = Score(s)
case int:
s, _ := internal.AdaptType(params.Command[2]).(int)
minimum = sorted_set.Score(s)
minimum = Score(s)
}
maximum := sorted_set.Score(math.Inf(1))
maximum := Score(math.Inf(1))
switch internal.AdaptType(params.Command[3]).(type) {
default:
return nil, errors.New("max constraint must be a double")
case string:
if strings.ToLower(params.Command[3]) == "-inf" {
maximum = sorted_set.Score(math.Inf(-1))
maximum = Score(math.Inf(-1))
} else {
return nil, errors.New("max constraint must be a double")
}
case float64:
s, _ := internal.AdaptType(params.Command[3]).(float64)
maximum = sorted_set.Score(s)
maximum = Score(s)
case int:
s, _ := internal.AdaptType(params.Command[3]).(int)
maximum = sorted_set.Score(s)
maximum = Score(s)
}
if !params.KeyExists(params.Context, key) {
@@ -256,12 +255,12 @@ func handleZCOUNT(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
var members []sorted_set.MemberParam
var members []MemberParam
for _, m := range set.GetAll() {
if m.Score >= minimum && m.Score <= maximum {
members = append(members, m)
@@ -290,7 +289,7 @@ func handleZLEXCOUNT(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -347,13 +346,13 @@ func handleZDIFF(params types.HandlerFuncParams) ([]byte, error) {
return nil, err
}
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
baseSortedSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*sorted_set.SortedSet)
baseSortedSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0])
}
// Extract the remaining sets
var sets []*sorted_set.SortedSet
var sets []*SortedSet
for i := 1; i < len(keys.ReadKeys); i++ {
if !params.KeyExists(params.Context, keys.ReadKeys[i]) {
@@ -364,7 +363,7 @@ func handleZDIFF(params types.HandlerFuncParams) ([]byte, error) {
return nil, err
}
locks[keys.ReadKeys[i]] = locked
set, ok := params.GetValue(params.Context, keys.ReadKeys[i]).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, keys.ReadKeys[i]).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i])
}
@@ -415,19 +414,19 @@ func handleZDIFFSTORE(params types.HandlerFuncParams) ([]byte, error) {
return nil, err
}
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
baseSortedSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*sorted_set.SortedSet)
baseSortedSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0])
}
var sets []*sorted_set.SortedSet
var sets []*SortedSet
for i := 1; i < len(keys.ReadKeys); i++ {
if params.KeyExists(params.Context, keys.ReadKeys[i]) {
if _, err = params.KeyRLock(params.Context, keys.ReadKeys[i]); err != nil {
return nil, err
}
set, ok := params.GetValue(params.Context, keys.ReadKeys[i]).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, keys.ReadKeys[i]).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i])
}
@@ -462,26 +461,26 @@ func handleZINCRBY(params types.HandlerFuncParams) ([]byte, error) {
}
key := keys.WriteKeys[0]
member := sorted_set.Value(params.Command[3])
var increment sorted_set.Score
member := Value(params.Command[3])
var increment Score
switch internal.AdaptType(params.Command[2]).(type) {
default:
return nil, errors.New("increment must be a double")
case string:
if strings.EqualFold("-inf", strings.ToLower(params.Command[2])) {
increment = sorted_set.Score(math.Inf(-1))
increment = Score(math.Inf(-1))
} else if strings.EqualFold("+inf", strings.ToLower(params.Command[2])) {
increment = sorted_set.Score(math.Inf(1))
increment = Score(math.Inf(1))
} else {
return nil, errors.New("increment must be a double")
}
case float64:
s, _ := internal.AdaptType(params.Command[2]).(float64)
increment = sorted_set.Score(s)
increment = Score(s)
case int:
s, _ := internal.AdaptType(params.Command[2]).(int)
increment = sorted_set.Score(s)
increment = Score(s)
}
if !params.KeyExists(params.Context, key) {
@@ -493,7 +492,7 @@ func handleZINCRBY(params types.HandlerFuncParams) ([]byte, error) {
if err = params.SetValue(
params.Context,
key,
sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: member, Score: increment}}),
NewSortedSet([]MemberParam{{Value: member, Score: increment}}),
); err != nil {
return nil, err
}
@@ -505,12 +504,12 @@ func handleZINCRBY(params types.HandlerFuncParams) ([]byte, error) {
return nil, err
}
defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
if _, err = set.AddOrUpdate(
[]sorted_set.MemberParam{
[]MemberParam{
{Value: member, Score: increment}},
"xx",
nil,
@@ -542,7 +541,7 @@ func handleZINTER(params types.HandlerFuncParams) ([]byte, error) {
}
}()
var setParams []sorted_set.SortedSetParam
var setParams []SortedSetParam
for i := 0; i < len(keys); i++ {
if !params.KeyExists(params.Context, keys[i]) {
@@ -553,17 +552,17 @@ func handleZINTER(params types.HandlerFuncParams) ([]byte, error) {
return nil, err
}
locks[keys[i]] = true
set, ok := params.GetValue(params.Context, keys[i]).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
}
setParams = append(setParams, sorted_set.SortedSetParam{
setParams = append(setParams, SortedSetParam{
Set: set,
Weight: weights[i],
})
}
intersect := sorted_set.Intersect(aggregate, setParams...)
intersect := Intersect(aggregate, setParams...)
res := fmt.Sprintf("*%d", intersect.Cardinality())
@@ -609,7 +608,7 @@ func handleZINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
}
}()
var setParams []sorted_set.SortedSetParam
var setParams []SortedSetParam
for i := 0; i < len(keys); i++ {
if !params.KeyExists(params.Context, keys[i]) {
@@ -619,17 +618,17 @@ func handleZINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
return nil, err
}
locks[keys[i]] = true
set, ok := params.GetValue(params.Context, keys[i]).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
}
setParams = append(setParams, sorted_set.SortedSetParam{
setParams = append(setParams, SortedSetParam{
Set: set,
Weight: weights[i],
})
}
intersect := sorted_set.Intersect(aggregate, setParams...)
intersect := Intersect(aggregate, setParams...)
if params.KeyExists(params.Context, destination) && intersect.Cardinality() > 0 {
if _, err = params.KeyLock(params.Context, destination); err != nil {
@@ -700,7 +699,7 @@ func handleZMPOP(params types.HandlerFuncParams) ([]byte, error) {
if _, err = params.KeyLock(params.Context, keys.WriteKeys[i]); err != nil {
continue
}
v, ok := params.GetValue(params.Context, keys.WriteKeys[i]).(*sorted_set.SortedSet)
v, ok := params.GetValue(params.Context, keys.WriteKeys[i]).(*SortedSet)
if !ok || v.Cardinality() == 0 {
params.KeyUnlock(params.Context, keys.WriteKeys[i])
continue
@@ -760,7 +759,7 @@ func handleZPOP(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at key %s is not a sorted set", key)
}
@@ -797,7 +796,7 @@ func handleZMSCORE(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -806,10 +805,10 @@ func handleZMSCORE(params types.HandlerFuncParams) ([]byte, error) {
res := fmt.Sprintf("*%d", len(members))
var member sorted_set.MemberObject
var member MemberObject
for i := 0; i < len(members); i++ {
member = set.Get(sorted_set.Value(members[i]))
member = set.Get(Value(members[i]))
if !member.Exists {
res = fmt.Sprintf("%s\r\n$-1", res)
} else {
@@ -859,7 +858,7 @@ func handleZRANDMEMBER(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -903,13 +902,13 @@ func handleZRANK(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
members := set.GetAll()
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
slices.SortFunc(members, func(a, b MemberParam) int {
if strings.EqualFold(params.Command[0], "zrevrank") {
return cmp.Compare(b.Score, a.Score)
}
@@ -917,7 +916,7 @@ func handleZRANK(params types.HandlerFuncParams) ([]byte, error) {
})
for i := 0; i < len(members); i++ {
if members[i].Value == sorted_set.Value(member) {
if members[i].Value == Value(member) {
if withscores {
score := strconv.FormatFloat(float64(members[i].Score), 'f', -1, 64)
return []byte(fmt.Sprintf("*2\r\n:%d\r\n$%d\r\n%s\r\n", i, len(score), score)), nil
@@ -947,14 +946,14 @@ func handleZREM(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
deletedCount := 0
for _, m := range params.Command[2:] {
if set.Remove(sorted_set.Value(m)) {
if set.Remove(Value(m)) {
deletedCount += 1
}
}
@@ -977,11 +976,11 @@ func handleZSCORE(params types.HandlerFuncParams) ([]byte, error) {
return nil, err
}
defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
member := set.Get(sorted_set.Value(params.Command[2]))
member := set.Get(Value(params.Command[2]))
if !member.Exists {
return []byte("$-1\r\n"), nil
}
@@ -1020,13 +1019,13 @@ func handleZREMRANGEBYSCORE(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
for _, m := range set.GetAll() {
if m.Score >= sorted_set.Score(minimum) && m.Score <= sorted_set.Score(maximum) {
if m.Score >= Score(minimum) && m.Score <= Score(maximum) {
set.Remove(m.Value)
deletedCount += 1
}
@@ -1062,7 +1061,7 @@ func handleZREMRANGEBYRANK(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -1079,7 +1078,7 @@ func handleZREMRANGEBYRANK(params types.HandlerFuncParams) ([]byte, error) {
}
members := set.GetAll()
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
slices.SortFunc(members, func(a, b MemberParam) int {
return cmp.Compare(a.Score, b.Score)
})
@@ -1119,7 +1118,7 @@ func handleZREMRANGEBYLEX(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -1217,7 +1216,7 @@ func handleZRANGE(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyRUnlock(params.Context, key)
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, key).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -1231,7 +1230,7 @@ func handleZRANGE(params types.HandlerFuncParams) ([]byte, error) {
members := set.GetAll()
if strings.EqualFold(policy, "byscore") {
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
slices.SortFunc(members, func(a, b MemberParam) int {
// Do a score sort
if reverse {
return cmp.Compare(b.Score, a.Score)
@@ -1246,7 +1245,7 @@ func handleZRANGE(params types.HandlerFuncParams) ([]byte, error) {
return []byte("*0\r\n"), nil
}
}
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
slices.SortFunc(members, func(a, b MemberParam) int {
if reverse {
return internal.CompareLex(string(b.Value), string(a.Value))
}
@@ -1254,14 +1253,14 @@ func handleZRANGE(params types.HandlerFuncParams) ([]byte, error) {
})
}
var resultMembers []sorted_set.MemberParam
var resultMembers []MemberParam
for i := offset; i <= count; i++ {
if i >= len(members) {
break
}
if strings.EqualFold(policy, "byscore") {
if members[i].Score >= sorted_set.Score(scoreStart) && members[i].Score <= sorted_set.Score(scoreStop) {
if members[i].Score >= Score(scoreStart) && members[i].Score <= Score(scoreStop) {
resultMembers = append(resultMembers, members[i])
}
continue
@@ -1354,7 +1353,7 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
}
defer params.KeyRUnlock(params.Context, source)
set, ok := params.GetValue(params.Context, source).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, source).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", source)
}
@@ -1368,7 +1367,7 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
members := set.GetAll()
if strings.EqualFold(policy, "byscore") {
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
slices.SortFunc(members, func(a, b MemberParam) int {
// Do a score sort
if reverse {
return cmp.Compare(b.Score, a.Score)
@@ -1383,7 +1382,7 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
return []byte(":0\r\n"), nil
}
}
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
slices.SortFunc(members, func(a, b MemberParam) int {
if reverse {
return internal.CompareLex(string(b.Value), string(a.Value))
}
@@ -1391,14 +1390,14 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
})
}
var resultMembers []sorted_set.MemberParam
var resultMembers []MemberParam
for i := offset; i <= count; i++ {
if i >= len(members) {
break
}
if strings.EqualFold(policy, "byscore") {
if members[i].Score >= sorted_set.Score(scoreStart) && members[i].Score <= sorted_set.Score(scoreStop) {
if members[i].Score >= Score(scoreStart) && members[i].Score <= Score(scoreStop) {
resultMembers = append(resultMembers, members[i])
}
continue
@@ -1409,7 +1408,7 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
}
}
newSortedSet := sorted_set.NewSortedSet(resultMembers)
newSortedSet := NewSortedSet(resultMembers)
if params.KeyExists(params.Context, destination) {
if _, err = params.KeyLock(params.Context, destination); err != nil {
@@ -1448,7 +1447,7 @@ func handleZUNION(params types.HandlerFuncParams) ([]byte, error) {
}
}()
var setParams []sorted_set.SortedSetParam
var setParams []SortedSetParam
for i := 0; i < len(keys); i++ {
if params.KeyExists(params.Context, keys[i]) {
@@ -1456,18 +1455,18 @@ func handleZUNION(params types.HandlerFuncParams) ([]byte, error) {
return nil, err
}
locks[keys[i]] = true
set, ok := params.GetValue(params.Context, keys[i]).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
}
setParams = append(setParams, sorted_set.SortedSetParam{
setParams = append(setParams, SortedSetParam{
Set: set,
Weight: weights[i],
})
}
}
union := sorted_set.Union(aggregate, setParams...)
union := Union(aggregate, setParams...)
res := fmt.Sprintf("*%d", union.Cardinality())
for _, m := range union.GetAll() {
@@ -1510,7 +1509,7 @@ func handleZUNIONSTORE(params types.HandlerFuncParams) ([]byte, error) {
}
}()
var setParams []sorted_set.SortedSetParam
var setParams []SortedSetParam
for i := 0; i < len(keys); i++ {
if params.KeyExists(params.Context, keys[i]) {
@@ -1518,18 +1517,18 @@ func handleZUNIONSTORE(params types.HandlerFuncParams) ([]byte, error) {
return nil, err
}
locks[keys[i]] = true
set, ok := params.GetValue(params.Context, keys[i]).(*sorted_set.SortedSet)
set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
}
setParams = append(setParams, sorted_set.SortedSetParam{
setParams = append(setParams, SortedSetParam{
Set: set,
Weight: weights[i],
})
}
}
union := sorted_set.Union(aggregate, setParams...)
union := Union(aggregate, setParams...)
if params.KeyExists(params.Context, destination) {
if _, err = params.KeyLock(params.Context, destination); err != nil {

View File

@@ -90,3 +90,80 @@ func extractKeysWeightsAggregateWithScores(cmd []string) ([]string, []int, strin
return keys, weights, aggregate, withscores, nil
}
func validateUpdatePolicy(updatePolicy interface{}) (string, error) {
if updatePolicy == nil {
return "", nil
}
err := errors.New("update policy must be a string of Value NX or XX")
policy, ok := updatePolicy.(string)
if !ok {
return "", err
}
if !slices.Contains([]string{"nx", "xx"}, strings.ToLower(policy)) {
return "", err
}
return policy, nil
}
func validateComparison(comparison interface{}) (string, error) {
if comparison == nil {
return "", nil
}
err := errors.New("comparison condition must be a string of Value LT or GT")
comp, ok := comparison.(string)
if !ok {
return "", err
}
if !slices.Contains([]string{"lt", "gt"}, strings.ToLower(comp)) {
return "", err
}
return comp, nil
}
func validateChanged(changed interface{}) (string, error) {
if changed == nil {
return "", nil
}
err := errors.New("changed condition should be a string of Value CH")
ch, ok := changed.(string)
if !ok {
return "", err
}
if !strings.EqualFold(ch, "ch") {
return "", err
}
return ch, nil
}
func validateIncr(incr interface{}) (string, error) {
if incr == nil {
return "", nil
}
err := errors.New("incr condition should be a string of Value INCR")
i, ok := incr.(string)
if !ok {
return "", err
}
if !strings.EqualFold(i, "incr") {
return "", err
}
return i, nil
}
func compareScores(old Score, new Score, comp string) Score {
switch strings.ToLower(comp) {
default:
return new
case "lt":
if new < old {
return new
}
return old
case "gt":
if new > old {
return new
}
return old
}
}

View File

@@ -1,98 +0,0 @@
// Copyright 2024 Kelvin Clement Mwinuka
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sorted_set
import (
"errors"
"slices"
"strings"
)
func validateUpdatePolicy(updatePolicy interface{}) (string, error) {
if updatePolicy == nil {
return "", nil
}
err := errors.New("update policy must be a string of Value NX or XX")
policy, ok := updatePolicy.(string)
if !ok {
return "", err
}
if !slices.Contains([]string{"nx", "xx"}, strings.ToLower(policy)) {
return "", err
}
return policy, nil
}
func validateComparison(comparison interface{}) (string, error) {
if comparison == nil {
return "", nil
}
err := errors.New("comparison condition must be a string of Value LT or GT")
comp, ok := comparison.(string)
if !ok {
return "", err
}
if !slices.Contains([]string{"lt", "gt"}, strings.ToLower(comp)) {
return "", err
}
return comp, nil
}
func validateChanged(changed interface{}) (string, error) {
if changed == nil {
return "", nil
}
err := errors.New("changed condition should be a string of Value CH")
ch, ok := changed.(string)
if !ok {
return "", err
}
if !strings.EqualFold(ch, "ch") {
return "", err
}
return ch, nil
}
func validateIncr(incr interface{}) (string, error) {
if incr == nil {
return "", nil
}
err := errors.New("incr condition should be a string of Value INCR")
i, ok := incr.(string)
if !ok {
return "", err
}
if !strings.EqualFold(i, "incr") {
return "", err
}
return i, nil
}
func compareScores(old Score, new Score, comp string) Score {
switch strings.ToLower(comp) {
default:
return new
case "lt":
if new < old {
return new
}
return old
case "gt":
if new > old {
return new
}
return old
}
}

View File

@@ -1,31 +0,0 @@
package commands
import (
"github.com/echovault/echovault/pkg/modules/acl"
"github.com/echovault/echovault/pkg/modules/admin"
"github.com/echovault/echovault/pkg/modules/connection"
"github.com/echovault/echovault/pkg/modules/generic"
"github.com/echovault/echovault/pkg/modules/hash"
"github.com/echovault/echovault/pkg/modules/list"
"github.com/echovault/echovault/pkg/modules/pubsub"
"github.com/echovault/echovault/pkg/modules/set"
"github.com/echovault/echovault/pkg/modules/sorted_set"
str "github.com/echovault/echovault/pkg/modules/string"
"github.com/echovault/echovault/pkg/types"
)
// All returns all the commands currently available on EchoVault
func All() []types.Command {
var commands []types.Command
commands = append(commands, acl.Commands()...)
commands = append(commands, admin.Commands()...)
commands = append(commands, generic.Commands()...)
commands = append(commands, hash.Commands()...)
commands = append(commands, list.Commands()...)
commands = append(commands, connection.Commands()...)
commands = append(commands, pubsub.Commands()...)
commands = append(commands, set.Commands()...)
commands = append(commands, sorted_set.Commands()...)
commands = append(commands, str.Commands()...)
return commands
}

View File

@@ -21,13 +21,21 @@ import (
"errors"
"fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/acl"
"github.com/echovault/echovault/internal/aof"
"github.com/echovault/echovault/internal/clock"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/eviction"
"github.com/echovault/echovault/internal/memberlist"
"github.com/echovault/echovault/internal/pubsub"
"github.com/echovault/echovault/internal/modules/acl"
"github.com/echovault/echovault/internal/modules/admin"
"github.com/echovault/echovault/internal/modules/connection"
"github.com/echovault/echovault/internal/modules/generic"
"github.com/echovault/echovault/internal/modules/hash"
"github.com/echovault/echovault/internal/modules/list"
"github.com/echovault/echovault/internal/modules/pubsub"
"github.com/echovault/echovault/internal/modules/set"
"github.com/echovault/echovault/internal/modules/sorted_set"
str "github.com/echovault/echovault/internal/modules/string"
"github.com/echovault/echovault/internal/raft"
"github.com/echovault/echovault/internal/snapshot"
"github.com/echovault/echovault/pkg/constants"
@@ -126,11 +134,24 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
echovault := &EchoVault{
clock: clock.NewClock(),
context: context.Background(),
commands: make([]types.Command, 0),
config: config.DefaultConfig(),
store: make(map[string]internal.KeyData),
keyLocks: make(map[string]*sync.RWMutex),
keyCreationLock: &sync.Mutex{},
commands: func() []types.Command {
var commands []types.Command
commands = append(commands, acl.Commands()...)
commands = append(commands, admin.Commands()...)
commands = append(commands, generic.Commands()...)
commands = append(commands, hash.Commands()...)
commands = append(commands, list.Commands()...)
commands = append(commands, connection.Commands()...)
commands = append(commands, pubsub.Commands()...)
commands = append(commands, set.Commands()...)
commands = append(commands, sorted_set.Commands()...)
commands = append(commands, str.Commands()...)
return commands
}(),
}
for _, option := range options {

View File

@@ -589,16 +589,3 @@ func (server *EchoVault) evictKeysWithExpiredTTL(ctx context.Context) error {
return nil
}
func presetValue(server *EchoVault, key string, value interface{}) {
_, _ = server.CreateKeyAndLock(server.context, key)
_ = server.SetValue(server.context, key, value)
server.KeyUnlock(server.context, key)
}
func presetKeyData(server *EchoVault, key string, data internal.KeyData) {
_, _ = server.CreateKeyAndLock(server.context, key)
defer server.KeyUnlock(server.context, key)
_ = server.SetValue(server.context, key, data.Value)
server.SetExpiry(server.context, key, data.ExpireAt, false)
}

View File

@@ -17,11 +17,10 @@ package acl
import (
"crypto/sha256"
"fmt"
internal_acl "github.com/echovault/echovault/internal/acl"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/modules/acl"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
acl2 "github.com/echovault/echovault/pkg/modules/acl"
"github.com/tidwall/resp"
"net"
"slices"
@@ -61,44 +60,43 @@ func setUpServer(bindAddr string, port uint16, requirePass bool, aclConfig strin
}
mockServer, _ := echovault.NewEchoVault(
echovault.WithCommands(acl2.Commands()),
echovault.WithConfig(conf),
)
// Add the initial test users to the ACL module
acl := mockServer.GetACL().(*internal_acl.ACL)
acl.AddUsers(generateInitialTestUsers())
a := mockServer.GetACL().(*acl.ACL)
a.AddUsers(generateInitialTestUsers())
return mockServer
}
func generateInitialTestUsers() []*internal_acl.User {
func generateInitialTestUsers() []*acl.User {
// User with both hash password and plaintext password
withPasswordUser := internal_acl.CreateUser("with_password_user")
withPasswordUser := acl.CreateUser("with_password_user")
h := sha256.New()
h.Write([]byte("password3"))
withPasswordUser.Passwords = []internal_acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "password2"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: string(h.Sum(nil))},
withPasswordUser.Passwords = []acl.Password{
{PasswordType: acl.PasswordPlainText, PasswordValue: "password2"},
{PasswordType: acl.PasswordSHA256, PasswordValue: string(h.Sum(nil))},
}
withPasswordUser.IncludedCategories = []string{"*"}
withPasswordUser.IncludedCommands = []string{"*"}
// User with NoPassword option
noPasswordUser := internal_acl.CreateUser("no_password_user")
noPasswordUser.Passwords = []internal_acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "password4"},
noPasswordUser := acl.CreateUser("no_password_user")
noPasswordUser.Passwords = []acl.Password{
{PasswordType: acl.PasswordPlainText, PasswordValue: "password4"},
}
noPasswordUser.NoPassword = true
// Disabled user
disabledUser := internal_acl.CreateUser("disabled_user")
disabledUser.Passwords = []internal_acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "password5"},
disabledUser := acl.CreateUser("disabled_user")
disabledUser.Passwords = []acl.Password{
{PasswordType: acl.PasswordPlainText, PasswordValue: "password5"},
}
disabledUser.Enabled = false
return []*internal_acl.User{
return []*acl.User{
withPasswordUser,
noPasswordUser,
disabledUser,
@@ -129,7 +127,7 @@ func compareSlices[T comparable](res, expected []T) error {
}
// compareUsers compares 2 users and checks if all their fields are equal
func compareUsers(user1, user2 *internal_acl.User) error {
func compareUsers(user1, user2 *acl.User) error {
// Compare flags
if user1.Username != user2.Username {
return fmt.Errorf("mismatched usernames \"%s\", and \"%s\"", user1.Username, user2.Username)
@@ -146,14 +144,14 @@ func compareUsers(user1, user2 *internal_acl.User) error {
// Compare passwords
for _, password1 := range user1.Passwords {
if !slices.ContainsFunc(user2.Passwords, func(password2 internal_acl.Password) bool {
if !slices.ContainsFunc(user2.Passwords, func(password2 acl.Password) bool {
return password1.PasswordType == password2.PasswordType && password1.PasswordValue == password2.PasswordValue
}) {
return fmt.Errorf("found password %+v in user1 that was not found in user2", password1)
}
}
for _, password2 := range user2.Passwords {
if !slices.ContainsFunc(user1.Passwords, func(password1 internal_acl.Password) bool {
if !slices.ContainsFunc(user1.Passwords, func(password1 acl.Password) bool {
return password1.PasswordType == password2.PasswordType && password1.PasswordValue == password2.PasswordValue
}) {
return fmt.Errorf("found password %+v in user2 that was not found in user1", password2)
@@ -392,14 +390,6 @@ func Test_HandleCat(t *testing.T) {
t.Errorf("could not find expected command \"%s\" in the response array for category", expected)
}
}
// Check if all the elements in the response array are in the expected array
for _, value := range resArr {
if !slices.ContainsFunc(test.wantRes, func(expected string) bool {
return value.String() == expected
}) {
t.Errorf("could not find response command \"%s\" in the expected array", value.String())
}
}
}
}
@@ -469,7 +459,7 @@ func Test_HandleSetUser(t *testing.T) {
}()
wg.Wait()
acl, ok := mockServer.GetACL().(*internal_acl.ACL)
a, ok := mockServer.GetACL().(*acl.ACL)
if !ok {
t.Error("error loading ACL module")
}
@@ -487,11 +477,11 @@ func Test_HandleSetUser(t *testing.T) {
r := resp.NewConn(conn)
tests := []struct {
presetUser *internal_acl.User
presetUser *acl.User
cmd []resp.Value
wantRes string
wantErr string
wantUser *internal_acl.User
wantUser *acl.User
}{
{
// 1. Create new enabled user
@@ -504,8 +494,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_1")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_1")
user.Enabled = true
user.Normalise()
return user
@@ -522,8 +512,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_2")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_2")
user.Enabled = false
user.Normalise()
return user
@@ -544,14 +534,14 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_3")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_3")
user.Enabled = true
user.Passwords = []internal_acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")},
user.Passwords = []acl.Password{
{PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
{PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"},
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")},
}
user.Normalise()
return user
@@ -559,14 +549,14 @@ func Test_HandleSetUser(t *testing.T) {
},
{
// 4. Remove plaintext and SHA256 password from existing user
presetUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_4")
presetUser: func() *acl.User {
user := acl.CreateUser("set_user_4")
user.Enabled = true
user.Passwords = []internal_acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")},
user.Passwords = []acl.Password{
{PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
{PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"},
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")},
}
user.Normalise()
return user
@@ -581,12 +571,12 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_4")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_4")
user.Enabled = true
user.Passwords = []internal_acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
user.Passwords = []acl.Password{
{PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
}
user.Normalise()
return user
@@ -604,8 +594,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_5")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_5")
user.Enabled = true
user.ExcludedCommands = []string{"*"}
user.ExcludedCategories = []string{"*"}
@@ -625,8 +615,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_6")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_6")
user.Enabled = true
user.IncludedCategories = []string{"*"}
user.ExcludedCategories = []string{}
@@ -646,8 +636,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_7")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_7")
user.Enabled = true
user.IncludedCategories = []string{"*"}
user.ExcludedCategories = []string{}
@@ -672,8 +662,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_8")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_8")
user.Enabled = true
user.IncludedCategories = []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory}
user.ExcludedCategories = []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory}
@@ -693,8 +683,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_9")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_9")
user.Enabled = true
user.NoKeys = true
user.IncludedReadKeys = []string{}
@@ -723,8 +713,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_10")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_10")
user.Enabled = true
user.NoKeys = false
user.IncludedReadKeys = []string{"key1", "key2", "key3", "key4", "key5", "key6"}
@@ -745,8 +735,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_11")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_11")
user.Enabled = true
user.IncludedPubSubChannels = []string{"*"}
user.ExcludedPubSubChannels = []string{}
@@ -766,8 +756,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_12")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_12")
user.Enabled = true
user.IncludedPubSubChannels = []string{"*"}
user.ExcludedPubSubChannels = []string{}
@@ -790,8 +780,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_13")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_13")
user.Enabled = true
user.IncludedPubSubChannels = []string{"channel1", "channel2"}
user.ExcludedPubSubChannels = []string{"channel3", "channel4"}
@@ -811,8 +801,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_14")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_14")
user.Enabled = true
user.IncludedCommands = []string{"*"}
user.ExcludedCommands = []string{}
@@ -837,8 +827,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_15")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_15")
user.Enabled = true
user.IncludedCommands = []string{"acl|getuser", "acl|setuser", "acl|deluser"}
user.ExcludedCommands = []string{"rewriteaof", "save", "publish"}
@@ -861,24 +851,24 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_16")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_16")
user.Enabled = true
user.NoPassword = true
user.Passwords = []internal_acl.Password{}
user.Passwords = []acl.Password{}
user.Normalise()
return user
}(),
},
{
// 17. Delete all existing users passwords using 'nopass'
presetUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_17")
presetUser: func() *acl.User {
user := acl.CreateUser("set_user_17")
user.Enabled = true
user.NoPassword = true
user.Passwords = []internal_acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "password1"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")},
user.Passwords = []acl.Password{
{PasswordType: acl.PasswordPlainText, PasswordValue: "password1"},
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")},
}
user.Normalise()
return user
@@ -892,24 +882,24 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_17")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_17")
user.Enabled = true
user.NoPassword = true
user.Passwords = []internal_acl.Password{}
user.Passwords = []acl.Password{}
user.Normalise()
return user
}(),
},
{
// 18. Clear all of an existing user's passwords using 'resetpass'
presetUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_18")
presetUser: func() *acl.User {
user := acl.CreateUser("set_user_18")
user.Enabled = true
user.NoPassword = true
user.Passwords = []internal_acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "password1"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")},
user.Passwords = []acl.Password{
{PasswordType: acl.PasswordPlainText, PasswordValue: "password1"},
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")},
}
user.Normalise()
return user
@@ -923,19 +913,19 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_18")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_18")
user.Enabled = true
user.NoPassword = true
user.Passwords = []internal_acl.Password{}
user.Passwords = []acl.Password{}
user.Normalise()
return user
}(),
},
{
// 19. Clear all of an existing user's command privileges using 'nocommands'
presetUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_19")
presetUser: func() *acl.User {
user := acl.CreateUser("set_user_19")
user.Enabled = true
user.IncludedCommands = []string{"acl|getuser", "acl|setuser", "acl|deluser"}
user.ExcludedCommands = []string{"rewriteaof", "save"}
@@ -951,8 +941,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_19")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_19")
user.Enabled = true
user.IncludedCommands = []string{}
user.ExcludedCommands = []string{"*"}
@@ -964,8 +954,8 @@ func Test_HandleSetUser(t *testing.T) {
},
{
// 20. Clear all of an existing user's allowed keys using 'resetkeys'
presetUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_20")
presetUser: func() *acl.User {
user := acl.CreateUser("set_user_20")
user.Enabled = true
user.IncludedWriteKeys = []string{"key1", "key2", "key3", "key4", "key5", "key6"}
user.IncludedReadKeys = []string{"key1", "key2", "key3", "key7", "key8", "key9"}
@@ -981,8 +971,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_20")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_20")
user.Enabled = true
user.NoKeys = true
user.IncludedReadKeys = []string{}
@@ -993,8 +983,8 @@ func Test_HandleSetUser(t *testing.T) {
},
{
// 21. Allow user to access all channels using 'resetchannels'
presetUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_21")
presetUser: func() *acl.User {
user := acl.CreateUser("set_user_21")
user.IncludedPubSubChannels = []string{"channel1", "channel2"}
user.ExcludedPubSubChannels = []string{"channel3", "channel4"}
user.Normalise()
@@ -1008,8 +998,8 @@ func Test_HandleSetUser(t *testing.T) {
},
wantRes: "OK",
wantErr: "",
wantUser: func() *internal_acl.User {
user := internal_acl.CreateUser("set_user_21")
wantUser: func() *acl.User {
user := acl.CreateUser("set_user_21")
user.IncludedPubSubChannels = []string{}
user.ExcludedPubSubChannels = []string{"*"}
user.Normalise()
@@ -1020,7 +1010,7 @@ func Test_HandleSetUser(t *testing.T) {
for i, test := range tests {
if test.presetUser != nil {
acl.AddUsers([]*internal_acl.User{test.presetUser})
a.AddUsers([]*acl.User{test.presetUser})
}
if err = r.WriteArray(test.cmd); err != nil {
t.Error(err)
@@ -1042,13 +1032,13 @@ func Test_HandleSetUser(t *testing.T) {
continue
}
expectedUser := test.wantUser
currUserIdx := slices.IndexFunc(acl.Users, func(user *internal_acl.User) bool {
currUserIdx := slices.IndexFunc(a.Users, func(user *acl.User) bool {
return user.Username == expectedUser.Username
})
if currUserIdx == -1 {
t.Errorf("expected to find user with username \"%s\" but could not find them.", expectedUser.Username)
}
if err = compareUsers(expectedUser, acl.Users[currUserIdx]); err != nil {
if err = compareUsers(expectedUser, a.Users[currUserIdx]); err != nil {
t.Errorf("test idx: %d, %+v", i, err)
}
}
@@ -1065,7 +1055,7 @@ func Test_HandleGetUser(t *testing.T) {
}()
wg.Wait()
acl, _ := mockServer.GetACL().(*internal_acl.ACL)
a, _ := mockServer.GetACL().(*acl.ACL)
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil {
@@ -1080,20 +1070,20 @@ func Test_HandleGetUser(t *testing.T) {
r := resp.NewConn(conn)
tests := []struct {
presetUser *internal_acl.User
presetUser *acl.User
cmd []resp.Value
wantRes []resp.Value
wantErr string
}{
{ // 1. Get the user and all their details
presetUser: &internal_acl.User{
presetUser: &acl.User{
Username: "get_user_1",
Enabled: true,
NoPassword: false,
NoKeys: false,
Passwords: []internal_acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "get_user_password_1"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("get_user_password_2")},
Passwords: []acl.Password{
{PasswordType: acl.PasswordPlainText, PasswordValue: "get_user_password_1"},
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("get_user_password_2")},
},
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
@@ -1165,7 +1155,7 @@ func Test_HandleGetUser(t *testing.T) {
for _, test := range tests {
if test.presetUser != nil {
acl.AddUsers([]*internal_acl.User{test.presetUser})
a.AddUsers([]*acl.User{test.presetUser})
}
if err = r.WriteArray(test.cmd); err != nil {
t.Error(err)
@@ -1218,7 +1208,7 @@ func Test_HandleDelUser(t *testing.T) {
}()
wg.Wait()
acl, _ := mockServer.GetACL().(*internal_acl.ACL)
a, _ := mockServer.GetACL().(*acl.ACL)
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil {
@@ -1233,14 +1223,14 @@ func Test_HandleDelUser(t *testing.T) {
r := resp.NewConn(conn)
tests := []struct {
presetUser *internal_acl.User
presetUser *acl.User
cmd []resp.Value
wantRes string
wantErr string
}{
{
// 1. Delete existing user while skipping default user and non-existent user
presetUser: internal_acl.CreateUser("user_to_delete"),
presetUser: acl.CreateUser("user_to_delete"),
cmd: []resp.Value{
resp.StringValue("ACL"),
resp.StringValue("DELUSER"),
@@ -1262,7 +1252,7 @@ func Test_HandleDelUser(t *testing.T) {
for _, test := range tests {
if test.presetUser != nil {
acl.AddUsers([]*internal_acl.User{test.presetUser})
a.AddUsers([]*acl.User{test.presetUser})
}
if err = r.WriteArray(test.cmd); err != nil {
t.Error(err)
@@ -1278,13 +1268,13 @@ func Test_HandleDelUser(t *testing.T) {
continue
}
// Check that default user still exists in the list of users
if !slices.ContainsFunc(acl.Users, func(user *internal_acl.User) bool {
if !slices.ContainsFunc(a.Users, func(user *acl.User) bool {
return user.Username == "default"
}) {
t.Error("could not find user with username \"default\" in the ACL after deleting user")
}
// Check that the deleted user is no longer in the list
if slices.ContainsFunc(acl.Users, func(user *internal_acl.User) bool {
if slices.ContainsFunc(a.Users, func(user *acl.User) bool {
return user.Username == "user_to_delete"
}) {
t.Error("deleted user found in the ACL")
@@ -1368,7 +1358,7 @@ func Test_HandleList(t *testing.T) {
}()
wg.Wait()
acl, _ := mockServer.GetACL().(*internal_acl.ACL)
a, _ := mockServer.GetACL().(*acl.ACL)
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil {
@@ -1383,21 +1373,21 @@ func Test_HandleList(t *testing.T) {
r := resp.NewConn(conn)
tests := []struct {
presetUsers []*internal_acl.User
presetUsers []*acl.User
cmd []resp.Value
wantRes []string
wantErr string
}{
{ // 1. Get the user and all their details
presetUsers: []*internal_acl.User{
presetUsers: []*acl.User{
{
Username: "list_user_1",
Enabled: true,
NoPassword: false,
NoKeys: false,
Passwords: []internal_acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "list_user_password_1"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_2")},
Passwords: []acl.Password{
{PasswordType: acl.PasswordPlainText, PasswordValue: "list_user_password_1"},
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_2")},
},
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
@@ -1413,7 +1403,7 @@ func Test_HandleList(t *testing.T) {
Enabled: true,
NoPassword: true,
NoKeys: true,
Passwords: []internal_acl.Password{},
Passwords: []acl.Password{},
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
IncludedCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"},
@@ -1428,9 +1418,9 @@ func Test_HandleList(t *testing.T) {
Enabled: true,
NoPassword: false,
NoKeys: false,
Passwords: []internal_acl.Password{
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "list_user_password_3"},
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_4")},
Passwords: []acl.Password{
{PasswordType: acl.PasswordPlainText, PasswordValue: "list_user_password_3"},
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_4")},
},
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
@@ -1457,7 +1447,7 @@ func Test_HandleList(t *testing.T) {
}
for _, test := range tests {
acl.AddUsers(test.presetUsers)
a.AddUsers(test.presetUsers)
if err = r.WriteArray(test.cmd); err != nil {
t.Error(err)

View File

@@ -21,7 +21,6 @@ import (
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/admin"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
@@ -33,7 +32,6 @@ var mockServer *echovault.EchoVault
func init() {
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(admin.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,

View File

@@ -21,7 +21,6 @@ import (
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/connection"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
@@ -33,7 +32,6 @@ var mockServer *echovault.EchoVault
func init() {
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(connection.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,

View File

@@ -20,7 +20,6 @@ import (
"github.com/echovault/echovault/internal/clock"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/generic"
"reflect"
"slices"
"strings"
@@ -30,7 +29,6 @@ import (
func createEchoVault() *echovault.EchoVault {
ev, _ := echovault.NewEchoVault(
echovault.WithCommands(generic.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
}),

View File

@@ -23,7 +23,6 @@ import (
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/generic"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
@@ -45,7 +44,6 @@ func init() {
mockClock = clock.NewClock()
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(generic.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,

View File

@@ -18,7 +18,6 @@ import (
"context"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/hash"
"reflect"
"slices"
"testing"
@@ -26,7 +25,6 @@ import (
func createEchoVault() *echovault.EchoVault {
ev, _ := echovault.NewEchoVault(
echovault.WithCommands(hash.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
}),

View File

@@ -22,7 +22,6 @@ import (
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/hash"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
@@ -35,7 +34,6 @@ var mockServer *echovault.EchoVault
func init() {
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(hash.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,

View File

@@ -18,14 +18,12 @@ import (
"context"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/list"
"reflect"
"testing"
)
func createEchoVault() *echovault.EchoVault {
ev, _ := echovault.NewEchoVault(
echovault.WithCommands(list.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
}),

View File

@@ -22,7 +22,6 @@ import (
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/list"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
@@ -34,7 +33,6 @@ var mockServer *echovault.EchoVault
func init() {
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(list.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,

View File

@@ -19,10 +19,9 @@ import (
"context"
"fmt"
"github.com/echovault/echovault/internal/config"
internal_pubsub "github.com/echovault/echovault/internal/pubsub"
"github.com/echovault/echovault/internal/modules/pubsub"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
ps "github.com/echovault/echovault/pkg/modules/pubsub"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
@@ -33,7 +32,7 @@ import (
"time"
)
var pubsub *internal_pubsub.PubSub
var ps *pubsub.PubSub
var mockServer *echovault.EchoVault
var bindAddr = "localhost"
@@ -41,7 +40,7 @@ var port uint16 = 7490
func init() {
mockServer = setUpServer(bindAddr, port)
pubsub = mockServer.GetPubSub().(*internal_pubsub.PubSub)
ps = mockServer.GetPubSub().(*pubsub.PubSub)
wg := sync.WaitGroup{}
wg.Add(1)
@@ -54,7 +53,6 @@ func init() {
func setUpServer(bindAddr string, port uint16) *echovault.EchoVault {
server, _ := echovault.NewEchoVault(
echovault.WithCommands(ps.Commands()),
echovault.WithConfig(config.Config{
BindAddr: bindAddr,
Port: port,
@@ -126,12 +124,12 @@ func Test_HandleSubscribe(t *testing.T) {
}
for _, channel := range channels {
// Check if the channel exists in the pubsub module
if !slices.ContainsFunc(pubsub.GetAllChannels(), func(c *internal_pubsub.Channel) bool {
if !slices.ContainsFunc(ps.GetAllChannels(), func(c *pubsub.Channel) bool {
return c.Name() == channel
}) {
t.Errorf("expected pubsub to contain channel \"%s\" but it was not found", channel)
}
for _, c := range pubsub.GetAllChannels() {
for _, c := range ps.GetAllChannels() {
if c.Name() == channel {
// Check if channel has nil pattern
if c.Pattern() != nil {
@@ -157,12 +155,12 @@ func Test_HandleSubscribe(t *testing.T) {
}
for _, pattern := range patterns {
// Check if pattern channel exists in pubsub module
if !slices.ContainsFunc(pubsub.GetAllChannels(), func(c *internal_pubsub.Channel) bool {
if !slices.ContainsFunc(ps.GetAllChannels(), func(c *pubsub.Channel) bool {
return c.Name() == pattern
}) {
t.Errorf("expected pubsub to contain pattern channel \"%s\" but it was not found", pattern)
}
for _, c := range pubsub.GetAllChannels() {
for _, c := range ps.GetAllChannels() {
if c.Name() == pattern {
// Check if channel has non-nil pattern
if c.Pattern() == nil {
@@ -322,7 +320,7 @@ func Test_HandleUnsubscribe(t *testing.T) {
verifyResponse(res, test.expectedResponses["pattern"])
for _, channel := range append(test.unSubChannels, test.unSubPatterns...) {
for _, pubsubChannel := range pubsub.GetAllChannels() {
for _, pubsubChannel := range ps.GetAllChannels() {
if pubsubChannel.Name() == channel {
// Assert that target connection is no longer in the unsub channels and patterns
if _, ok := pubsubChannel.Subscribers()[test.targetConn]; ok {
@@ -339,7 +337,7 @@ func Test_HandleUnsubscribe(t *testing.T) {
// Assert that the target connection is still in the remain channels and patterns
for _, channel := range append(test.remainChannels, test.remainPatterns...) {
for _, pubsubChannel := range pubsub.GetAllChannels() {
for _, pubsubChannel := range ps.GetAllChannels() {
if pubsubChannel.Name() == channel {
if _, ok := pubsubChannel.Subscribers()[test.targetConn]; !ok {
t.Errorf("could not find expected target connection in channel \"%s\"", channel)

View File

@@ -17,9 +17,8 @@ package set
import (
"context"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/set"
"github.com/echovault/echovault/internal/modules/set"
"github.com/echovault/echovault/pkg/echovault"
s "github.com/echovault/echovault/pkg/modules/set"
"reflect"
"slices"
"testing"
@@ -27,7 +26,6 @@ import (
func createEchoVault() *echovault.EchoVault {
ev, _ := echovault.NewEchoVault(
echovault.WithCommands(s.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
}),

View File

@@ -20,10 +20,9 @@ import (
"errors"
"fmt"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/set"
"github.com/echovault/echovault/internal/modules/set"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
s "github.com/echovault/echovault/pkg/modules/set"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
@@ -36,7 +35,6 @@ var mockServer *echovault.EchoVault
func init() {
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(s.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,

File diff suppressed because it is too large Load Diff

View File

@@ -20,10 +20,9 @@ import (
"errors"
"fmt"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/sorted_set"
"github.com/echovault/echovault/internal/modules/sorted_set"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
ss "github.com/echovault/echovault/pkg/modules/sorted_set"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"math"
@@ -38,7 +37,6 @@ var mockServer *echovault.EchoVault
func init() {
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(ss.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,

View File

@@ -18,13 +18,11 @@ import (
"context"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/echovault"
str "github.com/echovault/echovault/pkg/modules/string"
"testing"
)
func createEchoVault() *echovault.EchoVault {
ev, _ := echovault.NewEchoVault(
echovault.WithCommands(str.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
}),

View File

@@ -23,7 +23,6 @@ import (
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
str "github.com/echovault/echovault/pkg/modules/string"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
@@ -36,7 +35,6 @@ var mockServer *echovault.EchoVault
func init() {
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(str.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,