Moved tests for module commands and apis into 'test' folder

This commit is contained in:
Kelvin Clement Mwinuka
2024-04-24 21:36:59 +08:00
parent fbf4782b7c
commit 3e04b7a822
36 changed files with 2559 additions and 4912 deletions

View File

@@ -7,7 +7,7 @@ build:
run:
make build && docker-compose up --build
test:
test-normal:
go clean -testcache && go test ./... -coverprofile coverage/coverage.out
test-race:

File diff suppressed because it is too large Load Diff

View File

@@ -24,6 +24,7 @@ import (
"github.com/hashicorp/raft"
"io"
"log"
"net"
"strings"
)
@@ -36,6 +37,7 @@ type FSMOpts struct {
StartSnapshot func()
FinishSnapshot func()
SetLatestSnapshotTime func(msec int64)
GetHandlerFuncParams func(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams
}
type FSM struct {
@@ -86,34 +88,33 @@ func (fsm *FSM) Apply(log *raft.Log) interface{} {
}
case "command":
// TODO: Re-Implement Command handling with dependency injection
// Handle command
// command, err := fsm.options.GetCommand(request.CMD[0])
// if err != nil {
// return internal.ApplyResponse{
// Error: err,
// Response: nil,
// }
// }
//
// handler := command.HandlerFunc
//
// subCommand, ok := internal.GetSubCommand(command, request.CMD).(types.SubCommand)
// if ok {
// handler = subCommand.HandlerFunc
// }
//
// if res, err := handler(ctx, request.CMD, fsm.options.EchoVault, nil); err != nil {
// return internal.ApplyResponse{
// Error: err,
// Response: nil,
// }
// } else {
// return internal.ApplyResponse{
// Error: nil,
// Response: res,
// }
// }
command, err := fsm.options.GetCommand(request.CMD[0])
if err != nil {
return internal.ApplyResponse{
Error: err,
Response: nil,
}
}
handler := command.HandlerFunc
subCommand, ok := internal.GetSubCommand(command, request.CMD).(types.SubCommand)
if ok {
handler = subCommand.HandlerFunc
}
if res, err := handler(fsm.options.GetHandlerFuncParams(ctx, request.CMD, nil)); err != nil {
return internal.ApplyResponse{
Error: err,
Response: nil,
}
} else {
return internal.ApplyResponse{
Error: nil,
Response: res,
}
}
}
}

View File

@@ -41,6 +41,7 @@ type Opts struct {
StartSnapshot func()
FinishSnapshot func()
SetLatestSnapshotTime func(msec int64)
GetHandlerFuncParams func(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams
}
type Raft struct {
@@ -120,6 +121,7 @@ func (r *Raft) RaftInit(ctx context.Context) {
StartSnapshot: r.options.StartSnapshot,
FinishSnapshot: r.options.FinishSnapshot,
SetLatestSnapshotTime: r.options.SetLatestSnapshotTime,
GetHandlerFuncParams: r.options.GetHandlerFuncParams,
}),
logStore,
stableStore,

View File

@@ -157,6 +157,7 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
StartSnapshot: echovault.startSnapshot,
FinishSnapshot: echovault.finishSnapshot,
SetLatestSnapshotTime: echovault.setLatestSnapshot,
GetHandlerFuncParams: echovault.getHandlerFuncParams,
GetState: func() map[string]internal.KeyData {
state := make(map[string]internal.KeyData)
for k, v := range echovault.getState() {

View File

@@ -48,7 +48,6 @@ func (server *EchoVault) getCommand(cmd string) (types.Command, error) {
func (server *EchoVault) getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
return types.HandlerFuncParams{
// TODO: Add all the required methods here
Context: ctx,
Command: cmd,
Connection: conn,
@@ -60,6 +59,13 @@ func (server *EchoVault) getHandlerFuncParams(ctx context.Context, cmd []string,
KeyRUnlock: server.KeyRUnlock,
GetValue: server.GetValue,
SetValue: server.SetValue,
GetClock: server.GetClock,
GetExpiry: server.GetExpiry,
SetExpiry: server.SetExpiry,
DeleteKey: server.DeleteKey,
GetPubSub: server.GetPubSub,
GetACL: server.GetACL,
GetAllCommands: server.GetAllCommands,
}
}

View File

@@ -15,7 +15,6 @@
package acl
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -24,33 +23,32 @@ import (
"github.com/echovault/echovault/pkg/types"
"gopkg.in/yaml.v3"
"log"
"net"
"os"
"path"
"slices"
"strings"
)
func handleAuth(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
if len(cmd) < 2 || len(cmd) > 3 {
func handleAuth(params types.HandlerFuncParams) ([]byte, error) {
if len(params.Command) < 2 || len(params.Command) > 3 {
return nil, errors.New(constants.WrongArgsResponse)
}
acl, ok := server.GetACL().(*internal_acl.ACL)
acl, ok := params.GetACL().(*internal_acl.ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
if err := acl.AuthenticateConnection(ctx, conn, cmd); err != nil {
if err := acl.AuthenticateConnection(params.Context, params.Connection, params.Command); err != nil {
return nil, err
}
return []byte(constants.OkResponse), nil
}
func handleGetUser(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
if len(cmd) != 3 {
func handleGetUser(params types.HandlerFuncParams) ([]byte, error) {
if len(params.Command) != 3 {
return nil, errors.New(constants.WrongArgsResponse)
}
acl, ok := server.GetACL().(*internal_acl.ACL)
acl, ok := params.GetACL().(*internal_acl.ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
@@ -58,7 +56,7 @@ func handleGetUser(_ context.Context, cmd []string, server types.EchoVault, _ *n
var user *internal_acl.User
userFound := false
for _, u := range acl.Users {
if u.Username == cmd[2] {
if u.Username == params.Command[2] {
user = u
userFound = true
break
@@ -162,14 +160,14 @@ func handleGetUser(_ context.Context, cmd []string, server types.EchoVault, _ *n
return []byte(res), nil
}
func handleCat(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
if len(cmd) > 3 {
func handleCat(params types.HandlerFuncParams) ([]byte, error) {
if len(params.Command) > 3 {
return nil, errors.New(constants.WrongArgsResponse)
}
categories := make(map[string][]string)
commands := server.GetAllCommands()
commands := params.GetAllCommands()
for _, command := range commands {
if len(command.SubCommands) == 0 {
@@ -186,7 +184,7 @@ func handleCat(_ context.Context, cmd []string, server types.EchoVault, _ *net.C
}
}
if len(cmd) == 2 {
if len(params.Command) == 2 {
var cats []string
length := 0
for key, _ := range categories {
@@ -203,10 +201,10 @@ func handleCat(_ context.Context, cmd []string, server types.EchoVault, _ *net.C
return []byte(res), nil
}
if len(cmd) == 3 {
if len(params.Command) == 3 {
var res string
for category, commands := range categories {
if strings.EqualFold(category, cmd[2]) {
if strings.EqualFold(category, params.Command[2]) {
res = fmt.Sprintf("*%d", len(commands))
for i, command := range commands {
res = fmt.Sprintf("%s\r\n+%s", res, command)
@@ -219,11 +217,11 @@ func handleCat(_ context.Context, cmd []string, server types.EchoVault, _ *net.C
}
}
return nil, fmt.Errorf("category %s not found", strings.ToUpper(cmd[2]))
return nil, fmt.Errorf("category %s not found", strings.ToUpper(params.Command[2]))
}
func handleUsers(_ context.Context, _ []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
acl, ok := server.GetACL().(*internal_acl.ACL)
func handleUsers(params types.HandlerFuncParams) ([]byte, error) {
acl, ok := params.GetACL().(*internal_acl.ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
@@ -235,45 +233,45 @@ func handleUsers(_ context.Context, _ []string, server types.EchoVault, _ *net.C
return []byte(res), nil
}
func handleSetUser(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
acl, ok := server.GetACL().(*internal_acl.ACL)
func handleSetUser(params types.HandlerFuncParams) ([]byte, error) {
acl, ok := params.GetACL().(*internal_acl.ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
if err := acl.SetUser(cmd[2:]); err != nil {
if err := acl.SetUser(params.Command[2:]); err != nil {
return nil, err
}
return []byte(constants.OkResponse), nil
}
func handleDelUser(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
if len(cmd) < 3 {
func handleDelUser(params types.HandlerFuncParams) ([]byte, error) {
if len(params.Command) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
}
acl, ok := server.GetACL().(*internal_acl.ACL)
acl, ok := params.GetACL().(*internal_acl.ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
if err := acl.DeleteUser(ctx, cmd[2:]); err != nil {
if err := acl.DeleteUser(params.Context, params.Command[2:]); err != nil {
return nil, err
}
return []byte(constants.OkResponse), nil
}
func handleWhoAmI(_ context.Context, _ []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
acl, ok := server.GetACL().(*internal_acl.ACL)
func handleWhoAmI(params types.HandlerFuncParams) ([]byte, error) {
acl, ok := params.GetACL().(*internal_acl.ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
connectionInfo := acl.Connections[conn]
connectionInfo := acl.Connections[params.Connection]
return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil
}
func handleList(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
if len(cmd) > 2 {
func handleList(params types.HandlerFuncParams) ([]byte, error) {
if len(params.Command) > 2 {
return nil, errors.New(constants.WrongArgsResponse)
}
acl, ok := server.GetACL().(*internal_acl.ACL)
acl, ok := params.GetACL().(*internal_acl.ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
@@ -365,12 +363,12 @@ func handleList(_ context.Context, cmd []string, server types.EchoVault, _ *net.
return []byte(res), nil
}
func handleLoad(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
if len(cmd) != 3 {
func handleLoad(params types.HandlerFuncParams) ([]byte, error) {
if len(params.Command) != 3 {
return nil, errors.New(constants.WrongArgsResponse)
}
acl, ok := server.GetACL().(*internal_acl.ACL)
acl, ok := params.GetACL().(*internal_acl.ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
@@ -414,7 +412,7 @@ func handleLoad(_ context.Context, cmd []string, server types.EchoVault, _ *net.
if u.Username == user.Username {
userFound = true
// If we have a user with the current username and are in merge mode, merge the two users.
if strings.EqualFold(cmd[2], "merge") {
if strings.EqualFold(params.Command[2], "merge") {
u.Merge(user)
} else {
// If we have a user with the current username and are in replace mode, merge the two users.
@@ -432,12 +430,12 @@ func handleLoad(_ context.Context, cmd []string, server types.EchoVault, _ *net.
return []byte(constants.OkResponse), nil
}
func handleSave(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
if len(cmd) > 2 {
func handleSave(params types.HandlerFuncParams) ([]byte, error) {
if len(params.Command) > 2 {
return nil, errors.New(constants.WrongArgsResponse)
}
acl, ok := server.GetACL().(*internal_acl.ACL)
acl, ok := params.GetACL().(*internal_acl.ACL)
if !ok {
return nil, errors.New("could not load ACL")
}

View File

@@ -15,19 +15,17 @@
package admin
import (
"context"
"errors"
"fmt"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"github.com/gobwas/glob"
"net"
"slices"
"strings"
)
func handleGetAllCommands(_ context.Context, _ []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
commands := server.GetAllCommands()
func handleGetAllCommands(params types.HandlerFuncParams) ([]byte, error) {
commands := params.GetAllCommands()
res := ""
commandCount := 0
@@ -71,10 +69,10 @@ func handleGetAllCommands(_ context.Context, _ []string, server types.EchoVault,
return []byte(res), nil
}
func handleCommandCount(_ context.Context, _ []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
func handleCommandCount(params types.HandlerFuncParams) ([]byte, error) {
var count int
commands := server.GetAllCommands()
commands := params.GetAllCommands()
for _, command := range commands {
if command.SubCommands != nil && len(command.SubCommands) > 0 {
for _, _ = range command.SubCommands {
@@ -88,13 +86,13 @@ func handleCommandCount(_ context.Context, _ []string, server types.EchoVault, _
return []byte(fmt.Sprintf(":%d\r\n", count)), nil
}
func handleCommandList(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
switch len(cmd) {
func handleCommandList(params types.HandlerFuncParams) ([]byte, error) {
switch len(params.Command) {
case 2:
// Command is COMMAND LIST
var count int
var res string
commands := server.GetAllCommands()
commands := params.GetAllCommands()
for _, command := range commands {
if command.SubCommands != nil && len(command.SubCommands) > 0 {
for _, subcommand := range command.SubCommands {
@@ -114,13 +112,13 @@ func handleCommandList(_ context.Context, cmd []string, server types.EchoVault,
var count int
var res string
// Command has filter
if !strings.EqualFold("FILTERBY", cmd[2]) {
return nil, fmt.Errorf("expected FILTERBY, got %s", strings.ToUpper(cmd[2]))
if !strings.EqualFold("FILTERBY", params.Command[2]) {
return nil, fmt.Errorf("expected FILTERBY, got %s", strings.ToUpper(params.Command[2]))
}
if strings.EqualFold("ACLCAT", cmd[3]) {
if strings.EqualFold("ACLCAT", params.Command[3]) {
// ACL Category filter
commands := server.GetAllCommands()
category := strings.ToLower(cmd[4])
commands := params.GetAllCommands()
category := strings.ToLower(params.Command[4])
for _, command := range commands {
if command.SubCommands != nil && len(command.SubCommands) > 0 {
for _, subcommand := range command.SubCommands {
@@ -137,10 +135,10 @@ func handleCommandList(_ context.Context, cmd []string, server types.EchoVault,
count += 1
}
}
} else if strings.EqualFold("PATTERN", cmd[3]) {
} else if strings.EqualFold("PATTERN", params.Command[3]) {
// Pattern filter
commands := server.GetAllCommands()
g := glob.MustCompile(cmd[4])
commands := params.GetAllCommands()
g := glob.MustCompile(params.Command[4])
for _, command := range commands {
if command.SubCommands != nil && len(command.SubCommands) > 0 {
for _, subcommand := range command.SubCommands {
@@ -157,10 +155,10 @@ func handleCommandList(_ context.Context, cmd []string, server types.EchoVault,
count += 1
}
}
} else if strings.EqualFold("MODULE", cmd[3]) {
} else if strings.EqualFold("MODULE", params.Command[3]) {
// Module filter
commands := server.GetAllCommands()
module := strings.ToLower(cmd[4])
commands := params.GetAllCommands()
module := strings.ToLower(params.Command[4])
for _, command := range commands {
if command.SubCommands != nil && len(command.SubCommands) > 0 {
for _, subcommand := range command.SubCommands {
@@ -178,7 +176,7 @@ func handleCommandList(_ context.Context, cmd []string, server types.EchoVault,
}
}
} else {
return nil, fmt.Errorf("expected filter to be ACLCAT or PATTERN, got %s", strings.ToUpper(cmd[3]))
return nil, fmt.Errorf("expected filter to be ACLCAT or PATTERN, got %s", strings.ToUpper(params.Command[3]))
}
res = fmt.Sprintf("*%d\r\n%s", count, res)
return []byte(res), nil
@@ -187,7 +185,7 @@ func handleCommandList(_ context.Context, cmd []string, server types.EchoVault,
}
}
func handleCommandDocs(_ context.Context, _ []string, _ types.EchoVault, _ *net.Conn) ([]byte, error) {
func handleCommandDocs(params types.HandlerFuncParams) ([]byte, error) {
return []byte("*0\r\n"), nil
}
@@ -283,8 +281,8 @@ Allows for filtering by ACL category or glob pattern.`,
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: func(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
if err := server.TakeSnapshot(); err != nil {
HandlerFunc: func(params types.HandlerFuncParams) ([]byte, error) {
if err := params.TakeSnapshot(); err != nil {
return nil, err
}
return []byte(constants.OkResponse), nil
@@ -303,8 +301,8 @@ Allows for filtering by ACL category or glob pattern.`,
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: func(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
msec := server.GetLatestSnapshotTime()
HandlerFunc: func(params types.HandlerFuncParams) ([]byte, error) {
msec := params.GetLatestSnapshotTime()
if msec == 0 {
return nil, errors.New("no snapshot")
}
@@ -324,8 +322,8 @@ Allows for filtering by ACL category or glob pattern.`,
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: func(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
if err := server.RewriteAOF(); err != nil {
HandlerFunc: func(params types.HandlerFuncParams) ([]byte, error) {
if err := params.RewriteAOF(); err != nil {
return nil, err
}
return []byte(constants.OkResponse), nil

View File

@@ -15,29 +15,27 @@
package connection
import (
"context"
"errors"
"fmt"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"net"
)
func handlePing(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
switch len(cmd) {
func handlePing(params types.HandlerFuncParams) ([]byte, error) {
switch len(params.Command) {
default:
return nil, errors.New(constants.WrongArgsResponse)
case 1:
return []byte("+PONG\r\n"), nil
case 2:
return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(cmd[1]), cmd[1])), nil
return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(params.Command[1]), params.Command[1])), nil
}
}
func Commands() []types.Command {
return []types.Command{
{
Command: "connection",
Command: "ping",
Module: constants.ConnectionModule,
Categories: []string{constants.FastCategory, constants.ConnectionCategory},
Description: "(PING [value]) Ping the echovault. If a value is provided, the value will be echoed.",

View File

@@ -15,14 +15,12 @@
package generic
import (
"context"
"errors"
"fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"log"
"net"
"strconv"
"strings"
"time"
@@ -33,73 +31,73 @@ type KeyObject struct {
locked bool
}
func handleSet(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := setKeyFunc(cmd)
func handleSet(params types.HandlerFuncParams) ([]byte, error) {
keys, err := setKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.WriteKeys[0]
value := cmd[2]
value := params.Command[2]
res := []byte(constants.OkResponse)
clock := server.GetClock()
clock := params.GetClock()
params, err := getSetCommandParams(clock, cmd[3:], SetParams{})
options, err := getSetCommandOptions(clock, params.Command[3:], SetOptions{})
if err != nil {
return nil, err
}
// If GET is provided, the response should be the current stored value.
// If there's no current value, then the response should be nil.
if params.get {
if !server.KeyExists(ctx, key) {
if options.get {
if !params.KeyExists(params.Context, key) {
res = []byte("$-1\r\n")
} else {
res = []byte(fmt.Sprintf("+%v\r\n", server.GetValue(ctx, key)))
res = []byte(fmt.Sprintf("+%v\r\n", params.GetValue(params.Context, key)))
}
}
if "xx" == strings.ToLower(params.exists) {
if "xx" == strings.ToLower(options.exists) {
// If XX is specified, make sure the key exists.
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return nil, fmt.Errorf("key %s does not exist", key)
}
_, err = server.KeyLock(ctx, key)
} else if "nx" == strings.ToLower(params.exists) {
_, err = params.KeyLock(params.Context, key)
} else if "nx" == strings.ToLower(options.exists) {
// If NX is specified, make sure that the key does not currently exist.
if server.KeyExists(ctx, key) {
if params.KeyExists(params.Context, key) {
return nil, fmt.Errorf("key %s already exists", key)
}
_, err = server.CreateKeyAndLock(ctx, key)
_, err = params.CreateKeyAndLock(params.Context, key)
} else {
// Neither XX not NX are specified, lock or create the lock
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
// Key does not exist, create it
_, err = server.CreateKeyAndLock(ctx, key)
_, err = params.CreateKeyAndLock(params.Context, key)
} else {
// Key exists, acquire the lock
_, err = server.KeyLock(ctx, key)
_, err = params.KeyLock(params.Context, key)
}
}
if err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
if err = server.SetValue(ctx, key, internal.AdaptType(value)); err != nil {
if err = params.SetValue(params.Context, key, internal.AdaptType(value)); err != nil {
return nil, err
}
// If expiresAt is set, set the key's expiry time as well
if params.expireAt != nil {
server.SetExpiry(ctx, key, params.expireAt.(time.Time), false)
if options.expireAt != nil {
params.SetExpiry(params.Context, key, options.expireAt.(time.Time), false)
}
return res, nil
}
func handleMSet(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
_, err := msetKeyFunc(cmd)
func handleMSet(params types.HandlerFuncParams) ([]byte, error) {
_, err := msetKeyFunc(params.Command)
if err != nil {
return nil, err
}
@@ -110,7 +108,7 @@ func handleMSet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
defer func() {
for k, v := range entries {
if v.locked {
server.KeyUnlock(ctx, k)
params.KeyUnlock(params.Context, k)
entries[k] = KeyObject{
value: v.value,
locked: false,
@@ -120,10 +118,10 @@ func handleMSet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
}()
// Extract all the key/value pairs
for i, key := range cmd[1:] {
for i, key := range params.Command[1:] {
if i%2 == 0 {
entries[key] = KeyObject{
value: internal.AdaptType(cmd[1:][i+1]),
value: internal.AdaptType(params.Command[1:][i+1]),
locked: false,
}
}
@@ -132,14 +130,14 @@ func handleMSet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
// Acquire all the locks for each key first
// If any key cannot be acquired, abandon transaction and release all currently held keys
for k, v := range entries {
if server.KeyExists(ctx, k) {
if _, err := server.KeyLock(ctx, k); err != nil {
if params.KeyExists(params.Context, k) {
if _, err := params.KeyLock(params.Context, k); err != nil {
return nil, err
}
entries[k] = KeyObject{value: v.value, locked: true}
continue
}
if _, err := server.CreateKeyAndLock(ctx, k); err != nil {
if _, err := params.CreateKeyAndLock(params.Context, k); err != nil {
return nil, err
}
entries[k] = KeyObject{value: v.value, locked: true}
@@ -147,7 +145,7 @@ func handleMSet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
// Set all the values
for k, v := range entries {
if err := server.SetValue(ctx, k, v.value); err != nil {
if err := params.SetValue(params.Context, k, v.value); err != nil {
return nil, err
}
}
@@ -155,30 +153,30 @@ func handleMSet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return []byte(constants.OkResponse), nil
}
func handleGet(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := getKeyFunc(cmd)
func handleGet(params types.HandlerFuncParams) ([]byte, error) {
keys, err := getKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte("$-1\r\n"), nil
}
_, err = server.KeyRLock(ctx, key)
_, err = params.KeyRLock(params.Context, key)
if err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
value := server.GetValue(ctx, key)
value := params.GetValue(params.Context, key)
return []byte(fmt.Sprintf("+%v\r\n", value)), nil
}
func handleMGet(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := mgetKeyFunc(cmd)
func handleMGet(params types.HandlerFuncParams) ([]byte, error) {
keys, err := mgetKeyFunc(params.Command)
if err != nil {
return nil, err
}
@@ -191,8 +189,8 @@ func handleMGet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
// Skip if we have already locked this key
continue
}
if server.KeyExists(ctx, key) {
_, err = server.KeyRLock(ctx, key)
if params.KeyExists(params.Context, key) {
_, err = params.KeyRLock(params.Context, key)
if err != nil {
return nil, fmt.Errorf("could not obtain lock for %s key", key)
}
@@ -204,19 +202,19 @@ func handleMGet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
defer func() {
for key, locked := range locks {
if locked {
server.KeyRUnlock(ctx, key)
params.KeyRUnlock(params.Context, key)
locks[key] = false
}
}
}()
for key, _ := range locks {
values[key] = fmt.Sprintf("%v", server.GetValue(ctx, key))
values[key] = fmt.Sprintf("%v", params.GetValue(params.Context, key))
}
bytes := []byte(fmt.Sprintf("*%d\r\n", len(cmd[1:])))
bytes := []byte(fmt.Sprintf("*%d\r\n", len(params.Command[1:])))
for _, key := range cmd[1:] {
for _, key := range params.Command[1:] {
if values[key] == "" {
bytes = append(bytes, []byte("$-1\r\n")...)
continue
@@ -227,14 +225,14 @@ func handleMGet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return bytes, nil
}
func handleDel(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := delKeyFunc(cmd)
func handleDel(params types.HandlerFuncParams) ([]byte, error) {
keys, err := delKeyFunc(params.Command)
if err != nil {
return nil, err
}
count := 0
for _, key := range keys.WriteKeys {
err = server.DeleteKey(ctx, key)
err = params.DeleteKey(params.Context, key)
if err != nil {
log.Printf("could not delete key %s due to error: %+v\n", key, err)
continue
@@ -244,91 +242,91 @@ func handleDel(ctx context.Context, cmd []string, server types.EchoVault, _ *net
return []byte(fmt.Sprintf(":%d\r\n", count)), nil
}
func handlePersist(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := persistKeyFunc(cmd)
func handlePersist(params types.HandlerFuncParams) ([]byte, error) {
keys, err := persistKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.WriteKeys[0]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte(":0\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
expireAt := server.GetExpiry(ctx, key)
expireAt := params.GetExpiry(params.Context, key)
if expireAt == (time.Time{}) {
return []byte(":0\r\n"), nil
}
server.SetExpiry(ctx, key, time.Time{}, false)
params.SetExpiry(params.Context, key, time.Time{}, false)
return []byte(":1\r\n"), nil
}
func handleExpireTime(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := expireTimeKeyFunc(cmd)
func handleExpireTime(params types.HandlerFuncParams) ([]byte, error) {
keys, err := expireTimeKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte(":-2\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
expireAt := server.GetExpiry(ctx, key)
expireAt := params.GetExpiry(params.Context, key)
if expireAt == (time.Time{}) {
return []byte(":-1\r\n"), nil
}
t := expireAt.Unix()
if strings.ToLower(cmd[0]) == "pexpiretime" {
if strings.ToLower(params.Command[0]) == "pexpiretime" {
t = expireAt.UnixMilli()
}
return []byte(fmt.Sprintf(":%d\r\n", t)), nil
}
func handleTTL(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := ttlKeyFunc(cmd)
func handleTTL(params types.HandlerFuncParams) ([]byte, error) {
keys, err := ttlKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
clock := server.GetClock()
clock := params.GetClock()
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte(":-2\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
expireAt := server.GetExpiry(ctx, key)
expireAt := params.GetExpiry(params.Context, key)
if expireAt == (time.Time{}) {
return []byte(":-1\r\n"), nil
}
t := expireAt.Unix() - clock.Now().Unix()
if strings.ToLower(cmd[0]) == "pttl" {
if strings.ToLower(params.Command[0]) == "pttl" {
t = expireAt.UnixMilli() - clock.Now().UnixMilli()
}
@@ -339,8 +337,8 @@ func handleTTL(ctx context.Context, cmd []string, server types.EchoVault, _ *net
return []byte(fmt.Sprintf(":%d\r\n", t)), nil
}
func handleExpire(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := expireKeyFunc(cmd)
func handleExpire(params types.HandlerFuncParams) ([]byte, error) {
keys, err := expireKeyFunc(params.Command)
if err != nil {
return nil, err
}
@@ -348,42 +346,42 @@ func handleExpire(ctx context.Context, cmd []string, server types.EchoVault, _ *
key := keys.WriteKeys[0]
// Extract time
n, err := strconv.ParseInt(cmd[2], 10, 64)
n, err := strconv.ParseInt(params.Command[2], 10, 64)
if err != nil {
return nil, errors.New("expire time must be integer")
}
expireAt := server.GetClock().Now().Add(time.Duration(n) * time.Second)
if strings.ToLower(cmd[0]) == "pexpire" {
expireAt = server.GetClock().Now().Add(time.Duration(n) * time.Millisecond)
expireAt := params.GetClock().Now().Add(time.Duration(n) * time.Second)
if strings.ToLower(params.Command[0]) == "pexpire" {
expireAt = params.GetClock().Now().Add(time.Duration(n) * time.Millisecond)
}
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte(":0\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
if len(cmd) == 3 {
server.SetExpiry(ctx, key, expireAt, true)
if len(params.Command) == 3 {
params.SetExpiry(params.Context, key, expireAt, true)
return []byte(":1\r\n"), nil
}
currentExpireAt := server.GetExpiry(ctx, key)
currentExpireAt := params.GetExpiry(params.Context, key)
switch strings.ToLower(cmd[3]) {
switch strings.ToLower(params.Command[3]) {
case "nx":
if currentExpireAt != (time.Time{}) {
return []byte(":0\r\n"), nil
}
server.SetExpiry(ctx, key, expireAt, false)
params.SetExpiry(params.Context, key, expireAt, false)
case "xx":
if currentExpireAt == (time.Time{}) {
return []byte(":0\r\n"), nil
}
server.SetExpiry(ctx, key, expireAt, false)
params.SetExpiry(params.Context, key, expireAt, false)
case "gt":
if currentExpireAt == (time.Time{}) {
return []byte(":0\r\n"), nil
@@ -391,24 +389,24 @@ func handleExpire(ctx context.Context, cmd []string, server types.EchoVault, _ *
if expireAt.Before(currentExpireAt) {
return []byte(":0\r\n"), nil
}
server.SetExpiry(ctx, key, expireAt, false)
params.SetExpiry(params.Context, key, expireAt, false)
case "lt":
if currentExpireAt != (time.Time{}) {
if currentExpireAt.Before(expireAt) {
return []byte(":0\r\n"), nil
}
server.SetExpiry(ctx, key, expireAt, false)
params.SetExpiry(params.Context, key, expireAt, false)
}
server.SetExpiry(ctx, key, expireAt, false)
params.SetExpiry(params.Context, key, expireAt, false)
default:
return nil, fmt.Errorf("unknown option %s", strings.ToUpper(cmd[3]))
return nil, fmt.Errorf("unknown option %s", strings.ToUpper(params.Command[3]))
}
return []byte(":1\r\n"), nil
}
func handleExpireAt(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := expireKeyFunc(cmd)
func handleExpireAt(params types.HandlerFuncParams) ([]byte, error) {
keys, err := expireKeyFunc(params.Command)
if err != nil {
return nil, err
}
@@ -416,42 +414,42 @@ func handleExpireAt(ctx context.Context, cmd []string, server types.EchoVault, _
key := keys.WriteKeys[0]
// Extract time
n, err := strconv.ParseInt(cmd[2], 10, 64)
n, err := strconv.ParseInt(params.Command[2], 10, 64)
if err != nil {
return nil, errors.New("expire time must be integer")
}
expireAt := time.Unix(n, 0)
if strings.ToLower(cmd[0]) == "pexpireat" {
if strings.ToLower(params.Command[0]) == "pexpireat" {
expireAt = time.UnixMilli(n)
}
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte(":0\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
if len(cmd) == 3 {
server.SetExpiry(ctx, key, expireAt, true)
if len(params.Command) == 3 {
params.SetExpiry(params.Context, key, expireAt, true)
return []byte(":1\r\n"), nil
}
currentExpireAt := server.GetExpiry(ctx, key)
currentExpireAt := params.GetExpiry(params.Context, key)
switch strings.ToLower(cmd[3]) {
switch strings.ToLower(params.Command[3]) {
case "nx":
if currentExpireAt != (time.Time{}) {
return []byte(":0\r\n"), nil
}
server.SetExpiry(ctx, key, expireAt, false)
params.SetExpiry(params.Context, key, expireAt, false)
case "xx":
if currentExpireAt == (time.Time{}) {
return []byte(":0\r\n"), nil
}
server.SetExpiry(ctx, key, expireAt, false)
params.SetExpiry(params.Context, key, expireAt, false)
case "gt":
if currentExpireAt == (time.Time{}) {
return []byte(":0\r\n"), nil
@@ -459,17 +457,17 @@ func handleExpireAt(ctx context.Context, cmd []string, server types.EchoVault, _
if expireAt.Before(currentExpireAt) {
return []byte(":0\r\n"), nil
}
server.SetExpiry(ctx, key, expireAt, false)
params.SetExpiry(params.Context, key, expireAt, false)
case "lt":
if currentExpireAt != (time.Time{}) {
if currentExpireAt.Before(expireAt) {
return []byte(":0\r\n"), nil
}
server.SetExpiry(ctx, key, expireAt, false)
params.SetExpiry(params.Context, key, expireAt, false)
}
server.SetExpiry(ctx, key, expireAt, false)
params.SetExpiry(params.Context, key, expireAt, false)
default:
return nil, fmt.Errorf("unknown option %s", strings.ToUpper(cmd[3]))
return nil, fmt.Errorf("unknown option %s", strings.ToUpper(params.Command[3]))
}
return []byte(":1\r\n"), nil

View File

@@ -23,96 +23,96 @@ import (
"time"
)
type SetParams struct {
type SetOptions struct {
exists string
get bool
expireAt interface{} // Exact expireAt time un unix milliseconds
}
func getSetCommandParams(clock clock.Clock, cmd []string, params SetParams) (SetParams, error) {
func getSetCommandOptions(clock clock.Clock, cmd []string, options SetOptions) (SetOptions, error) {
if len(cmd) == 0 {
return params, nil
return options, nil
}
switch strings.ToLower(cmd[0]) {
case "get":
params.get = true
return getSetCommandParams(clock, cmd[1:], params)
options.get = true
return getSetCommandOptions(clock, cmd[1:], options)
case "nx":
if params.exists != "" {
return SetParams{}, fmt.Errorf("cannot specify NX when %s is already specified", strings.ToUpper(params.exists))
if options.exists != "" {
return SetOptions{}, fmt.Errorf("cannot specify NX when %s is already specified", strings.ToUpper(options.exists))
}
params.exists = "NX"
return getSetCommandParams(clock, cmd[1:], params)
options.exists = "NX"
return getSetCommandOptions(clock, cmd[1:], options)
case "xx":
if params.exists != "" {
return SetParams{}, fmt.Errorf("cannot specify XX when %s is already specified", strings.ToUpper(params.exists))
if options.exists != "" {
return SetOptions{}, fmt.Errorf("cannot specify XX when %s is already specified", strings.ToUpper(options.exists))
}
params.exists = "XX"
return getSetCommandParams(clock, cmd[1:], params)
options.exists = "XX"
return getSetCommandOptions(clock, cmd[1:], options)
case "ex":
if len(cmd) < 2 {
return SetParams{}, errors.New("seconds value required after EX")
return SetOptions{}, errors.New("seconds value required after EX")
}
if params.expireAt != nil {
return SetParams{}, errors.New("cannot specify EX when expiry time is already set")
if options.expireAt != nil {
return SetOptions{}, errors.New("cannot specify EX when expiry time is already set")
}
secondsStr := cmd[1]
seconds, err := strconv.ParseInt(secondsStr, 10, 64)
if err != nil {
return SetParams{}, errors.New("seconds value should be an integer")
return SetOptions{}, errors.New("seconds value should be an integer")
}
params.expireAt = clock.Now().Add(time.Duration(seconds) * time.Second)
return getSetCommandParams(clock, cmd[2:], params)
options.expireAt = clock.Now().Add(time.Duration(seconds) * time.Second)
return getSetCommandOptions(clock, cmd[2:], options)
case "px":
if len(cmd) < 2 {
return SetParams{}, errors.New("milliseconds value required after PX")
return SetOptions{}, errors.New("milliseconds value required after PX")
}
if params.expireAt != nil {
return SetParams{}, errors.New("cannot specify PX when expiry time is already set")
if options.expireAt != nil {
return SetOptions{}, errors.New("cannot specify PX when expiry time is already set")
}
millisecondsStr := cmd[1]
milliseconds, err := strconv.ParseInt(millisecondsStr, 10, 64)
if err != nil {
return SetParams{}, errors.New("milliseconds value should be an integer")
return SetOptions{}, errors.New("milliseconds value should be an integer")
}
params.expireAt = clock.Now().Add(time.Duration(milliseconds) * time.Millisecond)
return getSetCommandParams(clock, cmd[2:], params)
options.expireAt = clock.Now().Add(time.Duration(milliseconds) * time.Millisecond)
return getSetCommandOptions(clock, cmd[2:], options)
case "exat":
if len(cmd) < 2 {
return SetParams{}, errors.New("seconds value required after EXAT")
return SetOptions{}, errors.New("seconds value required after EXAT")
}
if params.expireAt != nil {
return SetParams{}, errors.New("cannot specify EXAT when expiry time is already set")
if options.expireAt != nil {
return SetOptions{}, errors.New("cannot specify EXAT when expiry time is already set")
}
secondsStr := cmd[1]
seconds, err := strconv.ParseInt(secondsStr, 10, 64)
if err != nil {
return SetParams{}, errors.New("seconds value should be an integer")
return SetOptions{}, errors.New("seconds value should be an integer")
}
params.expireAt = time.Unix(seconds, 0)
return getSetCommandParams(clock, cmd[2:], params)
options.expireAt = time.Unix(seconds, 0)
return getSetCommandOptions(clock, cmd[2:], options)
case "pxat":
if len(cmd) < 2 {
return SetParams{}, errors.New("milliseconds value required after PXAT")
return SetOptions{}, errors.New("milliseconds value required after PXAT")
}
if params.expireAt != nil {
return SetParams{}, errors.New("cannot specify PXAT when expiry time is already set")
if options.expireAt != nil {
return SetOptions{}, errors.New("cannot specify PXAT when expiry time is already set")
}
millisecondsStr := cmd[1]
milliseconds, err := strconv.ParseInt(millisecondsStr, 10, 64)
if err != nil {
return SetParams{}, errors.New("milliseconds value should be an integer")
return SetOptions{}, errors.New("milliseconds value should be an integer")
}
params.expireAt = time.UnixMilli(milliseconds)
return getSetCommandParams(clock, cmd[2:], params)
options.expireAt = time.UnixMilli(milliseconds)
return getSetCommandOptions(clock, cmd[2:], options)
default:
return SetParams{}, fmt.Errorf("unknown option %s for set command", strings.ToUpper(cmd[0]))
return SetOptions{}, fmt.Errorf("unknown option %s for set command", strings.ToUpper(cmd[0]))
}
}

View File

@@ -15,21 +15,19 @@
package hash
import (
"context"
"errors"
"fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"math/rand"
"net"
"slices"
"strconv"
"strings"
)
func handleHSET(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := hsetKeyFunc(cmd)
func handleHSET(params types.HandlerFuncParams) ([]byte, error) {
keys, err := hsetKeyFunc(params.Command)
if err != nil {
return nil, err
}
@@ -37,39 +35,39 @@ func handleHSET(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
key := keys.WriteKeys[0]
entries := make(map[string]interface{})
if len(cmd[2:])%2 != 0 {
if len(params.Command[2:])%2 != 0 {
return nil, errors.New("each field must have a corresponding value")
}
for i := 2; i <= len(cmd)-2; i += 2 {
entries[cmd[i]] = internal.AdaptType(cmd[i+1])
for i := 2; i <= len(params.Command)-2; i += 2 {
entries[params.Command[i]] = internal.AdaptType(params.Command[i+1])
}
if !server.KeyExists(ctx, key) {
_, err = server.CreateKeyAndLock(ctx, key)
if !params.KeyExists(params.Context, key) {
_, err = params.CreateKeyAndLock(params.Context, key)
if err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
if err = server.SetValue(ctx, key, entries); err != nil {
defer params.KeyUnlock(params.Context, key)
if err = params.SetValue(params.Context, key, entries); err != nil {
return nil, err
}
return []byte(fmt.Sprintf(":%d\r\n", len(entries))), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
count := 0
for field, value := range entries {
if strings.EqualFold(cmd[0], "hsetnx") {
if strings.EqualFold(params.Command[0], "hsetnx") {
if hash[field] == nil {
hash[field] = value
count += 1
@@ -79,32 +77,32 @@ func handleHSET(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
hash[field] = value
count += 1
}
if err = server.SetValue(ctx, key, hash); err != nil {
if err = params.SetValue(params.Context, key, hash); err != nil {
return nil, err
}
return []byte(fmt.Sprintf(":%d\r\n", count)), nil
}
func handleHGET(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := hgetKeyFunc(cmd)
func handleHGET(params types.HandlerFuncParams) ([]byte, error) {
keys, err := hgetKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
fields := cmd[2:]
fields := params.Command[2:]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte("$-1\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
@@ -137,25 +135,25 @@ func handleHGET(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return []byte(res), nil
}
func handleHSTRLEN(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := hstrlenKeyFunc(cmd)
func handleHSTRLEN(params types.HandlerFuncParams) ([]byte, error) {
keys, err := hstrlenKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
fields := cmd[2:]
fields := params.Command[2:]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte("$-1\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
@@ -188,24 +186,24 @@ func handleHSTRLEN(ctx context.Context, cmd []string, server types.EchoVault, _
return []byte(res), nil
}
func handleHVALS(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := hvalsKeyFunc(cmd)
func handleHVALS(params types.HandlerFuncParams) ([]byte, error) {
keys, err := hvalsKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
@@ -229,8 +227,8 @@ func handleHVALS(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return []byte(res), nil
}
func handleHRANDFIELD(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := hrandfieldKeyFunc(cmd)
func handleHRANDFIELD(params types.HandlerFuncParams) ([]byte, error) {
keys, err := hrandfieldKeyFunc(params.Command)
if err != nil {
return nil, err
}
@@ -238,8 +236,8 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server types.EchoVault,
key := keys.ReadKeys[0]
count := 1
if len(cmd) >= 3 {
c, err := strconv.Atoi(cmd[2])
if len(params.Command) >= 3 {
c, err := strconv.Atoi(params.Command[2])
if err != nil {
return nil, errors.New("count must be an integer")
}
@@ -250,24 +248,24 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server types.EchoVault,
}
withvalues := false
if len(cmd) == 4 {
if strings.EqualFold(cmd[3], "withvalues") {
if len(params.Command) == 4 {
if strings.EqualFold(params.Command[3], "withvalues") {
withvalues = true
} else {
return nil, errors.New("result modifier must be withvalues")
}
}
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
@@ -345,24 +343,24 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(res), nil
}
func handleHLEN(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := hlenKeyFunc(cmd)
func handleHLEN(params types.HandlerFuncParams) ([]byte, error) {
keys, err := hlenKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
@@ -370,24 +368,24 @@ func handleHLEN(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return []byte(fmt.Sprintf(":%d\r\n", len(hash))), nil
}
func handleHKEYS(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := hkeysKeyFunc(cmd)
func handleHKEYS(params types.HandlerFuncParams) ([]byte, error) {
keys, err := hkeysKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
@@ -400,59 +398,59 @@ func handleHKEYS(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return []byte(res), nil
}
func handleHINCRBY(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := hincrbyKeyFunc(cmd)
func handleHINCRBY(params types.HandlerFuncParams) ([]byte, error) {
keys, err := hincrbyKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.WriteKeys[0]
field := cmd[2]
field := params.Command[2]
var intIncrement int
var floatIncrement float64
if strings.EqualFold(cmd[0], "hincrbyfloat") {
f, err := strconv.ParseFloat(cmd[3], 64)
if strings.EqualFold(params.Command[0], "hincrbyfloat") {
f, err := strconv.ParseFloat(params.Command[3], 64)
if err != nil {
return nil, errors.New("increment must be a float")
}
floatIncrement = f
} else {
i, err := strconv.Atoi(cmd[3])
i, err := strconv.Atoi(params.Command[3])
if err != nil {
return nil, errors.New("increment must be an integer")
}
intIncrement = i
}
if !server.KeyExists(ctx, key) {
if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
if !params.KeyExists(params.Context, key) {
if _, err := params.CreateKeyAndLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
hash := make(map[string]interface{})
if strings.EqualFold(cmd[0], "hincrbyfloat") {
if strings.EqualFold(params.Command[0], "hincrbyfloat") {
hash[field] = floatIncrement
if err = server.SetValue(ctx, key, hash); err != nil {
if err = params.SetValue(params.Context, key, hash); err != nil {
return nil, err
}
return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(floatIncrement, 'f', -1, 64))), nil
} else {
hash[field] = intIncrement
if err = server.SetValue(ctx, key, hash); err != nil {
if err = params.SetValue(params.Context, key, hash); err != nil {
return nil, err
}
return []byte(fmt.Sprintf(":%d\r\n", intIncrement)), nil
}
}
if _, err := server.KeyLock(ctx, key); err != nil {
if _, err := params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
@@ -466,21 +464,21 @@ func handleHINCRBY(ctx context.Context, cmd []string, server types.EchoVault, _
return nil, fmt.Errorf("value at field %s is not a number", field)
case int:
i, _ := hash[field].(int)
if strings.EqualFold(cmd[0], "hincrbyfloat") {
if strings.EqualFold(params.Command[0], "hincrbyfloat") {
hash[field] = float64(i) + floatIncrement
} else {
hash[field] = i + intIncrement
}
case float64:
f, _ := hash[field].(float64)
if strings.EqualFold(cmd[0], "hincrbyfloat") {
if strings.EqualFold(params.Command[0], "hincrbyfloat") {
hash[field] = f + floatIncrement
} else {
hash[field] = f + float64(intIncrement)
}
}
if err = server.SetValue(ctx, key, hash); err != nil {
if err = params.SetValue(params.Context, key, hash); err != nil {
return nil, err
}
@@ -492,24 +490,24 @@ func handleHINCRBY(ctx context.Context, cmd []string, server types.EchoVault, _
return []byte(fmt.Sprintf(":%d\r\n", i)), nil
}
func handleHGETALL(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := hgetallKeyFunc(cmd)
func handleHGETALL(params types.HandlerFuncParams) ([]byte, error) {
keys, err := hgetallKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
@@ -532,25 +530,25 @@ func handleHGETALL(ctx context.Context, cmd []string, server types.EchoVault, _
return []byte(res), nil
}
func handleHEXISTS(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := hexistsKeyFunc(cmd)
func handleHEXISTS(params types.HandlerFuncParams) ([]byte, error) {
keys, err := hexistsKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
field := cmd[2]
field := params.Command[2]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
@@ -562,25 +560,25 @@ func handleHEXISTS(ctx context.Context, cmd []string, server types.EchoVault, _
return []byte(":0\r\n"), nil
}
func handleHDEL(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := hdelKeyFunc(cmd)
func handleHDEL(params types.HandlerFuncParams) ([]byte, error) {
keys, err := hdelKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.WriteKeys[0]
fields := cmd[2:]
fields := params.Command[2:]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte(":0\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
@@ -594,7 +592,7 @@ func handleHDEL(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
}
}
if err = server.SetValue(ctx, key, hash); err != nil {
if err = params.SetValue(params.Context, key, hash); err != nil {
return nil, err
}

View File

@@ -15,65 +15,63 @@
package list
import (
"context"
"errors"
"fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"math"
"net"
"slices"
"strings"
)
func handleLLen(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := llenKeyFunc(cmd)
func handleLLen(params types.HandlerFuncParams) ([]byte, error) {
keys, err := llenKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
// If key does not exist, return 0
return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
if list, ok := server.GetValue(ctx, key).([]interface{}); ok {
if list, ok := params.GetValue(params.Context, key).([]interface{}); ok {
return []byte(fmt.Sprintf(":%d\r\n", len(list))), nil
}
return nil, errors.New("LLEN command on non-list item")
}
func handleLIndex(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lindexKeyFunc(cmd)
func handleLIndex(params types.HandlerFuncParams) ([]byte, error) {
keys, err := lindexKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
index, ok := internal.AdaptType(cmd[2]).(int)
index, ok := internal.AdaptType(params.Command[2]).(int)
if !ok {
return nil, errors.New("index must be an integer")
}
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return nil, errors.New("LINDEX command on non-list item")
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
list, ok := server.GetValue(ctx, key).([]interface{})
server.KeyRUnlock(ctx, key)
list, ok := params.GetValue(params.Context, key).([]interface{})
params.KeyRUnlock(params.Context, key)
if !ok {
return nil, errors.New("LINDEX command on non-list item")
@@ -86,30 +84,30 @@ func handleLIndex(ctx context.Context, cmd []string, server types.EchoVault, _ *
return []byte(fmt.Sprintf("+%s\r\n", list[index])), nil
}
func handleLRange(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lrangeKeyFunc(cmd)
func handleLRange(params types.HandlerFuncParams) ([]byte, error) {
keys, err := lrangeKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
start, startOk := internal.AdaptType(cmd[2]).(int)
end, endOk := internal.AdaptType(cmd[3]).(int)
start, startOk := internal.AdaptType(params.Command[2]).(int)
end, endOk := internal.AdaptType(params.Command[3]).(int)
if !startOk || !endOk {
return nil, errors.New("start and end indices must be integers")
}
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return nil, errors.New("LRANGE command on non-list item")
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
list, ok := server.GetValue(ctx, key).([]interface{})
list, ok := params.GetValue(params.Context, key).([]interface{})
if !ok {
return nil, errors.New("LRANGE command on non-list item")
}
@@ -165,29 +163,29 @@ func handleLRange(ctx context.Context, cmd []string, server types.EchoVault, _ *
return bytes, nil
}
func handleLSet(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lsetKeyFunc(cmd)
func handleLSet(params types.HandlerFuncParams) ([]byte, error) {
keys, err := lsetKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.WriteKeys[0]
index, ok := internal.AdaptType(cmd[2]).(int)
index, ok := internal.AdaptType(params.Command[2]).(int)
if !ok {
return nil, errors.New("index must be an integer")
}
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return nil, errors.New("LSET command on non-list item")
}
if _, err = server.KeyLock(ctx, key); err != nil {
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
list, ok := server.GetValue(ctx, key).([]interface{})
list, ok := params.GetValue(params.Context, key).([]interface{})
if !ok {
return nil, errors.New("LSET command on non-list item")
}
@@ -196,23 +194,23 @@ func handleLSet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, errors.New("index must be within list range")
}
list[index] = internal.AdaptType(cmd[3])
if err = server.SetValue(ctx, key, list); err != nil {
list[index] = internal.AdaptType(params.Command[3])
if err = params.SetValue(params.Context, key, list); err != nil {
return nil, err
}
return []byte(constants.OkResponse), nil
}
func handleLTrim(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := ltrimKeyFunc(cmd)
func handleLTrim(params types.HandlerFuncParams) ([]byte, error) {
keys, err := ltrimKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.WriteKeys[0]
start, startOk := internal.AdaptType(cmd[2]).(int)
end, endOk := internal.AdaptType(cmd[3]).(int)
start, startOk := internal.AdaptType(params.Command[2]).(int)
end, endOk := internal.AdaptType(params.Command[3]).(int)
if !startOk || !endOk {
return nil, errors.New("start and end indices must be integers")
@@ -222,16 +220,16 @@ func handleLTrim(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return nil, errors.New("end index must be greater than start index or -1")
}
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return nil, errors.New("LTRIM command on non-list item")
}
if _, err = server.KeyLock(ctx, key); err != nil {
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
list, ok := server.GetValue(ctx, key).([]interface{})
list, ok := params.GetValue(params.Context, key).([]interface{})
if !ok {
return nil, errors.New("LTRIM command on non-list item")
}
@@ -241,44 +239,44 @@ func handleLTrim(ctx context.Context, cmd []string, server types.EchoVault, _ *n
}
if end == -1 || end > len(list) {
if err = server.SetValue(ctx, key, list[start:]); err != nil {
if err = params.SetValue(params.Context, key, list[start:]); err != nil {
return nil, err
}
return []byte(constants.OkResponse), nil
}
if err = server.SetValue(ctx, key, list[start:end]); err != nil {
if err = params.SetValue(params.Context, key, list[start:end]); err != nil {
return nil, err
}
return []byte(constants.OkResponse), nil
}
func handleLRem(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lremKeyFunc(cmd)
func handleLRem(params types.HandlerFuncParams) ([]byte, error) {
keys, err := lremKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.WriteKeys[0]
value := cmd[3]
value := params.Command[3]
count, ok := internal.AdaptType(cmd[2]).(int)
count, ok := internal.AdaptType(params.Command[2]).(int)
if !ok {
return nil, errors.New("count must be an integer")
}
absoluteCount := internal.AbsInt(count)
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return nil, errors.New("LREM command on non-list item")
}
if _, err = server.KeyLock(ctx, key); err != nil {
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
list, ok := server.GetValue(ctx, key).([]interface{})
list, ok := params.GetValue(params.Context, key).([]interface{})
if !ok {
return nil, errors.New("LREM command on non-list item")
}
@@ -314,44 +312,44 @@ func handleLRem(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return elem == nil
})
if err = server.SetValue(ctx, key, list); err != nil {
if err = params.SetValue(params.Context, key, list); err != nil {
return nil, err
}
return []byte(constants.OkResponse), nil
}
func handleLMove(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lmoveKeyFunc(cmd)
func handleLMove(params types.HandlerFuncParams) ([]byte, error) {
keys, err := lmoveKeyFunc(params.Command)
if err != nil {
return nil, err
}
source, destination := keys.WriteKeys[0], keys.WriteKeys[1]
whereFrom := strings.ToLower(cmd[3])
whereTo := strings.ToLower(cmd[4])
whereFrom := strings.ToLower(params.Command[3])
whereTo := strings.ToLower(params.Command[4])
if !slices.Contains([]string{"left", "right"}, whereFrom) || !slices.Contains([]string{"left", "right"}, whereTo) {
return nil, errors.New("wherefrom and whereto arguments must be either LEFT or RIGHT")
}
if !server.KeyExists(ctx, source) || !server.KeyExists(ctx, destination) {
if !params.KeyExists(params.Context, source) || !params.KeyExists(params.Context, destination) {
return nil, errors.New("both source and destination must be lists")
}
if _, err = server.KeyLock(ctx, source); err != nil {
if _, err = params.KeyLock(params.Context, source); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, source)
defer params.KeyUnlock(params.Context, source)
_, err = server.KeyLock(ctx, destination)
_, err = params.KeyLock(params.Context, destination)
if err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, destination)
defer params.KeyUnlock(params.Context, destination)
sourceList, sourceOk := server.GetValue(ctx, source).([]interface{})
destinationList, destinationOk := server.GetValue(ctx, destination).([]interface{})
sourceList, sourceOk := params.GetValue(params.Context, source).([]interface{})
destinationList, destinationOk := params.GetValue(params.Context, destination).([]interface{})
if !sourceOk || !destinationOk {
return nil, errors.New("both source and destination must be lists")
@@ -359,18 +357,18 @@ func handleLMove(ctx context.Context, cmd []string, server types.EchoVault, _ *n
switch whereFrom {
case "left":
err = server.SetValue(ctx, source, append([]interface{}{}, sourceList[1:]...))
err = params.SetValue(params.Context, source, append([]interface{}{}, sourceList[1:]...))
if whereTo == "left" {
err = server.SetValue(ctx, destination, append(sourceList[0:1], destinationList...))
err = params.SetValue(params.Context, destination, append(sourceList[0:1], destinationList...))
} else if whereTo == "right" {
err = server.SetValue(ctx, destination, append(destinationList, sourceList[0]))
err = params.SetValue(params.Context, destination, append(destinationList, sourceList[0]))
}
case "right":
err = server.SetValue(ctx, source, append([]interface{}{}, sourceList[:len(sourceList)-1]...))
err = params.SetValue(params.Context, source, append([]interface{}{}, sourceList[:len(sourceList)-1]...))
if whereTo == "left" {
err = server.SetValue(ctx, destination, append(sourceList[len(sourceList)-1:], destinationList...))
err = params.SetValue(params.Context, destination, append(sourceList[len(sourceList)-1:], destinationList...))
} else if whereTo == "right" {
err = server.SetValue(ctx, destination, append(destinationList, sourceList[len(sourceList)-1]))
err = params.SetValue(params.Context, destination, append(destinationList, sourceList[len(sourceList)-1]))
}
}
@@ -381,54 +379,54 @@ func handleLMove(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return []byte(constants.OkResponse), nil
}
func handleLPush(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lpushKeyFunc(cmd)
func handleLPush(params types.HandlerFuncParams) ([]byte, error) {
keys, err := lpushKeyFunc(params.Command)
if err != nil {
return nil, err
}
var newElems []interface{}
for _, elem := range cmd[2:] {
for _, elem := range params.Command[2:] {
newElems = append(newElems, internal.AdaptType(elem))
}
key := keys.WriteKeys[0]
if !server.KeyExists(ctx, key) {
switch strings.ToLower(cmd[0]) {
if !params.KeyExists(params.Context, key) {
switch strings.ToLower(params.Command[0]) {
case "lpushx":
return nil, errors.New("LPUSHX command on non-list item")
default:
if _, err = server.CreateKeyAndLock(ctx, key); err != nil {
if _, err = params.CreateKeyAndLock(params.Context, key); err != nil {
return nil, err
}
if err = server.SetValue(ctx, key, []interface{}{}); err != nil {
if err = params.SetValue(params.Context, key, []interface{}{}); err != nil {
return nil, err
}
}
} else {
if _, err = server.KeyLock(ctx, key); err != nil {
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
currentList := server.GetValue(ctx, key)
currentList := params.GetValue(params.Context, key)
l, ok := currentList.([]interface{})
if !ok {
return nil, errors.New("LPUSH command on non-list item")
}
if err = server.SetValue(ctx, key, append(newElems, l...)); err != nil {
if err = params.SetValue(params.Context, key, append(newElems, l...)); err != nil {
return nil, err
}
return []byte(constants.OkResponse), nil
}
func handleRPush(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := rpushKeyFunc(cmd)
func handleRPush(params types.HandlerFuncParams) ([]byte, error) {
keys, err := rpushKeyFunc(params.Command)
if err != nil {
return nil, err
}
@@ -437,31 +435,31 @@ func handleRPush(ctx context.Context, cmd []string, server types.EchoVault, _ *n
var newElems []interface{}
for _, elem := range cmd[2:] {
for _, elem := range params.Command[2:] {
newElems = append(newElems, internal.AdaptType(elem))
}
if !server.KeyExists(ctx, key) {
switch strings.ToLower(cmd[0]) {
if !params.KeyExists(params.Context, key) {
switch strings.ToLower(params.Command[0]) {
case "rpushx":
return nil, errors.New("RPUSHX command on non-list item")
default:
if _, err = server.CreateKeyAndLock(ctx, key); err != nil {
if _, err = params.CreateKeyAndLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
if err = server.SetValue(ctx, key, []interface{}{}); err != nil {
defer params.KeyUnlock(params.Context, key)
if err = params.SetValue(params.Context, key, []interface{}{}); err != nil {
return nil, err
}
}
} else {
if _, err = server.KeyLock(ctx, key); err != nil {
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
}
currentList := server.GetValue(ctx, key)
currentList := params.GetValue(params.Context, key)
l, ok := currentList.([]interface{})
@@ -469,42 +467,42 @@ func handleRPush(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return nil, errors.New("RPUSH command on non-list item")
}
if err = server.SetValue(ctx, key, append(l, newElems...)); err != nil {
if err = params.SetValue(params.Context, key, append(l, newElems...)); err != nil {
return nil, err
}
return []byte(constants.OkResponse), nil
}
func handlePop(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := popKeyFunc(cmd)
func handlePop(params types.HandlerFuncParams) ([]byte, error) {
keys, err := popKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.WriteKeys[0]
if !server.KeyExists(ctx, key) {
return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(cmd[0]))
if !params.KeyExists(params.Context, key) {
return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(params.Command[0]))
}
if _, err = server.KeyLock(ctx, key); err != nil {
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
list, ok := server.GetValue(ctx, key).([]interface{})
list, ok := params.GetValue(params.Context, key).([]interface{})
if !ok {
return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(cmd[0]))
return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(params.Command[0]))
}
switch strings.ToLower(cmd[0]) {
switch strings.ToLower(params.Command[0]) {
default:
if err = server.SetValue(ctx, key, list[1:]); err != nil {
if err = params.SetValue(params.Context, key, list[1:]); err != nil {
return nil, err
}
return []byte(fmt.Sprintf("+%v\r\n", list[0])), nil
case "rpop":
if err = server.SetValue(ctx, key, list[:len(list)-1]); err != nil {
if err = params.SetValue(params.Context, key, list[:len(list)-1]); err != nil {
return nil, err
}
return []byte(fmt.Sprintf("+%v\r\n", list[len(list)-1])), nil

View File

@@ -15,79 +15,77 @@
package pubsub
import (
"context"
"errors"
"fmt"
internal_pubsub "github.com/echovault/echovault/internal/pubsub"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"net"
"strings"
)
func handleSubscribe(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*internal_pubsub.PubSub)
func handleSubscribe(params types.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
if !ok {
return nil, errors.New("could not load pubsub module")
}
channels := cmd[1:]
channels := params.Command[1:]
if len(channels) == 0 {
return nil, errors.New(constants.WrongArgsResponse)
}
withPattern := strings.EqualFold(cmd[0], "psubscribe")
pubsub.Subscribe(ctx, conn, channels, withPattern)
withPattern := strings.EqualFold(params.Command[0], "psubscribe")
pubsub.Subscribe(params.Context, params.Connection, channels, withPattern)
return nil, nil
}
func handleUnsubscribe(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*internal_pubsub.PubSub)
func handleUnsubscribe(params types.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
if !ok {
return nil, errors.New("could not load pubsub module")
}
channels := cmd[1:]
channels := params.Command[1:]
withPattern := strings.EqualFold(cmd[0], "punsubscribe")
withPattern := strings.EqualFold(params.Command[0], "punsubscribe")
return pubsub.Unsubscribe(ctx, conn, channels, withPattern), nil
return pubsub.Unsubscribe(params.Context, params.Connection, channels, withPattern), nil
}
func handlePublish(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*internal_pubsub.PubSub)
func handlePublish(params types.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
if !ok {
return nil, errors.New("could not load pubsub module")
}
if len(cmd) != 3 {
if len(params.Command) != 3 {
return nil, errors.New(constants.WrongArgsResponse)
}
pubsub.Publish(ctx, cmd[2], cmd[1])
pubsub.Publish(params.Context, params.Command[2], params.Command[1])
return []byte(constants.OkResponse), nil
}
func handlePubSubChannels(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
if len(cmd) > 3 {
func handlePubSubChannels(params types.HandlerFuncParams) ([]byte, error) {
if len(params.Command) > 3 {
return nil, errors.New(constants.WrongArgsResponse)
}
pubsub, ok := server.GetPubSub().(*internal_pubsub.PubSub)
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
if !ok {
return nil, errors.New("could not load pubsub module")
}
pattern := ""
if len(cmd) == 3 {
pattern = cmd[2]
if len(params.Command) == 3 {
pattern = params.Command[2]
}
return pubsub.Channels(pattern), nil
}
func handlePubSubNumPat(_ context.Context, _ []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*internal_pubsub.PubSub)
func handlePubSubNumPat(params types.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
if !ok {
return nil, errors.New("could not load pubsub module")
}
@@ -95,12 +93,12 @@ func handlePubSubNumPat(_ context.Context, _ []string, server types.EchoVault, _
return []byte(fmt.Sprintf(":%d\r\n", num)), nil
}
func handlePubSubNumSubs(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*internal_pubsub.PubSub)
func handlePubSubNumSubs(params types.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
if !ok {
return nil, errors.New("could not load pubsub module")
}
return pubsub.NumSub(cmd[2:]), nil
return pubsub.NumSub(params.Command[2:]), nil
}
func Commands() []types.Command {
@@ -210,7 +208,7 @@ it's currently subscribe to.`,
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: func(_ context.Context, _ []string, _ types.EchoVault, _ *net.Conn) ([]byte, error) {
HandlerFunc: func(_ types.HandlerFuncParams) ([]byte, error) {
return nil, errors.New("provide CHANNELS, NUMPAT, or NUMSUB subcommand")
},
SubCommands: []types.SubCommand{

View File

@@ -15,20 +15,18 @@
package set
import (
"context"
"errors"
"fmt"
"github.com/echovault/echovault/internal"
internal_set "github.com/echovault/echovault/internal/set"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"net"
"slices"
"strings"
)
func handleSADD(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := saddKeyFunc(cmd)
func handleSADD(params types.HandlerFuncParams) ([]byte, error) {
keys, err := saddKeyFunc(params.Command)
if err != nil {
return nil, err
}
@@ -37,51 +35,51 @@ func handleSADD(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
var set *internal_set.Set
if !server.KeyExists(ctx, key) {
set = internal_set.NewSet(cmd[2:])
if ok, err := server.CreateKeyAndLock(ctx, key); !ok && err != nil {
if !params.KeyExists(params.Context, key) {
set = internal_set.NewSet(params.Command[2:])
if ok, err := params.CreateKeyAndLock(params.Context, key); !ok && err != nil {
return nil, err
}
if err = server.SetValue(ctx, key, set); err != nil {
if err = params.SetValue(params.Context, key, set); err != nil {
return nil, err
}
server.KeyUnlock(ctx, key)
return []byte(fmt.Sprintf(":%d\r\n", len(cmd[2:]))), nil
params.KeyUnlock(params.Context, key)
return []byte(fmt.Sprintf(":%d\r\n", len(params.Command[2:]))), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
count := set.Add(cmd[2:])
count := set.Add(params.Command[2:])
return []byte(fmt.Sprintf(":%d\r\n", count)), nil
}
func handleSCARD(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := scardKeyFunc(cmd)
func handleSCARD(params types.HandlerFuncParams) ([]byte, error) {
keys, err := scardKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte(fmt.Sprintf(":0\r\n")), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
@@ -91,21 +89,21 @@ func handleSCARD(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return []byte(fmt.Sprintf(":%d\r\n", cardinality)), nil
}
func handleSDIFF(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sdiffKeyFunc(cmd)
func handleSDIFF(params types.HandlerFuncParams) ([]byte, error) {
keys, err := sdiffKeyFunc(params.Command)
if err != nil {
return nil, err
}
// Extract base set first
if !server.KeyExists(ctx, keys.ReadKeys[0]) {
if !params.KeyExists(params.Context, keys.ReadKeys[0]) {
return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys.ReadKeys[0])
}
if _, err = server.KeyRLock(ctx, keys.ReadKeys[0]); err != nil {
if _, err = params.KeyRLock(params.Context, keys.ReadKeys[0]); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, keys.ReadKeys[0])
baseSet, ok := server.GetValue(ctx, keys.ReadKeys[0]).(*internal_set.Set)
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*internal_set.Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
}
@@ -114,24 +112,24 @@ func handleSDIFF(ctx context.Context, cmd []string, server types.EchoVault, _ *n
defer func() {
for key, locked := range locks {
if locked {
server.KeyRUnlock(ctx, key)
params.KeyRUnlock(params.Context, key)
}
}
}()
for _, key := range keys.ReadKeys[1:] {
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
continue
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
continue
}
locks[key] = true
}
var sets []*internal_set.Set
for _, key := range cmd[2:] {
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
for _, key := range params.Command[2:] {
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
if !ok {
continue
}
@@ -152,8 +150,8 @@ func handleSDIFF(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return []byte(res), nil
}
func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sdiffstoreKeyFunc(cmd)
func handleSDIFFSTORE(params types.HandlerFuncParams) ([]byte, error) {
keys, err := sdiffstoreKeyFunc(params.Command)
if err != nil {
return nil, err
}
@@ -161,14 +159,14 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
destination := keys.WriteKeys[0]
// Extract base set first
if !server.KeyExists(ctx, keys.ReadKeys[0]) {
if !params.KeyExists(params.Context, keys.ReadKeys[0]) {
return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys.ReadKeys[0])
}
if _, err := server.KeyRLock(ctx, keys.ReadKeys[0]); err != nil {
if _, err := params.KeyRLock(params.Context, keys.ReadKeys[0]); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, keys.ReadKeys[0])
baseSet, ok := server.GetValue(ctx, keys.ReadKeys[0]).(*internal_set.Set)
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*internal_set.Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
}
@@ -177,16 +175,16 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
defer func() {
for key, locked := range locks {
if locked {
server.KeyRUnlock(ctx, key)
params.KeyRUnlock(params.Context, key)
}
}
}()
for _, key := range keys.ReadKeys[1:] {
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
continue
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
continue
}
locks[key] = true
@@ -194,7 +192,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
var sets []*internal_set.Set
for _, key := range keys.ReadKeys[1:] {
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
if !ok {
continue
}
@@ -206,30 +204,30 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
res := fmt.Sprintf(":%d\r\n", len(elems))
if server.KeyExists(ctx, destination) {
if _, err = server.KeyLock(ctx, destination); err != nil {
if params.KeyExists(params.Context, destination) {
if _, err = params.KeyLock(params.Context, destination); err != nil {
return nil, err
}
if err = server.SetValue(ctx, destination, diff); err != nil {
if err = params.SetValue(params.Context, destination, diff); err != nil {
return nil, err
}
server.KeyUnlock(ctx, destination)
params.KeyUnlock(params.Context, destination)
return []byte(res), nil
}
if _, err = server.CreateKeyAndLock(ctx, destination); err != nil {
if _, err = params.CreateKeyAndLock(params.Context, destination); err != nil {
return nil, err
}
if err = server.SetValue(ctx, destination, diff); err != nil {
if err = params.SetValue(params.Context, destination, diff); err != nil {
return nil, err
}
server.KeyUnlock(ctx, destination)
params.KeyUnlock(params.Context, destination)
return []byte(res), nil
}
func handleSINTER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sinterKeyFunc(cmd)
func handleSINTER(params types.HandlerFuncParams) ([]byte, error) {
keys, err := sinterKeyFunc(params.Command)
if err != nil {
return nil, err
}
@@ -238,17 +236,17 @@ func handleSINTER(ctx context.Context, cmd []string, server types.EchoVault, _ *
defer func() {
for key, locked := range locks {
if locked {
server.KeyRUnlock(ctx, key)
params.KeyRUnlock(params.Context, key)
}
}
}()
for _, key := range keys.ReadKeys {
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
// If key does not exist, then there is no intersection
return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
locks[key] = true
@@ -257,7 +255,7 @@ func handleSINTER(ctx context.Context, cmd []string, server types.EchoVault, _ *
var sets []*internal_set.Set
for key, _ := range locks {
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
if !ok {
// If the value at the key is not a set, return error
return nil, fmt.Errorf("value at key %s is not a set", key)
@@ -283,15 +281,15 @@ func handleSINTER(ctx context.Context, cmd []string, server types.EchoVault, _ *
return []byte(res), nil
}
func handleSINTERCARD(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sintercardKeyFunc(cmd)
func handleSINTERCARD(params types.HandlerFuncParams) ([]byte, error) {
keys, err := sintercardKeyFunc(params.Command)
if err != nil {
return nil, err
}
// Extract the limit from the command
var limit int
limitIdx := slices.IndexFunc(cmd, func(s string) bool {
limitIdx := slices.IndexFunc(params.Command, func(s string) bool {
return strings.EqualFold(s, "limit")
})
if limitIdx >= 0 && limitIdx < 2 {
@@ -299,11 +297,11 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server types.EchoVault,
}
if limitIdx != -1 {
limitIdx += 1
if limitIdx >= len(cmd) {
if limitIdx >= len(params.Command) {
return nil, errors.New("provide limit after LIMIT keyword")
}
if l, ok := internal.AdaptType(cmd[limitIdx]).(int); !ok {
if l, ok := internal.AdaptType(params.Command[limitIdx]).(int); !ok {
return nil, errors.New("limit must be an integer")
} else {
limit = l
@@ -314,17 +312,17 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server types.EchoVault,
defer func() {
for key, locked := range locks {
if locked {
server.KeyRUnlock(ctx, key)
params.KeyRUnlock(params.Context, key)
}
}
}()
for _, key := range keys.ReadKeys {
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
// If key does not exist, then there is no intersection
return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
locks[key] = true
@@ -333,7 +331,7 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server types.EchoVault,
var sets []*internal_set.Set
for key, _ := range locks {
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
if !ok {
// If the value at the key is not a set, return error
return nil, fmt.Errorf("value at key %s is not a set", key)
@@ -350,8 +348,8 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
}
func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sinterstoreKeyFunc(cmd)
func handleSINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
keys, err := sinterstoreKeyFunc(params.Command)
if err != nil {
return nil, err
}
@@ -360,17 +358,17 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault
defer func() {
for key, locked := range locks {
if locked {
server.KeyRUnlock(ctx, key)
params.KeyRUnlock(params.Context, key)
}
}
}()
for _, key := range keys.ReadKeys {
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
// If key does not exist, then there is no intersection
return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
locks[key] = true
@@ -379,7 +377,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault
var sets []*internal_set.Set
for key, _ := range locks {
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
if !ok {
// If the value at the key is not a set, return error
return nil, fmt.Errorf("value at key %s is not a set", key)
@@ -390,71 +388,71 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault
intersect, _ := internal_set.Intersection(0, sets...)
destination := keys.WriteKeys[0]
if server.KeyExists(ctx, destination) {
if _, err = server.KeyLock(ctx, destination); err != nil {
if params.KeyExists(params.Context, destination) {
if _, err = params.KeyLock(params.Context, destination); err != nil {
return nil, err
}
} else {
if _, err = server.CreateKeyAndLock(ctx, destination); err != nil {
if _, err = params.CreateKeyAndLock(params.Context, destination); err != nil {
return nil, err
}
}
if err = server.SetValue(ctx, destination, intersect); err != nil {
if err = params.SetValue(params.Context, destination, intersect); err != nil {
return nil, err
}
server.KeyUnlock(ctx, destination)
params.KeyUnlock(params.Context, destination)
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
}
func handleSISMEMBER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sismemberKeyFunc(cmd)
func handleSISMEMBER(params types.HandlerFuncParams) ([]byte, error) {
keys, err := sismemberKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
if !set.Contains(cmd[2]) {
if !set.Contains(params.Command[2]) {
return []byte(":0\r\n"), nil
}
return []byte(":1\r\n"), nil
}
func handleSMEMBERS(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := smembersKeyFunc(cmd)
func handleSMEMBERS(params types.HandlerFuncParams) ([]byte, error) {
keys, err := smembersKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
@@ -472,16 +470,16 @@ func handleSMEMBERS(ctx context.Context, cmd []string, server types.EchoVault, _
return []byte(res), nil
}
func handleSMISMEMBER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := smismemberKeyFunc(cmd)
func handleSMISMEMBER(params types.HandlerFuncParams) ([]byte, error) {
keys, err := smismemberKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
members := cmd[2:]
members := params.Command[2:]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
res := fmt.Sprintf("*%d", len(members))
for i, _ := range members {
res = fmt.Sprintf("%s\r\n:0", res)
@@ -492,12 +490,12 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(res), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
@@ -515,48 +513,48 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(res), nil
}
func handleSMOVE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := smoveKeyFunc(cmd)
func handleSMOVE(params types.HandlerFuncParams) ([]byte, error) {
keys, err := smoveKeyFunc(params.Command)
if err != nil {
return nil, err
}
source, destination := keys.WriteKeys[0], keys.WriteKeys[1]
member := cmd[3]
member := params.Command[3]
if !server.KeyExists(ctx, source) {
if !params.KeyExists(params.Context, source) {
return []byte(":0\r\n"), nil
}
if _, err = server.KeyLock(ctx, source); err != nil {
if _, err = params.KeyLock(params.Context, source); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, source)
defer params.KeyUnlock(params.Context, source)
sourceSet, ok := server.GetValue(ctx, source).(*internal_set.Set)
sourceSet, ok := params.GetValue(params.Context, source).(*internal_set.Set)
if !ok {
return nil, errors.New("source is not a set")
}
var destinationSet *internal_set.Set
if !server.KeyExists(ctx, destination) {
if !params.KeyExists(params.Context, destination) {
// Destination key does not exist
if _, err = server.CreateKeyAndLock(ctx, destination); err != nil {
if _, err = params.CreateKeyAndLock(params.Context, destination); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, destination)
defer params.KeyUnlock(params.Context, destination)
destinationSet = internal_set.NewSet([]string{})
if err = server.SetValue(ctx, destination, destinationSet); err != nil {
if err = params.SetValue(params.Context, destination, destinationSet); err != nil {
return nil, err
}
} else {
// Destination key exists
if _, err := server.KeyLock(ctx, destination); err != nil {
if _, err := params.KeyLock(params.Context, destination); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, destination)
ds, ok := server.GetValue(ctx, destination).(*internal_set.Set)
defer params.KeyUnlock(params.Context, destination)
ds, ok := params.GetValue(params.Context, destination).(*internal_set.Set)
if !ok {
return nil, errors.New("destination is not a set")
}
@@ -568,8 +566,8 @@ func handleSMOVE(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return []byte(fmt.Sprintf(":%d\r\n", res)), nil
}
func handleSPOP(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := spopKeyFunc(cmd)
func handleSPOP(params types.HandlerFuncParams) ([]byte, error) {
keys, err := spopKeyFunc(params.Command)
if err != nil {
return nil, err
}
@@ -577,24 +575,24 @@ func handleSPOP(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
key := keys.WriteKeys[0]
count := 1
if len(cmd) == 3 {
c, ok := internal.AdaptType(cmd[2]).(int)
if len(params.Command) == 3 {
c, ok := internal.AdaptType(params.Command[2]).(int)
if !ok {
return nil, errors.New("count must be an integer")
}
count = c
}
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte("*-1\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
if !ok {
return nil, fmt.Errorf("value at %s is not a set", key)
}
@@ -612,8 +610,8 @@ func handleSPOP(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return []byte(res), nil
}
func handleSRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := srandmemberKeyFunc(cmd)
func handleSRANDMEMBER(params types.HandlerFuncParams) ([]byte, error) {
keys, err := srandmemberKeyFunc(params.Command)
if err != nil {
return nil, err
}
@@ -621,24 +619,24 @@ func handleSRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault
key := keys.ReadKeys[0]
count := 1
if len(cmd) == 3 {
c, ok := internal.AdaptType(cmd[2]).(int)
if len(params.Command) == 3 {
c, ok := internal.AdaptType(params.Command[2]).(int)
if !ok {
return nil, errors.New("count must be an integer")
}
count = c
}
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte("*-1\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
if !ok {
return nil, fmt.Errorf("value at %s is not a set", key)
}
@@ -656,25 +654,25 @@ func handleSRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault
return []byte(res), nil
}
func handleSREM(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sremKeyFunc(cmd)
func handleSREM(params types.HandlerFuncParams) ([]byte, error) {
keys, err := sremKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.WriteKeys[0]
members := cmd[2:]
members := params.Command[2:]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte(":0\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
@@ -684,8 +682,8 @@ func handleSREM(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return []byte(fmt.Sprintf(":%d\r\n", count)), nil
}
func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sunionKeyFunc(cmd)
func handleSUNION(params types.HandlerFuncParams) ([]byte, error) {
keys, err := sunionKeyFunc(params.Command)
if err != nil {
return nil, err
}
@@ -694,16 +692,16 @@ func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, _ *
defer func() {
for key, locked := range locks {
if locked {
server.KeyRUnlock(ctx, key)
params.KeyRUnlock(params.Context, key)
}
}
}()
for _, key := range keys.ReadKeys {
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
continue
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
locks[key] = true
@@ -715,7 +713,7 @@ func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, _ *
if !locked {
continue
}
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
@@ -735,8 +733,8 @@ func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, _ *
return []byte(res), nil
}
func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sunionstoreKeyFunc(cmd)
func handleSUNIONSTORE(params types.HandlerFuncParams) ([]byte, error) {
keys, err := sunionstoreKeyFunc(params.Command)
if err != nil {
return nil, err
}
@@ -745,16 +743,16 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault
defer func() {
for key, locked := range locks {
if locked {
server.KeyRUnlock(ctx, key)
params.KeyRUnlock(params.Context, key)
}
}
}()
for _, key := range keys.ReadKeys {
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
continue
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
locks[key] = true
@@ -766,7 +764,7 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault
if !locked {
continue
}
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", key)
}
@@ -777,18 +775,18 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault
destination := keys.WriteKeys[0]
if server.KeyExists(ctx, destination) {
if _, err = server.KeyLock(ctx, destination); err != nil {
if params.KeyExists(params.Context, destination) {
if _, err = params.KeyLock(params.Context, destination); err != nil {
return nil, err
}
} else {
if _, err = server.CreateKeyAndLock(ctx, destination); err != nil {
if _, err = params.CreateKeyAndLock(params.Context, destination); err != nil {
return nil, err
}
}
defer server.KeyUnlock(ctx, destination)
defer params.KeyUnlock(params.Context, destination)
if err = server.SetValue(ctx, destination, union); err != nil {
if err = params.SetValue(params.Context, destination, union); err != nil {
return nil, err
}
return []byte(fmt.Sprintf(":%d\r\n", union.Cardinality())), nil

File diff suppressed because it is too large Load Diff

View File

@@ -12,4 +12,4 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package echovault
package acl

View File

@@ -21,6 +21,7 @@ import (
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
acl2 "github.com/echovault/echovault/pkg/modules/acl"
"github.com/tidwall/resp"
"net"
"slices"
@@ -60,8 +61,8 @@ func setUpServer(bindAddr string, port uint16, requirePass bool, aclConfig strin
}
mockServer, _ := echovault.NewEchoVault(
echovault.WithCommands(acl2.Commands()),
echovault.WithConfig(conf),
echovault.WithCommands(Commands()),
)
// Add the initial test users to the ACL module

View File

@@ -12,4 +12,4 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package echovault
package admin

View File

@@ -21,20 +21,58 @@ import (
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/admin"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
"strings"
"testing"
)
func Test_CommandsHandler(t *testing.T) {
mockServer, _ := echovault.NewEchoVault(
var mockServer *echovault.EchoVault
func init() {
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(admin.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
echovault.WithCommands(Commands()),
)
}
res, err := handleGetAllCommands(context.Background(), []string{"commands"}, mockServer, nil)
func getHandler(commands ...string) types.HandlerFunc {
if len(commands) == 0 {
return nil
}
for _, c := range mockServer.GetAllCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler
return c.HandlerFunc
}
if strings.EqualFold(commands[0], c.Command) {
// Get sub-command handler
for _, sc := range c.SubCommands {
if strings.EqualFold(commands[1], sc.Command) {
return sc.HandlerFunc
}
}
}
}
return nil
}
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
return types.HandlerFuncParams{
Context: ctx,
Command: cmd,
Connection: conn,
GetAllCommands: mockServer.GetAllCommands,
}
}
func Test_CommandsHandler(t *testing.T) {
res, err := getHandler("COMMANDS")(getHandlerFuncParams(context.Background(), []string{"commands"}, nil))
if err != nil {
t.Error(err)
}

View File

@@ -12,4 +12,4 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package echovault
package connection

View File

@@ -21,7 +21,11 @@ import (
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/connection"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
"strings"
"testing"
)
@@ -29,6 +33,7 @@ var mockServer *echovault.EchoVault
func init() {
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(connection.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
@@ -36,6 +41,35 @@ func init() {
)
}
func getHandler(commands ...string) types.HandlerFunc {
if len(commands) == 0 {
return nil
}
for _, c := range mockServer.GetAllCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler
return c.HandlerFunc
}
if strings.EqualFold(commands[0], c.Command) {
// Get sub-command handler
for _, sc := range c.SubCommands {
if strings.EqualFold(commands[1], sc.Command) {
return sc.HandlerFunc
}
}
}
}
return nil
}
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
return types.HandlerFuncParams{
Context: ctx,
Command: cmd,
Connection: conn,
}
}
func Test_HandlePing(t *testing.T) {
ctx := context.Background()
@@ -62,7 +96,7 @@ func Test_HandlePing(t *testing.T) {
}
for _, test := range tests {
res, err := handlePing(ctx, test.command, mockServer, nil)
res, err := getHandler("PING")(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedErr != nil && err != nil {
if err.Error() != test.expectedErr.Error() {
t.Errorf("expected error %s, got: %s", test.expectedErr.Error(), err.Error())

View File

@@ -12,14 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package echovault
package generic
import (
"context"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/clock"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/commands"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/generic"
"reflect"
"slices"
"strings"
@@ -27,13 +28,36 @@ import (
"time"
)
func TestEchoVault_DEL(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
EvictionPolicy: constants.NoEviction,
func createEchoVault() *echovault.EchoVault {
ev, _ := echovault.NewEchoVault(
echovault.WithCommands(generic.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
}),
)
return ev
}
func presetValue(server *echovault.EchoVault, ctx context.Context, key string, value interface{}) error {
if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
return err
}
if err := server.SetValue(ctx, key, value); err != nil {
return err
}
server.KeyUnlock(ctx, key)
return nil
}
func presetKeyData(server *echovault.EchoVault, ctx context.Context, key string, data internal.KeyData) {
_, _ = server.CreateKeyAndLock(ctx, key)
defer server.KeyUnlock(ctx, key)
_ = server.SetValue(ctx, key, data.Value)
server.SetExpiry(ctx, key, data.ExpireAt, false)
}
func TestEchoVault_DEL(t *testing.T) {
server := createEchoVault()
tests := []struct {
name string
@@ -59,7 +83,7 @@ func TestEchoVault_DEL(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, d := range tt.presetValues {
presetKeyData(server, k, d)
presetKeyData(server, context.Background(), k, d)
}
}
got, err := server.DEL(tt.keys...)
@@ -77,12 +101,7 @@ func TestEchoVault_DEL(t *testing.T) {
func TestEchoVault_EXPIRE(t *testing.T) {
mockClock := clock.NewClock()
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -90,8 +109,8 @@ func TestEchoVault_EXPIRE(t *testing.T) {
cmd string
key string
time int
expireOpts EXPIREOptions
pexpireOpts PEXPIREOptions
expireOpts echovault.EXPIREOptions
pexpireOpts echovault.PEXPIREOptions
want int
wantErr bool
}{
@@ -100,7 +119,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
cmd: "EXPIRE",
key: "key1",
time: 100,
expireOpts: EXPIREOptions{},
expireOpts: echovault.EXPIREOptions{},
presetValues: map[string]internal.KeyData{
"key1": {Value: "value1", ExpireAt: time.Time{}},
},
@@ -112,7 +131,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
cmd: "PEXPIRE",
key: "key2",
time: 1000,
pexpireOpts: PEXPIREOptions{},
pexpireOpts: echovault.PEXPIREOptions{},
presetValues: map[string]internal.KeyData{
"key2": {Value: "value2", ExpireAt: time.Time{}},
},
@@ -124,7 +143,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
cmd: "EXPIRE",
key: "key3",
time: 1000,
expireOpts: EXPIREOptions{NX: true},
expireOpts: echovault.EXPIREOptions{NX: true},
presetValues: map[string]internal.KeyData{
"key3": {Value: "value3", ExpireAt: time.Time{}},
},
@@ -136,7 +155,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
cmd: "EXPIRE",
key: "key4",
time: 1000,
expireOpts: EXPIREOptions{NX: true},
expireOpts: echovault.EXPIREOptions{NX: true},
presetValues: map[string]internal.KeyData{
"key4": {Value: "value4", ExpireAt: mockClock.Now().Add(1000 * time.Second)},
},
@@ -148,7 +167,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
cmd: "EXPIRE",
key: "key5",
time: 1000,
expireOpts: EXPIREOptions{XX: true},
expireOpts: echovault.EXPIREOptions{XX: true},
presetValues: map[string]internal.KeyData{
"key5": {Value: "value5", ExpireAt: mockClock.Now().Add(30 * time.Second)},
},
@@ -159,7 +178,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
name: "Return 0 when key does not have an expiry and the XX flag is provided",
cmd: "EXPIRE",
time: 1000,
expireOpts: EXPIREOptions{XX: true},
expireOpts: echovault.EXPIREOptions{XX: true},
key: "key6",
presetValues: map[string]internal.KeyData{
"key6": {Value: "value6", ExpireAt: time.Time{}},
@@ -172,7 +191,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
cmd: "EXPIRE",
key: "key7",
time: 100000,
expireOpts: EXPIREOptions{GT: true},
expireOpts: echovault.EXPIREOptions{GT: true},
presetValues: map[string]internal.KeyData{
"key7": {Value: "value7", ExpireAt: mockClock.Now().Add(30 * time.Second)},
},
@@ -184,7 +203,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
cmd: "EXPIRE",
key: "key8",
time: 1000,
expireOpts: EXPIREOptions{GT: true},
expireOpts: echovault.EXPIREOptions{GT: true},
presetValues: map[string]internal.KeyData{
"key8": {Value: "value8", ExpireAt: mockClock.Now().Add(3000 * time.Second)},
},
@@ -196,7 +215,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
cmd: "EXPIRE",
key: "key9",
time: 1000,
expireOpts: EXPIREOptions{GT: true},
expireOpts: echovault.EXPIREOptions{GT: true},
presetValues: map[string]internal.KeyData{
"key9": {Value: "value9", ExpireAt: time.Time{}},
},
@@ -208,7 +227,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
cmd: "EXPIRE",
key: "key10",
time: 1000,
expireOpts: EXPIREOptions{LT: true},
expireOpts: echovault.EXPIREOptions{LT: true},
presetValues: map[string]internal.KeyData{
"key10": {Value: "value10", ExpireAt: mockClock.Now().Add(3000 * time.Second)},
},
@@ -220,7 +239,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
cmd: "EXPIRE",
key: "key11",
time: 50000,
expireOpts: EXPIREOptions{LT: true},
expireOpts: echovault.EXPIREOptions{LT: true},
presetValues: map[string]internal.KeyData{
"key11": {Value: "value11", ExpireAt: mockClock.Now().Add(30 * time.Second)},
},
@@ -232,7 +251,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, d := range tt.presetValues {
presetKeyData(server, k, d)
presetKeyData(server, context.Background(), k, d)
}
}
var got int
@@ -256,12 +275,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
func TestEchoVault_EXPIREAT(t *testing.T) {
mockClock := clock.NewClock()
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -269,8 +283,8 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
cmd string
key string
time int
expireAtOpts EXPIREATOptions
pexpireAtOpts PEXPIREATOptions
expireAtOpts echovault.EXPIREATOptions
pexpireAtOpts echovault.PEXPIREATOptions
want int
wantErr bool
}{
@@ -278,7 +292,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
name: "Set new expire by unix seconds",
cmd: "EXPIREAT",
key: "key1",
expireAtOpts: EXPIREATOptions{},
expireAtOpts: echovault.EXPIREATOptions{},
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
presetValues: map[string]internal.KeyData{
"key1": {Value: "value1", ExpireAt: time.Time{}},
@@ -290,7 +304,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
name: "Set new expire by milliseconds",
cmd: "PEXPIREAT",
key: "key2",
pexpireAtOpts: PEXPIREATOptions{},
pexpireAtOpts: echovault.PEXPIREATOptions{},
time: int(mockClock.Now().Add(1000 * time.Second).UnixMilli()),
presetValues: map[string]internal.KeyData{
"key2": {Value: "value2", ExpireAt: time.Time{}},
@@ -303,7 +317,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
cmd: "EXPIREAT",
key: "key3",
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
expireAtOpts: EXPIREATOptions{NX: true},
expireAtOpts: echovault.EXPIREATOptions{NX: true},
presetValues: map[string]internal.KeyData{
"key3": {Value: "value3", ExpireAt: time.Time{}},
},
@@ -314,7 +328,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
name: "Return 0, when NX flag is provided and key already has an expiry time",
cmd: "EXPIREAT",
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
expireAtOpts: EXPIREATOptions{NX: true},
expireAtOpts: echovault.EXPIREATOptions{NX: true},
key: "key4",
presetValues: map[string]internal.KeyData{
"key4": {Value: "value4", ExpireAt: mockClock.Now().Add(1000 * time.Second)},
@@ -327,7 +341,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
cmd: "EXPIREAT",
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
key: "key5",
expireAtOpts: EXPIREATOptions{XX: true},
expireAtOpts: echovault.EXPIREATOptions{XX: true},
presetValues: map[string]internal.KeyData{
"key5": {Value: "value5", ExpireAt: mockClock.Now().Add(30 * time.Second)},
},
@@ -339,7 +353,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
cmd: "EXPIREAT",
key: "key6",
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
expireAtOpts: EXPIREATOptions{XX: true},
expireAtOpts: echovault.EXPIREATOptions{XX: true},
presetValues: map[string]internal.KeyData{
"key6": {Value: "value6", ExpireAt: time.Time{}},
},
@@ -351,7 +365,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
cmd: "EXPIREAT",
key: "key7",
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
expireAtOpts: EXPIREATOptions{GT: true},
expireAtOpts: echovault.EXPIREATOptions{GT: true},
presetValues: map[string]internal.KeyData{
"key7": {Value: "value7", ExpireAt: mockClock.Now().Add(30 * time.Second)},
},
@@ -363,7 +377,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
cmd: "EXPIREAT",
key: "key8",
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
expireAtOpts: EXPIREATOptions{GT: true},
expireAtOpts: echovault.EXPIREATOptions{GT: true},
presetValues: map[string]internal.KeyData{
"key8": {Value: "value8", ExpireAt: mockClock.Now().Add(3000 * time.Second)},
},
@@ -375,7 +389,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
cmd: "EXPIREAT",
key: "key9",
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
expireAtOpts: EXPIREATOptions{GT: true},
expireAtOpts: echovault.EXPIREATOptions{GT: true},
presetValues: map[string]internal.KeyData{
"key9": {Value: "value9", ExpireAt: time.Time{}},
},
@@ -386,7 +400,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
cmd: "EXPIREAT",
key: "key10",
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
expireAtOpts: EXPIREATOptions{LT: true},
expireAtOpts: echovault.EXPIREATOptions{LT: true},
presetValues: map[string]internal.KeyData{
"key10": {Value: "value10", ExpireAt: mockClock.Now().Add(3000 * time.Second)},
},
@@ -398,7 +412,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
cmd: "EXPIREAT",
key: "key11",
time: int(mockClock.Now().Add(3000 * time.Second).Unix()),
expireAtOpts: EXPIREATOptions{LT: true},
expireAtOpts: echovault.EXPIREATOptions{LT: true},
presetValues: map[string]internal.KeyData{
"key11": {Value: "value11", ExpireAt: mockClock.Now().Add(1000 * time.Second)},
},
@@ -410,7 +424,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
cmd: "EXPIREAT",
key: "key12",
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
expireAtOpts: EXPIREATOptions{LT: true},
expireAtOpts: echovault.EXPIREATOptions{LT: true},
presetValues: map[string]internal.KeyData{
"key12": {Value: "value12", ExpireAt: time.Time{}},
},
@@ -422,7 +436,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, d := range tt.presetValues {
presetKeyData(server, k, d)
presetKeyData(server, context.Background(), k, d)
}
}
var got int
@@ -446,12 +460,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
func TestEchoVault_EXPIRETIME(t *testing.T) {
mockClock := clock.NewClock()
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -503,7 +512,7 @@ func TestEchoVault_EXPIRETIME(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, d := range tt.presetValues {
presetKeyData(server, k, d)
presetKeyData(server, context.Background(), k, d)
}
}
got, err := tt.expiretimeFunc(tt.key)
@@ -519,12 +528,7 @@ func TestEchoVault_EXPIRETIME(t *testing.T) {
}
func TestEchoVault_GET(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -551,7 +555,11 @@ func TestEchoVault_GET(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.GET(tt.key)
if (err != nil) != tt.wantErr {
@@ -566,12 +574,7 @@ func TestEchoVault_GET(t *testing.T) {
}
func TestEchoVault_MGET(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -599,7 +602,11 @@ func TestEchoVault_MGET(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, v := range tt.presetValues {
presetValue(server, k, v)
err := presetValue(server, context.Background(), k, v)
if err != nil {
t.Error(err)
return
}
}
}
got, err := server.MGET(tt.keys...)
@@ -625,19 +632,14 @@ func TestEchoVault_MGET(t *testing.T) {
func TestEchoVault_SET(t *testing.T) {
mockClock := clock.NewClock()
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
presetValues map[string]internal.KeyData
key string
value string
options SETOptions
options echovault.SETOptions
want string
wantErr bool
}{
@@ -646,7 +648,7 @@ func TestEchoVault_SET(t *testing.T) {
presetValues: nil,
key: "key1",
value: "value1",
options: SETOptions{},
options: echovault.SETOptions{},
want: "OK",
wantErr: false,
},
@@ -655,7 +657,7 @@ func TestEchoVault_SET(t *testing.T) {
presetValues: nil,
key: "key2",
value: "value2",
options: SETOptions{NX: true},
options: echovault.SETOptions{NX: true},
want: "OK",
wantErr: false,
},
@@ -669,7 +671,7 @@ func TestEchoVault_SET(t *testing.T) {
},
key: "key3",
value: "value3",
options: SETOptions{NX: true},
options: echovault.SETOptions{NX: true},
want: "",
wantErr: true,
},
@@ -683,7 +685,7 @@ func TestEchoVault_SET(t *testing.T) {
},
key: "key4",
value: "value4",
options: SETOptions{XX: true},
options: echovault.SETOptions{XX: true},
want: "OK",
wantErr: false,
},
@@ -692,7 +694,7 @@ func TestEchoVault_SET(t *testing.T) {
presetValues: nil,
key: "key5",
value: "value5",
options: SETOptions{XX: true},
options: echovault.SETOptions{XX: true},
want: "",
wantErr: true,
},
@@ -701,7 +703,7 @@ func TestEchoVault_SET(t *testing.T) {
presetValues: nil,
key: "key6",
value: "value6",
options: SETOptions{EX: 100},
options: echovault.SETOptions{EX: 100},
want: "OK",
wantErr: false,
},
@@ -710,7 +712,7 @@ func TestEchoVault_SET(t *testing.T) {
presetValues: nil,
key: "key7",
value: "value7",
options: SETOptions{PX: 4096},
options: echovault.SETOptions{PX: 4096},
want: "OK",
wantErr: false,
},
@@ -719,7 +721,7 @@ func TestEchoVault_SET(t *testing.T) {
presetValues: nil,
key: "key8",
value: "value8",
options: SETOptions{EXAT: int(mockClock.Now().Add(200 * time.Second).Unix())},
options: echovault.SETOptions{EXAT: int(mockClock.Now().Add(200 * time.Second).Unix())},
want: "OK",
wantErr: false,
},
@@ -727,7 +729,7 @@ func TestEchoVault_SET(t *testing.T) {
name: "Set exact expiry time in milliseconds from unix epoch",
key: "key9",
value: "value9",
options: SETOptions{PXAT: int(mockClock.Now().Add(4096 * time.Millisecond).UnixMilli())},
options: echovault.SETOptions{PXAT: int(mockClock.Now().Add(4096 * time.Millisecond).UnixMilli())},
presetValues: nil,
want: "OK",
wantErr: false,
@@ -742,7 +744,7 @@ func TestEchoVault_SET(t *testing.T) {
},
key: "key10",
value: "value10",
options: SETOptions{GET: true, EX: 1000},
options: echovault.SETOptions{GET: true, EX: 1000},
want: "previous-value",
wantErr: false,
},
@@ -751,7 +753,7 @@ func TestEchoVault_SET(t *testing.T) {
presetValues: nil,
key: "key11",
value: "value11",
options: SETOptions{GET: true, EX: 1000},
options: echovault.SETOptions{GET: true, EX: 1000},
want: "",
wantErr: false,
},
@@ -760,7 +762,7 @@ func TestEchoVault_SET(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, d := range tt.presetValues {
presetKeyData(server, k, d)
presetKeyData(server, context.Background(), k, d)
}
}
got, err := server.SET(tt.key, tt.value, tt.options)
@@ -776,12 +778,7 @@ func TestEchoVault_SET(t *testing.T) {
}
func TestEchoVault_MSET(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -813,12 +810,7 @@ func TestEchoVault_MSET(t *testing.T) {
func TestEchoVault_PERSIST(t *testing.T) {
mockClock := clock.NewClock()
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -857,7 +849,7 @@ func TestEchoVault_PERSIST(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, d := range tt.presetValues {
presetKeyData(server, k, d)
presetKeyData(server, context.Background(), k, d)
}
}
got, err := server.PERSIST(tt.key)
@@ -875,12 +867,7 @@ func TestEchoVault_PERSIST(t *testing.T) {
func TestEchoVault_TTL(t *testing.T) {
mockClock := clock.NewClock()
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -933,7 +920,7 @@ func TestEchoVault_TTL(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, d := range tt.presetValues {
presetKeyData(server, k, d)
presetKeyData(server, context.Background(), k, d)
}
}
got, err := tt.ttlFunc(tt.key)

View File

@@ -23,7 +23,11 @@ import (
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/generic"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
"strings"
"testing"
"time"
)
@@ -41,6 +45,7 @@ func init() {
mockClock = clock.NewClock()
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(generic.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
@@ -48,6 +53,47 @@ func init() {
)
}
func getHandler(commands ...string) types.HandlerFunc {
if len(commands) == 0 {
return nil
}
for _, c := range mockServer.GetAllCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler
return c.HandlerFunc
}
if strings.EqualFold(commands[0], c.Command) {
// Get sub-command handler
for _, sc := range c.SubCommands {
if strings.EqualFold(commands[1], sc.Command) {
return sc.HandlerFunc
}
}
}
}
return nil
}
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
return types.HandlerFuncParams{
Context: ctx,
Command: cmd,
Connection: conn,
KeyExists: mockServer.KeyExists,
CreateKeyAndLock: mockServer.CreateKeyAndLock,
KeyLock: mockServer.KeyLock,
KeyRLock: mockServer.KeyRLock,
KeyUnlock: mockServer.KeyUnlock,
KeyRUnlock: mockServer.KeyRUnlock,
GetValue: mockServer.GetValue,
SetValue: mockServer.SetValue,
GetClock: mockServer.GetClock,
GetExpiry: mockServer.GetExpiry,
SetExpiry: mockServer.SetExpiry,
DeleteKey: mockServer.DeleteKey,
}
}
func Test_HandleSET(t *testing.T) {
tests := []struct {
name string
@@ -372,7 +418,13 @@ func Test_HandleSET(t *testing.T) {
}
}
res, err := handleSet(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedErr != nil {
if err == nil {
t.Errorf("expected error \"%s\", got nil", test.expectedErr.Error())
@@ -454,7 +506,14 @@ func Test_HandleMSET(t *testing.T) {
for i, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("MSET, %d", i))
res, err := handleMSet(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedErr != nil {
if err.Error() != test.expectedErr.Error() {
t.Errorf("expected error %s, got %s", test.expectedErr.Error(), err.Error())
@@ -547,7 +606,14 @@ func Test_HandleGET(t *testing.T) {
}
mockServer.KeyUnlock(ctx, key)
res, err := handleGet(ctx, []string{"GET", key}, mockServer, nil)
handler := getHandler("GET")
if handler == nil {
t.Error("no handler found for command GET")
return
}
res, err := handler(getHandlerFuncParams(ctx, []string{"GET", key}, nil))
if err != nil {
t.Error(err)
}
@@ -559,7 +625,7 @@ func Test_HandleGET(t *testing.T) {
}
// Test get non-existent key
res, err := handleGet(context.Background(), []string{"GET", "test4"}, mockServer, nil)
res, err := getHandler("GET")(getHandlerFuncParams(context.Background(), []string{"GET", "test4"}, nil))
if err != nil {
t.Error(err)
}
@@ -585,7 +651,12 @@ func Test_HandleGET(t *testing.T) {
}
for _, test := range errorTests {
t.Run(test.name, func(t *testing.T) {
res, err = handleGet(context.Background(), test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err = handler(getHandlerFuncParams(context.Background(), test.command, nil))
if res != nil {
t.Errorf("expected nil response, got: %+v", res)
}
@@ -631,21 +702,28 @@ func Test_HandleMGET(t *testing.T) {
},
}
for _, test := range tests {
for i, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("MGET, %d", i))
// Set up the values
for i, key := range test.presetKeys {
_, err := mockServer.CreateKeyAndLock(context.Background(), key)
_, err := mockServer.CreateKeyAndLock(ctx, key)
if err != nil {
t.Error(err)
}
if err = mockServer.SetValue(context.Background(), key, test.presetValues[i]); err != nil {
if err = mockServer.SetValue(ctx, key, test.presetValues[i]); err != nil {
t.Error(err)
}
mockServer.KeyUnlock(context.Background(), key)
mockServer.KeyUnlock(ctx, key)
}
// Test the command and its results
res, err := handleMGet(context.Background(), test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
// If we expect and error, branch out and check error
if err.Error() != test.expectedError.Error() {
@@ -734,7 +812,13 @@ func Test_HandleDEL(t *testing.T) {
}
}
res, err := handleDel(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedErr != nil {
if err == nil {
t.Errorf("exected error \"%s\", got nil", test.expectedErr.Error())
@@ -844,7 +928,13 @@ func Test_HandlePERSIST(t *testing.T) {
}
}
res, err := handlePersist(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err == nil {
@@ -965,7 +1055,13 @@ func Test_HandleEXPIRETIME(t *testing.T) {
}
}
res, err := handleExpireTime(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err == nil {
@@ -1067,7 +1163,13 @@ func Test_HandleTTL(t *testing.T) {
}
}
res, err := handleTTL(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err == nil {
@@ -1300,7 +1402,13 @@ func Test_HandleEXPIRE(t *testing.T) {
}
}
res, err := handleExpire(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err == nil {
@@ -1576,7 +1684,13 @@ func Test_HandleEXPIREAT(t *testing.T) {
}
}
res, err := handleExpireAt(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err == nil {

View File

@@ -12,25 +12,41 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package echovault
package hash
import (
"context"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/commands"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/hash"
"reflect"
"slices"
"testing"
)
func TestEchoVault_HDEL(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
func createEchoVault() *echovault.EchoVault {
ev, _ := echovault.NewEchoVault(
echovault.WithCommands(hash.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
return ev
}
func presetValue(server *echovault.EchoVault, ctx context.Context, key string, value interface{}) error {
if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
return err
}
if err := server.SetValue(ctx, key, value); err != nil {
return err
}
server.KeyUnlock(ctx, key)
return nil
}
func TestEchoVault_HDEL(t *testing.T) {
server := createEchoVault()
tests := []struct {
name string
@@ -76,7 +92,11 @@ func TestEchoVault_HDEL(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.HDEL(tt.key, tt.fields...)
if (err != nil) != tt.wantErr {
@@ -91,13 +111,7 @@ func TestEchoVault_HDEL(t *testing.T) {
}
func TestEchoVault_HEXISTS(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -135,7 +149,11 @@ func TestEchoVault_HEXISTS(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.HEXISTS(tt.key, tt.field)
if (err != nil) != tt.wantErr {
@@ -150,13 +168,7 @@ func TestEchoVault_HEXISTS(t *testing.T) {
}
func TestEchoVault_HGETALL(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -190,7 +202,11 @@ func TestEchoVault_HGETALL(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.HGETALL(tt.key)
if (err != nil) != tt.wantErr {
@@ -212,13 +228,7 @@ func TestEchoVault_HGETALL(t *testing.T) {
}
func TestEchoVault_HINCRBY(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
const (
HINCRBY = "HINCRBY"
@@ -300,7 +310,11 @@ func TestEchoVault_HINCRBY(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
var got float64
var err error
@@ -326,13 +340,7 @@ func TestEchoVault_HINCRBY(t *testing.T) {
}
func TestEchoVault_HKEYS(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -366,7 +374,11 @@ func TestEchoVault_HKEYS(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.HKEYS(tt.key)
if (err != nil) != tt.wantErr {
@@ -386,13 +398,7 @@ func TestEchoVault_HKEYS(t *testing.T) {
}
func TestEchoVault_HLEN(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -426,7 +432,11 @@ func TestEchoVault_HLEN(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.HLEN(tt.key)
if (err != nil) != tt.wantErr {
@@ -441,19 +451,13 @@ func TestEchoVault_HLEN(t *testing.T) {
}
func TestEchoVault_HRANDFIELD(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
presetValue interface{}
key string
options HRANDFIELDOptions
options echovault.HRANDFIELDOptions
wantCount int
want []string
wantErr bool
@@ -462,7 +466,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
name: "Get a random field",
presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142},
key: "key1",
options: HRANDFIELDOptions{Count: 1},
options: echovault.HRANDFIELDOptions{Count: 1},
wantCount: 1,
want: []string{"field1", "field2", "field3"},
wantErr: false,
@@ -471,7 +475,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
name: "Get a random field with a value",
presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142},
key: "key2",
options: HRANDFIELDOptions{WithValues: true, Count: 1},
options: echovault.HRANDFIELDOptions{WithValues: true, Count: 1},
wantCount: 2,
want: []string{"field1", "value1", "field2", "123456789", "field3", "3.142"},
wantErr: false,
@@ -486,7 +490,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
"field5": "value5",
},
key: "key3",
options: HRANDFIELDOptions{Count: 3},
options: echovault.HRANDFIELDOptions{Count: 3},
wantCount: 3,
want: []string{"field1", "field2", "field3", "field4", "field5"},
wantErr: false,
@@ -501,7 +505,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
"field5": "value5",
},
key: "key4",
options: HRANDFIELDOptions{WithValues: true, Count: 3},
options: echovault.HRANDFIELDOptions{WithValues: true, Count: 3},
wantCount: 6,
want: []string{
"field1", "value1", "field2", "123456789", "field3",
@@ -519,7 +523,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
"field5": "value5",
},
key: "key5",
options: HRANDFIELDOptions{Count: 5},
options: echovault.HRANDFIELDOptions{Count: 5},
wantCount: 5,
want: []string{"field1", "field2", "field3", "field4", "field5"},
wantErr: false,
@@ -534,7 +538,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
"field5": "value5",
},
key: "key5",
options: HRANDFIELDOptions{WithValues: true, Count: 5},
options: echovault.HRANDFIELDOptions{WithValues: true, Count: 5},
wantCount: 10,
want: []string{
"field1", "value1", "field2", "123456789", "field3",
@@ -546,7 +550,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
name: "Trying to get random field on a non hash map returns error",
presetValue: "Default value",
key: "key12",
options: HRANDFIELDOptions{},
options: echovault.HRANDFIELDOptions{},
wantCount: 0,
want: nil,
wantErr: true,
@@ -555,7 +559,11 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.HRANDFIELD(tt.key, tt.options)
if (err != nil) != tt.wantErr {
@@ -575,13 +583,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
}
func TestEchoVault_HSET(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -650,7 +652,11 @@ func TestEchoVault_HSET(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := tt.hsetFunc(tt.key, tt.fieldValuePairs)
if (err != nil) != tt.wantErr {
@@ -665,13 +671,7 @@ func TestEchoVault_HSET(t *testing.T) {
}
func TestEchoVault_HSTRLEN(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -719,7 +719,11 @@ func TestEchoVault_HSTRLEN(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.HSTRLEN(tt.key, tt.fields...)
if (err != nil) != tt.wantErr {
@@ -734,13 +738,7 @@ func TestEchoVault_HSTRLEN(t *testing.T) {
}
func TestEchoVault_HVALS(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -774,7 +772,11 @@ func TestEchoVault_HVALS(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.HVALS(tt.key)
if (err != nil) != tt.wantErr {

View File

@@ -22,8 +22,12 @@ import (
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/hash"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
"slices"
"strings"
"testing"
)
@@ -31,6 +35,7 @@ var mockServer *echovault.EchoVault
func init() {
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(hash.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
@@ -38,6 +43,43 @@ func init() {
)
}
func getHandler(commands ...string) types.HandlerFunc {
if len(commands) == 0 {
return nil
}
for _, c := range mockServer.GetAllCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler
return c.HandlerFunc
}
if strings.EqualFold(commands[0], c.Command) {
// Get sub-command handler
for _, sc := range c.SubCommands {
if strings.EqualFold(commands[1], sc.Command) {
return sc.HandlerFunc
}
}
}
}
return nil
}
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
return types.HandlerFuncParams{
Context: ctx,
Command: cmd,
Connection: conn,
KeyExists: mockServer.KeyExists,
CreateKeyAndLock: mockServer.CreateKeyAndLock,
KeyLock: mockServer.KeyLock,
KeyRLock: mockServer.KeyRLock,
KeyUnlock: mockServer.KeyUnlock,
KeyRUnlock: mockServer.KeyRUnlock,
GetValue: mockServer.GetValue,
SetValue: mockServer.SetValue,
}
}
func Test_HandleHSET(t *testing.T) {
// Tests for both HSET and HSETNX
tests := []struct {
@@ -144,7 +186,14 @@ func Test_HandleHSET(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleHSET(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -163,11 +212,11 @@ func Test_HandleHSET(t *testing.T) {
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err)
}
hash, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{})
h, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{})
if !ok {
t.Errorf("value at key \"%s\" is not a hash map", test.key)
}
for field, value := range hash {
for field, value := range h {
if value != test.expectedValue[field] {
t.Errorf("expected value \"%+v\" for field \"%+v\", got \"%+v\"", test.expectedValue[field], field, value)
}
@@ -303,7 +352,14 @@ func Test_HandleHINCRBY(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleHINCRBY(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -331,11 +387,11 @@ func Test_HandleHINCRBY(t *testing.T) {
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err)
}
hash, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{})
h, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{})
if !ok {
t.Errorf("value at key \"%s\" is not a hash map", test.key)
}
for field, value := range hash {
for field, value := range h {
if value != test.expectedValue[field] {
t.Errorf("expected value \"%+v\" for field \"%+v\", got \"%+v\"", test.expectedValue[field], field, value)
}
@@ -410,7 +466,14 @@ func Test_HandleHGET(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleHGET(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -519,7 +582,14 @@ func Test_HandleHSTRLEN(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleHSTRLEN(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -623,7 +693,14 @@ func Test_HandleHVALS(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleHVALS(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -831,7 +908,14 @@ func Test_HandleHRANDFIELD(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleHRANDFIELD(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -954,7 +1038,14 @@ func Test_HandleHLEN(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleHLEN(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1053,7 +1144,14 @@ func Test_HandleHKeys(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleHKEYS(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1159,7 +1257,14 @@ func Test_HandleHGETALL(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleHGETALL(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1276,7 +1381,14 @@ func Test_HandleHEXISTS(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleHEXISTS(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1375,7 +1487,14 @@ func Test_HandleHDEL(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleHDEL(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1396,8 +1515,8 @@ func Test_HandleHDEL(t *testing.T) {
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err)
}
if hash, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{}); ok {
for field, value := range hash {
if h, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{}); ok {
for field, value := range h {
if value != test.expectedValue[field] {
t.Errorf("expected value \"%+v\", got \"%+v\"", test.expectedValue[field], value)
}

View File

@@ -12,24 +12,40 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package echovault
package list
import (
"context"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/commands"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/list"
"reflect"
"testing"
)
func TestEchoVault_LLEN(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
func createEchoVault() *echovault.EchoVault {
ev, _ := echovault.NewEchoVault(
echovault.WithCommands(list.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
return ev
}
func presetValue(server *echovault.EchoVault, ctx context.Context, key string, value interface{}) error {
if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
return err
}
if err := server.SetValue(ctx, key, value); err != nil {
return err
}
server.KeyUnlock(ctx, key)
return nil
}
func TestEchoVault_LLEN(t *testing.T) {
server := createEchoVault()
tests := []struct {
preset bool
@@ -67,7 +83,11 @@ func TestEchoVault_LLEN(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.preset {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.LLEN(tt.key)
if (err != nil) != tt.wantErr {
@@ -82,13 +102,7 @@ func TestEchoVault_LLEN(t *testing.T) {
}
func TestEchoVault_LINDEX(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
preset bool
@@ -156,7 +170,11 @@ func TestEchoVault_LINDEX(t *testing.T) {
}
for _, tt := range tests {
if tt.preset {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
t.Run(tt.name, func(t *testing.T) {
got, err := server.LINDEX(tt.key, tt.index)
@@ -172,13 +190,7 @@ func TestEchoVault_LINDEX(t *testing.T) {
}
func TestEchoVault_LMOVE(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -328,7 +340,11 @@ func TestEchoVault_LMOVE(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.preset {
for k, v := range tt.presetValue {
presetValue(server, k, v)
err := presetValue(server, context.Background(), k, v)
if err != nil {
t.Error(err)
return
}
}
}
got, err := server.LMOVE(tt.source, tt.destination, tt.whereFrom, tt.whereTo)
@@ -344,13 +360,7 @@ func TestEchoVault_LMOVE(t *testing.T) {
}
func TestEchoVault_POP(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -401,7 +411,11 @@ func TestEchoVault_POP(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.preset {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := tt.popFunc(tt.key)
if (err != nil) != tt.wantErr {
@@ -416,13 +430,7 @@ func TestEchoVault_POP(t *testing.T) {
}
func TestEchoVault_LPUSH(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -478,7 +486,11 @@ func TestEchoVault_LPUSH(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.preset {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := tt.lpushFunc(tt.key, tt.values...)
if (err != nil) != tt.wantErr {
@@ -493,13 +505,7 @@ func TestEchoVault_LPUSH(t *testing.T) {
}
func TestEchoVault_RPUSH(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -535,7 +541,11 @@ func TestEchoVault_RPUSH(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.preset {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := tt.rpushFunc(tt.key, tt.values...)
if (err != nil) != tt.wantErr {
@@ -550,13 +560,7 @@ func TestEchoVault_RPUSH(t *testing.T) {
}
func TestEchoVault_LRANGE(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -656,7 +660,11 @@ func TestEchoVault_LRANGE(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.preset {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.LRANGE(tt.key, tt.start, tt.end)
if (err != nil) != tt.wantErr {
@@ -671,13 +679,7 @@ func TestEchoVault_LRANGE(t *testing.T) {
}
func TestEchoVault_LREM(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -722,7 +724,11 @@ func TestEchoVault_LREM(t *testing.T) {
}
for _, tt := range tests {
if tt.preset {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
t.Run(tt.name, func(t *testing.T) {
got, err := server.LREM(tt.key, tt.count, tt.value)
@@ -738,13 +744,7 @@ func TestEchoVault_LREM(t *testing.T) {
}
func TestEchoVault_LSET(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -830,7 +830,11 @@ func TestEchoVault_LSET(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.preset {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.LSET(tt.key, tt.index, tt.value)
if (err != nil) != tt.wantErr {
@@ -845,13 +849,7 @@ func TestEchoVault_LSET(t *testing.T) {
}
func TestEchoVault_LTRIM(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -940,7 +938,11 @@ func TestEchoVault_LTRIM(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.preset {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.LTRIM(tt.key, tt.start, tt.end)
if (err != nil) != tt.wantErr {

View File

@@ -22,7 +22,11 @@ import (
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/modules/list"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
"strings"
"testing"
)
@@ -30,6 +34,7 @@ var mockServer *echovault.EchoVault
func init() {
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(list.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
@@ -37,6 +42,43 @@ func init() {
)
}
func getHandler(commands ...string) types.HandlerFunc {
if len(commands) == 0 {
return nil
}
for _, c := range mockServer.GetAllCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler
return c.HandlerFunc
}
if strings.EqualFold(commands[0], c.Command) {
// Get sub-command handler
for _, sc := range c.SubCommands {
if strings.EqualFold(commands[1], sc.Command) {
return sc.HandlerFunc
}
}
}
}
return nil
}
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
return types.HandlerFuncParams{
Context: ctx,
Command: cmd,
Connection: conn,
KeyExists: mockServer.KeyExists,
CreateKeyAndLock: mockServer.CreateKeyAndLock,
KeyLock: mockServer.KeyLock,
KeyRLock: mockServer.KeyRLock,
KeyUnlock: mockServer.KeyUnlock,
KeyRUnlock: mockServer.KeyRUnlock,
GetValue: mockServer.GetValue,
SetValue: mockServer.SetValue,
}
}
func Test_HandleLLEN(t *testing.T) {
tests := []struct {
name string
@@ -113,7 +155,14 @@ func Test_HandleLLEN(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleLLen(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -258,7 +307,14 @@ func Test_HandleLINDEX(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleLIndex(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -426,7 +482,14 @@ func Test_HandleLRANGE(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleLRange(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -577,7 +640,14 @@ func Test_HandleLSET(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleLSet(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -595,16 +665,16 @@ func Test_HandleLSET(t *testing.T) {
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err)
}
list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
l, ok := mockServer.GetValue(ctx, test.key).([]interface{})
if !ok {
t.Error("expected value to be list, got another type")
}
if len(list) != len(test.expectedValue) {
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(list))
if len(l) != len(test.expectedValue) {
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(l))
}
for i := 0; i < len(list); i++ {
if list[i] != test.expectedValue[i] {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
for i := 0; i < len(l); i++ {
if l[i] != test.expectedValue[i] {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], l[i])
}
}
mockServer.KeyRUnlock(ctx, test.key)
@@ -751,7 +821,14 @@ func Test_HandleLTRIM(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleLTrim(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -769,16 +846,16 @@ func Test_HandleLTRIM(t *testing.T) {
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err)
}
list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
l, ok := mockServer.GetValue(ctx, test.key).([]interface{})
if !ok {
t.Error("expected value to be list, got another type")
}
if len(list) != len(test.expectedValue) {
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(list))
if len(l) != len(test.expectedValue) {
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(l))
}
for i := 0; i < len(list); i++ {
if list[i] != test.expectedValue[i] {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
for i := 0; i < len(l); i++ {
if l[i] != test.expectedValue[i] {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], l[i])
}
}
mockServer.KeyRUnlock(ctx, test.key)
@@ -882,7 +959,14 @@ func Test_HandleLREM(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleLRem(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -900,16 +984,16 @@ func Test_HandleLREM(t *testing.T) {
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err)
}
list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
l, ok := mockServer.GetValue(ctx, test.key).([]interface{})
if !ok {
t.Error("expected value to be list, got another type")
}
if len(list) != len(test.expectedValue) {
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(list))
if len(l) != len(test.expectedValue) {
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(l))
}
for i := 0; i < len(list); i++ {
if list[i] != test.expectedValue[i] {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
for i := 0; i < len(l); i++ {
if l[i] != test.expectedValue[i] {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], l[i])
}
}
mockServer.KeyRUnlock(ctx, test.key)
@@ -1096,7 +1180,14 @@ func Test_HandleLMOVE(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleLMove(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1115,7 +1206,7 @@ func Test_HandleLMOVE(t *testing.T) {
if _, err = mockServer.KeyRLock(ctx, key); err != nil {
t.Error(err)
}
list, ok := mockServer.GetValue(ctx, key).([]interface{})
l, ok := mockServer.GetValue(ctx, key).([]interface{})
if !ok {
t.Error("expected value to be list, got another type")
}
@@ -1123,12 +1214,12 @@ func Test_HandleLMOVE(t *testing.T) {
if !ok {
t.Error("expected test value to be list, got another type")
}
if len(list) != len(expectedList) {
t.Errorf("expected list length to be %d, got %d", len(expectedList), len(list))
if len(l) != len(expectedList) {
t.Errorf("expected list length to be %d, got %d", len(expectedList), len(l))
}
for i := 0; i < len(list); i++ {
if list[i] != expectedList[i] {
t.Errorf("expected element at index %d to be %+v, got %+v", i, expectedList[i], list[i])
for i := 0; i < len(l); i++ {
if l[i] != expectedList[i] {
t.Errorf("expected element at index %d to be %+v, got %+v", i, expectedList[i], l[i])
}
}
mockServer.KeyRUnlock(ctx, key)
@@ -1213,7 +1304,14 @@ func Test_HandleLPUSH(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleLPush(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1231,16 +1329,16 @@ func Test_HandleLPUSH(t *testing.T) {
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err)
}
list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
l, ok := mockServer.GetValue(ctx, test.key).([]interface{})
if !ok {
t.Error("expected value to be list, got another type")
}
if len(list) != len(test.expectedValue) {
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(list))
if len(l) != len(test.expectedValue) {
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(l))
}
for i := 0; i < len(list); i++ {
if list[i] != test.expectedValue[i] {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
for i := 0; i < len(l); i++ {
if l[i] != test.expectedValue[i] {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], l[i])
}
}
mockServer.KeyRUnlock(ctx, test.key)
@@ -1324,7 +1422,14 @@ func Test_HandleRPUSH(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleRPush(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1342,16 +1447,16 @@ func Test_HandleRPUSH(t *testing.T) {
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err)
}
list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
l, ok := mockServer.GetValue(ctx, test.key).([]interface{})
if !ok {
t.Error("expected value to be list, got another type")
}
if len(list) != len(test.expectedValue) {
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(list))
if len(l) != len(test.expectedValue) {
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(l))
}
for i := 0; i < len(list); i++ {
if list[i] != test.expectedValue[i] {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
for i := 0; i < len(l); i++ {
if l[i] != test.expectedValue[i] {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], l[i])
}
}
mockServer.KeyRUnlock(ctx, test.key)
@@ -1445,7 +1550,14 @@ func Test_HandlePOP(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handlePop(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1463,16 +1575,16 @@ func Test_HandlePOP(t *testing.T) {
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err)
}
list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
l, ok := mockServer.GetValue(ctx, test.key).([]interface{})
if !ok {
t.Error("expected value to be list, got another type")
}
if len(list) != len(test.expectedValue) {
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(list))
if len(l) != len(test.expectedValue) {
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(l))
}
for i := 0; i < len(list); i++ {
if list[i] != test.expectedValue[i] {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
for i := 0; i < len(l); i++ {
if l[i] != test.expectedValue[i] {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], l[i])
}
}
mockServer.KeyRUnlock(ctx, test.key)

View File

@@ -12,4 +12,4 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package echovault
package pubsub

View File

@@ -22,9 +22,12 @@ import (
internal_pubsub "github.com/echovault/echovault/internal/pubsub"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
ps "github.com/echovault/echovault/pkg/modules/pubsub"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
"slices"
"strings"
"sync"
"testing"
"time"
@@ -51,7 +54,7 @@ func init() {
func setUpServer(bindAddr string, port uint16) *echovault.EchoVault {
server, _ := echovault.NewEchoVault(
echovault.WithCommands(Commands()),
echovault.WithCommands(ps.Commands()),
echovault.WithConfig(config.Config{
BindAddr: bindAddr,
Port: port,
@@ -62,6 +65,36 @@ func setUpServer(bindAddr string, port uint16) *echovault.EchoVault {
return server
}
func getHandler(commands ...string) types.HandlerFunc {
if len(commands) == 0 {
return nil
}
for _, c := range mockServer.GetAllCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler
return c.HandlerFunc
}
if strings.EqualFold(commands[0], c.Command) {
// Get sub-command handler
for _, sc := range c.SubCommands {
if strings.EqualFold(commands[1], sc.Command) {
return sc.HandlerFunc
}
}
}
}
return nil
}
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn, mockServer *echovault.EchoVault) types.HandlerFuncParams {
return types.HandlerFuncParams{
Context: ctx,
Command: cmd,
Connection: conn,
GetPubSub: mockServer.GetPubSub,
}
}
func Test_HandleSubscribe(t *testing.T) {
ctx := context.WithValue(context.Background(), "test_name", "SUBSCRIBE/PSUBSCRIBE")
@@ -86,7 +119,8 @@ func Test_HandleSubscribe(t *testing.T) {
// Test subscribe to channels
channels := []string{"sub_channel1", "sub_channel2", "sub_channel3"}
for _, conn := range connections {
if _, err := handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, conn); err != nil {
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), conn, mockServer))
if err != nil {
t.Error(err)
}
}
@@ -116,7 +150,8 @@ func Test_HandleSubscribe(t *testing.T) {
// Test subscribe to patterns
patterns := []string{"psub_channel*"}
for _, conn := range connections {
if _, err := handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, patterns...), mockServer, conn); err != nil {
_, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), conn, mockServer))
if err != nil {
t.Error(err)
}
}
@@ -263,24 +298,24 @@ func Test_HandleUnsubscribe(t *testing.T) {
// Subscribe all the connections to the channels and patterns
for _, conn := range append(test.otherConnections, test.targetConn) {
_, err := handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, test.subChannels...), mockServer, conn)
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, test.subChannels...), conn, mockServer))
if err != nil {
t.Error(err)
}
_, err = handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, test.subPatterns...), mockServer, conn)
_, err = getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, test.subPatterns...), conn, mockServer))
if err != nil {
t.Error(err)
}
}
// Unsubscribe the target connection from the unsub channels and patterns
res, err := handleUnsubscribe(ctx, append([]string{"UNSUBSCRIBE"}, test.unSubChannels...), mockServer, test.targetConn)
res, err := getHandler("UNSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"UNSUBSCRIBE"}, test.unSubChannels...), test.targetConn, mockServer))
if err != nil {
t.Error(err)
}
verifyResponse(res, test.expectedResponses["channel"])
res, err = handleUnsubscribe(ctx, append([]string{"PUNSUBSCRIBE"}, test.unSubPatterns...), mockServer, test.targetConn)
res, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PUNSUBSCRIBE"}, test.unSubPatterns...), test.targetConn, mockServer))
if err != nil {
t.Error(err)
}
@@ -347,7 +382,7 @@ func Test_HandlePublish(t *testing.T) {
subscribe := func(ctx context.Context, channels []string, patterns []string, c *net.Conn, r *resp.Conn) {
// Subscribe to channels
go func() {
_, _ = handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, c)
_, _ = getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), c, mockServer))
}()
// Verify all the responses for each channel subscription
for i := 0; i < len(channels); i++ {
@@ -355,7 +390,7 @@ func Test_HandlePublish(t *testing.T) {
}
// Subscribe to all the patterns
go func() {
_, _ = handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, patterns...), mockServer, c)
_, _ = getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), c, mockServer))
}()
// Verify all the responses for each pattern subscription
for i := 0; i < len(patterns); i++ {
@@ -518,7 +553,7 @@ func Test_HandlePubSubChannels(t *testing.T) {
// Subscribe connections to channels
go func() {
_, err := handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, &wConn1)
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), &wConn1, mockServer))
if err != nil {
t.Error(err)
}
@@ -535,7 +570,7 @@ func Test_HandlePubSubChannels(t *testing.T) {
}
}
go func() {
_, err := handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, patterns...), mockServer, &wConn2)
_, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), &wConn2, mockServer))
if err != nil {
t.Error(err)
}
@@ -571,7 +606,7 @@ func Test_HandlePubSubChannels(t *testing.T) {
}
// Check if all subscriptions are returned
res, err := handlePubSubChannels(ctx, []string{"PUBSUB", "CHANNELS"}, mockServer, nil)
res, err := getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS"}, nil, mockServer))
if err != nil {
t.Error(err)
}
@@ -579,45 +614,45 @@ func Test_HandlePubSubChannels(t *testing.T) {
// Unsubscribe from one pattern and one channel before checking against a new slice of
// expected channels/patterns in the response of the "PUBSUB CHANNELS" command
_, err = handleUnsubscribe(
_, err = getHandler("UNSUBSCRIBE")(getHandlerFuncParams(
ctx,
append([]string{"UNSUBSCRIBE"}, []string{"channel_2", "channel_3"}...),
mockServer,
&wConn1,
)
mockServer,
))
if err != nil {
t.Error(err)
}
_, err = handleUnsubscribe(
_, err = getHandler("UNSUBSCRIBE")(getHandlerFuncParams(
ctx,
append([]string{"UNSUBSCRIBE"}, "channel_[456]"),
mockServer,
&wConn2,
)
mockServer,
))
if err != nil {
t.Error(err)
}
// Return all the remaining channels
res, err = handlePubSubChannels(ctx, []string{"PUBSUB", "CHANNELS"}, mockServer, nil)
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS"}, nil, mockServer))
if err != nil {
t.Error(err)
}
verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"})
// Return only one of the remaining channels when passed a pattern that matches it
res, err = handlePubSubChannels(ctx, []string{"PUBSUB", "CHANNELS", "channel_[189]"}, mockServer, nil)
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS", "channel_[189]"}, nil, mockServer))
if err != nil {
t.Error(err)
}
verifyExpectedResponse(res, []string{"channel_1"})
// Return both remaining channels when passed a pattern that matches them
res, err = handlePubSubChannels(ctx, []string{"PUBSUB", "CHANNELS", "channel_[123]"}, mockServer, nil)
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS", "channel_[123]"}, nil, mockServer))
if err != nil {
t.Error(err)
}
verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"})
// Return none channels when passed a pattern that does not match either channel
res, err = handlePubSubChannels(ctx, []string{"PUBSUB", "CHANNELS", "channel_[456]"}, mockServer, nil)
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS", "channel_[456]"}, nil, mockServer))
if err != nil {
t.Error(err)
}
@@ -655,7 +690,8 @@ func Test_HandleNumPat(t *testing.T) {
r *resp.Conn
}{w: &w, r: resp.NewConn(r)}
go func() {
if _, err := handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, patterns...), mockServer, &w); err != nil {
_, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), &w, mockServer))
if err != nil {
t.Error(err)
}
}()
@@ -685,7 +721,7 @@ func Test_HandleNumPat(t *testing.T) {
}
// Check that we receive all the patterns with NUMPAT commands
res, err := handlePubSubNumPat(ctx, []string{"PUBSUB", "NUMPAT"}, mockServer, nil)
res, err := getHandler("PUBSUB", "NUMPAT")(getHandlerFuncParams(ctx, []string{"PUBSUB", "NUMPAT"}, nil, mockServer))
if err != nil {
t.Error(err)
}
@@ -693,12 +729,12 @@ func Test_HandleNumPat(t *testing.T) {
// Unsubscribe from a channel and check if the number of active channels is updated
for _, conn := range connections {
_, err = handleUnsubscribe(ctx, []string{"PUNSUBSCRIBE", patterns[0]}, mockServer, conn.w)
_, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, []string{"PUNSUBSCRIBE", patterns[0]}, conn.w, mockServer))
if err != nil {
t.Error(err)
}
}
res, err = handlePubSubNumPat(ctx, []string{"PUBSUB", "NUMPAT"}, mockServer, nil)
res, err = getHandler("PUBSUB", "NUMPAT")(getHandlerFuncParams(ctx, []string{"PUBSUB", "NUMPAT"}, nil, mockServer))
if err != nil {
t.Error(err)
}
@@ -706,12 +742,12 @@ func Test_HandleNumPat(t *testing.T) {
// Unsubscribe from all the channels and check if we get a 0 response
for _, conn := range connections {
_, err = handleUnsubscribe(ctx, []string{"PUNSUBSCRIBE"}, mockServer, conn.w)
_, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, []string{"PUNSUBSCRIBE"}, conn.w, mockServer))
if err != nil {
t.Error(err)
}
}
res, err = handlePubSubNumPat(ctx, []string{"PUBSUB", "NUMPAT"}, mockServer, nil)
res, err = getHandler("PUBSUB", "NUMPAT")(getHandlerFuncParams(ctx, []string{"PUBSUB", "NUMPAT"}, nil, mockServer))
if err != nil {
t.Error(err)
}
@@ -748,7 +784,8 @@ func Test_HandleNumSub(t *testing.T) {
r *resp.Conn
}{w: &w, r: resp.NewConn(r)}
go func() {
if _, err := handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, &w); err != nil {
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), &w, mockServer))
if err != nil {
t.Error(err)
}
}()
@@ -793,7 +830,7 @@ func Test_HandleNumSub(t *testing.T) {
for i, test := range tests {
ctx = context.WithValue(ctx, "test_index", i)
res, err := handlePubSubNumSubs(ctx, test.cmd, mockServer, nil)
res, err := getHandler("PUBSUB", "NUMSUB")(getHandlerFuncParams(ctx, test.cmd, nil, mockServer))
if err != nil {
t.Error(err)
}

View File

@@ -12,26 +12,42 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package echovault
package set
import (
"context"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/set"
"github.com/echovault/echovault/pkg/commands"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
s "github.com/echovault/echovault/pkg/modules/set"
"reflect"
"slices"
"testing"
)
func TestEchoVault_SADD(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
func createEchoVault() *echovault.EchoVault {
ev, _ := echovault.NewEchoVault(
echovault.WithCommands(s.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
return ev
}
func presetValue(server *echovault.EchoVault, ctx context.Context, key string, value interface{}) error {
if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
return err
}
if err := server.SetValue(ctx, key, value); err != nil {
return err
}
server.KeyUnlock(ctx, key)
return nil
}
func TestEchoVault_SADD(t *testing.T) {
server := createEchoVault()
tests := []struct {
name string
@@ -69,7 +85,11 @@ func TestEchoVault_SADD(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.SADD(tt.key, tt.members...)
if (err != nil) != tt.wantErr {
@@ -84,13 +104,7 @@ func TestEchoVault_SADD(t *testing.T) {
}
func TestEchoVault_SCARD(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -124,7 +138,11 @@ func TestEchoVault_SCARD(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.SCARD(tt.key)
if (err != nil) != tt.wantErr {
@@ -139,13 +157,7 @@ func TestEchoVault_SCARD(t *testing.T) {
}
func TestEchoVault_SDIFF(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -212,7 +224,11 @@ func TestEchoVault_SDIFF(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, v := range tt.presetValues {
presetValue(server, k, v)
err := presetValue(server, context.Background(), k, v)
if err != nil {
t.Error(err)
return
}
}
}
got, err := server.SDIFF(tt.keys...)
@@ -233,13 +249,7 @@ func TestEchoVault_SDIFF(t *testing.T) {
}
func TestEchoVault_SDIFFSTORE(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -312,7 +322,11 @@ func TestEchoVault_SDIFFSTORE(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, v := range tt.presetValues {
presetValue(server, k, v)
err := presetValue(server, context.Background(), k, v)
if err != nil {
t.Error(err)
return
}
}
}
got, err := server.SDIFFSTORE(tt.destination, tt.keys...)
@@ -328,13 +342,7 @@ func TestEchoVault_SDIFFSTORE(t *testing.T) {
}
func TestEchoVault_SINTER(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -401,7 +409,11 @@ func TestEchoVault_SINTER(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, v := range tt.presetValues {
presetValue(server, k, v)
err := presetValue(server, context.Background(), k, v)
if err != nil {
t.Error(err)
return
}
}
}
got, err := server.SINTER(tt.keys...)
@@ -422,13 +434,7 @@ func TestEchoVault_SINTER(t *testing.T) {
}
func TestEchoVault_SINTERCARD(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -512,7 +518,11 @@ func TestEchoVault_SINTERCARD(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, v := range tt.presetValues {
presetValue(server, k, v)
err := presetValue(server, context.Background(), k, v)
if err != nil {
t.Error(err)
return
}
}
}
got, err := server.SINTERCARD(tt.keys, tt.limit)
@@ -528,13 +538,7 @@ func TestEchoVault_SINTERCARD(t *testing.T) {
}
func TestEchoVault_SINTERSTORE(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -607,7 +611,11 @@ func TestEchoVault_SINTERSTORE(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, v := range tt.presetValues {
presetValue(server, k, v)
err := presetValue(server, context.Background(), k, v)
if err != nil {
t.Error(err)
return
}
}
}
got, err := server.SINTERSTORE(tt.destination, tt.keys...)
@@ -623,13 +631,7 @@ func TestEchoVault_SINTERSTORE(t *testing.T) {
}
func TestEchoVault_SISMEMBER(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -667,7 +669,11 @@ func TestEchoVault_SISMEMBER(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.SISMEMBER(tt.key, tt.member)
if (err != nil) != tt.wantErr {
@@ -682,13 +688,7 @@ func TestEchoVault_SISMEMBER(t *testing.T) {
}
func TestEchoVault_SMEMBERS(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -722,7 +722,11 @@ func TestEchoVault_SMEMBERS(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.SMEMBERS(tt.key)
if (err != nil) != tt.wantErr {
@@ -742,13 +746,7 @@ func TestEchoVault_SMEMBERS(t *testing.T) {
}
func TestEchoVault_SMISMEMBER(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -805,7 +803,11 @@ func TestEchoVault_SMISMEMBER(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.SMISMEMBER(tt.key, tt.members...)
if (err != nil) != tt.wantErr {
@@ -820,13 +822,7 @@ func TestEchoVault_SMISMEMBER(t *testing.T) {
}
func TestEchoVault_SMOVE(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -890,7 +886,11 @@ func TestEchoVault_SMOVE(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, v := range tt.presetValues {
presetValue(server, k, v)
err := presetValue(server, context.Background(), k, v)
if err != nil {
t.Error(err)
return
}
}
}
got, err := server.SMOVE(tt.source, tt.destination, tt.member)
@@ -906,13 +906,7 @@ func TestEchoVault_SMOVE(t *testing.T) {
}
func TestEchoVault_SPOP(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -942,7 +936,11 @@ func TestEchoVault_SPOP(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.SPOP(tt.key, tt.count)
if (err != nil) != tt.wantErr {
@@ -959,13 +957,7 @@ func TestEchoVault_SPOP(t *testing.T) {
}
func TestEchoVault_SRANDMEMBER(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -1007,7 +999,11 @@ func TestEchoVault_SRANDMEMBER(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.SRANDMEMBER(tt.key, tt.count)
if (err != nil) != tt.wantErr {
@@ -1028,13 +1024,7 @@ func TestEchoVault_SRANDMEMBER(t *testing.T) {
}
func TestEchoVault_SREM(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -1072,7 +1062,11 @@ func TestEchoVault_SREM(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.SREM(tt.key, tt.members...)
if (err != nil) != tt.wantErr {
@@ -1087,13 +1081,7 @@ func TestEchoVault_SREM(t *testing.T) {
}
func TestEchoVault_SUNION(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -1153,7 +1141,11 @@ func TestEchoVault_SUNION(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, v := range tt.presetValues {
presetValue(server, k, v)
err := presetValue(server, context.Background(), k, v)
if err != nil {
t.Error(err)
return
}
}
}
got, err := server.SUNION(tt.keys...)
@@ -1174,13 +1166,7 @@ func TestEchoVault_SUNION(t *testing.T) {
}
func TestEchoVault_SUNIONSTORE(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -1230,7 +1216,11 @@ func TestEchoVault_SUNIONSTORE(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValues != nil {
for k, v := range tt.presetValues {
presetValue(server, k, v)
err := presetValue(server, context.Background(), k, v)
if err != nil {
t.Error(err)
return
}
}
}
got, err := server.SUNIONSTORE(tt.destination, tt.keys...)

View File

@@ -23,8 +23,12 @@ import (
"github.com/echovault/echovault/internal/set"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
s "github.com/echovault/echovault/pkg/modules/set"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
"slices"
"strings"
"testing"
)
@@ -32,6 +36,7 @@ var mockServer *echovault.EchoVault
func init() {
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(s.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
@@ -39,6 +44,43 @@ func init() {
)
}
func getHandler(commands ...string) types.HandlerFunc {
if len(commands) == 0 {
return nil
}
for _, c := range mockServer.GetAllCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler
return c.HandlerFunc
}
if strings.EqualFold(commands[0], c.Command) {
// Get sub-command handler
for _, sc := range c.SubCommands {
if strings.EqualFold(commands[1], sc.Command) {
return sc.HandlerFunc
}
}
}
}
return nil
}
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
return types.HandlerFuncParams{
Context: ctx,
Command: cmd,
Connection: conn,
KeyExists: mockServer.KeyExists,
CreateKeyAndLock: mockServer.CreateKeyAndLock,
KeyLock: mockServer.KeyLock,
KeyRLock: mockServer.KeyRLock,
KeyUnlock: mockServer.KeyUnlock,
KeyRUnlock: mockServer.KeyRUnlock,
GetValue: mockServer.GetValue,
SetValue: mockServer.SetValue,
}
}
func Test_HandleSADD(t *testing.T) {
tests := []struct {
name string
@@ -103,7 +145,15 @@ func Test_HandleSADD(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleSADD(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -214,7 +264,15 @@ func Test_HandleSCARD(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleSCARD(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -327,7 +385,15 @@ func Test_HandleSDIFF(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleSDIFF(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -454,7 +520,15 @@ func Test_HandleSDIFFSTORE(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleSDIFFSTORE(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -582,7 +656,15 @@ func Test_HandleSINTER(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleSINTER(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -709,7 +791,15 @@ func Test_HandleSINTERCARD(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleSINTERCARD(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -834,7 +924,13 @@ func Test_HandleSINTERSTORE(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleSINTERSTORE(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -939,7 +1035,14 @@ func Test_HandleSISMEMBER(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleSISMEMBER(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1027,7 +1130,14 @@ func Test_HandleSMEMBERS(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleSMEMBERS(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1118,7 +1228,14 @@ func Test_HandleSMISMEMBER(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleSMISMEMBER(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1160,7 +1277,7 @@ func Test_HandleSMOVE(t *testing.T) {
"SmoveSource1": set.NewSet([]string{"one", "two", "three", "four"}),
"SmoveDestination1": set.NewSet([]string{"five", "six", "seven", "eight"}),
},
command: []string{"MOVE", "SmoveSource1", "SmoveDestination1", "four"},
command: []string{"SMOVE", "SmoveSource1", "SmoveDestination1", "four"},
expectedValues: map[string]interface{}{
"SmoveSource1": set.NewSet([]string{"one", "two", "three"}),
"SmoveDestination1": set.NewSet([]string{"four", "five", "six", "seven", "eight"}),
@@ -1242,7 +1359,14 @@ func Test_HandleSMOVE(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleSMOVE(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1350,7 +1474,14 @@ func Test_HandleSPOP(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleSPOP(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1473,7 +1604,14 @@ func Test_HandleSRANDMEMBER(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleSRANDMEMBER(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1588,7 +1726,14 @@ func Test_HandleSREM(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleSREM(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1708,7 +1853,14 @@ func Test_HandleSUNION(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleSUNION(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1811,7 +1963,14 @@ func Test_HandleSUNIONSTORE(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleSUNIONSTORE(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())

View File

@@ -23,10 +23,14 @@ import (
"github.com/echovault/echovault/internal/sorted_set"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
ss "github.com/echovault/echovault/pkg/modules/sorted_set"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"math"
"net"
"slices"
"strconv"
"strings"
"testing"
)
@@ -34,6 +38,7 @@ var mockServer *echovault.EchoVault
func init() {
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(ss.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
@@ -41,6 +46,43 @@ func init() {
)
}
func getHandler(commands ...string) types.HandlerFunc {
if len(commands) == 0 {
return nil
}
for _, c := range mockServer.GetAllCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler
return c.HandlerFunc
}
if strings.EqualFold(commands[0], c.Command) {
// Get sub-command handler
for _, sc := range c.SubCommands {
if strings.EqualFold(commands[1], sc.Command) {
return sc.HandlerFunc
}
}
}
}
return nil
}
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
return types.HandlerFuncParams{
Context: ctx,
Command: cmd,
Connection: conn,
KeyExists: mockServer.KeyExists,
CreateKeyAndLock: mockServer.CreateKeyAndLock,
KeyLock: mockServer.KeyLock,
KeyRLock: mockServer.KeyRLock,
KeyUnlock: mockServer.KeyUnlock,
KeyRUnlock: mockServer.KeyRUnlock,
GetValue: mockServer.GetValue,
SetValue: mockServer.SetValue,
}
}
func Test_HandleZADD(t *testing.T) {
tests := []struct {
name string
@@ -273,7 +315,15 @@ func Test_HandleZADD(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleZADD(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -390,7 +440,15 @@ func Test_HandleZCARD(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleZCARD(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -542,7 +600,15 @@ func Test_HandleZCOUNT(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleZCOUNT(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -666,7 +732,15 @@ func Test_HandleZLEXCOUNT(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleZLEXCOUNT(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -830,7 +904,15 @@ func Test_HandleZDIFF(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZDIFF(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1016,7 +1098,15 @@ func Test_HandleZDIFFSTORE(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZDIFFSTORE(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1246,7 +1336,15 @@ func Test_HandleZINCRBY(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleZINCRBY(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1482,7 +1580,15 @@ func Test_HandleZMPOP(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZMPOP(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1665,7 +1771,15 @@ func Test_HandleZPOP(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZPOP(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1775,7 +1889,15 @@ func Test_HandleZMSCORE(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZMSCORE(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1886,7 +2008,15 @@ func Test_HandleZSCORE(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZSCORE(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -2012,7 +2142,15 @@ func Test_HandleZRANDMEMBER(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleZRANDMEMBER(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -2181,7 +2319,15 @@ func Test_HandleZRANK(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZRANK(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -2289,7 +2435,15 @@ func Test_HandleZREM(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZREM(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -2405,7 +2559,15 @@ func Test_HandleZREMRANGEBYSCORE(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZREMRANGEBYSCORE(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -2578,7 +2740,15 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZREMRANGEBYRANK(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -2720,7 +2890,15 @@ func Test_HandleZREMRANGEBYLEX(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZREMRANGEBYLEX(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -2990,7 +3168,15 @@ func Test_HandleZRANGE(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZRANGE(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -3297,7 +3483,15 @@ func Test_HandleZRANGESTORE(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZRANGESTORE(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -3629,7 +3823,15 @@ func Test_HandleZINTER(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZINTER(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -3992,7 +4194,15 @@ func Test_HandleZINTERSTORE(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZINTERSTORE(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -4351,7 +4561,15 @@ func Test_HandleZUNION(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZUNION(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -4753,7 +4971,15 @@ func Test_HandleZUNIONSTORE(t *testing.T) {
mockServer.KeyUnlock(ctx, key)
}
}
res, err := handleZUNIONSTORE(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())

View File

@@ -17,12 +17,21 @@ package str
import (
"context"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
str "github.com/echovault/echovault/pkg/modules/string"
"testing"
)
func createEchoVault() *echovault.EchoVault {
ev, _ := echovault.NewEchoVault(
echovault.WithCommands(str.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
}),
)
return ev
}
func presetValue(server *echovault.EchoVault, ctx context.Context, key string, value interface{}) error {
if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
return err
@@ -35,13 +44,7 @@ func presetValue(server *echovault.EchoVault, ctx context.Context, key string, v
}
func TestEchoVault_SUBSTR(t *testing.T) {
server, _ := echovault.NewEchoVault(
echovault.WithCommands(str.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -200,13 +203,7 @@ func TestEchoVault_SUBSTR(t *testing.T) {
}
func TestEchoVault_SETRANGE(t *testing.T) {
server, _ := echovault.NewEchoVault(
echovault.WithCommands(str.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string
@@ -294,13 +291,7 @@ func TestEchoVault_SETRANGE(t *testing.T) {
}
func TestEchoVault_STRLEN(t *testing.T) {
server, _ := echovault.NewEchoVault(
echovault.WithCommands(str.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
server := createEchoVault()
tests := []struct {
name string

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package str_test
package str
import (
"bytes"
@@ -26,6 +26,7 @@ import (
str "github.com/echovault/echovault/pkg/modules/string"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"net"
"strconv"
"strings"
"testing"
@@ -43,15 +44,43 @@ func init() {
)
}
func getHandler(command string) types.HandlerFunc {
func getHandler(commands ...string) types.HandlerFunc {
if len(commands) == 0 {
return nil
}
for _, c := range mockServer.GetAllCommands() {
if strings.EqualFold(command, c.Command) {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler
return c.HandlerFunc
}
if strings.EqualFold(commands[0], c.Command) {
// Get sub-command handler
for _, sc := range c.SubCommands {
if strings.EqualFold(commands[1], sc.Command) {
return sc.HandlerFunc
}
}
}
}
return nil
}
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
return types.HandlerFuncParams{
Context: ctx,
Command: cmd,
Connection: conn,
KeyExists: mockServer.KeyExists,
CreateKeyAndLock: mockServer.CreateKeyAndLock,
KeyLock: mockServer.KeyLock,
KeyRLock: mockServer.KeyRLock,
KeyUnlock: mockServer.KeyUnlock,
KeyRUnlock: mockServer.KeyRUnlock,
GetValue: mockServer.GetValue,
SetValue: mockServer.SetValue,
}
}
func Test_HandleSetRange(t *testing.T) {
tests := []struct {
name string
@@ -176,17 +205,7 @@ func Test_HandleSetRange(t *testing.T) {
return
}
res, err := handler(types.HandlerFuncParams{
Context: ctx,
Command: test.command,
Connection: nil,
KeyExists: mockServer.KeyExists,
CreateKeyAndLock: mockServer.CreateKeyAndLock,
KeyLock: mockServer.KeyLock,
KeyUnlock: mockServer.KeyUnlock,
GetValue: mockServer.GetValue,
SetValue: mockServer.SetValue,
})
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
@@ -291,16 +310,7 @@ func Test_HandleStrLen(t *testing.T) {
return
}
res, err := handler(types.HandlerFuncParams{
Context: ctx,
Command: test.command,
Connection: nil,
KeyExists: mockServer.KeyExists,
KeyRLock: mockServer.KeyRLock,
KeyRUnlock: mockServer.KeyRUnlock,
GetValue: mockServer.GetValue,
SetValue: mockServer.SetValue,
})
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
@@ -436,16 +446,7 @@ func Test_HandleSubStr(t *testing.T) {
return
}
res, err := handler(types.HandlerFuncParams{
Context: ctx,
Command: test.command,
Connection: nil,
KeyExists: mockServer.KeyExists,
KeyRLock: mockServer.KeyRLock,
KeyRUnlock: mockServer.KeyRUnlock,
GetValue: mockServer.GetValue,
SetValue: mockServer.SetValue,
})
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {