mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-11-02 20:24:02 +08:00
Moved tests for module commands and apis into 'test' folder
This commit is contained in:
2
Makefile
2
Makefile
@@ -7,7 +7,7 @@ build:
|
||||
run:
|
||||
make build && docker-compose up --build
|
||||
|
||||
test:
|
||||
test-normal:
|
||||
go clean -testcache && go test ./... -coverprofile coverage/coverage.out
|
||||
|
||||
test-race:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/hashicorp/raft"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -36,6 +37,7 @@ type FSMOpts struct {
|
||||
StartSnapshot func()
|
||||
FinishSnapshot func()
|
||||
SetLatestSnapshotTime func(msec int64)
|
||||
GetHandlerFuncParams func(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams
|
||||
}
|
||||
|
||||
type FSM struct {
|
||||
@@ -86,34 +88,33 @@ func (fsm *FSM) Apply(log *raft.Log) interface{} {
|
||||
}
|
||||
|
||||
case "command":
|
||||
// TODO: Re-Implement Command handling with dependency injection
|
||||
// Handle command
|
||||
// command, err := fsm.options.GetCommand(request.CMD[0])
|
||||
// if err != nil {
|
||||
// return internal.ApplyResponse{
|
||||
// Error: err,
|
||||
// Response: nil,
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// handler := command.HandlerFunc
|
||||
//
|
||||
// subCommand, ok := internal.GetSubCommand(command, request.CMD).(types.SubCommand)
|
||||
// if ok {
|
||||
// handler = subCommand.HandlerFunc
|
||||
// }
|
||||
//
|
||||
// if res, err := handler(ctx, request.CMD, fsm.options.EchoVault, nil); err != nil {
|
||||
// return internal.ApplyResponse{
|
||||
// Error: err,
|
||||
// Response: nil,
|
||||
// }
|
||||
// } else {
|
||||
// return internal.ApplyResponse{
|
||||
// Error: nil,
|
||||
// Response: res,
|
||||
// }
|
||||
// }
|
||||
command, err := fsm.options.GetCommand(request.CMD[0])
|
||||
if err != nil {
|
||||
return internal.ApplyResponse{
|
||||
Error: err,
|
||||
Response: nil,
|
||||
}
|
||||
}
|
||||
|
||||
handler := command.HandlerFunc
|
||||
|
||||
subCommand, ok := internal.GetSubCommand(command, request.CMD).(types.SubCommand)
|
||||
if ok {
|
||||
handler = subCommand.HandlerFunc
|
||||
}
|
||||
|
||||
if res, err := handler(fsm.options.GetHandlerFuncParams(ctx, request.CMD, nil)); err != nil {
|
||||
return internal.ApplyResponse{
|
||||
Error: err,
|
||||
Response: nil,
|
||||
}
|
||||
} else {
|
||||
return internal.ApplyResponse{
|
||||
Error: nil,
|
||||
Response: res,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ type Opts struct {
|
||||
StartSnapshot func()
|
||||
FinishSnapshot func()
|
||||
SetLatestSnapshotTime func(msec int64)
|
||||
GetHandlerFuncParams func(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams
|
||||
}
|
||||
|
||||
type Raft struct {
|
||||
@@ -120,6 +121,7 @@ func (r *Raft) RaftInit(ctx context.Context) {
|
||||
StartSnapshot: r.options.StartSnapshot,
|
||||
FinishSnapshot: r.options.FinishSnapshot,
|
||||
SetLatestSnapshotTime: r.options.SetLatestSnapshotTime,
|
||||
GetHandlerFuncParams: r.options.GetHandlerFuncParams,
|
||||
}),
|
||||
logStore,
|
||||
stableStore,
|
||||
|
||||
@@ -157,6 +157,7 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
|
||||
StartSnapshot: echovault.startSnapshot,
|
||||
FinishSnapshot: echovault.finishSnapshot,
|
||||
SetLatestSnapshotTime: echovault.setLatestSnapshot,
|
||||
GetHandlerFuncParams: echovault.getHandlerFuncParams,
|
||||
GetState: func() map[string]internal.KeyData {
|
||||
state := make(map[string]internal.KeyData)
|
||||
for k, v := range echovault.getState() {
|
||||
|
||||
@@ -48,7 +48,6 @@ func (server *EchoVault) getCommand(cmd string) (types.Command, error) {
|
||||
|
||||
func (server *EchoVault) getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
|
||||
return types.HandlerFuncParams{
|
||||
// TODO: Add all the required methods here
|
||||
Context: ctx,
|
||||
Command: cmd,
|
||||
Connection: conn,
|
||||
@@ -60,6 +59,13 @@ func (server *EchoVault) getHandlerFuncParams(ctx context.Context, cmd []string,
|
||||
KeyRUnlock: server.KeyRUnlock,
|
||||
GetValue: server.GetValue,
|
||||
SetValue: server.SetValue,
|
||||
GetClock: server.GetClock,
|
||||
GetExpiry: server.GetExpiry,
|
||||
SetExpiry: server.SetExpiry,
|
||||
DeleteKey: server.DeleteKey,
|
||||
GetPubSub: server.GetPubSub,
|
||||
GetACL: server.GetACL,
|
||||
GetAllCommands: server.GetAllCommands,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -24,33 +23,32 @@ import (
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"gopkg.in/yaml.v3"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func handleAuth(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
|
||||
if len(cmd) < 2 || len(cmd) > 3 {
|
||||
func handleAuth(params types.HandlerFuncParams) ([]byte, error) {
|
||||
if len(params.Command) < 2 || len(params.Command) > 3 {
|
||||
return nil, errors.New(constants.WrongArgsResponse)
|
||||
}
|
||||
acl, ok := server.GetACL().(*internal_acl.ACL)
|
||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
if err := acl.AuthenticateConnection(ctx, conn, cmd); err != nil {
|
||||
if err := acl.AuthenticateConnection(params.Context, params.Connection, params.Command); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(constants.OkResponse), nil
|
||||
}
|
||||
|
||||
func handleGetUser(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
if len(cmd) != 3 {
|
||||
func handleGetUser(params types.HandlerFuncParams) ([]byte, error) {
|
||||
if len(params.Command) != 3 {
|
||||
return nil, errors.New(constants.WrongArgsResponse)
|
||||
}
|
||||
|
||||
acl, ok := server.GetACL().(*internal_acl.ACL)
|
||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
@@ -58,7 +56,7 @@ func handleGetUser(_ context.Context, cmd []string, server types.EchoVault, _ *n
|
||||
var user *internal_acl.User
|
||||
userFound := false
|
||||
for _, u := range acl.Users {
|
||||
if u.Username == cmd[2] {
|
||||
if u.Username == params.Command[2] {
|
||||
user = u
|
||||
userFound = true
|
||||
break
|
||||
@@ -162,14 +160,14 @@ func handleGetUser(_ context.Context, cmd []string, server types.EchoVault, _ *n
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleCat(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
if len(cmd) > 3 {
|
||||
func handleCat(params types.HandlerFuncParams) ([]byte, error) {
|
||||
if len(params.Command) > 3 {
|
||||
return nil, errors.New(constants.WrongArgsResponse)
|
||||
}
|
||||
|
||||
categories := make(map[string][]string)
|
||||
|
||||
commands := server.GetAllCommands()
|
||||
commands := params.GetAllCommands()
|
||||
|
||||
for _, command := range commands {
|
||||
if len(command.SubCommands) == 0 {
|
||||
@@ -186,7 +184,7 @@ func handleCat(_ context.Context, cmd []string, server types.EchoVault, _ *net.C
|
||||
}
|
||||
}
|
||||
|
||||
if len(cmd) == 2 {
|
||||
if len(params.Command) == 2 {
|
||||
var cats []string
|
||||
length := 0
|
||||
for key, _ := range categories {
|
||||
@@ -203,10 +201,10 @@ func handleCat(_ context.Context, cmd []string, server types.EchoVault, _ *net.C
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
if len(cmd) == 3 {
|
||||
if len(params.Command) == 3 {
|
||||
var res string
|
||||
for category, commands := range categories {
|
||||
if strings.EqualFold(category, cmd[2]) {
|
||||
if strings.EqualFold(category, params.Command[2]) {
|
||||
res = fmt.Sprintf("*%d", len(commands))
|
||||
for i, command := range commands {
|
||||
res = fmt.Sprintf("%s\r\n+%s", res, command)
|
||||
@@ -219,11 +217,11 @@ func handleCat(_ context.Context, cmd []string, server types.EchoVault, _ *net.C
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("category %s not found", strings.ToUpper(cmd[2]))
|
||||
return nil, fmt.Errorf("category %s not found", strings.ToUpper(params.Command[2]))
|
||||
}
|
||||
|
||||
func handleUsers(_ context.Context, _ []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
acl, ok := server.GetACL().(*internal_acl.ACL)
|
||||
func handleUsers(params types.HandlerFuncParams) ([]byte, error) {
|
||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
@@ -235,45 +233,45 @@ func handleUsers(_ context.Context, _ []string, server types.EchoVault, _ *net.C
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleSetUser(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
acl, ok := server.GetACL().(*internal_acl.ACL)
|
||||
func handleSetUser(params types.HandlerFuncParams) ([]byte, error) {
|
||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
if err := acl.SetUser(cmd[2:]); err != nil {
|
||||
if err := acl.SetUser(params.Command[2:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(constants.OkResponse), nil
|
||||
}
|
||||
|
||||
func handleDelUser(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
if len(cmd) < 3 {
|
||||
func handleDelUser(params types.HandlerFuncParams) ([]byte, error) {
|
||||
if len(params.Command) < 3 {
|
||||
return nil, errors.New(constants.WrongArgsResponse)
|
||||
}
|
||||
acl, ok := server.GetACL().(*internal_acl.ACL)
|
||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
if err := acl.DeleteUser(ctx, cmd[2:]); err != nil {
|
||||
if err := acl.DeleteUser(params.Context, params.Command[2:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(constants.OkResponse), nil
|
||||
}
|
||||
|
||||
func handleWhoAmI(_ context.Context, _ []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
|
||||
acl, ok := server.GetACL().(*internal_acl.ACL)
|
||||
func handleWhoAmI(params types.HandlerFuncParams) ([]byte, error) {
|
||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
connectionInfo := acl.Connections[conn]
|
||||
connectionInfo := acl.Connections[params.Connection]
|
||||
return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil
|
||||
}
|
||||
|
||||
func handleList(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
if len(cmd) > 2 {
|
||||
func handleList(params types.HandlerFuncParams) ([]byte, error) {
|
||||
if len(params.Command) > 2 {
|
||||
return nil, errors.New(constants.WrongArgsResponse)
|
||||
}
|
||||
acl, ok := server.GetACL().(*internal_acl.ACL)
|
||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
@@ -365,12 +363,12 @@ func handleList(_ context.Context, cmd []string, server types.EchoVault, _ *net.
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleLoad(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
if len(cmd) != 3 {
|
||||
func handleLoad(params types.HandlerFuncParams) ([]byte, error) {
|
||||
if len(params.Command) != 3 {
|
||||
return nil, errors.New(constants.WrongArgsResponse)
|
||||
}
|
||||
|
||||
acl, ok := server.GetACL().(*internal_acl.ACL)
|
||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
@@ -414,7 +412,7 @@ func handleLoad(_ context.Context, cmd []string, server types.EchoVault, _ *net.
|
||||
if u.Username == user.Username {
|
||||
userFound = true
|
||||
// If we have a user with the current username and are in merge mode, merge the two users.
|
||||
if strings.EqualFold(cmd[2], "merge") {
|
||||
if strings.EqualFold(params.Command[2], "merge") {
|
||||
u.Merge(user)
|
||||
} else {
|
||||
// If we have a user with the current username and are in replace mode, merge the two users.
|
||||
@@ -432,12 +430,12 @@ func handleLoad(_ context.Context, cmd []string, server types.EchoVault, _ *net.
|
||||
return []byte(constants.OkResponse), nil
|
||||
}
|
||||
|
||||
func handleSave(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
if len(cmd) > 2 {
|
||||
func handleSave(params types.HandlerFuncParams) ([]byte, error) {
|
||||
if len(params.Command) > 2 {
|
||||
return nil, errors.New(constants.WrongArgsResponse)
|
||||
}
|
||||
|
||||
acl, ok := server.GetACL().(*internal_acl.ACL)
|
||||
acl, ok := params.GetACL().(*internal_acl.ACL)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load ACL")
|
||||
}
|
||||
|
||||
@@ -15,19 +15,17 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"github.com/gobwas/glob"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func handleGetAllCommands(_ context.Context, _ []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
commands := server.GetAllCommands()
|
||||
func handleGetAllCommands(params types.HandlerFuncParams) ([]byte, error) {
|
||||
commands := params.GetAllCommands()
|
||||
|
||||
res := ""
|
||||
commandCount := 0
|
||||
@@ -71,10 +69,10 @@ func handleGetAllCommands(_ context.Context, _ []string, server types.EchoVault,
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleCommandCount(_ context.Context, _ []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
func handleCommandCount(params types.HandlerFuncParams) ([]byte, error) {
|
||||
var count int
|
||||
|
||||
commands := server.GetAllCommands()
|
||||
commands := params.GetAllCommands()
|
||||
for _, command := range commands {
|
||||
if command.SubCommands != nil && len(command.SubCommands) > 0 {
|
||||
for _, _ = range command.SubCommands {
|
||||
@@ -88,13 +86,13 @@ func handleCommandCount(_ context.Context, _ []string, server types.EchoVault, _
|
||||
return []byte(fmt.Sprintf(":%d\r\n", count)), nil
|
||||
}
|
||||
|
||||
func handleCommandList(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
switch len(cmd) {
|
||||
func handleCommandList(params types.HandlerFuncParams) ([]byte, error) {
|
||||
switch len(params.Command) {
|
||||
case 2:
|
||||
// Command is COMMAND LIST
|
||||
var count int
|
||||
var res string
|
||||
commands := server.GetAllCommands()
|
||||
commands := params.GetAllCommands()
|
||||
for _, command := range commands {
|
||||
if command.SubCommands != nil && len(command.SubCommands) > 0 {
|
||||
for _, subcommand := range command.SubCommands {
|
||||
@@ -114,13 +112,13 @@ func handleCommandList(_ context.Context, cmd []string, server types.EchoVault,
|
||||
var count int
|
||||
var res string
|
||||
// Command has filter
|
||||
if !strings.EqualFold("FILTERBY", cmd[2]) {
|
||||
return nil, fmt.Errorf("expected FILTERBY, got %s", strings.ToUpper(cmd[2]))
|
||||
if !strings.EqualFold("FILTERBY", params.Command[2]) {
|
||||
return nil, fmt.Errorf("expected FILTERBY, got %s", strings.ToUpper(params.Command[2]))
|
||||
}
|
||||
if strings.EqualFold("ACLCAT", cmd[3]) {
|
||||
if strings.EqualFold("ACLCAT", params.Command[3]) {
|
||||
// ACL Category filter
|
||||
commands := server.GetAllCommands()
|
||||
category := strings.ToLower(cmd[4])
|
||||
commands := params.GetAllCommands()
|
||||
category := strings.ToLower(params.Command[4])
|
||||
for _, command := range commands {
|
||||
if command.SubCommands != nil && len(command.SubCommands) > 0 {
|
||||
for _, subcommand := range command.SubCommands {
|
||||
@@ -137,10 +135,10 @@ func handleCommandList(_ context.Context, cmd []string, server types.EchoVault,
|
||||
count += 1
|
||||
}
|
||||
}
|
||||
} else if strings.EqualFold("PATTERN", cmd[3]) {
|
||||
} else if strings.EqualFold("PATTERN", params.Command[3]) {
|
||||
// Pattern filter
|
||||
commands := server.GetAllCommands()
|
||||
g := glob.MustCompile(cmd[4])
|
||||
commands := params.GetAllCommands()
|
||||
g := glob.MustCompile(params.Command[4])
|
||||
for _, command := range commands {
|
||||
if command.SubCommands != nil && len(command.SubCommands) > 0 {
|
||||
for _, subcommand := range command.SubCommands {
|
||||
@@ -157,10 +155,10 @@ func handleCommandList(_ context.Context, cmd []string, server types.EchoVault,
|
||||
count += 1
|
||||
}
|
||||
}
|
||||
} else if strings.EqualFold("MODULE", cmd[3]) {
|
||||
} else if strings.EqualFold("MODULE", params.Command[3]) {
|
||||
// Module filter
|
||||
commands := server.GetAllCommands()
|
||||
module := strings.ToLower(cmd[4])
|
||||
commands := params.GetAllCommands()
|
||||
module := strings.ToLower(params.Command[4])
|
||||
for _, command := range commands {
|
||||
if command.SubCommands != nil && len(command.SubCommands) > 0 {
|
||||
for _, subcommand := range command.SubCommands {
|
||||
@@ -178,7 +176,7 @@ func handleCommandList(_ context.Context, cmd []string, server types.EchoVault,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("expected filter to be ACLCAT or PATTERN, got %s", strings.ToUpper(cmd[3]))
|
||||
return nil, fmt.Errorf("expected filter to be ACLCAT or PATTERN, got %s", strings.ToUpper(params.Command[3]))
|
||||
}
|
||||
res = fmt.Sprintf("*%d\r\n%s", count, res)
|
||||
return []byte(res), nil
|
||||
@@ -187,7 +185,7 @@ func handleCommandList(_ context.Context, cmd []string, server types.EchoVault,
|
||||
}
|
||||
}
|
||||
|
||||
func handleCommandDocs(_ context.Context, _ []string, _ types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
func handleCommandDocs(params types.HandlerFuncParams) ([]byte, error) {
|
||||
return []byte("*0\r\n"), nil
|
||||
}
|
||||
|
||||
@@ -283,8 +281,8 @@ Allows for filtering by ACL category or glob pattern.`,
|
||||
WriteKeys: make([]string, 0),
|
||||
}, nil
|
||||
},
|
||||
HandlerFunc: func(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
|
||||
if err := server.TakeSnapshot(); err != nil {
|
||||
HandlerFunc: func(params types.HandlerFuncParams) ([]byte, error) {
|
||||
if err := params.TakeSnapshot(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(constants.OkResponse), nil
|
||||
@@ -303,8 +301,8 @@ Allows for filtering by ACL category or glob pattern.`,
|
||||
WriteKeys: make([]string, 0),
|
||||
}, nil
|
||||
},
|
||||
HandlerFunc: func(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
|
||||
msec := server.GetLatestSnapshotTime()
|
||||
HandlerFunc: func(params types.HandlerFuncParams) ([]byte, error) {
|
||||
msec := params.GetLatestSnapshotTime()
|
||||
if msec == 0 {
|
||||
return nil, errors.New("no snapshot")
|
||||
}
|
||||
@@ -324,8 +322,8 @@ Allows for filtering by ACL category or glob pattern.`,
|
||||
WriteKeys: make([]string, 0),
|
||||
}, nil
|
||||
},
|
||||
HandlerFunc: func(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
|
||||
if err := server.RewriteAOF(); err != nil {
|
||||
HandlerFunc: func(params types.HandlerFuncParams) ([]byte, error) {
|
||||
if err := params.RewriteAOF(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(constants.OkResponse), nil
|
||||
|
||||
@@ -15,29 +15,27 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"net"
|
||||
)
|
||||
|
||||
func handlePing(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
|
||||
switch len(cmd) {
|
||||
func handlePing(params types.HandlerFuncParams) ([]byte, error) {
|
||||
switch len(params.Command) {
|
||||
default:
|
||||
return nil, errors.New(constants.WrongArgsResponse)
|
||||
case 1:
|
||||
return []byte("+PONG\r\n"), nil
|
||||
case 2:
|
||||
return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(cmd[1]), cmd[1])), nil
|
||||
return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(params.Command[1]), params.Command[1])), nil
|
||||
}
|
||||
}
|
||||
|
||||
func Commands() []types.Command {
|
||||
return []types.Command{
|
||||
{
|
||||
Command: "connection",
|
||||
Command: "ping",
|
||||
Module: constants.ConnectionModule,
|
||||
Categories: []string{constants.FastCategory, constants.ConnectionCategory},
|
||||
Description: "(PING [value]) Ping the echovault. If a value is provided, the value will be echoed.",
|
||||
|
||||
@@ -15,14 +15,12 @@
|
||||
package generic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/internal"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -33,73 +31,73 @@ type KeyObject struct {
|
||||
locked bool
|
||||
}
|
||||
|
||||
func handleSet(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := setKeyFunc(cmd)
|
||||
func handleSet(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := setKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.WriteKeys[0]
|
||||
value := cmd[2]
|
||||
value := params.Command[2]
|
||||
res := []byte(constants.OkResponse)
|
||||
clock := server.GetClock()
|
||||
clock := params.GetClock()
|
||||
|
||||
params, err := getSetCommandParams(clock, cmd[3:], SetParams{})
|
||||
options, err := getSetCommandOptions(clock, params.Command[3:], SetOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If GET is provided, the response should be the current stored value.
|
||||
// If there's no current value, then the response should be nil.
|
||||
if params.get {
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if options.get {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
res = []byte("$-1\r\n")
|
||||
} else {
|
||||
res = []byte(fmt.Sprintf("+%v\r\n", server.GetValue(ctx, key)))
|
||||
res = []byte(fmt.Sprintf("+%v\r\n", params.GetValue(params.Context, key)))
|
||||
}
|
||||
}
|
||||
|
||||
if "xx" == strings.ToLower(params.exists) {
|
||||
if "xx" == strings.ToLower(options.exists) {
|
||||
// If XX is specified, make sure the key exists.
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return nil, fmt.Errorf("key %s does not exist", key)
|
||||
}
|
||||
_, err = server.KeyLock(ctx, key)
|
||||
} else if "nx" == strings.ToLower(params.exists) {
|
||||
_, err = params.KeyLock(params.Context, key)
|
||||
} else if "nx" == strings.ToLower(options.exists) {
|
||||
// If NX is specified, make sure that the key does not currently exist.
|
||||
if server.KeyExists(ctx, key) {
|
||||
if params.KeyExists(params.Context, key) {
|
||||
return nil, fmt.Errorf("key %s already exists", key)
|
||||
}
|
||||
_, err = server.CreateKeyAndLock(ctx, key)
|
||||
_, err = params.CreateKeyAndLock(params.Context, key)
|
||||
} else {
|
||||
// Neither XX not NX are specified, lock or create the lock
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
// Key does not exist, create it
|
||||
_, err = server.CreateKeyAndLock(ctx, key)
|
||||
_, err = params.CreateKeyAndLock(params.Context, key)
|
||||
} else {
|
||||
// Key exists, acquire the lock
|
||||
_, err = server.KeyLock(ctx, key)
|
||||
_, err = params.KeyLock(params.Context, key)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
if err = server.SetValue(ctx, key, internal.AdaptType(value)); err != nil {
|
||||
if err = params.SetValue(params.Context, key, internal.AdaptType(value)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If expiresAt is set, set the key's expiry time as well
|
||||
if params.expireAt != nil {
|
||||
server.SetExpiry(ctx, key, params.expireAt.(time.Time), false)
|
||||
if options.expireAt != nil {
|
||||
params.SetExpiry(params.Context, key, options.expireAt.(time.Time), false)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func handleMSet(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
_, err := msetKeyFunc(cmd)
|
||||
func handleMSet(params types.HandlerFuncParams) ([]byte, error) {
|
||||
_, err := msetKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -110,7 +108,7 @@ func handleMSet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
defer func() {
|
||||
for k, v := range entries {
|
||||
if v.locked {
|
||||
server.KeyUnlock(ctx, k)
|
||||
params.KeyUnlock(params.Context, k)
|
||||
entries[k] = KeyObject{
|
||||
value: v.value,
|
||||
locked: false,
|
||||
@@ -120,10 +118,10 @@ func handleMSet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
}()
|
||||
|
||||
// Extract all the key/value pairs
|
||||
for i, key := range cmd[1:] {
|
||||
for i, key := range params.Command[1:] {
|
||||
if i%2 == 0 {
|
||||
entries[key] = KeyObject{
|
||||
value: internal.AdaptType(cmd[1:][i+1]),
|
||||
value: internal.AdaptType(params.Command[1:][i+1]),
|
||||
locked: false,
|
||||
}
|
||||
}
|
||||
@@ -132,14 +130,14 @@ func handleMSet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
// Acquire all the locks for each key first
|
||||
// If any key cannot be acquired, abandon transaction and release all currently held keys
|
||||
for k, v := range entries {
|
||||
if server.KeyExists(ctx, k) {
|
||||
if _, err := server.KeyLock(ctx, k); err != nil {
|
||||
if params.KeyExists(params.Context, k) {
|
||||
if _, err := params.KeyLock(params.Context, k); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries[k] = KeyObject{value: v.value, locked: true}
|
||||
continue
|
||||
}
|
||||
if _, err := server.CreateKeyAndLock(ctx, k); err != nil {
|
||||
if _, err := params.CreateKeyAndLock(params.Context, k); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries[k] = KeyObject{value: v.value, locked: true}
|
||||
@@ -147,7 +145,7 @@ func handleMSet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
|
||||
// Set all the values
|
||||
for k, v := range entries {
|
||||
if err := server.SetValue(ctx, k, v.value); err != nil {
|
||||
if err := params.SetValue(params.Context, k, v.value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -155,30 +153,30 @@ func handleMSet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
return []byte(constants.OkResponse), nil
|
||||
}
|
||||
|
||||
func handleGet(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := getKeyFunc(cmd)
|
||||
func handleGet(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := getKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key := keys.ReadKeys[0]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte("$-1\r\n"), nil
|
||||
}
|
||||
|
||||
_, err = server.KeyRLock(ctx, key)
|
||||
_, err = params.KeyRLock(params.Context, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
value := server.GetValue(ctx, key)
|
||||
value := params.GetValue(params.Context, key)
|
||||
|
||||
return []byte(fmt.Sprintf("+%v\r\n", value)), nil
|
||||
}
|
||||
|
||||
func handleMGet(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := mgetKeyFunc(cmd)
|
||||
func handleMGet(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := mgetKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -191,8 +189,8 @@ func handleMGet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
// Skip if we have already locked this key
|
||||
continue
|
||||
}
|
||||
if server.KeyExists(ctx, key) {
|
||||
_, err = server.KeyRLock(ctx, key)
|
||||
if params.KeyExists(params.Context, key) {
|
||||
_, err = params.KeyRLock(params.Context, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not obtain lock for %s key", key)
|
||||
}
|
||||
@@ -204,19 +202,19 @@ func handleMGet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
defer func() {
|
||||
for key, locked := range locks {
|
||||
if locked {
|
||||
server.KeyRUnlock(ctx, key)
|
||||
params.KeyRUnlock(params.Context, key)
|
||||
locks[key] = false
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for key, _ := range locks {
|
||||
values[key] = fmt.Sprintf("%v", server.GetValue(ctx, key))
|
||||
values[key] = fmt.Sprintf("%v", params.GetValue(params.Context, key))
|
||||
}
|
||||
|
||||
bytes := []byte(fmt.Sprintf("*%d\r\n", len(cmd[1:])))
|
||||
bytes := []byte(fmt.Sprintf("*%d\r\n", len(params.Command[1:])))
|
||||
|
||||
for _, key := range cmd[1:] {
|
||||
for _, key := range params.Command[1:] {
|
||||
if values[key] == "" {
|
||||
bytes = append(bytes, []byte("$-1\r\n")...)
|
||||
continue
|
||||
@@ -227,14 +225,14 @@ func handleMGet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
func handleDel(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := delKeyFunc(cmd)
|
||||
func handleDel(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := delKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count := 0
|
||||
for _, key := range keys.WriteKeys {
|
||||
err = server.DeleteKey(ctx, key)
|
||||
err = params.DeleteKey(params.Context, key)
|
||||
if err != nil {
|
||||
log.Printf("could not delete key %s due to error: %+v\n", key, err)
|
||||
continue
|
||||
@@ -244,91 +242,91 @@ func handleDel(ctx context.Context, cmd []string, server types.EchoVault, _ *net
|
||||
return []byte(fmt.Sprintf(":%d\r\n", count)), nil
|
||||
}
|
||||
|
||||
func handlePersist(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := persistKeyFunc(cmd)
|
||||
func handlePersist(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := persistKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.WriteKeys[0]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
expireAt := server.GetExpiry(ctx, key)
|
||||
expireAt := params.GetExpiry(params.Context, key)
|
||||
if expireAt == (time.Time{}) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
|
||||
server.SetExpiry(ctx, key, time.Time{}, false)
|
||||
params.SetExpiry(params.Context, key, time.Time{}, false)
|
||||
|
||||
return []byte(":1\r\n"), nil
|
||||
}
|
||||
|
||||
func handleExpireTime(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := expireTimeKeyFunc(cmd)
|
||||
func handleExpireTime(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := expireTimeKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte(":-2\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
expireAt := server.GetExpiry(ctx, key)
|
||||
expireAt := params.GetExpiry(params.Context, key)
|
||||
|
||||
if expireAt == (time.Time{}) {
|
||||
return []byte(":-1\r\n"), nil
|
||||
}
|
||||
|
||||
t := expireAt.Unix()
|
||||
if strings.ToLower(cmd[0]) == "pexpiretime" {
|
||||
if strings.ToLower(params.Command[0]) == "pexpiretime" {
|
||||
t = expireAt.UnixMilli()
|
||||
}
|
||||
|
||||
return []byte(fmt.Sprintf(":%d\r\n", t)), nil
|
||||
}
|
||||
|
||||
func handleTTL(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := ttlKeyFunc(cmd)
|
||||
func handleTTL(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := ttlKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
|
||||
clock := server.GetClock()
|
||||
clock := params.GetClock()
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte(":-2\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
expireAt := server.GetExpiry(ctx, key)
|
||||
expireAt := params.GetExpiry(params.Context, key)
|
||||
|
||||
if expireAt == (time.Time{}) {
|
||||
return []byte(":-1\r\n"), nil
|
||||
}
|
||||
|
||||
t := expireAt.Unix() - clock.Now().Unix()
|
||||
if strings.ToLower(cmd[0]) == "pttl" {
|
||||
if strings.ToLower(params.Command[0]) == "pttl" {
|
||||
t = expireAt.UnixMilli() - clock.Now().UnixMilli()
|
||||
}
|
||||
|
||||
@@ -339,8 +337,8 @@ func handleTTL(ctx context.Context, cmd []string, server types.EchoVault, _ *net
|
||||
return []byte(fmt.Sprintf(":%d\r\n", t)), nil
|
||||
}
|
||||
|
||||
func handleExpire(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := expireKeyFunc(cmd)
|
||||
func handleExpire(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := expireKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -348,42 +346,42 @@ func handleExpire(ctx context.Context, cmd []string, server types.EchoVault, _ *
|
||||
key := keys.WriteKeys[0]
|
||||
|
||||
// Extract time
|
||||
n, err := strconv.ParseInt(cmd[2], 10, 64)
|
||||
n, err := strconv.ParseInt(params.Command[2], 10, 64)
|
||||
if err != nil {
|
||||
return nil, errors.New("expire time must be integer")
|
||||
}
|
||||
expireAt := server.GetClock().Now().Add(time.Duration(n) * time.Second)
|
||||
if strings.ToLower(cmd[0]) == "pexpire" {
|
||||
expireAt = server.GetClock().Now().Add(time.Duration(n) * time.Millisecond)
|
||||
expireAt := params.GetClock().Now().Add(time.Duration(n) * time.Second)
|
||||
if strings.ToLower(params.Command[0]) == "pexpire" {
|
||||
expireAt = params.GetClock().Now().Add(time.Duration(n) * time.Millisecond)
|
||||
}
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
if len(cmd) == 3 {
|
||||
server.SetExpiry(ctx, key, expireAt, true)
|
||||
if len(params.Command) == 3 {
|
||||
params.SetExpiry(params.Context, key, expireAt, true)
|
||||
return []byte(":1\r\n"), nil
|
||||
}
|
||||
|
||||
currentExpireAt := server.GetExpiry(ctx, key)
|
||||
currentExpireAt := params.GetExpiry(params.Context, key)
|
||||
|
||||
switch strings.ToLower(cmd[3]) {
|
||||
switch strings.ToLower(params.Command[3]) {
|
||||
case "nx":
|
||||
if currentExpireAt != (time.Time{}) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
server.SetExpiry(ctx, key, expireAt, false)
|
||||
params.SetExpiry(params.Context, key, expireAt, false)
|
||||
case "xx":
|
||||
if currentExpireAt == (time.Time{}) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
server.SetExpiry(ctx, key, expireAt, false)
|
||||
params.SetExpiry(params.Context, key, expireAt, false)
|
||||
case "gt":
|
||||
if currentExpireAt == (time.Time{}) {
|
||||
return []byte(":0\r\n"), nil
|
||||
@@ -391,24 +389,24 @@ func handleExpire(ctx context.Context, cmd []string, server types.EchoVault, _ *
|
||||
if expireAt.Before(currentExpireAt) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
server.SetExpiry(ctx, key, expireAt, false)
|
||||
params.SetExpiry(params.Context, key, expireAt, false)
|
||||
case "lt":
|
||||
if currentExpireAt != (time.Time{}) {
|
||||
if currentExpireAt.Before(expireAt) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
server.SetExpiry(ctx, key, expireAt, false)
|
||||
params.SetExpiry(params.Context, key, expireAt, false)
|
||||
}
|
||||
server.SetExpiry(ctx, key, expireAt, false)
|
||||
params.SetExpiry(params.Context, key, expireAt, false)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown option %s", strings.ToUpper(cmd[3]))
|
||||
return nil, fmt.Errorf("unknown option %s", strings.ToUpper(params.Command[3]))
|
||||
}
|
||||
|
||||
return []byte(":1\r\n"), nil
|
||||
}
|
||||
|
||||
func handleExpireAt(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := expireKeyFunc(cmd)
|
||||
func handleExpireAt(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := expireKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -416,42 +414,42 @@ func handleExpireAt(ctx context.Context, cmd []string, server types.EchoVault, _
|
||||
key := keys.WriteKeys[0]
|
||||
|
||||
// Extract time
|
||||
n, err := strconv.ParseInt(cmd[2], 10, 64)
|
||||
n, err := strconv.ParseInt(params.Command[2], 10, 64)
|
||||
if err != nil {
|
||||
return nil, errors.New("expire time must be integer")
|
||||
}
|
||||
expireAt := time.Unix(n, 0)
|
||||
if strings.ToLower(cmd[0]) == "pexpireat" {
|
||||
if strings.ToLower(params.Command[0]) == "pexpireat" {
|
||||
expireAt = time.UnixMilli(n)
|
||||
}
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
if len(cmd) == 3 {
|
||||
server.SetExpiry(ctx, key, expireAt, true)
|
||||
if len(params.Command) == 3 {
|
||||
params.SetExpiry(params.Context, key, expireAt, true)
|
||||
return []byte(":1\r\n"), nil
|
||||
}
|
||||
|
||||
currentExpireAt := server.GetExpiry(ctx, key)
|
||||
currentExpireAt := params.GetExpiry(params.Context, key)
|
||||
|
||||
switch strings.ToLower(cmd[3]) {
|
||||
switch strings.ToLower(params.Command[3]) {
|
||||
case "nx":
|
||||
if currentExpireAt != (time.Time{}) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
server.SetExpiry(ctx, key, expireAt, false)
|
||||
params.SetExpiry(params.Context, key, expireAt, false)
|
||||
case "xx":
|
||||
if currentExpireAt == (time.Time{}) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
server.SetExpiry(ctx, key, expireAt, false)
|
||||
params.SetExpiry(params.Context, key, expireAt, false)
|
||||
case "gt":
|
||||
if currentExpireAt == (time.Time{}) {
|
||||
return []byte(":0\r\n"), nil
|
||||
@@ -459,17 +457,17 @@ func handleExpireAt(ctx context.Context, cmd []string, server types.EchoVault, _
|
||||
if expireAt.Before(currentExpireAt) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
server.SetExpiry(ctx, key, expireAt, false)
|
||||
params.SetExpiry(params.Context, key, expireAt, false)
|
||||
case "lt":
|
||||
if currentExpireAt != (time.Time{}) {
|
||||
if currentExpireAt.Before(expireAt) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
server.SetExpiry(ctx, key, expireAt, false)
|
||||
params.SetExpiry(params.Context, key, expireAt, false)
|
||||
}
|
||||
server.SetExpiry(ctx, key, expireAt, false)
|
||||
params.SetExpiry(params.Context, key, expireAt, false)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown option %s", strings.ToUpper(cmd[3]))
|
||||
return nil, fmt.Errorf("unknown option %s", strings.ToUpper(params.Command[3]))
|
||||
}
|
||||
|
||||
return []byte(":1\r\n"), nil
|
||||
|
||||
@@ -23,96 +23,96 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type SetParams struct {
|
||||
type SetOptions struct {
|
||||
exists string
|
||||
get bool
|
||||
expireAt interface{} // Exact expireAt time un unix milliseconds
|
||||
}
|
||||
|
||||
func getSetCommandParams(clock clock.Clock, cmd []string, params SetParams) (SetParams, error) {
|
||||
func getSetCommandOptions(clock clock.Clock, cmd []string, options SetOptions) (SetOptions, error) {
|
||||
if len(cmd) == 0 {
|
||||
return params, nil
|
||||
return options, nil
|
||||
}
|
||||
switch strings.ToLower(cmd[0]) {
|
||||
case "get":
|
||||
params.get = true
|
||||
return getSetCommandParams(clock, cmd[1:], params)
|
||||
options.get = true
|
||||
return getSetCommandOptions(clock, cmd[1:], options)
|
||||
|
||||
case "nx":
|
||||
if params.exists != "" {
|
||||
return SetParams{}, fmt.Errorf("cannot specify NX when %s is already specified", strings.ToUpper(params.exists))
|
||||
if options.exists != "" {
|
||||
return SetOptions{}, fmt.Errorf("cannot specify NX when %s is already specified", strings.ToUpper(options.exists))
|
||||
}
|
||||
params.exists = "NX"
|
||||
return getSetCommandParams(clock, cmd[1:], params)
|
||||
options.exists = "NX"
|
||||
return getSetCommandOptions(clock, cmd[1:], options)
|
||||
|
||||
case "xx":
|
||||
if params.exists != "" {
|
||||
return SetParams{}, fmt.Errorf("cannot specify XX when %s is already specified", strings.ToUpper(params.exists))
|
||||
if options.exists != "" {
|
||||
return SetOptions{}, fmt.Errorf("cannot specify XX when %s is already specified", strings.ToUpper(options.exists))
|
||||
}
|
||||
params.exists = "XX"
|
||||
return getSetCommandParams(clock, cmd[1:], params)
|
||||
options.exists = "XX"
|
||||
return getSetCommandOptions(clock, cmd[1:], options)
|
||||
|
||||
case "ex":
|
||||
if len(cmd) < 2 {
|
||||
return SetParams{}, errors.New("seconds value required after EX")
|
||||
return SetOptions{}, errors.New("seconds value required after EX")
|
||||
}
|
||||
if params.expireAt != nil {
|
||||
return SetParams{}, errors.New("cannot specify EX when expiry time is already set")
|
||||
if options.expireAt != nil {
|
||||
return SetOptions{}, errors.New("cannot specify EX when expiry time is already set")
|
||||
}
|
||||
secondsStr := cmd[1]
|
||||
seconds, err := strconv.ParseInt(secondsStr, 10, 64)
|
||||
if err != nil {
|
||||
return SetParams{}, errors.New("seconds value should be an integer")
|
||||
return SetOptions{}, errors.New("seconds value should be an integer")
|
||||
}
|
||||
params.expireAt = clock.Now().Add(time.Duration(seconds) * time.Second)
|
||||
return getSetCommandParams(clock, cmd[2:], params)
|
||||
options.expireAt = clock.Now().Add(time.Duration(seconds) * time.Second)
|
||||
return getSetCommandOptions(clock, cmd[2:], options)
|
||||
|
||||
case "px":
|
||||
if len(cmd) < 2 {
|
||||
return SetParams{}, errors.New("milliseconds value required after PX")
|
||||
return SetOptions{}, errors.New("milliseconds value required after PX")
|
||||
}
|
||||
if params.expireAt != nil {
|
||||
return SetParams{}, errors.New("cannot specify PX when expiry time is already set")
|
||||
if options.expireAt != nil {
|
||||
return SetOptions{}, errors.New("cannot specify PX when expiry time is already set")
|
||||
}
|
||||
millisecondsStr := cmd[1]
|
||||
milliseconds, err := strconv.ParseInt(millisecondsStr, 10, 64)
|
||||
if err != nil {
|
||||
return SetParams{}, errors.New("milliseconds value should be an integer")
|
||||
return SetOptions{}, errors.New("milliseconds value should be an integer")
|
||||
}
|
||||
params.expireAt = clock.Now().Add(time.Duration(milliseconds) * time.Millisecond)
|
||||
return getSetCommandParams(clock, cmd[2:], params)
|
||||
options.expireAt = clock.Now().Add(time.Duration(milliseconds) * time.Millisecond)
|
||||
return getSetCommandOptions(clock, cmd[2:], options)
|
||||
|
||||
case "exat":
|
||||
if len(cmd) < 2 {
|
||||
return SetParams{}, errors.New("seconds value required after EXAT")
|
||||
return SetOptions{}, errors.New("seconds value required after EXAT")
|
||||
}
|
||||
if params.expireAt != nil {
|
||||
return SetParams{}, errors.New("cannot specify EXAT when expiry time is already set")
|
||||
if options.expireAt != nil {
|
||||
return SetOptions{}, errors.New("cannot specify EXAT when expiry time is already set")
|
||||
}
|
||||
secondsStr := cmd[1]
|
||||
seconds, err := strconv.ParseInt(secondsStr, 10, 64)
|
||||
if err != nil {
|
||||
return SetParams{}, errors.New("seconds value should be an integer")
|
||||
return SetOptions{}, errors.New("seconds value should be an integer")
|
||||
}
|
||||
params.expireAt = time.Unix(seconds, 0)
|
||||
return getSetCommandParams(clock, cmd[2:], params)
|
||||
options.expireAt = time.Unix(seconds, 0)
|
||||
return getSetCommandOptions(clock, cmd[2:], options)
|
||||
|
||||
case "pxat":
|
||||
if len(cmd) < 2 {
|
||||
return SetParams{}, errors.New("milliseconds value required after PXAT")
|
||||
return SetOptions{}, errors.New("milliseconds value required after PXAT")
|
||||
}
|
||||
if params.expireAt != nil {
|
||||
return SetParams{}, errors.New("cannot specify PXAT when expiry time is already set")
|
||||
if options.expireAt != nil {
|
||||
return SetOptions{}, errors.New("cannot specify PXAT when expiry time is already set")
|
||||
}
|
||||
millisecondsStr := cmd[1]
|
||||
milliseconds, err := strconv.ParseInt(millisecondsStr, 10, 64)
|
||||
if err != nil {
|
||||
return SetParams{}, errors.New("milliseconds value should be an integer")
|
||||
return SetOptions{}, errors.New("milliseconds value should be an integer")
|
||||
}
|
||||
params.expireAt = time.UnixMilli(milliseconds)
|
||||
return getSetCommandParams(clock, cmd[2:], params)
|
||||
options.expireAt = time.UnixMilli(milliseconds)
|
||||
return getSetCommandOptions(clock, cmd[2:], options)
|
||||
|
||||
default:
|
||||
return SetParams{}, fmt.Errorf("unknown option %s for set command", strings.ToUpper(cmd[0]))
|
||||
return SetOptions{}, fmt.Errorf("unknown option %s for set command", strings.ToUpper(cmd[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,21 +15,19 @@
|
||||
package hash
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/internal"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"math/rand"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func handleHSET(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := hsetKeyFunc(cmd)
|
||||
func handleHSET(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := hsetKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -37,39 +35,39 @@ func handleHSET(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
key := keys.WriteKeys[0]
|
||||
entries := make(map[string]interface{})
|
||||
|
||||
if len(cmd[2:])%2 != 0 {
|
||||
if len(params.Command[2:])%2 != 0 {
|
||||
return nil, errors.New("each field must have a corresponding value")
|
||||
}
|
||||
|
||||
for i := 2; i <= len(cmd)-2; i += 2 {
|
||||
entries[cmd[i]] = internal.AdaptType(cmd[i+1])
|
||||
for i := 2; i <= len(params.Command)-2; i += 2 {
|
||||
entries[params.Command[i]] = internal.AdaptType(params.Command[i+1])
|
||||
}
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
_, err = server.CreateKeyAndLock(ctx, key)
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
_, err = params.CreateKeyAndLock(params.Context, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
if err = server.SetValue(ctx, key, entries); err != nil {
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
if err = params.SetValue(params.Context, key, entries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(fmt.Sprintf(":%d\r\n", len(entries))), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
|
||||
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at %s is not a hash", key)
|
||||
}
|
||||
|
||||
count := 0
|
||||
for field, value := range entries {
|
||||
if strings.EqualFold(cmd[0], "hsetnx") {
|
||||
if strings.EqualFold(params.Command[0], "hsetnx") {
|
||||
if hash[field] == nil {
|
||||
hash[field] = value
|
||||
count += 1
|
||||
@@ -79,32 +77,32 @@ func handleHSET(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
hash[field] = value
|
||||
count += 1
|
||||
}
|
||||
if err = server.SetValue(ctx, key, hash); err != nil {
|
||||
if err = params.SetValue(params.Context, key, hash); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []byte(fmt.Sprintf(":%d\r\n", count)), nil
|
||||
}
|
||||
|
||||
func handleHGET(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := hgetKeyFunc(cmd)
|
||||
func handleHGET(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := hgetKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
fields := cmd[2:]
|
||||
fields := params.Command[2:]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte("$-1\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
|
||||
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at %s is not a hash", key)
|
||||
}
|
||||
@@ -137,25 +135,25 @@ func handleHGET(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleHSTRLEN(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := hstrlenKeyFunc(cmd)
|
||||
func handleHSTRLEN(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := hstrlenKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
fields := cmd[2:]
|
||||
fields := params.Command[2:]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte("$-1\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
|
||||
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at %s is not a hash", key)
|
||||
}
|
||||
@@ -188,24 +186,24 @@ func handleHSTRLEN(ctx context.Context, cmd []string, server types.EchoVault, _
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleHVALS(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := hvalsKeyFunc(cmd)
|
||||
func handleHVALS(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := hvalsKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte("*0\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
|
||||
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at %s is not a hash", key)
|
||||
}
|
||||
@@ -229,8 +227,8 @@ func handleHVALS(ctx context.Context, cmd []string, server types.EchoVault, _ *n
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleHRANDFIELD(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := hrandfieldKeyFunc(cmd)
|
||||
func handleHRANDFIELD(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := hrandfieldKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -238,8 +236,8 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server types.EchoVault,
|
||||
key := keys.ReadKeys[0]
|
||||
|
||||
count := 1
|
||||
if len(cmd) >= 3 {
|
||||
c, err := strconv.Atoi(cmd[2])
|
||||
if len(params.Command) >= 3 {
|
||||
c, err := strconv.Atoi(params.Command[2])
|
||||
if err != nil {
|
||||
return nil, errors.New("count must be an integer")
|
||||
}
|
||||
@@ -250,24 +248,24 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server types.EchoVault,
|
||||
}
|
||||
|
||||
withvalues := false
|
||||
if len(cmd) == 4 {
|
||||
if strings.EqualFold(cmd[3], "withvalues") {
|
||||
if len(params.Command) == 4 {
|
||||
if strings.EqualFold(params.Command[3], "withvalues") {
|
||||
withvalues = true
|
||||
} else {
|
||||
return nil, errors.New("result modifier must be withvalues")
|
||||
}
|
||||
}
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte("*0\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
|
||||
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at %s is not a hash", key)
|
||||
}
|
||||
@@ -345,24 +343,24 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server types.EchoVault,
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleHLEN(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := hlenKeyFunc(cmd)
|
||||
func handleHLEN(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := hlenKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
|
||||
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at %s is not a hash", key)
|
||||
}
|
||||
@@ -370,24 +368,24 @@ func handleHLEN(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
return []byte(fmt.Sprintf(":%d\r\n", len(hash))), nil
|
||||
}
|
||||
|
||||
func handleHKEYS(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := hkeysKeyFunc(cmd)
|
||||
func handleHKEYS(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := hkeysKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte("*0\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
|
||||
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at %s is not a hash", key)
|
||||
}
|
||||
@@ -400,59 +398,59 @@ func handleHKEYS(ctx context.Context, cmd []string, server types.EchoVault, _ *n
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleHINCRBY(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := hincrbyKeyFunc(cmd)
|
||||
func handleHINCRBY(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := hincrbyKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.WriteKeys[0]
|
||||
field := cmd[2]
|
||||
field := params.Command[2]
|
||||
|
||||
var intIncrement int
|
||||
var floatIncrement float64
|
||||
|
||||
if strings.EqualFold(cmd[0], "hincrbyfloat") {
|
||||
f, err := strconv.ParseFloat(cmd[3], 64)
|
||||
if strings.EqualFold(params.Command[0], "hincrbyfloat") {
|
||||
f, err := strconv.ParseFloat(params.Command[3], 64)
|
||||
if err != nil {
|
||||
return nil, errors.New("increment must be a float")
|
||||
}
|
||||
floatIncrement = f
|
||||
} else {
|
||||
i, err := strconv.Atoi(cmd[3])
|
||||
i, err := strconv.Atoi(params.Command[3])
|
||||
if err != nil {
|
||||
return nil, errors.New("increment must be an integer")
|
||||
}
|
||||
intIncrement = i
|
||||
}
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
if _, err := params.CreateKeyAndLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
hash := make(map[string]interface{})
|
||||
if strings.EqualFold(cmd[0], "hincrbyfloat") {
|
||||
if strings.EqualFold(params.Command[0], "hincrbyfloat") {
|
||||
hash[field] = floatIncrement
|
||||
if err = server.SetValue(ctx, key, hash); err != nil {
|
||||
if err = params.SetValue(params.Context, key, hash); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(floatIncrement, 'f', -1, 64))), nil
|
||||
} else {
|
||||
hash[field] = intIncrement
|
||||
if err = server.SetValue(ctx, key, hash); err != nil {
|
||||
if err = params.SetValue(params.Context, key, hash); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(fmt.Sprintf(":%d\r\n", intIncrement)), nil
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := server.KeyLock(ctx, key); err != nil {
|
||||
if _, err := params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
|
||||
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at %s is not a hash", key)
|
||||
}
|
||||
@@ -466,21 +464,21 @@ func handleHINCRBY(ctx context.Context, cmd []string, server types.EchoVault, _
|
||||
return nil, fmt.Errorf("value at field %s is not a number", field)
|
||||
case int:
|
||||
i, _ := hash[field].(int)
|
||||
if strings.EqualFold(cmd[0], "hincrbyfloat") {
|
||||
if strings.EqualFold(params.Command[0], "hincrbyfloat") {
|
||||
hash[field] = float64(i) + floatIncrement
|
||||
} else {
|
||||
hash[field] = i + intIncrement
|
||||
}
|
||||
case float64:
|
||||
f, _ := hash[field].(float64)
|
||||
if strings.EqualFold(cmd[0], "hincrbyfloat") {
|
||||
if strings.EqualFold(params.Command[0], "hincrbyfloat") {
|
||||
hash[field] = f + floatIncrement
|
||||
} else {
|
||||
hash[field] = f + float64(intIncrement)
|
||||
}
|
||||
}
|
||||
|
||||
if err = server.SetValue(ctx, key, hash); err != nil {
|
||||
if err = params.SetValue(params.Context, key, hash); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -492,24 +490,24 @@ func handleHINCRBY(ctx context.Context, cmd []string, server types.EchoVault, _
|
||||
return []byte(fmt.Sprintf(":%d\r\n", i)), nil
|
||||
}
|
||||
|
||||
func handleHGETALL(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := hgetallKeyFunc(cmd)
|
||||
func handleHGETALL(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := hgetallKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte("*0\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
|
||||
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at %s is not a hash", key)
|
||||
}
|
||||
@@ -532,25 +530,25 @@ func handleHGETALL(ctx context.Context, cmd []string, server types.EchoVault, _
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleHEXISTS(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := hexistsKeyFunc(cmd)
|
||||
func handleHEXISTS(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := hexistsKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
field := cmd[2]
|
||||
field := params.Command[2]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
|
||||
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at %s is not a hash", key)
|
||||
}
|
||||
@@ -562,25 +560,25 @@ func handleHEXISTS(ctx context.Context, cmd []string, server types.EchoVault, _
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
|
||||
func handleHDEL(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := hdelKeyFunc(cmd)
|
||||
func handleHDEL(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := hdelKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.WriteKeys[0]
|
||||
fields := cmd[2:]
|
||||
fields := params.Command[2:]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
hash, ok := server.GetValue(ctx, key).(map[string]interface{})
|
||||
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at %s is not a hash", key)
|
||||
}
|
||||
@@ -594,7 +592,7 @@ func handleHDEL(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
}
|
||||
}
|
||||
|
||||
if err = server.SetValue(ctx, key, hash); err != nil {
|
||||
if err = params.SetValue(params.Context, key, hash); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -15,65 +15,63 @@
|
||||
package list
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/internal"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"math"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func handleLLen(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := llenKeyFunc(cmd)
|
||||
func handleLLen(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := llenKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
// If key does not exist, return 0
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
if list, ok := server.GetValue(ctx, key).([]interface{}); ok {
|
||||
if list, ok := params.GetValue(params.Context, key).([]interface{}); ok {
|
||||
return []byte(fmt.Sprintf(":%d\r\n", len(list))), nil
|
||||
}
|
||||
|
||||
return nil, errors.New("LLEN command on non-list item")
|
||||
}
|
||||
|
||||
func handleLIndex(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := lindexKeyFunc(cmd)
|
||||
func handleLIndex(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := lindexKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
index, ok := internal.AdaptType(cmd[2]).(int)
|
||||
index, ok := internal.AdaptType(params.Command[2]).(int)
|
||||
|
||||
if !ok {
|
||||
return nil, errors.New("index must be an integer")
|
||||
}
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return nil, errors.New("LINDEX command on non-list item")
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list, ok := server.GetValue(ctx, key).([]interface{})
|
||||
server.KeyRUnlock(ctx, key)
|
||||
list, ok := params.GetValue(params.Context, key).([]interface{})
|
||||
params.KeyRUnlock(params.Context, key)
|
||||
|
||||
if !ok {
|
||||
return nil, errors.New("LINDEX command on non-list item")
|
||||
@@ -86,30 +84,30 @@ func handleLIndex(ctx context.Context, cmd []string, server types.EchoVault, _ *
|
||||
return []byte(fmt.Sprintf("+%s\r\n", list[index])), nil
|
||||
}
|
||||
|
||||
func handleLRange(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := lrangeKeyFunc(cmd)
|
||||
func handleLRange(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := lrangeKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
start, startOk := internal.AdaptType(cmd[2]).(int)
|
||||
end, endOk := internal.AdaptType(cmd[3]).(int)
|
||||
start, startOk := internal.AdaptType(params.Command[2]).(int)
|
||||
end, endOk := internal.AdaptType(params.Command[3]).(int)
|
||||
|
||||
if !startOk || !endOk {
|
||||
return nil, errors.New("start and end indices must be integers")
|
||||
}
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return nil, errors.New("LRANGE command on non-list item")
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
list, ok := server.GetValue(ctx, key).([]interface{})
|
||||
list, ok := params.GetValue(params.Context, key).([]interface{})
|
||||
if !ok {
|
||||
return nil, errors.New("LRANGE command on non-list item")
|
||||
}
|
||||
@@ -165,29 +163,29 @@ func handleLRange(ctx context.Context, cmd []string, server types.EchoVault, _ *
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
func handleLSet(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := lsetKeyFunc(cmd)
|
||||
func handleLSet(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := lsetKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.WriteKeys[0]
|
||||
|
||||
index, ok := internal.AdaptType(cmd[2]).(int)
|
||||
index, ok := internal.AdaptType(params.Command[2]).(int)
|
||||
if !ok {
|
||||
return nil, errors.New("index must be an integer")
|
||||
}
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return nil, errors.New("LSET command on non-list item")
|
||||
}
|
||||
|
||||
if _, err = server.KeyLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
list, ok := server.GetValue(ctx, key).([]interface{})
|
||||
list, ok := params.GetValue(params.Context, key).([]interface{})
|
||||
if !ok {
|
||||
return nil, errors.New("LSET command on non-list item")
|
||||
}
|
||||
@@ -196,23 +194,23 @@ func handleLSet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
return nil, errors.New("index must be within list range")
|
||||
}
|
||||
|
||||
list[index] = internal.AdaptType(cmd[3])
|
||||
if err = server.SetValue(ctx, key, list); err != nil {
|
||||
list[index] = internal.AdaptType(params.Command[3])
|
||||
if err = params.SetValue(params.Context, key, list); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []byte(constants.OkResponse), nil
|
||||
}
|
||||
|
||||
func handleLTrim(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := ltrimKeyFunc(cmd)
|
||||
func handleLTrim(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := ltrimKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.WriteKeys[0]
|
||||
start, startOk := internal.AdaptType(cmd[2]).(int)
|
||||
end, endOk := internal.AdaptType(cmd[3]).(int)
|
||||
start, startOk := internal.AdaptType(params.Command[2]).(int)
|
||||
end, endOk := internal.AdaptType(params.Command[3]).(int)
|
||||
|
||||
if !startOk || !endOk {
|
||||
return nil, errors.New("start and end indices must be integers")
|
||||
@@ -222,16 +220,16 @@ func handleLTrim(ctx context.Context, cmd []string, server types.EchoVault, _ *n
|
||||
return nil, errors.New("end index must be greater than start index or -1")
|
||||
}
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return nil, errors.New("LTRIM command on non-list item")
|
||||
}
|
||||
|
||||
if _, err = server.KeyLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
list, ok := server.GetValue(ctx, key).([]interface{})
|
||||
list, ok := params.GetValue(params.Context, key).([]interface{})
|
||||
if !ok {
|
||||
return nil, errors.New("LTRIM command on non-list item")
|
||||
}
|
||||
@@ -241,44 +239,44 @@ func handleLTrim(ctx context.Context, cmd []string, server types.EchoVault, _ *n
|
||||
}
|
||||
|
||||
if end == -1 || end > len(list) {
|
||||
if err = server.SetValue(ctx, key, list[start:]); err != nil {
|
||||
if err = params.SetValue(params.Context, key, list[start:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(constants.OkResponse), nil
|
||||
}
|
||||
|
||||
if err = server.SetValue(ctx, key, list[start:end]); err != nil {
|
||||
if err = params.SetValue(params.Context, key, list[start:end]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(constants.OkResponse), nil
|
||||
}
|
||||
|
||||
func handleLRem(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := lremKeyFunc(cmd)
|
||||
func handleLRem(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := lremKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.WriteKeys[0]
|
||||
value := cmd[3]
|
||||
value := params.Command[3]
|
||||
|
||||
count, ok := internal.AdaptType(cmd[2]).(int)
|
||||
count, ok := internal.AdaptType(params.Command[2]).(int)
|
||||
if !ok {
|
||||
return nil, errors.New("count must be an integer")
|
||||
}
|
||||
|
||||
absoluteCount := internal.AbsInt(count)
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return nil, errors.New("LREM command on non-list item")
|
||||
}
|
||||
|
||||
if _, err = server.KeyLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
list, ok := server.GetValue(ctx, key).([]interface{})
|
||||
list, ok := params.GetValue(params.Context, key).([]interface{})
|
||||
if !ok {
|
||||
return nil, errors.New("LREM command on non-list item")
|
||||
}
|
||||
@@ -314,44 +312,44 @@ func handleLRem(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
return elem == nil
|
||||
})
|
||||
|
||||
if err = server.SetValue(ctx, key, list); err != nil {
|
||||
if err = params.SetValue(params.Context, key, list); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []byte(constants.OkResponse), nil
|
||||
}
|
||||
|
||||
func handleLMove(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := lmoveKeyFunc(cmd)
|
||||
func handleLMove(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := lmoveKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
source, destination := keys.WriteKeys[0], keys.WriteKeys[1]
|
||||
whereFrom := strings.ToLower(cmd[3])
|
||||
whereTo := strings.ToLower(cmd[4])
|
||||
whereFrom := strings.ToLower(params.Command[3])
|
||||
whereTo := strings.ToLower(params.Command[4])
|
||||
|
||||
if !slices.Contains([]string{"left", "right"}, whereFrom) || !slices.Contains([]string{"left", "right"}, whereTo) {
|
||||
return nil, errors.New("wherefrom and whereto arguments must be either LEFT or RIGHT")
|
||||
}
|
||||
|
||||
if !server.KeyExists(ctx, source) || !server.KeyExists(ctx, destination) {
|
||||
if !params.KeyExists(params.Context, source) || !params.KeyExists(params.Context, destination) {
|
||||
return nil, errors.New("both source and destination must be lists")
|
||||
}
|
||||
|
||||
if _, err = server.KeyLock(ctx, source); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, source); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, source)
|
||||
defer params.KeyUnlock(params.Context, source)
|
||||
|
||||
_, err = server.KeyLock(ctx, destination)
|
||||
_, err = params.KeyLock(params.Context, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, destination)
|
||||
defer params.KeyUnlock(params.Context, destination)
|
||||
|
||||
sourceList, sourceOk := server.GetValue(ctx, source).([]interface{})
|
||||
destinationList, destinationOk := server.GetValue(ctx, destination).([]interface{})
|
||||
sourceList, sourceOk := params.GetValue(params.Context, source).([]interface{})
|
||||
destinationList, destinationOk := params.GetValue(params.Context, destination).([]interface{})
|
||||
|
||||
if !sourceOk || !destinationOk {
|
||||
return nil, errors.New("both source and destination must be lists")
|
||||
@@ -359,18 +357,18 @@ func handleLMove(ctx context.Context, cmd []string, server types.EchoVault, _ *n
|
||||
|
||||
switch whereFrom {
|
||||
case "left":
|
||||
err = server.SetValue(ctx, source, append([]interface{}{}, sourceList[1:]...))
|
||||
err = params.SetValue(params.Context, source, append([]interface{}{}, sourceList[1:]...))
|
||||
if whereTo == "left" {
|
||||
err = server.SetValue(ctx, destination, append(sourceList[0:1], destinationList...))
|
||||
err = params.SetValue(params.Context, destination, append(sourceList[0:1], destinationList...))
|
||||
} else if whereTo == "right" {
|
||||
err = server.SetValue(ctx, destination, append(destinationList, sourceList[0]))
|
||||
err = params.SetValue(params.Context, destination, append(destinationList, sourceList[0]))
|
||||
}
|
||||
case "right":
|
||||
err = server.SetValue(ctx, source, append([]interface{}{}, sourceList[:len(sourceList)-1]...))
|
||||
err = params.SetValue(params.Context, source, append([]interface{}{}, sourceList[:len(sourceList)-1]...))
|
||||
if whereTo == "left" {
|
||||
err = server.SetValue(ctx, destination, append(sourceList[len(sourceList)-1:], destinationList...))
|
||||
err = params.SetValue(params.Context, destination, append(sourceList[len(sourceList)-1:], destinationList...))
|
||||
} else if whereTo == "right" {
|
||||
err = server.SetValue(ctx, destination, append(destinationList, sourceList[len(sourceList)-1]))
|
||||
err = params.SetValue(params.Context, destination, append(destinationList, sourceList[len(sourceList)-1]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -381,54 +379,54 @@ func handleLMove(ctx context.Context, cmd []string, server types.EchoVault, _ *n
|
||||
return []byte(constants.OkResponse), nil
|
||||
}
|
||||
|
||||
func handleLPush(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := lpushKeyFunc(cmd)
|
||||
func handleLPush(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := lpushKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var newElems []interface{}
|
||||
|
||||
for _, elem := range cmd[2:] {
|
||||
for _, elem := range params.Command[2:] {
|
||||
newElems = append(newElems, internal.AdaptType(elem))
|
||||
}
|
||||
|
||||
key := keys.WriteKeys[0]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
switch strings.ToLower(cmd[0]) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
switch strings.ToLower(params.Command[0]) {
|
||||
case "lpushx":
|
||||
return nil, errors.New("LPUSHX command on non-list item")
|
||||
default:
|
||||
if _, err = server.CreateKeyAndLock(ctx, key); err != nil {
|
||||
if _, err = params.CreateKeyAndLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = server.SetValue(ctx, key, []interface{}{}); err != nil {
|
||||
if err = params.SetValue(params.Context, key, []interface{}{}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, err = server.KeyLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
currentList := server.GetValue(ctx, key)
|
||||
currentList := params.GetValue(params.Context, key)
|
||||
|
||||
l, ok := currentList.([]interface{})
|
||||
if !ok {
|
||||
return nil, errors.New("LPUSH command on non-list item")
|
||||
}
|
||||
|
||||
if err = server.SetValue(ctx, key, append(newElems, l...)); err != nil {
|
||||
if err = params.SetValue(params.Context, key, append(newElems, l...)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(constants.OkResponse), nil
|
||||
}
|
||||
|
||||
func handleRPush(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := rpushKeyFunc(cmd)
|
||||
func handleRPush(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := rpushKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -437,31 +435,31 @@ func handleRPush(ctx context.Context, cmd []string, server types.EchoVault, _ *n
|
||||
|
||||
var newElems []interface{}
|
||||
|
||||
for _, elem := range cmd[2:] {
|
||||
for _, elem := range params.Command[2:] {
|
||||
newElems = append(newElems, internal.AdaptType(elem))
|
||||
}
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
switch strings.ToLower(cmd[0]) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
switch strings.ToLower(params.Command[0]) {
|
||||
case "rpushx":
|
||||
return nil, errors.New("RPUSHX command on non-list item")
|
||||
default:
|
||||
if _, err = server.CreateKeyAndLock(ctx, key); err != nil {
|
||||
if _, err = params.CreateKeyAndLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
if err = server.SetValue(ctx, key, []interface{}{}); err != nil {
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
if err = params.SetValue(params.Context, key, []interface{}{}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, err = server.KeyLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
}
|
||||
|
||||
currentList := server.GetValue(ctx, key)
|
||||
currentList := params.GetValue(params.Context, key)
|
||||
|
||||
l, ok := currentList.([]interface{})
|
||||
|
||||
@@ -469,42 +467,42 @@ func handleRPush(ctx context.Context, cmd []string, server types.EchoVault, _ *n
|
||||
return nil, errors.New("RPUSH command on non-list item")
|
||||
}
|
||||
|
||||
if err = server.SetValue(ctx, key, append(l, newElems...)); err != nil {
|
||||
if err = params.SetValue(params.Context, key, append(l, newElems...)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(constants.OkResponse), nil
|
||||
}
|
||||
|
||||
func handlePop(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := popKeyFunc(cmd)
|
||||
func handlePop(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := popKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.WriteKeys[0]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(cmd[0]))
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(params.Command[0]))
|
||||
}
|
||||
|
||||
if _, err = server.KeyLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
list, ok := server.GetValue(ctx, key).([]interface{})
|
||||
list, ok := params.GetValue(params.Context, key).([]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(cmd[0]))
|
||||
return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(params.Command[0]))
|
||||
}
|
||||
|
||||
switch strings.ToLower(cmd[0]) {
|
||||
switch strings.ToLower(params.Command[0]) {
|
||||
default:
|
||||
if err = server.SetValue(ctx, key, list[1:]); err != nil {
|
||||
if err = params.SetValue(params.Context, key, list[1:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(fmt.Sprintf("+%v\r\n", list[0])), nil
|
||||
case "rpop":
|
||||
if err = server.SetValue(ctx, key, list[:len(list)-1]); err != nil {
|
||||
if err = params.SetValue(params.Context, key, list[:len(list)-1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(fmt.Sprintf("+%v\r\n", list[len(list)-1])), nil
|
||||
|
||||
@@ -15,79 +15,77 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
internal_pubsub "github.com/echovault/echovault/internal/pubsub"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func handleSubscribe(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
|
||||
pubsub, ok := server.GetPubSub().(*internal_pubsub.PubSub)
|
||||
func handleSubscribe(params types.HandlerFuncParams) ([]byte, error) {
|
||||
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load pubsub module")
|
||||
}
|
||||
|
||||
channels := cmd[1:]
|
||||
channels := params.Command[1:]
|
||||
|
||||
if len(channels) == 0 {
|
||||
return nil, errors.New(constants.WrongArgsResponse)
|
||||
}
|
||||
|
||||
withPattern := strings.EqualFold(cmd[0], "psubscribe")
|
||||
pubsub.Subscribe(ctx, conn, channels, withPattern)
|
||||
withPattern := strings.EqualFold(params.Command[0], "psubscribe")
|
||||
pubsub.Subscribe(params.Context, params.Connection, channels, withPattern)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func handleUnsubscribe(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
|
||||
pubsub, ok := server.GetPubSub().(*internal_pubsub.PubSub)
|
||||
func handleUnsubscribe(params types.HandlerFuncParams) ([]byte, error) {
|
||||
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load pubsub module")
|
||||
}
|
||||
|
||||
channels := cmd[1:]
|
||||
channels := params.Command[1:]
|
||||
|
||||
withPattern := strings.EqualFold(cmd[0], "punsubscribe")
|
||||
withPattern := strings.EqualFold(params.Command[0], "punsubscribe")
|
||||
|
||||
return pubsub.Unsubscribe(ctx, conn, channels, withPattern), nil
|
||||
return pubsub.Unsubscribe(params.Context, params.Connection, channels, withPattern), nil
|
||||
}
|
||||
|
||||
func handlePublish(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
pubsub, ok := server.GetPubSub().(*internal_pubsub.PubSub)
|
||||
func handlePublish(params types.HandlerFuncParams) ([]byte, error) {
|
||||
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load pubsub module")
|
||||
}
|
||||
if len(cmd) != 3 {
|
||||
if len(params.Command) != 3 {
|
||||
return nil, errors.New(constants.WrongArgsResponse)
|
||||
}
|
||||
pubsub.Publish(ctx, cmd[2], cmd[1])
|
||||
pubsub.Publish(params.Context, params.Command[2], params.Command[1])
|
||||
return []byte(constants.OkResponse), nil
|
||||
}
|
||||
|
||||
func handlePubSubChannels(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
if len(cmd) > 3 {
|
||||
func handlePubSubChannels(params types.HandlerFuncParams) ([]byte, error) {
|
||||
if len(params.Command) > 3 {
|
||||
return nil, errors.New(constants.WrongArgsResponse)
|
||||
}
|
||||
|
||||
pubsub, ok := server.GetPubSub().(*internal_pubsub.PubSub)
|
||||
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load pubsub module")
|
||||
}
|
||||
|
||||
pattern := ""
|
||||
if len(cmd) == 3 {
|
||||
pattern = cmd[2]
|
||||
if len(params.Command) == 3 {
|
||||
pattern = params.Command[2]
|
||||
}
|
||||
|
||||
return pubsub.Channels(pattern), nil
|
||||
}
|
||||
|
||||
func handlePubSubNumPat(_ context.Context, _ []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
pubsub, ok := server.GetPubSub().(*internal_pubsub.PubSub)
|
||||
func handlePubSubNumPat(params types.HandlerFuncParams) ([]byte, error) {
|
||||
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load pubsub module")
|
||||
}
|
||||
@@ -95,12 +93,12 @@ func handlePubSubNumPat(_ context.Context, _ []string, server types.EchoVault, _
|
||||
return []byte(fmt.Sprintf(":%d\r\n", num)), nil
|
||||
}
|
||||
|
||||
func handlePubSubNumSubs(_ context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
pubsub, ok := server.GetPubSub().(*internal_pubsub.PubSub)
|
||||
func handlePubSubNumSubs(params types.HandlerFuncParams) ([]byte, error) {
|
||||
pubsub, ok := params.GetPubSub().(*internal_pubsub.PubSub)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load pubsub module")
|
||||
}
|
||||
return pubsub.NumSub(cmd[2:]), nil
|
||||
return pubsub.NumSub(params.Command[2:]), nil
|
||||
}
|
||||
|
||||
func Commands() []types.Command {
|
||||
@@ -210,7 +208,7 @@ it's currently subscribe to.`,
|
||||
WriteKeys: make([]string, 0),
|
||||
}, nil
|
||||
},
|
||||
HandlerFunc: func(_ context.Context, _ []string, _ types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
HandlerFunc: func(_ types.HandlerFuncParams) ([]byte, error) {
|
||||
return nil, errors.New("provide CHANNELS, NUMPAT, or NUMSUB subcommand")
|
||||
},
|
||||
SubCommands: []types.SubCommand{
|
||||
|
||||
@@ -15,20 +15,18 @@
|
||||
package set
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/internal"
|
||||
internal_set "github.com/echovault/echovault/internal/set"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func handleSADD(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := saddKeyFunc(cmd)
|
||||
func handleSADD(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := saddKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -37,51 +35,51 @@ func handleSADD(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
|
||||
var set *internal_set.Set
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
set = internal_set.NewSet(cmd[2:])
|
||||
if ok, err := server.CreateKeyAndLock(ctx, key); !ok && err != nil {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
set = internal_set.NewSet(params.Command[2:])
|
||||
if ok, err := params.CreateKeyAndLock(params.Context, key); !ok && err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = server.SetValue(ctx, key, set); err != nil {
|
||||
if err = params.SetValue(params.Context, key, set); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server.KeyUnlock(ctx, key)
|
||||
return []byte(fmt.Sprintf(":%d\r\n", len(cmd[2:]))), nil
|
||||
params.KeyUnlock(params.Context, key)
|
||||
return []byte(fmt.Sprintf(":%d\r\n", len(params.Command[2:]))), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
|
||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||
}
|
||||
|
||||
count := set.Add(cmd[2:])
|
||||
count := set.Add(params.Command[2:])
|
||||
|
||||
return []byte(fmt.Sprintf(":%d\r\n", count)), nil
|
||||
}
|
||||
|
||||
func handleSCARD(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := scardKeyFunc(cmd)
|
||||
func handleSCARD(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := scardKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte(fmt.Sprintf(":0\r\n")), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
|
||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||
}
|
||||
@@ -91,21 +89,21 @@ func handleSCARD(ctx context.Context, cmd []string, server types.EchoVault, _ *n
|
||||
return []byte(fmt.Sprintf(":%d\r\n", cardinality)), nil
|
||||
}
|
||||
|
||||
func handleSDIFF(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := sdiffKeyFunc(cmd)
|
||||
func handleSDIFF(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := sdiffKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract base set first
|
||||
if !server.KeyExists(ctx, keys.ReadKeys[0]) {
|
||||
if !params.KeyExists(params.Context, keys.ReadKeys[0]) {
|
||||
return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys.ReadKeys[0])
|
||||
}
|
||||
if _, err = server.KeyRLock(ctx, keys.ReadKeys[0]); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, keys.ReadKeys[0]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, keys.ReadKeys[0])
|
||||
baseSet, ok := server.GetValue(ctx, keys.ReadKeys[0]).(*internal_set.Set)
|
||||
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
|
||||
baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*internal_set.Set)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
|
||||
}
|
||||
@@ -114,24 +112,24 @@ func handleSDIFF(ctx context.Context, cmd []string, server types.EchoVault, _ *n
|
||||
defer func() {
|
||||
for key, locked := range locks {
|
||||
if locked {
|
||||
server.KeyRUnlock(ctx, key)
|
||||
params.KeyRUnlock(params.Context, key)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, key := range keys.ReadKeys[1:] {
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
continue
|
||||
}
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
continue
|
||||
}
|
||||
locks[key] = true
|
||||
}
|
||||
|
||||
var sets []*internal_set.Set
|
||||
for _, key := range cmd[2:] {
|
||||
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
|
||||
for _, key := range params.Command[2:] {
|
||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
@@ -152,8 +150,8 @@ func handleSDIFF(ctx context.Context, cmd []string, server types.EchoVault, _ *n
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := sdiffstoreKeyFunc(cmd)
|
||||
func handleSDIFFSTORE(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := sdiffstoreKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -161,14 +159,14 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
|
||||
destination := keys.WriteKeys[0]
|
||||
|
||||
// Extract base set first
|
||||
if !server.KeyExists(ctx, keys.ReadKeys[0]) {
|
||||
if !params.KeyExists(params.Context, keys.ReadKeys[0]) {
|
||||
return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys.ReadKeys[0])
|
||||
}
|
||||
if _, err := server.KeyRLock(ctx, keys.ReadKeys[0]); err != nil {
|
||||
if _, err := params.KeyRLock(params.Context, keys.ReadKeys[0]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, keys.ReadKeys[0])
|
||||
baseSet, ok := server.GetValue(ctx, keys.ReadKeys[0]).(*internal_set.Set)
|
||||
defer params.KeyRUnlock(params.Context, keys.ReadKeys[0])
|
||||
baseSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*internal_set.Set)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
|
||||
}
|
||||
@@ -177,16 +175,16 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
|
||||
defer func() {
|
||||
for key, locked := range locks {
|
||||
if locked {
|
||||
server.KeyRUnlock(ctx, key)
|
||||
params.KeyRUnlock(params.Context, key)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, key := range keys.ReadKeys[1:] {
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
continue
|
||||
}
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
continue
|
||||
}
|
||||
locks[key] = true
|
||||
@@ -194,7 +192,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
|
||||
|
||||
var sets []*internal_set.Set
|
||||
for _, key := range keys.ReadKeys[1:] {
|
||||
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
|
||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
@@ -206,30 +204,30 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
|
||||
|
||||
res := fmt.Sprintf(":%d\r\n", len(elems))
|
||||
|
||||
if server.KeyExists(ctx, destination) {
|
||||
if _, err = server.KeyLock(ctx, destination); err != nil {
|
||||
if params.KeyExists(params.Context, destination) {
|
||||
if _, err = params.KeyLock(params.Context, destination); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = server.SetValue(ctx, destination, diff); err != nil {
|
||||
if err = params.SetValue(params.Context, destination, diff); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server.KeyUnlock(ctx, destination)
|
||||
params.KeyUnlock(params.Context, destination)
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
if _, err = server.CreateKeyAndLock(ctx, destination); err != nil {
|
||||
if _, err = params.CreateKeyAndLock(params.Context, destination); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = server.SetValue(ctx, destination, diff); err != nil {
|
||||
if err = params.SetValue(params.Context, destination, diff); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server.KeyUnlock(ctx, destination)
|
||||
params.KeyUnlock(params.Context, destination)
|
||||
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleSINTER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := sinterKeyFunc(cmd)
|
||||
func handleSINTER(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := sinterKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -238,17 +236,17 @@ func handleSINTER(ctx context.Context, cmd []string, server types.EchoVault, _ *
|
||||
defer func() {
|
||||
for key, locked := range locks {
|
||||
if locked {
|
||||
server.KeyRUnlock(ctx, key)
|
||||
params.KeyRUnlock(params.Context, key)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, key := range keys.ReadKeys {
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
// If key does not exist, then there is no intersection
|
||||
return []byte("*0\r\n"), nil
|
||||
}
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
locks[key] = true
|
||||
@@ -257,7 +255,7 @@ func handleSINTER(ctx context.Context, cmd []string, server types.EchoVault, _ *
|
||||
var sets []*internal_set.Set
|
||||
|
||||
for key, _ := range locks {
|
||||
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
|
||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
||||
if !ok {
|
||||
// If the value at the key is not a set, return error
|
||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||
@@ -283,15 +281,15 @@ func handleSINTER(ctx context.Context, cmd []string, server types.EchoVault, _ *
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleSINTERCARD(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := sintercardKeyFunc(cmd)
|
||||
func handleSINTERCARD(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := sintercardKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract the limit from the command
|
||||
var limit int
|
||||
limitIdx := slices.IndexFunc(cmd, func(s string) bool {
|
||||
limitIdx := slices.IndexFunc(params.Command, func(s string) bool {
|
||||
return strings.EqualFold(s, "limit")
|
||||
})
|
||||
if limitIdx >= 0 && limitIdx < 2 {
|
||||
@@ -299,11 +297,11 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server types.EchoVault,
|
||||
}
|
||||
if limitIdx != -1 {
|
||||
limitIdx += 1
|
||||
if limitIdx >= len(cmd) {
|
||||
if limitIdx >= len(params.Command) {
|
||||
return nil, errors.New("provide limit after LIMIT keyword")
|
||||
}
|
||||
|
||||
if l, ok := internal.AdaptType(cmd[limitIdx]).(int); !ok {
|
||||
if l, ok := internal.AdaptType(params.Command[limitIdx]).(int); !ok {
|
||||
return nil, errors.New("limit must be an integer")
|
||||
} else {
|
||||
limit = l
|
||||
@@ -314,17 +312,17 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server types.EchoVault,
|
||||
defer func() {
|
||||
for key, locked := range locks {
|
||||
if locked {
|
||||
server.KeyRUnlock(ctx, key)
|
||||
params.KeyRUnlock(params.Context, key)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, key := range keys.ReadKeys {
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
// If key does not exist, then there is no intersection
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
locks[key] = true
|
||||
@@ -333,7 +331,7 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server types.EchoVault,
|
||||
var sets []*internal_set.Set
|
||||
|
||||
for key, _ := range locks {
|
||||
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
|
||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
||||
if !ok {
|
||||
// If the value at the key is not a set, return error
|
||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||
@@ -350,8 +348,8 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server types.EchoVault,
|
||||
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
|
||||
}
|
||||
|
||||
func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := sinterstoreKeyFunc(cmd)
|
||||
func handleSINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := sinterstoreKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -360,17 +358,17 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault
|
||||
defer func() {
|
||||
for key, locked := range locks {
|
||||
if locked {
|
||||
server.KeyRUnlock(ctx, key)
|
||||
params.KeyRUnlock(params.Context, key)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, key := range keys.ReadKeys {
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
// If key does not exist, then there is no intersection
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
locks[key] = true
|
||||
@@ -379,7 +377,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault
|
||||
var sets []*internal_set.Set
|
||||
|
||||
for key, _ := range locks {
|
||||
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
|
||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
||||
if !ok {
|
||||
// If the value at the key is not a set, return error
|
||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||
@@ -390,71 +388,71 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault
|
||||
intersect, _ := internal_set.Intersection(0, sets...)
|
||||
destination := keys.WriteKeys[0]
|
||||
|
||||
if server.KeyExists(ctx, destination) {
|
||||
if _, err = server.KeyLock(ctx, destination); err != nil {
|
||||
if params.KeyExists(params.Context, destination) {
|
||||
if _, err = params.KeyLock(params.Context, destination); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if _, err = server.CreateKeyAndLock(ctx, destination); err != nil {
|
||||
if _, err = params.CreateKeyAndLock(params.Context, destination); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = server.SetValue(ctx, destination, intersect); err != nil {
|
||||
if err = params.SetValue(params.Context, destination, intersect); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server.KeyUnlock(ctx, destination)
|
||||
params.KeyUnlock(params.Context, destination)
|
||||
|
||||
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
|
||||
}
|
||||
|
||||
func handleSISMEMBER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := sismemberKeyFunc(cmd)
|
||||
func handleSISMEMBER(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := sismemberKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
|
||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||
}
|
||||
|
||||
if !set.Contains(cmd[2]) {
|
||||
if !set.Contains(params.Command[2]) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
|
||||
return []byte(":1\r\n"), nil
|
||||
}
|
||||
|
||||
func handleSMEMBERS(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := smembersKeyFunc(cmd)
|
||||
func handleSMEMBERS(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := smembersKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte("*0\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
|
||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||
}
|
||||
@@ -472,16 +470,16 @@ func handleSMEMBERS(ctx context.Context, cmd []string, server types.EchoVault, _
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleSMISMEMBER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := smismemberKeyFunc(cmd)
|
||||
func handleSMISMEMBER(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := smismemberKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.ReadKeys[0]
|
||||
members := cmd[2:]
|
||||
members := params.Command[2:]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
res := fmt.Sprintf("*%d", len(members))
|
||||
for i, _ := range members {
|
||||
res = fmt.Sprintf("%s\r\n:0", res)
|
||||
@@ -492,12 +490,12 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server types.EchoVault,
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyRUnlock(ctx, key)
|
||||
defer params.KeyRUnlock(params.Context, key)
|
||||
|
||||
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
|
||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||
}
|
||||
@@ -515,48 +513,48 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server types.EchoVault,
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleSMOVE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := smoveKeyFunc(cmd)
|
||||
func handleSMOVE(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := smoveKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
source, destination := keys.WriteKeys[0], keys.WriteKeys[1]
|
||||
member := cmd[3]
|
||||
member := params.Command[3]
|
||||
|
||||
if !server.KeyExists(ctx, source) {
|
||||
if !params.KeyExists(params.Context, source) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyLock(ctx, source); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, source); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, source)
|
||||
defer params.KeyUnlock(params.Context, source)
|
||||
|
||||
sourceSet, ok := server.GetValue(ctx, source).(*internal_set.Set)
|
||||
sourceSet, ok := params.GetValue(params.Context, source).(*internal_set.Set)
|
||||
if !ok {
|
||||
return nil, errors.New("source is not a set")
|
||||
}
|
||||
|
||||
var destinationSet *internal_set.Set
|
||||
|
||||
if !server.KeyExists(ctx, destination) {
|
||||
if !params.KeyExists(params.Context, destination) {
|
||||
// Destination key does not exist
|
||||
if _, err = server.CreateKeyAndLock(ctx, destination); err != nil {
|
||||
if _, err = params.CreateKeyAndLock(params.Context, destination); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, destination)
|
||||
defer params.KeyUnlock(params.Context, destination)
|
||||
destinationSet = internal_set.NewSet([]string{})
|
||||
if err = server.SetValue(ctx, destination, destinationSet); err != nil {
|
||||
if err = params.SetValue(params.Context, destination, destinationSet); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// Destination key exists
|
||||
if _, err := server.KeyLock(ctx, destination); err != nil {
|
||||
if _, err := params.KeyLock(params.Context, destination); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, destination)
|
||||
ds, ok := server.GetValue(ctx, destination).(*internal_set.Set)
|
||||
defer params.KeyUnlock(params.Context, destination)
|
||||
ds, ok := params.GetValue(params.Context, destination).(*internal_set.Set)
|
||||
if !ok {
|
||||
return nil, errors.New("destination is not a set")
|
||||
}
|
||||
@@ -568,8 +566,8 @@ func handleSMOVE(ctx context.Context, cmd []string, server types.EchoVault, _ *n
|
||||
return []byte(fmt.Sprintf(":%d\r\n", res)), nil
|
||||
}
|
||||
|
||||
func handleSPOP(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := spopKeyFunc(cmd)
|
||||
func handleSPOP(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := spopKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -577,24 +575,24 @@ func handleSPOP(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
key := keys.WriteKeys[0]
|
||||
count := 1
|
||||
|
||||
if len(cmd) == 3 {
|
||||
c, ok := internal.AdaptType(cmd[2]).(int)
|
||||
if len(params.Command) == 3 {
|
||||
c, ok := internal.AdaptType(params.Command[2]).(int)
|
||||
if !ok {
|
||||
return nil, errors.New("count must be an integer")
|
||||
}
|
||||
count = c
|
||||
}
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte("*-1\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
|
||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at %s is not a set", key)
|
||||
}
|
||||
@@ -612,8 +610,8 @@ func handleSPOP(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleSRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := srandmemberKeyFunc(cmd)
|
||||
func handleSRANDMEMBER(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := srandmemberKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -621,24 +619,24 @@ func handleSRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault
|
||||
key := keys.ReadKeys[0]
|
||||
count := 1
|
||||
|
||||
if len(cmd) == 3 {
|
||||
c, ok := internal.AdaptType(cmd[2]).(int)
|
||||
if len(params.Command) == 3 {
|
||||
c, ok := internal.AdaptType(params.Command[2]).(int)
|
||||
if !ok {
|
||||
return nil, errors.New("count must be an integer")
|
||||
}
|
||||
count = c
|
||||
}
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte("*-1\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
|
||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at %s is not a set", key)
|
||||
}
|
||||
@@ -656,25 +654,25 @@ func handleSRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleSREM(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := sremKeyFunc(cmd)
|
||||
func handleSREM(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := sremKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := keys.WriteKeys[0]
|
||||
members := cmd[2:]
|
||||
members := params.Command[2:]
|
||||
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
return []byte(":0\r\n"), nil
|
||||
}
|
||||
|
||||
if _, err = server.KeyLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
defer params.KeyUnlock(params.Context, key)
|
||||
|
||||
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
|
||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||
}
|
||||
@@ -684,8 +682,8 @@ func handleSREM(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
|
||||
return []byte(fmt.Sprintf(":%d\r\n", count)), nil
|
||||
}
|
||||
|
||||
func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := sunionKeyFunc(cmd)
|
||||
func handleSUNION(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := sunionKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -694,16 +692,16 @@ func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, _ *
|
||||
defer func() {
|
||||
for key, locked := range locks {
|
||||
if locked {
|
||||
server.KeyRUnlock(ctx, key)
|
||||
params.KeyRUnlock(params.Context, key)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, key := range keys.ReadKeys {
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
continue
|
||||
}
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
locks[key] = true
|
||||
@@ -715,7 +713,7 @@ func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, _ *
|
||||
if !locked {
|
||||
continue
|
||||
}
|
||||
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
|
||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||
}
|
||||
@@ -735,8 +733,8 @@ func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, _ *
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
|
||||
keys, err := sunionstoreKeyFunc(cmd)
|
||||
func handleSUNIONSTORE(params types.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := sunionstoreKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -745,16 +743,16 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault
|
||||
defer func() {
|
||||
for key, locked := range locks {
|
||||
if locked {
|
||||
server.KeyRUnlock(ctx, key)
|
||||
params.KeyRUnlock(params.Context, key)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, key := range keys.ReadKeys {
|
||||
if !server.KeyExists(ctx, key) {
|
||||
if !params.KeyExists(params.Context, key) {
|
||||
continue
|
||||
}
|
||||
if _, err = server.KeyRLock(ctx, key); err != nil {
|
||||
if _, err = params.KeyRLock(params.Context, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
locks[key] = true
|
||||
@@ -766,7 +764,7 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault
|
||||
if !locked {
|
||||
continue
|
||||
}
|
||||
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
|
||||
set, ok := params.GetValue(params.Context, key).(*internal_set.Set)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value at key %s is not a set", key)
|
||||
}
|
||||
@@ -777,18 +775,18 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault
|
||||
|
||||
destination := keys.WriteKeys[0]
|
||||
|
||||
if server.KeyExists(ctx, destination) {
|
||||
if _, err = server.KeyLock(ctx, destination); err != nil {
|
||||
if params.KeyExists(params.Context, destination) {
|
||||
if _, err = params.KeyLock(params.Context, destination); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if _, err = server.CreateKeyAndLock(ctx, destination); err != nil {
|
||||
if _, err = params.CreateKeyAndLock(params.Context, destination); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
defer server.KeyUnlock(ctx, destination)
|
||||
defer params.KeyUnlock(params.Context, destination)
|
||||
|
||||
if err = server.SetValue(ctx, destination, union); err != nil {
|
||||
if err = params.SetValue(params.Context, destination, union); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(fmt.Sprintf(":%d\r\n", union.Cardinality())), nil
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,4 +12,4 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package echovault
|
||||
package acl
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/echovault"
|
||||
acl2 "github.com/echovault/echovault/pkg/modules/acl"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"slices"
|
||||
@@ -60,8 +61,8 @@ func setUpServer(bindAddr string, port uint16, requirePass bool, aclConfig strin
|
||||
}
|
||||
|
||||
mockServer, _ := echovault.NewEchoVault(
|
||||
echovault.WithCommands(acl2.Commands()),
|
||||
echovault.WithConfig(conf),
|
||||
echovault.WithCommands(Commands()),
|
||||
)
|
||||
|
||||
// Add the initial test users to the ACL module
|
||||
@@ -12,4 +12,4 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package echovault
|
||||
package admin
|
||||
@@ -21,20 +21,58 @@ import (
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/echovault"
|
||||
"github.com/echovault/echovault/pkg/modules/admin"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_CommandsHandler(t *testing.T) {
|
||||
mockServer, _ := echovault.NewEchoVault(
|
||||
var mockServer *echovault.EchoVault
|
||||
|
||||
func init() {
|
||||
mockServer, _ = echovault.NewEchoVault(
|
||||
echovault.WithCommands(admin.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
echovault.WithCommands(Commands()),
|
||||
)
|
||||
}
|
||||
|
||||
res, err := handleGetAllCommands(context.Background(), []string{"commands"}, mockServer, nil)
|
||||
func getHandler(commands ...string) types.HandlerFunc {
|
||||
if len(commands) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, c := range mockServer.GetAllCommands() {
|
||||
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
|
||||
// Get command handler
|
||||
return c.HandlerFunc
|
||||
}
|
||||
if strings.EqualFold(commands[0], c.Command) {
|
||||
// Get sub-command handler
|
||||
for _, sc := range c.SubCommands {
|
||||
if strings.EqualFold(commands[1], sc.Command) {
|
||||
return sc.HandlerFunc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
|
||||
return types.HandlerFuncParams{
|
||||
Context: ctx,
|
||||
Command: cmd,
|
||||
Connection: conn,
|
||||
GetAllCommands: mockServer.GetAllCommands,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CommandsHandler(t *testing.T) {
|
||||
res, err := getHandler("COMMANDS")(getHandlerFuncParams(context.Background(), []string{"commands"}, nil))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -12,4 +12,4 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package echovault
|
||||
package connection
|
||||
@@ -21,7 +21,11 @@ import (
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/echovault"
|
||||
"github.com/echovault/echovault/pkg/modules/connection"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -29,6 +33,7 @@ var mockServer *echovault.EchoVault
|
||||
|
||||
func init() {
|
||||
mockServer, _ = echovault.NewEchoVault(
|
||||
echovault.WithCommands(connection.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
@@ -36,6 +41,35 @@ func init() {
|
||||
)
|
||||
}
|
||||
|
||||
func getHandler(commands ...string) types.HandlerFunc {
|
||||
if len(commands) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, c := range mockServer.GetAllCommands() {
|
||||
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
|
||||
// Get command handler
|
||||
return c.HandlerFunc
|
||||
}
|
||||
if strings.EqualFold(commands[0], c.Command) {
|
||||
// Get sub-command handler
|
||||
for _, sc := range c.SubCommands {
|
||||
if strings.EqualFold(commands[1], sc.Command) {
|
||||
return sc.HandlerFunc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
|
||||
return types.HandlerFuncParams{
|
||||
Context: ctx,
|
||||
Command: cmd,
|
||||
Connection: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandlePing(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -62,7 +96,7 @@ func Test_HandlePing(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
res, err := handlePing(ctx, test.command, mockServer, nil)
|
||||
res, err := getHandler("PING")(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedErr != nil && err != nil {
|
||||
if err.Error() != test.expectedErr.Error() {
|
||||
t.Errorf("expected error %s, got: %s", test.expectedErr.Error(), err.Error())
|
||||
@@ -12,14 +12,15 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package echovault
|
||||
package generic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/echovault/echovault/internal"
|
||||
"github.com/echovault/echovault/internal/clock"
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/pkg/commands"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/echovault"
|
||||
"github.com/echovault/echovault/pkg/modules/generic"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -27,13 +28,36 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEchoVault_DEL(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
func createEchoVault() *echovault.EchoVault {
|
||||
ev, _ := echovault.NewEchoVault(
|
||||
echovault.WithCommands(generic.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
}),
|
||||
)
|
||||
return ev
|
||||
}
|
||||
|
||||
func presetValue(server *echovault.EchoVault, ctx context.Context, key string, value interface{}) error {
|
||||
if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := server.SetValue(ctx, key, value); err != nil {
|
||||
return err
|
||||
}
|
||||
server.KeyUnlock(ctx, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func presetKeyData(server *echovault.EchoVault, ctx context.Context, key string, data internal.KeyData) {
|
||||
_, _ = server.CreateKeyAndLock(ctx, key)
|
||||
defer server.KeyUnlock(ctx, key)
|
||||
_ = server.SetValue(ctx, key, data.Value)
|
||||
server.SetExpiry(ctx, key, data.ExpireAt, false)
|
||||
}
|
||||
|
||||
func TestEchoVault_DEL(t *testing.T) {
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -59,7 +83,7 @@ func TestEchoVault_DEL(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, d := range tt.presetValues {
|
||||
presetKeyData(server, k, d)
|
||||
presetKeyData(server, context.Background(), k, d)
|
||||
}
|
||||
}
|
||||
got, err := server.DEL(tt.keys...)
|
||||
@@ -77,12 +101,7 @@ func TestEchoVault_DEL(t *testing.T) {
|
||||
func TestEchoVault_EXPIRE(t *testing.T) {
|
||||
mockClock := clock.NewClock()
|
||||
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -90,8 +109,8 @@ func TestEchoVault_EXPIRE(t *testing.T) {
|
||||
cmd string
|
||||
key string
|
||||
time int
|
||||
expireOpts EXPIREOptions
|
||||
pexpireOpts PEXPIREOptions
|
||||
expireOpts echovault.EXPIREOptions
|
||||
pexpireOpts echovault.PEXPIREOptions
|
||||
want int
|
||||
wantErr bool
|
||||
}{
|
||||
@@ -100,7 +119,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
|
||||
cmd: "EXPIRE",
|
||||
key: "key1",
|
||||
time: 100,
|
||||
expireOpts: EXPIREOptions{},
|
||||
expireOpts: echovault.EXPIREOptions{},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key1": {Value: "value1", ExpireAt: time.Time{}},
|
||||
},
|
||||
@@ -112,7 +131,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
|
||||
cmd: "PEXPIRE",
|
||||
key: "key2",
|
||||
time: 1000,
|
||||
pexpireOpts: PEXPIREOptions{},
|
||||
pexpireOpts: echovault.PEXPIREOptions{},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key2": {Value: "value2", ExpireAt: time.Time{}},
|
||||
},
|
||||
@@ -124,7 +143,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
|
||||
cmd: "EXPIRE",
|
||||
key: "key3",
|
||||
time: 1000,
|
||||
expireOpts: EXPIREOptions{NX: true},
|
||||
expireOpts: echovault.EXPIREOptions{NX: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key3": {Value: "value3", ExpireAt: time.Time{}},
|
||||
},
|
||||
@@ -136,7 +155,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
|
||||
cmd: "EXPIRE",
|
||||
key: "key4",
|
||||
time: 1000,
|
||||
expireOpts: EXPIREOptions{NX: true},
|
||||
expireOpts: echovault.EXPIREOptions{NX: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key4": {Value: "value4", ExpireAt: mockClock.Now().Add(1000 * time.Second)},
|
||||
},
|
||||
@@ -148,7 +167,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
|
||||
cmd: "EXPIRE",
|
||||
key: "key5",
|
||||
time: 1000,
|
||||
expireOpts: EXPIREOptions{XX: true},
|
||||
expireOpts: echovault.EXPIREOptions{XX: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key5": {Value: "value5", ExpireAt: mockClock.Now().Add(30 * time.Second)},
|
||||
},
|
||||
@@ -159,7 +178,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
|
||||
name: "Return 0 when key does not have an expiry and the XX flag is provided",
|
||||
cmd: "EXPIRE",
|
||||
time: 1000,
|
||||
expireOpts: EXPIREOptions{XX: true},
|
||||
expireOpts: echovault.EXPIREOptions{XX: true},
|
||||
key: "key6",
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key6": {Value: "value6", ExpireAt: time.Time{}},
|
||||
@@ -172,7 +191,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
|
||||
cmd: "EXPIRE",
|
||||
key: "key7",
|
||||
time: 100000,
|
||||
expireOpts: EXPIREOptions{GT: true},
|
||||
expireOpts: echovault.EXPIREOptions{GT: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key7": {Value: "value7", ExpireAt: mockClock.Now().Add(30 * time.Second)},
|
||||
},
|
||||
@@ -184,7 +203,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
|
||||
cmd: "EXPIRE",
|
||||
key: "key8",
|
||||
time: 1000,
|
||||
expireOpts: EXPIREOptions{GT: true},
|
||||
expireOpts: echovault.EXPIREOptions{GT: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key8": {Value: "value8", ExpireAt: mockClock.Now().Add(3000 * time.Second)},
|
||||
},
|
||||
@@ -196,7 +215,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
|
||||
cmd: "EXPIRE",
|
||||
key: "key9",
|
||||
time: 1000,
|
||||
expireOpts: EXPIREOptions{GT: true},
|
||||
expireOpts: echovault.EXPIREOptions{GT: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key9": {Value: "value9", ExpireAt: time.Time{}},
|
||||
},
|
||||
@@ -208,7 +227,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
|
||||
cmd: "EXPIRE",
|
||||
key: "key10",
|
||||
time: 1000,
|
||||
expireOpts: EXPIREOptions{LT: true},
|
||||
expireOpts: echovault.EXPIREOptions{LT: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key10": {Value: "value10", ExpireAt: mockClock.Now().Add(3000 * time.Second)},
|
||||
},
|
||||
@@ -220,7 +239,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
|
||||
cmd: "EXPIRE",
|
||||
key: "key11",
|
||||
time: 50000,
|
||||
expireOpts: EXPIREOptions{LT: true},
|
||||
expireOpts: echovault.EXPIREOptions{LT: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key11": {Value: "value11", ExpireAt: mockClock.Now().Add(30 * time.Second)},
|
||||
},
|
||||
@@ -232,7 +251,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, d := range tt.presetValues {
|
||||
presetKeyData(server, k, d)
|
||||
presetKeyData(server, context.Background(), k, d)
|
||||
}
|
||||
}
|
||||
var got int
|
||||
@@ -256,12 +275,7 @@ func TestEchoVault_EXPIRE(t *testing.T) {
|
||||
func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
mockClock := clock.NewClock()
|
||||
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -269,8 +283,8 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
cmd string
|
||||
key string
|
||||
time int
|
||||
expireAtOpts EXPIREATOptions
|
||||
pexpireAtOpts PEXPIREATOptions
|
||||
expireAtOpts echovault.EXPIREATOptions
|
||||
pexpireAtOpts echovault.PEXPIREATOptions
|
||||
want int
|
||||
wantErr bool
|
||||
}{
|
||||
@@ -278,7 +292,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
name: "Set new expire by unix seconds",
|
||||
cmd: "EXPIREAT",
|
||||
key: "key1",
|
||||
expireAtOpts: EXPIREATOptions{},
|
||||
expireAtOpts: echovault.EXPIREATOptions{},
|
||||
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key1": {Value: "value1", ExpireAt: time.Time{}},
|
||||
@@ -290,7 +304,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
name: "Set new expire by milliseconds",
|
||||
cmd: "PEXPIREAT",
|
||||
key: "key2",
|
||||
pexpireAtOpts: PEXPIREATOptions{},
|
||||
pexpireAtOpts: echovault.PEXPIREATOptions{},
|
||||
time: int(mockClock.Now().Add(1000 * time.Second).UnixMilli()),
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key2": {Value: "value2", ExpireAt: time.Time{}},
|
||||
@@ -303,7 +317,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
cmd: "EXPIREAT",
|
||||
key: "key3",
|
||||
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
|
||||
expireAtOpts: EXPIREATOptions{NX: true},
|
||||
expireAtOpts: echovault.EXPIREATOptions{NX: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key3": {Value: "value3", ExpireAt: time.Time{}},
|
||||
},
|
||||
@@ -314,7 +328,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
name: "Return 0, when NX flag is provided and key already has an expiry time",
|
||||
cmd: "EXPIREAT",
|
||||
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
|
||||
expireAtOpts: EXPIREATOptions{NX: true},
|
||||
expireAtOpts: echovault.EXPIREATOptions{NX: true},
|
||||
key: "key4",
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key4": {Value: "value4", ExpireAt: mockClock.Now().Add(1000 * time.Second)},
|
||||
@@ -327,7 +341,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
cmd: "EXPIREAT",
|
||||
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
|
||||
key: "key5",
|
||||
expireAtOpts: EXPIREATOptions{XX: true},
|
||||
expireAtOpts: echovault.EXPIREATOptions{XX: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key5": {Value: "value5", ExpireAt: mockClock.Now().Add(30 * time.Second)},
|
||||
},
|
||||
@@ -339,7 +353,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
cmd: "EXPIREAT",
|
||||
key: "key6",
|
||||
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
|
||||
expireAtOpts: EXPIREATOptions{XX: true},
|
||||
expireAtOpts: echovault.EXPIREATOptions{XX: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key6": {Value: "value6", ExpireAt: time.Time{}},
|
||||
},
|
||||
@@ -351,7 +365,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
cmd: "EXPIREAT",
|
||||
key: "key7",
|
||||
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
|
||||
expireAtOpts: EXPIREATOptions{GT: true},
|
||||
expireAtOpts: echovault.EXPIREATOptions{GT: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key7": {Value: "value7", ExpireAt: mockClock.Now().Add(30 * time.Second)},
|
||||
},
|
||||
@@ -363,7 +377,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
cmd: "EXPIREAT",
|
||||
key: "key8",
|
||||
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
|
||||
expireAtOpts: EXPIREATOptions{GT: true},
|
||||
expireAtOpts: echovault.EXPIREATOptions{GT: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key8": {Value: "value8", ExpireAt: mockClock.Now().Add(3000 * time.Second)},
|
||||
},
|
||||
@@ -375,7 +389,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
cmd: "EXPIREAT",
|
||||
key: "key9",
|
||||
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
|
||||
expireAtOpts: EXPIREATOptions{GT: true},
|
||||
expireAtOpts: echovault.EXPIREATOptions{GT: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key9": {Value: "value9", ExpireAt: time.Time{}},
|
||||
},
|
||||
@@ -386,7 +400,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
cmd: "EXPIREAT",
|
||||
key: "key10",
|
||||
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
|
||||
expireAtOpts: EXPIREATOptions{LT: true},
|
||||
expireAtOpts: echovault.EXPIREATOptions{LT: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key10": {Value: "value10", ExpireAt: mockClock.Now().Add(3000 * time.Second)},
|
||||
},
|
||||
@@ -398,7 +412,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
cmd: "EXPIREAT",
|
||||
key: "key11",
|
||||
time: int(mockClock.Now().Add(3000 * time.Second).Unix()),
|
||||
expireAtOpts: EXPIREATOptions{LT: true},
|
||||
expireAtOpts: echovault.EXPIREATOptions{LT: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key11": {Value: "value11", ExpireAt: mockClock.Now().Add(1000 * time.Second)},
|
||||
},
|
||||
@@ -410,7 +424,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
cmd: "EXPIREAT",
|
||||
key: "key12",
|
||||
time: int(mockClock.Now().Add(1000 * time.Second).Unix()),
|
||||
expireAtOpts: EXPIREATOptions{LT: true},
|
||||
expireAtOpts: echovault.EXPIREATOptions{LT: true},
|
||||
presetValues: map[string]internal.KeyData{
|
||||
"key12": {Value: "value12", ExpireAt: time.Time{}},
|
||||
},
|
||||
@@ -422,7 +436,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, d := range tt.presetValues {
|
||||
presetKeyData(server, k, d)
|
||||
presetKeyData(server, context.Background(), k, d)
|
||||
}
|
||||
}
|
||||
var got int
|
||||
@@ -446,12 +460,7 @@ func TestEchoVault_EXPIREAT(t *testing.T) {
|
||||
func TestEchoVault_EXPIRETIME(t *testing.T) {
|
||||
mockClock := clock.NewClock()
|
||||
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -503,7 +512,7 @@ func TestEchoVault_EXPIRETIME(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, d := range tt.presetValues {
|
||||
presetKeyData(server, k, d)
|
||||
presetKeyData(server, context.Background(), k, d)
|
||||
}
|
||||
}
|
||||
got, err := tt.expiretimeFunc(tt.key)
|
||||
@@ -519,12 +528,7 @@ func TestEchoVault_EXPIRETIME(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_GET(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -551,7 +555,11 @@ func TestEchoVault_GET(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.GET(tt.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -566,12 +574,7 @@ func TestEchoVault_GET(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_MGET(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -599,7 +602,11 @@ func TestEchoVault_MGET(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, v := range tt.presetValues {
|
||||
presetValue(server, k, v)
|
||||
err := presetValue(server, context.Background(), k, v)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
got, err := server.MGET(tt.keys...)
|
||||
@@ -625,19 +632,14 @@ func TestEchoVault_MGET(t *testing.T) {
|
||||
func TestEchoVault_SET(t *testing.T) {
|
||||
mockClock := clock.NewClock()
|
||||
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
presetValues map[string]internal.KeyData
|
||||
key string
|
||||
value string
|
||||
options SETOptions
|
||||
options echovault.SETOptions
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
@@ -646,7 +648,7 @@ func TestEchoVault_SET(t *testing.T) {
|
||||
presetValues: nil,
|
||||
key: "key1",
|
||||
value: "value1",
|
||||
options: SETOptions{},
|
||||
options: echovault.SETOptions{},
|
||||
want: "OK",
|
||||
wantErr: false,
|
||||
},
|
||||
@@ -655,7 +657,7 @@ func TestEchoVault_SET(t *testing.T) {
|
||||
presetValues: nil,
|
||||
key: "key2",
|
||||
value: "value2",
|
||||
options: SETOptions{NX: true},
|
||||
options: echovault.SETOptions{NX: true},
|
||||
want: "OK",
|
||||
wantErr: false,
|
||||
},
|
||||
@@ -669,7 +671,7 @@ func TestEchoVault_SET(t *testing.T) {
|
||||
},
|
||||
key: "key3",
|
||||
value: "value3",
|
||||
options: SETOptions{NX: true},
|
||||
options: echovault.SETOptions{NX: true},
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
@@ -683,7 +685,7 @@ func TestEchoVault_SET(t *testing.T) {
|
||||
},
|
||||
key: "key4",
|
||||
value: "value4",
|
||||
options: SETOptions{XX: true},
|
||||
options: echovault.SETOptions{XX: true},
|
||||
want: "OK",
|
||||
wantErr: false,
|
||||
},
|
||||
@@ -692,7 +694,7 @@ func TestEchoVault_SET(t *testing.T) {
|
||||
presetValues: nil,
|
||||
key: "key5",
|
||||
value: "value5",
|
||||
options: SETOptions{XX: true},
|
||||
options: echovault.SETOptions{XX: true},
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
@@ -701,7 +703,7 @@ func TestEchoVault_SET(t *testing.T) {
|
||||
presetValues: nil,
|
||||
key: "key6",
|
||||
value: "value6",
|
||||
options: SETOptions{EX: 100},
|
||||
options: echovault.SETOptions{EX: 100},
|
||||
want: "OK",
|
||||
wantErr: false,
|
||||
},
|
||||
@@ -710,7 +712,7 @@ func TestEchoVault_SET(t *testing.T) {
|
||||
presetValues: nil,
|
||||
key: "key7",
|
||||
value: "value7",
|
||||
options: SETOptions{PX: 4096},
|
||||
options: echovault.SETOptions{PX: 4096},
|
||||
want: "OK",
|
||||
wantErr: false,
|
||||
},
|
||||
@@ -719,7 +721,7 @@ func TestEchoVault_SET(t *testing.T) {
|
||||
presetValues: nil,
|
||||
key: "key8",
|
||||
value: "value8",
|
||||
options: SETOptions{EXAT: int(mockClock.Now().Add(200 * time.Second).Unix())},
|
||||
options: echovault.SETOptions{EXAT: int(mockClock.Now().Add(200 * time.Second).Unix())},
|
||||
want: "OK",
|
||||
wantErr: false,
|
||||
},
|
||||
@@ -727,7 +729,7 @@ func TestEchoVault_SET(t *testing.T) {
|
||||
name: "Set exact expiry time in milliseconds from unix epoch",
|
||||
key: "key9",
|
||||
value: "value9",
|
||||
options: SETOptions{PXAT: int(mockClock.Now().Add(4096 * time.Millisecond).UnixMilli())},
|
||||
options: echovault.SETOptions{PXAT: int(mockClock.Now().Add(4096 * time.Millisecond).UnixMilli())},
|
||||
presetValues: nil,
|
||||
want: "OK",
|
||||
wantErr: false,
|
||||
@@ -742,7 +744,7 @@ func TestEchoVault_SET(t *testing.T) {
|
||||
},
|
||||
key: "key10",
|
||||
value: "value10",
|
||||
options: SETOptions{GET: true, EX: 1000},
|
||||
options: echovault.SETOptions{GET: true, EX: 1000},
|
||||
want: "previous-value",
|
||||
wantErr: false,
|
||||
},
|
||||
@@ -751,7 +753,7 @@ func TestEchoVault_SET(t *testing.T) {
|
||||
presetValues: nil,
|
||||
key: "key11",
|
||||
value: "value11",
|
||||
options: SETOptions{GET: true, EX: 1000},
|
||||
options: echovault.SETOptions{GET: true, EX: 1000},
|
||||
want: "",
|
||||
wantErr: false,
|
||||
},
|
||||
@@ -760,7 +762,7 @@ func TestEchoVault_SET(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, d := range tt.presetValues {
|
||||
presetKeyData(server, k, d)
|
||||
presetKeyData(server, context.Background(), k, d)
|
||||
}
|
||||
}
|
||||
got, err := server.SET(tt.key, tt.value, tt.options)
|
||||
@@ -776,12 +778,7 @@ func TestEchoVault_SET(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_MSET(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -813,12 +810,7 @@ func TestEchoVault_MSET(t *testing.T) {
|
||||
func TestEchoVault_PERSIST(t *testing.T) {
|
||||
mockClock := clock.NewClock()
|
||||
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -857,7 +849,7 @@ func TestEchoVault_PERSIST(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, d := range tt.presetValues {
|
||||
presetKeyData(server, k, d)
|
||||
presetKeyData(server, context.Background(), k, d)
|
||||
}
|
||||
}
|
||||
got, err := server.PERSIST(tt.key)
|
||||
@@ -875,12 +867,7 @@ func TestEchoVault_PERSIST(t *testing.T) {
|
||||
func TestEchoVault_TTL(t *testing.T) {
|
||||
mockClock := clock.NewClock()
|
||||
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -933,7 +920,7 @@ func TestEchoVault_TTL(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, d := range tt.presetValues {
|
||||
presetKeyData(server, k, d)
|
||||
presetKeyData(server, context.Background(), k, d)
|
||||
}
|
||||
}
|
||||
got, err := tt.ttlFunc(tt.key)
|
||||
@@ -23,7 +23,11 @@ import (
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/echovault"
|
||||
"github.com/echovault/echovault/pkg/modules/generic"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -41,6 +45,7 @@ func init() {
|
||||
mockClock = clock.NewClock()
|
||||
|
||||
mockServer, _ = echovault.NewEchoVault(
|
||||
echovault.WithCommands(generic.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
@@ -48,6 +53,47 @@ func init() {
|
||||
)
|
||||
}
|
||||
|
||||
func getHandler(commands ...string) types.HandlerFunc {
|
||||
if len(commands) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, c := range mockServer.GetAllCommands() {
|
||||
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
|
||||
// Get command handler
|
||||
return c.HandlerFunc
|
||||
}
|
||||
if strings.EqualFold(commands[0], c.Command) {
|
||||
// Get sub-command handler
|
||||
for _, sc := range c.SubCommands {
|
||||
if strings.EqualFold(commands[1], sc.Command) {
|
||||
return sc.HandlerFunc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
|
||||
return types.HandlerFuncParams{
|
||||
Context: ctx,
|
||||
Command: cmd,
|
||||
Connection: conn,
|
||||
KeyExists: mockServer.KeyExists,
|
||||
CreateKeyAndLock: mockServer.CreateKeyAndLock,
|
||||
KeyLock: mockServer.KeyLock,
|
||||
KeyRLock: mockServer.KeyRLock,
|
||||
KeyUnlock: mockServer.KeyUnlock,
|
||||
KeyRUnlock: mockServer.KeyRUnlock,
|
||||
GetValue: mockServer.GetValue,
|
||||
SetValue: mockServer.SetValue,
|
||||
GetClock: mockServer.GetClock,
|
||||
GetExpiry: mockServer.GetExpiry,
|
||||
SetExpiry: mockServer.SetExpiry,
|
||||
DeleteKey: mockServer.DeleteKey,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleSET(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -372,7 +418,13 @@ func Test_HandleSET(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
res, err := handleSet(ctx, test.command, mockServer, nil)
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedErr != nil {
|
||||
if err == nil {
|
||||
t.Errorf("expected error \"%s\", got nil", test.expectedErr.Error())
|
||||
@@ -454,7 +506,14 @@ func Test_HandleMSET(t *testing.T) {
|
||||
for i, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("MSET, %d", i))
|
||||
res, err := handleMSet(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedErr != nil {
|
||||
if err.Error() != test.expectedErr.Error() {
|
||||
t.Errorf("expected error %s, got %s", test.expectedErr.Error(), err.Error())
|
||||
@@ -547,7 +606,14 @@ func Test_HandleGET(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
|
||||
res, err := handleGet(ctx, []string{"GET", key}, mockServer, nil)
|
||||
handler := getHandler("GET")
|
||||
if handler == nil {
|
||||
t.Error("no handler found for command GET")
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, []string{"GET", key}, nil))
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -559,7 +625,7 @@ func Test_HandleGET(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test get non-existent key
|
||||
res, err := handleGet(context.Background(), []string{"GET", "test4"}, mockServer, nil)
|
||||
res, err := getHandler("GET")(getHandlerFuncParams(context.Background(), []string{"GET", "test4"}, nil))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -585,7 +651,12 @@ func Test_HandleGET(t *testing.T) {
|
||||
}
|
||||
for _, test := range errorTests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
res, err = handleGet(context.Background(), test.command, mockServer, nil)
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
res, err = handler(getHandlerFuncParams(context.Background(), test.command, nil))
|
||||
if res != nil {
|
||||
t.Errorf("expected nil response, got: %+v", res)
|
||||
}
|
||||
@@ -631,21 +702,28 @@ func Test_HandleMGET(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
for i, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("MGET, %d", i))
|
||||
// Set up the values
|
||||
for i, key := range test.presetKeys {
|
||||
_, err := mockServer.CreateKeyAndLock(context.Background(), key)
|
||||
_, err := mockServer.CreateKeyAndLock(ctx, key)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err = mockServer.SetValue(context.Background(), key, test.presetValues[i]); err != nil {
|
||||
if err = mockServer.SetValue(ctx, key, test.presetValues[i]); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
mockServer.KeyUnlock(context.Background(), key)
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
// Test the command and its results
|
||||
res, err := handleMGet(context.Background(), test.command, mockServer, nil)
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
// If we expect and error, branch out and check error
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
@@ -734,7 +812,13 @@ func Test_HandleDEL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
res, err := handleDel(ctx, test.command, mockServer, nil)
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedErr != nil {
|
||||
if err == nil {
|
||||
t.Errorf("exected error \"%s\", got nil", test.expectedErr.Error())
|
||||
@@ -844,7 +928,13 @@ func Test_HandlePERSIST(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
res, err := handlePersist(ctx, test.command, mockServer, nil)
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err == nil {
|
||||
@@ -965,7 +1055,13 @@ func Test_HandleEXPIRETIME(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
res, err := handleExpireTime(ctx, test.command, mockServer, nil)
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err == nil {
|
||||
@@ -1067,7 +1163,13 @@ func Test_HandleTTL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
res, err := handleTTL(ctx, test.command, mockServer, nil)
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err == nil {
|
||||
@@ -1300,7 +1402,13 @@ func Test_HandleEXPIRE(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
res, err := handleExpire(ctx, test.command, mockServer, nil)
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err == nil {
|
||||
@@ -1576,7 +1684,13 @@ func Test_HandleEXPIREAT(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
res, err := handleExpireAt(ctx, test.command, mockServer, nil)
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err == nil {
|
||||
@@ -12,25 +12,41 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package echovault
|
||||
package hash
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/pkg/commands"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/echovault"
|
||||
"github.com/echovault/echovault/pkg/modules/hash"
|
||||
"reflect"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEchoVault_HDEL(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
func createEchoVault() *echovault.EchoVault {
|
||||
ev, _ := echovault.NewEchoVault(
|
||||
echovault.WithCommands(hash.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
return ev
|
||||
}
|
||||
|
||||
func presetValue(server *echovault.EchoVault, ctx context.Context, key string, value interface{}) error {
|
||||
if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := server.SetValue(ctx, key, value); err != nil {
|
||||
return err
|
||||
}
|
||||
server.KeyUnlock(ctx, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestEchoVault_HDEL(t *testing.T) {
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -76,7 +92,11 @@ func TestEchoVault_HDEL(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.HDEL(tt.key, tt.fields...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -91,13 +111,7 @@ func TestEchoVault_HDEL(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_HEXISTS(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -135,7 +149,11 @@ func TestEchoVault_HEXISTS(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.HEXISTS(tt.key, tt.field)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -150,13 +168,7 @@ func TestEchoVault_HEXISTS(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_HGETALL(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -190,7 +202,11 @@ func TestEchoVault_HGETALL(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.HGETALL(tt.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -212,13 +228,7 @@ func TestEchoVault_HGETALL(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_HINCRBY(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
const (
|
||||
HINCRBY = "HINCRBY"
|
||||
@@ -300,7 +310,11 @@ func TestEchoVault_HINCRBY(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
var got float64
|
||||
var err error
|
||||
@@ -326,13 +340,7 @@ func TestEchoVault_HINCRBY(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_HKEYS(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -366,7 +374,11 @@ func TestEchoVault_HKEYS(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.HKEYS(tt.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -386,13 +398,7 @@ func TestEchoVault_HKEYS(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_HLEN(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -426,7 +432,11 @@ func TestEchoVault_HLEN(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.HLEN(tt.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -441,19 +451,13 @@ func TestEchoVault_HLEN(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_HRANDFIELD(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
presetValue interface{}
|
||||
key string
|
||||
options HRANDFIELDOptions
|
||||
options echovault.HRANDFIELDOptions
|
||||
wantCount int
|
||||
want []string
|
||||
wantErr bool
|
||||
@@ -462,7 +466,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
|
||||
name: "Get a random field",
|
||||
presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142},
|
||||
key: "key1",
|
||||
options: HRANDFIELDOptions{Count: 1},
|
||||
options: echovault.HRANDFIELDOptions{Count: 1},
|
||||
wantCount: 1,
|
||||
want: []string{"field1", "field2", "field3"},
|
||||
wantErr: false,
|
||||
@@ -471,7 +475,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
|
||||
name: "Get a random field with a value",
|
||||
presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142},
|
||||
key: "key2",
|
||||
options: HRANDFIELDOptions{WithValues: true, Count: 1},
|
||||
options: echovault.HRANDFIELDOptions{WithValues: true, Count: 1},
|
||||
wantCount: 2,
|
||||
want: []string{"field1", "value1", "field2", "123456789", "field3", "3.142"},
|
||||
wantErr: false,
|
||||
@@ -486,7 +490,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
|
||||
"field5": "value5",
|
||||
},
|
||||
key: "key3",
|
||||
options: HRANDFIELDOptions{Count: 3},
|
||||
options: echovault.HRANDFIELDOptions{Count: 3},
|
||||
wantCount: 3,
|
||||
want: []string{"field1", "field2", "field3", "field4", "field5"},
|
||||
wantErr: false,
|
||||
@@ -501,7 +505,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
|
||||
"field5": "value5",
|
||||
},
|
||||
key: "key4",
|
||||
options: HRANDFIELDOptions{WithValues: true, Count: 3},
|
||||
options: echovault.HRANDFIELDOptions{WithValues: true, Count: 3},
|
||||
wantCount: 6,
|
||||
want: []string{
|
||||
"field1", "value1", "field2", "123456789", "field3",
|
||||
@@ -519,7 +523,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
|
||||
"field5": "value5",
|
||||
},
|
||||
key: "key5",
|
||||
options: HRANDFIELDOptions{Count: 5},
|
||||
options: echovault.HRANDFIELDOptions{Count: 5},
|
||||
wantCount: 5,
|
||||
want: []string{"field1", "field2", "field3", "field4", "field5"},
|
||||
wantErr: false,
|
||||
@@ -534,7 +538,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
|
||||
"field5": "value5",
|
||||
},
|
||||
key: "key5",
|
||||
options: HRANDFIELDOptions{WithValues: true, Count: 5},
|
||||
options: echovault.HRANDFIELDOptions{WithValues: true, Count: 5},
|
||||
wantCount: 10,
|
||||
want: []string{
|
||||
"field1", "value1", "field2", "123456789", "field3",
|
||||
@@ -546,7 +550,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
|
||||
name: "Trying to get random field on a non hash map returns error",
|
||||
presetValue: "Default value",
|
||||
key: "key12",
|
||||
options: HRANDFIELDOptions{},
|
||||
options: echovault.HRANDFIELDOptions{},
|
||||
wantCount: 0,
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
@@ -555,7 +559,11 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.HRANDFIELD(tt.key, tt.options)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -575,13 +583,7 @@ func TestEchoVault_HRANDFIELD(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_HSET(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -650,7 +652,11 @@ func TestEchoVault_HSET(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := tt.hsetFunc(tt.key, tt.fieldValuePairs)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -665,13 +671,7 @@ func TestEchoVault_HSET(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_HSTRLEN(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -719,7 +719,11 @@ func TestEchoVault_HSTRLEN(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.HSTRLEN(tt.key, tt.fields...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -734,13 +738,7 @@ func TestEchoVault_HSTRLEN(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_HVALS(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -774,7 +772,11 @@ func TestEchoVault_HVALS(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.HVALS(tt.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -22,8 +22,12 @@ import (
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/echovault"
|
||||
"github.com/echovault/echovault/pkg/modules/hash"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -31,6 +35,7 @@ var mockServer *echovault.EchoVault
|
||||
|
||||
func init() {
|
||||
mockServer, _ = echovault.NewEchoVault(
|
||||
echovault.WithCommands(hash.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
@@ -38,6 +43,43 @@ func init() {
|
||||
)
|
||||
}
|
||||
|
||||
func getHandler(commands ...string) types.HandlerFunc {
|
||||
if len(commands) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, c := range mockServer.GetAllCommands() {
|
||||
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
|
||||
// Get command handler
|
||||
return c.HandlerFunc
|
||||
}
|
||||
if strings.EqualFold(commands[0], c.Command) {
|
||||
// Get sub-command handler
|
||||
for _, sc := range c.SubCommands {
|
||||
if strings.EqualFold(commands[1], sc.Command) {
|
||||
return sc.HandlerFunc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
|
||||
return types.HandlerFuncParams{
|
||||
Context: ctx,
|
||||
Command: cmd,
|
||||
Connection: conn,
|
||||
KeyExists: mockServer.KeyExists,
|
||||
CreateKeyAndLock: mockServer.CreateKeyAndLock,
|
||||
KeyLock: mockServer.KeyLock,
|
||||
KeyRLock: mockServer.KeyRLock,
|
||||
KeyUnlock: mockServer.KeyUnlock,
|
||||
KeyRUnlock: mockServer.KeyRUnlock,
|
||||
GetValue: mockServer.GetValue,
|
||||
SetValue: mockServer.SetValue,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleHSET(t *testing.T) {
|
||||
// Tests for both HSET and HSETNX
|
||||
tests := []struct {
|
||||
@@ -144,7 +186,14 @@ func Test_HandleHSET(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleHSET(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -163,11 +212,11 @@ func Test_HandleHSET(t *testing.T) {
|
||||
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
hash, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{})
|
||||
h, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{})
|
||||
if !ok {
|
||||
t.Errorf("value at key \"%s\" is not a hash map", test.key)
|
||||
}
|
||||
for field, value := range hash {
|
||||
for field, value := range h {
|
||||
if value != test.expectedValue[field] {
|
||||
t.Errorf("expected value \"%+v\" for field \"%+v\", got \"%+v\"", test.expectedValue[field], field, value)
|
||||
}
|
||||
@@ -303,7 +352,14 @@ func Test_HandleHINCRBY(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleHINCRBY(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -331,11 +387,11 @@ func Test_HandleHINCRBY(t *testing.T) {
|
||||
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
hash, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{})
|
||||
h, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{})
|
||||
if !ok {
|
||||
t.Errorf("value at key \"%s\" is not a hash map", test.key)
|
||||
}
|
||||
for field, value := range hash {
|
||||
for field, value := range h {
|
||||
if value != test.expectedValue[field] {
|
||||
t.Errorf("expected value \"%+v\" for field \"%+v\", got \"%+v\"", test.expectedValue[field], field, value)
|
||||
}
|
||||
@@ -410,7 +466,14 @@ func Test_HandleHGET(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleHGET(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -519,7 +582,14 @@ func Test_HandleHSTRLEN(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleHSTRLEN(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -623,7 +693,14 @@ func Test_HandleHVALS(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleHVALS(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -831,7 +908,14 @@ func Test_HandleHRANDFIELD(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleHRANDFIELD(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -954,7 +1038,14 @@ func Test_HandleHLEN(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleHLEN(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1053,7 +1144,14 @@ func Test_HandleHKeys(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleHKEYS(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1159,7 +1257,14 @@ func Test_HandleHGETALL(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleHGETALL(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1276,7 +1381,14 @@ func Test_HandleHEXISTS(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleHEXISTS(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1375,7 +1487,14 @@ func Test_HandleHDEL(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleHDEL(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1396,8 +1515,8 @@ func Test_HandleHDEL(t *testing.T) {
|
||||
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if hash, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{}); ok {
|
||||
for field, value := range hash {
|
||||
if h, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{}); ok {
|
||||
for field, value := range h {
|
||||
if value != test.expectedValue[field] {
|
||||
t.Errorf("expected value \"%+v\", got \"%+v\"", test.expectedValue[field], value)
|
||||
}
|
||||
@@ -12,24 +12,40 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package echovault
|
||||
package list
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/pkg/commands"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/echovault"
|
||||
"github.com/echovault/echovault/pkg/modules/list"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEchoVault_LLEN(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
func createEchoVault() *echovault.EchoVault {
|
||||
ev, _ := echovault.NewEchoVault(
|
||||
echovault.WithCommands(list.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
return ev
|
||||
}
|
||||
|
||||
func presetValue(server *echovault.EchoVault, ctx context.Context, key string, value interface{}) error {
|
||||
if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := server.SetValue(ctx, key, value); err != nil {
|
||||
return err
|
||||
}
|
||||
server.KeyUnlock(ctx, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestEchoVault_LLEN(t *testing.T) {
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
preset bool
|
||||
@@ -67,7 +83,11 @@ func TestEchoVault_LLEN(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.preset {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.LLEN(tt.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -82,13 +102,7 @@ func TestEchoVault_LLEN(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_LINDEX(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
preset bool
|
||||
@@ -156,7 +170,11 @@ func TestEchoVault_LINDEX(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if tt.preset {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := server.LINDEX(tt.key, tt.index)
|
||||
@@ -172,13 +190,7 @@ func TestEchoVault_LINDEX(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_LMOVE(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -328,7 +340,11 @@ func TestEchoVault_LMOVE(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.preset {
|
||||
for k, v := range tt.presetValue {
|
||||
presetValue(server, k, v)
|
||||
err := presetValue(server, context.Background(), k, v)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
got, err := server.LMOVE(tt.source, tt.destination, tt.whereFrom, tt.whereTo)
|
||||
@@ -344,13 +360,7 @@ func TestEchoVault_LMOVE(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_POP(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -401,7 +411,11 @@ func TestEchoVault_POP(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.preset {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := tt.popFunc(tt.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -416,13 +430,7 @@ func TestEchoVault_POP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_LPUSH(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -478,7 +486,11 @@ func TestEchoVault_LPUSH(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.preset {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := tt.lpushFunc(tt.key, tt.values...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -493,13 +505,7 @@ func TestEchoVault_LPUSH(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_RPUSH(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -535,7 +541,11 @@ func TestEchoVault_RPUSH(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.preset {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := tt.rpushFunc(tt.key, tt.values...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -550,13 +560,7 @@ func TestEchoVault_RPUSH(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_LRANGE(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -656,7 +660,11 @@ func TestEchoVault_LRANGE(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.preset {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.LRANGE(tt.key, tt.start, tt.end)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -671,13 +679,7 @@ func TestEchoVault_LRANGE(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_LREM(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -722,7 +724,11 @@ func TestEchoVault_LREM(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if tt.preset {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := server.LREM(tt.key, tt.count, tt.value)
|
||||
@@ -738,13 +744,7 @@ func TestEchoVault_LREM(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_LSET(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -830,7 +830,11 @@ func TestEchoVault_LSET(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.preset {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.LSET(tt.key, tt.index, tt.value)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -845,13 +849,7 @@ func TestEchoVault_LSET(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_LTRIM(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -940,7 +938,11 @@ func TestEchoVault_LTRIM(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.preset {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.LTRIM(tt.key, tt.start, tt.end)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -22,7 +22,11 @@ import (
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/echovault"
|
||||
"github.com/echovault/echovault/pkg/modules/list"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -30,6 +34,7 @@ var mockServer *echovault.EchoVault
|
||||
|
||||
func init() {
|
||||
mockServer, _ = echovault.NewEchoVault(
|
||||
echovault.WithCommands(list.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
@@ -37,6 +42,43 @@ func init() {
|
||||
)
|
||||
}
|
||||
|
||||
func getHandler(commands ...string) types.HandlerFunc {
|
||||
if len(commands) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, c := range mockServer.GetAllCommands() {
|
||||
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
|
||||
// Get command handler
|
||||
return c.HandlerFunc
|
||||
}
|
||||
if strings.EqualFold(commands[0], c.Command) {
|
||||
// Get sub-command handler
|
||||
for _, sc := range c.SubCommands {
|
||||
if strings.EqualFold(commands[1], sc.Command) {
|
||||
return sc.HandlerFunc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
|
||||
return types.HandlerFuncParams{
|
||||
Context: ctx,
|
||||
Command: cmd,
|
||||
Connection: conn,
|
||||
KeyExists: mockServer.KeyExists,
|
||||
CreateKeyAndLock: mockServer.CreateKeyAndLock,
|
||||
KeyLock: mockServer.KeyLock,
|
||||
KeyRLock: mockServer.KeyRLock,
|
||||
KeyUnlock: mockServer.KeyUnlock,
|
||||
KeyRUnlock: mockServer.KeyRUnlock,
|
||||
GetValue: mockServer.GetValue,
|
||||
SetValue: mockServer.SetValue,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleLLEN(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -113,7 +155,14 @@ func Test_HandleLLEN(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleLLen(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -258,7 +307,14 @@ func Test_HandleLINDEX(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleLIndex(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -426,7 +482,14 @@ func Test_HandleLRANGE(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleLRange(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -577,7 +640,14 @@ func Test_HandleLSET(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleLSet(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -595,16 +665,16 @@ func Test_HandleLSET(t *testing.T) {
|
||||
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
|
||||
l, ok := mockServer.GetValue(ctx, test.key).([]interface{})
|
||||
if !ok {
|
||||
t.Error("expected value to be list, got another type")
|
||||
}
|
||||
if len(list) != len(test.expectedValue) {
|
||||
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(list))
|
||||
if len(l) != len(test.expectedValue) {
|
||||
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(l))
|
||||
}
|
||||
for i := 0; i < len(list); i++ {
|
||||
if list[i] != test.expectedValue[i] {
|
||||
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
|
||||
for i := 0; i < len(l); i++ {
|
||||
if l[i] != test.expectedValue[i] {
|
||||
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], l[i])
|
||||
}
|
||||
}
|
||||
mockServer.KeyRUnlock(ctx, test.key)
|
||||
@@ -751,7 +821,14 @@ func Test_HandleLTRIM(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleLTrim(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -769,16 +846,16 @@ func Test_HandleLTRIM(t *testing.T) {
|
||||
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
|
||||
l, ok := mockServer.GetValue(ctx, test.key).([]interface{})
|
||||
if !ok {
|
||||
t.Error("expected value to be list, got another type")
|
||||
}
|
||||
if len(list) != len(test.expectedValue) {
|
||||
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(list))
|
||||
if len(l) != len(test.expectedValue) {
|
||||
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(l))
|
||||
}
|
||||
for i := 0; i < len(list); i++ {
|
||||
if list[i] != test.expectedValue[i] {
|
||||
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
|
||||
for i := 0; i < len(l); i++ {
|
||||
if l[i] != test.expectedValue[i] {
|
||||
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], l[i])
|
||||
}
|
||||
}
|
||||
mockServer.KeyRUnlock(ctx, test.key)
|
||||
@@ -882,7 +959,14 @@ func Test_HandleLREM(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleLRem(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -900,16 +984,16 @@ func Test_HandleLREM(t *testing.T) {
|
||||
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
|
||||
l, ok := mockServer.GetValue(ctx, test.key).([]interface{})
|
||||
if !ok {
|
||||
t.Error("expected value to be list, got another type")
|
||||
}
|
||||
if len(list) != len(test.expectedValue) {
|
||||
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(list))
|
||||
if len(l) != len(test.expectedValue) {
|
||||
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(l))
|
||||
}
|
||||
for i := 0; i < len(list); i++ {
|
||||
if list[i] != test.expectedValue[i] {
|
||||
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
|
||||
for i := 0; i < len(l); i++ {
|
||||
if l[i] != test.expectedValue[i] {
|
||||
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], l[i])
|
||||
}
|
||||
}
|
||||
mockServer.KeyRUnlock(ctx, test.key)
|
||||
@@ -1096,7 +1180,14 @@ func Test_HandleLMOVE(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleLMove(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1115,7 +1206,7 @@ func Test_HandleLMOVE(t *testing.T) {
|
||||
if _, err = mockServer.KeyRLock(ctx, key); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
list, ok := mockServer.GetValue(ctx, key).([]interface{})
|
||||
l, ok := mockServer.GetValue(ctx, key).([]interface{})
|
||||
if !ok {
|
||||
t.Error("expected value to be list, got another type")
|
||||
}
|
||||
@@ -1123,12 +1214,12 @@ func Test_HandleLMOVE(t *testing.T) {
|
||||
if !ok {
|
||||
t.Error("expected test value to be list, got another type")
|
||||
}
|
||||
if len(list) != len(expectedList) {
|
||||
t.Errorf("expected list length to be %d, got %d", len(expectedList), len(list))
|
||||
if len(l) != len(expectedList) {
|
||||
t.Errorf("expected list length to be %d, got %d", len(expectedList), len(l))
|
||||
}
|
||||
for i := 0; i < len(list); i++ {
|
||||
if list[i] != expectedList[i] {
|
||||
t.Errorf("expected element at index %d to be %+v, got %+v", i, expectedList[i], list[i])
|
||||
for i := 0; i < len(l); i++ {
|
||||
if l[i] != expectedList[i] {
|
||||
t.Errorf("expected element at index %d to be %+v, got %+v", i, expectedList[i], l[i])
|
||||
}
|
||||
}
|
||||
mockServer.KeyRUnlock(ctx, key)
|
||||
@@ -1213,7 +1304,14 @@ func Test_HandleLPUSH(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleLPush(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1231,16 +1329,16 @@ func Test_HandleLPUSH(t *testing.T) {
|
||||
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
|
||||
l, ok := mockServer.GetValue(ctx, test.key).([]interface{})
|
||||
if !ok {
|
||||
t.Error("expected value to be list, got another type")
|
||||
}
|
||||
if len(list) != len(test.expectedValue) {
|
||||
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(list))
|
||||
if len(l) != len(test.expectedValue) {
|
||||
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(l))
|
||||
}
|
||||
for i := 0; i < len(list); i++ {
|
||||
if list[i] != test.expectedValue[i] {
|
||||
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
|
||||
for i := 0; i < len(l); i++ {
|
||||
if l[i] != test.expectedValue[i] {
|
||||
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], l[i])
|
||||
}
|
||||
}
|
||||
mockServer.KeyRUnlock(ctx, test.key)
|
||||
@@ -1324,7 +1422,14 @@ func Test_HandleRPUSH(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleRPush(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1342,16 +1447,16 @@ func Test_HandleRPUSH(t *testing.T) {
|
||||
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
|
||||
l, ok := mockServer.GetValue(ctx, test.key).([]interface{})
|
||||
if !ok {
|
||||
t.Error("expected value to be list, got another type")
|
||||
}
|
||||
if len(list) != len(test.expectedValue) {
|
||||
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(list))
|
||||
if len(l) != len(test.expectedValue) {
|
||||
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(l))
|
||||
}
|
||||
for i := 0; i < len(list); i++ {
|
||||
if list[i] != test.expectedValue[i] {
|
||||
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
|
||||
for i := 0; i < len(l); i++ {
|
||||
if l[i] != test.expectedValue[i] {
|
||||
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], l[i])
|
||||
}
|
||||
}
|
||||
mockServer.KeyRUnlock(ctx, test.key)
|
||||
@@ -1445,7 +1550,14 @@ func Test_HandlePOP(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handlePop(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1463,16 +1575,16 @@ func Test_HandlePOP(t *testing.T) {
|
||||
if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
|
||||
l, ok := mockServer.GetValue(ctx, test.key).([]interface{})
|
||||
if !ok {
|
||||
t.Error("expected value to be list, got another type")
|
||||
}
|
||||
if len(list) != len(test.expectedValue) {
|
||||
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(list))
|
||||
if len(l) != len(test.expectedValue) {
|
||||
t.Errorf("expected list length to be %d, got %d", len(test.expectedValue), len(l))
|
||||
}
|
||||
for i := 0; i < len(list); i++ {
|
||||
if list[i] != test.expectedValue[i] {
|
||||
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
|
||||
for i := 0; i < len(l); i++ {
|
||||
if l[i] != test.expectedValue[i] {
|
||||
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], l[i])
|
||||
}
|
||||
}
|
||||
mockServer.KeyRUnlock(ctx, test.key)
|
||||
@@ -12,4 +12,4 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package echovault
|
||||
package pubsub
|
||||
@@ -22,9 +22,12 @@ import (
|
||||
internal_pubsub "github.com/echovault/echovault/internal/pubsub"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/echovault"
|
||||
ps "github.com/echovault/echovault/pkg/modules/pubsub"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -51,7 +54,7 @@ func init() {
|
||||
|
||||
func setUpServer(bindAddr string, port uint16) *echovault.EchoVault {
|
||||
server, _ := echovault.NewEchoVault(
|
||||
echovault.WithCommands(Commands()),
|
||||
echovault.WithCommands(ps.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
BindAddr: bindAddr,
|
||||
Port: port,
|
||||
@@ -62,6 +65,36 @@ func setUpServer(bindAddr string, port uint16) *echovault.EchoVault {
|
||||
return server
|
||||
}
|
||||
|
||||
func getHandler(commands ...string) types.HandlerFunc {
|
||||
if len(commands) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, c := range mockServer.GetAllCommands() {
|
||||
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
|
||||
// Get command handler
|
||||
return c.HandlerFunc
|
||||
}
|
||||
if strings.EqualFold(commands[0], c.Command) {
|
||||
// Get sub-command handler
|
||||
for _, sc := range c.SubCommands {
|
||||
if strings.EqualFold(commands[1], sc.Command) {
|
||||
return sc.HandlerFunc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn, mockServer *echovault.EchoVault) types.HandlerFuncParams {
|
||||
return types.HandlerFuncParams{
|
||||
Context: ctx,
|
||||
Command: cmd,
|
||||
Connection: conn,
|
||||
GetPubSub: mockServer.GetPubSub,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleSubscribe(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), "test_name", "SUBSCRIBE/PSUBSCRIBE")
|
||||
|
||||
@@ -86,7 +119,8 @@ func Test_HandleSubscribe(t *testing.T) {
|
||||
// Test subscribe to channels
|
||||
channels := []string{"sub_channel1", "sub_channel2", "sub_channel3"}
|
||||
for _, conn := range connections {
|
||||
if _, err := handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, conn); err != nil {
|
||||
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), conn, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
@@ -116,7 +150,8 @@ func Test_HandleSubscribe(t *testing.T) {
|
||||
// Test subscribe to patterns
|
||||
patterns := []string{"psub_channel*"}
|
||||
for _, conn := range connections {
|
||||
if _, err := handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, patterns...), mockServer, conn); err != nil {
|
||||
_, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), conn, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
@@ -263,24 +298,24 @@ func Test_HandleUnsubscribe(t *testing.T) {
|
||||
|
||||
// Subscribe all the connections to the channels and patterns
|
||||
for _, conn := range append(test.otherConnections, test.targetConn) {
|
||||
_, err := handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, test.subChannels...), mockServer, conn)
|
||||
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, test.subChannels...), conn, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
_, err = handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, test.subPatterns...), mockServer, conn)
|
||||
_, err = getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, test.subPatterns...), conn, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Unsubscribe the target connection from the unsub channels and patterns
|
||||
res, err := handleUnsubscribe(ctx, append([]string{"UNSUBSCRIBE"}, test.unSubChannels...), mockServer, test.targetConn)
|
||||
res, err := getHandler("UNSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"UNSUBSCRIBE"}, test.unSubChannels...), test.targetConn, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
verifyResponse(res, test.expectedResponses["channel"])
|
||||
|
||||
res, err = handleUnsubscribe(ctx, append([]string{"PUNSUBSCRIBE"}, test.unSubPatterns...), mockServer, test.targetConn)
|
||||
res, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PUNSUBSCRIBE"}, test.unSubPatterns...), test.targetConn, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -347,7 +382,7 @@ func Test_HandlePublish(t *testing.T) {
|
||||
subscribe := func(ctx context.Context, channels []string, patterns []string, c *net.Conn, r *resp.Conn) {
|
||||
// Subscribe to channels
|
||||
go func() {
|
||||
_, _ = handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, c)
|
||||
_, _ = getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), c, mockServer))
|
||||
}()
|
||||
// Verify all the responses for each channel subscription
|
||||
for i := 0; i < len(channels); i++ {
|
||||
@@ -355,7 +390,7 @@ func Test_HandlePublish(t *testing.T) {
|
||||
}
|
||||
// Subscribe to all the patterns
|
||||
go func() {
|
||||
_, _ = handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, patterns...), mockServer, c)
|
||||
_, _ = getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), c, mockServer))
|
||||
}()
|
||||
// Verify all the responses for each pattern subscription
|
||||
for i := 0; i < len(patterns); i++ {
|
||||
@@ -518,7 +553,7 @@ func Test_HandlePubSubChannels(t *testing.T) {
|
||||
|
||||
// Subscribe connections to channels
|
||||
go func() {
|
||||
_, err := handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, &wConn1)
|
||||
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), &wConn1, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -535,7 +570,7 @@ func Test_HandlePubSubChannels(t *testing.T) {
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
_, err := handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, patterns...), mockServer, &wConn2)
|
||||
_, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), &wConn2, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -571,7 +606,7 @@ func Test_HandlePubSubChannels(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check if all subscriptions are returned
|
||||
res, err := handlePubSubChannels(ctx, []string{"PUBSUB", "CHANNELS"}, mockServer, nil)
|
||||
res, err := getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -579,45 +614,45 @@ func Test_HandlePubSubChannels(t *testing.T) {
|
||||
|
||||
// Unsubscribe from one pattern and one channel before checking against a new slice of
|
||||
// expected channels/patterns in the response of the "PUBSUB CHANNELS" command
|
||||
_, err = handleUnsubscribe(
|
||||
_, err = getHandler("UNSUBSCRIBE")(getHandlerFuncParams(
|
||||
ctx,
|
||||
append([]string{"UNSUBSCRIBE"}, []string{"channel_2", "channel_3"}...),
|
||||
mockServer,
|
||||
&wConn1,
|
||||
)
|
||||
mockServer,
|
||||
))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
_, err = handleUnsubscribe(
|
||||
_, err = getHandler("UNSUBSCRIBE")(getHandlerFuncParams(
|
||||
ctx,
|
||||
append([]string{"UNSUBSCRIBE"}, "channel_[456]"),
|
||||
mockServer,
|
||||
&wConn2,
|
||||
)
|
||||
mockServer,
|
||||
))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Return all the remaining channels
|
||||
res, err = handlePubSubChannels(ctx, []string{"PUBSUB", "CHANNELS"}, mockServer, nil)
|
||||
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"})
|
||||
// Return only one of the remaining channels when passed a pattern that matches it
|
||||
res, err = handlePubSubChannels(ctx, []string{"PUBSUB", "CHANNELS", "channel_[189]"}, mockServer, nil)
|
||||
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS", "channel_[189]"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
verifyExpectedResponse(res, []string{"channel_1"})
|
||||
// Return both remaining channels when passed a pattern that matches them
|
||||
res, err = handlePubSubChannels(ctx, []string{"PUBSUB", "CHANNELS", "channel_[123]"}, mockServer, nil)
|
||||
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS", "channel_[123]"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"})
|
||||
// Return none channels when passed a pattern that does not match either channel
|
||||
res, err = handlePubSubChannels(ctx, []string{"PUBSUB", "CHANNELS", "channel_[456]"}, mockServer, nil)
|
||||
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS", "channel_[456]"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -655,7 +690,8 @@ func Test_HandleNumPat(t *testing.T) {
|
||||
r *resp.Conn
|
||||
}{w: &w, r: resp.NewConn(r)}
|
||||
go func() {
|
||||
if _, err := handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, patterns...), mockServer, &w); err != nil {
|
||||
_, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), &w, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
@@ -685,7 +721,7 @@ func Test_HandleNumPat(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check that we receive all the patterns with NUMPAT commands
|
||||
res, err := handlePubSubNumPat(ctx, []string{"PUBSUB", "NUMPAT"}, mockServer, nil)
|
||||
res, err := getHandler("PUBSUB", "NUMPAT")(getHandlerFuncParams(ctx, []string{"PUBSUB", "NUMPAT"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -693,12 +729,12 @@ func Test_HandleNumPat(t *testing.T) {
|
||||
|
||||
// Unsubscribe from a channel and check if the number of active channels is updated
|
||||
for _, conn := range connections {
|
||||
_, err = handleUnsubscribe(ctx, []string{"PUNSUBSCRIBE", patterns[0]}, mockServer, conn.w)
|
||||
_, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, []string{"PUNSUBSCRIBE", patterns[0]}, conn.w, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
res, err = handlePubSubNumPat(ctx, []string{"PUBSUB", "NUMPAT"}, mockServer, nil)
|
||||
res, err = getHandler("PUBSUB", "NUMPAT")(getHandlerFuncParams(ctx, []string{"PUBSUB", "NUMPAT"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -706,12 +742,12 @@ func Test_HandleNumPat(t *testing.T) {
|
||||
|
||||
// Unsubscribe from all the channels and check if we get a 0 response
|
||||
for _, conn := range connections {
|
||||
_, err = handleUnsubscribe(ctx, []string{"PUNSUBSCRIBE"}, mockServer, conn.w)
|
||||
_, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, []string{"PUNSUBSCRIBE"}, conn.w, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
res, err = handlePubSubNumPat(ctx, []string{"PUBSUB", "NUMPAT"}, mockServer, nil)
|
||||
res, err = getHandler("PUBSUB", "NUMPAT")(getHandlerFuncParams(ctx, []string{"PUBSUB", "NUMPAT"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -748,7 +784,8 @@ func Test_HandleNumSub(t *testing.T) {
|
||||
r *resp.Conn
|
||||
}{w: &w, r: resp.NewConn(r)}
|
||||
go func() {
|
||||
if _, err := handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, &w); err != nil {
|
||||
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), &w, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
@@ -793,7 +830,7 @@ func Test_HandleNumSub(t *testing.T) {
|
||||
for i, test := range tests {
|
||||
ctx = context.WithValue(ctx, "test_index", i)
|
||||
|
||||
res, err := handlePubSubNumSubs(ctx, test.cmd, mockServer, nil)
|
||||
res, err := getHandler("PUBSUB", "NUMSUB")(getHandlerFuncParams(ctx, test.cmd, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -12,26 +12,42 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package echovault
|
||||
package set
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/internal/set"
|
||||
"github.com/echovault/echovault/pkg/commands"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/echovault"
|
||||
s "github.com/echovault/echovault/pkg/modules/set"
|
||||
"reflect"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEchoVault_SADD(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
func createEchoVault() *echovault.EchoVault {
|
||||
ev, _ := echovault.NewEchoVault(
|
||||
echovault.WithCommands(s.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
return ev
|
||||
}
|
||||
|
||||
func presetValue(server *echovault.EchoVault, ctx context.Context, key string, value interface{}) error {
|
||||
if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := server.SetValue(ctx, key, value); err != nil {
|
||||
return err
|
||||
}
|
||||
server.KeyUnlock(ctx, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestEchoVault_SADD(t *testing.T) {
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -69,7 +85,11 @@ func TestEchoVault_SADD(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.SADD(tt.key, tt.members...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -84,13 +104,7 @@ func TestEchoVault_SADD(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SCARD(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -124,7 +138,11 @@ func TestEchoVault_SCARD(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.SCARD(tt.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -139,13 +157,7 @@ func TestEchoVault_SCARD(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SDIFF(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -212,7 +224,11 @@ func TestEchoVault_SDIFF(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, v := range tt.presetValues {
|
||||
presetValue(server, k, v)
|
||||
err := presetValue(server, context.Background(), k, v)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
got, err := server.SDIFF(tt.keys...)
|
||||
@@ -233,13 +249,7 @@ func TestEchoVault_SDIFF(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SDIFFSTORE(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -312,7 +322,11 @@ func TestEchoVault_SDIFFSTORE(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, v := range tt.presetValues {
|
||||
presetValue(server, k, v)
|
||||
err := presetValue(server, context.Background(), k, v)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
got, err := server.SDIFFSTORE(tt.destination, tt.keys...)
|
||||
@@ -328,13 +342,7 @@ func TestEchoVault_SDIFFSTORE(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SINTER(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -401,7 +409,11 @@ func TestEchoVault_SINTER(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, v := range tt.presetValues {
|
||||
presetValue(server, k, v)
|
||||
err := presetValue(server, context.Background(), k, v)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
got, err := server.SINTER(tt.keys...)
|
||||
@@ -422,13 +434,7 @@ func TestEchoVault_SINTER(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SINTERCARD(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -512,7 +518,11 @@ func TestEchoVault_SINTERCARD(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, v := range tt.presetValues {
|
||||
presetValue(server, k, v)
|
||||
err := presetValue(server, context.Background(), k, v)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
got, err := server.SINTERCARD(tt.keys, tt.limit)
|
||||
@@ -528,13 +538,7 @@ func TestEchoVault_SINTERCARD(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SINTERSTORE(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -607,7 +611,11 @@ func TestEchoVault_SINTERSTORE(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, v := range tt.presetValues {
|
||||
presetValue(server, k, v)
|
||||
err := presetValue(server, context.Background(), k, v)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
got, err := server.SINTERSTORE(tt.destination, tt.keys...)
|
||||
@@ -623,13 +631,7 @@ func TestEchoVault_SINTERSTORE(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SISMEMBER(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -667,7 +669,11 @@ func TestEchoVault_SISMEMBER(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.SISMEMBER(tt.key, tt.member)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -682,13 +688,7 @@ func TestEchoVault_SISMEMBER(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SMEMBERS(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -722,7 +722,11 @@ func TestEchoVault_SMEMBERS(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.SMEMBERS(tt.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -742,13 +746,7 @@ func TestEchoVault_SMEMBERS(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SMISMEMBER(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -805,7 +803,11 @@ func TestEchoVault_SMISMEMBER(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.SMISMEMBER(tt.key, tt.members...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -820,13 +822,7 @@ func TestEchoVault_SMISMEMBER(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SMOVE(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -890,7 +886,11 @@ func TestEchoVault_SMOVE(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, v := range tt.presetValues {
|
||||
presetValue(server, k, v)
|
||||
err := presetValue(server, context.Background(), k, v)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
got, err := server.SMOVE(tt.source, tt.destination, tt.member)
|
||||
@@ -906,13 +906,7 @@ func TestEchoVault_SMOVE(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SPOP(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -942,7 +936,11 @@ func TestEchoVault_SPOP(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.SPOP(tt.key, tt.count)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -959,13 +957,7 @@ func TestEchoVault_SPOP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SRANDMEMBER(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1007,7 +999,11 @@ func TestEchoVault_SRANDMEMBER(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.SRANDMEMBER(tt.key, tt.count)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -1028,13 +1024,7 @@ func TestEchoVault_SRANDMEMBER(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SREM(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1072,7 +1062,11 @@ func TestEchoVault_SREM(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValue != nil {
|
||||
presetValue(server, tt.key, tt.presetValue)
|
||||
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
got, err := server.SREM(tt.key, tt.members...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@@ -1087,13 +1081,7 @@ func TestEchoVault_SREM(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SUNION(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1153,7 +1141,11 @@ func TestEchoVault_SUNION(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, v := range tt.presetValues {
|
||||
presetValue(server, k, v)
|
||||
err := presetValue(server, context.Background(), k, v)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
got, err := server.SUNION(tt.keys...)
|
||||
@@ -1174,13 +1166,7 @@ func TestEchoVault_SUNION(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SUNIONSTORE(t *testing.T) {
|
||||
server, _ := NewEchoVault(
|
||||
WithCommands(commands.All()),
|
||||
WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1230,7 +1216,11 @@ func TestEchoVault_SUNIONSTORE(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.presetValues != nil {
|
||||
for k, v := range tt.presetValues {
|
||||
presetValue(server, k, v)
|
||||
err := presetValue(server, context.Background(), k, v)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
got, err := server.SUNIONSTORE(tt.destination, tt.keys...)
|
||||
@@ -23,8 +23,12 @@ import (
|
||||
"github.com/echovault/echovault/internal/set"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/echovault"
|
||||
s "github.com/echovault/echovault/pkg/modules/set"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -32,6 +36,7 @@ var mockServer *echovault.EchoVault
|
||||
|
||||
func init() {
|
||||
mockServer, _ = echovault.NewEchoVault(
|
||||
echovault.WithCommands(s.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
@@ -39,6 +44,43 @@ func init() {
|
||||
)
|
||||
}
|
||||
|
||||
func getHandler(commands ...string) types.HandlerFunc {
|
||||
if len(commands) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, c := range mockServer.GetAllCommands() {
|
||||
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
|
||||
// Get command handler
|
||||
return c.HandlerFunc
|
||||
}
|
||||
if strings.EqualFold(commands[0], c.Command) {
|
||||
// Get sub-command handler
|
||||
for _, sc := range c.SubCommands {
|
||||
if strings.EqualFold(commands[1], sc.Command) {
|
||||
return sc.HandlerFunc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
|
||||
return types.HandlerFuncParams{
|
||||
Context: ctx,
|
||||
Command: cmd,
|
||||
Connection: conn,
|
||||
KeyExists: mockServer.KeyExists,
|
||||
CreateKeyAndLock: mockServer.CreateKeyAndLock,
|
||||
KeyLock: mockServer.KeyLock,
|
||||
KeyRLock: mockServer.KeyRLock,
|
||||
KeyUnlock: mockServer.KeyUnlock,
|
||||
KeyRUnlock: mockServer.KeyRUnlock,
|
||||
GetValue: mockServer.GetValue,
|
||||
SetValue: mockServer.SetValue,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleSADD(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -103,7 +145,15 @@ func Test_HandleSADD(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleSADD(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -214,7 +264,15 @@ func Test_HandleSCARD(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleSCARD(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -327,7 +385,15 @@ func Test_HandleSDIFF(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleSDIFF(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -454,7 +520,15 @@ func Test_HandleSDIFFSTORE(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleSDIFFSTORE(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -582,7 +656,15 @@ func Test_HandleSINTER(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleSINTER(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -709,7 +791,15 @@ func Test_HandleSINTERCARD(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleSINTERCARD(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -834,7 +924,13 @@ func Test_HandleSINTERSTORE(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleSINTERSTORE(ctx, test.command, mockServer, nil)
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -939,7 +1035,14 @@ func Test_HandleSISMEMBER(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleSISMEMBER(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1027,7 +1130,14 @@ func Test_HandleSMEMBERS(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleSMEMBERS(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1118,7 +1228,14 @@ func Test_HandleSMISMEMBER(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleSMISMEMBER(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1160,7 +1277,7 @@ func Test_HandleSMOVE(t *testing.T) {
|
||||
"SmoveSource1": set.NewSet([]string{"one", "two", "three", "four"}),
|
||||
"SmoveDestination1": set.NewSet([]string{"five", "six", "seven", "eight"}),
|
||||
},
|
||||
command: []string{"MOVE", "SmoveSource1", "SmoveDestination1", "four"},
|
||||
command: []string{"SMOVE", "SmoveSource1", "SmoveDestination1", "four"},
|
||||
expectedValues: map[string]interface{}{
|
||||
"SmoveSource1": set.NewSet([]string{"one", "two", "three"}),
|
||||
"SmoveDestination1": set.NewSet([]string{"four", "five", "six", "seven", "eight"}),
|
||||
@@ -1242,7 +1359,14 @@ func Test_HandleSMOVE(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleSMOVE(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1350,7 +1474,14 @@ func Test_HandleSPOP(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleSPOP(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1473,7 +1604,14 @@ func Test_HandleSRANDMEMBER(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleSRANDMEMBER(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1588,7 +1726,14 @@ func Test_HandleSREM(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleSREM(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1708,7 +1853,14 @@ func Test_HandleSUNION(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleSUNION(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1811,7 +1963,14 @@ func Test_HandleSUNIONSTORE(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleSUNIONSTORE(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,10 +23,14 @@ import (
|
||||
"github.com/echovault/echovault/internal/sorted_set"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/echovault"
|
||||
ss "github.com/echovault/echovault/pkg/modules/sorted_set"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"github.com/tidwall/resp"
|
||||
"math"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -34,6 +38,7 @@ var mockServer *echovault.EchoVault
|
||||
|
||||
func init() {
|
||||
mockServer, _ = echovault.NewEchoVault(
|
||||
echovault.WithCommands(ss.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
@@ -41,6 +46,43 @@ func init() {
|
||||
)
|
||||
}
|
||||
|
||||
func getHandler(commands ...string) types.HandlerFunc {
|
||||
if len(commands) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, c := range mockServer.GetAllCommands() {
|
||||
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
|
||||
// Get command handler
|
||||
return c.HandlerFunc
|
||||
}
|
||||
if strings.EqualFold(commands[0], c.Command) {
|
||||
// Get sub-command handler
|
||||
for _, sc := range c.SubCommands {
|
||||
if strings.EqualFold(commands[1], sc.Command) {
|
||||
return sc.HandlerFunc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
|
||||
return types.HandlerFuncParams{
|
||||
Context: ctx,
|
||||
Command: cmd,
|
||||
Connection: conn,
|
||||
KeyExists: mockServer.KeyExists,
|
||||
CreateKeyAndLock: mockServer.CreateKeyAndLock,
|
||||
KeyLock: mockServer.KeyLock,
|
||||
KeyRLock: mockServer.KeyRLock,
|
||||
KeyUnlock: mockServer.KeyUnlock,
|
||||
KeyRUnlock: mockServer.KeyRUnlock,
|
||||
GetValue: mockServer.GetValue,
|
||||
SetValue: mockServer.SetValue,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleZADD(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -273,7 +315,15 @@ func Test_HandleZADD(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleZADD(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -390,7 +440,15 @@ func Test_HandleZCARD(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleZCARD(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -542,7 +600,15 @@ func Test_HandleZCOUNT(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleZCOUNT(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -666,7 +732,15 @@ func Test_HandleZLEXCOUNT(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleZLEXCOUNT(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -830,7 +904,15 @@ func Test_HandleZDIFF(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZDIFF(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1016,7 +1098,15 @@ func Test_HandleZDIFFSTORE(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZDIFFSTORE(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1246,7 +1336,15 @@ func Test_HandleZINCRBY(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleZINCRBY(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1482,7 +1580,15 @@ func Test_HandleZMPOP(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZMPOP(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1665,7 +1771,15 @@ func Test_HandleZPOP(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZPOP(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1775,7 +1889,15 @@ func Test_HandleZMSCORE(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZMSCORE(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -1886,7 +2008,15 @@ func Test_HandleZSCORE(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZSCORE(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -2012,7 +2142,15 @@ func Test_HandleZRANDMEMBER(t *testing.T) {
|
||||
}
|
||||
mockServer.KeyUnlock(ctx, test.key)
|
||||
}
|
||||
res, err := handleZRANDMEMBER(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -2181,7 +2319,15 @@ func Test_HandleZRANK(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZRANK(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -2289,7 +2435,15 @@ func Test_HandleZREM(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZREM(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -2405,7 +2559,15 @@ func Test_HandleZREMRANGEBYSCORE(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZREMRANGEBYSCORE(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -2578,7 +2740,15 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZREMRANGEBYRANK(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -2720,7 +2890,15 @@ func Test_HandleZREMRANGEBYLEX(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZREMRANGEBYLEX(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -2990,7 +3168,15 @@ func Test_HandleZRANGE(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZRANGE(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -3297,7 +3483,15 @@ func Test_HandleZRANGESTORE(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZRANGESTORE(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -3629,7 +3823,15 @@ func Test_HandleZINTER(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZINTER(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -3992,7 +4194,15 @@ func Test_HandleZINTERSTORE(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZINTERSTORE(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -4351,7 +4561,15 @@ func Test_HandleZUNION(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZUNION(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -4753,7 +4971,15 @@ func Test_HandleZUNIONSTORE(t *testing.T) {
|
||||
mockServer.KeyUnlock(ctx, key)
|
||||
}
|
||||
}
|
||||
res, err := handleZUNIONSTORE(ctx, test.command, mockServer, nil)
|
||||
|
||||
handler := getHandler(test.command[0])
|
||||
if handler == nil {
|
||||
t.Errorf("no handler found for command %s", test.command[0])
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
@@ -17,12 +17,21 @@ package str
|
||||
import (
|
||||
"context"
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/pkg/constants"
|
||||
"github.com/echovault/echovault/pkg/echovault"
|
||||
str "github.com/echovault/echovault/pkg/modules/string"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func createEchoVault() *echovault.EchoVault {
|
||||
ev, _ := echovault.NewEchoVault(
|
||||
echovault.WithCommands(str.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
}),
|
||||
)
|
||||
return ev
|
||||
}
|
||||
|
||||
func presetValue(server *echovault.EchoVault, ctx context.Context, key string, value interface{}) error {
|
||||
if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
|
||||
return err
|
||||
@@ -35,13 +44,7 @@ func presetValue(server *echovault.EchoVault, ctx context.Context, key string, v
|
||||
}
|
||||
|
||||
func TestEchoVault_SUBSTR(t *testing.T) {
|
||||
server, _ := echovault.NewEchoVault(
|
||||
echovault.WithCommands(str.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -200,13 +203,7 @@ func TestEchoVault_SUBSTR(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_SETRANGE(t *testing.T) {
|
||||
server, _ := echovault.NewEchoVault(
|
||||
echovault.WithCommands(str.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -294,13 +291,7 @@ func TestEchoVault_SETRANGE(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoVault_STRLEN(t *testing.T) {
|
||||
server, _ := echovault.NewEchoVault(
|
||||
echovault.WithCommands(str.Commands()),
|
||||
echovault.WithConfig(config.Config{
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
server := createEchoVault()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package str_test
|
||||
package str
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
str "github.com/echovault/echovault/pkg/modules/string"
|
||||
"github.com/echovault/echovault/pkg/types"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -43,15 +44,43 @@ func init() {
|
||||
)
|
||||
}
|
||||
|
||||
func getHandler(command string) types.HandlerFunc {
|
||||
func getHandler(commands ...string) types.HandlerFunc {
|
||||
if len(commands) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, c := range mockServer.GetAllCommands() {
|
||||
if strings.EqualFold(command, c.Command) {
|
||||
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
|
||||
// Get command handler
|
||||
return c.HandlerFunc
|
||||
}
|
||||
if strings.EqualFold(commands[0], c.Command) {
|
||||
// Get sub-command handler
|
||||
for _, sc := range c.SubCommands {
|
||||
if strings.EqualFold(commands[1], sc.Command) {
|
||||
return sc.HandlerFunc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
|
||||
return types.HandlerFuncParams{
|
||||
Context: ctx,
|
||||
Command: cmd,
|
||||
Connection: conn,
|
||||
KeyExists: mockServer.KeyExists,
|
||||
CreateKeyAndLock: mockServer.CreateKeyAndLock,
|
||||
KeyLock: mockServer.KeyLock,
|
||||
KeyRLock: mockServer.KeyRLock,
|
||||
KeyUnlock: mockServer.KeyUnlock,
|
||||
KeyRUnlock: mockServer.KeyRUnlock,
|
||||
GetValue: mockServer.GetValue,
|
||||
SetValue: mockServer.SetValue,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleSetRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -176,17 +205,7 @@ func Test_HandleSetRange(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(types.HandlerFuncParams{
|
||||
Context: ctx,
|
||||
Command: test.command,
|
||||
Connection: nil,
|
||||
KeyExists: mockServer.KeyExists,
|
||||
CreateKeyAndLock: mockServer.CreateKeyAndLock,
|
||||
KeyLock: mockServer.KeyLock,
|
||||
KeyUnlock: mockServer.KeyUnlock,
|
||||
GetValue: mockServer.GetValue,
|
||||
SetValue: mockServer.SetValue,
|
||||
})
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
@@ -291,16 +310,7 @@ func Test_HandleStrLen(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(types.HandlerFuncParams{
|
||||
Context: ctx,
|
||||
Command: test.command,
|
||||
Connection: nil,
|
||||
KeyExists: mockServer.KeyExists,
|
||||
KeyRLock: mockServer.KeyRLock,
|
||||
KeyRUnlock: mockServer.KeyRUnlock,
|
||||
GetValue: mockServer.GetValue,
|
||||
SetValue: mockServer.SetValue,
|
||||
})
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
@@ -436,16 +446,7 @@ func Test_HandleSubStr(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(types.HandlerFuncParams{
|
||||
Context: ctx,
|
||||
Command: test.command,
|
||||
Connection: nil,
|
||||
KeyExists: mockServer.KeyExists,
|
||||
KeyRLock: mockServer.KeyRLock,
|
||||
KeyRUnlock: mockServer.KeyRUnlock,
|
||||
GetValue: mockServer.GetValue,
|
||||
SetValue: mockServer.SetValue,
|
||||
})
|
||||
res, err := handler(getHandlerFuncParams(ctx, test.command, nil))
|
||||
|
||||
if test.expectedError != nil {
|
||||
if err.Error() != test.expectedError.Error() {
|
||||
|
||||
Reference in New Issue
Block a user