Updated KeyExtractionFunc for all the modules

This commit is contained in:
Kelvin Clement Mwinuka
2024-04-21 02:59:07 +08:00
parent 7b88122c25
commit c2c887cd75
18 changed files with 1842 additions and 1378 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -410,34 +410,30 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command types.
return errors.New("not authorised to access any keys")
}
// 8. If @read is in the list of categories, check if keys are in IncludedReadKeys
if slices.Contains(categories, constants.ReadCategory) {
if !slices.ContainsFunc(readKeys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKeyGlob string) bool {
if acl.GlobPatterns[readKeyGlob].Match(key) {
return true
}
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%R", key))
return false
})
}) {
return fmt.Errorf("not authorised to access the following keys %+v", notAllowed)
}
// 8. Check if readKeys are in IncludedReadKeys
if !slices.ContainsFunc(readKeys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKeyGlob string) bool {
if acl.GlobPatterns[readKeyGlob].Match(key) {
return true
}
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%R", key))
return false
})
}) {
return fmt.Errorf("not authorised to access the following keys %+v", notAllowed)
}
// 9. If @write is in the list of categories, check if keys are in IncludedWriteKeys
if slices.Contains(categories, constants.WriteCategory) {
if !slices.ContainsFunc(writeKeys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool {
if acl.GlobPatterns[writeKeyGlob].Match(key) {
return true
}
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%W", key))
return false
})
}) {
return fmt.Errorf("not authorised to access the following keys %+v", notAllowed)
}
// 9. Check if keys are in IncludedWriteKeys
if !slices.ContainsFunc(writeKeys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool {
if acl.GlobPatterns[writeKeyGlob].Match(key) {
return true
}
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%W", key))
return false
})
}) {
return fmt.Errorf("not authorised to access the following keys %+v", notAllowed)
}
}

View File

@@ -498,8 +498,12 @@ func Commands() []types.Command {
Categories: []string{constants.ConnectionCategory, constants.SlowCategory},
Description: "(AUTH [username] password) Authenticates the connection",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleAuth,
},
@@ -509,8 +513,12 @@ func Commands() []types.Command {
Categories: []string{},
Description: "Access-Control-List commands",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
SubCommands: []types.SubCommand{
{
@@ -520,8 +528,12 @@ func Commands() []types.Command {
Description: `(ACL CAT [category]) List all the categories.
If the optional category is provided, list all the commands in the category`,
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleCat,
},
@@ -531,8 +543,12 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL USERS) List all usernames of the configured ACL users",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleUsers,
},
@@ -542,8 +558,12 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL SETUSER) Configure a new or existing user",
Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleSetUser,
},
@@ -553,8 +573,12 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL GETUSER username) List the ACL rules of a user",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleGetUser,
},
@@ -564,8 +588,12 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL DELUSER username [username ...]) Deletes users and terminates their connections. Cannot delete default user",
Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleDelUser,
},
@@ -575,8 +603,12 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.FastCategory},
Description: "(ACL WHOAMI) Returns the authenticated user of the current connection",
Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleWhoAmI,
},
@@ -586,8 +618,12 @@ If the optional category is provided, list all the commands in the category`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL LIST) Dumps effective acl rules in acl config file format",
Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleList,
},
@@ -600,8 +636,12 @@ If the optional category is provided, list all the commands in the category`,
When 'MERGE' is passed, users from config file who share a username with users in memory will be merged.
When 'REPLACE' is passed, users from config file who share a username with users in memory will replace the user in memory.`,
Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleLoad,
},
@@ -611,8 +651,12 @@ When 'REPLACE' is passed, users from config file who share a username with users
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(ACL SAVE) Saves the effective ACL rules the configured ACL config file",
Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleSave,
},

View File

@@ -26,7 +26,7 @@ import (
"strings"
)
func handleGetAllCommands(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
func handleGetAllCommands(_ context.Context, _ []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
commands := server.GetAllCommands()
res := ""
@@ -194,13 +194,19 @@ func handleCommandDocs(_ context.Context, _ []string, _ types.EchoVault, _ *net.
func Commands() []types.Command {
return []types.Command{
{
Command: "commands",
Module: constants.AdminModule,
Categories: []string{constants.AdminCategory, constants.SlowCategory},
Description: "Get a list of all the commands in available on the echovault with categories and descriptions",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
HandlerFunc: handleGetAllCommands,
Command: "commands",
Module: constants.AdminModule,
Categories: []string{constants.AdminCategory, constants.SlowCategory},
Description: "Get a list of all the commands in available on the echovault with categories and descriptions",
Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleGetAllCommands,
},
{
Command: "command",
@@ -208,27 +214,43 @@ func Commands() []types.Command {
Categories: []string{},
Description: "Commands pertaining to echovault commands",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
SubCommands: []types.SubCommand{
{
Command: "docs",
Module: constants.AdminModule,
Categories: []string{constants.SlowCategory, constants.ConnectionCategory},
Description: "Get command documentation",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
HandlerFunc: handleCommandDocs,
Command: "docs",
Module: constants.AdminModule,
Categories: []string{constants.SlowCategory, constants.ConnectionCategory},
Description: "Get command documentation",
Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleCommandDocs,
},
{
Command: "count",
Module: constants.AdminModule,
Categories: []string{constants.SlowCategory},
Description: "Get the dumber of commands in the echovault",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
HandlerFunc: handleCommandCount,
Command: "count",
Module: constants.AdminModule,
Categories: []string{constants.SlowCategory},
Description: "Get the dumber of commands in the echovault",
Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleCommandCount,
},
{
Command: "list",
@@ -236,9 +258,15 @@ func Commands() []types.Command {
Categories: []string{constants.SlowCategory},
Description: `(COMMAND LIST [FILTERBY <ACLCAT category | PATTERN pattern | MODULE module>]) Get the list of command names.
Allows for filtering by ACL category or glob pattern.`,
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
HandlerFunc: handleCommandList,
Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleCommandList,
},
},
},
@@ -248,8 +276,12 @@ Allows for filtering by ACL category or glob pattern.`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(SAVE) Trigger a snapshot save",
Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: func(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
if err := server.TakeSnapshot(); err != nil {
@@ -264,8 +296,12 @@ Allows for filtering by ACL category or glob pattern.`,
Categories: []string{constants.AdminCategory, constants.FastCategory, constants.DangerousCategory},
Description: "(LASTSAVE) Get unix timestamp for the latest snapshot in milliseconds.",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: func(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
msec := server.GetLatestSnapshotTime()
@@ -281,8 +317,12 @@ Allows for filtering by ACL category or glob pattern.`,
Categories: []string{constants.AdminCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(REWRITEAOF) Trigger re-writing of append process",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: func(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
if err := server.RewriteAOF(); err != nil {

View File

@@ -42,8 +42,12 @@ func Commands() []types.Command {
Categories: []string{constants.FastCategory, constants.ConnectionCategory},
Description: "(PING [value]) Ping the echovault. If a value is provided, the value will be echoed.",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handlePing,
},

View File

@@ -39,7 +39,7 @@ func handleSet(ctx context.Context, cmd []string, server types.EchoVault, _ *net
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
value := cmd[2]
res := []byte(constants.OkResponse)
clock := server.GetClock()
@@ -99,7 +99,8 @@ func handleSet(ctx context.Context, cmd []string, server types.EchoVault, _ *net
}
func handleMSet(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
if _, err := msetKeyFunc(cmd); err != nil {
_, err := msetKeyFunc(cmd)
if err != nil {
return nil, err
}
@@ -159,7 +160,7 @@ func handleGet(ctx context.Context, cmd []string, server types.EchoVault, _ *net
if err != nil {
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
return []byte("$-1\r\n"), nil
@@ -185,7 +186,7 @@ func handleMGet(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
values := make(map[string]string)
locks := make(map[string]bool)
for _, key := range keys {
for _, key := range keys.ReadKeys {
if _, ok := values[key]; ok {
// Skip if we have already locked this key
continue
@@ -232,7 +233,7 @@ func handleDel(ctx context.Context, cmd []string, server types.EchoVault, _ *net
return nil, err
}
count := 0
for _, key := range keys {
for _, key := range keys.WriteKeys {
err = server.DeleteKey(ctx, key)
if err != nil {
log.Printf("could not delete key %s due to error: %+v\n", key, err)
@@ -249,7 +250,7 @@ func handlePersist(ctx context.Context, cmd []string, server types.EchoVault, _
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil
@@ -276,7 +277,7 @@ func handleExpireTime(ctx context.Context, cmd []string, server types.EchoVault,
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
return []byte(":-2\r\n"), nil
@@ -307,7 +308,7 @@ func handleTTL(ctx context.Context, cmd []string, server types.EchoVault, _ *net
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
clock := server.GetClock()
@@ -344,7 +345,7 @@ func handleExpire(ctx context.Context, cmd []string, server types.EchoVault, _ *
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
// Extract time
n, err := strconv.ParseInt(cmd[2], 10, 64)
@@ -412,7 +413,7 @@ func handleExpireAt(ctx context.Context, cmd []string, server types.EchoVault, _
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
// Extract time
n, err := strconv.ParseInt(cmd[2], 10, 64)

View File

@@ -17,18 +17,23 @@ package generic
import (
"errors"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
)
func setKeyFunc(cmd []string) ([]string, error) {
func setKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 || len(cmd) > 7 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func msetKeyFunc(cmd []string) ([]string, error) {
func msetKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd[1:])%2 != 0 {
return nil, errors.New("each key must be paired with a value")
return types.AccessKeys{}, errors.New("each key must be paired with a value")
}
var keys []string
for i, key := range cmd[1:] {
@@ -36,61 +41,97 @@ func msetKeyFunc(cmd []string) ([]string, error) {
keys = append(keys, key)
}
}
return keys, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: keys,
}, nil
}
func getKeyFunc(cmd []string) ([]string, error) {
func getKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
func mgetKeyFunc(cmd []string) ([]string, error) {
func mgetKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
func delKeyFunc(cmd []string) ([]string, error) {
func delKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:],
}, nil
}
func persistKeyFunc(cmd []string) ([]string, error) {
func persistKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:],
}, nil
}
func expireTimeKeyFunc(cmd []string) ([]string, error) {
func expireTimeKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
func ttlKeyFunc(cmd []string) ([]string, error) {
func ttlKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
func expireKeyFunc(cmd []string) ([]string, error) {
func expireKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 || len(cmd) > 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func expireAtKeyFunc(cmd []string) ([]string, error) {
func expireAtKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 || len(cmd) > 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}

View File

@@ -34,7 +34,7 @@ func handleHSET(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
entries := make(map[string]interface{})
if len(cmd[2:])%2 != 0 {
@@ -92,7 +92,7 @@ func handleHGET(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
fields := cmd[2:]
if !server.KeyExists(ctx, key) {
@@ -143,7 +143,7 @@ func handleHSTRLEN(ctx context.Context, cmd []string, server types.EchoVault, _
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
fields := cmd[2:]
if !server.KeyExists(ctx, key) {
@@ -194,7 +194,7 @@ func handleHVALS(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil
@@ -235,7 +235,7 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server types.EchoVault,
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
count := 1
if len(cmd) >= 3 {
@@ -351,7 +351,7 @@ func handleHLEN(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil
@@ -376,7 +376,7 @@ func handleHKEYS(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil
@@ -406,7 +406,7 @@ func handleHINCRBY(ctx context.Context, cmd []string, server types.EchoVault, _
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
field := cmd[2]
var intIncrement int
@@ -498,7 +498,7 @@ func handleHGETALL(ctx context.Context, cmd []string, server types.EchoVault, _
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil
@@ -538,7 +538,7 @@ func handleHEXISTS(ctx context.Context, cmd []string, server types.EchoVault, _
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
field := cmd[2]
if !server.KeyExists(ctx, key) {
@@ -568,7 +568,7 @@ func handleHDEL(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
fields := cmd[2:]
if !server.KeyExists(ctx, key) {

View File

@@ -17,91 +17,144 @@ package hash
import (
"errors"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
)
func hsetKeyFunc(cmd []string) ([]string, error) {
func hsetKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func hsetnxKeyFunc(cmd []string) ([]string, error) {
func hsetnxKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func hgetKeyFunc(cmd []string) ([]string, error) {
func hgetKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func hstrlenKeyFunc(cmd []string) ([]string, error) {
func hstrlenKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func hvalsKeyFunc(cmd []string) ([]string, error) {
func hvalsKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
func hrandfieldKeyFunc(cmd []string) ([]string, error) {
func hrandfieldKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 || len(cmd) > 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
if len(cmd) == 2 {
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func hlenKeyFunc(cmd []string) ([]string, error) {
func hlenKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
func hkeysKeyFunc(cmd []string) ([]string, error) {
func hkeysKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
func hincrbyKeyFunc(cmd []string) ([]string, error) {
func hincrbyKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func hgetallKeyFunc(cmd []string) ([]string, error) {
func hgetallKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
func hexistsKeyFunc(cmd []string) ([]string, error) {
func hexistsKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func hdelKeyFunc(cmd []string) ([]string, error) {
func hdelKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}

View File

@@ -33,7 +33,7 @@ func handleLLen(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
// If key does not exist, return 0
@@ -52,13 +52,13 @@ func handleLLen(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, errors.New("LLEN command on non-list item")
}
func handleLIndex(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleLIndex(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lindexKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
index, ok := internal.AdaptType(cmd[2]).(int)
if !ok {
@@ -86,13 +86,13 @@ func handleLIndex(ctx context.Context, cmd []string, server types.EchoVault, con
return []byte(fmt.Sprintf("+%s\r\n", list[index])), nil
}
func handleLRange(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleLRange(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lrangeKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
start, startOk := internal.AdaptType(cmd[2]).(int)
end, endOk := internal.AdaptType(cmd[3]).(int)
@@ -165,13 +165,13 @@ func handleLRange(ctx context.Context, cmd []string, server types.EchoVault, con
return bytes, nil
}
func handleLSet(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleLSet(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lsetKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
index, ok := internal.AdaptType(cmd[2]).(int)
if !ok {
@@ -204,13 +204,13 @@ func handleLSet(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(constants.OkResponse), nil
}
func handleLTrim(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleLTrim(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := ltrimKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
start, startOk := internal.AdaptType(cmd[2]).(int)
end, endOk := internal.AdaptType(cmd[3]).(int)
@@ -253,13 +253,13 @@ func handleLTrim(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(constants.OkResponse), nil
}
func handleLRem(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleLRem(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lremKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
value := cmd[3]
count, ok := internal.AdaptType(cmd[2]).(int)
@@ -321,14 +321,13 @@ func handleLRem(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(constants.OkResponse), nil
}
func handleLMove(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleLMove(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := lmoveKeyFunc(cmd)
if err != nil {
return nil, err
}
source := keys[0]
destination := keys[1]
source, destination := keys.WriteKeys[0], keys.WriteKeys[1]
whereFrom := strings.ToLower(cmd[3])
whereTo := strings.ToLower(cmd[4])
@@ -394,7 +393,7 @@ func handleLPush(ctx context.Context, cmd []string, server types.EchoVault, _ *n
newElems = append(newElems, internal.AdaptType(elem))
}
key := keys[0]
key := keys.WriteKeys[0]
if !server.KeyExists(ctx, key) {
switch strings.ToLower(cmd[0]) {
@@ -428,13 +427,13 @@ func handleLPush(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return []byte(constants.OkResponse), nil
}
func handleRPush(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleRPush(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := rpushKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
var newElems []interface{}
@@ -482,7 +481,7 @@ func handlePop(ctx context.Context, cmd []string, server types.EchoVault, _ *net
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
if !server.KeyExists(ctx, key) {
return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(cmd[0]))

View File

@@ -17,74 +17,115 @@ package list
import (
"errors"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
)
func lpushKeyFunc(cmd []string) ([]string, error) {
func lpushKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func popKeyFunc(cmd []string) ([]string, error) {
func popKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:],
}, nil
}
func llenKeyFunc(cmd []string) ([]string, error) {
func llenKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
func lrangeKeyFunc(cmd []string) ([]string, error) {
func lrangeKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func lindexKeyFunc(cmd []string) ([]string, error) {
func lindexKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func lsetKeyFunc(cmd []string) ([]string, error) {
func lsetKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func ltrimKeyFunc(cmd []string) ([]string, error) {
func ltrimKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func lremKeyFunc(cmd []string) ([]string, error) {
func lremKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func rpushKeyFunc(cmd []string) ([]string, error) {
func rpushKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func lmoveKeyFunc(cmd []string) ([]string, error) {
func lmoveKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 5 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1], cmd[2]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:3],
}, nil
}

View File

@@ -111,12 +111,16 @@ func Commands() []types.Command {
Categories: []string{constants.PubSubCategory, constants.ConnectionCategory, constants.SlowCategory},
Description: "(SUBSCRIBE channel [channel ...]) Subscribe to one or more channels.",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
// Treat the channels as keys
if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: cmd[1:],
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleSubscribe,
},
@@ -126,12 +130,16 @@ func Commands() []types.Command {
Categories: []string{constants.PubSubCategory, constants.ConnectionCategory, constants.SlowCategory},
Description: "(PSUBSCRIBE pattern [pattern ...]) Subscribe to one or more glob patterns.",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
// Treat the patterns as keys
if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: cmd[1:],
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleSubscribe,
},
@@ -141,12 +149,16 @@ func Commands() []types.Command {
Categories: []string{constants.PubSubCategory, constants.FastCategory},
Description: "(PUBLISH channel message) Publish a message to the specified channel.",
Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
// Treat the channel as a key
if len(cmd) != 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: cmd[1:2],
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handlePublish,
},
@@ -158,9 +170,13 @@ func Commands() []types.Command {
If the channel list is not provided, then the connection will be unsubscribed from all the channels that
it's currently subscribe to.`,
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
// Treat the channels as keys
return cmd[1:], nil
return types.AccessKeys{
Channels: cmd[1:],
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleUnsubscribe,
},
@@ -172,19 +188,28 @@ it's currently subscribe to.`,
If the pattern list is not provided, then the connection will be unsubscribed from all the patterns that
it's currently subscribe to.`,
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
// Treat the channels as keys
return cmd[1:], nil
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: cmd[1:],
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleUnsubscribe,
},
{
Command: "pubsub",
Module: constants.PubSubModule,
Categories: []string{},
Description: "",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
Command: "pubsub",
Module: constants.PubSubModule,
Categories: []string{},
Description: "",
Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: func(_ context.Context, _ []string, _ types.EchoVault, _ *net.Conn) ([]byte, error) {
return nil, errors.New("provide CHANNELS, NUMPAT, or NUMSUB subcommand")
},
@@ -196,18 +221,30 @@ it's currently subscribe to.`,
Description: `(PUBSUB CHANNELS [pattern]) Returns an array containing the list of channels that
match the given pattern. If no pattern is provided, all active channels are returned. Active channels are
channels with 1 or more subscribers.`,
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
HandlerFunc: handlePubSubChannels,
Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handlePubSubChannels,
},
{
Command: "numpat",
Module: constants.PubSubModule,
Categories: []string{constants.PubSubCategory, constants.SlowCategory},
Description: `(PUBSUB NUMPAT) Return the number of patterns that are currently subscribed to by clients.`,
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
HandlerFunc: handlePubSubNumPat,
Command: "numpat",
Module: constants.PubSubModule,
Categories: []string{constants.PubSubCategory, constants.SlowCategory},
Description: `(PUBSUB NUMPAT) Return the number of patterns that are currently subscribed to by clients.`,
Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handlePubSubNumPat,
},
{
Command: "numsub",
@@ -215,9 +252,15 @@ channels with 1 or more subscribers.`,
Categories: []string{constants.PubSubCategory, constants.SlowCategory},
Description: `(PUBSUB NUMSUB [channel [channel ...]]) Return an array of arrays containing the provided
channel name and how many clients are currently subscribed to the channel.`,
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return cmd[2:], nil },
HandlerFunc: handlePubSubNumSubs,
Sync: false,
KeyExtractionFunc: func(cmd []string) (types.AccessKeys, error) {
return types.AccessKeys{
Channels: cmd[2:],
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handlePubSubNumSubs,
},
},
},

View File

@@ -27,13 +27,13 @@ import (
"strings"
)
func handleSADD(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSADD(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := saddKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
var set *internal_set.Set
@@ -64,13 +64,13 @@ func handleSADD(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(fmt.Sprintf(":%d\r\n", count)), nil
}
func handleSCARD(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSCARD(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := scardKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
return []byte(fmt.Sprintf(":0\r\n")), nil
@@ -91,23 +91,23 @@ func handleSCARD(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(fmt.Sprintf(":%d\r\n", cardinality)), nil
}
func handleSDIFF(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSDIFF(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sdiffKeyFunc(cmd)
if err != nil {
return nil, err
}
// Extract base set first
if !server.KeyExists(ctx, keys[0]) {
return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys[0])
if !server.KeyExists(ctx, keys.ReadKeys[0]) {
return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys.ReadKeys[0])
}
if _, err = server.KeyRLock(ctx, keys[0]); err != nil {
if _, err = server.KeyRLock(ctx, keys.ReadKeys[0]); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, keys[0])
baseSet, ok := server.GetValue(ctx, keys[0]).(*internal_set.Set)
defer server.KeyRUnlock(ctx, keys.ReadKeys[0])
baseSet, ok := server.GetValue(ctx, keys.ReadKeys[0]).(*internal_set.Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", keys[0])
return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
}
locks := make(map[string]bool)
@@ -119,7 +119,7 @@ func handleSDIFF(ctx context.Context, cmd []string, server types.EchoVault, conn
}
}()
for _, key := range keys[1:] {
for _, key := range keys.ReadKeys[1:] {
if !server.KeyExists(ctx, key) {
continue
}
@@ -152,25 +152,25 @@ func handleSDIFF(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(res), nil
}
func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sdiffstoreKeyFunc(cmd)
if err != nil {
return nil, err
}
destination := keys[0]
destination := keys.WriteKeys[0]
// Extract base set first
if !server.KeyExists(ctx, keys[1]) {
return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys[1])
if !server.KeyExists(ctx, keys.ReadKeys[0]) {
return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys.ReadKeys[0])
}
if _, err := server.KeyRLock(ctx, keys[1]); err != nil {
if _, err := server.KeyRLock(ctx, keys.ReadKeys[0]); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, keys[1])
baseSet, ok := server.GetValue(ctx, keys[1]).(*internal_set.Set)
defer server.KeyRUnlock(ctx, keys.ReadKeys[0])
baseSet, ok := server.GetValue(ctx, keys.ReadKeys[0]).(*internal_set.Set)
if !ok {
return nil, fmt.Errorf("value at key %s is not a set", keys[1])
return nil, fmt.Errorf("value at key %s is not a set", keys.ReadKeys[0])
}
locks := make(map[string]bool)
@@ -182,7 +182,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
}
}()
for _, key := range keys[2:] {
for _, key := range keys.ReadKeys[1:] {
if !server.KeyExists(ctx, key) {
continue
}
@@ -193,7 +193,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
}
var sets []*internal_set.Set
for _, key := range keys[2:] {
for _, key := range keys.ReadKeys[1:] {
set, ok := server.GetValue(ctx, key).(*internal_set.Set)
if !ok {
continue
@@ -228,7 +228,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(res), nil
}
func handleSINTER(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSINTER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sinterKeyFunc(cmd)
if err != nil {
return nil, err
@@ -243,7 +243,7 @@ func handleSINTER(ctx context.Context, cmd []string, server types.EchoVault, con
}
}()
for _, key := range keys[0:] {
for _, key := range keys.ReadKeys {
if !server.KeyExists(ctx, key) {
// If key does not exist, then there is no intersection
return []byte("*0\r\n"), nil
@@ -319,7 +319,7 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server types.EchoVault,
}
}()
for _, key := range keys {
for _, key := range keys.ReadKeys {
if !server.KeyExists(ctx, key) {
// If key does not exist, then there is no intersection
return []byte(":0\r\n"), nil
@@ -350,7 +350,7 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
}
func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sinterstoreKeyFunc(cmd)
if err != nil {
return nil, err
@@ -365,7 +365,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault
}
}()
for _, key := range keys[1:] {
for _, key := range keys.ReadKeys {
if !server.KeyExists(ctx, key) {
// If key does not exist, then there is no intersection
return []byte(":0\r\n"), nil
@@ -388,7 +388,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault
}
intersect, _ := internal_set.Intersection(0, sets...)
destination := keys[0]
destination := keys.WriteKeys[0]
if server.KeyExists(ctx, destination) {
if _, err = server.KeyLock(ctx, destination); err != nil {
@@ -408,13 +408,13 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
}
func handleSISMEMBER(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSISMEMBER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sismemberKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil
@@ -437,13 +437,13 @@ func handleSISMEMBER(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(":1\r\n"), nil
}
func handleSMEMBERS(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSMEMBERS(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := smembersKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil
@@ -472,13 +472,13 @@ func handleSMEMBERS(ctx context.Context, cmd []string, server types.EchoVault, c
return []byte(res), nil
}
func handleSMISMEMBER(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSMISMEMBER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := smismemberKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
members := cmd[2:]
if !server.KeyExists(ctx, key) {
@@ -515,14 +515,13 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(res), nil
}
func handleSMOVE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSMOVE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := smoveKeyFunc(cmd)
if err != nil {
return nil, err
}
source := keys[0]
destination := keys[1]
source, destination := keys.WriteKeys[0], keys.WriteKeys[1]
member := cmd[3]
if !server.KeyExists(ctx, source) {
@@ -569,13 +568,13 @@ func handleSMOVE(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(fmt.Sprintf(":%d\r\n", res)), nil
}
func handleSPOP(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSPOP(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := spopKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
count := 1
if len(cmd) == 3 {
@@ -613,13 +612,13 @@ func handleSPOP(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(res), nil
}
func handleSRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := srandmemberKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
count := 1
if len(cmd) == 3 {
@@ -657,13 +656,13 @@ func handleSRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault
return []byte(res), nil
}
func handleSREM(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSREM(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sremKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
members := cmd[2:]
if !server.KeyExists(ctx, key) {
@@ -685,7 +684,7 @@ func handleSREM(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(fmt.Sprintf(":%d\r\n", count)), nil
}
func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sunionKeyFunc(cmd)
if err != nil {
return nil, err
@@ -700,7 +699,7 @@ func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, con
}
}()
for _, key := range keys {
for _, key := range keys.ReadKeys {
if !server.KeyExists(ctx, key) {
continue
}
@@ -736,7 +735,7 @@ func handleSUNION(ctx context.Context, cmd []string, server types.EchoVault, con
return []byte(res), nil
}
func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := sunionstoreKeyFunc(cmd)
if err != nil {
return nil, err
@@ -751,7 +750,7 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault
}
}()
for _, key := range keys[1:] {
for _, key := range keys.ReadKeys {
if !server.KeyExists(ctx, key) {
continue
}
@@ -776,7 +775,7 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault
union := internal_set.Union(sets...)
destination := cmd[1]
destination := keys.WriteKeys[0]
if server.KeyExists(ctx, destination) {
if _, err = server.KeyLock(ctx, destination); err != nil {

View File

@@ -17,48 +17,69 @@ package set
import (
"errors"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"slices"
"strings"
)
func saddKeyFunc(cmd []string) ([]string, error) {
func saddKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func scardKeyFunc(cmd []string) ([]string, error) {
func scardKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func sdiffKeyFunc(cmd []string) ([]string, error) {
func sdiffKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
func sdiffstoreKeyFunc(cmd []string) ([]string, error) {
func sdiffstoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:],
WriteKeys: cmd[1:2],
}, nil
}
func sinterKeyFunc(cmd []string) ([]string, error) {
func sinterKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
func sintercardKeyFunc(cmd []string) ([]string, error) {
func sintercardKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
limitIdx := slices.IndexFunc(cmd, func(s string) bool {
@@ -66,78 +87,126 @@ func sintercardKeyFunc(cmd []string) ([]string, error) {
})
if limitIdx == -1 {
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
return cmd[1:limitIdx], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:limitIdx],
WriteKeys: make([]string, 0),
}, nil
}
func sinterstoreKeyFunc(cmd []string) ([]string, error) {
func sinterstoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:],
WriteKeys: cmd[1:2],
}, nil
}
func sismemberKeyFunc(cmd []string) ([]string, error) {
func sismemberKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
func smembersKeyFunc(cmd []string) ([]string, error) {
func smembersKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
func smismemberKeyFunc(cmd []string) ([]string, error) {
func smismemberKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func smoveKeyFunc(cmd []string) ([]string, error) {
func smoveKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:3], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:3],
}, nil
}
func spopKeyFunc(cmd []string) ([]string, error) {
func spopKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 || len(cmd) > 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func srandmemberKeyFunc(cmd []string) ([]string, error) {
func srandmemberKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 || len(cmd) > 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func sremKeyFunc(cmd []string) ([]string, error) {
func sremKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func sunionKeyFunc(cmd []string) ([]string, error) {
func sunionKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
func sunionstoreKeyFunc(cmd []string) ([]string, error) {
func sunionstoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:],
WriteKeys: cmd[1:2],
}, nil
}

View File

@@ -36,7 +36,7 @@ func handleZADD(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
var updatePolicy interface{} = nil
var comparison interface{} = nil
@@ -181,12 +181,12 @@ func handleZADD(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return []byte(fmt.Sprintf(":%d\r\n", set.Cardinality())), nil
}
func handleZCARD(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZCARD(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zcardKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil
@@ -205,13 +205,13 @@ func handleZCARD(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(fmt.Sprintf(":%d\r\n", set.Cardinality())), nil
}
func handleZCOUNT(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZCOUNT(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zcountKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
minimum := sorted_set.Score(math.Inf(-1))
switch internal.AdaptType(cmd[2]).(type) {
@@ -279,7 +279,7 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server types.EchoVault,
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
minimum := cmd[2]
maximum := cmd[3]
@@ -318,7 +318,7 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(fmt.Sprintf(":%d\r\n", count)), nil
}
func handleZDIFF(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZDIFF(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zdiffKeyFunc(cmd)
if err != nil {
return nil, err
@@ -341,34 +341,34 @@ func handleZDIFF(ctx context.Context, cmd []string, server types.EchoVault, conn
}()
// Extract base set
if !server.KeyExists(ctx, keys[0]) {
if !server.KeyExists(ctx, keys.ReadKeys[0]) {
// If base set does not exist, return an empty array
return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, keys[0]); err != nil {
if _, err = server.KeyRLock(ctx, keys.ReadKeys[0]); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, keys[0])
baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*sorted_set.SortedSet)
defer server.KeyRUnlock(ctx, keys.ReadKeys[0])
baseSortedSet, ok := server.GetValue(ctx, keys.ReadKeys[0]).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[0])
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0])
}
// Extract the remaining sets
var sets []*sorted_set.SortedSet
for i := 1; i < len(keys); i++ {
if !server.KeyExists(ctx, keys[i]) {
for i := 1; i < len(keys.ReadKeys); i++ {
if !server.KeyExists(ctx, keys.ReadKeys[i]) {
continue
}
locked, err := server.KeyRLock(ctx, keys[i])
locked, err := server.KeyRLock(ctx, keys.ReadKeys[i])
if err != nil {
return nil, err
}
locks[keys[i]] = locked
set, ok := server.GetValue(ctx, keys[i]).(*sorted_set.SortedSet)
locks[keys.ReadKeys[i]] = locked
set, ok := server.GetValue(ctx, keys.ReadKeys[i]).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i])
}
sets = append(sets, set)
}
@@ -391,13 +391,13 @@ func handleZDIFF(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(res), nil
}
func handleZDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zdiffstoreKeyFunc(cmd)
if err != nil {
return nil, err
}
destination := cmd[1]
destination := keys.WriteKeys[0]
locks := make(map[string]bool)
defer func() {
@@ -409,29 +409,29 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
}()
// Extract base set
if !server.KeyExists(ctx, keys[0]) {
if !server.KeyExists(ctx, keys.ReadKeys[0]) {
// If base set does not exist, return 0
return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, keys[0]); err != nil {
if _, err = server.KeyRLock(ctx, keys.ReadKeys[0]); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, keys[0])
baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*sorted_set.SortedSet)
defer server.KeyRUnlock(ctx, keys.ReadKeys[0])
baseSortedSet, ok := server.GetValue(ctx, keys.ReadKeys[0]).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[0])
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0])
}
var sets []*sorted_set.SortedSet
for i := 1; i < len(keys); i++ {
if server.KeyExists(ctx, keys[i]) {
if _, err = server.KeyRLock(ctx, keys[i]); err != nil {
for i := 1; i < len(keys.ReadKeys); i++ {
if server.KeyExists(ctx, keys.ReadKeys[i]) {
if _, err = server.KeyRLock(ctx, keys.ReadKeys[i]); err != nil {
return nil, err
}
set, ok := server.GetValue(ctx, keys[i]).(*sorted_set.SortedSet)
set, ok := server.GetValue(ctx, keys.ReadKeys[i]).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i])
}
sets = append(sets, set)
}
@@ -457,13 +457,13 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server types.EchoVault,
return []byte(fmt.Sprintf(":%d\r\n", diff.Cardinality())), nil
}
func handleZINCRBY(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZINCRBY(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zincrbyKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
member := sorted_set.Value(cmd[3])
var increment sorted_set.Score
@@ -524,8 +524,8 @@ func handleZINCRBY(ctx context.Context, cmd []string, server types.EchoVault, co
strconv.FormatFloat(float64(set.Get(member).Score), 'f', -1, 64))), nil
}
func handleZINTER(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
keys, err := zinterKeyFunc(cmd)
func handleZINTER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
_, err := zinterKeyFunc(cmd)
if err != nil {
return nil, err
}
@@ -584,13 +584,13 @@ func handleZINTER(ctx context.Context, cmd []string, server types.EchoVault, con
return []byte(res), nil
}
func handleZINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
keys, err := zinterstoreKeyFunc(cmd)
func handleZINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
k, err := zinterstoreKeyFunc(cmd)
if err != nil {
return nil, err
}
destination := keys[0]
destination := k.WriteKeys[0]
// Remove the destination keys from the command before parsing it
cmd = slices.DeleteFunc(cmd, func(s string) bool {
@@ -651,7 +651,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server types.EchoVault
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
}
func handleZMPOP(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZMPOP(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zmpopKeyFunc(cmd)
if err != nil {
return nil, err
@@ -697,22 +697,22 @@ func handleZMPOP(ctx context.Context, cmd []string, server types.EchoVault, conn
}
}
for i := 0; i < len(keys); i++ {
if server.KeyExists(ctx, keys[i]) {
if _, err = server.KeyLock(ctx, keys[i]); err != nil {
for i := 0; i < len(keys.WriteKeys); i++ {
if server.KeyExists(ctx, keys.WriteKeys[i]) {
if _, err = server.KeyLock(ctx, keys.WriteKeys[i]); err != nil {
continue
}
v, ok := server.GetValue(ctx, keys[i]).(*sorted_set.SortedSet)
v, ok := server.GetValue(ctx, keys.WriteKeys[i]).(*sorted_set.SortedSet)
if !ok || v.Cardinality() == 0 {
server.KeyUnlock(ctx, keys[i])
server.KeyUnlock(ctx, keys.WriteKeys[i])
continue
}
popped, err := v.Pop(count, policy)
if err != nil {
server.KeyUnlock(ctx, keys[i])
server.KeyUnlock(ctx, keys.WriteKeys[i])
return nil, err
}
server.KeyUnlock(ctx, keys[i])
server.KeyUnlock(ctx, keys.WriteKeys[i])
res := fmt.Sprintf("*%d", popped.Cardinality())
@@ -729,13 +729,13 @@ func handleZMPOP(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte("*0\r\n"), nil
}
func handleZPOP(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZPOP(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zpopKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
count := 1
policy := "min"
@@ -782,13 +782,13 @@ func handleZPOP(ctx context.Context, cmd []string, server types.EchoVault, conn
return []byte(res), nil
}
func handleZMSCORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZMSCORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zmscoreKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil
@@ -824,13 +824,13 @@ func handleZMSCORE(ctx context.Context, cmd []string, server types.EchoVault, co
return []byte(res), nil
}
func handleZRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZRANDMEMBER(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zrandmemberKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
count := 1
if len(cmd) >= 3 {
@@ -888,7 +888,7 @@ func handleZRANK(ctx context.Context, cmd []string, server types.EchoVault, _ *n
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
member := cmd[2]
withscores := false
@@ -938,7 +938,7 @@ func handleZREM(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil
@@ -964,13 +964,13 @@ func handleZREM(ctx context.Context, cmd []string, server types.EchoVault, _ *ne
return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
}
func handleZSCORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZSCORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zscoreKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
return []byte("$-1\r\n"), nil
@@ -993,13 +993,13 @@ func handleZSCORE(ctx context.Context, cmd []string, server types.EchoVault, con
return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(score), score)), nil
}
func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zremrangebyscoreKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
deletedCount := 0
@@ -1037,13 +1037,13 @@ func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server types.Echo
return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
}
func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zremrangebyrankKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
start, err := strconv.Atoi(cmd[2])
if err != nil {
@@ -1102,13 +1102,13 @@ func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server types.EchoV
return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
}
func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zremrangebylexKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
minimum := cmd[2]
maximum := cmd[3]
@@ -1149,13 +1149,13 @@ func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server types.EchoVa
return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
}
func handleZRANGE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZRANGE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zrangeKeyCount(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
policy := "byscore"
scoreStart := math.Inf(-1) // Lower bound if policy is "byscore"
scoreStop := math.Inf(1) // Upper bound if policy is "byscore"
@@ -1289,14 +1289,14 @@ func handleZRANGE(ctx context.Context, cmd []string, server types.EchoVault, con
return []byte(res), nil
}
func handleZRANGESTORE(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZRANGESTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zrangeStoreKeyFunc(cmd)
if err != nil {
return nil, err
}
destination := keys[0]
source := keys[1]
destination := keys.WriteKeys[0]
source := keys.ReadKeys[0]
policy := "byscore"
scoreStart := math.Inf(-1) // Lower bound if policy is "byscore"
scoreStop := math.Inf(1) // Upper bound if policy is "byfloat"
@@ -1431,7 +1431,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server types.EchoVault
return []byte(fmt.Sprintf(":%d\r\n", newSortedSet.Cardinality())), nil
}
func handleZUNION(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZUNION(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
if _, err := zunionKeyFunc(cmd); err != nil {
return nil, err
}
@@ -1486,12 +1486,12 @@ func handleZUNION(ctx context.Context, cmd []string, server types.EchoVault, con
}
func handleZUNIONSTORE(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zunionstoreKeyFunc(cmd)
k, err := zunionstoreKeyFunc(cmd)
if err != nil {
return nil, err
}
destination := keys[0]
destination := k.WriteKeys[0]
// Remove destination key from list of keys
cmd = slices.DeleteFunc(cmd, func(s string) bool {

View File

@@ -17,34 +17,47 @@ package sorted_set
import (
"errors"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"slices"
"strings"
)
func zaddKeyFunc(cmd []string) ([]string, error) {
func zaddKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func zcardKeyFunc(cmd []string) ([]string, error) {
func zcardKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
func zcountKeyFunc(cmd []string) ([]string, error) {
func zcountKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func zdiffKeyFunc(cmd []string) ([]string, error) {
func zdiffKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
withscoresIndex := slices.IndexFunc(cmd, func(s string) bool {
@@ -52,29 +65,45 @@ func zdiffKeyFunc(cmd []string) ([]string, error) {
})
if withscoresIndex == -1 {
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
return cmd[1:withscoresIndex], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:withscoresIndex],
WriteKeys: make([]string, 0),
}, nil
}
func zdiffstoreKeyFunc(cmd []string) ([]string, error) {
func zdiffstoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[2:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:],
WriteKeys: cmd[1:2],
}, nil
}
func zincrbyKeyFunc(cmd []string) ([]string, error) {
func zincrbyKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func zinterKeyFunc(cmd []string) ([]string, error) {
func zinterKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
endIdx := slices.IndexFunc(cmd[1:], func(s string) bool {
if strings.EqualFold(s, "WEIGHTS") ||
@@ -85,17 +114,25 @@ func zinterKeyFunc(cmd []string) ([]string, error) {
return false
})
if endIdx == -1 {
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
if endIdx >= 1 {
return cmd[1:endIdx], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:endIdx],
WriteKeys: make([]string, 0),
}, nil
}
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
func zinterstoreKeyFunc(cmd []string) ([]string, error) {
if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse)
func zinterstoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
endIdx := slices.IndexFunc(cmd[1:], func(s string) bool {
if strings.EqualFold(s, "WEIGHTS") ||
@@ -106,124 +143,192 @@ func zinterstoreKeyFunc(cmd []string) ([]string, error) {
return false
})
if endIdx == -1 {
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:],
WriteKeys: cmd[1:2],
}, nil
}
if endIdx >= 2 {
return cmd[1:endIdx], nil
if endIdx >= 3 {
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:endIdx],
WriteKeys: cmd[1:2],
}, nil
}
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
func zmpopKeyFunc(cmd []string) ([]string, error) {
func zmpopKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
endIdx := slices.IndexFunc(cmd, func(s string) bool {
return slices.Contains([]string{"MIN", "MAX", "COUNT"}, strings.ToUpper(s))
})
if endIdx == -1 {
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:],
}, nil
}
if endIdx >= 2 {
return cmd[1:endIdx], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:endIdx],
}, nil
}
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
func zmscoreKeyFunc(cmd []string) ([]string, error) {
func zmscoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func zpopKeyFunc(cmd []string) ([]string, error) {
func zpopKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 || len(cmd) > 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func zrandmemberKeyFunc(cmd []string) ([]string, error) {
func zrandmemberKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 || len(cmd) > 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func zrankKeyFunc(cmd []string) ([]string, error) {
func zrankKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 || len(cmd) > 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func zremKeyFunc(cmd []string) ([]string, error) {
func zremKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func zrevrankKeyFunc(cmd []string) ([]string, error) {
func zrevrankKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func zscoreKeyFunc(cmd []string) ([]string, error) {
func zscoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 3 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func zremrangebylexKeyFunc(cmd []string) ([]string, error) {
func zremrangebylexKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func zremrangebyrankKeyFunc(cmd []string) ([]string, error) {
func zremrangebyrankKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func zremrangebyscoreKeyFunc(cmd []string) ([]string, error) {
func zremrangebyscoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func zlexcountKeyFunc(cmd []string) ([]string, error) {
func zlexcountKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func zrangeKeyCount(cmd []string) ([]string, error) {
func zrangeKeyCount(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 4 || len(cmd) > 10 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:2], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func zrangeStoreKeyFunc(cmd []string) ([]string, error) {
func zrangeStoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 5 || len(cmd) > 11 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return cmd[1:3], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:3],
WriteKeys: cmd[1:2],
}, nil
}
func zunionKeyFunc(cmd []string) ([]string, error) {
func zunionKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
endIdx := slices.IndexFunc(cmd[1:], func(s string) bool {
if strings.EqualFold(s, "WEIGHTS") ||
@@ -234,17 +339,25 @@ func zunionKeyFunc(cmd []string) ([]string, error) {
return false
})
if endIdx == -1 {
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:],
WriteKeys: make([]string, 0),
}, nil
}
if endIdx >= 1 {
return cmd[1:endIdx], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:endIdx],
WriteKeys: cmd[1:endIdx],
}, nil
}
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
func zunionstoreKeyFunc(cmd []string) ([]string, error) {
if len(cmd) < 2 {
return nil, errors.New(constants.WrongArgsResponse)
func zunionstoreKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) < 3 {
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
endIdx := slices.IndexFunc(cmd[1:], func(s string) bool {
if strings.EqualFold(s, "WEIGHTS") ||
@@ -255,10 +368,18 @@ func zunionstoreKeyFunc(cmd []string) ([]string, error) {
return false
})
if endIdx == -1 {
return cmd[1:], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:],
WriteKeys: cmd[1:2],
}, nil
}
if endIdx >= 1 {
return cmd[1:endIdx], nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[2:endIdx],
WriteKeys: cmd[1:2],
}, nil
}
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}

View File

@@ -24,13 +24,13 @@ import (
"net"
)
func handleSetRange(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSetRange(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := setRangeKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.WriteKeys[0]
offset, ok := internal.AdaptType(cmd[2]).(int)
if !ok {
@@ -105,7 +105,7 @@ func handleStrLen(ctx context.Context, cmd []string, server types.EchoVault, con
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil
@@ -125,13 +125,13 @@ func handleStrLen(ctx context.Context, cmd []string, server types.EchoVault, con
return []byte(fmt.Sprintf(":%d\r\n", len(value))), nil
}
func handleSubStr(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
func handleSubStr(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := subStrKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
key := keys.ReadKeys[0]
start, startOk := internal.AdaptType(cmd[2]).(int)
end, endOk := internal.AdaptType(cmd[3]).(int)

View File

@@ -17,25 +17,38 @@ package str
import (
"errors"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
)
func setRangeKeyFunc(cmd []string) ([]string, error) {
func setRangeKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: cmd[1:2],
}, nil
}
func strLenKeyFunc(cmd []string) ([]string, error) {
func strLenKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 2 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}
func subStrKeyFunc(cmd []string) ([]string, error) {
func subStrKeyFunc(cmd []string) (types.AccessKeys, error) {
if len(cmd) != 4 {
return nil, errors.New(constants.WrongArgsResponse)
return types.AccessKeys{}, errors.New(constants.WrongArgsResponse)
}
return []string{cmd[1]}, nil
return types.AccessKeys{
Channels: make([]string, 0),
ReadKeys: cmd[1:2],
WriteKeys: make([]string, 0),
}, nil
}