Updated generic and hash package tests to use tcp connection instead of calling the handler directly

This commit is contained in:
Kelvin Clement Mwinuka
2024-05-24 13:40:40 +08:00
parent 43361cdd42
commit 926a008c23
9 changed files with 1655 additions and 1531 deletions

View File

@@ -16,7 +16,6 @@ package echovault
import ( import (
"context" "context"
"fmt"
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
"slices" "slices"
"strings" "strings"
@@ -230,133 +229,133 @@ func (server *EchoVault) RewriteAOF() (string, error) {
// Errors: // Errors:
// //
// "command <command> already exists" - If a command with the same command name as the passed command already exists. // "command <command> already exists" - If a command with the same command name as the passed command already exists.
func (server *EchoVault) AddCommand(command CommandOptions) error { // func (server *EchoVault) AddCommand(command CommandOptions) error {
server.commandsRWMut.Lock() // server.commandsRWMut.Lock()
defer server.commandsRWMut.Unlock() // defer server.commandsRWMut.Unlock()
// Check if command already exists // // Check if command already exists
for _, c := range server.commands { // for _, c := range server.commands {
if strings.EqualFold(c.Command, command.Command) { // if strings.EqualFold(c.Command, command.Command) {
return fmt.Errorf("command %s already exists", command.Command) // return fmt.Errorf("command %s already exists", command.Command)
} // }
} // }
//
if command.SubCommand == nil || len(command.SubCommand) == 0 { // if command.SubCommand == nil || len(command.SubCommand) == 0 {
// Add command with no subcommands // // Add command with no subcommands
server.commands = append(server.commands, internal.Command{ // server.commands = append(server.commands, internal.Command{
Command: command.Command, // Command: command.Command,
Module: strings.ToLower(command.Module), // Convert module to lower case for uniformity // Module: strings.ToLower(command.Module), // Convert module to lower case for uniformity
Categories: func() []string { // Categories: func() []string {
// Convert all the categories to lower case for uniformity // // Convert all the categories to lower case for uniformity
cats := make([]string, len(command.Categories)) // cats := make([]string, len(command.Categories))
for i, cat := range command.Categories { // for i, cat := range command.Categories {
cats[i] = strings.ToLower(cat) // cats[i] = strings.ToLower(cat)
} // }
return cats // return cats
}(), // }(),
Description: command.Description, // Description: command.Description,
Sync: command.Sync, // Sync: command.Sync,
KeyExtractionFunc: internal.KeyExtractionFunc(func(cmd []string) (internal.KeyExtractionFuncResult, error) { // KeyExtractionFunc: internal.KeyExtractionFunc(func(cmd []string) (internal.KeyExtractionFuncResult, error) {
accessKeys, err := command.KeyExtractionFunc(cmd) // accessKeys, err := command.KeyExtractionFunc(cmd)
if err != nil { // if err != nil {
return internal.KeyExtractionFuncResult{}, err // return internal.KeyExtractionFuncResult{}, err
} // }
return internal.KeyExtractionFuncResult{ // return internal.KeyExtractionFuncResult{
Channels: []string{}, // Channels: []string{},
ReadKeys: accessKeys.ReadKeys, // ReadKeys: accessKeys.ReadKeys,
WriteKeys: accessKeys.WriteKeys, // WriteKeys: accessKeys.WriteKeys,
}, nil // }, nil
}), // }),
HandlerFunc: internal.HandlerFunc(func(params internal.HandlerFuncParams) ([]byte, error) { // HandlerFunc: internal.HandlerFunc(func(params internal.HandlerFuncParams) ([]byte, error) {
return command.HandlerFunc(CommandHandlerFuncParams{ // return command.HandlerFunc(CommandHandlerFuncParams{
Context: params.Context, // Context: params.Context,
Command: params.Command, // Command: params.Command,
KeyLock: params.KeyLock, // KeyLock: params.KeyLock,
KeyUnlock: params.KeyUnlock, // KeyUnlock: params.KeyUnlock,
KeyRLock: params.KeyRLock, // KeyRLock: params.KeyRLock,
KeyRUnlock: params.KeyRUnlock, // KeyRUnlock: params.KeyRUnlock,
KeyExists: params.KeyExists, // KeyExists: params.KeyExists,
CreateKeyAndLock: params.CreateKeyAndLock, // CreateKeyAndLock: params.CreateKeyAndLock,
GetValue: params.GetValue, // GetValue: params.GetValue,
SetValue: params.SetValue, // SetValue: params.SetValue,
}) // })
}), // }),
}) // })
return nil // return nil
} // }
//
// Add command with subcommands // // Add command with subcommands
newCommand := internal.Command{ // newCommand := internal.Command{
Command: command.Command, // Command: command.Command,
Module: command.Module, // Module: command.Module,
Categories: func() []string { // Categories: func() []string {
// Convert all the categories to lower case for uniformity // // Convert all the categories to lower case for uniformity
cats := make([]string, len(command.Categories)) // cats := make([]string, len(command.Categories))
for j, cat := range command.Categories { // for j, cat := range command.Categories {
cats[j] = strings.ToLower(cat) // cats[j] = strings.ToLower(cat)
} // }
return cats // return cats
}(), // }(),
Description: command.Description, // Description: command.Description,
Sync: command.Sync, // Sync: command.Sync,
KeyExtractionFunc: func(cmd []string) (internal.KeyExtractionFuncResult, error) { // KeyExtractionFunc: func(cmd []string) (internal.KeyExtractionFuncResult, error) {
return internal.KeyExtractionFuncResult{}, nil // return internal.KeyExtractionFuncResult{}, nil
}, // },
HandlerFunc: func(param internal.HandlerFuncParams) ([]byte, error) { return nil, nil }, // HandlerFunc: func(param internal.HandlerFuncParams) ([]byte, error) { return nil, nil },
SubCommands: make([]internal.SubCommand, len(command.SubCommand)), // SubCommands: make([]internal.SubCommand, len(command.SubCommand)),
} // }
//
for i, sc := range command.SubCommand { // for i, sc := range command.SubCommand {
// Skip the subcommand if it already exists in newCommand // // Skip the subcommand if it already exists in newCommand
if slices.ContainsFunc(newCommand.SubCommands, func(subcommand internal.SubCommand) bool { // if slices.ContainsFunc(newCommand.SubCommands, func(subcommand internal.SubCommand) bool {
return strings.EqualFold(subcommand.Command, sc.Command) // return strings.EqualFold(subcommand.Command, sc.Command)
}) { // }) {
continue // continue
} // }
newCommand.SubCommands[i] = internal.SubCommand{ // newCommand.SubCommands[i] = internal.SubCommand{
Command: sc.Command, // Command: sc.Command,
Module: strings.ToLower(command.Module), // Module: strings.ToLower(command.Module),
Categories: func() []string { // Categories: func() []string {
// Convert all the categories to lower case for uniformity // // Convert all the categories to lower case for uniformity
cats := make([]string, len(sc.Categories)) // cats := make([]string, len(sc.Categories))
for j, cat := range sc.Categories { // for j, cat := range sc.Categories {
cats[j] = strings.ToLower(cat) // cats[j] = strings.ToLower(cat)
} // }
return cats // return cats
}(), // }(),
Description: sc.Description, // Description: sc.Description,
Sync: sc.Sync, // Sync: sc.Sync,
KeyExtractionFunc: internal.KeyExtractionFunc(func(cmd []string) (internal.KeyExtractionFuncResult, error) { // KeyExtractionFunc: internal.KeyExtractionFunc(func(cmd []string) (internal.KeyExtractionFuncResult, error) {
accessKeys, err := sc.KeyExtractionFunc(cmd) // accessKeys, err := sc.KeyExtractionFunc(cmd)
if err != nil { // if err != nil {
return internal.KeyExtractionFuncResult{}, err // return internal.KeyExtractionFuncResult{}, err
} // }
return internal.KeyExtractionFuncResult{ // return internal.KeyExtractionFuncResult{
Channels: []string{}, // Channels: []string{},
ReadKeys: accessKeys.ReadKeys, // ReadKeys: accessKeys.ReadKeys,
WriteKeys: accessKeys.WriteKeys, // WriteKeys: accessKeys.WriteKeys,
}, nil // }, nil
}), // }),
HandlerFunc: internal.HandlerFunc(func(params internal.HandlerFuncParams) ([]byte, error) { // HandlerFunc: internal.HandlerFunc(func(params internal.HandlerFuncParams) ([]byte, error) {
return sc.HandlerFunc(CommandHandlerFuncParams{ // return sc.HandlerFunc(CommandHandlerFuncParams{
Context: params.Context, // Context: params.Context,
Command: params.Command, // Command: params.Command,
KeyLock: params.KeyLock, // KeyLock: params.KeyLock,
KeyUnlock: params.KeyUnlock, // KeyUnlock: params.KeyUnlock,
KeyRLock: params.KeyRLock, // KeyRLock: params.KeyRLock,
KeyRUnlock: params.KeyRUnlock, // KeyRUnlock: params.KeyRUnlock,
KeyExists: params.KeyExists, // KeyExists: params.KeyExists,
CreateKeyAndLock: params.CreateKeyAndLock, // CreateKeyAndLock: params.CreateKeyAndLock,
GetValue: params.GetValue, // GetValue: params.GetValue,
SetValue: params.SetValue, // SetValue: params.SetValue,
}) // })
}), // }),
} // }
} // }
//
server.commands = append(server.commands, newCommand) // server.commands = append(server.commands, newCommand)
//
return nil // return nil
} // }
// ExecuteCommand executes the command passed to it. If 1 string is passed, EchoVault will try to // ExecuteCommand executes the command passed to it. If 1 string is passed, EchoVault will try to
// execute the command. If 2 strings are passed, EchoVault will attempt to execute the subcommand of the command. // execute the command. If 2 strings are passed, EchoVault will attempt to execute the subcommand of the command.

View File

@@ -29,14 +29,9 @@ import (
"github.com/echovault/echovault/internal/memberlist" "github.com/echovault/echovault/internal/memberlist"
"github.com/echovault/echovault/internal/modules/acl" "github.com/echovault/echovault/internal/modules/acl"
"github.com/echovault/echovault/internal/modules/admin" "github.com/echovault/echovault/internal/modules/admin"
"github.com/echovault/echovault/internal/modules/connection"
"github.com/echovault/echovault/internal/modules/generic" "github.com/echovault/echovault/internal/modules/generic"
"github.com/echovault/echovault/internal/modules/hash" "github.com/echovault/echovault/internal/modules/hash"
"github.com/echovault/echovault/internal/modules/list"
"github.com/echovault/echovault/internal/modules/pubsub" "github.com/echovault/echovault/internal/modules/pubsub"
"github.com/echovault/echovault/internal/modules/set"
"github.com/echovault/echovault/internal/modules/sorted_set"
str "github.com/echovault/echovault/internal/modules/string"
"github.com/echovault/echovault/internal/raft" "github.com/echovault/echovault/internal/raft"
"github.com/echovault/echovault/internal/snapshot" "github.com/echovault/echovault/internal/snapshot"
"io" "io"
@@ -139,12 +134,12 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
commands = append(commands, admin.Commands()...) commands = append(commands, admin.Commands()...)
commands = append(commands, generic.Commands()...) commands = append(commands, generic.Commands()...)
commands = append(commands, hash.Commands()...) commands = append(commands, hash.Commands()...)
commands = append(commands, list.Commands()...) // commands = append(commands, list.Commands()...)
commands = append(commands, connection.Commands()...) // commands = append(commands, connection.Commands()...)
commands = append(commands, pubsub.Commands()...) // commands = append(commands, pubsub.Commands()...)
commands = append(commands, set.Commands()...) // commands = append(commands, set.Commands()...)
commands = append(commands, sorted_set.Commands()...) // commands = append(commands, sorted_set.Commands()...)
commands = append(commands, str.Commands()...) // commands = append(commands, str.Commands()...)
return commands return commands
}(), }(),
} }
@@ -159,13 +154,14 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
) )
// Load .so modules from config // Load .so modules from config
for _, path := range echovault.config.Modules { // TODO: Uncomment this
if err := echovault.LoadModule(path); err != nil { // for _, path := range echovault.config.Modules {
log.Printf("%s %v\n", path, err) // if err := echovault.LoadModule(path); err != nil {
continue // log.Printf("%s %v\n", path, err)
} // continue
log.Printf("loaded plugin %s\n", path) // }
} // log.Printf("loaded plugin %s\n", path)
// }
// Function for server commands retrieval // Function for server commands retrieval
echovault.getCommands = func() []internal.Command { echovault.getCommands = func() []internal.Command {
@@ -190,35 +186,36 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
} }
if echovault.isInCluster() { if echovault.isInCluster() {
echovault.raft = raft.NewRaft(raft.Opts{ // TODO: Uncomment this
Config: echovault.config, // echovault.raft = raft.NewRaft(raft.Opts{
GetCommand: echovault.getCommand, // Config: echovault.config,
SetValue: echovault.SetValue, // GetCommand: echovault.getCommand,
SetExpiry: echovault.SetExpiry, // SetValue: echovault.SetValue,
DeleteKey: echovault.DeleteKey, // SetExpiry: echovault.SetExpiry,
StartSnapshot: echovault.startSnapshot, // DeleteKey: echovault.DeleteKey,
FinishSnapshot: echovault.finishSnapshot, // StartSnapshot: echovault.startSnapshot,
SetLatestSnapshotTime: echovault.setLatestSnapshot, // FinishSnapshot: echovault.finishSnapshot,
GetHandlerFuncParams: echovault.getHandlerFuncParams, // SetLatestSnapshotTime: echovault.setLatestSnapshot,
GetState: func() map[string]internal.KeyData { // GetHandlerFuncParams: echovault.getHandlerFuncParams,
state := make(map[string]internal.KeyData) // GetState: func() map[string]internal.KeyData {
for k, v := range echovault.getState() { // state := make(map[string]internal.KeyData)
if data, ok := v.(internal.KeyData); ok { // for k, v := range echovault.getState() {
state[k] = data // if data, ok := v.(internal.KeyData); ok {
} // state[k] = data
} // }
return state // }
}, // return state
}) // },
echovault.memberList = memberlist.NewMemberList(memberlist.Opts{ // })
Config: echovault.config, // echovault.memberList = memberlist.NewMemberList(memberlist.Opts{
HasJoinedCluster: echovault.raft.HasJoinedCluster, // Config: echovault.config,
AddVoter: echovault.raft.AddVoter, // HasJoinedCluster: echovault.raft.HasJoinedCluster,
RemoveRaftServer: echovault.raft.RemoveServer, // AddVoter: echovault.raft.AddVoter,
IsRaftLeader: echovault.raft.IsRaftLeader, // RemoveRaftServer: echovault.raft.RemoveServer,
ApplyMutate: echovault.raftApplyCommand, // IsRaftLeader: echovault.raft.IsRaftLeader,
ApplyDeleteKey: echovault.raftApplyDeleteKey, // ApplyMutate: echovault.raftApplyCommand,
}) // ApplyDeleteKey: echovault.raftApplyDeleteKey,
// })
} else { } else {
// Set up standalone snapshot engine // Set up standalone snapshot engine
echovault.snapshotEngine = snapshot.NewSnapshotEngine( echovault.snapshotEngine = snapshot.NewSnapshotEngine(
@@ -241,10 +238,10 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
}), }),
snapshot.WithSetKeyDataFunc(func(key string, data internal.KeyData) { snapshot.WithSetKeyDataFunc(func(key string, data internal.KeyData) {
ctx := context.Background() ctx := context.Background()
if err := echovault.SetValue(ctx, key, data.Value); err != nil { if err := echovault.setValues(ctx, map[string]interface{}{key: data.Value}); err != nil {
log.Println(err) log.Println(err)
} }
echovault.SetExpiry(ctx, key, data.ExpireAt, false) echovault.setExpiry(ctx, key, data.ExpireAt, false)
}), }),
) )
// Set up standalone AOF engine // Set up standalone AOF engine
@@ -265,10 +262,10 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
}), }),
aof.WithSetKeyDataFunc(func(key string, value internal.KeyData) { aof.WithSetKeyDataFunc(func(key string, value internal.KeyData) {
ctx := context.Background() ctx := context.Background()
if err := echovault.SetValue(ctx, key, value.Value); err != nil { if err := echovault.setValues(ctx, map[string]interface{}{key: value.Value}); err != nil {
log.Println(err) log.Println(err)
} }
echovault.SetExpiry(ctx, key, value.ExpireAt, false) echovault.setExpiry(ctx, key, value.ExpireAt, false)
}), }),
aof.WithHandleCommandFunc(func(command []byte) { aof.WithHandleCommandFunc(func(command []byte) {
_, err := echovault.handleCommand(context.Background(), command, nil, true, false) _, err := echovault.handleCommand(context.Background(), command, nil, true, false)

View File

@@ -113,9 +113,13 @@ func (server *EchoVault) setValues(ctx context.Context, entries map[string]inter
} }
for key, value := range entries { for key, value := range entries {
expireAt := time.Time{}
if _, ok := server.store[key]; ok {
expireAt = server.store[key].ExpireAt
}
server.store[key] = internal.KeyData{ server.store[key] = internal.KeyData{
Value: value, Value: value,
ExpireAt: server.store[key].ExpireAt, ExpireAt: expireAt,
} }
if !server.isInCluster() { if !server.isInCluster() {
server.snapshotEngine.IncrementChangeCount() server.snapshotEngine.IncrementChangeCount()

View File

@@ -102,14 +102,9 @@ func (server *EchoVault) LoadModule(path string, args ...string) error {
handlerFunc, ok := handlerFuncSymbol.(func( handlerFunc, ok := handlerFuncSymbol.(func(
ctx context.Context, ctx context.Context,
command []string, command []string,
keyExists func(ctx context.Context, key string) bool, keysExist func(key []string) map[string]bool,
keyLock func(ctx context.Context, key string) (bool, error), getValues func(ctx context.Context, key []string) map[string]interface{},
keyUnlock func(ctx context.Context, key string), setValues func(ctx context.Context, entries map[string]interface{}) error,
keyRLock func(ctx context.Context, key string) (bool, error),
keyRUnlock func(ctx context.Context, key string),
createKeyAndLock func(ctx context.Context, key string) (bool, error),
getValue func(ctx context.Context, key string) interface{},
setValue func(ctx context.Context, key string, value interface{}) error,
args ...string, args ...string,
) ([]byte, error)) ) ([]byte, error))
if !ok { if !ok {
@@ -151,14 +146,9 @@ func (server *EchoVault) LoadModule(path string, args ...string) error {
return handlerFunc( return handlerFunc(
params.Context, params.Context,
params.Command, params.Command,
params.KeyExists, params.KeysExist,
params.KeyLock, params.GetValues,
params.KeyUnlock, params.SetValues,
params.KeyRLock,
params.KeyRUnlock,
params.CreateKeyAndLock,
params.GetValue,
params.SetValue,
args..., args...,
) )
}, },

View File

@@ -16,19 +16,13 @@ func createEchoVault() *EchoVault {
} }
func presetValue(server *EchoVault, ctx context.Context, key string, value interface{}) error { func presetValue(server *EchoVault, ctx context.Context, key string, value interface{}) error {
if _, err := server.CreateKeyAndLock(ctx, key); err != nil { if err := server.setValues(ctx, map[string]interface{}{key: value}); err != nil {
return err return err
} }
if err := server.SetValue(ctx, key, value); err != nil {
return err
}
server.KeyUnlock(ctx, key)
return nil return nil
} }
func presetKeyData(server *EchoVault, ctx context.Context, key string, data internal.KeyData) { func presetKeyData(server *EchoVault, ctx context.Context, key string, data internal.KeyData) {
_, _ = server.CreateKeyAndLock(ctx, key) _ = server.setValues(ctx, map[string]interface{}{key: data.Value})
defer server.KeyUnlock(ctx, key) server.setExpiry(ctx, key, data.ExpireAt, false)
_ = server.SetValue(ctx, key, data.Value)
server.SetExpiry(ctx, key, data.ExpireAt, false)
} }

View File

@@ -19,6 +19,7 @@ import (
"fmt" "fmt"
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/constants" "github.com/echovault/echovault/internal/constants"
"log"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -36,6 +37,7 @@ func handleSet(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.WriteKeys[0] key := keys.WriteKeys[0]
keyExists := params.KeysExist(keys.WriteKeys)[key]
value := params.Command[2] value := params.Command[2]
res := []byte(constants.OkResponse) res := []byte(constants.OkResponse)
clock := params.GetClock() clock := params.GetClock()
@@ -48,41 +50,28 @@ func handleSet(params internal.HandlerFuncParams) ([]byte, error) {
// If Get is provided, the response should be the current stored value. // If Get is provided, the response should be the current stored value.
// If there's no current value, then the response should be nil. // If there's no current value, then the response should be nil.
if options.get { if options.get {
if !params.KeyExists(params.Context, key) { if !keyExists {
res = []byte("$-1\r\n") res = []byte("$-1\r\n")
} else { } else {
res = []byte(fmt.Sprintf("+%v\r\n", params.GetValue(params.Context, key))) res = []byte(fmt.Sprintf("+%v\r\n", params.GetValues(params.Context, []string{key})[key]))
} }
} }
if "xx" == strings.ToLower(options.exists) { if "xx" == strings.ToLower(options.exists) {
// If XX is specified, make sure the key exists. // If XX is specified, make sure the key exists.
if !params.KeyExists(params.Context, key) { if !keyExists {
return nil, fmt.Errorf("key %s does not exist", key) return nil, fmt.Errorf("key %s does not exist", key)
} }
_, err = params.KeyLock(params.Context, key)
} else if "nx" == strings.ToLower(options.exists) { } else if "nx" == strings.ToLower(options.exists) {
// If NX is specified, make sure that the key does not currently exist. // If NX is specified, make sure that the key does not currently exist.
if params.KeyExists(params.Context, key) { if keyExists {
return nil, fmt.Errorf("key %s already exists", key) return nil, fmt.Errorf("key %s already exists", key)
} }
_, err = params.CreateKeyAndLock(params.Context, key)
} else {
// Neither XX not NX are specified, lock or create the lock
if !params.KeyExists(params.Context, key) {
// Key does not exist, create it
_, err = params.CreateKeyAndLock(params.Context, key)
} else {
// Key exists, acquire the lock
_, err = params.KeyLock(params.Context, key)
} }
}
if err != nil {
return nil, err
}
defer params.KeyUnlock(params.Context, key)
if err = params.SetValue(params.Context, key, internal.AdaptType(value)); err != nil { if err = params.SetValues(params.Context, map[string]interface{}{
key: internal.AdaptType(value),
}); err != nil {
return nil, err return nil, err
} }
@@ -100,53 +89,19 @@ func handleMSet(params internal.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
entries := make(map[string]KeyObject) entries := make(map[string]interface{})
// Release all acquired key locks
defer func() {
for k, v := range entries {
if v.locked {
params.KeyUnlock(params.Context, k)
entries[k] = KeyObject{
value: v.value,
locked: false,
}
}
}
}()
// Extract all the key/value pairs // Extract all the key/value pairs
for i, key := range params.Command[1:] { for i, key := range params.Command[1:] {
if i%2 == 0 { if i%2 == 0 {
entries[key] = KeyObject{ entries[key] = internal.AdaptType(params.Command[1:][i+1])
value: internal.AdaptType(params.Command[1:][i+1]),
locked: false,
} }
} }
}
// Acquire all the locks for each key first
// If any key cannot be acquired, abandon transaction and release all currently held keys
for k, v := range entries {
if params.KeyExists(params.Context, k) {
if _, err := params.KeyLock(params.Context, k); err != nil {
return nil, err
}
entries[k] = KeyObject{value: v.value, locked: true}
continue
}
if _, err := params.CreateKeyAndLock(params.Context, k); err != nil {
return nil, err
}
entries[k] = KeyObject{value: v.value, locked: true}
}
// Set all the values // Set all the values
for k, v := range entries { if err = params.SetValues(params.Context, entries); err != nil {
if err := params.SetValue(params.Context, k, v.value); err != nil {
return nil, err return nil, err
} }
}
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }
@@ -157,18 +112,13 @@ func handleGet(params internal.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
key := keys.ReadKeys[0] key := keys.ReadKeys[0]
keyExists := params.KeysExist([]string{key})[key]
if !params.KeyExists(params.Context, key) { if !keyExists {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
} }
_, err = params.KeyRLock(params.Context, key) value := params.GetValues(params.Context, []string{key})[key]
if err != nil {
return nil, err
}
defer params.KeyRUnlock(params.Context, key)
value := params.GetValue(params.Context, key)
return []byte(fmt.Sprintf("+%v\r\n", value)), nil return []byte(fmt.Sprintf("+%v\r\n", value)), nil
} }
@@ -180,34 +130,12 @@ func handleMGet(params internal.HandlerFuncParams) ([]byte, error) {
} }
values := make(map[string]string) values := make(map[string]string)
for key, value := range params.GetValues(params.Context, keys.ReadKeys) {
locks := make(map[string]bool) if value == nil {
for _, key := range keys.ReadKeys {
if _, ok := values[key]; ok {
// Skip if we have already locked this key
continue
}
if params.KeyExists(params.Context, key) {
_, err = params.KeyRLock(params.Context, key)
if err != nil {
return nil, fmt.Errorf("could not obtain lock for %s key", key)
}
locks[key] = true
continue
}
values[key] = "" values[key] = ""
continue
} }
defer func() { values[key] = fmt.Sprintf("%v", value)
for key, locked := range locks {
if locked {
params.KeyRUnlock(params.Context, key)
locks[key] = false
}
}
}()
for key, _ := range locks {
values[key] = fmt.Sprintf("%v", params.GetValue(params.Context, key))
} }
bytes := []byte(fmt.Sprintf("*%d\r\n", len(params.Command[1:]))) bytes := []byte(fmt.Sprintf("*%d\r\n", len(params.Command[1:])))
@@ -229,10 +157,13 @@ func handleDel(params internal.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
count := 0 count := 0
for _, key := range keys.WriteKeys { for key, exists := range params.KeysExist(keys.WriteKeys) {
err = params.DeleteKey(params.Context, key) if !exists {
continue
}
err = params.DeleteKey(key)
if err != nil { if err != nil {
// log.Printf("could not delete key %s due to error: %+v\n", key, err) // TODO: Uncomment this log.Printf("could not delete key %s due to error: %+v\n", key, err)
continue continue
} }
count += 1 count += 1
@@ -247,17 +178,13 @@ func handlePersist(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.WriteKeys[0] key := keys.WriteKeys[0]
keyExists := params.KeysExist(keys.WriteKeys)[key]
if !params.KeyExists(params.Context, key) { if !keyExists {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = params.KeyLock(params.Context, key); err != nil { expireAt := params.GetExpiry(key)
return nil, err
}
defer params.KeyUnlock(params.Context, key)
expireAt := params.GetExpiry(params.Context, key)
if expireAt == (time.Time{}) { if expireAt == (time.Time{}) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
@@ -274,17 +201,13 @@ func handleExpireTime(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.ReadKeys[0] key := keys.ReadKeys[0]
keyExists := params.KeysExist(keys.ReadKeys)[key]
if !params.KeyExists(params.Context, key) { if !keyExists {
return []byte(":-2\r\n"), nil return []byte(":-2\r\n"), nil
} }
if _, err = params.KeyRLock(params.Context, key); err != nil { expireAt := params.GetExpiry(key)
return nil, err
}
defer params.KeyRUnlock(params.Context, key)
expireAt := params.GetExpiry(params.Context, key)
if expireAt == (time.Time{}) { if expireAt == (time.Time{}) {
return []byte(":-1\r\n"), nil return []byte(":-1\r\n"), nil
@@ -305,19 +228,15 @@ func handleTTL(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.ReadKeys[0] key := keys.ReadKeys[0]
keyExists := params.KeysExist(keys.ReadKeys)[key]
clock := params.GetClock() clock := params.GetClock()
if !params.KeyExists(params.Context, key) { if !keyExists {
return []byte(":-2\r\n"), nil return []byte(":-2\r\n"), nil
} }
if _, err = params.KeyRLock(params.Context, key); err != nil { expireAt := params.GetExpiry(key)
return nil, err
}
defer params.KeyRUnlock(params.Context, key)
expireAt := params.GetExpiry(params.Context, key)
if expireAt == (time.Time{}) { if expireAt == (time.Time{}) {
return []byte(":-1\r\n"), nil return []byte(":-1\r\n"), nil
@@ -342,6 +261,7 @@ func handleExpire(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.WriteKeys[0] key := keys.WriteKeys[0]
keyExists := params.KeysExist(keys.WriteKeys)[key]
// Extract time // Extract time
n, err := strconv.ParseInt(params.Command[2], 10, 64) n, err := strconv.ParseInt(params.Command[2], 10, 64)
@@ -353,21 +273,16 @@ func handleExpire(params internal.HandlerFuncParams) ([]byte, error) {
expireAt = params.GetClock().Now().Add(time.Duration(n) * time.Millisecond) expireAt = params.GetClock().Now().Add(time.Duration(n) * time.Millisecond)
} }
if !params.KeyExists(params.Context, key) { if !keyExists {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = params.KeyLock(params.Context, key); err != nil {
return []byte(":0\r\n"), err
}
defer params.KeyUnlock(params.Context, key)
if len(params.Command) == 3 { if len(params.Command) == 3 {
params.SetExpiry(params.Context, key, expireAt, true) params.SetExpiry(params.Context, key, expireAt, true)
return []byte(":1\r\n"), nil return []byte(":1\r\n"), nil
} }
currentExpireAt := params.GetExpiry(params.Context, key) currentExpireAt := params.GetExpiry(key)
switch strings.ToLower(params.Command[3]) { switch strings.ToLower(params.Command[3]) {
case "nx": case "nx":
@@ -410,6 +325,7 @@ func handleExpireAt(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.WriteKeys[0] key := keys.WriteKeys[0]
keyExists := params.KeysExist(keys.WriteKeys)[key]
// Extract time // Extract time
n, err := strconv.ParseInt(params.Command[2], 10, 64) n, err := strconv.ParseInt(params.Command[2], 10, 64)
@@ -421,21 +337,16 @@ func handleExpireAt(params internal.HandlerFuncParams) ([]byte, error) {
expireAt = time.UnixMilli(n) expireAt = time.UnixMilli(n)
} }
if !params.KeyExists(params.Context, key) { if !keyExists {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer params.KeyUnlock(params.Context, key)
if len(params.Command) == 3 { if len(params.Command) == 3 {
params.SetExpiry(params.Context, key, expireAt, true) params.SetExpiry(params.Context, key, expireAt, true)
return []byte(":1\r\n"), nil return []byte(":1\r\n"), nil
} }
currentExpireAt := params.GetExpiry(params.Context, key) currentExpireAt := params.GetExpiry(key)
switch strings.ToLower(params.Command[3]) { switch strings.ToLower(params.Command[3]) {
case "nx": case "nx":

File diff suppressed because it is too large Load Diff

View File

@@ -32,6 +32,7 @@ func handleHSET(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.WriteKeys[0] key := keys.WriteKeys[0]
keyExists := params.KeysExist(keys.WriteKeys)[key]
entries := make(map[string]interface{}) entries := make(map[string]interface{})
if len(params.Command[2:])%2 != 0 { if len(params.Command[2:])%2 != 0 {
@@ -42,26 +43,19 @@ func handleHSET(params internal.HandlerFuncParams) ([]byte, error) {
entries[params.Command[i]] = internal.AdaptType(params.Command[i+1]) entries[params.Command[i]] = internal.AdaptType(params.Command[i+1])
} }
if !params.KeyExists(params.Context, key) { if !keyExists {
_, err = params.CreateKeyAndLock(params.Context, key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer params.KeyUnlock(params.Context, key) if err = params.SetValues(params.Context, map[string]interface{}{key: entries}); err != nil {
if err = params.SetValue(params.Context, key, entries); err != nil {
return nil, err return nil, err
} }
return []byte(fmt.Sprintf(":%d\r\n", len(entries))), nil return []byte(fmt.Sprintf(":%d\r\n", len(entries))), nil
} }
if _, err = params.KeyLock(params.Context, key); err != nil { hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
return nil, err
}
defer params.KeyUnlock(params.Context, key)
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) hash = make(map[string]interface{})
} }
count := 0 count := 0
@@ -76,7 +70,7 @@ func handleHSET(params internal.HandlerFuncParams) ([]byte, error) {
hash[field] = value hash[field] = value
count += 1 count += 1
} }
if err = params.SetValue(params.Context, key, hash); err != nil { if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil {
return nil, err return nil, err
} }
@@ -90,18 +84,14 @@ func handleHGET(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.ReadKeys[0] key := keys.ReadKeys[0]
keyExists := params.KeysExist(keys.ReadKeys)[key]
fields := params.Command[2:] fields := params.Command[2:]
if !params.KeyExists(params.Context, key) { if !keyExists {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
} }
if _, err = params.KeyRLock(params.Context, key); err != nil { hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
return nil, err
}
defer params.KeyRUnlock(params.Context, key)
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
@@ -141,18 +131,14 @@ func handleHSTRLEN(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.ReadKeys[0] key := keys.ReadKeys[0]
keyExists := params.KeysExist(keys.ReadKeys)[key]
fields := params.Command[2:] fields := params.Command[2:]
if !params.KeyExists(params.Context, key) { if !keyExists {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
} }
if _, err = params.KeyRLock(params.Context, key); err != nil { hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
return nil, err
}
defer params.KeyRUnlock(params.Context, key)
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
@@ -192,17 +178,13 @@ func handleHVALS(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.ReadKeys[0] key := keys.ReadKeys[0]
keyExists := params.KeysExist(keys.ReadKeys)[key]
if !params.KeyExists(params.Context, key) { if !keyExists {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
if _, err = params.KeyRLock(params.Context, key); err != nil { hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
return nil, err
}
defer params.KeyRUnlock(params.Context, key)
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
@@ -233,6 +215,7 @@ func handleHRANDFIELD(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.ReadKeys[0] key := keys.ReadKeys[0]
keyExists := params.KeysExist(keys.ReadKeys)[key]
count := 1 count := 1
if len(params.Command) >= 3 { if len(params.Command) >= 3 {
@@ -255,16 +238,11 @@ func handleHRANDFIELD(params internal.HandlerFuncParams) ([]byte, error) {
} }
} }
if !params.KeyExists(params.Context, key) { if !keyExists {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
if _, err = params.KeyRLock(params.Context, key); err != nil { hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
return nil, err
}
defer params.KeyRUnlock(params.Context, key)
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
@@ -349,17 +327,13 @@ func handleHLEN(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.ReadKeys[0] key := keys.ReadKeys[0]
keyExists := params.KeysExist(keys.ReadKeys)[key]
if !params.KeyExists(params.Context, key) { if !keyExists {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = params.KeyRLock(params.Context, key); err != nil { hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
return nil, err
}
defer params.KeyRUnlock(params.Context, key)
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
@@ -374,17 +348,13 @@ func handleHKEYS(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.ReadKeys[0] key := keys.ReadKeys[0]
keyExists := params.KeysExist(keys.ReadKeys)[key]
if !params.KeyExists(params.Context, key) { if !keyExists {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
if _, err = params.KeyRLock(params.Context, key); err != nil { hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
return nil, err
}
defer params.KeyRUnlock(params.Context, key)
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
@@ -404,6 +374,7 @@ func handleHINCRBY(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.WriteKeys[0] key := keys.WriteKeys[0]
keyExists := params.KeysExist(keys.WriteKeys)[key]
field := params.Command[2] field := params.Command[2]
var intIncrement int var intIncrement int
@@ -423,33 +394,24 @@ func handleHINCRBY(params internal.HandlerFuncParams) ([]byte, error) {
intIncrement = i intIncrement = i
} }
if !params.KeyExists(params.Context, key) { if !keyExists {
if _, err := params.CreateKeyAndLock(params.Context, key); err != nil {
return nil, err
}
defer params.KeyUnlock(params.Context, key)
hash := make(map[string]interface{}) hash := make(map[string]interface{})
if strings.EqualFold(params.Command[0], "hincrbyfloat") { if strings.EqualFold(params.Command[0], "hincrbyfloat") {
hash[field] = floatIncrement hash[field] = floatIncrement
if err = params.SetValue(params.Context, key, hash); err != nil { if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil {
return nil, err return nil, err
} }
return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(floatIncrement, 'f', -1, 64))), nil return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(floatIncrement, 'f', -1, 64))), nil
} else { } else {
hash[field] = intIncrement hash[field] = intIncrement
if err = params.SetValue(params.Context, key, hash); err != nil { if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil {
return nil, err return nil, err
} }
return []byte(fmt.Sprintf(":%d\r\n", intIncrement)), nil return []byte(fmt.Sprintf(":%d\r\n", intIncrement)), nil
} }
} }
if _, err := params.KeyLock(params.Context, key); err != nil { hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
return nil, err
}
defer params.KeyUnlock(params.Context, key)
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
@@ -477,7 +439,7 @@ func handleHINCRBY(params internal.HandlerFuncParams) ([]byte, error) {
} }
} }
if err = params.SetValue(params.Context, key, hash); err != nil { if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil {
return nil, err return nil, err
} }
@@ -496,17 +458,13 @@ func handleHGETALL(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.ReadKeys[0] key := keys.ReadKeys[0]
keyExists := params.KeysExist(keys.ReadKeys)[key]
if !params.KeyExists(params.Context, key) { if !keyExists {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
if _, err = params.KeyRLock(params.Context, key); err != nil { hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
return nil, err
}
defer params.KeyRUnlock(params.Context, key)
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
@@ -536,18 +494,14 @@ func handleHEXISTS(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.ReadKeys[0] key := keys.ReadKeys[0]
keyExists := params.KeysExist(keys.ReadKeys)[key]
field := params.Command[2] field := params.Command[2]
if !params.KeyExists(params.Context, key) { if !keyExists {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = params.KeyRLock(params.Context, key); err != nil { hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
return nil, err
}
defer params.KeyRUnlock(params.Context, key)
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
@@ -566,18 +520,14 @@ func handleHDEL(params internal.HandlerFuncParams) ([]byte, error) {
} }
key := keys.WriteKeys[0] key := keys.WriteKeys[0]
keyExists := params.KeysExist(keys.WriteKeys)[key]
fields := params.Command[2:] fields := params.Command[2:]
if !params.KeyExists(params.Context, key) { if !keyExists {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = params.KeyLock(params.Context, key); err != nil { hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
return nil, err
}
defer params.KeyUnlock(params.Context, key)
hash, ok := params.GetValue(params.Context, key).(map[string]interface{})
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
@@ -591,7 +541,7 @@ func handleHDEL(params internal.HandlerFuncParams) ([]byte, error) {
} }
} }
if err = params.SetValue(params.Context, key, hash); err != nil { if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil {
return nil, err return nil, err
} }

File diff suppressed because it is too large Load Diff