Updated KeyExtractionFunc for all the modules

This commit is contained in:
Kelvin Clement Mwinuka
2024-04-21 02:59:07 +08:00
parent 7b88122c25
commit c2c887cd75
18 changed files with 1842 additions and 1378 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -410,8 +410,7 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command types.
return errors.New("not authorised to access any keys") return errors.New("not authorised to access any keys")
} }
// 8. If @read is in the list of categories, check if keys are in IncludedReadKeys // 8. Check if readKeys are in IncludedReadKeys
if slices.Contains(categories, constants.ReadCategory) {
if !slices.ContainsFunc(readKeys, func(key string) bool { if !slices.ContainsFunc(readKeys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKeyGlob string) bool { return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKeyGlob string) bool {
if acl.GlobPatterns[readKeyGlob].Match(key) { if acl.GlobPatterns[readKeyGlob].Match(key) {
@@ -423,10 +422,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command types.
}) { }) {
return fmt.Errorf("not authorised to access the following keys %+v", notAllowed) return fmt.Errorf("not authorised to access the following keys %+v", notAllowed)
} }
}
// 9. If @write is in the list of categories, check if keys are in IncludedWriteKeys // 9. Check if keys are in IncludedWriteKeys
if slices.Contains(categories, constants.WriteCategory) {
if !slices.ContainsFunc(writeKeys, func(key string) bool { if !slices.ContainsFunc(writeKeys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool { return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool {
if acl.GlobPatterns[writeKeyGlob].Match(key) { if acl.GlobPatterns[writeKeyGlob].Match(key) {
@@ -439,7 +436,6 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command types.
return fmt.Errorf("not authorised to access the following keys %+v", notAllowed) return fmt.Errorf("not authorised to access the following keys %+v", notAllowed)
} }
} }
}
return nil return nil
} }

View File

@@ -498,8 +498,12 @@ func Commands() []types.Command {
Categories: []string{constants.ConnectionCategory, constants.SlowCategory}, Categories: []string{constants.ConnectionCategory, constants.SlowCategory},
Description: "(AUTH [username] password) Authenticates the connection", Description: "(AUTH [username] password) Authenticates the connection",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handleAuth, HandlerFunc: handleAuth,
}, },
@@ -509,8 +513,12 @@ func Commands() []types.Command {
Categories: []string{}, Categories: []string{},
Description: "Access-Control-List commands", Description: "Access-Control-List commands",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
SubCommands: []types.SubCommand{ SubCommands: []types.SubCommand{
{ {
@@ -520,8 +528,12 @@ func Commands() []types.Command {
Description: `(ACL CAT [category]) List all the categories. Description: `(ACL CAT [category]) List all the categories.
If the optional category is provided, list all the commands in the category`, If the optional category is provided, list all the commands in the category`,
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handleCat, HandlerFunc: handleCat,
}, },
@@ -531,8 +543,12 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL USERS) List all usernames of the configured ACL users", Description: "(ACL USERS) List all usernames of the configured ACL users",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handleUsers, HandlerFunc: handleUsers,
}, },
@@ -542,8 +558,12 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL SETUSER) Configure a new or existing user", Description: "(ACL SETUSER) Configure a new or existing user",
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handleSetUser, HandlerFunc: handleSetUser,
}, },
@@ -553,8 +573,12 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL GETUSER username) List the ACL rules of a user", Description: "(ACL GETUSER username) List the ACL rules of a user",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handleGetUser, HandlerFunc: handleGetUser,
}, },
@@ -564,8 +588,12 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL DELUSER username [username ...]) Deletes users and terminates their connections. Cannot delete default user", Description: "(ACL DELUSER username [username ...]) Deletes users and terminates their connections. Cannot delete default user",
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handleDelUser, HandlerFunc: handleDelUser,
}, },
@@ -575,8 +603,12 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.FastCategory}, Categories: []string{constants.FastCategory},
Description: "(ACL WHOAMI) Returns the authenticated user of the current connection", Description: "(ACL WHOAMI) Returns the authenticated user of the current connection",
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handleWhoAmI, HandlerFunc: handleWhoAmI,
}, },
@@ -586,8 +618,12 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL LIST) Dumps effective acl rules in acl config file format", Description: "(ACL LIST) Dumps effective acl rules in acl config file format",
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handleList, HandlerFunc: handleList,
}, },
@@ -600,8 +636,12 @@ If the optional category is provided, list all the commands in the category`,
When 'MERGE' is passed, users from config file who share a username with users in memory will be merged. When 'MERGE' is passed, users from config file who share a username with users in memory will be merged.
When 'REPLACE' is passed, users from config file who share a username with users in memory will replace the user in memory.`, When 'REPLACE' is passed, users from config file who share a username with users in memory will replace the user in memory.`,
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handleLoad, HandlerFunc: handleLoad,
}, },
@@ -611,8 +651,12 @@ When 'REPLACE' is passed, users from config file who share a username with users
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL SAVE) Saves the effective ACL rules the configured ACL config file", Description: "(ACL SAVE) Saves the effective ACL rules the configured ACL config file",
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handleSave, HandlerFunc: handleSave,
}, },

View File

@@ -26,7 +26,7 @@ import (
"strings" "strings"
) )
func handleGetAllCommands(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) { func handleGetAllCommands(_ context.Context, _ []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
commands := server.GetAllCommands() commands := server.GetAllCommands()
res := "" res := ""
@@ -199,7 +199,13 @@ func Commands() []types.Command {
Categories: []string{constants.AdminCategory, constants.SlowCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory},
Description: "Get a list of all the commands in available on the echovault with categories and descriptions", Description: "Get a list of all the commands in available on the echovault with categories and descriptions",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleGetAllCommands, HandlerFunc: handleGetAllCommands,
}, },
{ {
@@ -208,8 +214,12 @@ func Commands() []types.Command {
Categories: []string{}, Categories: []string{},
Description: "Commands pertaining to echovault commands", Description: "Commands pertaining to echovault commands",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
SubCommands: []types.SubCommand{ SubCommands: []types.SubCommand{
{ {
@@ -218,7 +228,13 @@ func Commands() []types.Command {
Categories: []string{constants.SlowCategory, constants.ConnectionCategory}, Categories: []string{constants.SlowCategory, constants.ConnectionCategory},
Description: "Get command documentation", Description: "Get command documentation",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleCommandDocs, HandlerFunc: handleCommandDocs,
}, },
{ {
@@ -227,7 +243,13 @@ func Commands() []types.Command {
Categories: []string{constants.SlowCategory}, Categories: []string{constants.SlowCategory},
Description: "Get the dumber of commands in the echovault", Description: "Get the dumber of commands in the echovault",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleCommandCount, HandlerFunc: handleCommandCount,
}, },
{ {
@@ -237,7 +259,13 @@ func Commands() []types.Command {
Description: `(COMMAND LIST [FILTERBY <ACLCAT category | PATTERN pattern | MODULE module>]) Get the list of command names. Description: `(COMMAND LIST [FILTERBY <ACLCAT category | PATTERN pattern | MODULE module>]) Get the list of command names.
Allows for filtering by ACL category or glob pattern.`, Allows for filtering by ACL category or glob pattern.`,
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleCommandList, HandlerFunc: handleCommandList,
}, },
}, },
@@ -248,8 +276,12 @@ Allows for filtering by ACL category or glob pattern.`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(SAVE) Trigger a snapshot save", Description: "(SAVE) Trigger a snapshot save",
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: func(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { HandlerFunc: func(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
if err := server.TakeSnapshot(); err != nil { if err := server.TakeSnapshot(); err != nil {
@@ -264,8 +296,12 @@ Allows for filtering by ACL category or glob pattern.`,
Categories: []string{constants.AdminCategory, constants.FastCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.FastCategory, constants.DangerousCategory},
Description: "(LASTSAVE) Get unix timestamp for the latest snapshot in milliseconds.", Description: "(LASTSAVE) Get unix timestamp for the latest snapshot in milliseconds.",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: func(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { HandlerFunc: func(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
msec := server.GetLatestSnapshotTime() msec := server.GetLatestSnapshotTime()
@@ -281,8 +317,12 @@ Allows for filtering by ACL category or glob pattern.`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory}, Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(REWRITEAOF) Trigger re-writing of append process", Description: "(REWRITEAOF) Trigger re-writing of append process",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: func(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { HandlerFunc: func(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
if err := server.RewriteAOF(); err != nil { if err := server.RewriteAOF(); err != nil {

View File

@@ -42,8 +42,12 @@ func Commands() []types.Command {
Categories: []string{constants.FastCategory, constants.ConnectionCategory}, Categories: []string{constants.FastCategory, constants.ConnectionCategory},
Description: "(PING [value]) Ping the echovault. If a value is provided, the value will be echoed.", Description: "(PING [value]) Ping the echovault. If a value is provided, the value will be echoed.",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return []string{}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handlePing, HandlerFunc: handlePing,
}, },

View File

@@ -39,7 +39,7 @@ func handleSet(ctx context.Context, cmd []string, server types.EchoVault, _ *net
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
value := cmd[2] value := cmd[2]
res := []byte(constants.OkResponse) res := []byte(constants.OkResponse)
clock := server.GetClock() clock := server.GetClock()
@@ -99,7 +99,8 @@ func handleSet(ctx context.Context, cmd []string, server types.EchoVault, _ *net
} }
func handleMSet(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) { func handleMSet(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
if _, err := msetKeyFunc(cmd); err != nil { _, err := msetKeyFunc(cmd)
if err != nil {
return nil, err return nil, err
} }
@@ -159,7 +160,7 @@ func handleGet(ctx context.Context, cmd []string, server types.EchoVault, _ *net
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
@@ -185,7 +186,7 @@ func handleMGet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
values := make(map[string]string) values := make(map[string]string)
locks := make(map[string]bool) locks := make(map[string]bool)
for _, key := range keys { for _, key := range keys.ReadKeys {
if _, ok := values[key]; ok { if _, ok := values[key]; ok {
// Skip if we have already locked this key // Skip if we have already locked this key
continue continue
@@ -232,7 +233,7 @@ func handleDel(ctx context.Context, cmd []string, server types.EchoVault, _ *net
return nil, err return nil, err
} }
count := 0 count := 0
for _, key := range keys { for _, key := range keys.WriteKeys {
err = server.DeleteKey(ctx, key) err = server.DeleteKey(ctx, key)
if err != nil { if err != nil {
log.Printf("could not delete key %s due to error: %+v\n", key, err) log.Printf("could not delete key %s due to error: %+v\n", key, err)
@@ -249,7 +250,7 @@ func handlePersist(ctx context.Context, cmd []string, server types.EchoVault, _
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
@@ -276,7 +277,7 @@ func handleExpireTime(ctx context.Context, cmd []string, server types.EchoVault,
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return []byte(":-2\r\n"), nil return []byte(":-2\r\n"), nil
@@ -307,7 +308,7 @@ func handleTTL(ctx context.Context, cmd []string, server types.EchoVault, _ *net
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
clock := server.GetClock() clock := server.GetClock()
@@ -344,7 +345,7 @@ func handleExpire(ctx context.Context, cmd []string, server types.EchoVault, _ *
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
// Extract time // Extract time
n, err := strconv.ParseInt(cmd[2], 10, 64) n, err := strconv.ParseInt(cmd[2], 10, 64)
@@ -412,7 +413,7 @@ func handleExpireAt(ctx context.Context, cmd []string, server types.EchoVault, _
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
// Extract time // Extract time
n, err := strconv.ParseInt(cmd[2], 10, 64) n, err := strconv.ParseInt(cmd[2], 10, 64)

View File

@@ -17,18 +17,23 @@ package generic
import ( import (
"errors" "errors"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
) )
func setKeyFunc(cmd []string) ([]string, error) { func setKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 || len(cmd) > 7 { if len(cmd) < 3 || len(cmd) > 7 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func msetKeyFunc(cmd []string) ([]string, error) { func msetKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd[1:])%2 != 0 { if len(cmd[1:])%2 != 0 {
return nil, errors.New("each key must be paired with a value") return types.AccessKeys{}, errors.New("each key must be paired with a value")
} }
var keys []string var keys []string
for i, key := range cmd[1:] { for i, key := range cmd[1:] {
@@ -36,61 +41,97 @@ func msetKeyFunc(cmd []string) ([]string, error) {
keys = append(keys, key) keys = append(keys, key)
} }
} }
return keys, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: keys,
}, nil
} }
func getKeyFunc(cmd []string) ([]string, error) { func getKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
func mgetKeyFunc(cmd []string) ([]string, error) { func mgetKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
func delKeyFunc(cmd []string) ([]string, error) { func delKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:],
}, nil
} }
func persistKeyFunc(cmd []string) ([]string, error) { func persistKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:],
}, nil
} }
func expireTimeKeyFunc(cmd []string) ([]string, error) { func expireTimeKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
func ttlKeyFunc(cmd []string) ([]string, error) { func ttlKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
func expireKeyFunc(cmd []string) ([]string, error) { func expireKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 || len(cmd) > 4 { if len(cmd) < 3 || len(cmd) > 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func expireAtKeyFunc(cmd []string) ([]string, error) { func expireAtKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 || len(cmd) > 4 { if len(cmd) < 3 || len(cmd) > 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }

View File

@@ -34,7 +34,7 @@ func handleHSET(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
entries := make(map[string]interface{}) entries := make(map[string]interface{})
if len(cmd[2:])%2 != 0 { if len(cmd[2:])%2 != 0 {
@@ -92,7 +92,7 @@ func handleHGET(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
fields := cmd[2:] fields := cmd[2:]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
@@ -143,7 +143,7 @@ func handleHSTRLEN(ctx context.Context, cmd []string, server types.EchoVault, _
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
fields := cmd[2:] fields := cmd[2:]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
@@ -194,7 +194,7 @@ func handleHVALS(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
@@ -235,7 +235,7 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server types.EchoVault,
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
count := 1 count := 1
if len(cmd) >= 3 { if len(cmd) >= 3 {
@@ -351,7 +351,7 @@ func handleHLEN(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
@@ -376,7 +376,7 @@ func handleHKEYS(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
@@ -406,7 +406,7 @@ func handleHINCRBY(ctx context.Context, cmd []string, server types.EchoVault, _
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
field := cmd[2] field := cmd[2]
var intIncrement int var intIncrement int
@@ -498,7 +498,7 @@ func handleHGETALL(ctx context.Context, cmd []string, server types.EchoVault, _
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
@@ -538,7 +538,7 @@ func handleHEXISTS(ctx context.Context, cmd []string, server types.EchoVault, _
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
field := cmd[2] field := cmd[2]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
@@ -568,7 +568,7 @@ func handleHDEL(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
fields := cmd[2:] fields := cmd[2:]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {

View File

@@ -17,91 +17,144 @@ package hash
import ( import (
"errors" "errors"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
) )
func hsetKeyFunc(cmd []string) ([]string, error) { func hsetKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 4 { if len(cmd) < 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func hsetnxKeyFunc(cmd []string) ([]string, error) { func hsetnxKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 4 { if len(cmd) < 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func hgetKeyFunc(cmd []string) ([]string, error) { func hgetKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func hstrlenKeyFunc(cmd []string) ([]string, error) { func hstrlenKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func hvalsKeyFunc(cmd []string) ([]string, error) { func hvalsKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
func hrandfieldKeyFunc(cmd []string) ([]string, error) { func hrandfieldKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 || len(cmd) > 4 { if len(cmd) < 2 || len(cmd) > 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
if len(cmd) == 2 { if len(cmd) == 2 {
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func hlenKeyFunc(cmd []string) ([]string, error) { func hlenKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
func hkeysKeyFunc(cmd []string) ([]string, error) { func hkeysKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
func hincrbyKeyFunc(cmd []string) ([]string, error) { func hincrbyKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func hgetallKeyFunc(cmd []string) ([]string, error) { func hgetallKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
func hexistsKeyFunc(cmd []string) ([]string, error) { func hexistsKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 3 { if len(cmd) != 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func hdelKeyFunc(cmd []string) ([]string, error) { func hdelKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }

View File

@@ -33,7 +33,7 @@ func handleLLen(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
// If key does not exist, return 0 // If key does not exist, return 0
@@ -52,13 +52,13 @@ func handleLLen(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, errors.New("LLEN command on non-list item") return nil, errors.New("LLEN command on non-list item")
} }
func handleLIndex(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleLIndex(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lindexKeyFunc(cmd) keys, err := lindexKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
index, ok := internal.AdaptType(cmd[2]).(int) index, ok := internal.AdaptType(cmd[2]).(int)
if !ok { if !ok {
@@ -86,13 +86,13 @@ func handleLIndex(ctx context.Context, cmd []string, server types.EchoVault, con
return []byte(fmt.Sprintf("+%s\r\n", list[index])), nil return []byte(fmt.Sprintf("+%s\r\n", list[index])), nil
} }
func handleLRange(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleLRange(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lrangeKeyFunc(cmd) keys, err := lrangeKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
start, startOk := internal.AdaptType(cmd[2]).(int) start, startOk := internal.AdaptType(cmd[2]).(int)
end, endOk := internal.AdaptType(cmd[3]).(int) end, endOk := internal.AdaptType(cmd[3]).(int)
@@ -165,13 +165,13 @@ func handleLRange(ctx context.Context, cmd []string, server types.EchoVault, con
return bytes, nil return bytes, nil
} }
func handleLSet(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleLSet(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lsetKeyFunc(cmd) keys, err := lsetKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
index, ok := internal.AdaptType(cmd[2]).(int) index, ok := internal.AdaptType(cmd[2]).(int)
if !ok { if !ok {
@@ -204,13 +204,13 @@ func handleLSet(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handleLTrim(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleLTrim(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := ltrimKeyFunc(cmd) keys, err := ltrimKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
start, startOk := internal.AdaptType(cmd[2]).(int) start, startOk := internal.AdaptType(cmd[2]).(int)
end, endOk := internal.AdaptType(cmd[3]).(int) end, endOk := internal.AdaptType(cmd[3]).(int)
@@ -253,13 +253,13 @@ func handleLTrim(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handleLRem(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleLRem(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lremKeyFunc(cmd) keys, err := lremKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
value := cmd[3] value := cmd[3]
count, ok := internal.AdaptType(cmd[2]).(int) count, ok := internal.AdaptType(cmd[2]).(int)
@@ -321,14 +321,13 @@ func handleLRem(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handleLMove(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleLMove(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lmoveKeyFunc(cmd) keys, err := lmoveKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
source := keys[0] source, destination := keys.WriteKeys[0], keys.WriteKeys[1]
destination := keys[1]
whereFrom := strings.ToLower(cmd[3]) whereFrom := strings.ToLower(cmd[3])
whereTo := strings.ToLower(cmd[4]) whereTo := strings.ToLower(cmd[4])
@@ -394,7 +393,7 @@ func handleLPush(ctx context.Context, cmd []string, server types.EchoVault, _ *n
newElems = append(newElems, internal.AdaptType(elem)) newElems = append(newElems, internal.AdaptType(elem))
} }
key := keys[0] key := keys.WriteKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
switch strings.ToLower(cmd[0]) { switch strings.ToLower(cmd[0]) {
@@ -428,13 +427,13 @@ func handleLPush(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
func handleRPush(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleRPush(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := rpushKeyFunc(cmd) keys, err := rpushKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
var newElems []interface{} var newElems []interface{}
@@ -482,7 +481,7 @@ func handlePop(ctx context.Context, cmd []string, server types.EchoVault, _ *net
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(cmd[0])) return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(cmd[0]))

View File

@@ -17,74 +17,115 @@ package list
import ( import (
"errors" "errors"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
) )
func lpushKeyFunc(cmd []string) ([]string, error) { func lpushKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func popKeyFunc(cmd []string) ([]string, error) { func popKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:],
}, nil
} }
func llenKeyFunc(cmd []string) ([]string, error) { func llenKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
func lrangeKeyFunc(cmd []string) ([]string, error) { func lrangeKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func lindexKeyFunc(cmd []string) ([]string, error) { func lindexKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 3 { if len(cmd) != 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func lsetKeyFunc(cmd []string) ([]string, error) { func lsetKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func ltrimKeyFunc(cmd []string) ([]string, error) { func ltrimKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func lremKeyFunc(cmd []string) ([]string, error) { func lremKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func rpushKeyFunc(cmd []string) ([]string, error) { func rpushKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func lmoveKeyFunc(cmd []string) ([]string, error) { func lmoveKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 5 { if len(cmd) != 5 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1], cmd[2]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:3],
}, nil
} }

View File

@@ -111,12 +111,16 @@ func Commands() []types.Command {
Categories: []string{constants.PubSubCategory, constants.ConnectionCategory, constants.SlowCategory}, Categories: []string{constants.PubSubCategory, constants.ConnectionCategory, constants.SlowCategory},
Description: "(SUBSCRIBE channel [channel ...]) Subscribe to one or more channels.", Description: "(SUBSCRIBE channel [channel ...]) Subscribe to one or more channels.",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
// Treat the channels as keys // Treat the channels as keys
if len(cmd) < 2 { if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: cmd[1:],
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handleSubscribe, HandlerFunc: handleSubscribe,
}, },
@@ -126,12 +130,16 @@ func Commands() []types.Command {
Categories: []string{constants.PubSubCategory, constants.ConnectionCategory, constants.SlowCategory}, Categories: []string{constants.PubSubCategory, constants.ConnectionCategory, constants.SlowCategory},
Description: "(PSUBSCRIBE pattern [pattern ...]) Subscribe to one or more glob patterns.", Description: "(PSUBSCRIBE pattern [pattern ...]) Subscribe to one or more glob patterns.",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
// Treat the patterns as keys // Treat the patterns as keys
if len(cmd) < 2 { if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: cmd[1:],
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handleSubscribe, HandlerFunc: handleSubscribe,
}, },
@@ -141,12 +149,16 @@ func Commands() []types.Command {
Categories: []string{constants.PubSubCategory, constants.FastCategory}, Categories: []string{constants.PubSubCategory, constants.FastCategory},
Description: "(PUBLISH channel message) Publish a message to the specified channel.", Description: "(PUBLISH channel message) Publish a message to the specified channel.",
Sync: true, Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
// Treat the channel as a key // Treat the channel as a key
if len(cmd) != 3 { if len(cmd) != 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: cmd[1:2],
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handlePublish, HandlerFunc: handlePublish,
}, },
@@ -158,9 +170,13 @@ func Commands() []types.Command {
If the channel list is not provided, then the connection will be unsubscribed from all the channels that If the channel list is not provided, then the connection will be unsubscribed from all the channels that
it's currently subscribe to.`, it's currently subscribe to.`,
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
// Treat the channels as keys // Treat the channels as keys
return cmd[1:], nil return types.AccessKeys{
Channels: cmd[1:],
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handleUnsubscribe, HandlerFunc: handleUnsubscribe,
}, },
@@ -172,9 +188,12 @@ it's currently subscribe to.`,
If the pattern list is not provided, then the connection will be unsubscribed from all the patterns that If the pattern list is not provided, then the connection will be unsubscribed from all the patterns that
it's currently subscribe to.`, it's currently subscribe to.`,
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
// Treat the channels as keys return types.AccessKeys{
return cmd[1:], nil Channels: cmd[1:],
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
}, },
HandlerFunc: handleUnsubscribe, HandlerFunc: handleUnsubscribe,
}, },
@@ -184,7 +203,13 @@ it's currently subscribe to.`,
Categories: []string{}, Categories: []string{},
Description: "", Description: "",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: func(_ context.Context, _ []string, _ types.EchoVault, _ *net.Conn) ([]byte, error) { HandlerFunc: func(_ context.Context, _ []string, _ types.EchoVault, _ *net.Conn) ([]byte, error) {
return nil, errors.New("provide CHANNELS, NUMPAT, or NUMSUB subcommand") return nil, errors.New("provide CHANNELS, NUMPAT, or NUMSUB subcommand")
}, },
@@ -197,7 +222,13 @@ it's currently subscribe to.`,
match the given pattern. If no pattern is provided, all active channels are returned. Active channels are match the given pattern. If no pattern is provided, all active channels are returned. Active channels are
channels with 1 or more subscribers.`, channels with 1 or more subscribers.`,
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handlePubSubChannels, HandlerFunc: handlePubSubChannels,
}, },
{ {
@@ -206,7 +237,13 @@ channels with 1 or more subscribers.`,
Categories: []string{constants.PubSubCategory, constants.SlowCategory}, Categories: []string{constants.PubSubCategory, constants.SlowCategory},
Description: `(PUBSUB NUMPAT) Return the number of patterns that are currently subscribed to by clients.`, Description: `(PUBSUB NUMPAT) Return the number of patterns that are currently subscribed to by clients.`,
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handlePubSubNumPat, HandlerFunc: handlePubSubNumPat,
}, },
{ {
@@ -216,7 +253,13 @@ channels with 1 or more subscribers.`,
Description: `(PUBSUB NUMSUB [channel [channel ...]]) Return an array of arrays containing the provided Description: `(PUBSUB NUMSUB [channel [channel ...]]) Return an array of arrays containing the provided
channel name and how many clients are currently subscribed to the channel.`, channel name and how many clients are currently subscribed to the channel.`,
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return cmd[2:], nil }, KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: cmd[2:],
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handlePubSubNumSubs, HandlerFunc: handlePubSubNumSubs,
}, },
}, },

View File

@@ -27,13 +27,13 @@ import (
"strings" "strings"
) )
func handleSADD(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSADD(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := saddKeyFunc(cmd) keys, err := saddKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
var set *internal_set.Set var set *internal_set.Set
@@ -64,13 +64,13 @@ func handleSADD(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(fmt.Sprintf(":%d\r\n", count)), nil return []byte(fmt.Sprintf(":%d\r\n", count)), nil
} }
func handleSCARD(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSCARD(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := scardKeyFunc(cmd) keys, err := scardKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return []byte(fmt.Sprintf(":0\r\n")), nil return []byte(fmt.Sprintf(":0\r\n")), nil
@@ -91,23 +91,23 @@ func handleSCARD(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(fmt.Sprintf(":%d\r\n", cardinality)), nil return []byte(fmt.Sprintf(":%d\r\n", cardinality)), nil
} }
func handleSDIFF(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSDIFF(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sdiffKeyFunc(cmd) keys, err := sdiffKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Extract base set first // Extract base set first
if !server.KeyExists(ctx, keys[0]) { if !server.KeyExists(ctx, keys.ReadKeys[0]) {
return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys[0]) return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys.ReadKeys[0])
} }
if _, err = server.KeyRLock(ctx, keys[0]); err != nil { if _, err = server.KeyRLock(ctx, keys.ReadKeys[0]); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(ctx, keys[0]) defer server.KeyRUnlock(ctx, keys.ReadKeys[0])
baseSet, ok := server.GetValue(ctx, keys[0]).(*internal_set.Set) baseSet, ok := server.GetValue(ctx, keys.ReadKeys[0]).(*internal_set.Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at key %s is not a set", keys[0]) return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
} }
locks := make(map[string]bool) locks := make(map[string]bool)
@@ -119,7 +119,7 @@ func handleSDIFF(ctx context.Context, cmd []string, server types.EchoVault, conn
} }
}() }()
for _, key := range keys[1:] { for _, key := range keys.ReadKeys[1:] {
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
continue continue
} }
@@ -152,25 +152,25 @@ func handleSDIFF(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(res), nil return []byte(res), nil
} }
func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sdiffstoreKeyFunc(cmd) keys, err := sdiffstoreKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
destination := keys[0] destination := keys.WriteKeys[0]
// Extract base set first // Extract base set first
if !server.KeyExists(ctx, keys[1]) { if !server.KeyExists(ctx, keys.ReadKeys[0]) {
return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys[1]) return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys.ReadKeys[0])
} }
if _, err := server.KeyRLock(ctx, keys[1]); err != nil { if _, err := server.KeyRLock(ctx, keys.ReadKeys[0]); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(ctx, keys[1]) defer server.KeyRUnlock(ctx, keys.ReadKeys[0])
baseSet, ok := server.GetValue(ctx, keys[1]).(*internal_set.Set) baseSet, ok := server.GetValue(ctx, keys.ReadKeys[0]).(*internal_set.Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at key %s is not a set", keys[1]) return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
} }
locks := make(map[string]bool) locks := make(map[string]bool)
@@ -182,7 +182,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
} }
}() }()
for _, key := range keys[2:] { for _, key := range keys.ReadKeys[1:] {
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
continue continue
} }
@@ -193,7 +193,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
} }
var sets []*internal_set.Set var sets []*internal_set.Set
for _, key := range keys[2:] { for _, key := range keys.ReadKeys[1:] {
set, ok := server.GetValue(ctx, key).(*internal_set.Set) set, ok := server.GetValue(ctx, key).(*internal_set.Set)
if !ok { if !ok {
continue continue
@@ -228,7 +228,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(res), nil return []byte(res), nil
} }
func handleSINTER(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSINTER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sinterKeyFunc(cmd) keys, err := sinterKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -243,7 +243,7 @@ func handleSINTER(ctx context.Context, cmd []string, server types.EchoVault, con
} }
}() }()
for _, key := range keys[0:] { for _, key := range keys.ReadKeys {
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
// If key does not exist, then there is no intersection // If key does not exist, then there is no intersection
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
@@ -319,7 +319,7 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server types.EchoVault,
} }
}() }()
for _, key := range keys { for _, key := range keys.ReadKeys {
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
// If key does not exist, then there is no intersection // If key does not exist, then there is no intersection
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
@@ -350,7 +350,7 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
} }
func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sinterstoreKeyFunc(cmd) keys, err := sinterstoreKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -365,7 +365,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault
} }
}() }()
for _, key := range keys[1:] { for _, key := range keys.ReadKeys {
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
// If key does not exist, then there is no intersection // If key does not exist, then there is no intersection
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
@@ -388,7 +388,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault
} }
intersect, _ := internal_set.Intersection(0, sets...) intersect, _ := internal_set.Intersection(0, sets...)
destination := keys[0] destination := keys.WriteKeys[0]
if server.KeyExists(ctx, destination) { if server.KeyExists(ctx, destination) {
if _, err = server.KeyLock(ctx, destination); err != nil { if _, err = server.KeyLock(ctx, destination); err != nil {
@@ -408,13 +408,13 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
} }
func handleSISMEMBER(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSISMEMBER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sismemberKeyFunc(cmd) keys, err := sismemberKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
@@ -437,13 +437,13 @@ func handleSISMEMBER(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(":1\r\n"), nil return []byte(":1\r\n"), nil
} }
func handleSMEMBERS(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSMEMBERS(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := smembersKeyFunc(cmd) keys, err := smembersKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
@@ -472,13 +472,13 @@ func handleSMEMBERS(ctx context.Context, cmd []string, server types.EchoVault, c
return []byte(res), nil return []byte(res), nil
} }
func handleSMISMEMBER(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSMISMEMBER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := smismemberKeyFunc(cmd) keys, err := smismemberKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
members := cmd[2:] members := cmd[2:]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
@@ -515,14 +515,13 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(res), nil return []byte(res), nil
} }
func handleSMOVE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSMOVE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := smoveKeyFunc(cmd) keys, err := smoveKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
source := keys[0] source, destination := keys.WriteKeys[0], keys.WriteKeys[1]
destination := keys[1]
member := cmd[3] member := cmd[3]
if !server.KeyExists(ctx, source) { if !server.KeyExists(ctx, source) {
@@ -569,13 +568,13 @@ func handleSMOVE(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(fmt.Sprintf(":%d\r\n", res)), nil return []byte(fmt.Sprintf(":%d\r\n", res)), nil
} }
func handleSPOP(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSPOP(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := spopKeyFunc(cmd) keys, err := spopKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
count := 1 count := 1
if len(cmd) == 3 { if len(cmd) == 3 {
@@ -613,13 +612,13 @@ func handleSPOP(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(res), nil return []byte(res), nil
} }
func handleSRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := srandmemberKeyFunc(cmd) keys, err := srandmemberKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
count := 1 count := 1
if len(cmd) == 3 { if len(cmd) == 3 {
@@ -657,13 +656,13 @@ func handleSRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault
return []byte(res), nil return []byte(res), nil
} }
func handleSREM(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSREM(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sremKeyFunc(cmd) keys, err := sremKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
members := cmd[2:] members := cmd[2:]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
@@ -685,7 +684,7 @@ func handleSREM(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(fmt.Sprintf(":%d\r\n", count)), nil return []byte(fmt.Sprintf(":%d\r\n", count)), nil
} }
func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sunionKeyFunc(cmd) keys, err := sunionKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -700,7 +699,7 @@ func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, con
} }
}() }()
for _, key := range keys { for _, key := range keys.ReadKeys {
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
continue continue
} }
@@ -736,7 +735,7 @@ func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, con
return []byte(res), nil return []byte(res), nil
} }
func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sunionstoreKeyFunc(cmd) keys, err := sunionstoreKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -751,7 +750,7 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault
} }
}() }()
for _, key := range keys[1:] { for _, key := range keys.ReadKeys {
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
continue continue
} }
@@ -776,7 +775,7 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault
union := internal_set.Union(sets...) union := internal_set.Union(sets...)
destination := cmd[1] destination := keys.WriteKeys[0]
if server.KeyExists(ctx, destination) { if server.KeyExists(ctx, destination) {
if _, err = server.KeyLock(ctx, destination); err != nil { if _, err = server.KeyLock(ctx, destination); err != nil {

View File

@@ -17,48 +17,69 @@ package set
import ( import (
"errors" "errors"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"slices" "slices"
"strings" "strings"
) )
func saddKeyFunc(cmd []string) ([]string, error) { func saddKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func scardKeyFunc(cmd []string) ([]string, error) { func scardKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func sdiffKeyFunc(cmd []string) ([]string, error) { func sdiffKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
func sdiffstoreKeyFunc(cmd []string) ([]string, error) { func sdiffstoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:],
WriteKeys: cmd[1:2],
}, nil
} }
func sinterKeyFunc(cmd []string) ([]string, error) { func sinterKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
func sintercardKeyFunc(cmd []string) ([]string, error) { func sintercardKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
limitIdx := slices.IndexFunc(cmd, func(s string) bool { limitIdx := slices.IndexFunc(cmd, func(s string) bool {
@@ -66,78 +87,126 @@ func sintercardKeyFunc(cmd []string) ([]string, error) {
}) })
if limitIdx == -1 { if limitIdx == -1 {
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
return cmd[1:limitIdx], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:limitIdx],
WriteKeys: make([]string, 0),
}, nil
} }
func sinterstoreKeyFunc(cmd []string) ([]string, error) { func sinterstoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:],
WriteKeys: cmd[1:2],
}, nil
} }
func sismemberKeyFunc(cmd []string) ([]string, error) { func sismemberKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 3 { if len(cmd) != 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
func smembersKeyFunc(cmd []string) ([]string, error) { func smembersKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
func smismemberKeyFunc(cmd []string) ([]string, error) { func smismemberKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func smoveKeyFunc(cmd []string) ([]string, error) { func smoveKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:3], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:3],
}, nil
} }
func spopKeyFunc(cmd []string) ([]string, error) { func spopKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 || len(cmd) > 3 { if len(cmd) < 2 || len(cmd) > 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func srandmemberKeyFunc(cmd []string) ([]string, error) { func srandmemberKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 || len(cmd) > 3 { if len(cmd) < 2 || len(cmd) > 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func sremKeyFunc(cmd []string) ([]string, error) { func sremKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func sunionKeyFunc(cmd []string) ([]string, error) { func sunionKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
func sunionstoreKeyFunc(cmd []string) ([]string, error) { func sunionstoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:],
WriteKeys: cmd[1:2],
}, nil
} }

View File

@@ -36,7 +36,7 @@ func handleZADD(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
var updatePolicy interface{} = nil var updatePolicy interface{} = nil
var comparison interface{} = nil var comparison interface{} = nil
@@ -181,12 +181,12 @@ func handleZADD(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return []byte(fmt.Sprintf(":%d\r\n", set.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", set.Cardinality())), nil
} }
func handleZCARD(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZCARD(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zcardKeyFunc(cmd) keys, err := zcardKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
@@ -205,13 +205,13 @@ func handleZCARD(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(fmt.Sprintf(":%d\r\n", set.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", set.Cardinality())), nil
} }
func handleZCOUNT(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZCOUNT(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zcountKeyFunc(cmd) keys, err := zcountKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
minimum := sorted_set.Score(math.Inf(-1)) minimum := sorted_set.Score(math.Inf(-1))
switch internal.AdaptType(cmd[2]).(type) { switch internal.AdaptType(cmd[2]).(type) {
@@ -279,7 +279,7 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server types.EchoVault,
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
minimum := cmd[2] minimum := cmd[2]
maximum := cmd[3] maximum := cmd[3]
@@ -318,7 +318,7 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(fmt.Sprintf(":%d\r\n", count)), nil return []byte(fmt.Sprintf(":%d\r\n", count)), nil
} }
func handleZDIFF(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZDIFF(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zdiffKeyFunc(cmd) keys, err := zdiffKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -341,34 +341,34 @@ func handleZDIFF(ctx context.Context, cmd []string, server types.EchoVault, conn
}() }()
// Extract base set // Extract base set
if !server.KeyExists(ctx, keys[0]) { if !server.KeyExists(ctx, keys.ReadKeys[0]) {
// If base set does not exist, return an empty array // If base set does not exist, return an empty array
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, keys[0]); err != nil { if _, err = server.KeyRLock(ctx, keys.ReadKeys[0]); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(ctx, keys[0]) defer server.KeyRUnlock(ctx, keys.ReadKeys[0])
baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*sorted_set.SortedSet) baseSortedSet, ok := server.GetValue(ctx, keys.ReadKeys[0]).(*sorted_set.SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[0]) return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0])
} }
// Extract the remaining sets // Extract the remaining sets
var sets []*sorted_set.SortedSet var sets []*sorted_set.SortedSet
for i := 1; i < len(keys); i++ { for i := 1; i < len(keys.ReadKeys); i++ {
if !server.KeyExists(ctx, keys[i]) { if !server.KeyExists(ctx, keys.ReadKeys[i]) {
continue continue
} }
locked, err := server.KeyRLock(ctx, keys[i]) locked, err := server.KeyRLock(ctx, keys.ReadKeys[i])
if err != nil { if err != nil {
return nil, err return nil, err
} }
locks[keys[i]] = locked locks[keys.ReadKeys[i]] = locked
set, ok := server.GetValue(ctx, keys[i]).(*sorted_set.SortedSet) set, ok := server.GetValue(ctx, keys.ReadKeys[i]).(*sorted_set.SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i])
} }
sets = append(sets, set) sets = append(sets, set)
} }
@@ -391,13 +391,13 @@ func handleZDIFF(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(res), nil return []byte(res), nil
} }
func handleZDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zdiffstoreKeyFunc(cmd) keys, err := zdiffstoreKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
destination := cmd[1] destination := keys.WriteKeys[0]
locks := make(map[string]bool) locks := make(map[string]bool)
defer func() { defer func() {
@@ -409,29 +409,29 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
}() }()
// Extract base set // Extract base set
if !server.KeyExists(ctx, keys[0]) { if !server.KeyExists(ctx, keys.ReadKeys[0]) {
// If base set does not exist, return 0 // If base set does not exist, return 0
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, keys[0]); err != nil { if _, err = server.KeyRLock(ctx, keys.ReadKeys[0]); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(ctx, keys[0]) defer server.KeyRUnlock(ctx, keys.ReadKeys[0])
baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*sorted_set.SortedSet) baseSortedSet, ok := server.GetValue(ctx, keys.ReadKeys[0]).(*sorted_set.SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[0]) return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0])
} }
var sets []*sorted_set.SortedSet var sets []*sorted_set.SortedSet
for i := 1; i < len(keys); i++ { for i := 1; i < len(keys.ReadKeys); i++ {
if server.KeyExists(ctx, keys[i]) { if server.KeyExists(ctx, keys.ReadKeys[i]) {
if _, err = server.KeyRLock(ctx, keys[i]); err != nil { if _, err = server.KeyRLock(ctx, keys.ReadKeys[i]); err != nil {
return nil, err return nil, err
} }
set, ok := server.GetValue(ctx, keys[i]).(*sorted_set.SortedSet) set, ok := server.GetValue(ctx, keys.ReadKeys[i]).(*sorted_set.SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i])
} }
sets = append(sets, set) sets = append(sets, set)
} }
@@ -457,13 +457,13 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(fmt.Sprintf(":%d\r\n", diff.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", diff.Cardinality())), nil
} }
func handleZINCRBY(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZINCRBY(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zincrbyKeyFunc(cmd) keys, err := zincrbyKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
member := sorted_set.Value(cmd[3]) member := sorted_set.Value(cmd[3])
var increment sorted_set.Score var increment sorted_set.Score
@@ -524,8 +524,8 @@ func handleZINCRBY(ctx context.Context, cmd []string, server types.EchoVault, co
strconv.FormatFloat(float64(set.Get(member).Score), 'f', -1, 64))), nil strconv.FormatFloat(float64(set.Get(member).Score), 'f', -1, 64))), nil
} }
func handleZINTER(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZINTER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zinterKeyFunc(cmd) _, err := zinterKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -584,13 +584,13 @@ func handleZINTER(ctx context.Context, cmd []string, server types.EchoVault, con
return []byte(res), nil return []byte(res), nil
} }
func handleZINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zinterstoreKeyFunc(cmd) k, err := zinterstoreKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
destination := keys[0] destination := k.WriteKeys[0]
// Remove the destination keys from the command before parsing it // Remove the destination keys from the command before parsing it
cmd = slices.DeleteFunc(cmd, func(s string) bool { cmd = slices.DeleteFunc(cmd, func(s string) bool {
@@ -651,7 +651,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
} }
func handleZMPOP(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZMPOP(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zmpopKeyFunc(cmd) keys, err := zmpopKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -697,22 +697,22 @@ func handleZMPOP(ctx context.Context, cmd []string, server types.EchoVault, conn
} }
} }
for i := 0; i < len(keys); i++ { for i := 0; i < len(keys.WriteKeys); i++ {
if server.KeyExists(ctx, keys[i]) { if server.KeyExists(ctx, keys.WriteKeys[i]) {
if _, err = server.KeyLock(ctx, keys[i]); err != nil { if _, err = server.KeyLock(ctx, keys.WriteKeys[i]); err != nil {
continue continue
} }
v, ok := server.GetValue(ctx, keys[i]).(*sorted_set.SortedSet) v, ok := server.GetValue(ctx, keys.WriteKeys[i]).(*sorted_set.SortedSet)
if !ok || v.Cardinality() == 0 { if !ok || v.Cardinality() == 0 {
server.KeyUnlock(ctx, keys[i]) server.KeyUnlock(ctx, keys.WriteKeys[i])
continue continue
} }
popped, err := v.Pop(count, policy) popped, err := v.Pop(count, policy)
if err != nil { if err != nil {
server.KeyUnlock(ctx, keys[i]) server.KeyUnlock(ctx, keys.WriteKeys[i])
return nil, err return nil, err
} }
server.KeyUnlock(ctx, keys[i]) server.KeyUnlock(ctx, keys.WriteKeys[i])
res := fmt.Sprintf("*%d", popped.Cardinality()) res := fmt.Sprintf("*%d", popped.Cardinality())
@@ -729,13 +729,13 @@ func handleZMPOP(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
func handleZPOP(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZPOP(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zpopKeyFunc(cmd) keys, err := zpopKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
count := 1 count := 1
policy := "min" policy := "min"
@@ -782,13 +782,13 @@ func handleZPOP(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(res), nil return []byte(res), nil
} }
func handleZMSCORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZMSCORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zmscoreKeyFunc(cmd) keys, err := zmscoreKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
@@ -824,13 +824,13 @@ func handleZMSCORE(ctx context.Context, cmd []string, server types.EchoVault, co
return []byte(res), nil return []byte(res), nil
} }
func handleZRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zrandmemberKeyFunc(cmd) keys, err := zrandmemberKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
count := 1 count := 1
if len(cmd) >= 3 { if len(cmd) >= 3 {
@@ -888,7 +888,7 @@ func handleZRANK(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
member := cmd[2] member := cmd[2]
withscores := false withscores := false
@@ -938,7 +938,7 @@ func handleZREM(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
@@ -964,13 +964,13 @@ func handleZREM(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
} }
func handleZSCORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZSCORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zscoreKeyFunc(cmd) keys, err := zscoreKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
@@ -993,13 +993,13 @@ func handleZSCORE(ctx context.Context, cmd []string, server types.EchoVault, con
return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(score), score)), nil return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(score), score)), nil
} }
func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zremrangebyscoreKeyFunc(cmd) keys, err := zremrangebyscoreKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
deletedCount := 0 deletedCount := 0
@@ -1037,13 +1037,13 @@ func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server types.Echo
return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
} }
func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zremrangebyrankKeyFunc(cmd) keys, err := zremrangebyrankKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
start, err := strconv.Atoi(cmd[2]) start, err := strconv.Atoi(cmd[2])
if err != nil { if err != nil {
@@ -1102,13 +1102,13 @@ func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server types.EchoV
return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
} }
func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zremrangebylexKeyFunc(cmd) keys, err := zremrangebylexKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
minimum := cmd[2] minimum := cmd[2]
maximum := cmd[3] maximum := cmd[3]
@@ -1149,13 +1149,13 @@ func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server types.EchoVa
return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
} }
func handleZRANGE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZRANGE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zrangeKeyCount(cmd) keys, err := zrangeKeyCount(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
policy := "byscore" policy := "byscore"
scoreStart := math.Inf(-1) // Lower bound if policy is "byscore" scoreStart := math.Inf(-1) // Lower bound if policy is "byscore"
scoreStop := math.Inf(1) // Upper bound if policy is "byscore" scoreStop := math.Inf(1) // Upper bound if policy is "byscore"
@@ -1289,14 +1289,14 @@ func handleZRANGE(ctx context.Context, cmd []string, server types.EchoVault, con
return []byte(res), nil return []byte(res), nil
} }
func handleZRANGESTORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZRANGESTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zrangeStoreKeyFunc(cmd) keys, err := zrangeStoreKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
destination := keys[0] destination := keys.WriteKeys[0]
source := keys[1] source := keys.ReadKeys[0]
policy := "byscore" policy := "byscore"
scoreStart := math.Inf(-1) // Lower bound if policy is "byscore" scoreStart := math.Inf(-1) // Lower bound if policy is "byscore"
scoreStop := math.Inf(1) // Upper bound if policy is "byfloat" scoreStop := math.Inf(1) // Upper bound if policy is "byfloat"
@@ -1431,7 +1431,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server types.EchoVault
return []byte(fmt.Sprintf(":%d\r\n", newSortedSet.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", newSortedSet.Cardinality())), nil
} }
func handleZUNION(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleZUNION(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
if _, err := zunionKeyFunc(cmd); err != nil { if _, err := zunionKeyFunc(cmd); err != nil {
return nil, err return nil, err
} }
@@ -1486,12 +1486,12 @@ func handleZUNION(ctx context.Context, cmd []string, server types.EchoVault, con
} }
func handleZUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) { func handleZUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zunionstoreKeyFunc(cmd) k, err := zunionstoreKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
destination := keys[0] destination := k.WriteKeys[0]
// Remove destination key from list of keys // Remove destination key from list of keys
cmd = slices.DeleteFunc(cmd, func(s string) bool { cmd = slices.DeleteFunc(cmd, func(s string) bool {

View File

@@ -17,34 +17,47 @@ package sorted_set
import ( import (
"errors" "errors"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"slices" "slices"
"strings" "strings"
) )
func zaddKeyFunc(cmd []string) ([]string, error) { func zaddKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 4 { if len(cmd) < 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func zcardKeyFunc(cmd []string) ([]string, error) { func zcardKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
func zcountKeyFunc(cmd []string) ([]string, error) { func zcountKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func zdiffKeyFunc(cmd []string) ([]string, error) { func zdiffKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
withscoresIndex := slices.IndexFunc(cmd, func(s string) bool { withscoresIndex := slices.IndexFunc(cmd, func(s string) bool {
@@ -52,29 +65,45 @@ func zdiffKeyFunc(cmd []string) ([]string, error) {
}) })
if withscoresIndex == -1 { if withscoresIndex == -1 {
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
return cmd[1:withscoresIndex], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:withscoresIndex],
WriteKeys: make([]string, 0),
}, nil
} }
func zdiffstoreKeyFunc(cmd []string) ([]string, error) { func zdiffstoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[2:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:],
WriteKeys: cmd[1:2],
}, nil
} }
func zincrbyKeyFunc(cmd []string) ([]string, error) { func zincrbyKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func zinterKeyFunc(cmd []string) ([]string, error) { func zinterKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
endIdx := slices.IndexFunc(cmd[1:], func(s string) bool { endIdx := slices.IndexFunc(cmd[1:], func(s string) bool {
if strings.EqualFold(s, "WEIGHTS") || if strings.EqualFold(s, "WEIGHTS") ||
@@ -85,17 +114,25 @@ func zinterKeyFunc(cmd []string) ([]string, error) {
return false return false
}) })
if endIdx == -1 { if endIdx == -1 {
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
if endIdx >= 1 { if endIdx >= 1 {
return cmd[1:endIdx], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:endIdx],
WriteKeys: make([]string, 0),
}, nil
} }
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
func zinterstoreKeyFunc(cmd []string) ([]string, error) { func zinterstoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
endIdx := slices.IndexFunc(cmd[1:], func(s string) bool { endIdx := slices.IndexFunc(cmd[1:], func(s string) bool {
if strings.EqualFold(s, "WEIGHTS") || if strings.EqualFold(s, "WEIGHTS") ||
@@ -106,124 +143,192 @@ func zinterstoreKeyFunc(cmd []string) ([]string, error) {
return false return false
}) })
if endIdx == -1 { if endIdx == -1 {
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:],
WriteKeys: cmd[1:2],
}, nil
} }
if endIdx >= 2 { if endIdx >= 3 {
return cmd[1:endIdx], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:endIdx],
WriteKeys: cmd[1:2],
}, nil
} }
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
func zmpopKeyFunc(cmd []string) ([]string, error) { func zmpopKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
endIdx := slices.IndexFunc(cmd, func(s string) bool { endIdx := slices.IndexFunc(cmd, func(s string) bool {
return slices.Contains([]string{"MIN", "MAX", "COUNT"}, strings.ToUpper(s)) return slices.Contains([]string{"MIN", "MAX", "COUNT"}, strings.ToUpper(s))
}) })
if endIdx == -1 { if endIdx == -1 {
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:],
}, nil
} }
if endIdx >= 2 { if endIdx >= 2 {
return cmd[1:endIdx], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:endIdx],
}, nil
} }
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
func zmscoreKeyFunc(cmd []string) ([]string, error) { func zmscoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func zpopKeyFunc(cmd []string) ([]string, error) { func zpopKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 || len(cmd) > 3 { if len(cmd) < 2 || len(cmd) > 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func zrandmemberKeyFunc(cmd []string) ([]string, error) { func zrandmemberKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 || len(cmd) > 4 { if len(cmd) < 2 || len(cmd) > 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func zrankKeyFunc(cmd []string) ([]string, error) { func zrankKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 || len(cmd) > 4 { if len(cmd) < 3 || len(cmd) > 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func zremKeyFunc(cmd []string) ([]string, error) { func zremKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func zrevrankKeyFunc(cmd []string) ([]string, error) { func zrevrankKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func zscoreKeyFunc(cmd []string) ([]string, error) { func zscoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 3 { if len(cmd) != 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func zremrangebylexKeyFunc(cmd []string) ([]string, error) { func zremrangebylexKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func zremrangebyrankKeyFunc(cmd []string) ([]string, error) { func zremrangebyrankKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func zremrangebyscoreKeyFunc(cmd []string) ([]string, error) { func zremrangebyscoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func zlexcountKeyFunc(cmd []string) ([]string, error) { func zlexcountKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func zrangeKeyCount(cmd []string) ([]string, error) { func zrangeKeyCount(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 4 || len(cmd) > 10 { if len(cmd) < 4 || len(cmd) > 10 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:2], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func zrangeStoreKeyFunc(cmd []string) ([]string, error) { func zrangeStoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 5 || len(cmd) > 11 { if len(cmd) < 5 || len(cmd) > 11 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return cmd[1:3], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:3],
WriteKeys: cmd[1:2],
}, nil
} }
func zunionKeyFunc(cmd []string) ([]string, error) { func zunionKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
endIdx := slices.IndexFunc(cmd[1:], func(s string) bool { endIdx := slices.IndexFunc(cmd[1:], func(s string) bool {
if strings.EqualFold(s, "WEIGHTS") || if strings.EqualFold(s, "WEIGHTS") ||
@@ -234,17 +339,25 @@ func zunionKeyFunc(cmd []string) ([]string, error) {
return false return false
}) })
if endIdx == -1 { if endIdx == -1 {
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
} }
if endIdx >= 1 { if endIdx >= 1 {
return cmd[1:endIdx], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:endIdx],
WriteKeys: cmd[1:endIdx],
}, nil
} }
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
func zunionstoreKeyFunc(cmd []string) ([]string, error) { func zunionstoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 { if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
endIdx := slices.IndexFunc(cmd[1:], func(s string) bool { endIdx := slices.IndexFunc(cmd[1:], func(s string) bool {
if strings.EqualFold(s, "WEIGHTS") || if strings.EqualFold(s, "WEIGHTS") ||
@@ -255,10 +368,18 @@ func zunionstoreKeyFunc(cmd []string) ([]string, error) {
return false return false
}) })
if endIdx == -1 { if endIdx == -1 {
return cmd[1:], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:],
WriteKeys: cmd[1:2],
}, nil
} }
if endIdx >= 1 { if endIdx >= 1 {
return cmd[1:endIdx], nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:endIdx],
WriteKeys: cmd[1:2],
}, nil
} }
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }

View File

@@ -24,13 +24,13 @@ import (
"net" "net"
) )
func handleSetRange(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSetRange(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := setRangeKeyFunc(cmd) keys, err := setRangeKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.WriteKeys[0]
offset, ok := internal.AdaptType(cmd[2]).(int) offset, ok := internal.AdaptType(cmd[2]).(int)
if !ok { if !ok {
@@ -105,7 +105,7 @@ func handleStrLen(ctx context.Context, cmd []string, server types.EchoVault, con
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
@@ -125,13 +125,13 @@ func handleStrLen(ctx context.Context, cmd []string, server types.EchoVault, con
return []byte(fmt.Sprintf(":%d\r\n", len(value))), nil return []byte(fmt.Sprintf(":%d\r\n", len(value))), nil
} }
func handleSubStr(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) { func handleSubStr(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := subStrKeyFunc(cmd) keys, err := subStrKeyFunc(cmd)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key := keys[0] key := keys.ReadKeys[0]
start, startOk := internal.AdaptType(cmd[2]).(int) start, startOk := internal.AdaptType(cmd[2]).(int)
end, endOk := internal.AdaptType(cmd[3]).(int) end, endOk := internal.AdaptType(cmd[3]).(int)

View File

@@ -17,25 +17,38 @@ package str
import ( import (
"errors" "errors"
"github.com/echovault/echovault/pkg/constants" "github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
) )
func setRangeKeyFunc(cmd []string) ([]string, error) { func setRangeKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
} }
func strLenKeyFunc(cmd []string) ([]string, error) { func strLenKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 { if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }
func subStrKeyFunc(cmd []string) ([]string, error) { func subStrKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 { if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse) return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
} }