Exported EchoVault interface now onlu contains the keyspace methods. All other methods are private. Private methods are accessed using the reflect package in the test folder

This commit is contained in:
Kelvin Clement Mwinuka
2024-04-26 02:33:35 +08:00
parent 44e4f06670
commit 6ad3b7baab
35 changed files with 709 additions and 634 deletions

View File

@@ -20,9 +20,9 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"github.com/gobwas/glob" "github.com/gobwas/glob"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"log" "log"
@@ -286,7 +286,7 @@ func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []
return errors.New("could not authenticate user") return errors.New("could not authenticate user")
} }
func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command types.Command, subCommand types.SubCommand) error { func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command internal.Command, subCommand internal.SubCommand) error {
acl.RLockUsers() acl.RLockUsers()
defer acl.RUnlockUsers() defer acl.RUnlockUsers()
@@ -303,7 +303,7 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command types.
readKeys := keys.ReadKeys readKeys := keys.ReadKeys
writeKeys := keys.WriteKeys writeKeys := keys.WriteKeys
if !reflect.DeepEqual(subCommand, types.SubCommand{}) { if !reflect.DeepEqual(subCommand, internal.SubCommand{}) {
comm = fmt.Sprintf("%s|%s", comm, subCommand.Command) comm = fmt.Sprintf("%s|%s", comm, subCommand.Command)
categories = append(categories, subCommand.Categories...) categories = append(categories, subCommand.Categories...)
keys, err = subCommand.KeyExtractionFunc(cmd) keys, err = subCommand.KeyExtractionFunc(cmd)

View File

@@ -18,8 +18,8 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"log" "log"
"os" "os"
@@ -28,7 +28,7 @@ import (
"strings" "strings"
) )
func handleAuth(params types.HandlerFuncParams) ([]byte, error) { func handleAuth(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) < 2 || len(params.Command) > 3 { if len(params.Command) < 2 || len(params.Command) > 3 {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
@@ -42,7 +42,7 @@ func handleAuth(params types.HandlerFuncParams) ([]byte, error) {
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handleGetUser(params types.HandlerFuncParams) ([]byte, error) { func handleGetUser(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) != 3 { if len(params.Command) != 3 {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
@@ -159,7 +159,7 @@ func handleGetUser(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleCat(params types.HandlerFuncParams) ([]byte, error) { func handleCat(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) > 3 { if len(params.Command) > 3 {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
@@ -219,7 +219,7 @@ func handleCat(params types.HandlerFuncParams) ([]byte, error) {
return nil, fmt.Errorf("category %s not found", strings.ToUpper(params.Command[2])) return nil, fmt.Errorf("category %s not found", strings.ToUpper(params.Command[2]))
} }
func handleUsers(params types.HandlerFuncParams) ([]byte, error) { func handleUsers(params internal.HandlerFuncParams) ([]byte, error) {
acl, ok := params.GetACL().(*ACL) acl, ok := params.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
@@ -232,7 +232,7 @@ func handleUsers(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleSetUser(params types.HandlerFuncParams) ([]byte, error) { func handleSetUser(params internal.HandlerFuncParams) ([]byte, error) {
acl, ok := params.GetACL().(*ACL) acl, ok := params.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
@@ -243,7 +243,7 @@ func handleSetUser(params types.HandlerFuncParams) ([]byte, error) {
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handleDelUser(params types.HandlerFuncParams) ([]byte, error) { func handleDelUser(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) < 3 { if len(params.Command) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
@@ -257,7 +257,7 @@ func handleDelUser(params types.HandlerFuncParams) ([]byte, error) {
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handleWhoAmI(params types.HandlerFuncParams) ([]byte, error) { func handleWhoAmI(params internal.HandlerFuncParams) ([]byte, error) {
acl, ok := params.GetACL().(*ACL) acl, ok := params.GetACL().(*ACL)
if !ok { if !ok {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
@@ -266,7 +266,7 @@ func handleWhoAmI(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil
} }
func handleList(params types.HandlerFuncParams) ([]byte, error) { func handleList(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) > 2 { if len(params.Command) > 2 {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
@@ -362,7 +362,7 @@ func handleList(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleLoad(params types.HandlerFuncParams) ([]byte, error) { func handleLoad(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) != 3 { if len(params.Command) != 3 {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
@@ -429,7 +429,7 @@ func handleLoad(params types.HandlerFuncParams) ([]byte, error) {
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handleSave(params types.HandlerFuncParams) ([]byte, error) { func handleSave(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) > 2 { if len(params.Command) > 2 {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
@@ -487,16 +487,16 @@ func handleSave(params types.HandlerFuncParams) ([]byte, error) {
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func Commands() []types.Command { func Commands() []internal.Command {
return []types.Command{ return []internal.Command{
{ {
Command: "auth", Command: "auth",
Module: constants.ACLModule, Module: constants.ACLModule,
Categories: []string{constants.ConnectionCategory, constants.SlowCategory}, Categories: []string{constants.ConnectionCategory, constants.SlowCategory},
Description: "(AUTH [username] password) Authenticates the connection", Description: "(AUTH [username] password) Authenticates the connection",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -510,14 +510,14 @@ func Commands() []types.Command {
Categories: []string{}, Categories: []string{},
Description: "Access-Control-List commands", Description: "Access-Control-List commands",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
}, },
SubCommands: []types.SubCommand{ SubCommands: []internal.SubCommand{
{ {
Command: "cat", Command: "cat",
Module: constants.ACLModule, Module: constants.ACLModule,
@@ -525,8 +525,8 @@ func Commands() []types.Command {
Description: `(ACL CAT [category]) List all the categories. Description: `(ACL CAT [category]) List all the categories.
If the optional category is provided, list all the commands in the category`, If the optional category is provided, list all the commands in the category`,
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -540,8 +540,8 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL USERS) List all usernames of the configured ACL users", Description: "(ACL USERS) List all usernames of the configured ACL users",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -555,8 +555,8 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL SETUSER) Configure a new or existing user", Description: "(ACL SETUSER) Configure a new or existing user",
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -570,8 +570,8 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL GETUSER username) List the ACL rules of a user", Description: "(ACL GETUSER username) List the ACL rules of a user",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -585,8 +585,8 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL DELUSER username [username ...]) Deletes users and terminates their connections. Cannot delete default user", Description: "(ACL DELUSER username [username ...]) Deletes users and terminates their connections. Cannot delete default user",
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -600,8 +600,8 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.FastCategory}, Categories: []string{constants.FastCategory},
Description: "(ACL WHOAMI) Returns the authenticated user of the current connection", Description: "(ACL WHOAMI) Returns the authenticated user of the current connection",
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -615,8 +615,8 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL LIST) Dumps effective acl rules in acl config file format", Description: "(ACL LIST) Dumps effective acl rules in acl config file format",
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -633,8 +633,8 @@ If the optional category is provided, list all the commands in the category`,
When 'MERGE' is passed, users from config file who share a username with users in memory will be merged. When 'MERGE' is passed, users from config file who share a username with users in memory will be merged.
When 'REPLACE' is passed, users from config file who share a username with users in memory will replace the user in memory.`, When 'REPLACE' is passed, users from config file who share a username with users in memory will replace the user in memory.`,
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -648,8 +648,8 @@ When 'REPLACE' is passed, users from config file who share a username with users
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL SAVE) Saves the effective ACL rules the configured ACL config file", Description: "(ACL SAVE) Saves the effective ACL rules the configured ACL config file",
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),

View File

@@ -17,14 +17,14 @@ package admin
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"github.com/gobwas/glob" "github.com/gobwas/glob"
"slices" "slices"
"strings" "strings"
) )
func handleGetAllCommands(params types.HandlerFuncParams) ([]byte, error) { func handleGetAllCommands(params internal.HandlerFuncParams) ([]byte, error) {
commands := params.GetAllCommands() commands := params.GetAllCommands()
res := "" res := ""
@@ -69,7 +69,7 @@ func handleGetAllCommands(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleCommandCount(params types.HandlerFuncParams) ([]byte, error) { func handleCommandCount(params internal.HandlerFuncParams) ([]byte, error) {
var count int var count int
commands := params.GetAllCommands() commands := params.GetAllCommands()
@@ -86,7 +86,7 @@ func handleCommandCount(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", count)), nil return []byte(fmt.Sprintf(":%d\r\n", count)), nil
} }
func handleCommandList(params types.HandlerFuncParams) ([]byte, error) { func handleCommandList(params internal.HandlerFuncParams) ([]byte, error) {
switch len(params.Command) { switch len(params.Command) {
case 2: case 2:
// Command is COMMAND LIST // Command is COMMAND LIST
@@ -185,20 +185,20 @@ func handleCommandList(params types.HandlerFuncParams) ([]byte, error) {
} }
} }
func handleCommandDocs(params types.HandlerFuncParams) ([]byte, error) { func handleCommandDocs(params internal.HandlerFuncParams) ([]byte, error) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
func Commands() []types.Command { func Commands() []internal.Command {
return []types.Command{ return []internal.Command{
{ {
Command: "commands", Command: "commands",
Module: constants.AdminModule, Module: constants.AdminModule,
Categories: []string{constants.AdminCategory, constants.SlowCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory},
Description: "Get a list of all the commands in available on the echovault with categories and descriptions", Description: "Get a list of all the commands in available on the echovault with categories and descriptions",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -212,22 +212,22 @@ func Commands() []types.Command {
Categories: []string{}, Categories: []string{},
Description: "Commands pertaining to echovault commands", Description: "Commands pertaining to echovault commands",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
}, },
SubCommands: []types.SubCommand{ SubCommands: []internal.SubCommand{
{ {
Command: "docs", Command: "docs",
Module: constants.AdminModule, Module: constants.AdminModule,
Categories: []string{constants.SlowCategory, constants.ConnectionCategory}, Categories: []string{constants.SlowCategory, constants.ConnectionCategory},
Description: "Get command documentation", Description: "Get command documentation",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -241,8 +241,8 @@ func Commands() []types.Command {
Categories: []string{constants.SlowCategory}, Categories: []string{constants.SlowCategory},
Description: "Get the dumber of commands in the echovault", Description: "Get the dumber of commands in the echovault",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -257,8 +257,8 @@ func Commands() []types.Command {
Description: `(COMMAND LIST [FILTERBY <ACLCAT category | PATTERN pattern | MODULE module>]) Get the list of command names. Description: `(COMMAND LIST [FILTERBY <ACLCAT category | PATTERN pattern | MODULE module>]) Get the list of command names.
Allows for filtering by ACL category or glob pattern.`, Allows for filtering by ACL category or glob pattern.`,
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -274,14 +274,14 @@ Allows for filtering by ACL category or glob pattern.`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(SAVE) Trigger a snapshot save", Description: "(SAVE) Trigger a snapshot save",
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
}, },
HandlerFunc: func(params types.HandlerFuncParams) ([]byte, error) { HandlerFunc: func(params internal.HandlerFuncParams) ([]byte, error) {
if err := params.TakeSnapshot(); err != nil { if err := params.TakeSnapshot(); err != nil {
return nil, err return nil, err
} }
@@ -294,14 +294,14 @@ Allows for filtering by ACL category or glob pattern.`,
Categories: []string{constants.AdminCategory, constants.FastCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.FastCategory, constants.DangerousCategory},
Description: "(LASTSAVE) Get unix timestamp for the latest snapshot in milliseconds.", Description: "(LASTSAVE) Get unix timestamp for the latest snapshot in milliseconds.",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
}, },
HandlerFunc: func(params types.HandlerFuncParams) ([]byte, error) { HandlerFunc: func(params internal.HandlerFuncParams) ([]byte, error) {
msec := params.GetLatestSnapshotTime() msec := params.GetLatestSnapshotTime()
if msec == 0 { if msec == 0 {
return nil, errors.New("no snapshot") return nil, errors.New("no snapshot")
@@ -315,14 +315,14 @@ Allows for filtering by ACL category or glob pattern.`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(REWRITEAOF) Trigger re-writing of append process", Description: "(REWRITEAOF) Trigger re-writing of append process",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
}, },
HandlerFunc: func(params types.HandlerFuncParams) ([]byte, error) { HandlerFunc: func(params internal.HandlerFuncParams) ([]byte, error) {
if err := params.RewriteAOF(); err != nil { if err := params.RewriteAOF(); err != nil {
return nil, err return nil, err
} }

View File

@@ -17,11 +17,11 @@ package connection
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
) )
func handlePing(params types.HandlerFuncParams) ([]byte, error) { func handlePing(params internal.HandlerFuncParams) ([]byte, error) {
switch len(params.Command) { switch len(params.Command) {
default: default:
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
@@ -32,16 +32,16 @@ func handlePing(params types.HandlerFuncParams) ([]byte, error) {
} }
} }
func Commands() []types.Command { func Commands() []internal.Command {
return []types.Command{ return []internal.Command{
{ {
Command: "ping", Command: "ping",
Module: constants.ConnectionModule, Module: constants.ConnectionModule,
Categories: []string{constants.FastCategory, constants.ConnectionCategory}, Categories: []string{constants.FastCategory, constants.ConnectionCategory},
Description: "(PING [value]) Ping the echovault. If a value is provided, the value will be echoed.", Description: "(PING [value]) Ping the echovault. If a value is provided, the value will be echoed.",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),

View File

@@ -19,7 +19,6 @@ import (
"fmt" "fmt"
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"log" "log"
"strconv" "strconv"
"strings" "strings"
@@ -31,7 +30,7 @@ type KeyObject struct {
locked bool locked bool
} }
func handleSet(params types.HandlerFuncParams) ([]byte, error) { func handleSet(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := setKeyFunc(params.Command) keys, err := setKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -96,7 +95,7 @@ func handleSet(params types.HandlerFuncParams) ([]byte, error) {
return res, nil return res, nil
} }
func handleMSet(params types.HandlerFuncParams) ([]byte, error) { func handleMSet(params internal.HandlerFuncParams) ([]byte, error) {
_, err := msetKeyFunc(params.Command) _, err := msetKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -153,7 +152,7 @@ func handleMSet(params types.HandlerFuncParams) ([]byte, error) {
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handleGet(params types.HandlerFuncParams) ([]byte, error) { func handleGet(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := getKeyFunc(params.Command) keys, err := getKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -175,7 +174,7 @@ func handleGet(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf("+%v\r\n", value)), nil return []byte(fmt.Sprintf("+%v\r\n", value)), nil
} }
func handleMGet(params types.HandlerFuncParams) ([]byte, error) { func handleMGet(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := mgetKeyFunc(params.Command) keys, err := mgetKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -225,7 +224,7 @@ func handleMGet(params types.HandlerFuncParams) ([]byte, error) {
return bytes, nil return bytes, nil
} }
func handleDel(params types.HandlerFuncParams) ([]byte, error) { func handleDel(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := delKeyFunc(params.Command) keys, err := delKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -242,7 +241,7 @@ func handleDel(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", count)), nil return []byte(fmt.Sprintf(":%d\r\n", count)), nil
} }
func handlePersist(params types.HandlerFuncParams) ([]byte, error) { func handlePersist(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := persistKeyFunc(params.Command) keys, err := persistKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -269,7 +268,7 @@ func handlePersist(params types.HandlerFuncParams) ([]byte, error) {
return []byte(":1\r\n"), nil return []byte(":1\r\n"), nil
} }
func handleExpireTime(params types.HandlerFuncParams) ([]byte, error) { func handleExpireTime(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := expireTimeKeyFunc(params.Command) keys, err := expireTimeKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -300,7 +299,7 @@ func handleExpireTime(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", t)), nil return []byte(fmt.Sprintf(":%d\r\n", t)), nil
} }
func handleTTL(params types.HandlerFuncParams) ([]byte, error) { func handleTTL(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := ttlKeyFunc(params.Command) keys, err := ttlKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -337,7 +336,7 @@ func handleTTL(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", t)), nil return []byte(fmt.Sprintf(":%d\r\n", t)), nil
} }
func handleExpire(params types.HandlerFuncParams) ([]byte, error) { func handleExpire(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := expireKeyFunc(params.Command) keys, err := expireKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -405,7 +404,7 @@ func handleExpire(params types.HandlerFuncParams) ([]byte, error) {
return []byte(":1\r\n"), nil return []byte(":1\r\n"), nil
} }
func handleExpireAt(params types.HandlerFuncParams) ([]byte, error) { func handleExpireAt(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := expireKeyFunc(params.Command) keys, err := expireKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -473,8 +472,8 @@ func handleExpireAt(params types.HandlerFuncParams) ([]byte, error) {
return []byte(":1\r\n"), nil return []byte(":1\r\n"), nil
} }
func Commands() []types.Command { func Commands() []internal.Command {
return []types.Command{ return []internal.Command{
{ {
Command: "set", Command: "set",
Module: constants.GenericModule, Module: constants.GenericModule,

View File

@@ -16,24 +16,24 @@ package generic
import ( import (
"errors" "errors"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
) )
func setKeyFunc(cmd []string) (types.AccessKeys, error) { func setKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 || len(cmd) > 7 { if len(cmd) < 3 || len(cmd) > 7 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func msetKeyFunc(cmd []string) (types.AccessKeys, error) { func msetKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd[1:])%2 != 0 { if len(cmd[1:])%2 != 0 {
return types.AccessKeys{}, errors.New("each key must be paired with a value") return internal.AccessKeys{}, errors.New("each key must be paired with a value")
} }
var keys []string var keys []string
for i, key := range cmd[1:] { for i, key := range cmd[1:] {
@@ -41,95 +41,95 @@ func msetKeyFunc(cmd []string) (types.AccessKeys, error) {
keys = append(keys, key) keys = append(keys, key)
} }
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: keys, WriteKeys: keys,
}, nil }, nil
} }
func getKeyFunc(cmd []string) (types.AccessKeys, error) { func getKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func mgetKeyFunc(cmd []string) (types.AccessKeys, error) { func mgetKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func delKeyFunc(cmd []string) (types.AccessKeys, error) { func delKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:], WriteKeys: cmd[1:],
}, nil }, nil
} }
func persistKeyFunc(cmd []string) (types.AccessKeys, error) { func persistKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:], WriteKeys: cmd[1:],
}, nil }, nil
} }
func expireTimeKeyFunc(cmd []string) (types.AccessKeys, error) { func expireTimeKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func ttlKeyFunc(cmd []string) (types.AccessKeys, error) { func ttlKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func expireKeyFunc(cmd []string) (types.AccessKeys, error) { func expireKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 || len(cmd) > 4 { if len(cmd) < 3 || len(cmd) > 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func expireAtKeyFunc(cmd []string) (types.AccessKeys, error) { func expireAtKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 || len(cmd) > 4 { if len(cmd) < 3 || len(cmd) > 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],

View File

@@ -19,14 +19,13 @@ import (
"fmt" "fmt"
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"math/rand" "math/rand"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
) )
func handleHSET(params types.HandlerFuncParams) ([]byte, error) { func handleHSET(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := hsetKeyFunc(params.Command) keys, err := hsetKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -84,7 +83,7 @@ func handleHSET(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", count)), nil return []byte(fmt.Sprintf(":%d\r\n", count)), nil
} }
func handleHGET(params types.HandlerFuncParams) ([]byte, error) { func handleHGET(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := hgetKeyFunc(params.Command) keys, err := hgetKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -135,7 +134,7 @@ func handleHGET(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleHSTRLEN(params types.HandlerFuncParams) ([]byte, error) { func handleHSTRLEN(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := hstrlenKeyFunc(params.Command) keys, err := hstrlenKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -186,7 +185,7 @@ func handleHSTRLEN(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleHVALS(params types.HandlerFuncParams) ([]byte, error) { func handleHVALS(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := hvalsKeyFunc(params.Command) keys, err := hvalsKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -227,7 +226,7 @@ func handleHVALS(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleHRANDFIELD(params types.HandlerFuncParams) ([]byte, error) { func handleHRANDFIELD(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := hrandfieldKeyFunc(params.Command) keys, err := hrandfieldKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -343,7 +342,7 @@ func handleHRANDFIELD(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleHLEN(params types.HandlerFuncParams) ([]byte, error) { func handleHLEN(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := hlenKeyFunc(params.Command) keys, err := hlenKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -368,7 +367,7 @@ func handleHLEN(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", len(hash))), nil return []byte(fmt.Sprintf(":%d\r\n", len(hash))), nil
} }
func handleHKEYS(params types.HandlerFuncParams) ([]byte, error) { func handleHKEYS(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := hkeysKeyFunc(params.Command) keys, err := hkeysKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -398,7 +397,7 @@ func handleHKEYS(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleHINCRBY(params types.HandlerFuncParams) ([]byte, error) { func handleHINCRBY(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := hincrbyKeyFunc(params.Command) keys, err := hincrbyKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -490,7 +489,7 @@ func handleHINCRBY(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", i)), nil return []byte(fmt.Sprintf(":%d\r\n", i)), nil
} }
func handleHGETALL(params types.HandlerFuncParams) ([]byte, error) { func handleHGETALL(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := hgetallKeyFunc(params.Command) keys, err := hgetallKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -530,7 +529,7 @@ func handleHGETALL(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleHEXISTS(params types.HandlerFuncParams) ([]byte, error) { func handleHEXISTS(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := hexistsKeyFunc(params.Command) keys, err := hexistsKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -560,7 +559,7 @@ func handleHEXISTS(params types.HandlerFuncParams) ([]byte, error) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
func handleHDEL(params types.HandlerFuncParams) ([]byte, error) { func handleHDEL(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := hdelKeyFunc(params.Command) keys, err := hdelKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -599,8 +598,8 @@ func handleHDEL(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", count)), nil return []byte(fmt.Sprintf(":%d\r\n", count)), nil
} }
func Commands() []types.Command { func Commands() []internal.Command {
return []types.Command{ return []internal.Command{
{ {
Command: "hset", Command: "hset",
Module: constants.HashModule, Module: constants.HashModule,

View File

@@ -16,143 +16,143 @@ package hash
import ( import (
"errors" "errors"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
) )
func hsetKeyFunc(cmd []string) (types.AccessKeys, error) { func hsetKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 4 { if len(cmd) < 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func hsetnxKeyFunc(cmd []string) (types.AccessKeys, error) { func hsetnxKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 4 { if len(cmd) < 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func hgetKeyFunc(cmd []string) (types.AccessKeys, error) { func hgetKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func hstrlenKeyFunc(cmd []string) (types.AccessKeys, error) { func hstrlenKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func hvalsKeyFunc(cmd []string) (types.AccessKeys, error) { func hvalsKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func hrandfieldKeyFunc(cmd []string) (types.AccessKeys, error) { func hrandfieldKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 2 || len(cmd) > 4 { if len(cmd) < 2 || len(cmd) > 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
if len(cmd) == 2 { if len(cmd) == 2 {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func hlenKeyFunc(cmd []string) (types.AccessKeys, error) { func hlenKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func hkeysKeyFunc(cmd []string) (types.AccessKeys, error) { func hkeysKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func hincrbyKeyFunc(cmd []string) (types.AccessKeys, error) { func hincrbyKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func hgetallKeyFunc(cmd []string) (types.AccessKeys, error) { func hgetallKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func hexistsKeyFunc(cmd []string) (types.AccessKeys, error) { func hexistsKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 3 { if len(cmd) != 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func hdelKeyFunc(cmd []string) (types.AccessKeys, error) { func hdelKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],

View File

@@ -19,13 +19,12 @@ import (
"fmt" "fmt"
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"math" "math"
"slices" "slices"
"strings" "strings"
) )
func handleLLen(params types.HandlerFuncParams) ([]byte, error) { func handleLLen(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := llenKeyFunc(params.Command) keys, err := llenKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -50,7 +49,7 @@ func handleLLen(params types.HandlerFuncParams) ([]byte, error) {
return nil, errors.New("LLEN command on non-list item") return nil, errors.New("LLEN command on non-list item")
} }
func handleLIndex(params types.HandlerFuncParams) ([]byte, error) { func handleLIndex(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := lindexKeyFunc(params.Command) keys, err := lindexKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -84,7 +83,7 @@ func handleLIndex(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf("+%s\r\n", list[index])), nil return []byte(fmt.Sprintf("+%s\r\n", list[index])), nil
} }
func handleLRange(params types.HandlerFuncParams) ([]byte, error) { func handleLRange(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := lrangeKeyFunc(params.Command) keys, err := lrangeKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -163,7 +162,7 @@ func handleLRange(params types.HandlerFuncParams) ([]byte, error) {
return bytes, nil return bytes, nil
} }
func handleLSet(params types.HandlerFuncParams) ([]byte, error) { func handleLSet(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := lsetKeyFunc(params.Command) keys, err := lsetKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -202,7 +201,7 @@ func handleLSet(params types.HandlerFuncParams) ([]byte, error) {
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handleLTrim(params types.HandlerFuncParams) ([]byte, error) { func handleLTrim(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := ltrimKeyFunc(params.Command) keys, err := ltrimKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -251,7 +250,7 @@ func handleLTrim(params types.HandlerFuncParams) ([]byte, error) {
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handleLRem(params types.HandlerFuncParams) ([]byte, error) { func handleLRem(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := lremKeyFunc(params.Command) keys, err := lremKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -319,7 +318,7 @@ func handleLRem(params types.HandlerFuncParams) ([]byte, error) {
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handleLMove(params types.HandlerFuncParams) ([]byte, error) { func handleLMove(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := lmoveKeyFunc(params.Command) keys, err := lmoveKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -379,7 +378,7 @@ func handleLMove(params types.HandlerFuncParams) ([]byte, error) {
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handleLPush(params types.HandlerFuncParams) ([]byte, error) { func handleLPush(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := lpushKeyFunc(params.Command) keys, err := lpushKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -425,7 +424,7 @@ func handleLPush(params types.HandlerFuncParams) ([]byte, error) {
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handleRPush(params types.HandlerFuncParams) ([]byte, error) { func handleRPush(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := rpushKeyFunc(params.Command) keys, err := rpushKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -473,7 +472,7 @@ func handleRPush(params types.HandlerFuncParams) ([]byte, error) {
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handlePop(params types.HandlerFuncParams) ([]byte, error) { func handlePop(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := popKeyFunc(params.Command) keys, err := popKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -509,8 +508,8 @@ func handlePop(params types.HandlerFuncParams) ([]byte, error) {
} }
} }
func Commands() []types.Command { func Commands() []internal.Command {
return []types.Command{ return []internal.Command{
{ {
Command: "lpush", Command: "lpush",
Module: constants.ListModule, Module: constants.ListModule,

View File

@@ -16,114 +16,114 @@ package list
import ( import (
"errors" "errors"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
) )
func lpushKeyFunc(cmd []string) (types.AccessKeys, error) { func lpushKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func popKeyFunc(cmd []string) (types.AccessKeys, error) { func popKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:], WriteKeys: cmd[1:],
}, nil }, nil
} }
func llenKeyFunc(cmd []string) (types.AccessKeys, error) { func llenKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func lrangeKeyFunc(cmd []string) (types.AccessKeys, error) { func lrangeKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func lindexKeyFunc(cmd []string) (types.AccessKeys, error) { func lindexKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 3 { if len(cmd) != 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func lsetKeyFunc(cmd []string) (types.AccessKeys, error) { func lsetKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func ltrimKeyFunc(cmd []string) (types.AccessKeys, error) { func ltrimKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func lremKeyFunc(cmd []string) (types.AccessKeys, error) { func lremKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func rpushKeyFunc(cmd []string) (types.AccessKeys, error) { func rpushKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func lmoveKeyFunc(cmd []string) (types.AccessKeys, error) { func lmoveKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 5 { if len(cmd) != 5 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:3], WriteKeys: cmd[1:3],

View File

@@ -17,12 +17,12 @@ package pubsub
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"strings" "strings"
) )
func handleSubscribe(params types.HandlerFuncParams) ([]byte, error) { func handleSubscribe(params internal.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*PubSub) pubsub, ok := params.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub module") return nil, errors.New("could not load pubsub module")
@@ -40,7 +40,7 @@ func handleSubscribe(params types.HandlerFuncParams) ([]byte, error) {
return nil, nil return nil, nil
} }
func handleUnsubscribe(params types.HandlerFuncParams) ([]byte, error) { func handleUnsubscribe(params internal.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*PubSub) pubsub, ok := params.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub module") return nil, errors.New("could not load pubsub module")
@@ -53,7 +53,7 @@ func handleUnsubscribe(params types.HandlerFuncParams) ([]byte, error) {
return pubsub.Unsubscribe(params.Context, params.Connection, channels, withPattern), nil return pubsub.Unsubscribe(params.Context, params.Connection, channels, withPattern), nil
} }
func handlePublish(params types.HandlerFuncParams) ([]byte, error) { func handlePublish(params internal.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*PubSub) pubsub, ok := params.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub module") return nil, errors.New("could not load pubsub module")
@@ -65,7 +65,7 @@ func handlePublish(params types.HandlerFuncParams) ([]byte, error) {
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handlePubSubChannels(params types.HandlerFuncParams) ([]byte, error) { func handlePubSubChannels(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) > 3 { if len(params.Command) > 3 {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
@@ -83,7 +83,7 @@ func handlePubSubChannels(params types.HandlerFuncParams) ([]byte, error) {
return pubsub.Channels(pattern), nil return pubsub.Channels(pattern), nil
} }
func handlePubSubNumPat(params types.HandlerFuncParams) ([]byte, error) { func handlePubSubNumPat(params internal.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*PubSub) pubsub, ok := params.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub module") return nil, errors.New("could not load pubsub module")
@@ -92,7 +92,7 @@ func handlePubSubNumPat(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", num)), nil return []byte(fmt.Sprintf(":%d\r\n", num)), nil
} }
func handlePubSubNumSubs(params types.HandlerFuncParams) ([]byte, error) { func handlePubSubNumSubs(params internal.HandlerFuncParams) ([]byte, error) {
pubsub, ok := params.GetPubSub().(*PubSub) pubsub, ok := params.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub module") return nil, errors.New("could not load pubsub module")
@@ -100,20 +100,20 @@ func handlePubSubNumSubs(params types.HandlerFuncParams) ([]byte, error) {
return pubsub.NumSub(params.Command[2:]), nil return pubsub.NumSub(params.Command[2:]), nil
} }
func Commands() []types.Command { func Commands() []internal.Command {
return []types.Command{ return []internal.Command{
{ {
Command: "subscribe", Command: "subscribe",
Module: constants.PubSubModule, Module: constants.PubSubModule,
Categories: []string{constants.PubSubCategory, constants.ConnectionCategory, constants.SlowCategory}, Categories: []string{constants.PubSubCategory, constants.ConnectionCategory, constants.SlowCategory},
Description: "(SUBSCRIBE channel [channel ...]) Subscribe to one or more channels.", Description: "(SUBSCRIBE channel [channel ...]) Subscribe to one or more channels.",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
// Treat the channels as keys // Treat the channels as keys
if len(cmd) < 2 { if len(cmd) < 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: cmd[1:], Channels: cmd[1:],
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -127,12 +127,12 @@ func Commands() []types.Command {
Categories: []string{constants.PubSubCategory, constants.ConnectionCategory, constants.SlowCategory}, Categories: []string{constants.PubSubCategory, constants.ConnectionCategory, constants.SlowCategory},
Description: "(PSUBSCRIBE pattern [pattern ...]) Subscribe to one or more glob patterns.", Description: "(PSUBSCRIBE pattern [pattern ...]) Subscribe to one or more glob patterns.",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
// Treat the patterns as keys // Treat the patterns as keys
if len(cmd) < 2 { if len(cmd) < 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: cmd[1:], Channels: cmd[1:],
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -146,12 +146,12 @@ func Commands() []types.Command {
Categories: []string{constants.PubSubCategory, constants.FastCategory}, Categories: []string{constants.PubSubCategory, constants.FastCategory},
Description: "(PUBLISH channel message) Publish a message to the specified channel.", Description: "(PUBLISH channel message) Publish a message to the specified channel.",
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
// Treat the channel as a key // Treat the channel as a key
if len(cmd) != 3 { if len(cmd) != 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: cmd[1:2], Channels: cmd[1:2],
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -167,9 +167,9 @@ func Commands() []types.Command {
If the channel list is not provided, then the connection will be unsubscribed from all the channels that If the channel list is not provided, then the connection will be unsubscribed from all the channels that
it's currently subscribe to.`, it's currently subscribe to.`,
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
// Treat the channels as keys // Treat the channels as keys
return types.AccessKeys{ return internal.AccessKeys{
Channels: cmd[1:], Channels: cmd[1:],
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -185,8 +185,8 @@ it's currently subscribe to.`,
If the pattern list is not provided, then the connection will be unsubscribed from all the patterns that If the pattern list is not provided, then the connection will be unsubscribed from all the patterns that
it's currently subscribe to.`, it's currently subscribe to.`,
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: cmd[1:], Channels: cmd[1:],
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -200,17 +200,17 @@ it's currently subscribe to.`,
Categories: []string{}, Categories: []string{},
Description: "", Description: "",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
}, },
HandlerFunc: func(_ types.HandlerFuncParams) ([]byte, error) { HandlerFunc: func(_ internal.HandlerFuncParams) ([]byte, error) {
return nil, errors.New("provide CHANNELS, NUMPAT, or NUMSUB subcommand") return nil, errors.New("provide CHANNELS, NUMPAT, or NUMSUB subcommand")
}, },
SubCommands: []types.SubCommand{ SubCommands: []internal.SubCommand{
{ {
Command: "channels", Command: "channels",
Module: constants.PubSubModule, Module: constants.PubSubModule,
@@ -219,8 +219,8 @@ it's currently subscribe to.`,
match the given pattern. If no pattern is provided, all active channels are returned. Active channels are match the given pattern. If no pattern is provided, all active channels are returned. Active channels are
channels with 1 or more subscribers.`, channels with 1 or more subscribers.`,
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -234,8 +234,8 @@ channels with 1 or more subscribers.`,
Categories: []string{constants.PubSubCategory, constants.SlowCategory}, Categories: []string{constants.PubSubCategory, constants.SlowCategory},
Description: `(PUBSUB NUMPAT) Return the number of patterns that are currently subscribed to by clients.`, Description: `(PUBSUB NUMPAT) Return the number of patterns that are currently subscribed to by clients.`,
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
@@ -250,8 +250,8 @@ channels with 1 or more subscribers.`,
Description: `(PUBSUB NUMSUB [channel [channel ...]]) Return an array of arrays containing the provided Description: `(PUBSUB NUMSUB [channel [channel ...]]) Return an array of arrays containing the provided
channel name and how many clients are currently subscribed to the channel.`, channel name and how many clients are currently subscribed to the channel.`,
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) { KeyExtractionFunc: func(cmd []string) (internal.AccessKeys, error) {
return types.AccessKeys{ return internal.AccessKeys{
Channels: cmd[2:], Channels: cmd[2:],
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),

View File

@@ -19,12 +19,11 @@ import (
"fmt" "fmt"
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"slices" "slices"
"strings" "strings"
) )
func handleSADD(params types.HandlerFuncParams) ([]byte, error) { func handleSADD(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := saddKeyFunc(params.Command) keys, err := saddKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -61,7 +60,7 @@ func handleSADD(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", count)), nil return []byte(fmt.Sprintf(":%d\r\n", count)), nil
} }
func handleSCARD(params types.HandlerFuncParams) ([]byte, error) { func handleSCARD(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := scardKeyFunc(params.Command) keys, err := scardKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -88,7 +87,7 @@ func handleSCARD(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", cardinality)), nil return []byte(fmt.Sprintf(":%d\r\n", cardinality)), nil
} }
func handleSDIFF(params types.HandlerFuncParams) ([]byte, error) { func handleSDIFF(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := sdiffKeyFunc(params.Command) keys, err := sdiffKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -149,7 +148,7 @@ func handleSDIFF(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleSDIFFSTORE(params types.HandlerFuncParams) ([]byte, error) { func handleSDIFFSTORE(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := sdiffstoreKeyFunc(params.Command) keys, err := sdiffstoreKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -225,7 +224,7 @@ func handleSDIFFSTORE(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleSINTER(params types.HandlerFuncParams) ([]byte, error) { func handleSINTER(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := sinterKeyFunc(params.Command) keys, err := sinterKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -280,7 +279,7 @@ func handleSINTER(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleSINTERCARD(params types.HandlerFuncParams) ([]byte, error) { func handleSINTERCARD(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := sintercardKeyFunc(params.Command) keys, err := sintercardKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -347,7 +346,7 @@ func handleSINTERCARD(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
} }
func handleSINTERSTORE(params types.HandlerFuncParams) ([]byte, error) { func handleSINTERSTORE(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := sinterstoreKeyFunc(params.Command) keys, err := sinterstoreKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -405,7 +404,7 @@ func handleSINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
} }
func handleSISMEMBER(params types.HandlerFuncParams) ([]byte, error) { func handleSISMEMBER(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := sismemberKeyFunc(params.Command) keys, err := sismemberKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -434,7 +433,7 @@ func handleSISMEMBER(params types.HandlerFuncParams) ([]byte, error) {
return []byte(":1\r\n"), nil return []byte(":1\r\n"), nil
} }
func handleSMEMBERS(params types.HandlerFuncParams) ([]byte, error) { func handleSMEMBERS(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := smembersKeyFunc(params.Command) keys, err := smembersKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -469,7 +468,7 @@ func handleSMEMBERS(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleSMISMEMBER(params types.HandlerFuncParams) ([]byte, error) { func handleSMISMEMBER(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := smismemberKeyFunc(params.Command) keys, err := smismemberKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -512,7 +511,7 @@ func handleSMISMEMBER(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleSMOVE(params types.HandlerFuncParams) ([]byte, error) { func handleSMOVE(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := smoveKeyFunc(params.Command) keys, err := smoveKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -565,7 +564,7 @@ func handleSMOVE(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", res)), nil return []byte(fmt.Sprintf(":%d\r\n", res)), nil
} }
func handleSPOP(params types.HandlerFuncParams) ([]byte, error) { func handleSPOP(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := spopKeyFunc(params.Command) keys, err := spopKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -609,7 +608,7 @@ func handleSPOP(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleSRANDMEMBER(params types.HandlerFuncParams) ([]byte, error) { func handleSRANDMEMBER(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := srandmemberKeyFunc(params.Command) keys, err := srandmemberKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -653,7 +652,7 @@ func handleSRANDMEMBER(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleSREM(params types.HandlerFuncParams) ([]byte, error) { func handleSREM(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := sremKeyFunc(params.Command) keys, err := sremKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -681,7 +680,7 @@ func handleSREM(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", count)), nil return []byte(fmt.Sprintf(":%d\r\n", count)), nil
} }
func handleSUNION(params types.HandlerFuncParams) ([]byte, error) { func handleSUNION(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := sunionKeyFunc(params.Command) keys, err := sunionKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -732,7 +731,7 @@ func handleSUNION(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleSUNIONSTORE(params types.HandlerFuncParams) ([]byte, error) { func handleSUNIONSTORE(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := sunionstoreKeyFunc(params.Command) keys, err := sunionstoreKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -791,8 +790,8 @@ func handleSUNIONSTORE(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", union.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", union.Cardinality())), nil
} }
func Commands() []types.Command { func Commands() []internal.Command {
return []types.Command{ return []internal.Command{
{ {
Command: "sadd", Command: "sadd",
Module: constants.SetModule, Module: constants.SetModule,

View File

@@ -16,70 +16,70 @@ package set
import ( import (
"errors" "errors"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"slices" "slices"
"strings" "strings"
) )
func saddKeyFunc(cmd []string) (types.AccessKeys, error) { func saddKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func scardKeyFunc(cmd []string) (types.AccessKeys, error) { func scardKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func sdiffKeyFunc(cmd []string) (types.AccessKeys, error) { func sdiffKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func sdiffstoreKeyFunc(cmd []string) (types.AccessKeys, error) { func sdiffstoreKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[2:], ReadKeys: cmd[2:],
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func sinterKeyFunc(cmd []string) (types.AccessKeys, error) { func sinterKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func sintercardKeyFunc(cmd []string) (types.AccessKeys, error) { func sintercardKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
limitIdx := slices.IndexFunc(cmd, func(s string) bool { limitIdx := slices.IndexFunc(cmd, func(s string) bool {
@@ -87,124 +87,124 @@ func sintercardKeyFunc(cmd []string) (types.AccessKeys, error) {
}) })
if limitIdx == -1 { if limitIdx == -1 {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:limitIdx], ReadKeys: cmd[1:limitIdx],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func sinterstoreKeyFunc(cmd []string) (types.AccessKeys, error) { func sinterstoreKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[2:], ReadKeys: cmd[2:],
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func sismemberKeyFunc(cmd []string) (types.AccessKeys, error) { func sismemberKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 3 { if len(cmd) != 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func smembersKeyFunc(cmd []string) (types.AccessKeys, error) { func smembersKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func smismemberKeyFunc(cmd []string) (types.AccessKeys, error) { func smismemberKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func smoveKeyFunc(cmd []string) (types.AccessKeys, error) { func smoveKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:3], WriteKeys: cmd[1:3],
}, nil }, nil
} }
func spopKeyFunc(cmd []string) (types.AccessKeys, error) { func spopKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 2 || len(cmd) > 3 { if len(cmd) < 2 || len(cmd) > 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func srandmemberKeyFunc(cmd []string) (types.AccessKeys, error) { func srandmemberKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 2 || len(cmd) > 3 { if len(cmd) < 2 || len(cmd) > 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func sremKeyFunc(cmd []string) (types.AccessKeys, error) { func sremKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func sunionKeyFunc(cmd []string) (types.AccessKeys, error) { func sunionKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func sunionstoreKeyFunc(cmd []string) (types.AccessKeys, error) { func sunionstoreKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[2:], ReadKeys: cmd[2:],
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],

View File

@@ -20,14 +20,13 @@ import (
"fmt" "fmt"
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"math" "math"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
) )
func handleZADD(params types.HandlerFuncParams) ([]byte, error) { func handleZADD(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zaddKeyFunc(params.Command) keys, err := zaddKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -178,7 +177,7 @@ func handleZADD(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", set.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", set.Cardinality())), nil
} }
func handleZCARD(params types.HandlerFuncParams) ([]byte, error) { func handleZCARD(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zcardKeyFunc(params.Command) keys, err := zcardKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -202,7 +201,7 @@ func handleZCARD(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", set.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", set.Cardinality())), nil
} }
func handleZCOUNT(params types.HandlerFuncParams) ([]byte, error) { func handleZCOUNT(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zcountKeyFunc(params.Command) keys, err := zcountKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -270,7 +269,7 @@ func handleZCOUNT(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", len(members))), nil return []byte(fmt.Sprintf(":%d\r\n", len(members))), nil
} }
func handleZLEXCOUNT(params types.HandlerFuncParams) ([]byte, error) { func handleZLEXCOUNT(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zlexcountKeyFunc(params.Command) keys, err := zlexcountKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -315,7 +314,7 @@ func handleZLEXCOUNT(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", count)), nil return []byte(fmt.Sprintf(":%d\r\n", count)), nil
} }
func handleZDIFF(params types.HandlerFuncParams) ([]byte, error) { func handleZDIFF(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zdiffKeyFunc(params.Command) keys, err := zdiffKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -388,7 +387,7 @@ func handleZDIFF(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleZDIFFSTORE(params types.HandlerFuncParams) ([]byte, error) { func handleZDIFFSTORE(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zdiffstoreKeyFunc(params.Command) keys, err := zdiffstoreKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -454,7 +453,7 @@ func handleZDIFFSTORE(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", diff.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", diff.Cardinality())), nil
} }
func handleZINCRBY(params types.HandlerFuncParams) ([]byte, error) { func handleZINCRBY(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zincrbyKeyFunc(params.Command) keys, err := zincrbyKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -521,7 +520,7 @@ func handleZINCRBY(params types.HandlerFuncParams) ([]byte, error) {
strconv.FormatFloat(float64(set.Get(member).Score), 'f', -1, 64))), nil strconv.FormatFloat(float64(set.Get(member).Score), 'f', -1, 64))), nil
} }
func handleZINTER(params types.HandlerFuncParams) ([]byte, error) { func handleZINTER(params internal.HandlerFuncParams) ([]byte, error) {
_, err := zinterKeyFunc(params.Command) _, err := zinterKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -581,7 +580,7 @@ func handleZINTER(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleZINTERSTORE(params types.HandlerFuncParams) ([]byte, error) { func handleZINTERSTORE(params internal.HandlerFuncParams) ([]byte, error) {
k, err := zinterstoreKeyFunc(params.Command) k, err := zinterstoreKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -648,7 +647,7 @@ func handleZINTERSTORE(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
} }
func handleZMPOP(params types.HandlerFuncParams) ([]byte, error) { func handleZMPOP(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zmpopKeyFunc(params.Command) keys, err := zmpopKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -726,7 +725,7 @@ func handleZMPOP(params types.HandlerFuncParams) ([]byte, error) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
func handleZPOP(params types.HandlerFuncParams) ([]byte, error) { func handleZPOP(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zpopKeyFunc(params.Command) keys, err := zpopKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -779,7 +778,7 @@ func handleZPOP(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleZMSCORE(params types.HandlerFuncParams) ([]byte, error) { func handleZMSCORE(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zmscoreKeyFunc(params.Command) keys, err := zmscoreKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -821,7 +820,7 @@ func handleZMSCORE(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleZRANDMEMBER(params types.HandlerFuncParams) ([]byte, error) { func handleZRANDMEMBER(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zrandmemberKeyFunc(params.Command) keys, err := zrandmemberKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -879,7 +878,7 @@ func handleZRANDMEMBER(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleZRANK(params types.HandlerFuncParams) ([]byte, error) { func handleZRANK(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zrankKeyFunc(params.Command) keys, err := zrankKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -929,7 +928,7 @@ func handleZRANK(params types.HandlerFuncParams) ([]byte, error) {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
} }
func handleZREM(params types.HandlerFuncParams) ([]byte, error) { func handleZREM(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zremKeyFunc(params.Command) keys, err := zremKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -961,7 +960,7 @@ func handleZREM(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
} }
func handleZSCORE(params types.HandlerFuncParams) ([]byte, error) { func handleZSCORE(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zscoreKeyFunc(params.Command) keys, err := zscoreKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -990,7 +989,7 @@ func handleZSCORE(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(score), score)), nil return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(score), score)), nil
} }
func handleZREMRANGEBYSCORE(params types.HandlerFuncParams) ([]byte, error) { func handleZREMRANGEBYSCORE(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zremrangebyscoreKeyFunc(params.Command) keys, err := zremrangebyscoreKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1034,7 +1033,7 @@ func handleZREMRANGEBYSCORE(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
} }
func handleZREMRANGEBYRANK(params types.HandlerFuncParams) ([]byte, error) { func handleZREMRANGEBYRANK(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zremrangebyrankKeyFunc(params.Command) keys, err := zremrangebyrankKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1099,7 +1098,7 @@ func handleZREMRANGEBYRANK(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
} }
func handleZREMRANGEBYLEX(params types.HandlerFuncParams) ([]byte, error) { func handleZREMRANGEBYLEX(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zremrangebylexKeyFunc(params.Command) keys, err := zremrangebylexKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1146,7 +1145,7 @@ func handleZREMRANGEBYLEX(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
} }
func handleZRANGE(params types.HandlerFuncParams) ([]byte, error) { func handleZRANGE(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zrangeKeyCount(params.Command) keys, err := zrangeKeyCount(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1286,7 +1285,7 @@ func handleZRANGE(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) { func handleZRANGESTORE(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := zrangeStoreKeyFunc(params.Command) keys, err := zrangeStoreKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1428,7 +1427,7 @@ func handleZRANGESTORE(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", newSortedSet.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", newSortedSet.Cardinality())), nil
} }
func handleZUNION(params types.HandlerFuncParams) ([]byte, error) { func handleZUNION(params internal.HandlerFuncParams) ([]byte, error) {
if _, err := zunionKeyFunc(params.Command); err != nil { if _, err := zunionKeyFunc(params.Command); err != nil {
return nil, err return nil, err
} }
@@ -1482,7 +1481,7 @@ func handleZUNION(params types.HandlerFuncParams) ([]byte, error) {
return []byte(res), nil return []byte(res), nil
} }
func handleZUNIONSTORE(params types.HandlerFuncParams) ([]byte, error) { func handleZUNIONSTORE(params internal.HandlerFuncParams) ([]byte, error) {
k, err := zunionstoreKeyFunc(params.Command) k, err := zunionstoreKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1548,8 +1547,8 @@ func handleZUNIONSTORE(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", union.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", union.Cardinality())), nil
} }
func Commands() []types.Command { func Commands() []internal.Command {
return []types.Command{ return []internal.Command{
{ {
Command: "zadd", Command: "zadd",
Module: constants.SortedSetModule, Module: constants.SortedSetModule,

View File

@@ -16,48 +16,48 @@ package sorted_set
import ( import (
"errors" "errors"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"slices" "slices"
"strings" "strings"
) )
func zaddKeyFunc(cmd []string) (types.AccessKeys, error) { func zaddKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 4 { if len(cmd) < 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func zcardKeyFunc(cmd []string) (types.AccessKeys, error) { func zcardKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func zcountKeyFunc(cmd []string) (types.AccessKeys, error) { func zcountKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func zdiffKeyFunc(cmd []string) (types.AccessKeys, error) { func zdiffKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
withscoresIndex := slices.IndexFunc(cmd, func(s string) bool { withscoresIndex := slices.IndexFunc(cmd, func(s string) bool {
@@ -65,45 +65,45 @@ func zdiffKeyFunc(cmd []string) (types.AccessKeys, error) {
}) })
if withscoresIndex == -1 { if withscoresIndex == -1 {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:withscoresIndex], ReadKeys: cmd[1:withscoresIndex],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func zdiffstoreKeyFunc(cmd []string) (types.AccessKeys, error) { func zdiffstoreKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[2:], ReadKeys: cmd[2:],
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func zincrbyKeyFunc(cmd []string) (types.AccessKeys, error) { func zincrbyKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func zinterKeyFunc(cmd []string) (types.AccessKeys, error) { func zinterKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
endIdx := slices.IndexFunc(cmd[1:], func(s string) bool { endIdx := slices.IndexFunc(cmd[1:], func(s string) bool {
if strings.EqualFold(s, "WEIGHTS") || if strings.EqualFold(s, "WEIGHTS") ||
@@ -114,25 +114,25 @@ func zinterKeyFunc(cmd []string) (types.AccessKeys, error) {
return false return false
}) })
if endIdx == -1 { if endIdx == -1 {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
if endIdx >= 1 { if endIdx >= 1 {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:endIdx], ReadKeys: cmd[1:endIdx],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
func zinterstoreKeyFunc(cmd []string) (types.AccessKeys, error) { func zinterstoreKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
endIdx := slices.IndexFunc(cmd[1:], func(s string) bool { endIdx := slices.IndexFunc(cmd[1:], func(s string) bool {
if strings.EqualFold(s, "WEIGHTS") || if strings.EqualFold(s, "WEIGHTS") ||
@@ -143,192 +143,192 @@ func zinterstoreKeyFunc(cmd []string) (types.AccessKeys, error) {
return false return false
}) })
if endIdx == -1 { if endIdx == -1 {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[2:], ReadKeys: cmd[2:],
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
if endIdx >= 3 { if endIdx >= 3 {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[2:endIdx], ReadKeys: cmd[2:endIdx],
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
func zmpopKeyFunc(cmd []string) (types.AccessKeys, error) { func zmpopKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
endIdx := slices.IndexFunc(cmd, func(s string) bool { endIdx := slices.IndexFunc(cmd, func(s string) bool {
return slices.Contains([]string{"MIN", "MAX", "COUNT"}, strings.ToUpper(s)) return slices.Contains([]string{"MIN", "MAX", "COUNT"}, strings.ToUpper(s))
}) })
if endIdx == -1 { if endIdx == -1 {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:], WriteKeys: cmd[1:],
}, nil }, nil
} }
if endIdx >= 2 { if endIdx >= 2 {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:endIdx], WriteKeys: cmd[1:endIdx],
}, nil }, nil
} }
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
func zmscoreKeyFunc(cmd []string) (types.AccessKeys, error) { func zmscoreKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func zpopKeyFunc(cmd []string) (types.AccessKeys, error) { func zpopKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 2 || len(cmd) > 3 { if len(cmd) < 2 || len(cmd) > 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func zrandmemberKeyFunc(cmd []string) (types.AccessKeys, error) { func zrandmemberKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 2 || len(cmd) > 4 { if len(cmd) < 2 || len(cmd) > 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func zrankKeyFunc(cmd []string) (types.AccessKeys, error) { func zrankKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 || len(cmd) > 4 { if len(cmd) < 3 || len(cmd) > 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func zremKeyFunc(cmd []string) (types.AccessKeys, error) { func zremKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func zrevrankKeyFunc(cmd []string) (types.AccessKeys, error) { func zrevrankKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func zscoreKeyFunc(cmd []string) (types.AccessKeys, error) { func zscoreKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 3 { if len(cmd) != 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func zremrangebylexKeyFunc(cmd []string) (types.AccessKeys, error) { func zremrangebylexKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func zremrangebyrankKeyFunc(cmd []string) (types.AccessKeys, error) { func zremrangebyrankKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func zremrangebyscoreKeyFunc(cmd []string) (types.AccessKeys, error) { func zremrangebyscoreKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func zlexcountKeyFunc(cmd []string) (types.AccessKeys, error) { func zlexcountKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func zrangeKeyCount(cmd []string) (types.AccessKeys, error) { func zrangeKeyCount(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 4 || len(cmd) > 10 { if len(cmd) < 4 || len(cmd) > 10 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func zrangeStoreKeyFunc(cmd []string) (types.AccessKeys, error) { func zrangeStoreKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 5 || len(cmd) > 11 { if len(cmd) < 5 || len(cmd) > 11 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[2:3], ReadKeys: cmd[2:3],
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func zunionKeyFunc(cmd []string) (types.AccessKeys, error) { func zunionKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
endIdx := slices.IndexFunc(cmd[1:], func(s string) bool { endIdx := slices.IndexFunc(cmd[1:], func(s string) bool {
if strings.EqualFold(s, "WEIGHTS") || if strings.EqualFold(s, "WEIGHTS") ||
@@ -339,25 +339,25 @@ func zunionKeyFunc(cmd []string) (types.AccessKeys, error) {
return false return false
}) })
if endIdx == -1 { if endIdx == -1 {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:], ReadKeys: cmd[1:],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
if endIdx >= 1 { if endIdx >= 1 {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:endIdx], ReadKeys: cmd[1:endIdx],
WriteKeys: cmd[1:endIdx], WriteKeys: cmd[1:endIdx],
}, nil }, nil
} }
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
func zunionstoreKeyFunc(cmd []string) (types.AccessKeys, error) { func zunionstoreKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
endIdx := slices.IndexFunc(cmd[1:], func(s string) bool { endIdx := slices.IndexFunc(cmd[1:], func(s string) bool {
if strings.EqualFold(s, "WEIGHTS") || if strings.EqualFold(s, "WEIGHTS") ||
@@ -368,18 +368,18 @@ func zunionstoreKeyFunc(cmd []string) (types.AccessKeys, error) {
return false return false
}) })
if endIdx == -1 { if endIdx == -1 {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[2:], ReadKeys: cmd[2:],
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
if endIdx >= 1 { if endIdx >= 1 {
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[2:endIdx], ReadKeys: cmd[2:endIdx],
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }

View File

@@ -19,10 +19,9 @@ import (
"fmt" "fmt"
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
) )
func handleSetRange(params types.HandlerFuncParams) ([]byte, error) { func handleSetRange(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := setRangeKeyFunc(params.Command) keys, err := setRangeKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -97,7 +96,7 @@ func handleSetRange(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", len(strRunes))), nil return []byte(fmt.Sprintf(":%d\r\n", len(strRunes))), nil
} }
func handleStrLen(params types.HandlerFuncParams) ([]byte, error) { func handleStrLen(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := strLenKeyFunc(params.Command) keys, err := strLenKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -123,7 +122,7 @@ func handleStrLen(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf(":%d\r\n", len(value))), nil return []byte(fmt.Sprintf(":%d\r\n", len(value))), nil
} }
func handleSubStr(params types.HandlerFuncParams) ([]byte, error) { func handleSubStr(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := subStrKeyFunc(params.Command) keys, err := subStrKeyFunc(params.Command)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -186,8 +185,8 @@ func handleSubStr(params types.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(str), str)), nil return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(str), str)), nil
} }
func Commands() []types.Command { func Commands() []internal.Command {
return []types.Command{ return []internal.Command{
{ {
Command: "setrange", Command: "setrange",
Module: constants.StringModule, Module: constants.StringModule,

View File

@@ -16,37 +16,37 @@ package str
import ( import (
"errors" "errors"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
) )
func setRangeKeyFunc(cmd []string) (types.AccessKeys, error) { func setRangeKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: make([]string, 0), ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2], WriteKeys: cmd[1:2],
}, nil }, nil
} }
func strLenKeyFunc(cmd []string) (types.AccessKeys, error) { func strLenKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),
}, nil }, nil
} }
func subStrKeyFunc(cmd []string) (types.AccessKeys, error) { func subStrKeyFunc(cmd []string) (internal.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse) return internal.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return types.AccessKeys{ return internal.AccessKeys{
Channels: make([]string, 0), Channels: make([]string, 0),
ReadKeys: cmd[1:2], ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0), WriteKeys: make([]string, 0),

View File

@@ -32,12 +32,12 @@ type FSMOpts struct {
Config config.Config Config config.Config
EchoVault types.EchoVault EchoVault types.EchoVault
GetState func() map[string]internal.KeyData GetState func() map[string]internal.KeyData
GetCommand func(command string) (types.Command, error) GetCommand func(command string) (internal.Command, error)
DeleteKey func(ctx context.Context, key string) error DeleteKey func(ctx context.Context, key string) error
StartSnapshot func() StartSnapshot func()
FinishSnapshot func() FinishSnapshot func()
SetLatestSnapshotTime func(msec int64) SetLatestSnapshotTime func(msec int64)
GetHandlerFuncParams func(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams GetHandlerFuncParams func(ctx context.Context, cmd []string, conn *net.Conn) internal.HandlerFuncParams
} }
type FSM struct { type FSM struct {
@@ -99,7 +99,7 @@ func (fsm *FSM) Apply(log *raft.Log) interface{} {
handler := command.HandlerFunc handler := command.HandlerFunc
subCommand, ok := internal.GetSubCommand(command, request.CMD).(types.SubCommand) subCommand, ok := internal.GetSubCommand(command, request.CMD).(internal.SubCommand)
if ok { if ok {
handler = subCommand.HandlerFunc handler = subCommand.HandlerFunc
} }

View File

@@ -36,12 +36,12 @@ type Opts struct {
Config config.Config Config config.Config
EchoVault types.EchoVault EchoVault types.EchoVault
GetState func() map[string]internal.KeyData GetState func() map[string]internal.KeyData
GetCommand func(command string) (types.Command, error) GetCommand func(command string) (internal.Command, error)
DeleteKey func(ctx context.Context, key string) error DeleteKey func(ctx context.Context, key string) error
StartSnapshot func() StartSnapshot func()
FinishSnapshot func() FinishSnapshot func()
SetLatestSnapshotTime func(msec int64) SetLatestSnapshotTime func(msec int64)
GetHandlerFuncParams func(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams GetHandlerFuncParams func(ctx context.Context, cmd []string, conn *net.Conn) internal.HandlerFuncParams
} }
type Raft struct { type Raft struct {

View File

@@ -14,7 +14,12 @@
package internal package internal
import "time" import (
"context"
"github.com/echovault/echovault/internal/clock"
"net"
"time"
)
type KeyData struct { type KeyData struct {
Value interface{} Value interface{}
@@ -41,3 +46,59 @@ type SnapshotObject struct {
State map[string]KeyData State map[string]KeyData
LatestSnapshotMilliseconds int64 LatestSnapshotMilliseconds int64
} }
type AccessKeys struct {
Channels []string
ReadKeys []string
WriteKeys []string
}
type KeyExtractionFunc func(cmd []string) (AccessKeys, error)
type HandlerFuncParams struct {
Context context.Context
Command []string
Connection *net.Conn
KeyLock func(ctx context.Context, key string) (bool, error)
KeyUnlock func(ctx context.Context, key string)
KeyRLock func(ctx context.Context, key string) (bool, error)
KeyRUnlock func(ctx context.Context, key string)
KeyExists func(ctx context.Context, key string) bool
CreateKeyAndLock func(ctx context.Context, key string) (bool, error)
GetValue func(ctx context.Context, key string) interface{}
SetValue func(ctx context.Context, key string, value interface{}) error
GetExpiry func(ctx context.Context, key string) time.Time
SetExpiry func(ctx context.Context, key string, expire time.Time, touch bool)
RemoveExpiry func(ctx context.Context, key string)
DeleteKey func(ctx context.Context, key string) error
GetClock func() clock.Clock
GetAllCommands func() []Command
GetACL func() interface{}
GetPubSub func() interface{}
TakeSnapshot func() error
RewriteAOF func() error
GetLatestSnapshotTime func() int64
}
type HandlerFunc func(params HandlerFuncParams) ([]byte, error)
type Command struct {
Command string
Module string
Categories []string
Description string
SubCommands []SubCommand
Sync bool // Specifies if command should be synced across replication cluster
KeyExtractionFunc
HandlerFunc
}
type SubCommand struct {
Command string
Module string
Categories []string
Description string
Sync bool // Specifies if sub-command should be synced across replication cluster
KeyExtractionFunc
HandlerFunc
}

View File

@@ -21,7 +21,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"io" "io"
"log" "log"
"math/big" "math/big"
@@ -129,7 +128,7 @@ func GetIPAddress() (string, error) {
return localAddr, nil return localAddr, nil
} }
func GetSubCommand(command types.Command, cmd []string) interface{} { func GetSubCommand(command Command, cmd []string) interface{} {
if len(command.SubCommands) == 0 || len(cmd) < 2 { if len(command.SubCommands) == 0 || len(cmd) < 2 {
return nil return nil
} }
@@ -141,7 +140,7 @@ func GetSubCommand(command types.Command, cmd []string) interface{} {
return nil return nil
} }
func IsWriteCommand(command types.Command, subCommand types.SubCommand) bool { func IsWriteCommand(command Command, subCommand SubCommand) bool {
return slices.Contains(append(command.Categories, subCommand.Categories...), constants.WriteCategory) return slices.Contains(append(command.Categories, subCommand.Categories...), constants.WriteCategory)
} }

View File

@@ -39,7 +39,6 @@ import (
"github.com/echovault/echovault/internal/raft" "github.com/echovault/echovault/internal/raft"
"github.com/echovault/echovault/internal/snapshot" "github.com/echovault/echovault/internal/snapshot"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"io" "io"
"log" "log"
"net" "net"
@@ -51,7 +50,8 @@ import (
type EchoVault struct { type EchoVault struct {
// clock is an implementation of a time interface that allows mocking of time functions during testing. // clock is an implementation of a time interface that allows mocking of time functions during testing.
clock clock.Clock clock clock.Clock
getClock func() clock.Clock
// config holds the echovault configuration variables. // config holds the echovault configuration variables.
config config.Config config config.Config
@@ -82,7 +82,8 @@ type EchoVault struct {
} }
// Holds the list of all commands supported by the echovault. // Holds the list of all commands supported by the echovault.
commands []types.Command commands []internal.Command
getCommands func() []internal.Command
raft *raft.Raft // The raft replication layer for the echovault. raft *raft.Raft // The raft replication layer for the echovault.
memberList *memberlist.MemberList // The memberlist layer for the echovault. memberList *memberlist.MemberList // The memberlist layer for the echovault.
@@ -90,7 +91,10 @@ type EchoVault struct {
context context.Context context context.Context
acl *acl.ACL acl *acl.ACL
pubSub *pubsub.PubSub getACL func() interface{}
pubSub *pubsub.PubSub
getPubSub func() interface{}
snapshotInProgress atomic.Bool // Atomic boolean that's true when actively taking a snapshot. snapshotInProgress atomic.Bool // Atomic boolean that's true when actively taking a snapshot.
rewriteAOFInProgress atomic.Bool // Atomic boolean that's true when actively rewriting AOF file is in progress. rewriteAOFInProgress atomic.Bool // Atomic boolean that's true when actively rewriting AOF file is in progress.
@@ -119,15 +123,6 @@ func WithConfig(config config.Config) func(echovault *EchoVault) {
} }
} }
// WithCommands is an options for the NewEchoVault function that allows you to pass a
// list of commands that should be supported by your EchoVault instance.
// If you don't pass this option, EchoVault will start with no commands loaded.
func WithCommands(commands []types.Command) func(echovault *EchoVault) {
return func(echovault *EchoVault) {
echovault.commands = commands
}
}
// NewEchoVault creates a new EchoVault instance. // NewEchoVault creates a new EchoVault instance.
// This functions accepts the WithContext, WithConfig and WithCommands options. // This functions accepts the WithContext, WithConfig and WithCommands options.
func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) { func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
@@ -138,8 +133,8 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
store: make(map[string]internal.KeyData), store: make(map[string]internal.KeyData),
keyLocks: make(map[string]*sync.RWMutex), keyLocks: make(map[string]*sync.RWMutex),
keyCreationLock: &sync.Mutex{}, keyCreationLock: &sync.Mutex{},
commands: func() []types.Command { commands: func() []internal.Command {
var commands []types.Command var commands []internal.Command
commands = append(commands, acl.Commands()...) commands = append(commands, acl.Commands()...)
commands = append(commands, admin.Commands()...) commands = append(commands, admin.Commands()...)
commands = append(commands, generic.Commands()...) commands = append(commands, generic.Commands()...)
@@ -163,11 +158,27 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
internal.ContextServerID(echovault.config.ServerID), internal.ContextServerID(echovault.config.ServerID),
) )
// Function for server commands retrieval
echovault.getCommands = func() []internal.Command {
return echovault.commands
}
// Function for clock retrieval
echovault.getClock = func() clock.Clock {
return echovault.clock
}
// Set up ACL module // Set up ACL module
echovault.acl = acl.NewACL(echovault.config) echovault.acl = acl.NewACL(echovault.config)
echovault.getACL = func() interface{} {
return echovault.acl
}
// Set up Pub/Sub module // Set up Pub/Sub module
echovault.pubSub = pubsub.NewPubSub() echovault.pubSub = pubsub.NewPubSub()
echovault.getPubSub = func() interface{} {
return echovault.pubSub
}
if echovault.isInCluster() { if echovault.isInCluster() {
echovault.raft = raft.NewRaft(raft.Opts{ echovault.raft = raft.NewRaft(raft.Opts{
@@ -208,7 +219,7 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
snapshot.WithStartSnapshotFunc(echovault.startSnapshot), snapshot.WithStartSnapshotFunc(echovault.startSnapshot),
snapshot.WithFinishSnapshotFunc(echovault.finishSnapshot), snapshot.WithFinishSnapshotFunc(echovault.finishSnapshot),
snapshot.WithSetLatestSnapshotTimeFunc(echovault.setLatestSnapshot), snapshot.WithSetLatestSnapshotTimeFunc(echovault.setLatestSnapshot),
snapshot.WithGetLatestSnapshotTimeFunc(echovault.GetLatestSnapshotTime), snapshot.WithGetLatestSnapshotTimeFunc(echovault.getLatestSnapshotTime),
snapshot.WithGetStateFunc(func() map[string]internal.KeyData { snapshot.WithGetStateFunc(func() map[string]internal.KeyData {
state := make(map[string]internal.KeyData) state := make(map[string]internal.KeyData)
for k, v := range echovault.getState() { for k, v := range echovault.getState() {
@@ -471,8 +482,8 @@ func (server *EchoVault) Start() {
server.startTCP() server.startTCP()
} }
// TakeSnapshot triggers a snapshot when called. // takeSnapshot triggers a snapshot when called.
func (server *EchoVault) TakeSnapshot() error { func (server *EchoVault) takeSnapshot() error {
if server.snapshotInProgress.Load() { if server.snapshotInProgress.Load() {
return errors.New("snapshot already in progress") return errors.New("snapshot already in progress")
} }
@@ -494,11 +505,6 @@ func (server *EchoVault) TakeSnapshot() error {
return nil return nil
} }
// GetClock returns the server's clock implementation
func (server *EchoVault) GetClock() clock.Clock {
return server.clock
}
func (server *EchoVault) startSnapshot() { func (server *EchoVault) startSnapshot() {
server.snapshotInProgress.Store(true) server.snapshotInProgress.Store(true)
} }
@@ -511,8 +517,8 @@ func (server *EchoVault) setLatestSnapshot(msec int64) {
server.latestSnapshotMilliseconds.Store(msec) server.latestSnapshotMilliseconds.Store(msec)
} }
// GetLatestSnapshotTime returns the latest snapshot time in unix epoch milliseconds. // getLatestSnapshotTime returns the latest snapshot time in unix epoch milliseconds.
func (server *EchoVault) GetLatestSnapshotTime() int64 { func (server *EchoVault) getLatestSnapshotTime() int64 {
return server.latestSnapshotMilliseconds.Load() return server.latestSnapshotMilliseconds.Load()
} }
@@ -524,8 +530,8 @@ func (server *EchoVault) finishRewriteAOF() {
server.rewriteAOFInProgress.Store(false) server.rewriteAOFInProgress.Store(false)
} }
// RewriteAOF triggers an AOF compaction when running in standalone mode. // rewriteAOF triggers an AOF compaction when running in standalone mode.
func (server *EchoVault) RewriteAOF() error { func (server *EchoVault) rewriteAOF() error {
if server.rewriteAOFInProgress.Load() { if server.rewriteAOFInProgress.Load() {
return errors.New("aof rewrite in progress") return errors.New("aof rewrite in progress")
} }

View File

@@ -247,7 +247,7 @@ func (server *EchoVault) SetExpiry(ctx context.Context, key string, expireAt tim
// RemoveExpiry is called by commands that remove key expiry (e.g. PERSIST). // RemoveExpiry is called by commands that remove key expiry (e.g. PERSIST).
// The key must be locked prior ro calling this function. // The key must be locked prior ro calling this function.
func (server *EchoVault) RemoveExpiry(key string) { func (server *EchoVault) RemoveExpiry(_ context.Context, key string) {
// Reset expiry time // Reset expiry time
server.store[key] = internal.KeyData{ server.store[key] = internal.KeyData{
Value: server.store[key].Value, Value: server.store[key].Value,
@@ -292,7 +292,7 @@ func (server *EchoVault) DeleteKey(ctx context.Context, key string) error {
} }
// Remove key expiry. // Remove key expiry.
server.RemoveExpiry(key) server.RemoveExpiry(ctx, key)
// Delete the key from keyLocks and store. // Delete the key from keyLocks and store.
delete(server.keyLocks, key) delete(server.keyLocks, key)

View File

@@ -20,52 +20,42 @@ import (
"fmt" "fmt"
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"net" "net"
"strings" "strings"
) )
func (server *EchoVault) GetAllCommands() []types.Command { func (server *EchoVault) getCommand(cmd string) (internal.Command, error) {
return server.commands
}
func (server *EchoVault) GetACL() interface{} {
return server.acl
}
func (server *EchoVault) GetPubSub() interface{} {
return server.pubSub
}
func (server *EchoVault) getCommand(cmd string) (types.Command, error) {
for _, command := range server.commands { for _, command := range server.commands {
if strings.EqualFold(command.Command, cmd) { if strings.EqualFold(command.Command, cmd) {
return command, nil return command, nil
} }
} }
return types.Command{}, fmt.Errorf("command %s not supported", cmd) return internal.Command{}, fmt.Errorf("command %s not supported", cmd)
} }
func (server *EchoVault) getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams { func (server *EchoVault) getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) internal.HandlerFuncParams {
return types.HandlerFuncParams{ return internal.HandlerFuncParams{
Context: ctx, Context: ctx,
Command: cmd, Command: cmd,
Connection: conn, Connection: conn,
KeyExists: server.KeyExists, KeyExists: server.KeyExists,
CreateKeyAndLock: server.CreateKeyAndLock, CreateKeyAndLock: server.CreateKeyAndLock,
KeyLock: server.KeyLock, KeyLock: server.KeyLock,
KeyRLock: server.KeyRLock, KeyRLock: server.KeyRLock,
KeyUnlock: server.KeyUnlock, KeyUnlock: server.KeyUnlock,
KeyRUnlock: server.KeyRUnlock, KeyRUnlock: server.KeyRUnlock,
GetValue: server.GetValue, GetValue: server.GetValue,
SetValue: server.SetValue, SetValue: server.SetValue,
GetClock: server.GetClock, GetExpiry: server.GetExpiry,
GetExpiry: server.GetExpiry, SetExpiry: server.SetExpiry,
SetExpiry: server.SetExpiry, DeleteKey: server.DeleteKey,
DeleteKey: server.DeleteKey, TakeSnapshot: server.takeSnapshot,
GetPubSub: server.GetPubSub, GetLatestSnapshotTime: server.getLatestSnapshotTime,
GetACL: server.GetACL, RewriteAOF: server.rewriteAOF,
GetAllCommands: server.GetAllCommands, GetClock: server.getClock,
GetPubSub: server.getPubSub,
GetACL: server.getACL,
GetAllCommands: server.getCommands,
} }
} }
@@ -83,7 +73,7 @@ func (server *EchoVault) handleCommand(ctx context.Context, message []byte, conn
synchronize := command.Sync synchronize := command.Sync
handler := command.HandlerFunc handler := command.HandlerFunc
subCommand, ok := internal.GetSubCommand(command, cmd).(types.SubCommand) subCommand, ok := internal.GetSubCommand(command, cmd).(internal.SubCommand)
if ok { if ok {
synchronize = subCommand.Sync synchronize = subCommand.Sync
handler = subCommand.HandlerFunc handler = subCommand.HandlerFunc

View File

@@ -16,8 +16,6 @@ package types
import ( import (
"context" "context"
"github.com/echovault/echovault/internal/clock"
"net"
"time" "time"
) )
@@ -32,67 +30,6 @@ type EchoVault interface {
SetValue(ctx context.Context, key string, value interface{}) error SetValue(ctx context.Context, key string, value interface{}) error
GetExpiry(ctx context.Context, key string) time.Time GetExpiry(ctx context.Context, key string) time.Time
SetExpiry(ctx context.Context, key string, expire time.Time, touch bool) SetExpiry(ctx context.Context, key string, expire time.Time, touch bool)
RemoveExpiry(key string) RemoveExpiry(ctx context.Context, key string)
DeleteKey(ctx context.Context, key string) error DeleteKey(ctx context.Context, key string) error
GetClock() clock.Clock
GetAllCommands() []Command
GetACL() interface{}
GetPubSub() interface{}
TakeSnapshot() error
RewriteAOF() error
GetLatestSnapshotTime() int64
}
type AccessKeys struct {
Channels []string
ReadKeys []string
WriteKeys []string
}
type KeyExtractionFunc func(cmd []string) (AccessKeys, error)
type HandlerFuncParams struct {
Context context.Context
Command []string
Connection *net.Conn
KeyLock func(ctx context.Context, key string) (bool, error)
KeyUnlock func(ctx context.Context, key string)
KeyRLock func(ctx context.Context, key string) (bool, error)
KeyRUnlock func(ctx context.Context, key string)
KeyExists func(ctx context.Context, key string) bool
CreateKeyAndLock func(ctx context.Context, key string) (bool, error)
GetValue func(ctx context.Context, key string) interface{}
SetValue func(ctx context.Context, key string, value interface{}) error
GetExpiry func(ctx context.Context, key string) time.Time
SetExpiry func(ctx context.Context, key string, expire time.Time, touch bool)
RemoveExpiry func(key string)
DeleteKey func(ctx context.Context, key string) error
GetClock func() clock.Clock
GetAllCommands func() []Command
GetACL func() interface{}
GetPubSub func() interface{}
TakeSnapshot func() error
RewriteAOF func() error
GetLatestSnapshotTime func() int64
}
type HandlerFunc func(params HandlerFuncParams) ([]byte, error)
type SubCommand struct {
Command string
Module string
Categories []string
Description string
Sync bool // Specifies if sub-command should be synced across replication cluster
KeyExtractionFunc
HandlerFunc
}
type Command struct {
Command string
Module string
Categories []string
Description string
SubCommands []SubCommand
Sync bool // Specifies if command should be synced across replication cluster
KeyExtractionFunc
HandlerFunc
} }

View File

@@ -23,10 +23,12 @@ import (
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
"reflect"
"slices" "slices"
"strings" "strings"
"sync" "sync"
"testing" "testing"
"unsafe"
) )
var bindAddr string var bindAddr string
@@ -64,12 +66,22 @@ func setUpServer(bindAddr string, port uint16, requirePass bool, aclConfig strin
) )
// Add the initial test users to the ACL module // Add the initial test users to the ACL module
a := mockServer.GetACL().(*acl.ACL) a := getACL(mockServer)
a.AddUsers(generateInitialTestUsers()) a.AddUsers(generateInitialTestUsers())
return mockServer return mockServer
} }
func getUnexportedField(field reflect.Value) interface{} {
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
}
func getACL(mockServer *echovault.EchoVault) *acl.ACL {
method := getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getACL"))
f := method.(func() interface{})
return f().(*acl.ACL)
}
func generateInitialTestUsers() []*acl.User { func generateInitialTestUsers() []*acl.User {
// User with both hash password and plaintext password // User with both hash password and plaintext password
withPasswordUser := acl.CreateUser("with_password_user") withPasswordUser := acl.CreateUser("with_password_user")
@@ -459,10 +471,7 @@ func Test_HandleSetUser(t *testing.T) {
}() }()
wg.Wait() wg.Wait()
a, ok := mockServer.GetACL().(*acl.ACL) a := getACL(mockServer)
if !ok {
t.Error("error loading ACL module")
}
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil { if err != nil {
@@ -1055,7 +1064,7 @@ func Test_HandleGetUser(t *testing.T) {
}() }()
wg.Wait() wg.Wait()
a, _ := mockServer.GetACL().(*acl.ACL) a := getACL(mockServer)
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil { if err != nil {
@@ -1208,7 +1217,7 @@ func Test_HandleDelUser(t *testing.T) {
}() }()
wg.Wait() wg.Wait()
a, _ := mockServer.GetACL().(*acl.ACL) a := getACL(mockServer)
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil { if err != nil {
@@ -1358,7 +1367,7 @@ func Test_HandleList(t *testing.T) {
}() }()
wg.Wait() wg.Wait()
a, _ := mockServer.GetACL().(*acl.ACL) a := getACL(mockServer)
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil { if err != nil {

View File

@@ -18,14 +18,16 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
"reflect"
"strings" "strings"
"testing" "testing"
"unsafe"
) )
var mockServer *echovault.EchoVault var mockServer *echovault.EchoVault
@@ -39,11 +41,17 @@ func init() {
) )
} }
func getHandler(commands ...string) types.HandlerFunc { func getUnexportedField(field reflect.Value) interface{} {
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
}
func getHandler(commands ...string) internal.HandlerFunc {
if len(commands) == 0 { if len(commands) == 0 {
return nil return nil
} }
for _, c := range mockServer.GetAllCommands() { getCommands :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
for _, c := range getCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 { if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler // Get command handler
return c.HandlerFunc return c.HandlerFunc
@@ -60,12 +68,14 @@ func getHandler(commands ...string) types.HandlerFunc {
return nil return nil
} }
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams { func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) internal.HandlerFuncParams {
return types.HandlerFuncParams{ getCommands :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
return internal.HandlerFuncParams{
Context: ctx, Context: ctx,
Command: cmd, Command: cmd,
Connection: conn, Connection: conn,
GetAllCommands: mockServer.GetAllCommands, GetAllCommands: getCommands,
} }
} }

View File

@@ -18,14 +18,16 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
"reflect"
"strings" "strings"
"testing" "testing"
"unsafe"
) )
var mockServer *echovault.EchoVault var mockServer *echovault.EchoVault
@@ -39,11 +41,17 @@ func init() {
) )
} }
func getHandler(commands ...string) types.HandlerFunc { func getUnexportedField(field reflect.Value) interface{} {
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
}
func getHandler(commands ...string) internal.HandlerFunc {
if len(commands) == 0 { if len(commands) == 0 {
return nil return nil
} }
for _, c := range mockServer.GetAllCommands() { getCommands :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
for _, c := range getCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 { if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler // Get command handler
return c.HandlerFunc return c.HandlerFunc
@@ -60,8 +68,8 @@ func getHandler(commands ...string) types.HandlerFunc {
return nil return nil
} }
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams { func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) internal.HandlerFuncParams {
return types.HandlerFuncParams{ return internal.HandlerFuncParams{
Context: ctx, Context: ctx,
Command: cmd, Command: cmd,
Connection: conn, Connection: conn,

View File

@@ -19,16 +19,18 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/clock" "github.com/echovault/echovault/internal/clock"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
"reflect"
"strings" "strings"
"testing" "testing"
"time" "time"
"unsafe"
) )
var mockServer *echovault.EchoVault var mockServer *echovault.EchoVault
@@ -51,11 +53,17 @@ func init() {
) )
} }
func getHandler(commands ...string) types.HandlerFunc { func getUnexportedField(field reflect.Value) interface{} {
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
}
func getHandler(commands ...string) internal.HandlerFunc {
if len(commands) == 0 { if len(commands) == 0 {
return nil return nil
} }
for _, c := range mockServer.GetAllCommands() { getCommands :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
for _, c := range getCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 { if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler // Get command handler
return c.HandlerFunc return c.HandlerFunc
@@ -72,8 +80,10 @@ func getHandler(commands ...string) types.HandlerFunc {
return nil return nil
} }
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams { func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) internal.HandlerFuncParams {
return types.HandlerFuncParams{ getClock :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getClock")).(func() clock.Clock)
return internal.HandlerFuncParams{
Context: ctx, Context: ctx,
Command: cmd, Command: cmd,
Connection: conn, Connection: conn,
@@ -85,10 +95,10 @@ func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) typ
KeyRUnlock: mockServer.KeyRUnlock, KeyRUnlock: mockServer.KeyRUnlock,
GetValue: mockServer.GetValue, GetValue: mockServer.GetValue,
SetValue: mockServer.SetValue, SetValue: mockServer.SetValue,
GetClock: mockServer.GetClock,
GetExpiry: mockServer.GetExpiry, GetExpiry: mockServer.GetExpiry,
SetExpiry: mockServer.SetExpiry, SetExpiry: mockServer.SetExpiry,
DeleteKey: mockServer.DeleteKey, DeleteKey: mockServer.DeleteKey,
GetClock: getClock,
} }
} }

View File

@@ -19,15 +19,17 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
"reflect"
"slices" "slices"
"strings" "strings"
"testing" "testing"
"unsafe"
) )
var mockServer *echovault.EchoVault var mockServer *echovault.EchoVault
@@ -41,11 +43,17 @@ func init() {
) )
} }
func getHandler(commands ...string) types.HandlerFunc { func getUnexportedField(field reflect.Value) interface{} {
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
}
func getHandler(commands ...string) internal.HandlerFunc {
if len(commands) == 0 { if len(commands) == 0 {
return nil return nil
} }
for _, c := range mockServer.GetAllCommands() { getCommands :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
for _, c := range getCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 { if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler // Get command handler
return c.HandlerFunc return c.HandlerFunc
@@ -62,8 +70,8 @@ func getHandler(commands ...string) types.HandlerFunc {
return nil return nil
} }
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams { func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) internal.HandlerFuncParams {
return types.HandlerFuncParams{ return internal.HandlerFuncParams{
Context: ctx, Context: ctx,
Command: cmd, Command: cmd,
Connection: conn, Connection: conn,

View File

@@ -19,14 +19,16 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
"reflect"
"strings" "strings"
"testing" "testing"
"unsafe"
) )
var mockServer *echovault.EchoVault var mockServer *echovault.EchoVault
@@ -40,11 +42,17 @@ func init() {
) )
} }
func getHandler(commands ...string) types.HandlerFunc { func getUnexportedField(field reflect.Value) interface{} {
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
}
func getHandler(commands ...string) internal.HandlerFunc {
if len(commands) == 0 { if len(commands) == 0 {
return nil return nil
} }
for _, c := range mockServer.GetAllCommands() { getCommands :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
for _, c := range getCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 { if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler // Get command handler
return c.HandlerFunc return c.HandlerFunc
@@ -61,8 +69,8 @@ func getHandler(commands ...string) types.HandlerFunc {
return nil return nil
} }
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams { func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) internal.HandlerFuncParams {
return types.HandlerFuncParams{ return internal.HandlerFuncParams{
Context: ctx, Context: ctx,
Command: cmd, Command: cmd,
Connection: conn, Connection: conn,

View File

@@ -18,18 +18,20 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/modules/pubsub" "github.com/echovault/echovault/internal/modules/pubsub"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
"reflect"
"slices" "slices"
"strings" "strings"
"sync" "sync"
"testing" "testing"
"time" "time"
"unsafe"
) )
var ps *pubsub.PubSub var ps *pubsub.PubSub
@@ -40,7 +42,9 @@ var port uint16 = 7490
func init() { func init() {
mockServer = setUpServer(bindAddr, port) mockServer = setUpServer(bindAddr, port)
ps = mockServer.GetPubSub().(*pubsub.PubSub)
getPubSub := getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getPubSub")).(func() interface{})
ps = getPubSub().(*pubsub.PubSub)
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(1) wg.Add(1)
@@ -63,11 +67,17 @@ func setUpServer(bindAddr string, port uint16) *echovault.EchoVault {
return server return server
} }
func getHandler(commands ...string) types.HandlerFunc { func getUnexportedField(field reflect.Value) interface{} {
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
}
func getHandler(commands ...string) internal.HandlerFunc {
if len(commands) == 0 { if len(commands) == 0 {
return nil return nil
} }
for _, c := range mockServer.GetAllCommands() { getCommands :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
for _, c := range getCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 { if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler // Get command handler
return c.HandlerFunc return c.HandlerFunc
@@ -84,12 +94,14 @@ func getHandler(commands ...string) types.HandlerFunc {
return nil return nil
} }
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn, mockServer *echovault.EchoVault) types.HandlerFuncParams { func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn, mockServer *echovault.EchoVault) internal.HandlerFuncParams {
return types.HandlerFuncParams{ getPubSub :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getPubSub")).(func() interface{})
return internal.HandlerFuncParams{
Context: ctx, Context: ctx,
Command: cmd, Command: cmd,
Connection: conn, Connection: conn,
GetPubSub: mockServer.GetPubSub, GetPubSub: getPubSub,
} }
} }

View File

@@ -19,16 +19,18 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/modules/set" "github.com/echovault/echovault/internal/modules/set"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
"reflect"
"slices" "slices"
"strings" "strings"
"testing" "testing"
"unsafe"
) )
var mockServer *echovault.EchoVault var mockServer *echovault.EchoVault
@@ -42,11 +44,17 @@ func init() {
) )
} }
func getHandler(commands ...string) types.HandlerFunc { func getUnexportedField(field reflect.Value) interface{} {
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
}
func getHandler(commands ...string) internal.HandlerFunc {
if len(commands) == 0 { if len(commands) == 0 {
return nil return nil
} }
for _, c := range mockServer.GetAllCommands() { getCommands :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
for _, c := range getCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 { if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler // Get command handler
return c.HandlerFunc return c.HandlerFunc
@@ -63,8 +71,8 @@ func getHandler(commands ...string) types.HandlerFunc {
return nil return nil
} }
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams { func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) internal.HandlerFuncParams {
return types.HandlerFuncParams{ return internal.HandlerFuncParams{
Context: ctx, Context: ctx,
Command: cmd, Command: cmd,
Connection: conn, Connection: conn,

View File

@@ -19,18 +19,20 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/modules/sorted_set" "github.com/echovault/echovault/internal/modules/sorted_set"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"math" "math"
"net" "net"
"reflect"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"unsafe"
) )
var mockServer *echovault.EchoVault var mockServer *echovault.EchoVault
@@ -44,11 +46,17 @@ func init() {
) )
} }
func getHandler(commands ...string) types.HandlerFunc { func getUnexportedField(field reflect.Value) interface{} {
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
}
func getHandler(commands ...string) internal.HandlerFunc {
if len(commands) == 0 { if len(commands) == 0 {
return nil return nil
} }
for _, c := range mockServer.GetAllCommands() { getCommands :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
for _, c := range getCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 { if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler // Get command handler
return c.HandlerFunc return c.HandlerFunc
@@ -65,8 +73,8 @@ func getHandler(commands ...string) types.HandlerFunc {
return nil return nil
} }
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams { func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) internal.HandlerFuncParams {
return types.HandlerFuncParams{ return internal.HandlerFuncParams{
Context: ctx, Context: ctx,
Command: cmd, Command: cmd,
Connection: conn, Connection: conn,

View File

@@ -23,12 +23,13 @@ import (
"github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault" "github.com/echovault/echovault/pkg/echovault"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
"reflect"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"unsafe"
) )
var mockServer *echovault.EchoVault var mockServer *echovault.EchoVault
@@ -42,11 +43,17 @@ func init() {
) )
} }
func getHandler(commands ...string) types.HandlerFunc { func getUnexportedField(field reflect.Value) interface{} {
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
}
func getHandler(commands ...string) internal.HandlerFunc {
if len(commands) == 0 { if len(commands) == 0 {
return nil return nil
} }
for _, c := range mockServer.GetAllCommands() { getCommands :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
for _, c := range getCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 { if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler // Get command handler
return c.HandlerFunc return c.HandlerFunc
@@ -63,8 +70,8 @@ func getHandler(commands ...string) types.HandlerFunc {
return nil return nil
} }
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams { func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) internal.HandlerFuncParams {
return types.HandlerFuncParams{ return internal.HandlerFuncParams{
Context: ctx, Context: ctx,
Command: cmd, Command: cmd,
Connection: conn, Connection: conn,