diff --git a/echovault/api_admin.go b/echovault/api_admin.go index 9f3c2aa..98577c2 100644 --- a/echovault/api_admin.go +++ b/echovault/api_admin.go @@ -16,7 +16,6 @@ package echovault import ( "context" - "fmt" "github.com/echovault/echovault/internal" "slices" "strings" @@ -230,133 +229,133 @@ func (server *EchoVault) RewriteAOF() (string, error) { // Errors: // // "command already exists" - If a command with the same command name as the passed command already exists. -func (server *EchoVault) AddCommand(command CommandOptions) error { - server.commandsRWMut.Lock() - defer server.commandsRWMut.Unlock() - // Check if command already exists - for _, c := range server.commands { - if strings.EqualFold(c.Command, command.Command) { - return fmt.Errorf("command %s already exists", command.Command) - } - } - - if command.SubCommand == nil || len(command.SubCommand) == 0 { - // Add command with no subcommands - server.commands = append(server.commands, internal.Command{ - Command: command.Command, - Module: strings.ToLower(command.Module), // Convert module to lower case for uniformity - Categories: func() []string { - // Convert all the categories to lower case for uniformity - cats := make([]string, len(command.Categories)) - for i, cat := range command.Categories { - cats[i] = strings.ToLower(cat) - } - return cats - }(), - Description: command.Description, - Sync: command.Sync, - KeyExtractionFunc: internal.KeyExtractionFunc(func(cmd []string) (internal.KeyExtractionFuncResult, error) { - accessKeys, err := command.KeyExtractionFunc(cmd) - if err != nil { - return internal.KeyExtractionFuncResult{}, err - } - return internal.KeyExtractionFuncResult{ - Channels: []string{}, - ReadKeys: accessKeys.ReadKeys, - WriteKeys: accessKeys.WriteKeys, - }, nil - }), - HandlerFunc: internal.HandlerFunc(func(params internal.HandlerFuncParams) ([]byte, error) { - return command.HandlerFunc(CommandHandlerFuncParams{ - Context: params.Context, - Command: params.Command, - KeyLock: params.KeyLock, - KeyUnlock: params.KeyUnlock, - KeyRLock: params.KeyRLock, - KeyRUnlock: params.KeyRUnlock, - KeyExists: params.KeyExists, - CreateKeyAndLock: params.CreateKeyAndLock, - GetValue: params.GetValue, - SetValue: params.SetValue, - }) - }), - }) - return nil - } - - // Add command with subcommands - newCommand := internal.Command{ - Command: command.Command, - Module: command.Module, - Categories: func() []string { - // Convert all the categories to lower case for uniformity - cats := make([]string, len(command.Categories)) - for j, cat := range command.Categories { - cats[j] = strings.ToLower(cat) - } - return cats - }(), - Description: command.Description, - Sync: command.Sync, - KeyExtractionFunc: func(cmd []string) (internal.KeyExtractionFuncResult, error) { - return internal.KeyExtractionFuncResult{}, nil - }, - HandlerFunc: func(param internal.HandlerFuncParams) ([]byte, error) { return nil, nil }, - SubCommands: make([]internal.SubCommand, len(command.SubCommand)), - } - - for i, sc := range command.SubCommand { - // Skip the subcommand if it already exists in newCommand - if slices.ContainsFunc(newCommand.SubCommands, func(subcommand internal.SubCommand) bool { - return strings.EqualFold(subcommand.Command, sc.Command) - }) { - continue - } - newCommand.SubCommands[i] = internal.SubCommand{ - Command: sc.Command, - Module: strings.ToLower(command.Module), - Categories: func() []string { - // Convert all the categories to lower case for uniformity - cats := make([]string, len(sc.Categories)) - for j, cat := range sc.Categories { - cats[j] = strings.ToLower(cat) - } - return cats - }(), - Description: sc.Description, - Sync: sc.Sync, - KeyExtractionFunc: internal.KeyExtractionFunc(func(cmd []string) (internal.KeyExtractionFuncResult, error) { - accessKeys, err := sc.KeyExtractionFunc(cmd) - if err != nil { - return internal.KeyExtractionFuncResult{}, err - } - return internal.KeyExtractionFuncResult{ - Channels: []string{}, - ReadKeys: accessKeys.ReadKeys, - WriteKeys: accessKeys.WriteKeys, - }, nil - }), - HandlerFunc: internal.HandlerFunc(func(params internal.HandlerFuncParams) ([]byte, error) { - return sc.HandlerFunc(CommandHandlerFuncParams{ - Context: params.Context, - Command: params.Command, - KeyLock: params.KeyLock, - KeyUnlock: params.KeyUnlock, - KeyRLock: params.KeyRLock, - KeyRUnlock: params.KeyRUnlock, - KeyExists: params.KeyExists, - CreateKeyAndLock: params.CreateKeyAndLock, - GetValue: params.GetValue, - SetValue: params.SetValue, - }) - }), - } - } - - server.commands = append(server.commands, newCommand) - - return nil -} +// func (server *EchoVault) AddCommand(command CommandOptions) error { +// server.commandsRWMut.Lock() +// defer server.commandsRWMut.Unlock() +// // Check if command already exists +// for _, c := range server.commands { +// if strings.EqualFold(c.Command, command.Command) { +// return fmt.Errorf("command %s already exists", command.Command) +// } +// } +// +// if command.SubCommand == nil || len(command.SubCommand) == 0 { +// // Add command with no subcommands +// server.commands = append(server.commands, internal.Command{ +// Command: command.Command, +// Module: strings.ToLower(command.Module), // Convert module to lower case for uniformity +// Categories: func() []string { +// // Convert all the categories to lower case for uniformity +// cats := make([]string, len(command.Categories)) +// for i, cat := range command.Categories { +// cats[i] = strings.ToLower(cat) +// } +// return cats +// }(), +// Description: command.Description, +// Sync: command.Sync, +// KeyExtractionFunc: internal.KeyExtractionFunc(func(cmd []string) (internal.KeyExtractionFuncResult, error) { +// accessKeys, err := command.KeyExtractionFunc(cmd) +// if err != nil { +// return internal.KeyExtractionFuncResult{}, err +// } +// return internal.KeyExtractionFuncResult{ +// Channels: []string{}, +// ReadKeys: accessKeys.ReadKeys, +// WriteKeys: accessKeys.WriteKeys, +// }, nil +// }), +// HandlerFunc: internal.HandlerFunc(func(params internal.HandlerFuncParams) ([]byte, error) { +// return command.HandlerFunc(CommandHandlerFuncParams{ +// Context: params.Context, +// Command: params.Command, +// KeyLock: params.KeyLock, +// KeyUnlock: params.KeyUnlock, +// KeyRLock: params.KeyRLock, +// KeyRUnlock: params.KeyRUnlock, +// KeyExists: params.KeyExists, +// CreateKeyAndLock: params.CreateKeyAndLock, +// GetValue: params.GetValue, +// SetValue: params.SetValue, +// }) +// }), +// }) +// return nil +// } +// +// // Add command with subcommands +// newCommand := internal.Command{ +// Command: command.Command, +// Module: command.Module, +// Categories: func() []string { +// // Convert all the categories to lower case for uniformity +// cats := make([]string, len(command.Categories)) +// for j, cat := range command.Categories { +// cats[j] = strings.ToLower(cat) +// } +// return cats +// }(), +// Description: command.Description, +// Sync: command.Sync, +// KeyExtractionFunc: func(cmd []string) (internal.KeyExtractionFuncResult, error) { +// return internal.KeyExtractionFuncResult{}, nil +// }, +// HandlerFunc: func(param internal.HandlerFuncParams) ([]byte, error) { return nil, nil }, +// SubCommands: make([]internal.SubCommand, len(command.SubCommand)), +// } +// +// for i, sc := range command.SubCommand { +// // Skip the subcommand if it already exists in newCommand +// if slices.ContainsFunc(newCommand.SubCommands, func(subcommand internal.SubCommand) bool { +// return strings.EqualFold(subcommand.Command, sc.Command) +// }) { +// continue +// } +// newCommand.SubCommands[i] = internal.SubCommand{ +// Command: sc.Command, +// Module: strings.ToLower(command.Module), +// Categories: func() []string { +// // Convert all the categories to lower case for uniformity +// cats := make([]string, len(sc.Categories)) +// for j, cat := range sc.Categories { +// cats[j] = strings.ToLower(cat) +// } +// return cats +// }(), +// Description: sc.Description, +// Sync: sc.Sync, +// KeyExtractionFunc: internal.KeyExtractionFunc(func(cmd []string) (internal.KeyExtractionFuncResult, error) { +// accessKeys, err := sc.KeyExtractionFunc(cmd) +// if err != nil { +// return internal.KeyExtractionFuncResult{}, err +// } +// return internal.KeyExtractionFuncResult{ +// Channels: []string{}, +// ReadKeys: accessKeys.ReadKeys, +// WriteKeys: accessKeys.WriteKeys, +// }, nil +// }), +// HandlerFunc: internal.HandlerFunc(func(params internal.HandlerFuncParams) ([]byte, error) { +// return sc.HandlerFunc(CommandHandlerFuncParams{ +// Context: params.Context, +// Command: params.Command, +// KeyLock: params.KeyLock, +// KeyUnlock: params.KeyUnlock, +// KeyRLock: params.KeyRLock, +// KeyRUnlock: params.KeyRUnlock, +// KeyExists: params.KeyExists, +// CreateKeyAndLock: params.CreateKeyAndLock, +// GetValue: params.GetValue, +// SetValue: params.SetValue, +// }) +// }), +// } +// } +// +// server.commands = append(server.commands, newCommand) +// +// return nil +// } // ExecuteCommand executes the command passed to it. If 1 string is passed, EchoVault will try to // execute the command. If 2 strings are passed, EchoVault will attempt to execute the subcommand of the command. diff --git a/echovault/echovault.go b/echovault/echovault.go index dd198ed..c548055 100644 --- a/echovault/echovault.go +++ b/echovault/echovault.go @@ -29,14 +29,9 @@ import ( "github.com/echovault/echovault/internal/memberlist" "github.com/echovault/echovault/internal/modules/acl" "github.com/echovault/echovault/internal/modules/admin" - "github.com/echovault/echovault/internal/modules/connection" "github.com/echovault/echovault/internal/modules/generic" "github.com/echovault/echovault/internal/modules/hash" - "github.com/echovault/echovault/internal/modules/list" "github.com/echovault/echovault/internal/modules/pubsub" - "github.com/echovault/echovault/internal/modules/set" - "github.com/echovault/echovault/internal/modules/sorted_set" - str "github.com/echovault/echovault/internal/modules/string" "github.com/echovault/echovault/internal/raft" "github.com/echovault/echovault/internal/snapshot" "io" @@ -139,12 +134,12 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) { commands = append(commands, admin.Commands()...) commands = append(commands, generic.Commands()...) commands = append(commands, hash.Commands()...) - commands = append(commands, list.Commands()...) - commands = append(commands, connection.Commands()...) - commands = append(commands, pubsub.Commands()...) - commands = append(commands, set.Commands()...) - commands = append(commands, sorted_set.Commands()...) - commands = append(commands, str.Commands()...) + // commands = append(commands, list.Commands()...) + // commands = append(commands, connection.Commands()...) + // commands = append(commands, pubsub.Commands()...) + // commands = append(commands, set.Commands()...) + // commands = append(commands, sorted_set.Commands()...) + // commands = append(commands, str.Commands()...) return commands }(), } @@ -159,13 +154,14 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) { ) // Load .so modules from config - for _, path := range echovault.config.Modules { - if err := echovault.LoadModule(path); err != nil { - log.Printf("%s %v\n", path, err) - continue - } - log.Printf("loaded plugin %s\n", path) - } + // TODO: Uncomment this + // for _, path := range echovault.config.Modules { + // if err := echovault.LoadModule(path); err != nil { + // log.Printf("%s %v\n", path, err) + // continue + // } + // log.Printf("loaded plugin %s\n", path) + // } // Function for server commands retrieval echovault.getCommands = func() []internal.Command { @@ -190,35 +186,36 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) { } if echovault.isInCluster() { - echovault.raft = raft.NewRaft(raft.Opts{ - Config: echovault.config, - GetCommand: echovault.getCommand, - SetValue: echovault.SetValue, - SetExpiry: echovault.SetExpiry, - DeleteKey: echovault.DeleteKey, - StartSnapshot: echovault.startSnapshot, - FinishSnapshot: echovault.finishSnapshot, - SetLatestSnapshotTime: echovault.setLatestSnapshot, - GetHandlerFuncParams: echovault.getHandlerFuncParams, - GetState: func() map[string]internal.KeyData { - state := make(map[string]internal.KeyData) - for k, v := range echovault.getState() { - if data, ok := v.(internal.KeyData); ok { - state[k] = data - } - } - return state - }, - }) - echovault.memberList = memberlist.NewMemberList(memberlist.Opts{ - Config: echovault.config, - HasJoinedCluster: echovault.raft.HasJoinedCluster, - AddVoter: echovault.raft.AddVoter, - RemoveRaftServer: echovault.raft.RemoveServer, - IsRaftLeader: echovault.raft.IsRaftLeader, - ApplyMutate: echovault.raftApplyCommand, - ApplyDeleteKey: echovault.raftApplyDeleteKey, - }) + // TODO: Uncomment this + // echovault.raft = raft.NewRaft(raft.Opts{ + // Config: echovault.config, + // GetCommand: echovault.getCommand, + // SetValue: echovault.SetValue, + // SetExpiry: echovault.SetExpiry, + // DeleteKey: echovault.DeleteKey, + // StartSnapshot: echovault.startSnapshot, + // FinishSnapshot: echovault.finishSnapshot, + // SetLatestSnapshotTime: echovault.setLatestSnapshot, + // GetHandlerFuncParams: echovault.getHandlerFuncParams, + // GetState: func() map[string]internal.KeyData { + // state := make(map[string]internal.KeyData) + // for k, v := range echovault.getState() { + // if data, ok := v.(internal.KeyData); ok { + // state[k] = data + // } + // } + // return state + // }, + // }) + // echovault.memberList = memberlist.NewMemberList(memberlist.Opts{ + // Config: echovault.config, + // HasJoinedCluster: echovault.raft.HasJoinedCluster, + // AddVoter: echovault.raft.AddVoter, + // RemoveRaftServer: echovault.raft.RemoveServer, + // IsRaftLeader: echovault.raft.IsRaftLeader, + // ApplyMutate: echovault.raftApplyCommand, + // ApplyDeleteKey: echovault.raftApplyDeleteKey, + // }) } else { // Set up standalone snapshot engine echovault.snapshotEngine = snapshot.NewSnapshotEngine( @@ -241,10 +238,10 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) { }), snapshot.WithSetKeyDataFunc(func(key string, data internal.KeyData) { ctx := context.Background() - if err := echovault.SetValue(ctx, key, data.Value); err != nil { + if err := echovault.setValues(ctx, map[string]interface{}{key: data.Value}); err != nil { log.Println(err) } - echovault.SetExpiry(ctx, key, data.ExpireAt, false) + echovault.setExpiry(ctx, key, data.ExpireAt, false) }), ) // Set up standalone AOF engine @@ -265,10 +262,10 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) { }), aof.WithSetKeyDataFunc(func(key string, value internal.KeyData) { ctx := context.Background() - if err := echovault.SetValue(ctx, key, value.Value); err != nil { + if err := echovault.setValues(ctx, map[string]interface{}{key: value.Value}); err != nil { log.Println(err) } - echovault.SetExpiry(ctx, key, value.ExpireAt, false) + echovault.setExpiry(ctx, key, value.ExpireAt, false) }), aof.WithHandleCommandFunc(func(command []byte) { _, err := echovault.handleCommand(context.Background(), command, nil, true, false) diff --git a/echovault/keyspace.go b/echovault/keyspace.go index 0fc1b67..16a5cc6 100644 --- a/echovault/keyspace.go +++ b/echovault/keyspace.go @@ -113,9 +113,13 @@ func (server *EchoVault) setValues(ctx context.Context, entries map[string]inter } for key, value := range entries { + expireAt := time.Time{} + if _, ok := server.store[key]; ok { + expireAt = server.store[key].ExpireAt + } server.store[key] = internal.KeyData{ Value: value, - ExpireAt: server.store[key].ExpireAt, + ExpireAt: expireAt, } if !server.isInCluster() { server.snapshotEngine.IncrementChangeCount() diff --git a/echovault/plugin.go b/echovault/plugin.go index dd883d6..956a924 100644 --- a/echovault/plugin.go +++ b/echovault/plugin.go @@ -102,14 +102,9 @@ func (server *EchoVault) LoadModule(path string, args ...string) error { handlerFunc, ok := handlerFuncSymbol.(func( ctx context.Context, command []string, - keyExists func(ctx context.Context, key string) bool, - keyLock func(ctx context.Context, key string) (bool, error), - keyUnlock func(ctx context.Context, key string), - keyRLock func(ctx context.Context, key string) (bool, error), - keyRUnlock func(ctx context.Context, key string), - createKeyAndLock func(ctx context.Context, key string) (bool, error), - getValue func(ctx context.Context, key string) interface{}, - setValue func(ctx context.Context, key string, value interface{}) error, + keysExist func(key []string) map[string]bool, + getValues func(ctx context.Context, key []string) map[string]interface{}, + setValues func(ctx context.Context, entries map[string]interface{}) error, args ...string, ) ([]byte, error)) if !ok { @@ -151,14 +146,9 @@ func (server *EchoVault) LoadModule(path string, args ...string) error { return handlerFunc( params.Context, params.Command, - params.KeyExists, - params.KeyLock, - params.KeyUnlock, - params.KeyRLock, - params.KeyRUnlock, - params.CreateKeyAndLock, - params.GetValue, - params.SetValue, + params.KeysExist, + params.GetValues, + params.SetValues, args..., ) }, diff --git a/echovault/test_helpers.go b/echovault/test_helpers.go index 4560789..6be0ca4 100644 --- a/echovault/test_helpers.go +++ b/echovault/test_helpers.go @@ -16,19 +16,13 @@ func createEchoVault() *EchoVault { } func presetValue(server *EchoVault, ctx context.Context, key string, value interface{}) error { - if _, err := server.CreateKeyAndLock(ctx, key); err != nil { + if err := server.setValues(ctx, map[string]interface{}{key: value}); err != nil { return err } - if err := server.SetValue(ctx, key, value); err != nil { - return err - } - server.KeyUnlock(ctx, key) return nil } func presetKeyData(server *EchoVault, ctx context.Context, key string, data internal.KeyData) { - _, _ = server.CreateKeyAndLock(ctx, key) - defer server.KeyUnlock(ctx, key) - _ = server.SetValue(ctx, key, data.Value) - server.SetExpiry(ctx, key, data.ExpireAt, false) + _ = server.setValues(ctx, map[string]interface{}{key: data.Value}) + server.setExpiry(ctx, key, data.ExpireAt, false) } diff --git a/internal/modules/generic/commands.go b/internal/modules/generic/commands.go index e8a7c6b..946ca75 100644 --- a/internal/modules/generic/commands.go +++ b/internal/modules/generic/commands.go @@ -19,6 +19,7 @@ import ( "fmt" "github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal/constants" + "log" "strconv" "strings" "time" @@ -36,6 +37,7 @@ func handleSet(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.WriteKeys[0] + keyExists := params.KeysExist(keys.WriteKeys)[key] value := params.Command[2] res := []byte(constants.OkResponse) clock := params.GetClock() @@ -48,41 +50,28 @@ func handleSet(params internal.HandlerFuncParams) ([]byte, error) { // If Get is provided, the response should be the current stored value. // If there's no current value, then the response should be nil. if options.get { - if !params.KeyExists(params.Context, key) { + if !keyExists { res = []byte("$-1\r\n") } else { - res = []byte(fmt.Sprintf("+%v\r\n", params.GetValue(params.Context, key))) + res = []byte(fmt.Sprintf("+%v\r\n", params.GetValues(params.Context, []string{key})[key])) } } if "xx" == strings.ToLower(options.exists) { // If XX is specified, make sure the key exists. - if !params.KeyExists(params.Context, key) { + if !keyExists { return nil, fmt.Errorf("key %s does not exist", key) } - _, err = params.KeyLock(params.Context, key) } else if "nx" == strings.ToLower(options.exists) { // If NX is specified, make sure that the key does not currently exist. - if params.KeyExists(params.Context, key) { + if keyExists { return nil, fmt.Errorf("key %s already exists", key) } - _, err = params.CreateKeyAndLock(params.Context, key) - } else { - // Neither XX not NX are specified, lock or create the lock - if !params.KeyExists(params.Context, key) { - // Key does not exist, create it - _, err = params.CreateKeyAndLock(params.Context, key) - } else { - // Key exists, acquire the lock - _, err = params.KeyLock(params.Context, key) - } } - if err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) - if err = params.SetValue(params.Context, key, internal.AdaptType(value)); err != nil { + if err = params.SetValues(params.Context, map[string]interface{}{ + key: internal.AdaptType(value), + }); err != nil { return nil, err } @@ -100,52 +89,18 @@ func handleMSet(params internal.HandlerFuncParams) ([]byte, error) { return nil, err } - entries := make(map[string]KeyObject) - - // Release all acquired key locks - defer func() { - for k, v := range entries { - if v.locked { - params.KeyUnlock(params.Context, k) - entries[k] = KeyObject{ - value: v.value, - locked: false, - } - } - } - }() + entries := make(map[string]interface{}) // Extract all the key/value pairs for i, key := range params.Command[1:] { if i%2 == 0 { - entries[key] = KeyObject{ - value: internal.AdaptType(params.Command[1:][i+1]), - locked: false, - } + entries[key] = internal.AdaptType(params.Command[1:][i+1]) } } - // Acquire all the locks for each key first - // If any key cannot be acquired, abandon transaction and release all currently held keys - for k, v := range entries { - if params.KeyExists(params.Context, k) { - if _, err := params.KeyLock(params.Context, k); err != nil { - return nil, err - } - entries[k] = KeyObject{value: v.value, locked: true} - continue - } - if _, err := params.CreateKeyAndLock(params.Context, k); err != nil { - return nil, err - } - entries[k] = KeyObject{value: v.value, locked: true} - } - // Set all the values - for k, v := range entries { - if err := params.SetValue(params.Context, k, v.value); err != nil { - return nil, err - } + if err = params.SetValues(params.Context, entries); err != nil { + return nil, err } return []byte(constants.OkResponse), nil @@ -157,18 +112,13 @@ func handleGet(params internal.HandlerFuncParams) ([]byte, error) { return nil, err } key := keys.ReadKeys[0] + keyExists := params.KeysExist([]string{key})[key] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte("$-1\r\n"), nil } - _, err = params.KeyRLock(params.Context, key) - if err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - value := params.GetValue(params.Context, key) + value := params.GetValues(params.Context, []string{key})[key] return []byte(fmt.Sprintf("+%v\r\n", value)), nil } @@ -180,34 +130,12 @@ func handleMGet(params internal.HandlerFuncParams) ([]byte, error) { } values := make(map[string]string) - - locks := make(map[string]bool) - for _, key := range keys.ReadKeys { - if _, ok := values[key]; ok { - // Skip if we have already locked this key + for key, value := range params.GetValues(params.Context, keys.ReadKeys) { + if value == nil { + values[key] = "" continue } - if params.KeyExists(params.Context, key) { - _, err = params.KeyRLock(params.Context, key) - if err != nil { - return nil, fmt.Errorf("could not obtain lock for %s key", key) - } - locks[key] = true - continue - } - values[key] = "" - } - defer func() { - for key, locked := range locks { - if locked { - params.KeyRUnlock(params.Context, key) - locks[key] = false - } - } - }() - - for key, _ := range locks { - values[key] = fmt.Sprintf("%v", params.GetValue(params.Context, key)) + values[key] = fmt.Sprintf("%v", value) } bytes := []byte(fmt.Sprintf("*%d\r\n", len(params.Command[1:]))) @@ -229,10 +157,13 @@ func handleDel(params internal.HandlerFuncParams) ([]byte, error) { return nil, err } count := 0 - for _, key := range keys.WriteKeys { - err = params.DeleteKey(params.Context, key) + for key, exists := range params.KeysExist(keys.WriteKeys) { + if !exists { + continue + } + err = params.DeleteKey(key) if err != nil { - // log.Printf("could not delete key %s due to error: %+v\n", key, err) // TODO: Uncomment this + log.Printf("could not delete key %s due to error: %+v\n", key, err) continue } count += 1 @@ -247,17 +178,13 @@ func handlePersist(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.WriteKeys[0] + keyExists := params.KeysExist(keys.WriteKeys)[key] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte(":0\r\n"), nil } - if _, err = params.KeyLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) - - expireAt := params.GetExpiry(params.Context, key) + expireAt := params.GetExpiry(key) if expireAt == (time.Time{}) { return []byte(":0\r\n"), nil } @@ -274,17 +201,13 @@ func handleExpireTime(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte(":-2\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - expireAt := params.GetExpiry(params.Context, key) + expireAt := params.GetExpiry(key) if expireAt == (time.Time{}) { return []byte(":-1\r\n"), nil @@ -305,19 +228,15 @@ func handleTTL(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] clock := params.GetClock() - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte(":-2\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - expireAt := params.GetExpiry(params.Context, key) + expireAt := params.GetExpiry(key) if expireAt == (time.Time{}) { return []byte(":-1\r\n"), nil @@ -342,6 +261,7 @@ func handleExpire(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.WriteKeys[0] + keyExists := params.KeysExist(keys.WriteKeys)[key] // Extract time n, err := strconv.ParseInt(params.Command[2], 10, 64) @@ -353,21 +273,16 @@ func handleExpire(params internal.HandlerFuncParams) ([]byte, error) { expireAt = params.GetClock().Now().Add(time.Duration(n) * time.Millisecond) } - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte(":0\r\n"), nil } - if _, err = params.KeyLock(params.Context, key); err != nil { - return []byte(":0\r\n"), err - } - defer params.KeyUnlock(params.Context, key) - if len(params.Command) == 3 { params.SetExpiry(params.Context, key, expireAt, true) return []byte(":1\r\n"), nil } - currentExpireAt := params.GetExpiry(params.Context, key) + currentExpireAt := params.GetExpiry(key) switch strings.ToLower(params.Command[3]) { case "nx": @@ -410,6 +325,7 @@ func handleExpireAt(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.WriteKeys[0] + keyExists := params.KeysExist(keys.WriteKeys)[key] // Extract time n, err := strconv.ParseInt(params.Command[2], 10, 64) @@ -421,21 +337,16 @@ func handleExpireAt(params internal.HandlerFuncParams) ([]byte, error) { expireAt = time.UnixMilli(n) } - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte(":0\r\n"), nil } - if _, err = params.KeyLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) - if len(params.Command) == 3 { params.SetExpiry(params.Context, key, expireAt, true) return []byte(":1\r\n"), nil } - currentExpireAt := params.GetExpiry(params.Context, key) + currentExpireAt := params.GetExpiry(key) switch strings.ToLower(params.Command[3]) { case "nx": diff --git a/internal/modules/generic/commands_test.go b/internal/modules/generic/commands_test.go index 25de089..84f8c8a 100644 --- a/internal/modules/generic/commands_test.go +++ b/internal/modules/generic/commands_test.go @@ -15,8 +15,6 @@ package generic_test import ( - "bytes" - "context" "errors" "fmt" "github.com/echovault/echovault/echovault" @@ -26,15 +24,15 @@ import ( "github.com/echovault/echovault/internal/constants" "github.com/tidwall/resp" "net" - "reflect" "strings" + "sync" "testing" "time" - "unsafe" ) +var addr string +var port int var mockServer *echovault.EchoVault - var mockClock clock.Clock type KeyData struct { @@ -44,65 +42,31 @@ type KeyData struct { func init() { mockClock = clock.NewClock() - + port, _ = internal.GetFreePort() mockServer, _ = echovault.NewEchoVault( echovault.WithConfig(config.Config{ + BindAddr: addr, + Port: uint16(port), DataDir: "", EvictionPolicy: constants.NoEviction, }), ) -} - -func getUnexportedField(field reflect.Value) interface{} { - return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface() -} - -func getHandler(commands ...string) internal.HandlerFunc { - if len(commands) == 0 { - return nil - } - getCommands := - getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command) - for _, c := range getCommands() { - if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 { - // Get command handler - return c.HandlerFunc - } - if strings.EqualFold(commands[0], c.Command) { - // Get sub-command handler - for _, sc := range c.SubCommands { - if strings.EqualFold(commands[1], sc.Command) { - return sc.HandlerFunc - } - } - } - } - return nil -} - -func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) internal.HandlerFuncParams { - getClock := - getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getClock")).(func() clock.Clock) - return internal.HandlerFuncParams{ - Context: ctx, - Command: cmd, - Connection: conn, - KeyExists: mockServer.KeyExists, - CreateKeyAndLock: mockServer.CreateKeyAndLock, - KeyLock: mockServer.KeyLock, - KeyRLock: mockServer.KeyRLock, - KeyUnlock: mockServer.KeyUnlock, - KeyRUnlock: mockServer.KeyRUnlock, - GetValue: mockServer.GetValue, - SetValue: mockServer.SetValue, - GetExpiry: mockServer.GetExpiry, - SetExpiry: mockServer.SetExpiry, - DeleteKey: mockServer.DeleteKey, - GetClock: getClock, - } + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + wg.Done() + mockServer.Start() + }() + wg.Wait() } func Test_HandleSET(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string command []string @@ -126,7 +90,7 @@ func Test_HandleSET(t *testing.T) { command: []string{"SET", "SetKey2", "1245678910"}, presetValues: nil, expectedResponse: "OK", - expectedValue: 1245678910, + expectedValue: "1245678910", expectedExpiry: time.Time{}, expectedErr: nil, }, @@ -135,7 +99,7 @@ func Test_HandleSET(t *testing.T) { command: []string{"SET", "SetKey3", "45782.11341"}, presetValues: nil, expectedResponse: "OK", - expectedValue: 45782.11341, + expectedValue: "45782.11341", expectedExpiry: time.Time{}, expectedErr: nil, }, @@ -409,35 +373,41 @@ func Test_HandleSET(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SET, %d", i+1)) - if test.presetValues != nil { for k, v := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, k); err != nil { + cmd := []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(k), + resp.StringValue(v.Value.(string))} + err := client.WriteArray(cmd) + if err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, k, v.Value); err != nil { + rd, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.SetExpiry(ctx, k, v.ExpireAt, false) - mockServer.KeyUnlock(ctx, k) + if !strings.EqualFold(rd.String(), "ok") { + t.Errorf("expected preset response to be \"OK\", got %s", rd.String()) + } } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for j, c := range test.command { + command[j] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + + res, _, err := client.ReadValue() + if test.expectedErr != nil { - if err == nil { - t.Errorf("expected error \"%s\", got nil", test.expectedErr.Error()) - } - if test.expectedErr.Error() != err.Error() { + if !strings.Contains(res.Error().Error(), test.expectedErr.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedErr.Error(), err.Error()) } return @@ -446,48 +416,57 @@ func Test_HandleSET(t *testing.T) { t.Error(err) } - rd := resp.NewReader(bytes.NewReader(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - switch test.expectedResponse.(type) { case string: - if test.expectedResponse != rv.String() { - t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, rv.String()) + if test.expectedResponse != res.String() { + t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) } case nil: - if !rv.IsNull() { - t.Errorf("expcted nil response, got %+v", rv) + if !res.IsNull() { + t.Errorf("expcted nil response, got %+v", res) } default: t.Error("test expected result should be nil or string") } - // Compare expected value and expected time key := test.command[1] - var value interface{} - var expireAt time.Time - if _, err = mockServer.KeyLock(ctx, key); err != nil { + // Compare expected value to response value + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { t.Error(err) } - value = mockServer.GetValue(ctx, key) - expireAt = mockServer.GetExpiry(ctx, key) - mockServer.KeyUnlock(ctx, key) - - if value != test.expectedValue { - t.Errorf("expected value %+v, got %+v", test.expectedValue, value) + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) } - if test.expectedExpiry.Unix() != expireAt.Unix() { - t.Errorf("expected expiry time %d, got %d, cmd: %+v", test.expectedExpiry.Unix(), expireAt.Unix(), test.command) + if res.String() != test.expectedValue.(string) { + t.Errorf("expected value %s, got %s", test.expectedValue.(string), res.String()) + } + + // Compare expected expiry to response expiry + if !test.expectedExpiry.Equal(time.Time{}) { + if err = client.WriteArray([]resp.Value{resp.StringValue("EXPIRETIME"), resp.StringValue(key)}); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + if res.Integer() != int(test.expectedExpiry.Unix()) { + t.Errorf("expected expiry time %d, got %d", test.expectedExpiry.Unix(), res.Integer()) + } } }) } } func Test_HandleMSET(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string command []string @@ -511,73 +490,70 @@ func Test_HandleMSET(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("MSET, %d", i)) - - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for j, c := range test.command { + command[j] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) if test.expectedErr != nil { - if err.Error() != test.expectedErr.Error() { + if !strings.Contains(res.Error().Error(), test.expectedErr.Error()) { t.Errorf("expected error %s, got %s", test.expectedErr.Error(), err.Error()) } return } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.String() != test.expectedResponse { - t.Errorf("expected response %s, got %s", test.expectedResponse, rv.String()) + + if res.String() != test.expectedResponse { + t.Errorf("expected response %s, got %s", test.expectedResponse, res.String()) } + for key, expectedValue := range test.expectedValues { - if _, err = mockServer.KeyRLock(ctx, key); err != nil { + // Get value from server + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { t.Error(err) } + res, _, err = client.ReadValue() switch expectedValue.(type) { default: t.Error("unexpected type for expectedValue") case int: ev, _ := expectedValue.(int) - value, ok := mockServer.GetValue(ctx, key).(int) - if !ok { - t.Errorf("expected integer type for key %s, got another type", key) - } - if value != ev { - t.Errorf("expected value %d for key %s, got %d", ev, key, value) + if res.Integer() != ev { + t.Errorf("expected value %d for key %s, got %d", ev, key, res.Integer()) } case float64: ev, _ := expectedValue.(float64) - value, ok := mockServer.GetValue(ctx, key).(float64) - if !ok { - t.Errorf("expected float type for key %s, got another type", key) - } - if value != ev { - t.Errorf("expected value %f for key %s, got %f", ev, key, value) + if res.Float() != ev { + t.Errorf("expected value %f for key %s, got %f", ev, key, res.Float()) } case string: ev, _ := expectedValue.(string) - value, ok := mockServer.GetValue(ctx, key).(string) - if !ok { - t.Errorf("expected string type for key %s, got another type", key) - } - if value != ev { - t.Errorf("expected value %s for key %s, got %s", ev, key, value) + if res.String() != ev { + t.Errorf("expected value %s for key %s, got %s", ev, key, res.String()) } } - mockServer.KeyRUnlock(ctx, key) } }) } } func Test_HandleGET(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string key string @@ -600,45 +576,50 @@ func Test_HandleGET(t *testing.T) { }, } // Test successful Get command - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("GET, %d", i)) func(key, value string) { - - _, err := mockServer.CreateKeyAndLock(ctx, key) + // Preset the values + err = client.WriteArray([]resp.Value{resp.StringValue("SET"), resp.StringValue(key), resp.StringValue(value)}) if err != nil { t.Error(err) } - if err = mockServer.SetValue(ctx, key, value); err != nil { - t.Error(err) - } - mockServer.KeyUnlock(ctx, key) - - handler := getHandler("GET") - if handler == nil { - t.Error("no handler found for command GET") - return - } - - res, err := handler(getHandlerFuncParams(ctx, []string{"GET", key}, nil)) + res, _, err := client.ReadValue() if err != nil { t.Error(err) } - if !bytes.Equal(res, []byte(fmt.Sprintf("+%v\r\n", value))) { - t.Errorf("expected %s, got: %s", fmt.Sprintf("+%v\r\n", value), string(res)) + + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be \"OK\", got %s", res.String()) + } + + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if res.String() != test.value { + t.Errorf("expected value %s, got %s", test.value, res.String()) } }(test.key, test.value) }) } // Test get non-existent key - res, err := getHandler("GET")(getHandlerFuncParams(context.Background(), []string{"GET", "test4"}, nil)) + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue("test4")}); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() if err != nil { t.Error(err) } - if !bytes.Equal(res, []byte("$-1\r\n")) { - t.Errorf("expected %+v, got: %+v", "+nil\r\n", res) + if !res.IsNull() { + t.Errorf("expected nil, got: %+v", res) } errorTests := []struct { @@ -659,16 +640,21 @@ func Test_HandleGET(t *testing.T) { } for _, test := range errorTests { t.Run(test.name, func(t *testing.T) { - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err = handler(getHandlerFuncParams(context.Background(), test.command, nil)) - if res != nil { - t.Errorf("expected nil response, got: %+v", res) + + if err = client.WriteArray(command); err != nil { + t.Error(err) } - if err.Error() != test.expected { + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.Contains(res.Error().Error(), test.expected) { t.Errorf("expected error '%s', got: %s", test.expected, err.Error()) } }) @@ -676,6 +662,12 @@ func Test_HandleGET(t *testing.T) { } func Test_HandleMGET(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string presetKeys []string @@ -710,47 +702,53 @@ func Test_HandleMGET(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("MGET, %d", i)) // Set up the values for i, key := range test.presetKeys { - _, err := mockServer.CreateKeyAndLock(ctx, key) + if err = client.WriteArray([]resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(test.presetValues[i]), + }); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() if err != nil { t.Error(err) } - if err = mockServer.SetValue(ctx, key, test.presetValues[i]); err != nil { - t.Error(err) + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be \"OK\", got \"%s\"", res.String()) } - mockServer.KeyUnlock(ctx, key) - } - // Test the command and its results - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + // Test the command and its results + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + if test.expectedError != nil { // If we expect and error, branch out and check error - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error %+v, got: %+v", test.expectedError, err) } return } - if err != nil { - t.Error(err) + + if res.Type().String() != "Array" { + t.Errorf("expected type Array, got: %s", res.Type().String()) } - rr := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rr.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Type().String() != "Array" { - t.Errorf("expected type Array, got: %s", rv.Type().String()) - } - for i, value := range rv.Array() { + for i, value := range res.Array() { if test.expected[i] == nil { if !value.IsNull() { t.Errorf("expected nil value, got %+v", value) @@ -766,10 +764,16 @@ func Test_HandleMGET(t *testing.T) { } func Test_HandleDEL(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string command []string - presetValues map[string]KeyData + presetValues map[string]string expectedResponse int expectToExist map[string]bool expectedErr error @@ -777,11 +781,11 @@ func Test_HandleDEL(t *testing.T) { { name: "1. Delete multiple keys", command: []string{"DEL", "DelKey1", "DelKey2", "DelKey3", "DelKey4", "DelKey5"}, - presetValues: map[string]KeyData{ - "DelKey1": {Value: "value1", ExpireAt: time.Time{}}, - "DelKey2": {Value: "value2", ExpireAt: time.Time{}}, - "DelKey3": {Value: "value3", ExpireAt: time.Time{}}, - "DelKey4": {Value: "value4", ExpireAt: time.Time{}}, + presetValues: map[string]string{ + "DelKey1": "value1", + "DelKey2": "value2", + "DelKey3": "value3", + "DelKey4": "value4", }, expectedResponse: 4, expectToExist: map[string]bool{ @@ -803,57 +807,63 @@ func Test_HandleDEL(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("DEL, %d", i)) - if test.presetValues != nil { for k, v := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, k); err != nil { + if err = client.WriteArray([]resp.Value{ + resp.StringValue("SET"), + resp.StringValue(k), + resp.StringValue(v), + }); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, k, v.Value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.SetExpiry(ctx, k, v.ExpireAt, false) - mockServer.KeyUnlock(ctx, k) + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be \"OK\", got %s", res.String()) + } } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) if test.expectedErr != nil { - if err == nil { - t.Errorf("exected error \"%s\", got nil", test.expectedErr.Error()) - } - if test.expectedErr.Error() != err.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedErr.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedErr.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedErr.Error(), res.Error().Error()) } return } - if err != nil { - t.Error(err) + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) } - rd := resp.NewReader(bytes.NewReader(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, rv.Integer()) - } - - for k, expected := range test.expectToExist { - exists := mockServer.KeyExists(ctx, k) + for key, expected := range test.expectToExist { + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + exists := !res.IsNull() if exists != expected { - t.Errorf("expected exists status to be %+v, got %+v", expected, exists) + t.Errorf("expected existence of key %s to be %v, got %v", key, expected, exists) } } }) @@ -861,6 +871,12 @@ func Test_HandleDEL(t *testing.T) { } func Test_HandlePERSIST(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string command []string @@ -919,76 +935,100 @@ func Test_HandlePERSIST(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("PERSIST, %d", i)) - if test.presetValues != nil { for k, v := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, k); err != nil { + command := []resp.Value{resp.StringValue("SET"), resp.StringValue(k), resp.StringValue(v.Value.(string))} + if !v.ExpireAt.Equal(time.Time{}) { + command = append(command, []resp.Value{ + resp.StringValue("PX"), + resp.StringValue(fmt.Sprintf("%d", v.ExpireAt.Sub(mockClock.Now()).Milliseconds())), + }...) + } + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, k, v.Value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.SetExpiry(ctx, k, v.ExpireAt, false) - mockServer.KeyUnlock(ctx, k) + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err == nil { - t.Errorf("expected error \"%s\", got nil", test.expectedError.Error()) - } - if test.expectedError.Error() != err.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - if err != nil { - t.Error(err) - } - rd := resp.NewReader(bytes.NewReader(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, rv.Integer()) + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) } if test.expectedValues == nil { return } - for k, expected := range test.expectedValues { - if _, err = mockServer.KeyLock(ctx, k); err != nil { + for key, expected := range test.expectedValues { + // Compare the value of the key with what's expected + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { t.Error(err) } - value := mockServer.GetValue(ctx, k) - expiry := mockServer.GetExpiry(ctx, k) - if value != expected.Value { - t.Errorf("expected value %+v, got %+v", expected.Value, value) + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) } - if expiry.UnixMilli() != expected.ExpireAt.UnixMilli() { - t.Errorf("expected exiry %d, got %d", expected.ExpireAt.UnixMilli(), expiry.UnixMilli()) + if res.String() != expected.Value.(string) { + t.Errorf("expected value %s, got %s", expected.Value.(string), res.String()) + } + // Compare the expiry of the key with what's expected + if err = client.WriteArray([]resp.Value{resp.StringValue("PTTL"), resp.StringValue(key)}); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + if expected.ExpireAt.Equal(time.Time{}) { + if res.Integer() != -1 { + t.Error("expected key to be persisted, it was not.") + } + continue + } + if res.Integer() != int(expected.ExpireAt.UnixMilli()) { + t.Errorf("expected expiry %d, got %d", expected.ExpireAt.UnixMilli(), res.Integer()) } - mockServer.KeyUnlock(ctx, k) } }) } } func Test_HandleEXPIRETIME(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string command []string @@ -1046,57 +1086,65 @@ func Test_HandleEXPIRETIME(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("EXPIRETIME/PEXPIRETIME, %d", i)) - if test.presetValues != nil { for k, v := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, k); err != nil { + command := []resp.Value{resp.StringValue("SET"), resp.StringValue(k), resp.StringValue(v.Value.(string))} + if !v.ExpireAt.Equal(time.Time{}) { + command = append(command, []resp.Value{ + resp.StringValue("PX"), + resp.StringValue(fmt.Sprintf("%d", v.ExpireAt.Sub(mockClock.Now()).Milliseconds())), + }...) + } + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, k, v.Value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.SetExpiry(ctx, k, v.ExpireAt, false) - mockServer.KeyUnlock(ctx, k) + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err == nil { - t.Errorf("expected error \"%s\", got nil", test.expectedError.Error()) - } - if test.expectedError.Error() != err.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - if err != nil { - t.Error(err) - } - rd := resp.NewReader(bytes.NewReader(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, rv.Integer()) + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) } }) } } func Test_HandleTTL(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string command []string @@ -1154,57 +1202,65 @@ func Test_HandleTTL(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("TTL/PTTL, %d", i)) - if test.presetValues != nil { for k, v := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, k); err != nil { + command := []resp.Value{resp.StringValue("SET"), resp.StringValue(k), resp.StringValue(v.Value.(string))} + if !v.ExpireAt.Equal(time.Time{}) { + command = append(command, []resp.Value{ + resp.StringValue("PX"), + resp.StringValue(fmt.Sprintf("%d", v.ExpireAt.Sub(mockClock.Now()).Milliseconds())), + }...) + } + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, k, v.Value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.SetExpiry(ctx, k, v.ExpireAt, false) - mockServer.KeyUnlock(ctx, k) + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err == nil { - t.Errorf("expected error \"%s\", got nil", test.expectedError.Error()) - } - if test.expectedError.Error() != err.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - if err != nil { - t.Error(err) - } - rd := resp.NewReader(bytes.NewReader(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, rv.Integer()) + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) } }) } } func Test_HandleEXPIRE(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string command []string @@ -1393,76 +1449,100 @@ func Test_HandleEXPIRE(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("PERSIST, %d", i)) - if test.presetValues != nil { for k, v := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, k); err != nil { + command := []resp.Value{resp.StringValue("SET"), resp.StringValue(k), resp.StringValue(v.Value.(string))} + if !v.ExpireAt.Equal(time.Time{}) { + command = append(command, []resp.Value{ + resp.StringValue("PX"), + resp.StringValue(fmt.Sprintf("%d", v.ExpireAt.Sub(mockClock.Now()).Milliseconds())), + }...) + } + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, k, v.Value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.SetExpiry(ctx, k, v.ExpireAt, false) - mockServer.KeyUnlock(ctx, k) + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err == nil { - t.Errorf("expected error \"%s\", got nil", test.expectedError.Error()) - } - if test.expectedError.Error() != err.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - if err != nil { - t.Error(err) - } - rd := resp.NewReader(bytes.NewReader(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, rv.Integer()) + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) } if test.expectedValues == nil { return } - for k, expected := range test.expectedValues { - if _, err = mockServer.KeyLock(ctx, k); err != nil { + for key, expected := range test.expectedValues { + // Compare the value of the key with what's expected + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { t.Error(err) } - value := mockServer.GetValue(ctx, k) - expiry := mockServer.GetExpiry(ctx, k) - if value != expected.Value { - t.Errorf("expected value %+v, got %+v", expected.Value, value) + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) } - if expiry.UnixMilli() != expected.ExpireAt.UnixMilli() { - t.Errorf("expected expiry %d, got %d", expected.ExpireAt.UnixMilli(), expiry.UnixMilli()) + if res.String() != expected.Value.(string) { + t.Errorf("expected value %s, got %s", expected.Value.(string), res.String()) + } + // Compare the expiry of the key with what's expected + if err = client.WriteArray([]resp.Value{resp.StringValue("PTTL"), resp.StringValue(key)}); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + if expected.ExpireAt.Equal(time.Time{}) { + if res.Integer() != -1 { + t.Error("expected key to be persisted, it was not.") + } + continue + } + if res.Integer() != int(expected.ExpireAt.Sub(mockClock.Now()).Milliseconds()) { + t.Errorf("expected expiry %d, got %d", expected.ExpireAt.Sub(mockClock.Now()).Milliseconds(), res.Integer()) } - mockServer.KeyUnlock(ctx, k) } }) } } func Test_HandleEXPIREAT(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string command []string @@ -1675,70 +1755,88 @@ func Test_HandleEXPIREAT(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("PERSIST, %d", i)) - if test.presetValues != nil { for k, v := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, k); err != nil { + command := []resp.Value{resp.StringValue("SET"), resp.StringValue(k), resp.StringValue(v.Value.(string))} + if !v.ExpireAt.Equal(time.Time{}) { + command = append(command, []resp.Value{ + resp.StringValue("PX"), + resp.StringValue(fmt.Sprintf("%d", v.ExpireAt.Sub(mockClock.Now()).Milliseconds())), + }...) + } + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, k, v.Value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.SetExpiry(ctx, k, v.ExpireAt, false) - mockServer.KeyUnlock(ctx, k) + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err == nil { - t.Errorf("expected error \"%s\", got nil", test.expectedError.Error()) - } - if test.expectedError.Error() != err.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - if err != nil { - t.Error(err) - } - rd := resp.NewReader(bytes.NewReader(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, rv.Integer()) + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) } if test.expectedValues == nil { return } - for k, expected := range test.expectedValues { - if _, err = mockServer.KeyLock(ctx, k); err != nil { + for key, expected := range test.expectedValues { + // Compare the value of the key with what's expected + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { t.Error(err) } - value := mockServer.GetValue(ctx, k) - expiry := mockServer.GetExpiry(ctx, k) - if value != expected.Value { - t.Errorf("expected value %+v, got %+v", expected.Value, value) + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) } - if expiry.UnixMilli() != expected.ExpireAt.UnixMilli() { - t.Errorf("expected expiry %d, got %d", expected.ExpireAt.UnixMilli(), expiry.UnixMilli()) + if res.String() != expected.Value.(string) { + t.Errorf("expected value %s, got %s", expected.Value.(string), res.String()) + } + // Compare the expiry of the key with what's expected + if err = client.WriteArray([]resp.Value{resp.StringValue("PTTL"), resp.StringValue(key)}); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + if expected.ExpireAt.Equal(time.Time{}) { + if res.Integer() != -1 { + t.Error("expected key to be persisted, it was not.") + } + continue + } + if res.Integer() != int(expected.ExpireAt.Sub(mockClock.Now()).Milliseconds()) { + t.Errorf("expected expiry %d, got %d", expected.ExpireAt.Sub(mockClock.Now()).Milliseconds(), res.Integer()) } - mockServer.KeyUnlock(ctx, k) } }) } diff --git a/internal/modules/hash/commands.go b/internal/modules/hash/commands.go index 63741c0..bb93f7e 100644 --- a/internal/modules/hash/commands.go +++ b/internal/modules/hash/commands.go @@ -32,6 +32,7 @@ func handleHSET(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.WriteKeys[0] + keyExists := params.KeysExist(keys.WriteKeys)[key] entries := make(map[string]interface{}) if len(params.Command[2:])%2 != 0 { @@ -42,26 +43,19 @@ func handleHSET(params internal.HandlerFuncParams) ([]byte, error) { entries[params.Command[i]] = internal.AdaptType(params.Command[i+1]) } - if !params.KeyExists(params.Context, key) { - _, err = params.CreateKeyAndLock(params.Context, key) + if !keyExists { if err != nil { return nil, err } - defer params.KeyUnlock(params.Context, key) - if err = params.SetValue(params.Context, key, entries); err != nil { + if err = params.SetValues(params.Context, map[string]interface{}{key: entries}); err != nil { return nil, err } return []byte(fmt.Sprintf(":%d\r\n", len(entries))), nil } - if _, err = params.KeyLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) - - hash, ok := params.GetValue(params.Context, key).(map[string]interface{}) + hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) if !ok { - return nil, fmt.Errorf("value at %s is not a hash", key) + hash = make(map[string]interface{}) } count := 0 @@ -76,7 +70,7 @@ func handleHSET(params internal.HandlerFuncParams) ([]byte, error) { hash[field] = value count += 1 } - if err = params.SetValue(params.Context, key, hash); err != nil { + if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil { return nil, err } @@ -90,18 +84,14 @@ func handleHGET(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] fields := params.Command[2:] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte("$-1\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - hash, ok := params.GetValue(params.Context, key).(map[string]interface{}) + hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -141,18 +131,14 @@ func handleHSTRLEN(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] fields := params.Command[2:] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte("$-1\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - hash, ok := params.GetValue(params.Context, key).(map[string]interface{}) + hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -192,17 +178,13 @@ func handleHVALS(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte("*0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - hash, ok := params.GetValue(params.Context, key).(map[string]interface{}) + hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -233,6 +215,7 @@ func handleHRANDFIELD(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] count := 1 if len(params.Command) >= 3 { @@ -255,16 +238,11 @@ func handleHRANDFIELD(params internal.HandlerFuncParams) ([]byte, error) { } } - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte("*0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - hash, ok := params.GetValue(params.Context, key).(map[string]interface{}) + hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -349,17 +327,13 @@ func handleHLEN(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte(":0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - hash, ok := params.GetValue(params.Context, key).(map[string]interface{}) + hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -374,17 +348,13 @@ func handleHKEYS(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte("*0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - hash, ok := params.GetValue(params.Context, key).(map[string]interface{}) + hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -404,6 +374,7 @@ func handleHINCRBY(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.WriteKeys[0] + keyExists := params.KeysExist(keys.WriteKeys)[key] field := params.Command[2] var intIncrement int @@ -423,33 +394,24 @@ func handleHINCRBY(params internal.HandlerFuncParams) ([]byte, error) { intIncrement = i } - if !params.KeyExists(params.Context, key) { - if _, err := params.CreateKeyAndLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) + if !keyExists { hash := make(map[string]interface{}) if strings.EqualFold(params.Command[0], "hincrbyfloat") { hash[field] = floatIncrement - if err = params.SetValue(params.Context, key, hash); err != nil { + if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil { return nil, err } return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(floatIncrement, 'f', -1, 64))), nil } else { hash[field] = intIncrement - if err = params.SetValue(params.Context, key, hash); err != nil { + if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil { return nil, err } return []byte(fmt.Sprintf(":%d\r\n", intIncrement)), nil } } - if _, err := params.KeyLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) - - hash, ok := params.GetValue(params.Context, key).(map[string]interface{}) + hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -477,7 +439,7 @@ func handleHINCRBY(params internal.HandlerFuncParams) ([]byte, error) { } } - if err = params.SetValue(params.Context, key, hash); err != nil { + if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil { return nil, err } @@ -496,17 +458,13 @@ func handleHGETALL(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte("*0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - hash, ok := params.GetValue(params.Context, key).(map[string]interface{}) + hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -536,18 +494,14 @@ func handleHEXISTS(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] field := params.Command[2] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte(":0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - hash, ok := params.GetValue(params.Context, key).(map[string]interface{}) + hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -566,18 +520,14 @@ func handleHDEL(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.WriteKeys[0] + keyExists := params.KeysExist(keys.WriteKeys)[key] fields := params.Command[2:] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte(":0\r\n"), nil } - if _, err = params.KeyLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) - - hash, ok := params.GetValue(params.Context, key).(map[string]interface{}) + hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -591,7 +541,7 @@ func handleHDEL(params internal.HandlerFuncParams) ([]byte, error) { } } - if err = params.SetValue(params.Context, key, hash); err != nil { + if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil { return nil, err } diff --git a/internal/modules/hash/commands_test.go b/internal/modules/hash/commands_test.go index 0a77ca2..b99b8d6 100644 --- a/internal/modules/hash/commands_test.go +++ b/internal/modules/hash/commands_test.go @@ -15,8 +15,6 @@ package hash_test import ( - "bytes" - "context" "errors" "fmt" "github.com/echovault/echovault/echovault" @@ -25,206 +23,209 @@ import ( "github.com/echovault/echovault/internal/constants" "github.com/tidwall/resp" "net" - "reflect" "slices" + "strconv" "strings" + "sync" "testing" - "unsafe" ) var mockServer *echovault.EchoVault +var addr = "localhost" +var port int func init() { + port, _ = internal.GetFreePort() mockServer, _ = echovault.NewEchoVault( echovault.WithConfig(config.Config{ + BindAddr: addr, + Port: uint16(port), DataDir: "", EvictionPolicy: constants.NoEviction, }), ) -} - -func getUnexportedField(field reflect.Value) interface{} { - return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface() -} - -func getHandler(commands ...string) internal.HandlerFunc { - if len(commands) == 0 { - return nil - } - getCommands := - getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command) - for _, c := range getCommands() { - if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 { - // Get command handler - return c.HandlerFunc - } - if strings.EqualFold(commands[0], c.Command) { - // Get sub-command handler - for _, sc := range c.SubCommands { - if strings.EqualFold(commands[1], sc.Command) { - return sc.HandlerFunc - } - } - } - } - return nil -} - -func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) internal.HandlerFuncParams { - return internal.HandlerFuncParams{ - Context: ctx, - Command: cmd, - Connection: conn, - KeyExists: mockServer.KeyExists, - CreateKeyAndLock: mockServer.CreateKeyAndLock, - KeyLock: mockServer.KeyLock, - KeyRLock: mockServer.KeyRLock, - KeyUnlock: mockServer.KeyUnlock, - KeyRUnlock: mockServer.KeyRUnlock, - GetValue: mockServer.GetValue, - SetValue: mockServer.SetValue, - } + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + wg.Done() + mockServer.Start() + }() + wg.Wait() } func Test_HandleHSET(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + // Tests for both HSet and HSetNX tests := []struct { name string - preset bool key string presetValue interface{} command []string expectedResponse int // Change count - expectedValue map[string]interface{} + expectedValue map[string]string expectedError error }{ { name: "1. HSETNX set field on non-existent hash map", - preset: false, key: "HsetKey1", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HSETNX", "HsetKey1", "field1", "value1"}, expectedResponse: 1, - expectedValue: map[string]interface{}{"field1": "value1"}, + expectedValue: map[string]string{"field1": "value1"}, expectedError: nil, }, { name: "2. HSETNX set field on existing hash map", - preset: true, key: "HsetKey2", - presetValue: map[string]interface{}{"field1": "value1"}, + presetValue: map[string]string{"field1": "value1"}, command: []string{"HSETNX", "HsetKey2", "field2", "value2"}, expectedResponse: 1, - expectedValue: map[string]interface{}{"field1": "value1", "field2": "value2"}, + expectedValue: map[string]string{"field1": "value1", "field2": "value2"}, expectedError: nil, }, { name: "3. HSETNX skips operation when setting on existing field", - preset: true, key: "HsetKey3", - presetValue: map[string]interface{}{"field1": "value1"}, + presetValue: map[string]string{"field1": "value1"}, command: []string{"HSETNX", "HsetKey3", "field1", "value1-new"}, expectedResponse: 0, - expectedValue: map[string]interface{}{"field1": "value1"}, + expectedValue: map[string]string{"field1": "value1"}, expectedError: nil, }, { name: "4. Regular HSET command on non-existent hash map", - preset: false, key: "HsetKey4", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HSET", "HsetKey4", "field1", "value1", "field2", "value2"}, expectedResponse: 2, - expectedValue: map[string]interface{}{"field1": "value1", "field2": "value2"}, + expectedValue: map[string]string{"field1": "value1", "field2": "value2"}, expectedError: nil, }, { name: "5. Regular HSET update on existing hash map", - preset: true, key: "HsetKey5", - presetValue: map[string]interface{}{"field1": "value1", "field2": "value2"}, + presetValue: map[string]string{"field1": "value1", "field2": "value2"}, command: []string{"HSET", "HsetKey5", "field1", "value1-new", "field2", "value2-ne2", "field3", "value3"}, expectedResponse: 3, - expectedValue: map[string]interface{}{"field1": "value1-new", "field2": "value2-ne2", "field3": "value3"}, + expectedValue: map[string]string{"field1": "value1-new", "field2": "value2-ne2", "field3": "value3"}, expectedError: nil, }, { name: "6. HSET returns error when the target key is not a map", - preset: true, key: "HsetKey6", presetValue: "Default preset value", command: []string{"HSET", "HsetKey6", "field1", "value1"}, - expectedResponse: 0, - expectedValue: map[string]interface{}{}, - expectedError: errors.New("value at HsetKey6 is not a hash"), + expectedResponse: 1, + expectedValue: map[string]string{"field1": "value1"}, + expectedError: nil, }, { name: "7. HSET returns error when there's a mismatch in key/values", - preset: false, key: "HsetKey7", presetValue: nil, command: []string{"HSET", "HsetKey7", "field1", "value1", "field2"}, expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedValue: map[string]string{}, expectedError: errors.New("each field must have a corresponding value"), }, { name: "8. Command too short", - preset: true, key: "HsetKey8", presetValue: nil, command: []string{"HSET", "field1"}, expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedValue: map[string]string{}, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HSET/HSETNX, %d", i)) - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) + } + + // Check that all the values are what is expected + if err := client.WriteArray([]resp.Value{ + resp.StringValue("HGETALL"), + resp.StringValue(test.key), + }); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, rv.Integer()) - } - // Check that all the values are what is expected - if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { - t.Error(err) - } - h, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{}) - if !ok { - t.Errorf("value at key \"%s\" is not a hash map", test.key) - } - for field, value := range h { - if value != test.expectedValue[field] { - t.Errorf("expected value \"%+v\" for field \"%+v\", got \"%+v\"", test.expectedValue[field], field, value) + + for idx, field := range res.Array() { + if idx%2 == 0 { + if res.Array()[idx+1].String() != test.expectedValue[field.String()] { + t.Errorf( + "expected value \"%+v\" for field \"%s\", got \"%+v\"", + test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), + ) + } } } }) @@ -232,174 +233,196 @@ func Test_HandleHSET(t *testing.T) { } func Test_HandleHINCRBY(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + // Tests for both HIncrBy and HIncrByFloat tests := []struct { name string - preset bool key string presetValue interface{} command []string - expectedResponse interface{} // Change count - expectedValue map[string]interface{} + expectedResponse string // Change count + expectedValue map[string]string expectedError error }{ { name: "1. Increment by integer on non-existent hash should create a new one", - preset: false, key: "HincrbyKey1", presetValue: nil, command: []string{"HINCRBY", "HincrbyKey1", "field1", "1"}, - expectedResponse: 1, - expectedValue: map[string]interface{}{"field1": 1}, + expectedResponse: "1", + expectedValue: map[string]string{"field1": "1"}, expectedError: nil, }, { name: "2. Increment by float on non-existent hash should create one", - preset: false, key: "HincrbyKey2", presetValue: nil, command: []string{"HINCRBYFLOAT", "HincrbyKey2", "field1", "3.142"}, - expectedResponse: 3.142, - expectedValue: map[string]interface{}{"field1": 3.142}, + expectedResponse: "3.142", + expectedValue: map[string]string{"field1": "3.142"}, expectedError: nil, }, { name: "3. Increment by integer on existing hash", - preset: true, key: "HincrbyKey3", - presetValue: map[string]interface{}{"field1": 1}, + presetValue: map[string]string{"field1": "1"}, command: []string{"HINCRBY", "HincrbyKey3", "field1", "10"}, - expectedResponse: 11, - expectedValue: map[string]interface{}{"field1": 11}, + expectedResponse: "11", + expectedValue: map[string]string{"field1": "11"}, expectedError: nil, }, { name: "4. Increment by float on an existing hash", - preset: true, key: "HincrbyKey4", - presetValue: map[string]interface{}{"field1": 3.142}, + presetValue: map[string]string{"field1": "3.142"}, command: []string{"HINCRBYFLOAT", "HincrbyKey4", "field1", "3.142"}, - expectedResponse: 6.284, - expectedValue: map[string]interface{}{"field1": 6.284}, + expectedResponse: "6.284", + expectedValue: map[string]string{"field1": "6.284"}, expectedError: nil, }, { name: "5. Command too short", - preset: false, key: "HincrbyKey5", presetValue: nil, command: []string{"HINCRBY", "HincrbyKey5"}, - expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedResponse: "0", + expectedValue: nil, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "6. Command too long", - preset: false, key: "HincrbyKey6", presetValue: nil, command: []string{"HINCRBY", "HincrbyKey6", "field1", "23", "45"}, - expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedResponse: "0", + expectedValue: nil, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "7. Error when increment by float does not pass valid float", - preset: false, key: "HincrbyKey7", presetValue: nil, command: []string{"HINCRBYFLOAT", "HincrbyKey7", "field1", "three point one four two"}, - expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedResponse: "0", + expectedValue: nil, expectedError: errors.New("increment must be a float"), }, { name: "8. Error when increment does not pass valid integer", - preset: false, key: "HincrbyKey8", presetValue: nil, command: []string{"HINCRBY", "HincrbyKey8", "field1", "three"}, - expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedResponse: "0", + expectedValue: nil, expectedError: errors.New("increment must be an integer"), }, { name: "9. Error when trying to increment on a key that is not a hash", - preset: true, key: "HincrbyKey9", presetValue: "Default value", command: []string{"HINCRBY", "HincrbyKey9", "field1", "3"}, - expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedResponse: "0", + expectedValue: nil, expectedError: errors.New("value at HincrbyKey9 is not a hash"), }, { name: "10. Error when trying to increment a hash field that is not a number", - preset: true, key: "HincrbyKey10", - presetValue: map[string]interface{}{"field1": "value1"}, + presetValue: map[string]string{"field1": "value1"}, command: []string{"HINCRBY", "HincrbyKey10", "field1", "3"}, - expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedResponse: "0", + expectedValue: nil, expectedError: errors.New("value at field field1 is not a number"), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HINCRBY, %d", i)) + if test.presetValue != nil { + var command []resp.Value + var expected string - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() + + if res.String() != test.expectedResponse { + t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) + } + + // Check that all the values are what is expected + if err := client.WriteArray([]resp.Value{ + resp.StringValue("HGETALL"), + resp.StringValue(test.key), + }); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - switch test.expectedResponse.(type) { - default: - t.Error("expectedResponse must be an integer or string") - case int: - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response \"%+v\", got \"%d\"", test.expectedResponse, rv.Integer()) - } - case float64: - if rv.Float() != test.expectedResponse { - t.Errorf("expected response \"%+v\", got \"%+v\"", test.expectedResponse, rv.Float()) - } - } - // Check that all the values are what is expected - if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { - t.Error(err) - } - h, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{}) - if !ok { - t.Errorf("value at key \"%s\" is not a hash map", test.key) - } - for field, value := range h { - if value != test.expectedValue[field] { - t.Errorf("expected value \"%+v\" for field \"%+v\", got \"%+v\"", test.expectedValue[field], field, value) + + for idx, field := range res.Array() { + if idx%2 == 0 { + if res.Array()[idx+1].String() != test.expectedValue[field.String()] { + t.Errorf( + "expected value \"%+v\" for field \"%s\", got \"%+v\"", + test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), + ) + } } } }) @@ -407,112 +430,149 @@ func Test_HandleHINCRBY(t *testing.T) { } func Test_HandleHGET(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool key string presetValue interface{} command []string - expectedResponse interface{} // Change count - expectedValue map[string]interface{} + expectedResponse []string // Change count + expectedValue map[string]string expectedError error }{ { - name: "1. Return nil when attempting to get from non-existed key", - preset: true, + name: "1. Get values from existing hash.", key: "HgetKey1", - presetValue: map[string]interface{}{"field1": "value1", "field2": 365, "field3": 3.142}, + presetValue: map[string]string{"field1": "value1", "field2": "365", "field3": "3.142"}, command: []string{"HGET", "HgetKey1", "field1", "field2", "field3", "field4"}, - expectedResponse: []interface{}{"value1", 365, "3.142", nil}, - expectedValue: map[string]interface{}{}, + expectedResponse: []string{"value1", "365", "3.142", ""}, + expectedValue: map[string]string{"field1": "value1", "field2": "365", "field3": "3.142"}, expectedError: nil, }, { name: "2. Return nil when attempting to get from non-existed key", - preset: false, key: "HgetKey2", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HGET", "HgetKey2", "field1"}, expectedResponse: nil, - expectedValue: map[string]interface{}{}, + expectedValue: nil, expectedError: nil, }, { name: "3. Error when trying to get from a value that is not a hash map", - preset: true, key: "HgetKey3", presetValue: "Default Value", command: []string{"HGET", "HgetKey3", "field1"}, - expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedResponse: nil, + expectedValue: nil, expectedError: errors.New("value at HgetKey3 is not a hash"), }, { name: "4. Command too short", - preset: false, key: "HgetKey4", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HGET", "HgetKey4"}, - expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedResponse: nil, + expectedValue: nil, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HINCRBY, %d", i)) + if test.presetValue != nil { + var command []resp.Value + var expected string - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } + if test.expectedResponse == nil { - if !rv.IsNull() { - t.Errorf("expected nil response, got %+v", rv) + if !res.IsNull() { + t.Errorf("expected nil response, got %+v", res) } return } - if expectedArr, ok := test.expectedResponse.([]interface{}); ok { - for i, v := range rv.Array() { - switch v.Type().String() { - default: - t.Error("unexpected type encountered") - case "Integer": - if v.Integer() != expectedArr[i] { - t.Errorf("expected \"%+v\", got \"%d\"", expectedArr[i], v.Integer()) - } - case "BulkString": - if len(v.String()) == 0 && expectedArr[i] == nil { - continue - } - if v.String() != expectedArr[i] { - t.Errorf("expected \"%+v\", got \"%s\"", expectedArr[i], v.String()) - } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected element \"%s\" in response", item.String()) + } + } + + // Check that all the values are what is expected + if err := client.WriteArray([]resp.Value{ + resp.StringValue("HGETALL"), + resp.StringValue(test.key), + }); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + for idx, field := range res.Array() { + if idx%2 == 0 { + if res.Array()[idx+1].String() != test.expectedValue[field.String()] { + t.Errorf( + "expected value \"%+v\" for field \"%s\", got \"%+v\"", + test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), + ) } } } @@ -521,229 +581,151 @@ func Test_HandleHGET(t *testing.T) { } func Test_HandleHSTRLEN(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool key string presetValue interface{} command []string - expectedResponse interface{} // Change count - expectedValue map[string]interface{} + expectedResponse []int // Change count + expectedValue map[string]string expectedError error }{ { // Return lengths of field values. // If the key does not exist, its length should be 0. name: "1. Return lengths of field values.", - preset: true, key: "HstrlenKey1", - presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, command: []string{"HSTRLEN", "HstrlenKey1", "field1", "field2", "field3", "field4"}, expectedResponse: []int{len("value1"), len("123456789"), len("3.142"), 0}, - expectedValue: map[string]interface{}{}, + expectedValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, expectedError: nil, }, { name: "2. Nil response when trying to get HSTRLEN non-existent key", - preset: false, key: "HstrlenKey2", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HSTRLEN", "HstrlenKey2", "field1"}, expectedResponse: nil, - expectedValue: map[string]interface{}{}, + expectedValue: nil, expectedError: nil, }, { name: "3. Command too short", - preset: false, key: "HstrlenKey3", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HSTRLEN", "HstrlenKey3"}, - expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedResponse: nil, + expectedValue: nil, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "4. Trying to get lengths on a non hash map returns error", - preset: true, key: "HstrlenKey4", presetValue: "Default value", command: []string{"HSTRLEN", "HstrlenKey4", "field1"}, - expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedResponse: nil, + expectedValue: nil, expectedError: errors.New("value at HstrlenKey4 is not a hash"), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HSTRLEN, %d", i)) + if test.presetValue != nil { + var command []resp.Value + var expected string - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } + if test.expectedResponse == nil { - if !rv.IsNull() { - t.Errorf("expected nil response, got %+v", rv) + if !res.IsNull() { + t.Errorf("expected nil response, got %+v", res) } return } - expectedResponse, _ := test.expectedResponse.([]int) - for i, v := range rv.Array() { - if v.Integer() != expectedResponse[i] { - t.Errorf("expected \"%d\", got \"%d\"", expectedResponse[i], v.Integer()) + + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.Integer()) { + t.Errorf("unexpected element \"%d\" in response", item.Integer()) } } - }) - } -} -func Test_HandleHVALS(t *testing.T) { - tests := []struct { - name string - preset bool - key string - presetValue interface{} - command []string - expectedResponse []interface{} - expectedValue map[string]interface{} - expectedError error - }{ - { - name: "1. Return all the values from a hash", - preset: true, - key: "HvalsKey1", - presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, - command: []string{"HVALS", "HvalsKey1"}, - expectedResponse: []interface{}{"value1", 123456789, "3.142"}, - expectedValue: map[string]interface{}{}, - expectedError: nil, - }, - { - name: "2. Empty array response when trying to get HSTRLEN non-existent key", - preset: false, - key: "HvalsKey2", - presetValue: map[string]interface{}{}, - command: []string{"HVALS", "HvalsKey2"}, - expectedResponse: []interface{}{}, - expectedValue: map[string]interface{}{}, - expectedError: nil, - }, - { - name: "3. Command too short", - preset: false, - key: "HvalsKey3", - presetValue: map[string]interface{}{}, - command: []string{"HVALS"}, - expectedResponse: nil, - expectedValue: map[string]interface{}{}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "4. Command too long", - preset: false, - key: "HvalsKey4", - presetValue: map[string]interface{}{}, - command: []string{"HVALS", "HvalsKey4", "HvalsKey4"}, - expectedResponse: nil, - expectedValue: map[string]interface{}{}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Trying to get lengths on a non hash map returns error", - preset: true, - key: "HvalsKey5", - presetValue: "Default value", - command: []string{"HVALS", "HvalsKey5"}, - expectedResponse: nil, - expectedValue: map[string]interface{}{}, - expectedError: errors.New("value at HvalsKey5 is not a hash"), - }, - } - - for i, test := range tests { - t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HVALS, %d", i)) - - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { - t.Error(err) - } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { - t.Error(err) - } - mockServer.KeyUnlock(ctx, test.key) + // Check that all the values are what is expected + if err := client.WriteArray([]resp.Value{ + resp.StringValue("HGETALL"), + resp.StringValue(test.key), + }); err != nil { + t.Error(err) } - - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return - } - - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) - if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - switch len(test.expectedResponse) { - case 0: - if len(rv.Array()) != 0 { - t.Errorf("expected empty array, got length \"%d\"", len(rv.Array())) - } - default: - for _, v := range rv.Array() { - switch v.Type().String() { - default: - t.Errorf("unexpected error type") - case "Integer": - // Value is an integer, check if it is contained in the expected response - if !slices.ContainsFunc(test.expectedResponse, func(e interface{}) bool { - expectedValue, ok := e.(int) - return ok && expectedValue == v.Integer() - }) { - t.Errorf("couldn't find response value \"%d\" in expected values", v.Integer()) - } - case "BulkString": - // Value is a string, check if it is contained in the expected response - if !slices.ContainsFunc(test.expectedResponse, func(e interface{}) bool { - expectedValue, ok := e.(string) - return ok && expectedValue == v.String() - }) { - t.Errorf("couldn't find response value \"%s\" in expected values", v.String()) - } + + for idx, field := range res.Array() { + if idx%2 == 0 { + if res.Array()[idx+1].String() != test.expectedValue[field.String()] { + t.Errorf( + "expected value \"%+v\" for field \"%s\", got \"%+v\"", + test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), + ) } } } @@ -751,71 +733,199 @@ func Test_HandleHVALS(t *testing.T) { } } -func Test_HandleHRANDFIELD(t *testing.T) { +func Test_HandleHVALS(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse []string + expectedValue map[string]string + expectedError error + }{ + { + name: "1. Return all the values from a hash", + key: "HvalsKey1", + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, + command: []string{"HVALS", "HvalsKey1"}, + expectedResponse: []string{"value1", "123456789", "3.142"}, + expectedValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, + expectedError: nil, + }, + { + name: "2. Empty array response when trying to get HSTRLEN non-existent key", + key: "HvalsKey2", + presetValue: nil, + command: []string{"HVALS", "HvalsKey2"}, + expectedResponse: []string{}, + expectedValue: nil, + expectedError: nil, + }, + { + name: "3. Command too short", + key: "HvalsKey3", + presetValue: nil, + command: []string{"HVALS"}, + expectedResponse: nil, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "4. Command too long", + key: "HvalsKey4", + presetValue: nil, + command: []string{"HVALS", "HvalsKey4", "HvalsKey4"}, + expectedResponse: nil, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Trying to get lengths on a non hash map returns error", + key: "HvalsKey5", + presetValue: "Default value", + command: []string{"HVALS", "HvalsKey5"}, + expectedResponse: nil, + expectedValue: nil, + expectedError: errors.New("value at HvalsKey5 is not a hash"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if test.expectedResponse == nil { + if !res.IsNull() { + t.Errorf("expected nil response, got %+v", res) + } + return + } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected element \"%s\" in response", item.String()) + } + } + }) + } +} + +func Test_HandleHRANDFIELD(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool key string presetValue interface{} command []string - withValues bool - expectedCount int expectedResponse []string expectedError error }{ { name: "1. Get a random field", - preset: true, key: "HrandfieldKey1", - presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, command: []string{"HRANDFIELD", "HrandfieldKey1"}, - withValues: false, - expectedCount: 1, expectedResponse: []string{"field1", "field2", "field3"}, expectedError: nil, }, { name: "2. Get a random field with a value", - preset: true, key: "HrandfieldKey2", - presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, command: []string{"HRANDFIELD", "HrandfieldKey2", "1", "WITHVALUES"}, - withValues: true, - expectedCount: 2, expectedResponse: []string{"field1", "value1", "field2", "123456789", "field3", "3.142"}, expectedError: nil, }, { - name: "3. Get several random fields", - preset: true, - key: "HrandfieldKey3", - presetValue: map[string]interface{}{ + name: "3. Get several random fields", + key: "HrandfieldKey3", + presetValue: map[string]string{ "field1": "value1", - "field2": 123456789, - "field3": 3.142, + "field2": "123456789", + "field3": "3.142", "field4": "value4", "field5": "value5", }, command: []string{"HRANDFIELD", "HrandfieldKey3", "3"}, - withValues: false, - expectedCount: 3, expectedResponse: []string{"field1", "field2", "field3", "field4", "field5"}, expectedError: nil, }, { - name: "4. Get several random fields with their corresponding values", - preset: true, - key: "HrandfieldKey4", - presetValue: map[string]interface{}{ + name: "4. Get several random fields with their corresponding values", + key: "HrandfieldKey4", + presetValue: map[string]string{ "field1": "value1", - "field2": 123456789, - "field3": 3.142, + "field2": "123456789", + "field3": "3.142", "field4": "value4", "field5": "value5", }, - command: []string{"HRANDFIELD", "HrandfieldKey4", "3", "WITHVALUES"}, - withValues: true, - expectedCount: 6, + command: []string{"HRANDFIELD", "HrandfieldKey4", "3", "WITHVALUES"}, expectedResponse: []string{ "field1", "value1", "field2", "123456789", "field3", "3.142", "field4", "value4", "field5", "value5", @@ -823,36 +933,30 @@ func Test_HandleHRANDFIELD(t *testing.T) { expectedError: nil, }, { - name: "5. Get the entire hash", - preset: true, - key: "HrandfieldKey5", - presetValue: map[string]interface{}{ + name: "5. Get the entire hash", + key: "HrandfieldKey5", + presetValue: map[string]string{ "field1": "value1", - "field2": 123456789, - "field3": 3.142, + "field2": "123456789", + "field3": "3.142", "field4": "value4", "field5": "value5", }, command: []string{"HRANDFIELD", "HrandfieldKey5", "5"}, - withValues: false, - expectedCount: 5, expectedResponse: []string{"field1", "field2", "field3", "field4", "field5"}, expectedError: nil, }, { - name: "6. Get the entire hash with values", - preset: true, - key: "HrandfieldKey5", - presetValue: map[string]interface{}{ + name: "6. Get the entire hash with values", + key: "HrandfieldKey5", + presetValue: map[string]string{ "field1": "value1", - "field2": 123456789, - "field3": 3.142, + "field2": "123456789", + "field3": "3.142", "field4": "value4", "field5": "value5", }, - command: []string{"HRANDFIELD", "HrandfieldKey5", "5", "WITHVALUES"}, - withValues: true, - expectedCount: 10, + command: []string{"HRANDFIELD", "HrandfieldKey5", "5", "WITHVALUES"}, expectedResponse: []string{ "field1", "value1", "field2", "123456789", "field3", "3.142", "field4", "value4", "field5", "value5", @@ -861,23 +965,20 @@ func Test_HandleHRANDFIELD(t *testing.T) { }, { name: "7. Command too short", - preset: false, key: "HrandfieldKey10", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HRANDFIELD"}, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "8. Command too long", - preset: false, key: "HrandfieldKey11", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HRANDFIELD", "HrandfieldKey11", "HrandfieldKey11", "HrandfieldKey11", "HrandfieldKey11"}, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "9. Trying to get random field on a non hash map returns error", - preset: true, key: "HrandfieldKey12", presetValue: "Default value", command: []string{"HRANDFIELD", "HrandfieldKey12"}, @@ -885,7 +986,6 @@ func Test_HandleHRANDFIELD(t *testing.T) { }, { name: "10. Throw error when count provided is not an integer", - preset: true, key: "HrandfieldKey12", presetValue: "Default value", command: []string{"HRANDFIELD", "HrandfieldKey12", "COUNT"}, @@ -893,7 +993,6 @@ func Test_HandleHRANDFIELD(t *testing.T) { }, { name: "11. If fourth argument is provided, it must be \"WITHVALUES\"", - preset: true, key: "HrandfieldKey12", presetValue: "Default value", command: []string{"HRANDFIELD", "HrandfieldKey12", "10", "FLAG"}, @@ -901,67 +1000,74 @@ func Test_HandleHRANDFIELD(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HRANDFIELD, %d", i)) + if test.presetValue != nil { + var command []resp.Value + var expected string - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if len(rv.Array()) != test.expectedCount { - t.Errorf("expected response array of length \"%d\", got length \"%d\"", test.expectedCount, len(rv.Array())) - } - switch test.withValues { - case false: - for _, v := range rv.Array() { - if !slices.ContainsFunc(test.expectedResponse, func(expected string) bool { - return expected == v.String() - }) { - t.Errorf("could not find response element \"%s\" in expected response", v.String()) - } + + if test.expectedResponse == nil { + if !res.IsNull() { + t.Errorf("expected nil response, got %+v", res) } - case true: - responseArray := rv.Array() - for i := 0; i < len(responseArray); i++ { - if i%2 == 0 { - field := responseArray[i].String() - value := responseArray[i+1].String() + return + } - expectedFieldIndex := slices.Index(test.expectedResponse, field) - if expectedFieldIndex == -1 { - t.Errorf("could not find response value \"%s\" in expected values", field) - } - expectedValue := test.expectedResponse[expectedFieldIndex+1] - - if value != expectedValue { - t.Errorf("expected value \"%s\", got \"%s\"", expectedValue, value) - } - } + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected element \"%s\" in response", item.String()) } } }) @@ -969,567 +1075,642 @@ func Test_HandleHRANDFIELD(t *testing.T) { } func Test_HandleHLEN(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool key string presetValue interface{} command []string - expectedResponse interface{} // Change count - expectedValue map[string]interface{} + expectedResponse int // Change count expectedError error }{ { name: "1. Return the correct length of the hash", - preset: true, key: "HlenKey1", - presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, command: []string{"HLEN", "HlenKey1"}, expectedResponse: 3, - expectedValue: map[string]interface{}{}, expectedError: nil, }, { name: "2. 0 response when trying to call HLEN on non-existent key", - preset: false, key: "HlenKey2", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HLEN", "HlenKey2"}, expectedResponse: 0, - expectedValue: map[string]interface{}{}, expectedError: nil, }, { name: "3. Command too short", - preset: false, key: "HlenKey3", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HLEN"}, expectedResponse: 0, - expectedValue: map[string]interface{}{}, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "4. Command too long", - preset: false, - key: "HlenKey4", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HLEN", "HlenKey4", "HlenKey4"}, expectedResponse: 0, - expectedValue: map[string]interface{}{}, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "5. Trying to get lengths on a non hash map returns error", - preset: true, key: "HlenKey5", presetValue: "Default value", command: []string{"HLEN", "HlenKey5"}, expectedResponse: 0, - expectedValue: map[string]interface{}{}, expectedError: errors.New("value at HlenKey5 is not a hash"), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HLEN, %d", i)) + if test.presetValue != nil { + var command []resp.Value + var expected string - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) } - if expectedResponse, ok := test.expectedResponse.(int); ok { - if rv.Integer() != expectedResponse { - t.Errorf("expected ineger \"%d\", got \"%d\"", expectedResponse, rv.Integer()) - } - return - } - t.Error("expected integer response, got another type") }) } } func Test_HandleHKeys(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool key string presetValue interface{} command []string - expectedResponse interface{} // Change count - expectedValue map[string]interface{} + expectedResponse []string expectedError error }{ { name: "1. Return an array containing all the keys of the hash", - preset: true, key: "HkeysKey1", - presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, command: []string{"HKEYS", "HkeysKey1"}, expectedResponse: []string{"field1", "field2", "field3"}, - expectedValue: map[string]interface{}{}, expectedError: nil, }, { name: "2. Empty array response when trying to call HKEYS on non-existent key", - preset: false, key: "HkeysKey2", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HKEYS", "HkeysKey2"}, expectedResponse: []string{}, - expectedValue: map[string]interface{}{}, expectedError: nil, }, { name: "3. Command too short", - preset: false, key: "HkeysKey3", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HKEYS"}, expectedResponse: nil, - expectedValue: map[string]interface{}{}, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "4. Command too long", - preset: false, key: "HkeysKey4", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HKEYS", "HkeysKey4", "HkeysKey4"}, expectedResponse: nil, - expectedValue: map[string]interface{}{}, expectedError: errors.New(constants.WrongArgsResponse), }, { - name: "5. Trying to get lengths on a non hash map returns error", - preset: true, - key: "HkeysKey5", - presetValue: "Default value", - command: []string{"HKEYS", "HkeysKey5"}, - expectedResponse: 0, - expectedValue: map[string]interface{}{}, - expectedError: errors.New("value at HkeysKey5 is not a hash"), + name: "5. Trying to get lengths on a non hash map returns error", + key: "HkeysKey5", + presetValue: "Default value", + command: []string{"HKEYS", "HkeysKey5"}, + expectedError: errors.New("value at HkeysKey5 is not a hash"), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HKEYS, %d", i)) + if test.presetValue != nil { + var command []resp.Value + var expected string - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if expectedResponse, ok := test.expectedResponse.([]string); ok { - if len(rv.Array()) != len(expectedResponse) { - t.Errorf("expected length \"%d\", got \"%d\"", len(expectedResponse), len(rv.Array())) + + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected value \"%s\" in response", item.String()) } - for _, field := range expectedResponse { - if !slices.ContainsFunc(rv.Array(), func(value resp.Value) bool { - return value.String() == field - }) { - t.Errorf("could not find expected to find key \"%s\" in response", field) - } - } - return } - t.Error("expected array response, got another type") }) } } func Test_HandleHGETALL(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool key string presetValue interface{} command []string - expectedResponse []string - expectedValue map[string]interface{} + expectedResponse map[string]string expectedError error }{ { name: "1. Return an array containing all the fields and values of the hash", - preset: true, key: "HGetAllKey1", - presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, command: []string{"HGETALL", "HGetAllKey1"}, - expectedResponse: []string{"field1", "value1", "field2", "123456789", "field3", "3.142"}, - expectedValue: map[string]interface{}{}, + expectedResponse: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, expectedError: nil, }, { name: "2. Empty array response when trying to call HGETALL on non-existent key", - preset: false, key: "HGetAllKey2", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HGETALL", "HGetAllKey2"}, - expectedResponse: []string{}, - expectedValue: map[string]interface{}{}, + expectedResponse: nil, expectedError: nil, }, { name: "3. Command too short", - preset: false, key: "HGetAllKey3", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HGETALL"}, expectedResponse: nil, - expectedValue: map[string]interface{}{}, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "4. Command too long", - preset: false, key: "HGetAllKey4", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HGETALL", "HGetAllKey4", "HGetAllKey4"}, expectedResponse: nil, - expectedValue: map[string]interface{}{}, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "5. Trying to get lengths on a non hash map returns error", - preset: true, key: "HGetAllKey5", presetValue: "Default value", command: []string{"HGETALL", "HGetAllKey5"}, expectedResponse: nil, - expectedValue: map[string]interface{}{}, expectedError: errors.New("value at HGetAllKey5 is not a hash"), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HGETALL, %d", i)) + if test.presetValue != nil { + var command []resp.Value + var expected string - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if len(rv.Array()) != len(test.expectedResponse) { - t.Errorf("expected length \"%d\", got \"%d\"", len(test.expectedResponse), len(rv.Array())) - } - // In the response: - // The order of results is not guaranteed, - // However, each field in the array will be reliably followed by its corresponding value - responseArray := rv.Array() - for i := 0; i < len(responseArray); i++ { - if i%2 == 0 { - // We're on a field in the response - field := responseArray[i].String() - value := responseArray[i+1].String() - expectedFieldIndex := slices.Index(test.expectedResponse, field) - if expectedFieldIndex == -1 { - t.Errorf("received unexpected field \"%s\" in response", field) - } - expectedValue := test.expectedResponse[expectedFieldIndex+1] - if expectedValue != value { - t.Errorf("expected entry \"%s\", got \"%s\"", expectedValue, value) + if test.expectedResponse == nil { + if len(res.Array()) != 0 { + t.Errorf("expected response to be empty array, got %+v", res) + } + return + } + + for i, item := range res.Array() { + if i%2 == 0 { + field := item.String() + value := res.Array()[i+1].String() + if test.expectedResponse[field] != value { + t.Errorf("expected value at field \"%s\" to be \"%s\", got \"%s\"", field, test.expectedResponse[field], value) } } - } - return + }) } } func Test_HandleHEXISTS(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool key string presetValue interface{} command []string - expectedResponse interface{} - expectedValue map[string]interface{} + expectedResponse bool expectedError error }{ { name: "1. Return 1 if the field exists in the hash", - preset: true, key: "HexistsKey1", - presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, command: []string{"HEXISTS", "HexistsKey1", "field1"}, - expectedResponse: 1, - expectedValue: map[string]interface{}{}, + expectedResponse: true, expectedError: nil, }, { name: "2. 0 response when trying to call HEXISTS on non-existent key", - preset: false, key: "HexistsKey2", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HEXISTS", "HexistsKey2", "field1"}, - expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedResponse: false, expectedError: nil, }, { name: "3. Command too short", - preset: false, key: "HexistsKey3", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HEXISTS", "HexistsKey3"}, - expectedResponse: nil, - expectedValue: map[string]interface{}{}, + expectedResponse: false, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "4. Command too long", - preset: false, key: "HexistsKey4", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HEXISTS", "HexistsKey4", "field1", "field2"}, - expectedResponse: nil, - expectedValue: map[string]interface{}{}, + expectedResponse: false, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "5. Trying to get lengths on a non hash map returns error", - preset: true, key: "HexistsKey5", presetValue: "Default value", command: []string{"HEXISTS", "HexistsKey5", "field1"}, - expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedResponse: false, expectedError: errors.New("value at HexistsKey5 is not a hash"), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HEXISTS, %d", i)) + if test.presetValue != nil { + var command []resp.Value + var expected string - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) + + if res.Bool() != test.expectedResponse { + t.Errorf("expected response to be %v, got %v", test.expectedResponse, res.Bool()) } - if expectedResponse, ok := test.expectedResponse.(int); ok { - if rv.Integer() != expectedResponse { - t.Errorf("expected \"%d\", got \"%d\"", expectedResponse, rv.Integer()) - } - return - } - t.Error("expected integer response, got another type") }) } } func Test_HandleHDEL(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool key string presetValue interface{} command []string - expectedResponse interface{} - expectedValue map[string]interface{} + expectedResponse int + expectedValue map[string]string expectedError error }{ { name: "1. Return count of deleted fields in the specified hash", - preset: true, key: "HdelKey1", - presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142, "field7": "value7"}, + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142", "field7": "value7"}, command: []string{"HDEL", "HdelKey1", "field1", "field2", "field3", "field4", "field5", "field6"}, expectedResponse: 3, - expectedValue: map[string]interface{}{"field1": nil, "field2": nil, "field3": nil, "field7": "value1"}, + expectedValue: map[string]string{"field7": "value7"}, expectedError: nil, }, { name: "2. 0 response when passing delete fields that are non-existent on valid hash", - preset: true, key: "HdelKey2", - presetValue: map[string]interface{}{"field1": "value1", "field2": "value2", "field3": "value3"}, + presetValue: map[string]string{"field1": "value1", "field2": "value2", "field3": "value3"}, command: []string{"HDEL", "HdelKey2", "field4", "field5", "field6"}, expectedResponse: 0, - expectedValue: map[string]interface{}{"field1": "value1", "field2": "value2", "field3": "value3"}, + expectedValue: map[string]string{"field1": "value1", "field2": "value2", "field3": "value3"}, expectedError: nil, }, { name: "3. 0 response when trying to call HDEL on non-existent key", - preset: false, key: "HdelKey3", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HDEL", "HdelKey3", "field1"}, expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedValue: nil, expectedError: nil, }, { name: "4. Command too short", - preset: false, key: "HdelKey4", - presetValue: map[string]interface{}{}, + presetValue: nil, command: []string{"HDEL", "HdelKey4"}, - expectedResponse: nil, - expectedValue: map[string]interface{}{}, + expectedResponse: 0, + expectedValue: nil, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "5. Trying to get lengths on a non hash map returns error", - preset: true, key: "HdelKey5", presetValue: "Default value", command: []string{"HDEL", "HdelKey5", "field1"}, expectedResponse: 0, - expectedValue: map[string]interface{}{}, + expectedValue: nil, expectedError: errors.New("value at HdelKey5 is not a hash"), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HDEL, %d", i)) + if test.presetValue != nil { + var command []resp.Value + var expected string - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) } - if expectedResponse, ok := test.expectedResponse.(int); ok { - if rv.Integer() != expectedResponse { - t.Errorf("expected \"%d\", got \"%d\"", expectedResponse, rv.Integer()) - } - return - } - if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { - t.Error(err) - } - if h, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{}); ok { - for field, value := range h { - if value != test.expectedValue[field] { - t.Errorf("expected value \"%+v\", got \"%+v\"", test.expectedValue[field], value) + + for idx, field := range res.Array() { + if idx%2 == 0 { + if res.Array()[idx+1].String() != test.expectedValue[field.String()] { + t.Errorf( + "expected value \"%+v\" for field \"%s\", got \"%+v\"", + test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), + ) } } - return } - t.Error("expected hash value but got another type") }) } }