mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-05 07:56:52 +08:00
Removed duplicate imports for set, sorted_set, pubsub and acl modules. Moved /modules from /pkg to /internal. Delted commands package: Commands will now be automatically loaded when an EchoVault instance is initialised.
This commit is contained in:
2
Makefile
2
Makefile
@@ -7,7 +7,7 @@ build:
|
|||||||
run:
|
run:
|
||||||
make build && docker-compose up --build
|
make build && docker-compose up --build
|
||||||
|
|
||||||
test-normal:
|
test-unit:
|
||||||
go clean -testcache && go test ./... -coverprofile coverage/coverage.out
|
go clean -testcache && go test ./... -coverprofile coverage/coverage.out
|
||||||
|
|
||||||
test-race:
|
test-race:
|
||||||
|
@@ -18,7 +18,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"github.com/echovault/echovault/internal"
|
"github.com/echovault/echovault/internal"
|
||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
"github.com/echovault/echovault/pkg/commands"
|
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
@@ -49,7 +48,6 @@ func main() {
|
|||||||
server, err := echovault.NewEchoVault(
|
server, err := echovault.NewEchoVault(
|
||||||
echovault.WithContext(ctx),
|
echovault.WithContext(ctx),
|
||||||
echovault.WithConfig(conf),
|
echovault.WithConfig(conf),
|
||||||
echovault.WithCommands(commands.All()),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@@ -18,7 +18,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
internal_acl "github.com/echovault/echovault/internal/acl"
|
|
||||||
"github.com/echovault/echovault/pkg/constants"
|
"github.com/echovault/echovault/pkg/constants"
|
||||||
"github.com/echovault/echovault/pkg/types"
|
"github.com/echovault/echovault/pkg/types"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
@@ -33,7 +32,7 @@ func handleAuth(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
if len(params.Command) < 2 || len(params.Command) > 3 {
|
if len(params.Command) < 2 || len(params.Command) > 3 {
|
||||||
return nil, errors.New(constants.WrongArgsResponse)
|
return nil, errors.New(constants.WrongArgsResponse)
|
||||||
}
|
}
|
||||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
acl, ok := params.GetACL().(*ACL)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load ACL")
|
return nil, errors.New("could not load ACL")
|
||||||
}
|
}
|
||||||
@@ -48,12 +47,12 @@ func handleGetUser(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, errors.New(constants.WrongArgsResponse)
|
return nil, errors.New(constants.WrongArgsResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
acl, ok := params.GetACL().(*ACL)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load ACL")
|
return nil, errors.New("could not load ACL")
|
||||||
}
|
}
|
||||||
|
|
||||||
var user *internal_acl.User
|
var user *User
|
||||||
userFound := false
|
userFound := false
|
||||||
for _, u := range acl.Users {
|
for _, u := range acl.Users {
|
||||||
if u.Username == params.Command[2] {
|
if u.Username == params.Command[2] {
|
||||||
@@ -221,7 +220,7 @@ func handleCat(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func handleUsers(params types.HandlerFuncParams) ([]byte, error) {
|
func handleUsers(params types.HandlerFuncParams) ([]byte, error) {
|
||||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
acl, ok := params.GetACL().(*ACL)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load ACL")
|
return nil, errors.New("could not load ACL")
|
||||||
}
|
}
|
||||||
@@ -234,7 +233,7 @@ func handleUsers(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func handleSetUser(params types.HandlerFuncParams) ([]byte, error) {
|
func handleSetUser(params types.HandlerFuncParams) ([]byte, error) {
|
||||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
acl, ok := params.GetACL().(*ACL)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load ACL")
|
return nil, errors.New("could not load ACL")
|
||||||
}
|
}
|
||||||
@@ -248,7 +247,7 @@ func handleDelUser(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
if len(params.Command) < 3 {
|
if len(params.Command) < 3 {
|
||||||
return nil, errors.New(constants.WrongArgsResponse)
|
return nil, errors.New(constants.WrongArgsResponse)
|
||||||
}
|
}
|
||||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
acl, ok := params.GetACL().(*ACL)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load ACL")
|
return nil, errors.New("could not load ACL")
|
||||||
}
|
}
|
||||||
@@ -259,7 +258,7 @@ func handleDelUser(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func handleWhoAmI(params types.HandlerFuncParams) ([]byte, error) {
|
func handleWhoAmI(params types.HandlerFuncParams) ([]byte, error) {
|
||||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
acl, ok := params.GetACL().(*ACL)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load ACL")
|
return nil, errors.New("could not load ACL")
|
||||||
}
|
}
|
||||||
@@ -271,7 +270,7 @@ func handleList(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
if len(params.Command) > 2 {
|
if len(params.Command) > 2 {
|
||||||
return nil, errors.New(constants.WrongArgsResponse)
|
return nil, errors.New(constants.WrongArgsResponse)
|
||||||
}
|
}
|
||||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
acl, ok := params.GetACL().(*ACL)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load ACL")
|
return nil, errors.New("could not load ACL")
|
||||||
}
|
}
|
||||||
@@ -368,7 +367,7 @@ func handleLoad(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, errors.New(constants.WrongArgsResponse)
|
return nil, errors.New(constants.WrongArgsResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
acl, ok := params.GetACL().(*ACL)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load ACL")
|
return nil, errors.New("could not load ACL")
|
||||||
}
|
}
|
||||||
@@ -389,7 +388,7 @@ func handleLoad(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
|
|
||||||
ext := path.Ext(f.Name())
|
ext := path.Ext(f.Name())
|
||||||
|
|
||||||
var users []*internal_acl.User
|
var users []*User
|
||||||
|
|
||||||
if ext == ".json" {
|
if ext == ".json" {
|
||||||
if err := json.NewDecoder(f).Decode(&users); err != nil {
|
if err := json.NewDecoder(f).Decode(&users); err != nil {
|
||||||
@@ -435,7 +434,7 @@ func handleSave(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, errors.New(constants.WrongArgsResponse)
|
return nil, errors.New(constants.WrongArgsResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
acl, ok := params.GetACL().(*ACL)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load ACL")
|
return nil, errors.New("could not load ACL")
|
||||||
}
|
}
|
@@ -17,14 +17,13 @@ package pubsub
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
internal_pubsub "github.com/echovault/echovault/internal/pubsub"
|
|
||||||
"github.com/echovault/echovault/pkg/constants"
|
"github.com/echovault/echovault/pkg/constants"
|
||||||
"github.com/echovault/echovault/pkg/types"
|
"github.com/echovault/echovault/pkg/types"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func handleSubscribe(params types.HandlerFuncParams) ([]byte, error) {
|
func handleSubscribe(params types.HandlerFuncParams) ([]byte, error) {
|
||||||
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
|
pubsub, ok := params.GetPubSub().(*PubSub)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load pubsub module")
|
return nil, errors.New("could not load pubsub module")
|
||||||
}
|
}
|
||||||
@@ -42,7 +41,7 @@ func handleSubscribe(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func handleUnsubscribe(params types.HandlerFuncParams) ([]byte, error) {
|
func handleUnsubscribe(params types.HandlerFuncParams) ([]byte, error) {
|
||||||
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
|
pubsub, ok := params.GetPubSub().(*PubSub)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load pubsub module")
|
return nil, errors.New("could not load pubsub module")
|
||||||
}
|
}
|
||||||
@@ -55,7 +54,7 @@ func handleUnsubscribe(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func handlePublish(params types.HandlerFuncParams) ([]byte, error) {
|
func handlePublish(params types.HandlerFuncParams) ([]byte, error) {
|
||||||
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
|
pubsub, ok := params.GetPubSub().(*PubSub)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load pubsub module")
|
return nil, errors.New("could not load pubsub module")
|
||||||
}
|
}
|
||||||
@@ -71,7 +70,7 @@ func handlePubSubChannels(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, errors.New(constants.WrongArgsResponse)
|
return nil, errors.New(constants.WrongArgsResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
|
pubsub, ok := params.GetPubSub().(*PubSub)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load pubsub module")
|
return nil, errors.New("could not load pubsub module")
|
||||||
}
|
}
|
||||||
@@ -85,7 +84,7 @@ func handlePubSubChannels(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func handlePubSubNumPat(params types.HandlerFuncParams) ([]byte, error) {
|
func handlePubSubNumPat(params types.HandlerFuncParams) ([]byte, error) {
|
||||||
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
|
pubsub, ok := params.GetPubSub().(*PubSub)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load pubsub module")
|
return nil, errors.New("could not load pubsub module")
|
||||||
}
|
}
|
||||||
@@ -94,7 +93,7 @@ func handlePubSubNumPat(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func handlePubSubNumSubs(params types.HandlerFuncParams) ([]byte, error) {
|
func handlePubSubNumSubs(params types.HandlerFuncParams) ([]byte, error) {
|
||||||
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
|
pubsub, ok := params.GetPubSub().(*PubSub)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load pubsub module")
|
return nil, errors.New("could not load pubsub module")
|
||||||
}
|
}
|
@@ -18,7 +18,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/echovault/echovault/internal"
|
"github.com/echovault/echovault/internal"
|
||||||
internal_set "github.com/echovault/echovault/internal/set"
|
|
||||||
"github.com/echovault/echovault/pkg/constants"
|
"github.com/echovault/echovault/pkg/constants"
|
||||||
"github.com/echovault/echovault/pkg/types"
|
"github.com/echovault/echovault/pkg/types"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -33,10 +32,10 @@ func handleSADD(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
|
|
||||||
key := keys.WriteKeys[0]
|
key := keys.WriteKeys[0]
|
||||||
|
|
||||||
var set *internal_set.Set
|
var set *Set
|
||||||
|
|
||||||
if !params.KeyExists(params.Context, key) {
|
if !params.KeyExists(params.Context, key) {
|
||||||
set = internal_set.NewSet(params.Command[2:])
|
set = NewSet(params.Command[2:])
|
||||||
if ok, err := params.CreateKeyAndLock(params.Context, key); !ok && err != nil {
|
if ok, err := params.CreateKeyAndLock(params.Context, key); !ok && err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -52,7 +51,7 @@ func handleSADD(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyUnlock(params.Context, key)
|
defer params.KeyUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
set, ok := params.GetValue(params.Context, key).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||||
}
|
}
|
||||||
@@ -79,7 +78,7 @@ func handleSCARD(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, key)
|
defer params.KeyRUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
set, ok := params.GetValue(params.Context, key).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||||
}
|
}
|
||||||
@@ -103,7 +102,7 @@ func handleSDIFF(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
|
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
|
||||||
baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*internal_set.Set)
|
baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
|
return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
|
||||||
}
|
}
|
||||||
@@ -127,9 +126,9 @@ func handleSDIFF(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
locks[key] = true
|
locks[key] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var sets []*internal_set.Set
|
var sets []*Set
|
||||||
for _, key := range params.Command[2:] {
|
for _, key := range params.Command[2:] {
|
||||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
set, ok := params.GetValue(params.Context, key).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -166,7 +165,7 @@ func handleSDIFFSTORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
|
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
|
||||||
baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*internal_set.Set)
|
baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
|
return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
|
||||||
}
|
}
|
||||||
@@ -190,9 +189,9 @@ func handleSDIFFSTORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
locks[key] = true
|
locks[key] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var sets []*internal_set.Set
|
var sets []*Set
|
||||||
for _, key := range keys.ReadKeys[1:] {
|
for _, key := range keys.ReadKeys[1:] {
|
||||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
set, ok := params.GetValue(params.Context, key).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -252,10 +251,10 @@ func handleSINTER(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
locks[key] = true
|
locks[key] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var sets []*internal_set.Set
|
var sets []*Set
|
||||||
|
|
||||||
for key, _ := range locks {
|
for key, _ := range locks {
|
||||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
set, ok := params.GetValue(params.Context, key).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
// If the value at the key is not a set, return error
|
// If the value at the key is not a set, return error
|
||||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||||
@@ -267,7 +266,7 @@ func handleSINTER(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("not enough sets in the keys provided")
|
return nil, fmt.Errorf("not enough sets in the keys provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
intersect, _ := internal_set.Intersection(0, sets...)
|
intersect, _ := Intersection(0, sets...)
|
||||||
elems := intersect.GetAll()
|
elems := intersect.GetAll()
|
||||||
|
|
||||||
res := fmt.Sprintf("*%d", len(elems))
|
res := fmt.Sprintf("*%d", len(elems))
|
||||||
@@ -328,10 +327,10 @@ func handleSINTERCARD(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
locks[key] = true
|
locks[key] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var sets []*internal_set.Set
|
var sets []*Set
|
||||||
|
|
||||||
for key, _ := range locks {
|
for key, _ := range locks {
|
||||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
set, ok := params.GetValue(params.Context, key).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
// If the value at the key is not a set, return error
|
// If the value at the key is not a set, return error
|
||||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||||
@@ -343,7 +342,7 @@ func handleSINTERCARD(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("not enough sets in the keys provided")
|
return nil, fmt.Errorf("not enough sets in the keys provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
intersect, _ := internal_set.Intersection(limit, sets...)
|
intersect, _ := Intersection(limit, sets...)
|
||||||
|
|
||||||
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
|
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
|
||||||
}
|
}
|
||||||
@@ -374,10 +373,10 @@ func handleSINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
locks[key] = true
|
locks[key] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var sets []*internal_set.Set
|
var sets []*Set
|
||||||
|
|
||||||
for key, _ := range locks {
|
for key, _ := range locks {
|
||||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
set, ok := params.GetValue(params.Context, key).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
// If the value at the key is not a set, return error
|
// If the value at the key is not a set, return error
|
||||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||||
@@ -385,7 +384,7 @@ func handleSINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
sets = append(sets, set)
|
sets = append(sets, set)
|
||||||
}
|
}
|
||||||
|
|
||||||
intersect, _ := internal_set.Intersection(0, sets...)
|
intersect, _ := Intersection(0, sets...)
|
||||||
destination := keys.WriteKeys[0]
|
destination := keys.WriteKeys[0]
|
||||||
|
|
||||||
if params.KeyExists(params.Context, destination) {
|
if params.KeyExists(params.Context, destination) {
|
||||||
@@ -423,7 +422,7 @@ func handleSISMEMBER(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, key)
|
defer params.KeyRUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
set, ok := params.GetValue(params.Context, key).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||||
}
|
}
|
||||||
@@ -452,7 +451,7 @@ func handleSMEMBERS(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, key)
|
defer params.KeyRUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
set, ok := params.GetValue(params.Context, key).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||||
}
|
}
|
||||||
@@ -495,7 +494,7 @@ func handleSMISMEMBER(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, key)
|
defer params.KeyRUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
set, ok := params.GetValue(params.Context, key).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||||
}
|
}
|
||||||
@@ -531,12 +530,12 @@ func handleSMOVE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyUnlock(params.Context, source)
|
defer params.KeyUnlock(params.Context, source)
|
||||||
|
|
||||||
sourceSet, ok := params.GetValue(params.Context, source).(*internal_set.Set)
|
sourceSet, ok := params.GetValue(params.Context, source).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("source is not a set")
|
return nil, errors.New("source is not a set")
|
||||||
}
|
}
|
||||||
|
|
||||||
var destinationSet *internal_set.Set
|
var destinationSet *Set
|
||||||
|
|
||||||
if !params.KeyExists(params.Context, destination) {
|
if !params.KeyExists(params.Context, destination) {
|
||||||
// Destination key does not exist
|
// Destination key does not exist
|
||||||
@@ -544,7 +543,7 @@ func handleSMOVE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer params.KeyUnlock(params.Context, destination)
|
defer params.KeyUnlock(params.Context, destination)
|
||||||
destinationSet = internal_set.NewSet([]string{})
|
destinationSet = NewSet([]string{})
|
||||||
if err = params.SetValue(params.Context, destination, destinationSet); err != nil {
|
if err = params.SetValue(params.Context, destination, destinationSet); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -554,7 +553,7 @@ func handleSMOVE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer params.KeyUnlock(params.Context, destination)
|
defer params.KeyUnlock(params.Context, destination)
|
||||||
ds, ok := params.GetValue(params.Context, destination).(*internal_set.Set)
|
ds, ok := params.GetValue(params.Context, destination).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("destination is not a set")
|
return nil, errors.New("destination is not a set")
|
||||||
}
|
}
|
||||||
@@ -592,7 +591,7 @@ func handleSPOP(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyUnlock(params.Context, key)
|
defer params.KeyUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
set, ok := params.GetValue(params.Context, key).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a set", key)
|
return nil, fmt.Errorf("value at %s is not a set", key)
|
||||||
}
|
}
|
||||||
@@ -636,7 +635,7 @@ func handleSRANDMEMBER(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyUnlock(params.Context, key)
|
defer params.KeyUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
set, ok := params.GetValue(params.Context, key).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a set", key)
|
return nil, fmt.Errorf("value at %s is not a set", key)
|
||||||
}
|
}
|
||||||
@@ -672,7 +671,7 @@ func handleSREM(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyUnlock(params.Context, key)
|
defer params.KeyUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
set, ok := params.GetValue(params.Context, key).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||||
}
|
}
|
||||||
@@ -707,20 +706,20 @@ func handleSUNION(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
locks[key] = true
|
locks[key] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var sets []*internal_set.Set
|
var sets []*Set
|
||||||
|
|
||||||
for key, locked := range locks {
|
for key, locked := range locks {
|
||||||
if !locked {
|
if !locked {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
set, ok := params.GetValue(params.Context, key).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||||
}
|
}
|
||||||
sets = append(sets, set)
|
sets = append(sets, set)
|
||||||
}
|
}
|
||||||
|
|
||||||
union := internal_set.Union(sets...)
|
union := Union(sets...)
|
||||||
|
|
||||||
res := fmt.Sprintf("*%d", union.Cardinality())
|
res := fmt.Sprintf("*%d", union.Cardinality())
|
||||||
for i, e := range union.GetAll() {
|
for i, e := range union.GetAll() {
|
||||||
@@ -758,20 +757,20 @@ func handleSUNIONSTORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
locks[key] = true
|
locks[key] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var sets []*internal_set.Set
|
var sets []*Set
|
||||||
|
|
||||||
for key, locked := range locks {
|
for key, locked := range locks {
|
||||||
if !locked {
|
if !locked {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
set, ok := params.GetValue(params.Context, key).(*Set)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||||
}
|
}
|
||||||
sets = append(sets, set)
|
sets = append(sets, set)
|
||||||
}
|
}
|
||||||
|
|
||||||
union := internal_set.Union(sets...)
|
union := Union(sets...)
|
||||||
|
|
||||||
destination := keys.WriteKeys[0]
|
destination := keys.WriteKeys[0]
|
||||||
|
|
@@ -19,7 +19,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/echovault/echovault/internal"
|
"github.com/echovault/echovault/internal"
|
||||||
"github.com/echovault/echovault/internal/sorted_set"
|
|
||||||
"github.com/echovault/echovault/pkg/constants"
|
"github.com/echovault/echovault/pkg/constants"
|
||||||
"github.com/echovault/echovault/pkg/types"
|
"github.com/echovault/echovault/pkg/types"
|
||||||
"math"
|
"math"
|
||||||
@@ -63,7 +62,7 @@ func handleZADD(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, errors.New("score/member pairs must be float/string")
|
return nil, errors.New("score/member pairs must be float/string")
|
||||||
}
|
}
|
||||||
|
|
||||||
var members []sorted_set.MemberParam
|
var members []MemberParam
|
||||||
|
|
||||||
for i := 0; i < len(params.Command[membersStartIndex:]); i++ {
|
for i := 0; i < len(params.Command[membersStartIndex:]); i++ {
|
||||||
if i%2 != 0 {
|
if i%2 != 0 {
|
||||||
@@ -77,29 +76,29 @@ func handleZADD(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
var s float64
|
var s float64
|
||||||
if strings.ToLower(score.(string)) == "-inf" {
|
if strings.ToLower(score.(string)) == "-inf" {
|
||||||
s = math.Inf(-1)
|
s = math.Inf(-1)
|
||||||
members = append(members, sorted_set.MemberParam{
|
members = append(members, MemberParam{
|
||||||
Value: sorted_set.Value(params.Command[membersStartIndex:][i+1]),
|
Value: Value(params.Command[membersStartIndex:][i+1]),
|
||||||
Score: sorted_set.Score(s),
|
Score: Score(s),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if strings.ToLower(score.(string)) == "+inf" {
|
if strings.ToLower(score.(string)) == "+inf" {
|
||||||
s = math.Inf(1)
|
s = math.Inf(1)
|
||||||
members = append(members, sorted_set.MemberParam{
|
members = append(members, MemberParam{
|
||||||
Value: sorted_set.Value(params.Command[membersStartIndex:][i+1]),
|
Value: Value(params.Command[membersStartIndex:][i+1]),
|
||||||
Score: sorted_set.Score(s),
|
Score: Score(s),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
case float64:
|
case float64:
|
||||||
s, _ := score.(float64)
|
s, _ := score.(float64)
|
||||||
members = append(members, sorted_set.MemberParam{
|
members = append(members, MemberParam{
|
||||||
Value: sorted_set.Value(params.Command[membersStartIndex:][i+1]),
|
Value: Value(params.Command[membersStartIndex:][i+1]),
|
||||||
Score: sorted_set.Score(s),
|
Score: Score(s),
|
||||||
})
|
})
|
||||||
case int:
|
case int:
|
||||||
s, _ := score.(int)
|
s, _ := score.(int)
|
||||||
members = append(members, sorted_set.MemberParam{
|
members = append(members, MemberParam{
|
||||||
Value: sorted_set.Value(params.Command[membersStartIndex:][i+1]),
|
Value: Value(params.Command[membersStartIndex:][i+1]),
|
||||||
Score: sorted_set.Score(s),
|
Score: Score(s),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -148,7 +147,7 @@ func handleZADD(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer params.KeyUnlock(params.Context, key)
|
defer params.KeyUnlock(params.Context, key)
|
||||||
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, key).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
||||||
}
|
}
|
||||||
@@ -171,7 +170,7 @@ func handleZADD(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyUnlock(params.Context, key)
|
defer params.KeyUnlock(params.Context, key)
|
||||||
|
|
||||||
set := sorted_set.NewSortedSet(members)
|
set := NewSortedSet(members)
|
||||||
if err = params.SetValue(params.Context, key, set); err != nil {
|
if err = params.SetValue(params.Context, key, set); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -195,7 +194,7 @@ func handleZCARD(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, key)
|
defer params.KeyRUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, key).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
||||||
}
|
}
|
||||||
@@ -211,40 +210,40 @@ func handleZCOUNT(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
|
|
||||||
key := keys.ReadKeys[0]
|
key := keys.ReadKeys[0]
|
||||||
|
|
||||||
minimum := sorted_set.Score(math.Inf(-1))
|
minimum := Score(math.Inf(-1))
|
||||||
switch internal.AdaptType(params.Command[2]).(type) {
|
switch internal.AdaptType(params.Command[2]).(type) {
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("min constraint must be a double")
|
return nil, errors.New("min constraint must be a double")
|
||||||
case string:
|
case string:
|
||||||
if strings.ToLower(params.Command[2]) == "+inf" {
|
if strings.ToLower(params.Command[2]) == "+inf" {
|
||||||
minimum = sorted_set.Score(math.Inf(1))
|
minimum = Score(math.Inf(1))
|
||||||
} else {
|
} else {
|
||||||
return nil, errors.New("min constraint must be a double")
|
return nil, errors.New("min constraint must be a double")
|
||||||
}
|
}
|
||||||
case float64:
|
case float64:
|
||||||
s, _ := internal.AdaptType(params.Command[2]).(float64)
|
s, _ := internal.AdaptType(params.Command[2]).(float64)
|
||||||
minimum = sorted_set.Score(s)
|
minimum = Score(s)
|
||||||
case int:
|
case int:
|
||||||
s, _ := internal.AdaptType(params.Command[2]).(int)
|
s, _ := internal.AdaptType(params.Command[2]).(int)
|
||||||
minimum = sorted_set.Score(s)
|
minimum = Score(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
maximum := sorted_set.Score(math.Inf(1))
|
maximum := Score(math.Inf(1))
|
||||||
switch internal.AdaptType(params.Command[3]).(type) {
|
switch internal.AdaptType(params.Command[3]).(type) {
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("max constraint must be a double")
|
return nil, errors.New("max constraint must be a double")
|
||||||
case string:
|
case string:
|
||||||
if strings.ToLower(params.Command[3]) == "-inf" {
|
if strings.ToLower(params.Command[3]) == "-inf" {
|
||||||
maximum = sorted_set.Score(math.Inf(-1))
|
maximum = Score(math.Inf(-1))
|
||||||
} else {
|
} else {
|
||||||
return nil, errors.New("max constraint must be a double")
|
return nil, errors.New("max constraint must be a double")
|
||||||
}
|
}
|
||||||
case float64:
|
case float64:
|
||||||
s, _ := internal.AdaptType(params.Command[3]).(float64)
|
s, _ := internal.AdaptType(params.Command[3]).(float64)
|
||||||
maximum = sorted_set.Score(s)
|
maximum = Score(s)
|
||||||
case int:
|
case int:
|
||||||
s, _ := internal.AdaptType(params.Command[3]).(int)
|
s, _ := internal.AdaptType(params.Command[3]).(int)
|
||||||
maximum = sorted_set.Score(s)
|
maximum = Score(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !params.KeyExists(params.Context, key) {
|
if !params.KeyExists(params.Context, key) {
|
||||||
@@ -256,12 +255,12 @@ func handleZCOUNT(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, key)
|
defer params.KeyRUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, key).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
var members []sorted_set.MemberParam
|
var members []MemberParam
|
||||||
for _, m := range set.GetAll() {
|
for _, m := range set.GetAll() {
|
||||||
if m.Score >= minimum && m.Score <= maximum {
|
if m.Score >= minimum && m.Score <= maximum {
|
||||||
members = append(members, m)
|
members = append(members, m)
|
||||||
@@ -290,7 +289,7 @@ func handleZLEXCOUNT(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, key)
|
defer params.KeyRUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, key).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
||||||
}
|
}
|
||||||
@@ -347,13 +346,13 @@ func handleZDIFF(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
|
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
|
||||||
baseSortedSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*sorted_set.SortedSet)
|
baseSortedSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0])
|
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract the remaining sets
|
// Extract the remaining sets
|
||||||
var sets []*sorted_set.SortedSet
|
var sets []*SortedSet
|
||||||
|
|
||||||
for i := 1; i < len(keys.ReadKeys); i++ {
|
for i := 1; i < len(keys.ReadKeys); i++ {
|
||||||
if !params.KeyExists(params.Context, keys.ReadKeys[i]) {
|
if !params.KeyExists(params.Context, keys.ReadKeys[i]) {
|
||||||
@@ -364,7 +363,7 @@ func handleZDIFF(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
locks[keys.ReadKeys[i]] = locked
|
locks[keys.ReadKeys[i]] = locked
|
||||||
set, ok := params.GetValue(params.Context, keys.ReadKeys[i]).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, keys.ReadKeys[i]).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i])
|
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i])
|
||||||
}
|
}
|
||||||
@@ -415,19 +414,19 @@ func handleZDIFFSTORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
|
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
|
||||||
baseSortedSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*sorted_set.SortedSet)
|
baseSortedSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0])
|
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
var sets []*sorted_set.SortedSet
|
var sets []*SortedSet
|
||||||
|
|
||||||
for i := 1; i < len(keys.ReadKeys); i++ {
|
for i := 1; i < len(keys.ReadKeys); i++ {
|
||||||
if params.KeyExists(params.Context, keys.ReadKeys[i]) {
|
if params.KeyExists(params.Context, keys.ReadKeys[i]) {
|
||||||
if _, err = params.KeyRLock(params.Context, keys.ReadKeys[i]); err != nil {
|
if _, err = params.KeyRLock(params.Context, keys.ReadKeys[i]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
set, ok := params.GetValue(params.Context, keys.ReadKeys[i]).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, keys.ReadKeys[i]).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i])
|
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i])
|
||||||
}
|
}
|
||||||
@@ -462,26 +461,26 @@ func handleZINCRBY(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
key := keys.WriteKeys[0]
|
key := keys.WriteKeys[0]
|
||||||
member := sorted_set.Value(params.Command[3])
|
member := Value(params.Command[3])
|
||||||
var increment sorted_set.Score
|
var increment Score
|
||||||
|
|
||||||
switch internal.AdaptType(params.Command[2]).(type) {
|
switch internal.AdaptType(params.Command[2]).(type) {
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("increment must be a double")
|
return nil, errors.New("increment must be a double")
|
||||||
case string:
|
case string:
|
||||||
if strings.EqualFold("-inf", strings.ToLower(params.Command[2])) {
|
if strings.EqualFold("-inf", strings.ToLower(params.Command[2])) {
|
||||||
increment = sorted_set.Score(math.Inf(-1))
|
increment = Score(math.Inf(-1))
|
||||||
} else if strings.EqualFold("+inf", strings.ToLower(params.Command[2])) {
|
} else if strings.EqualFold("+inf", strings.ToLower(params.Command[2])) {
|
||||||
increment = sorted_set.Score(math.Inf(1))
|
increment = Score(math.Inf(1))
|
||||||
} else {
|
} else {
|
||||||
return nil, errors.New("increment must be a double")
|
return nil, errors.New("increment must be a double")
|
||||||
}
|
}
|
||||||
case float64:
|
case float64:
|
||||||
s, _ := internal.AdaptType(params.Command[2]).(float64)
|
s, _ := internal.AdaptType(params.Command[2]).(float64)
|
||||||
increment = sorted_set.Score(s)
|
increment = Score(s)
|
||||||
case int:
|
case int:
|
||||||
s, _ := internal.AdaptType(params.Command[2]).(int)
|
s, _ := internal.AdaptType(params.Command[2]).(int)
|
||||||
increment = sorted_set.Score(s)
|
increment = Score(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !params.KeyExists(params.Context, key) {
|
if !params.KeyExists(params.Context, key) {
|
||||||
@@ -493,7 +492,7 @@ func handleZINCRBY(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
if err = params.SetValue(
|
if err = params.SetValue(
|
||||||
params.Context,
|
params.Context,
|
||||||
key,
|
key,
|
||||||
sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: member, Score: increment}}),
|
NewSortedSet([]MemberParam{{Value: member, Score: increment}}),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -505,12 +504,12 @@ func handleZINCRBY(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer params.KeyUnlock(params.Context, key)
|
defer params.KeyUnlock(params.Context, key)
|
||||||
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, key).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
||||||
}
|
}
|
||||||
if _, err = set.AddOrUpdate(
|
if _, err = set.AddOrUpdate(
|
||||||
[]sorted_set.MemberParam{
|
[]MemberParam{
|
||||||
{Value: member, Score: increment}},
|
{Value: member, Score: increment}},
|
||||||
"xx",
|
"xx",
|
||||||
nil,
|
nil,
|
||||||
@@ -542,7 +541,7 @@ func handleZINTER(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var setParams []sorted_set.SortedSetParam
|
var setParams []SortedSetParam
|
||||||
|
|
||||||
for i := 0; i < len(keys); i++ {
|
for i := 0; i < len(keys); i++ {
|
||||||
if !params.KeyExists(params.Context, keys[i]) {
|
if !params.KeyExists(params.Context, keys[i]) {
|
||||||
@@ -553,17 +552,17 @@ func handleZINTER(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
locks[keys[i]] = true
|
locks[keys[i]] = true
|
||||||
set, ok := params.GetValue(params.Context, keys[i]).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
|
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
|
||||||
}
|
}
|
||||||
setParams = append(setParams, sorted_set.SortedSetParam{
|
setParams = append(setParams, SortedSetParam{
|
||||||
Set: set,
|
Set: set,
|
||||||
Weight: weights[i],
|
Weight: weights[i],
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
intersect := sorted_set.Intersect(aggregate, setParams...)
|
intersect := Intersect(aggregate, setParams...)
|
||||||
|
|
||||||
res := fmt.Sprintf("*%d", intersect.Cardinality())
|
res := fmt.Sprintf("*%d", intersect.Cardinality())
|
||||||
|
|
||||||
@@ -609,7 +608,7 @@ func handleZINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var setParams []sorted_set.SortedSetParam
|
var setParams []SortedSetParam
|
||||||
|
|
||||||
for i := 0; i < len(keys); i++ {
|
for i := 0; i < len(keys); i++ {
|
||||||
if !params.KeyExists(params.Context, keys[i]) {
|
if !params.KeyExists(params.Context, keys[i]) {
|
||||||
@@ -619,17 +618,17 @@ func handleZINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
locks[keys[i]] = true
|
locks[keys[i]] = true
|
||||||
set, ok := params.GetValue(params.Context, keys[i]).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
|
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
|
||||||
}
|
}
|
||||||
setParams = append(setParams, sorted_set.SortedSetParam{
|
setParams = append(setParams, SortedSetParam{
|
||||||
Set: set,
|
Set: set,
|
||||||
Weight: weights[i],
|
Weight: weights[i],
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
intersect := sorted_set.Intersect(aggregate, setParams...)
|
intersect := Intersect(aggregate, setParams...)
|
||||||
|
|
||||||
if params.KeyExists(params.Context, destination) && intersect.Cardinality() > 0 {
|
if params.KeyExists(params.Context, destination) && intersect.Cardinality() > 0 {
|
||||||
if _, err = params.KeyLock(params.Context, destination); err != nil {
|
if _, err = params.KeyLock(params.Context, destination); err != nil {
|
||||||
@@ -700,7 +699,7 @@ func handleZMPOP(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
if _, err = params.KeyLock(params.Context, keys.WriteKeys[i]); err != nil {
|
if _, err = params.KeyLock(params.Context, keys.WriteKeys[i]); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
v, ok := params.GetValue(params.Context, keys.WriteKeys[i]).(*sorted_set.SortedSet)
|
v, ok := params.GetValue(params.Context, keys.WriteKeys[i]).(*SortedSet)
|
||||||
if !ok || v.Cardinality() == 0 {
|
if !ok || v.Cardinality() == 0 {
|
||||||
params.KeyUnlock(params.Context, keys.WriteKeys[i])
|
params.KeyUnlock(params.Context, keys.WriteKeys[i])
|
||||||
continue
|
continue
|
||||||
@@ -760,7 +759,7 @@ func handleZPOP(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyUnlock(params.Context, key)
|
defer params.KeyUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, key).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at key %s is not a sorted set", key)
|
return nil, fmt.Errorf("value at key %s is not a sorted set", key)
|
||||||
}
|
}
|
||||||
@@ -797,7 +796,7 @@ func handleZMSCORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, key)
|
defer params.KeyRUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, key).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
||||||
}
|
}
|
||||||
@@ -806,10 +805,10 @@ func handleZMSCORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
|
|
||||||
res := fmt.Sprintf("*%d", len(members))
|
res := fmt.Sprintf("*%d", len(members))
|
||||||
|
|
||||||
var member sorted_set.MemberObject
|
var member MemberObject
|
||||||
|
|
||||||
for i := 0; i < len(members); i++ {
|
for i := 0; i < len(members); i++ {
|
||||||
member = set.Get(sorted_set.Value(members[i]))
|
member = set.Get(Value(members[i]))
|
||||||
if !member.Exists {
|
if !member.Exists {
|
||||||
res = fmt.Sprintf("%s\r\n$-1", res)
|
res = fmt.Sprintf("%s\r\n$-1", res)
|
||||||
} else {
|
} else {
|
||||||
@@ -859,7 +858,7 @@ func handleZRANDMEMBER(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, key)
|
defer params.KeyRUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, key).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
||||||
}
|
}
|
||||||
@@ -903,13 +902,13 @@ func handleZRANK(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, key)
|
defer params.KeyRUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, key).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
members := set.GetAll()
|
members := set.GetAll()
|
||||||
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
|
slices.SortFunc(members, func(a, b MemberParam) int {
|
||||||
if strings.EqualFold(params.Command[0], "zrevrank") {
|
if strings.EqualFold(params.Command[0], "zrevrank") {
|
||||||
return cmp.Compare(b.Score, a.Score)
|
return cmp.Compare(b.Score, a.Score)
|
||||||
}
|
}
|
||||||
@@ -917,7 +916,7 @@ func handleZRANK(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
for i := 0; i < len(members); i++ {
|
for i := 0; i < len(members); i++ {
|
||||||
if members[i].Value == sorted_set.Value(member) {
|
if members[i].Value == Value(member) {
|
||||||
if withscores {
|
if withscores {
|
||||||
score := strconv.FormatFloat(float64(members[i].Score), 'f', -1, 64)
|
score := strconv.FormatFloat(float64(members[i].Score), 'f', -1, 64)
|
||||||
return []byte(fmt.Sprintf("*2\r\n:%d\r\n$%d\r\n%s\r\n", i, len(score), score)), nil
|
return []byte(fmt.Sprintf("*2\r\n:%d\r\n$%d\r\n%s\r\n", i, len(score), score)), nil
|
||||||
@@ -947,14 +946,14 @@ func handleZREM(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyUnlock(params.Context, key)
|
defer params.KeyUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, key).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
deletedCount := 0
|
deletedCount := 0
|
||||||
for _, m := range params.Command[2:] {
|
for _, m := range params.Command[2:] {
|
||||||
if set.Remove(sorted_set.Value(m)) {
|
if set.Remove(Value(m)) {
|
||||||
deletedCount += 1
|
deletedCount += 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -977,11 +976,11 @@ func handleZSCORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, key)
|
defer params.KeyRUnlock(params.Context, key)
|
||||||
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, key).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
||||||
}
|
}
|
||||||
member := set.Get(sorted_set.Value(params.Command[2]))
|
member := set.Get(Value(params.Command[2]))
|
||||||
if !member.Exists {
|
if !member.Exists {
|
||||||
return []byte("$-1\r\n"), nil
|
return []byte("$-1\r\n"), nil
|
||||||
}
|
}
|
||||||
@@ -1020,13 +1019,13 @@ func handleZREMRANGEBYSCORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyUnlock(params.Context, key)
|
defer params.KeyUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, key).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, m := range set.GetAll() {
|
for _, m := range set.GetAll() {
|
||||||
if m.Score >= sorted_set.Score(minimum) && m.Score <= sorted_set.Score(maximum) {
|
if m.Score >= Score(minimum) && m.Score <= Score(maximum) {
|
||||||
set.Remove(m.Value)
|
set.Remove(m.Value)
|
||||||
deletedCount += 1
|
deletedCount += 1
|
||||||
}
|
}
|
||||||
@@ -1062,7 +1061,7 @@ func handleZREMRANGEBYRANK(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyUnlock(params.Context, key)
|
defer params.KeyUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, key).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
||||||
}
|
}
|
||||||
@@ -1079,7 +1078,7 @@ func handleZREMRANGEBYRANK(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
members := set.GetAll()
|
members := set.GetAll()
|
||||||
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
|
slices.SortFunc(members, func(a, b MemberParam) int {
|
||||||
return cmp.Compare(a.Score, b.Score)
|
return cmp.Compare(a.Score, b.Score)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -1119,7 +1118,7 @@ func handleZREMRANGEBYLEX(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyUnlock(params.Context, key)
|
defer params.KeyUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, key).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
||||||
}
|
}
|
||||||
@@ -1217,7 +1216,7 @@ func handleZRANGE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, key)
|
defer params.KeyRUnlock(params.Context, key)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, key).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, key).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
return nil, fmt.Errorf("value at %s is not a sorted set", key)
|
||||||
}
|
}
|
||||||
@@ -1231,7 +1230,7 @@ func handleZRANGE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
|
|
||||||
members := set.GetAll()
|
members := set.GetAll()
|
||||||
if strings.EqualFold(policy, "byscore") {
|
if strings.EqualFold(policy, "byscore") {
|
||||||
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
|
slices.SortFunc(members, func(a, b MemberParam) int {
|
||||||
// Do a score sort
|
// Do a score sort
|
||||||
if reverse {
|
if reverse {
|
||||||
return cmp.Compare(b.Score, a.Score)
|
return cmp.Compare(b.Score, a.Score)
|
||||||
@@ -1246,7 +1245,7 @@ func handleZRANGE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return []byte("*0\r\n"), nil
|
return []byte("*0\r\n"), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
|
slices.SortFunc(members, func(a, b MemberParam) int {
|
||||||
if reverse {
|
if reverse {
|
||||||
return internal.CompareLex(string(b.Value), string(a.Value))
|
return internal.CompareLex(string(b.Value), string(a.Value))
|
||||||
}
|
}
|
||||||
@@ -1254,14 +1253,14 @@ func handleZRANGE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var resultMembers []sorted_set.MemberParam
|
var resultMembers []MemberParam
|
||||||
|
|
||||||
for i := offset; i <= count; i++ {
|
for i := offset; i <= count; i++ {
|
||||||
if i >= len(members) {
|
if i >= len(members) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if strings.EqualFold(policy, "byscore") {
|
if strings.EqualFold(policy, "byscore") {
|
||||||
if members[i].Score >= sorted_set.Score(scoreStart) && members[i].Score <= sorted_set.Score(scoreStop) {
|
if members[i].Score >= Score(scoreStart) && members[i].Score <= Score(scoreStop) {
|
||||||
resultMembers = append(resultMembers, members[i])
|
resultMembers = append(resultMembers, members[i])
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -1354,7 +1353,7 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
defer params.KeyRUnlock(params.Context, source)
|
defer params.KeyRUnlock(params.Context, source)
|
||||||
|
|
||||||
set, ok := params.GetValue(params.Context, source).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, source).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", source)
|
return nil, fmt.Errorf("value at %s is not a sorted set", source)
|
||||||
}
|
}
|
||||||
@@ -1368,7 +1367,7 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
|
|
||||||
members := set.GetAll()
|
members := set.GetAll()
|
||||||
if strings.EqualFold(policy, "byscore") {
|
if strings.EqualFold(policy, "byscore") {
|
||||||
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
|
slices.SortFunc(members, func(a, b MemberParam) int {
|
||||||
// Do a score sort
|
// Do a score sort
|
||||||
if reverse {
|
if reverse {
|
||||||
return cmp.Compare(b.Score, a.Score)
|
return cmp.Compare(b.Score, a.Score)
|
||||||
@@ -1383,7 +1382,7 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return []byte(":0\r\n"), nil
|
return []byte(":0\r\n"), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
|
slices.SortFunc(members, func(a, b MemberParam) int {
|
||||||
if reverse {
|
if reverse {
|
||||||
return internal.CompareLex(string(b.Value), string(a.Value))
|
return internal.CompareLex(string(b.Value), string(a.Value))
|
||||||
}
|
}
|
||||||
@@ -1391,14 +1390,14 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var resultMembers []sorted_set.MemberParam
|
var resultMembers []MemberParam
|
||||||
|
|
||||||
for i := offset; i <= count; i++ {
|
for i := offset; i <= count; i++ {
|
||||||
if i >= len(members) {
|
if i >= len(members) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if strings.EqualFold(policy, "byscore") {
|
if strings.EqualFold(policy, "byscore") {
|
||||||
if members[i].Score >= sorted_set.Score(scoreStart) && members[i].Score <= sorted_set.Score(scoreStop) {
|
if members[i].Score >= Score(scoreStart) && members[i].Score <= Score(scoreStop) {
|
||||||
resultMembers = append(resultMembers, members[i])
|
resultMembers = append(resultMembers, members[i])
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -1409,7 +1408,7 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
newSortedSet := sorted_set.NewSortedSet(resultMembers)
|
newSortedSet := NewSortedSet(resultMembers)
|
||||||
|
|
||||||
if params.KeyExists(params.Context, destination) {
|
if params.KeyExists(params.Context, destination) {
|
||||||
if _, err = params.KeyLock(params.Context, destination); err != nil {
|
if _, err = params.KeyLock(params.Context, destination); err != nil {
|
||||||
@@ -1448,7 +1447,7 @@ func handleZUNION(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var setParams []sorted_set.SortedSetParam
|
var setParams []SortedSetParam
|
||||||
|
|
||||||
for i := 0; i < len(keys); i++ {
|
for i := 0; i < len(keys); i++ {
|
||||||
if params.KeyExists(params.Context, keys[i]) {
|
if params.KeyExists(params.Context, keys[i]) {
|
||||||
@@ -1456,18 +1455,18 @@ func handleZUNION(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
locks[keys[i]] = true
|
locks[keys[i]] = true
|
||||||
set, ok := params.GetValue(params.Context, keys[i]).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
|
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
|
||||||
}
|
}
|
||||||
setParams = append(setParams, sorted_set.SortedSetParam{
|
setParams = append(setParams, SortedSetParam{
|
||||||
Set: set,
|
Set: set,
|
||||||
Weight: weights[i],
|
Weight: weights[i],
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
union := sorted_set.Union(aggregate, setParams...)
|
union := Union(aggregate, setParams...)
|
||||||
|
|
||||||
res := fmt.Sprintf("*%d", union.Cardinality())
|
res := fmt.Sprintf("*%d", union.Cardinality())
|
||||||
for _, m := range union.GetAll() {
|
for _, m := range union.GetAll() {
|
||||||
@@ -1510,7 +1509,7 @@ func handleZUNIONSTORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var setParams []sorted_set.SortedSetParam
|
var setParams []SortedSetParam
|
||||||
|
|
||||||
for i := 0; i < len(keys); i++ {
|
for i := 0; i < len(keys); i++ {
|
||||||
if params.KeyExists(params.Context, keys[i]) {
|
if params.KeyExists(params.Context, keys[i]) {
|
||||||
@@ -1518,18 +1517,18 @@ func handleZUNIONSTORE(params types.HandlerFuncParams) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
locks[keys[i]] = true
|
locks[keys[i]] = true
|
||||||
set, ok := params.GetValue(params.Context, keys[i]).(*sorted_set.SortedSet)
|
set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
|
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
|
||||||
}
|
}
|
||||||
setParams = append(setParams, sorted_set.SortedSetParam{
|
setParams = append(setParams, SortedSetParam{
|
||||||
Set: set,
|
Set: set,
|
||||||
Weight: weights[i],
|
Weight: weights[i],
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
union := sorted_set.Union(aggregate, setParams...)
|
union := Union(aggregate, setParams...)
|
||||||
|
|
||||||
if params.KeyExists(params.Context, destination) {
|
if params.KeyExists(params.Context, destination) {
|
||||||
if _, err = params.KeyLock(params.Context, destination); err != nil {
|
if _, err = params.KeyLock(params.Context, destination); err != nil {
|
@@ -90,3 +90,80 @@ func extractKeysWeightsAggregateWithScores(cmd []string) ([]string, []int, strin
|
|||||||
|
|
||||||
return keys, weights, aggregate, withscores, nil
|
return keys, weights, aggregate, withscores, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateUpdatePolicy(updatePolicy interface{}) (string, error) {
|
||||||
|
if updatePolicy == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
err := errors.New("update policy must be a string of Value NX or XX")
|
||||||
|
policy, ok := updatePolicy.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if !slices.Contains([]string{"nx", "xx"}, strings.ToLower(policy)) {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return policy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateComparison(comparison interface{}) (string, error) {
|
||||||
|
if comparison == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
err := errors.New("comparison condition must be a string of Value LT or GT")
|
||||||
|
comp, ok := comparison.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if !slices.Contains([]string{"lt", "gt"}, strings.ToLower(comp)) {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return comp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateChanged(changed interface{}) (string, error) {
|
||||||
|
if changed == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
err := errors.New("changed condition should be a string of Value CH")
|
||||||
|
ch, ok := changed.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(ch, "ch") {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateIncr(incr interface{}) (string, error) {
|
||||||
|
if incr == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
err := errors.New("incr condition should be a string of Value INCR")
|
||||||
|
i, ok := incr.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(i, "incr") {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return i, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareScores(old Score, new Score, comp string) Score {
|
||||||
|
switch strings.ToLower(comp) {
|
||||||
|
default:
|
||||||
|
return new
|
||||||
|
case "lt":
|
||||||
|
if new < old {
|
||||||
|
return new
|
||||||
|
}
|
||||||
|
return old
|
||||||
|
case "gt":
|
||||||
|
if new > old {
|
||||||
|
return new
|
||||||
|
}
|
||||||
|
return old
|
||||||
|
}
|
||||||
|
}
|
@@ -1,98 +0,0 @@
|
|||||||
// Copyright 2024 Kelvin Clement Mwinuka
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package sorted_set
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func validateUpdatePolicy(updatePolicy interface{}) (string, error) {
|
|
||||||
if updatePolicy == nil {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
err := errors.New("update policy must be a string of Value NX or XX")
|
|
||||||
policy, ok := updatePolicy.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if !slices.Contains([]string{"nx", "xx"}, strings.ToLower(policy)) {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return policy, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateComparison(comparison interface{}) (string, error) {
|
|
||||||
if comparison == nil {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
err := errors.New("comparison condition must be a string of Value LT or GT")
|
|
||||||
comp, ok := comparison.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if !slices.Contains([]string{"lt", "gt"}, strings.ToLower(comp)) {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return comp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateChanged(changed interface{}) (string, error) {
|
|
||||||
if changed == nil {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
err := errors.New("changed condition should be a string of Value CH")
|
|
||||||
ch, ok := changed.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if !strings.EqualFold(ch, "ch") {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return ch, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateIncr(incr interface{}) (string, error) {
|
|
||||||
if incr == nil {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
err := errors.New("incr condition should be a string of Value INCR")
|
|
||||||
i, ok := incr.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if !strings.EqualFold(i, "incr") {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return i, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func compareScores(old Score, new Score, comp string) Score {
|
|
||||||
switch strings.ToLower(comp) {
|
|
||||||
default:
|
|
||||||
return new
|
|
||||||
case "lt":
|
|
||||||
if new < old {
|
|
||||||
return new
|
|
||||||
}
|
|
||||||
return old
|
|
||||||
case "gt":
|
|
||||||
if new > old {
|
|
||||||
return new
|
|
||||||
}
|
|
||||||
return old
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,31 +0,0 @@
|
|||||||
package commands
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/echovault/echovault/pkg/modules/acl"
|
|
||||||
"github.com/echovault/echovault/pkg/modules/admin"
|
|
||||||
"github.com/echovault/echovault/pkg/modules/connection"
|
|
||||||
"github.com/echovault/echovault/pkg/modules/generic"
|
|
||||||
"github.com/echovault/echovault/pkg/modules/hash"
|
|
||||||
"github.com/echovault/echovault/pkg/modules/list"
|
|
||||||
"github.com/echovault/echovault/pkg/modules/pubsub"
|
|
||||||
"github.com/echovault/echovault/pkg/modules/set"
|
|
||||||
"github.com/echovault/echovault/pkg/modules/sorted_set"
|
|
||||||
str "github.com/echovault/echovault/pkg/modules/string"
|
|
||||||
"github.com/echovault/echovault/pkg/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
// All returns all the commands currently available on EchoVault
|
|
||||||
func All() []types.Command {
|
|
||||||
var commands []types.Command
|
|
||||||
commands = append(commands, acl.Commands()...)
|
|
||||||
commands = append(commands, admin.Commands()...)
|
|
||||||
commands = append(commands, generic.Commands()...)
|
|
||||||
commands = append(commands, hash.Commands()...)
|
|
||||||
commands = append(commands, list.Commands()...)
|
|
||||||
commands = append(commands, connection.Commands()...)
|
|
||||||
commands = append(commands, pubsub.Commands()...)
|
|
||||||
commands = append(commands, set.Commands()...)
|
|
||||||
commands = append(commands, sorted_set.Commands()...)
|
|
||||||
commands = append(commands, str.Commands()...)
|
|
||||||
return commands
|
|
||||||
}
|
|
@@ -21,13 +21,21 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/echovault/echovault/internal"
|
"github.com/echovault/echovault/internal"
|
||||||
"github.com/echovault/echovault/internal/acl"
|
|
||||||
"github.com/echovault/echovault/internal/aof"
|
"github.com/echovault/echovault/internal/aof"
|
||||||
"github.com/echovault/echovault/internal/clock"
|
"github.com/echovault/echovault/internal/clock"
|
||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
"github.com/echovault/echovault/internal/eviction"
|
"github.com/echovault/echovault/internal/eviction"
|
||||||
"github.com/echovault/echovault/internal/memberlist"
|
"github.com/echovault/echovault/internal/memberlist"
|
||||||
"github.com/echovault/echovault/internal/pubsub"
|
"github.com/echovault/echovault/internal/modules/acl"
|
||||||
|
"github.com/echovault/echovault/internal/modules/admin"
|
||||||
|
"github.com/echovault/echovault/internal/modules/connection"
|
||||||
|
"github.com/echovault/echovault/internal/modules/generic"
|
||||||
|
"github.com/echovault/echovault/internal/modules/hash"
|
||||||
|
"github.com/echovault/echovault/internal/modules/list"
|
||||||
|
"github.com/echovault/echovault/internal/modules/pubsub"
|
||||||
|
"github.com/echovault/echovault/internal/modules/set"
|
||||||
|
"github.com/echovault/echovault/internal/modules/sorted_set"
|
||||||
|
str "github.com/echovault/echovault/internal/modules/string"
|
||||||
"github.com/echovault/echovault/internal/raft"
|
"github.com/echovault/echovault/internal/raft"
|
||||||
"github.com/echovault/echovault/internal/snapshot"
|
"github.com/echovault/echovault/internal/snapshot"
|
||||||
"github.com/echovault/echovault/pkg/constants"
|
"github.com/echovault/echovault/pkg/constants"
|
||||||
@@ -126,11 +134,24 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
|
|||||||
echovault := &EchoVault{
|
echovault := &EchoVault{
|
||||||
clock: clock.NewClock(),
|
clock: clock.NewClock(),
|
||||||
context: context.Background(),
|
context: context.Background(),
|
||||||
commands: make([]types.Command, 0),
|
|
||||||
config: config.DefaultConfig(),
|
config: config.DefaultConfig(),
|
||||||
store: make(map[string]internal.KeyData),
|
store: make(map[string]internal.KeyData),
|
||||||
keyLocks: make(map[string]*sync.RWMutex),
|
keyLocks: make(map[string]*sync.RWMutex),
|
||||||
keyCreationLock: &sync.Mutex{},
|
keyCreationLock: &sync.Mutex{},
|
||||||
|
commands: func() []types.Command {
|
||||||
|
var commands []types.Command
|
||||||
|
commands = append(commands, acl.Commands()...)
|
||||||
|
commands = append(commands, admin.Commands()...)
|
||||||
|
commands = append(commands, generic.Commands()...)
|
||||||
|
commands = append(commands, hash.Commands()...)
|
||||||
|
commands = append(commands, list.Commands()...)
|
||||||
|
commands = append(commands, connection.Commands()...)
|
||||||
|
commands = append(commands, pubsub.Commands()...)
|
||||||
|
commands = append(commands, set.Commands()...)
|
||||||
|
commands = append(commands, sorted_set.Commands()...)
|
||||||
|
commands = append(commands, str.Commands()...)
|
||||||
|
return commands
|
||||||
|
}(),
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, option := range options {
|
for _, option := range options {
|
||||||
|
@@ -589,16 +589,3 @@ func (server *EchoVault) evictKeysWithExpiredTTL(ctx context.Context) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func presetValue(server *EchoVault, key string, value interface{}) {
|
|
||||||
_, _ = server.CreateKeyAndLock(server.context, key)
|
|
||||||
_ = server.SetValue(server.context, key, value)
|
|
||||||
server.KeyUnlock(server.context, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func presetKeyData(server *EchoVault, key string, data internal.KeyData) {
|
|
||||||
_, _ = server.CreateKeyAndLock(server.context, key)
|
|
||||||
defer server.KeyUnlock(server.context, key)
|
|
||||||
_ = server.SetValue(server.context, key, data.Value)
|
|
||||||
server.SetExpiry(server.context, key, data.ExpireAt, false)
|
|
||||||
}
|
|
||||||
|
@@ -17,11 +17,10 @@ package acl
|
|||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"fmt"
|
"fmt"
|
||||||
internal_acl "github.com/echovault/echovault/internal/acl"
|
|
||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
|
"github.com/echovault/echovault/internal/modules/acl"
|
||||||
"github.com/echovault/echovault/pkg/constants"
|
"github.com/echovault/echovault/pkg/constants"
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
acl2 "github.com/echovault/echovault/pkg/modules/acl"
|
|
||||||
"github.com/tidwall/resp"
|
"github.com/tidwall/resp"
|
||||||
"net"
|
"net"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -61,44 +60,43 @@ func setUpServer(bindAddr string, port uint16, requirePass bool, aclConfig strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
mockServer, _ := echovault.NewEchoVault(
|
mockServer, _ := echovault.NewEchoVault(
|
||||||
echovault.WithCommands(acl2.Commands()),
|
|
||||||
echovault.WithConfig(conf),
|
echovault.WithConfig(conf),
|
||||||
)
|
)
|
||||||
|
|
||||||
// Add the initial test users to the ACL module
|
// Add the initial test users to the ACL module
|
||||||
acl := mockServer.GetACL().(*internal_acl.ACL)
|
a := mockServer.GetACL().(*acl.ACL)
|
||||||
acl.AddUsers(generateInitialTestUsers())
|
a.AddUsers(generateInitialTestUsers())
|
||||||
|
|
||||||
return mockServer
|
return mockServer
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateInitialTestUsers() []*internal_acl.User {
|
func generateInitialTestUsers() []*acl.User {
|
||||||
// User with both hash password and plaintext password
|
// User with both hash password and plaintext password
|
||||||
withPasswordUser := internal_acl.CreateUser("with_password_user")
|
withPasswordUser := acl.CreateUser("with_password_user")
|
||||||
h := sha256.New()
|
h := sha256.New()
|
||||||
h.Write([]byte("password3"))
|
h.Write([]byte("password3"))
|
||||||
withPasswordUser.Passwords = []internal_acl.Password{
|
withPasswordUser.Passwords = []acl.Password{
|
||||||
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "password2"},
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "password2"},
|
||||||
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: string(h.Sum(nil))},
|
{PasswordType: acl.PasswordSHA256, PasswordValue: string(h.Sum(nil))},
|
||||||
}
|
}
|
||||||
withPasswordUser.IncludedCategories = []string{"*"}
|
withPasswordUser.IncludedCategories = []string{"*"}
|
||||||
withPasswordUser.IncludedCommands = []string{"*"}
|
withPasswordUser.IncludedCommands = []string{"*"}
|
||||||
|
|
||||||
// User with NoPassword option
|
// User with NoPassword option
|
||||||
noPasswordUser := internal_acl.CreateUser("no_password_user")
|
noPasswordUser := acl.CreateUser("no_password_user")
|
||||||
noPasswordUser.Passwords = []internal_acl.Password{
|
noPasswordUser.Passwords = []acl.Password{
|
||||||
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "password4"},
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "password4"},
|
||||||
}
|
}
|
||||||
noPasswordUser.NoPassword = true
|
noPasswordUser.NoPassword = true
|
||||||
|
|
||||||
// Disabled user
|
// Disabled user
|
||||||
disabledUser := internal_acl.CreateUser("disabled_user")
|
disabledUser := acl.CreateUser("disabled_user")
|
||||||
disabledUser.Passwords = []internal_acl.Password{
|
disabledUser.Passwords = []acl.Password{
|
||||||
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "password5"},
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "password5"},
|
||||||
}
|
}
|
||||||
disabledUser.Enabled = false
|
disabledUser.Enabled = false
|
||||||
|
|
||||||
return []*internal_acl.User{
|
return []*acl.User{
|
||||||
withPasswordUser,
|
withPasswordUser,
|
||||||
noPasswordUser,
|
noPasswordUser,
|
||||||
disabledUser,
|
disabledUser,
|
||||||
@@ -129,7 +127,7 @@ func compareSlices[T comparable](res, expected []T) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// compareUsers compares 2 users and checks if all their fields are equal
|
// compareUsers compares 2 users and checks if all their fields are equal
|
||||||
func compareUsers(user1, user2 *internal_acl.User) error {
|
func compareUsers(user1, user2 *acl.User) error {
|
||||||
// Compare flags
|
// Compare flags
|
||||||
if user1.Username != user2.Username {
|
if user1.Username != user2.Username {
|
||||||
return fmt.Errorf("mismatched usernames \"%s\", and \"%s\"", user1.Username, user2.Username)
|
return fmt.Errorf("mismatched usernames \"%s\", and \"%s\"", user1.Username, user2.Username)
|
||||||
@@ -146,14 +144,14 @@ func compareUsers(user1, user2 *internal_acl.User) error {
|
|||||||
|
|
||||||
// Compare passwords
|
// Compare passwords
|
||||||
for _, password1 := range user1.Passwords {
|
for _, password1 := range user1.Passwords {
|
||||||
if !slices.ContainsFunc(user2.Passwords, func(password2 internal_acl.Password) bool {
|
if !slices.ContainsFunc(user2.Passwords, func(password2 acl.Password) bool {
|
||||||
return password1.PasswordType == password2.PasswordType && password1.PasswordValue == password2.PasswordValue
|
return password1.PasswordType == password2.PasswordType && password1.PasswordValue == password2.PasswordValue
|
||||||
}) {
|
}) {
|
||||||
return fmt.Errorf("found password %+v in user1 that was not found in user2", password1)
|
return fmt.Errorf("found password %+v in user1 that was not found in user2", password1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, password2 := range user2.Passwords {
|
for _, password2 := range user2.Passwords {
|
||||||
if !slices.ContainsFunc(user1.Passwords, func(password1 internal_acl.Password) bool {
|
if !slices.ContainsFunc(user1.Passwords, func(password1 acl.Password) bool {
|
||||||
return password1.PasswordType == password2.PasswordType && password1.PasswordValue == password2.PasswordValue
|
return password1.PasswordType == password2.PasswordType && password1.PasswordValue == password2.PasswordValue
|
||||||
}) {
|
}) {
|
||||||
return fmt.Errorf("found password %+v in user2 that was not found in user1", password2)
|
return fmt.Errorf("found password %+v in user2 that was not found in user1", password2)
|
||||||
@@ -392,14 +390,6 @@ func Test_HandleCat(t *testing.T) {
|
|||||||
t.Errorf("could not find expected command \"%s\" in the response array for category", expected)
|
t.Errorf("could not find expected command \"%s\" in the response array for category", expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Check if all the elements in the response array are in the expected array
|
|
||||||
for _, value := range resArr {
|
|
||||||
if !slices.ContainsFunc(test.wantRes, func(expected string) bool {
|
|
||||||
return value.String() == expected
|
|
||||||
}) {
|
|
||||||
t.Errorf("could not find response command \"%s\" in the expected array", value.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -469,7 +459,7 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
acl, ok := mockServer.GetACL().(*internal_acl.ACL)
|
a, ok := mockServer.GetACL().(*acl.ACL)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Error("error loading ACL module")
|
t.Error("error loading ACL module")
|
||||||
}
|
}
|
||||||
@@ -487,11 +477,11 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
r := resp.NewConn(conn)
|
r := resp.NewConn(conn)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
presetUser *internal_acl.User
|
presetUser *acl.User
|
||||||
cmd []resp.Value
|
cmd []resp.Value
|
||||||
wantRes string
|
wantRes string
|
||||||
wantErr string
|
wantErr string
|
||||||
wantUser *internal_acl.User
|
wantUser *acl.User
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
// 1. Create new enabled user
|
// 1. Create new enabled user
|
||||||
@@ -504,8 +494,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_1")
|
user := acl.CreateUser("set_user_1")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.Normalise()
|
user.Normalise()
|
||||||
return user
|
return user
|
||||||
@@ -522,8 +512,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_2")
|
user := acl.CreateUser("set_user_2")
|
||||||
user.Enabled = false
|
user.Enabled = false
|
||||||
user.Normalise()
|
user.Normalise()
|
||||||
return user
|
return user
|
||||||
@@ -544,14 +534,14 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_3")
|
user := acl.CreateUser("set_user_3")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.Passwords = []internal_acl.Password{
|
user.Passwords = []acl.Password{
|
||||||
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
|
||||||
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"},
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"},
|
||||||
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
|
||||||
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")},
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")},
|
||||||
}
|
}
|
||||||
user.Normalise()
|
user.Normalise()
|
||||||
return user
|
return user
|
||||||
@@ -559,14 +549,14 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
// 4. Remove plaintext and SHA256 password from existing user
|
// 4. Remove plaintext and SHA256 password from existing user
|
||||||
presetUser: func() *internal_acl.User {
|
presetUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_4")
|
user := acl.CreateUser("set_user_4")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.Passwords = []internal_acl.Password{
|
user.Passwords = []acl.Password{
|
||||||
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
|
||||||
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"},
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"},
|
||||||
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
|
||||||
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")},
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")},
|
||||||
}
|
}
|
||||||
user.Normalise()
|
user.Normalise()
|
||||||
return user
|
return user
|
||||||
@@ -581,12 +571,12 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_4")
|
user := acl.CreateUser("set_user_4")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.Passwords = []internal_acl.Password{
|
user.Passwords = []acl.Password{
|
||||||
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
|
||||||
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
|
||||||
}
|
}
|
||||||
user.Normalise()
|
user.Normalise()
|
||||||
return user
|
return user
|
||||||
@@ -604,8 +594,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_5")
|
user := acl.CreateUser("set_user_5")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.ExcludedCommands = []string{"*"}
|
user.ExcludedCommands = []string{"*"}
|
||||||
user.ExcludedCategories = []string{"*"}
|
user.ExcludedCategories = []string{"*"}
|
||||||
@@ -625,8 +615,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_6")
|
user := acl.CreateUser("set_user_6")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.IncludedCategories = []string{"*"}
|
user.IncludedCategories = []string{"*"}
|
||||||
user.ExcludedCategories = []string{}
|
user.ExcludedCategories = []string{}
|
||||||
@@ -646,8 +636,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_7")
|
user := acl.CreateUser("set_user_7")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.IncludedCategories = []string{"*"}
|
user.IncludedCategories = []string{"*"}
|
||||||
user.ExcludedCategories = []string{}
|
user.ExcludedCategories = []string{}
|
||||||
@@ -672,8 +662,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_8")
|
user := acl.CreateUser("set_user_8")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.IncludedCategories = []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory}
|
user.IncludedCategories = []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory}
|
||||||
user.ExcludedCategories = []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory}
|
user.ExcludedCategories = []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory}
|
||||||
@@ -693,8 +683,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_9")
|
user := acl.CreateUser("set_user_9")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.NoKeys = true
|
user.NoKeys = true
|
||||||
user.IncludedReadKeys = []string{}
|
user.IncludedReadKeys = []string{}
|
||||||
@@ -723,8 +713,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_10")
|
user := acl.CreateUser("set_user_10")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.NoKeys = false
|
user.NoKeys = false
|
||||||
user.IncludedReadKeys = []string{"key1", "key2", "key3", "key4", "key5", "key6"}
|
user.IncludedReadKeys = []string{"key1", "key2", "key3", "key4", "key5", "key6"}
|
||||||
@@ -745,8 +735,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_11")
|
user := acl.CreateUser("set_user_11")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.IncludedPubSubChannels = []string{"*"}
|
user.IncludedPubSubChannels = []string{"*"}
|
||||||
user.ExcludedPubSubChannels = []string{}
|
user.ExcludedPubSubChannels = []string{}
|
||||||
@@ -766,8 +756,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_12")
|
user := acl.CreateUser("set_user_12")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.IncludedPubSubChannels = []string{"*"}
|
user.IncludedPubSubChannels = []string{"*"}
|
||||||
user.ExcludedPubSubChannels = []string{}
|
user.ExcludedPubSubChannels = []string{}
|
||||||
@@ -790,8 +780,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_13")
|
user := acl.CreateUser("set_user_13")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.IncludedPubSubChannels = []string{"channel1", "channel2"}
|
user.IncludedPubSubChannels = []string{"channel1", "channel2"}
|
||||||
user.ExcludedPubSubChannels = []string{"channel3", "channel4"}
|
user.ExcludedPubSubChannels = []string{"channel3", "channel4"}
|
||||||
@@ -811,8 +801,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_14")
|
user := acl.CreateUser("set_user_14")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.IncludedCommands = []string{"*"}
|
user.IncludedCommands = []string{"*"}
|
||||||
user.ExcludedCommands = []string{}
|
user.ExcludedCommands = []string{}
|
||||||
@@ -837,8 +827,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_15")
|
user := acl.CreateUser("set_user_15")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.IncludedCommands = []string{"acl|getuser", "acl|setuser", "acl|deluser"}
|
user.IncludedCommands = []string{"acl|getuser", "acl|setuser", "acl|deluser"}
|
||||||
user.ExcludedCommands = []string{"rewriteaof", "save", "publish"}
|
user.ExcludedCommands = []string{"rewriteaof", "save", "publish"}
|
||||||
@@ -861,24 +851,24 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_16")
|
user := acl.CreateUser("set_user_16")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.NoPassword = true
|
user.NoPassword = true
|
||||||
user.Passwords = []internal_acl.Password{}
|
user.Passwords = []acl.Password{}
|
||||||
user.Normalise()
|
user.Normalise()
|
||||||
return user
|
return user
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// 17. Delete all existing users passwords using 'nopass'
|
// 17. Delete all existing users passwords using 'nopass'
|
||||||
presetUser: func() *internal_acl.User {
|
presetUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_17")
|
user := acl.CreateUser("set_user_17")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.NoPassword = true
|
user.NoPassword = true
|
||||||
user.Passwords = []internal_acl.Password{
|
user.Passwords = []acl.Password{
|
||||||
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "password1"},
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "password1"},
|
||||||
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")},
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")},
|
||||||
}
|
}
|
||||||
user.Normalise()
|
user.Normalise()
|
||||||
return user
|
return user
|
||||||
@@ -892,24 +882,24 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_17")
|
user := acl.CreateUser("set_user_17")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.NoPassword = true
|
user.NoPassword = true
|
||||||
user.Passwords = []internal_acl.Password{}
|
user.Passwords = []acl.Password{}
|
||||||
user.Normalise()
|
user.Normalise()
|
||||||
return user
|
return user
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// 18. Clear all of an existing user's passwords using 'resetpass'
|
// 18. Clear all of an existing user's passwords using 'resetpass'
|
||||||
presetUser: func() *internal_acl.User {
|
presetUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_18")
|
user := acl.CreateUser("set_user_18")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.NoPassword = true
|
user.NoPassword = true
|
||||||
user.Passwords = []internal_acl.Password{
|
user.Passwords = []acl.Password{
|
||||||
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "password1"},
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "password1"},
|
||||||
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")},
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")},
|
||||||
}
|
}
|
||||||
user.Normalise()
|
user.Normalise()
|
||||||
return user
|
return user
|
||||||
@@ -923,19 +913,19 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_18")
|
user := acl.CreateUser("set_user_18")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.NoPassword = true
|
user.NoPassword = true
|
||||||
user.Passwords = []internal_acl.Password{}
|
user.Passwords = []acl.Password{}
|
||||||
user.Normalise()
|
user.Normalise()
|
||||||
return user
|
return user
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// 19. Clear all of an existing user's command privileges using 'nocommands'
|
// 19. Clear all of an existing user's command privileges using 'nocommands'
|
||||||
presetUser: func() *internal_acl.User {
|
presetUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_19")
|
user := acl.CreateUser("set_user_19")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.IncludedCommands = []string{"acl|getuser", "acl|setuser", "acl|deluser"}
|
user.IncludedCommands = []string{"acl|getuser", "acl|setuser", "acl|deluser"}
|
||||||
user.ExcludedCommands = []string{"rewriteaof", "save"}
|
user.ExcludedCommands = []string{"rewriteaof", "save"}
|
||||||
@@ -951,8 +941,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_19")
|
user := acl.CreateUser("set_user_19")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.IncludedCommands = []string{}
|
user.IncludedCommands = []string{}
|
||||||
user.ExcludedCommands = []string{"*"}
|
user.ExcludedCommands = []string{"*"}
|
||||||
@@ -964,8 +954,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
// 20. Clear all of an existing user's allowed keys using 'resetkeys'
|
// 20. Clear all of an existing user's allowed keys using 'resetkeys'
|
||||||
presetUser: func() *internal_acl.User {
|
presetUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_20")
|
user := acl.CreateUser("set_user_20")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.IncludedWriteKeys = []string{"key1", "key2", "key3", "key4", "key5", "key6"}
|
user.IncludedWriteKeys = []string{"key1", "key2", "key3", "key4", "key5", "key6"}
|
||||||
user.IncludedReadKeys = []string{"key1", "key2", "key3", "key7", "key8", "key9"}
|
user.IncludedReadKeys = []string{"key1", "key2", "key3", "key7", "key8", "key9"}
|
||||||
@@ -981,8 +971,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_20")
|
user := acl.CreateUser("set_user_20")
|
||||||
user.Enabled = true
|
user.Enabled = true
|
||||||
user.NoKeys = true
|
user.NoKeys = true
|
||||||
user.IncludedReadKeys = []string{}
|
user.IncludedReadKeys = []string{}
|
||||||
@@ -993,8 +983,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
// 21. Allow user to access all channels using 'resetchannels'
|
// 21. Allow user to access all channels using 'resetchannels'
|
||||||
presetUser: func() *internal_acl.User {
|
presetUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_21")
|
user := acl.CreateUser("set_user_21")
|
||||||
user.IncludedPubSubChannels = []string{"channel1", "channel2"}
|
user.IncludedPubSubChannels = []string{"channel1", "channel2"}
|
||||||
user.ExcludedPubSubChannels = []string{"channel3", "channel4"}
|
user.ExcludedPubSubChannels = []string{"channel3", "channel4"}
|
||||||
user.Normalise()
|
user.Normalise()
|
||||||
@@ -1008,8 +998,8 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
wantUser: func() *internal_acl.User {
|
wantUser: func() *acl.User {
|
||||||
user := internal_acl.CreateUser("set_user_21")
|
user := acl.CreateUser("set_user_21")
|
||||||
user.IncludedPubSubChannels = []string{}
|
user.IncludedPubSubChannels = []string{}
|
||||||
user.ExcludedPubSubChannels = []string{"*"}
|
user.ExcludedPubSubChannels = []string{"*"}
|
||||||
user.Normalise()
|
user.Normalise()
|
||||||
@@ -1020,7 +1010,7 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
|
|
||||||
for i, test := range tests {
|
for i, test := range tests {
|
||||||
if test.presetUser != nil {
|
if test.presetUser != nil {
|
||||||
acl.AddUsers([]*internal_acl.User{test.presetUser})
|
a.AddUsers([]*acl.User{test.presetUser})
|
||||||
}
|
}
|
||||||
if err = r.WriteArray(test.cmd); err != nil {
|
if err = r.WriteArray(test.cmd); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -1042,13 +1032,13 @@ func Test_HandleSetUser(t *testing.T) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
expectedUser := test.wantUser
|
expectedUser := test.wantUser
|
||||||
currUserIdx := slices.IndexFunc(acl.Users, func(user *internal_acl.User) bool {
|
currUserIdx := slices.IndexFunc(a.Users, func(user *acl.User) bool {
|
||||||
return user.Username == expectedUser.Username
|
return user.Username == expectedUser.Username
|
||||||
})
|
})
|
||||||
if currUserIdx == -1 {
|
if currUserIdx == -1 {
|
||||||
t.Errorf("expected to find user with username \"%s\" but could not find them.", expectedUser.Username)
|
t.Errorf("expected to find user with username \"%s\" but could not find them.", expectedUser.Username)
|
||||||
}
|
}
|
||||||
if err = compareUsers(expectedUser, acl.Users[currUserIdx]); err != nil {
|
if err = compareUsers(expectedUser, a.Users[currUserIdx]); err != nil {
|
||||||
t.Errorf("test idx: %d, %+v", i, err)
|
t.Errorf("test idx: %d, %+v", i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1065,7 +1055,7 @@ func Test_HandleGetUser(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
acl, _ := mockServer.GetACL().(*internal_acl.ACL)
|
a, _ := mockServer.GetACL().(*acl.ACL)
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1080,20 +1070,20 @@ func Test_HandleGetUser(t *testing.T) {
|
|||||||
r := resp.NewConn(conn)
|
r := resp.NewConn(conn)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
presetUser *internal_acl.User
|
presetUser *acl.User
|
||||||
cmd []resp.Value
|
cmd []resp.Value
|
||||||
wantRes []resp.Value
|
wantRes []resp.Value
|
||||||
wantErr string
|
wantErr string
|
||||||
}{
|
}{
|
||||||
{ // 1. Get the user and all their details
|
{ // 1. Get the user and all their details
|
||||||
presetUser: &internal_acl.User{
|
presetUser: &acl.User{
|
||||||
Username: "get_user_1",
|
Username: "get_user_1",
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
NoPassword: false,
|
NoPassword: false,
|
||||||
NoKeys: false,
|
NoKeys: false,
|
||||||
Passwords: []internal_acl.Password{
|
Passwords: []acl.Password{
|
||||||
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "get_user_password_1"},
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "get_user_password_1"},
|
||||||
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("get_user_password_2")},
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("get_user_password_2")},
|
||||||
},
|
},
|
||||||
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
||||||
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
||||||
@@ -1165,7 +1155,7 @@ func Test_HandleGetUser(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
if test.presetUser != nil {
|
if test.presetUser != nil {
|
||||||
acl.AddUsers([]*internal_acl.User{test.presetUser})
|
a.AddUsers([]*acl.User{test.presetUser})
|
||||||
}
|
}
|
||||||
if err = r.WriteArray(test.cmd); err != nil {
|
if err = r.WriteArray(test.cmd); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -1218,7 +1208,7 @@ func Test_HandleDelUser(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
acl, _ := mockServer.GetACL().(*internal_acl.ACL)
|
a, _ := mockServer.GetACL().(*acl.ACL)
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1233,14 +1223,14 @@ func Test_HandleDelUser(t *testing.T) {
|
|||||||
r := resp.NewConn(conn)
|
r := resp.NewConn(conn)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
presetUser *internal_acl.User
|
presetUser *acl.User
|
||||||
cmd []resp.Value
|
cmd []resp.Value
|
||||||
wantRes string
|
wantRes string
|
||||||
wantErr string
|
wantErr string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
// 1. Delete existing user while skipping default user and non-existent user
|
// 1. Delete existing user while skipping default user and non-existent user
|
||||||
presetUser: internal_acl.CreateUser("user_to_delete"),
|
presetUser: acl.CreateUser("user_to_delete"),
|
||||||
cmd: []resp.Value{
|
cmd: []resp.Value{
|
||||||
resp.StringValue("ACL"),
|
resp.StringValue("ACL"),
|
||||||
resp.StringValue("DELUSER"),
|
resp.StringValue("DELUSER"),
|
||||||
@@ -1262,7 +1252,7 @@ func Test_HandleDelUser(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
if test.presetUser != nil {
|
if test.presetUser != nil {
|
||||||
acl.AddUsers([]*internal_acl.User{test.presetUser})
|
a.AddUsers([]*acl.User{test.presetUser})
|
||||||
}
|
}
|
||||||
if err = r.WriteArray(test.cmd); err != nil {
|
if err = r.WriteArray(test.cmd); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -1278,13 +1268,13 @@ func Test_HandleDelUser(t *testing.T) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Check that default user still exists in the list of users
|
// Check that default user still exists in the list of users
|
||||||
if !slices.ContainsFunc(acl.Users, func(user *internal_acl.User) bool {
|
if !slices.ContainsFunc(a.Users, func(user *acl.User) bool {
|
||||||
return user.Username == "default"
|
return user.Username == "default"
|
||||||
}) {
|
}) {
|
||||||
t.Error("could not find user with username \"default\" in the ACL after deleting user")
|
t.Error("could not find user with username \"default\" in the ACL after deleting user")
|
||||||
}
|
}
|
||||||
// Check that the deleted user is no longer in the list
|
// Check that the deleted user is no longer in the list
|
||||||
if slices.ContainsFunc(acl.Users, func(user *internal_acl.User) bool {
|
if slices.ContainsFunc(a.Users, func(user *acl.User) bool {
|
||||||
return user.Username == "user_to_delete"
|
return user.Username == "user_to_delete"
|
||||||
}) {
|
}) {
|
||||||
t.Error("deleted user found in the ACL")
|
t.Error("deleted user found in the ACL")
|
||||||
@@ -1368,7 +1358,7 @@ func Test_HandleList(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
acl, _ := mockServer.GetACL().(*internal_acl.ACL)
|
a, _ := mockServer.GetACL().(*acl.ACL)
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1383,21 +1373,21 @@ func Test_HandleList(t *testing.T) {
|
|||||||
r := resp.NewConn(conn)
|
r := resp.NewConn(conn)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
presetUsers []*internal_acl.User
|
presetUsers []*acl.User
|
||||||
cmd []resp.Value
|
cmd []resp.Value
|
||||||
wantRes []string
|
wantRes []string
|
||||||
wantErr string
|
wantErr string
|
||||||
}{
|
}{
|
||||||
{ // 1. Get the user and all their details
|
{ // 1. Get the user and all their details
|
||||||
presetUsers: []*internal_acl.User{
|
presetUsers: []*acl.User{
|
||||||
{
|
{
|
||||||
Username: "list_user_1",
|
Username: "list_user_1",
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
NoPassword: false,
|
NoPassword: false,
|
||||||
NoKeys: false,
|
NoKeys: false,
|
||||||
Passwords: []internal_acl.Password{
|
Passwords: []acl.Password{
|
||||||
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "list_user_password_1"},
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "list_user_password_1"},
|
||||||
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_2")},
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_2")},
|
||||||
},
|
},
|
||||||
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
||||||
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
||||||
@@ -1413,7 +1403,7 @@ func Test_HandleList(t *testing.T) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
NoPassword: true,
|
NoPassword: true,
|
||||||
NoKeys: true,
|
NoKeys: true,
|
||||||
Passwords: []internal_acl.Password{},
|
Passwords: []acl.Password{},
|
||||||
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
||||||
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
||||||
IncludedCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"},
|
IncludedCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"},
|
||||||
@@ -1428,9 +1418,9 @@ func Test_HandleList(t *testing.T) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
NoPassword: false,
|
NoPassword: false,
|
||||||
NoKeys: false,
|
NoKeys: false,
|
||||||
Passwords: []internal_acl.Password{
|
Passwords: []acl.Password{
|
||||||
{PasswordType: internal_acl.PasswordPlainText, PasswordValue: "list_user_password_3"},
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "list_user_password_3"},
|
||||||
{PasswordType: internal_acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_4")},
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_4")},
|
||||||
},
|
},
|
||||||
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
||||||
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
||||||
@@ -1457,7 +1447,7 @@ func Test_HandleList(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
acl.AddUsers(test.presetUsers)
|
a.AddUsers(test.presetUsers)
|
||||||
|
|
||||||
if err = r.WriteArray(test.cmd); err != nil {
|
if err = r.WriteArray(test.cmd); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
|
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
"github.com/echovault/echovault/pkg/constants"
|
"github.com/echovault/echovault/pkg/constants"
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
"github.com/echovault/echovault/pkg/modules/admin"
|
|
||||||
"github.com/echovault/echovault/pkg/types"
|
"github.com/echovault/echovault/pkg/types"
|
||||||
"github.com/tidwall/resp"
|
"github.com/tidwall/resp"
|
||||||
"net"
|
"net"
|
||||||
@@ -33,7 +32,6 @@ var mockServer *echovault.EchoVault
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
mockServer, _ = echovault.NewEchoVault(
|
mockServer, _ = echovault.NewEchoVault(
|
||||||
echovault.WithCommands(admin.Commands()),
|
|
||||||
echovault.WithConfig(config.Config{
|
echovault.WithConfig(config.Config{
|
||||||
DataDir: "",
|
DataDir: "",
|
||||||
EvictionPolicy: constants.NoEviction,
|
EvictionPolicy: constants.NoEviction,
|
||||||
|
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
"github.com/echovault/echovault/pkg/constants"
|
"github.com/echovault/echovault/pkg/constants"
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
"github.com/echovault/echovault/pkg/modules/connection"
|
|
||||||
"github.com/echovault/echovault/pkg/types"
|
"github.com/echovault/echovault/pkg/types"
|
||||||
"github.com/tidwall/resp"
|
"github.com/tidwall/resp"
|
||||||
"net"
|
"net"
|
||||||
@@ -33,7 +32,6 @@ var mockServer *echovault.EchoVault
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
mockServer, _ = echovault.NewEchoVault(
|
mockServer, _ = echovault.NewEchoVault(
|
||||||
echovault.WithCommands(connection.Commands()),
|
|
||||||
echovault.WithConfig(config.Config{
|
echovault.WithConfig(config.Config{
|
||||||
DataDir: "",
|
DataDir: "",
|
||||||
EvictionPolicy: constants.NoEviction,
|
EvictionPolicy: constants.NoEviction,
|
||||||
|
@@ -20,7 +20,6 @@ import (
|
|||||||
"github.com/echovault/echovault/internal/clock"
|
"github.com/echovault/echovault/internal/clock"
|
||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
"github.com/echovault/echovault/pkg/modules/generic"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -30,7 +29,6 @@ import (
|
|||||||
|
|
||||||
func createEchoVault() *echovault.EchoVault {
|
func createEchoVault() *echovault.EchoVault {
|
||||||
ev, _ := echovault.NewEchoVault(
|
ev, _ := echovault.NewEchoVault(
|
||||||
echovault.WithCommands(generic.Commands()),
|
|
||||||
echovault.WithConfig(config.Config{
|
echovault.WithConfig(config.Config{
|
||||||
DataDir: "",
|
DataDir: "",
|
||||||
}),
|
}),
|
||||||
|
@@ -23,7 +23,6 @@ import (
|
|||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
"github.com/echovault/echovault/pkg/constants"
|
"github.com/echovault/echovault/pkg/constants"
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
"github.com/echovault/echovault/pkg/modules/generic"
|
|
||||||
"github.com/echovault/echovault/pkg/types"
|
"github.com/echovault/echovault/pkg/types"
|
||||||
"github.com/tidwall/resp"
|
"github.com/tidwall/resp"
|
||||||
"net"
|
"net"
|
||||||
@@ -45,7 +44,6 @@ func init() {
|
|||||||
mockClock = clock.NewClock()
|
mockClock = clock.NewClock()
|
||||||
|
|
||||||
mockServer, _ = echovault.NewEchoVault(
|
mockServer, _ = echovault.NewEchoVault(
|
||||||
echovault.WithCommands(generic.Commands()),
|
|
||||||
echovault.WithConfig(config.Config{
|
echovault.WithConfig(config.Config{
|
||||||
DataDir: "",
|
DataDir: "",
|
||||||
EvictionPolicy: constants.NoEviction,
|
EvictionPolicy: constants.NoEviction,
|
||||||
|
@@ -18,7 +18,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
"github.com/echovault/echovault/pkg/modules/hash"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -26,7 +25,6 @@ import (
|
|||||||
|
|
||||||
func createEchoVault() *echovault.EchoVault {
|
func createEchoVault() *echovault.EchoVault {
|
||||||
ev, _ := echovault.NewEchoVault(
|
ev, _ := echovault.NewEchoVault(
|
||||||
echovault.WithCommands(hash.Commands()),
|
|
||||||
echovault.WithConfig(config.Config{
|
echovault.WithConfig(config.Config{
|
||||||
DataDir: "",
|
DataDir: "",
|
||||||
}),
|
}),
|
||||||
|
@@ -22,7 +22,6 @@ import (
|
|||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
"github.com/echovault/echovault/pkg/constants"
|
"github.com/echovault/echovault/pkg/constants"
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
"github.com/echovault/echovault/pkg/modules/hash"
|
|
||||||
"github.com/echovault/echovault/pkg/types"
|
"github.com/echovault/echovault/pkg/types"
|
||||||
"github.com/tidwall/resp"
|
"github.com/tidwall/resp"
|
||||||
"net"
|
"net"
|
||||||
@@ -35,7 +34,6 @@ var mockServer *echovault.EchoVault
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
mockServer, _ = echovault.NewEchoVault(
|
mockServer, _ = echovault.NewEchoVault(
|
||||||
echovault.WithCommands(hash.Commands()),
|
|
||||||
echovault.WithConfig(config.Config{
|
echovault.WithConfig(config.Config{
|
||||||
DataDir: "",
|
DataDir: "",
|
||||||
EvictionPolicy: constants.NoEviction,
|
EvictionPolicy: constants.NoEviction,
|
||||||
|
@@ -18,14 +18,12 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
"github.com/echovault/echovault/pkg/modules/list"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createEchoVault() *echovault.EchoVault {
|
func createEchoVault() *echovault.EchoVault {
|
||||||
ev, _ := echovault.NewEchoVault(
|
ev, _ := echovault.NewEchoVault(
|
||||||
echovault.WithCommands(list.Commands()),
|
|
||||||
echovault.WithConfig(config.Config{
|
echovault.WithConfig(config.Config{
|
||||||
DataDir: "",
|
DataDir: "",
|
||||||
}),
|
}),
|
||||||
|
@@ -22,7 +22,6 @@ import (
|
|||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
"github.com/echovault/echovault/pkg/constants"
|
"github.com/echovault/echovault/pkg/constants"
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
"github.com/echovault/echovault/pkg/modules/list"
|
|
||||||
"github.com/echovault/echovault/pkg/types"
|
"github.com/echovault/echovault/pkg/types"
|
||||||
"github.com/tidwall/resp"
|
"github.com/tidwall/resp"
|
||||||
"net"
|
"net"
|
||||||
@@ -34,7 +33,6 @@ var mockServer *echovault.EchoVault
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
mockServer, _ = echovault.NewEchoVault(
|
mockServer, _ = echovault.NewEchoVault(
|
||||||
echovault.WithCommands(list.Commands()),
|
|
||||||
echovault.WithConfig(config.Config{
|
echovault.WithConfig(config.Config{
|
||||||
DataDir: "",
|
DataDir: "",
|
||||||
EvictionPolicy: constants.NoEviction,
|
EvictionPolicy: constants.NoEviction,
|
||||||
|
@@ -19,10 +19,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
internal_pubsub "github.com/echovault/echovault/internal/pubsub"
|
"github.com/echovault/echovault/internal/modules/pubsub"
|
||||||
"github.com/echovault/echovault/pkg/constants"
|
"github.com/echovault/echovault/pkg/constants"
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
ps "github.com/echovault/echovault/pkg/modules/pubsub"
|
|
||||||
"github.com/echovault/echovault/pkg/types"
|
"github.com/echovault/echovault/pkg/types"
|
||||||
"github.com/tidwall/resp"
|
"github.com/tidwall/resp"
|
||||||
"net"
|
"net"
|
||||||
@@ -33,7 +32,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var pubsub *internal_pubsub.PubSub
|
var ps *pubsub.PubSub
|
||||||
var mockServer *echovault.EchoVault
|
var mockServer *echovault.EchoVault
|
||||||
|
|
||||||
var bindAddr = "localhost"
|
var bindAddr = "localhost"
|
||||||
@@ -41,7 +40,7 @@ var port uint16 = 7490
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
mockServer = setUpServer(bindAddr, port)
|
mockServer = setUpServer(bindAddr, port)
|
||||||
pubsub = mockServer.GetPubSub().(*internal_pubsub.PubSub)
|
ps = mockServer.GetPubSub().(*pubsub.PubSub)
|
||||||
|
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@@ -54,7 +53,6 @@ func init() {
|
|||||||
|
|
||||||
func setUpServer(bindAddr string, port uint16) *echovault.EchoVault {
|
func setUpServer(bindAddr string, port uint16) *echovault.EchoVault {
|
||||||
server, _ := echovault.NewEchoVault(
|
server, _ := echovault.NewEchoVault(
|
||||||
echovault.WithCommands(ps.Commands()),
|
|
||||||
echovault.WithConfig(config.Config{
|
echovault.WithConfig(config.Config{
|
||||||
BindAddr: bindAddr,
|
BindAddr: bindAddr,
|
||||||
Port: port,
|
Port: port,
|
||||||
@@ -126,12 +124,12 @@ func Test_HandleSubscribe(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
// Check if the channel exists in the pubsub module
|
// Check if the channel exists in the pubsub module
|
||||||
if !slices.ContainsFunc(pubsub.GetAllChannels(), func(c *internal_pubsub.Channel) bool {
|
if !slices.ContainsFunc(ps.GetAllChannels(), func(c *pubsub.Channel) bool {
|
||||||
return c.Name() == channel
|
return c.Name() == channel
|
||||||
}) {
|
}) {
|
||||||
t.Errorf("expected pubsub to contain channel \"%s\" but it was not found", channel)
|
t.Errorf("expected pubsub to contain channel \"%s\" but it was not found", channel)
|
||||||
}
|
}
|
||||||
for _, c := range pubsub.GetAllChannels() {
|
for _, c := range ps.GetAllChannels() {
|
||||||
if c.Name() == channel {
|
if c.Name() == channel {
|
||||||
// Check if channel has nil pattern
|
// Check if channel has nil pattern
|
||||||
if c.Pattern() != nil {
|
if c.Pattern() != nil {
|
||||||
@@ -157,12 +155,12 @@ func Test_HandleSubscribe(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, pattern := range patterns {
|
for _, pattern := range patterns {
|
||||||
// Check if pattern channel exists in pubsub module
|
// Check if pattern channel exists in pubsub module
|
||||||
if !slices.ContainsFunc(pubsub.GetAllChannels(), func(c *internal_pubsub.Channel) bool {
|
if !slices.ContainsFunc(ps.GetAllChannels(), func(c *pubsub.Channel) bool {
|
||||||
return c.Name() == pattern
|
return c.Name() == pattern
|
||||||
}) {
|
}) {
|
||||||
t.Errorf("expected pubsub to contain pattern channel \"%s\" but it was not found", pattern)
|
t.Errorf("expected pubsub to contain pattern channel \"%s\" but it was not found", pattern)
|
||||||
}
|
}
|
||||||
for _, c := range pubsub.GetAllChannels() {
|
for _, c := range ps.GetAllChannels() {
|
||||||
if c.Name() == pattern {
|
if c.Name() == pattern {
|
||||||
// Check if channel has non-nil pattern
|
// Check if channel has non-nil pattern
|
||||||
if c.Pattern() == nil {
|
if c.Pattern() == nil {
|
||||||
@@ -322,7 +320,7 @@ func Test_HandleUnsubscribe(t *testing.T) {
|
|||||||
verifyResponse(res, test.expectedResponses["pattern"])
|
verifyResponse(res, test.expectedResponses["pattern"])
|
||||||
|
|
||||||
for _, channel := range append(test.unSubChannels, test.unSubPatterns...) {
|
for _, channel := range append(test.unSubChannels, test.unSubPatterns...) {
|
||||||
for _, pubsubChannel := range pubsub.GetAllChannels() {
|
for _, pubsubChannel := range ps.GetAllChannels() {
|
||||||
if pubsubChannel.Name() == channel {
|
if pubsubChannel.Name() == channel {
|
||||||
// Assert that target connection is no longer in the unsub channels and patterns
|
// Assert that target connection is no longer in the unsub channels and patterns
|
||||||
if _, ok := pubsubChannel.Subscribers()[test.targetConn]; ok {
|
if _, ok := pubsubChannel.Subscribers()[test.targetConn]; ok {
|
||||||
@@ -339,7 +337,7 @@ func Test_HandleUnsubscribe(t *testing.T) {
|
|||||||
|
|
||||||
// Assert that the target connection is still in the remain channels and patterns
|
// Assert that the target connection is still in the remain channels and patterns
|
||||||
for _, channel := range append(test.remainChannels, test.remainPatterns...) {
|
for _, channel := range append(test.remainChannels, test.remainPatterns...) {
|
||||||
for _, pubsubChannel := range pubsub.GetAllChannels() {
|
for _, pubsubChannel := range ps.GetAllChannels() {
|
||||||
if pubsubChannel.Name() == channel {
|
if pubsubChannel.Name() == channel {
|
||||||
if _, ok := pubsubChannel.Subscribers()[test.targetConn]; !ok {
|
if _, ok := pubsubChannel.Subscribers()[test.targetConn]; !ok {
|
||||||
t.Errorf("could not find expected target connection in channel \"%s\"", channel)
|
t.Errorf("could not find expected target connection in channel \"%s\"", channel)
|
||||||
|
@@ -17,9 +17,8 @@ package set
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
"github.com/echovault/echovault/internal/set"
|
"github.com/echovault/echovault/internal/modules/set"
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
s "github.com/echovault/echovault/pkg/modules/set"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -27,7 +26,6 @@ import (
|
|||||||
|
|
||||||
func createEchoVault() *echovault.EchoVault {
|
func createEchoVault() *echovault.EchoVault {
|
||||||
ev, _ := echovault.NewEchoVault(
|
ev, _ := echovault.NewEchoVault(
|
||||||
echovault.WithCommands(s.Commands()),
|
|
||||||
echovault.WithConfig(config.Config{
|
echovault.WithConfig(config.Config{
|
||||||
DataDir: "",
|
DataDir: "",
|
||||||
}),
|
}),
|
||||||
|
@@ -20,10 +20,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
"github.com/echovault/echovault/internal/set"
|
"github.com/echovault/echovault/internal/modules/set"
|
||||||
"github.com/echovault/echovault/pkg/constants"
|
"github.com/echovault/echovault/pkg/constants"
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
s "github.com/echovault/echovault/pkg/modules/set"
|
|
||||||
"github.com/echovault/echovault/pkg/types"
|
"github.com/echovault/echovault/pkg/types"
|
||||||
"github.com/tidwall/resp"
|
"github.com/tidwall/resp"
|
||||||
"net"
|
"net"
|
||||||
@@ -36,7 +35,6 @@ var mockServer *echovault.EchoVault
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
mockServer, _ = echovault.NewEchoVault(
|
mockServer, _ = echovault.NewEchoVault(
|
||||||
echovault.WithCommands(s.Commands()),
|
|
||||||
echovault.WithConfig(config.Config{
|
echovault.WithConfig(config.Config{
|
||||||
DataDir: "",
|
DataDir: "",
|
||||||
EvictionPolicy: constants.NoEviction,
|
EvictionPolicy: constants.NoEviction,
|
||||||
|
File diff suppressed because it is too large
Load Diff
@@ -20,10 +20,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
"github.com/echovault/echovault/internal/sorted_set"
|
"github.com/echovault/echovault/internal/modules/sorted_set"
|
||||||
"github.com/echovault/echovault/pkg/constants"
|
"github.com/echovault/echovault/pkg/constants"
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
ss "github.com/echovault/echovault/pkg/modules/sorted_set"
|
|
||||||
"github.com/echovault/echovault/pkg/types"
|
"github.com/echovault/echovault/pkg/types"
|
||||||
"github.com/tidwall/resp"
|
"github.com/tidwall/resp"
|
||||||
"math"
|
"math"
|
||||||
@@ -38,7 +37,6 @@ var mockServer *echovault.EchoVault
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
mockServer, _ = echovault.NewEchoVault(
|
mockServer, _ = echovault.NewEchoVault(
|
||||||
echovault.WithCommands(ss.Commands()),
|
|
||||||
echovault.WithConfig(config.Config{
|
echovault.WithConfig(config.Config{
|
||||||
DataDir: "",
|
DataDir: "",
|
||||||
EvictionPolicy: constants.NoEviction,
|
EvictionPolicy: constants.NoEviction,
|
||||||
|
@@ -18,13 +18,11 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
str "github.com/echovault/echovault/pkg/modules/string"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createEchoVault() *echovault.EchoVault {
|
func createEchoVault() *echovault.EchoVault {
|
||||||
ev, _ := echovault.NewEchoVault(
|
ev, _ := echovault.NewEchoVault(
|
||||||
echovault.WithCommands(str.Commands()),
|
|
||||||
echovault.WithConfig(config.Config{
|
echovault.WithConfig(config.Config{
|
||||||
DataDir: "",
|
DataDir: "",
|
||||||
}),
|
}),
|
||||||
|
@@ -23,7 +23,6 @@ import (
|
|||||||
"github.com/echovault/echovault/internal/config"
|
"github.com/echovault/echovault/internal/config"
|
||||||
"github.com/echovault/echovault/pkg/constants"
|
"github.com/echovault/echovault/pkg/constants"
|
||||||
"github.com/echovault/echovault/pkg/echovault"
|
"github.com/echovault/echovault/pkg/echovault"
|
||||||
str "github.com/echovault/echovault/pkg/modules/string"
|
|
||||||
"github.com/echovault/echovault/pkg/types"
|
"github.com/echovault/echovault/pkg/types"
|
||||||
"github.com/tidwall/resp"
|
"github.com/tidwall/resp"
|
||||||
"net"
|
"net"
|
||||||
@@ -36,7 +35,6 @@ var mockServer *echovault.EchoVault
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
mockServer, _ = echovault.NewEchoVault(
|
mockServer, _ = echovault.NewEchoVault(
|
||||||
echovault.WithCommands(str.Commands()),
|
|
||||||
echovault.WithConfig(config.Config{
|
echovault.WithConfig(config.Config{
|
||||||
DataDir: "",
|
DataDir: "",
|
||||||
EvictionPolicy: constants.NoEviction,
|
EvictionPolicy: constants.NoEviction,
|
||||||
|
Reference in New Issue
Block a user