diff --git a/echovault/echovault.go b/echovault/echovault.go index 07bce2d..8e5ac40 100644 --- a/echovault/echovault.go +++ b/echovault/echovault.go @@ -36,6 +36,7 @@ import ( "github.com/echovault/echovault/internal/modules/pubsub" "github.com/echovault/echovault/internal/modules/set" "github.com/echovault/echovault/internal/modules/sorted_set" + str "github.com/echovault/echovault/internal/modules/string" "github.com/echovault/echovault/internal/raft" "github.com/echovault/echovault/internal/snapshot" "io" @@ -143,7 +144,7 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) { commands = append(commands, pubsub.Commands()...) commands = append(commands, set.Commands()...) commands = append(commands, sorted_set.Commands()...) - // commands = append(commands, str.Commands()...) + commands = append(commands, str.Commands()...) return commands }(), } diff --git a/internal/modules/string/commands.go b/internal/modules/string/commands.go index b466228..e933b61 100644 --- a/internal/modules/string/commands.go +++ b/internal/modules/string/commands.go @@ -28,6 +28,7 @@ func handleSetRange(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.WriteKeys[0] + keyExists := params.KeysExist(keys.WriteKeys)[key] offset, ok := internal.AdaptType(params.Command[2]).(int) if !ok { @@ -36,23 +37,11 @@ func handleSetRange(params internal.HandlerFuncParams) ([]byte, error) { newStr := params.Command[3] - if !params.KeyExists(params.Context, key) { - if _, err = params.CreateKeyAndLock(params.Context, key); err != nil { - return nil, err - } - if err = params.SetValue(params.Context, key, newStr); err != nil { - return nil, err - } - params.KeyUnlock(params.Context, key) + if !keyExists { return []byte(fmt.Sprintf(":%d\r\n", len(newStr))), nil } - if _, err := params.KeyLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) - - str, ok := params.GetValue(params.Context, key).(string) + str, ok := params.GetValues(params.Context, []string{key})[key].(string) if !ok { return nil, fmt.Errorf("value at key %s is not a string", key) } @@ -60,7 +49,7 @@ func handleSetRange(params internal.HandlerFuncParams) ([]byte, error) { // If the offset >= length of the string, append the new string to the old one. if offset >= len(str) { newStr = str + newStr - if err = params.SetValue(params.Context, key, newStr); err != nil { + if err = params.SetValues(params.Context, map[string]interface{}{key: newStr}); err != nil { return nil, err } return []byte(fmt.Sprintf(":%d\r\n", len(newStr))), nil @@ -69,7 +58,7 @@ func handleSetRange(params internal.HandlerFuncParams) ([]byte, error) { // If the offset is < 0, prepend the new string to the old one. if offset < 0 { newStr = newStr + str - if err = params.SetValue(params.Context, key, newStr); err != nil { + if err = params.SetValues(params.Context, map[string]interface{}{key: newStr}); err != nil { return nil, err } return []byte(fmt.Sprintf(":%d\r\n", len(newStr))), nil @@ -89,7 +78,7 @@ func handleSetRange(params internal.HandlerFuncParams) ([]byte, error) { break } - if err = params.SetValue(params.Context, key, string(strRunes)); err != nil { + if err = params.SetValues(params.Context, map[string]interface{}{key: string(strRunes)}); err != nil { return nil, err } @@ -103,17 +92,13 @@ func handleStrLen(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte(":0\r\n"), nil } - if _, err := params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - value, ok := params.GetValue(params.Context, key).(string) + value, ok := params.GetValues(params.Context, []string{key})[key].(string) if !ok { return nil, fmt.Errorf("value at key %s is not a string", key) @@ -129,6 +114,7 @@ func handleSubStr(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] start, startOk := internal.AdaptType(params.Command[2]).(int) end, endOk := internal.AdaptType(params.Command[3]).(int) @@ -138,16 +124,11 @@ func handleSubStr(params internal.HandlerFuncParams) ([]byte, error) { return nil, errors.New("start and end indices must be integers") } - if !params.KeyExists(params.Context, key) { + if !keyExists { return nil, fmt.Errorf("key %s does not exist", key) } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - value, ok := params.GetValue(params.Context, key).(string) + value, ok := params.GetValues(params.Context, []string{key})[key].(string) if !ok { return nil, fmt.Errorf("value at key %s is not a string", key) } diff --git a/internal/modules/string/commands_test.go b/internal/modules/string/commands_test.go index a393008..6acd90e 100644 --- a/internal/modules/string/commands_test.go +++ b/internal/modules/string/commands_test.go @@ -15,8 +15,6 @@ package str_test import ( - "bytes" - "context" "errors" "fmt" "github.com/echovault/echovault/echovault" @@ -25,71 +23,44 @@ import ( "github.com/echovault/echovault/internal/constants" "github.com/tidwall/resp" "net" - "reflect" "strconv" "strings" + "sync" "testing" - "unsafe" ) var mockServer *echovault.EchoVault +var addr = "localhost" +var port int func init() { + port, _ = internal.GetFreePort() mockServer, _ = echovault.NewEchoVault( echovault.WithConfig(config.Config{ + BindAddr: addr, + Port: uint16(port), DataDir: "", EvictionPolicy: constants.NoEviction, }), ) -} - -func getUnexportedField(field reflect.Value) interface{} { - return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface() -} - -func getHandler(commands ...string) internal.HandlerFunc { - if len(commands) == 0 { - return nil - } - getCommands := - getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command) - for _, c := range getCommands() { - if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 { - // Get command handler - return c.HandlerFunc - } - if strings.EqualFold(commands[0], c.Command) { - // Get sub-command handler - for _, sc := range c.SubCommands { - if strings.EqualFold(commands[1], sc.Command) { - return sc.HandlerFunc - } - } - } - } - return nil -} - -func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) internal.HandlerFuncParams { - return internal.HandlerFuncParams{ - Context: ctx, - Command: cmd, - Connection: conn, - KeyExists: mockServer.KeyExists, - CreateKeyAndLock: mockServer.CreateKeyAndLock, - KeyLock: mockServer.KeyLock, - KeyRLock: mockServer.KeyRLock, - KeyUnlock: mockServer.KeyUnlock, - KeyRUnlock: mockServer.KeyRUnlock, - GetValue: mockServer.GetValue, - SetValue: mockServer.SetValue, - } + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + wg.Done() + mockServer.Start() + }() + wg.Wait() } func Test_HandleSetRange(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool key string presetValue string command []string @@ -99,7 +70,6 @@ func Test_HandleSetRange(t *testing.T) { }{ { name: "Test that SETRANGE on non-existent string creates new string", - preset: false, key: "SetRangeKey1", presetValue: "", command: []string{"SETRANGE", "SetRangeKey1", "10", "New String Value"}, @@ -109,7 +79,6 @@ func Test_HandleSetRange(t *testing.T) { }, { name: "Test SETRANGE with an offset that leads to a longer resulting string", - preset: true, key: "SetRangeKey2", presetValue: "Original String Value", command: []string{"SETRANGE", "SetRangeKey2", "16", "Portion Replaced With This New String"}, @@ -119,7 +88,6 @@ func Test_HandleSetRange(t *testing.T) { }, { name: "SETRANGE with negative offset prepends the string", - preset: true, key: "SetRangeKey3", presetValue: "This is a preset value", command: []string{"SETRANGE", "SetRangeKey3", "-10", "Prepended "}, @@ -129,7 +97,6 @@ func Test_HandleSetRange(t *testing.T) { }, { name: "SETRANGE with offset that embeds new string inside the old string", - preset: true, key: "SetRangeKey4", presetValue: "This is a preset value", command: []string{"SETRANGE", "SetRangeKey4", "0", "That"}, @@ -139,7 +106,6 @@ func Test_HandleSetRange(t *testing.T) { }, { name: "SETRANGE with offset longer than original lengths appends the string", - preset: true, key: "SetRangeKey5", presetValue: "This is a preset value", command: []string{"SETRANGE", "SetRangeKey5", "100", " Appended"}, @@ -149,7 +115,6 @@ func Test_HandleSetRange(t *testing.T) { }, { name: "SETRANGE with offset on the last character replaces last character with new string", - preset: true, key: "SetRangeKey6", presetValue: "This is a preset value", command: []string{"SETRANGE", "SetRangeKey6", strconv.Itoa(len("This is a preset value") - 1), " replaced"}, @@ -159,14 +124,12 @@ func Test_HandleSetRange(t *testing.T) { }, { name: " Offset not integer", - preset: false, command: []string{"SETRANGE", "key", "offset", "value"}, expectedResponse: 0, expectedError: errors.New("offset must be an integer"), }, { name: "SETRANGE target is not a string", - preset: true, key: "test-int", presetValue: "10", command: []string{"SETRANGE", "test-int", "10", "value"}, @@ -175,81 +138,74 @@ func Test_HandleSetRange(t *testing.T) { }, { name: "Command too short", - preset: false, command: []string{"SETRANGE", "key"}, expectedResponse: 0, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "Command too long", - preset: false, command: []string{"SETRANGE", "key", "offset", "value", "value1"}, expectedResponse: 0, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SETRANGE, %d", i)) - - // If there's a preset step, carry it out here - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + if test.presetValue != "" { + if err = client.WriteArray([]resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue), + }); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, internal.AdaptType(test.presetValue)); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - if err != nil { - t.Error(err) - } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, rv.Integer()) - } - // Get the value from the echovault and check against the expected value - if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { - t.Error(err) + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) } - value, ok := mockServer.GetValue(ctx, test.key).(string) - if !ok { - t.Error("expected string data type, got another type") - } - if value != test.expectedValue { - t.Errorf("expected value \"%s\", got \"%s\"", test.expectedValue, value) - } - mockServer.KeyRUnlock(ctx, test.key) }) } } func Test_HandleStrLen(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool key string presetValue string command []string @@ -258,7 +214,6 @@ func Test_HandleStrLen(t *testing.T) { }{ { name: "Return the correct string length for an existing string", - preset: true, key: "StrLenKey1", presetValue: "Test String", command: []string{"STRLEN", "StrLenKey1"}, @@ -267,7 +222,6 @@ func Test_HandleStrLen(t *testing.T) { }, { name: "If the string does not exist, return 0", - preset: false, key: "StrLenKey2", presetValue: "", command: []string{"STRLEN", "StrLenKey2"}, @@ -276,7 +230,6 @@ func Test_HandleStrLen(t *testing.T) { }, { name: "Too few args", - preset: false, key: "StrLenKey3", presetValue: "", command: []string{"STRLEN"}, @@ -285,7 +238,6 @@ func Test_HandleStrLen(t *testing.T) { }, { name: "Too many args", - preset: false, key: "StrLenKey4", presetValue: "", command: []string{"STRLEN", "StrLenKey4", "StrLenKey5"}, @@ -294,51 +246,62 @@ func Test_HandleStrLen(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("STRLEN, %d", i)) - - if test.preset { - _, err := mockServer.CreateKeyAndLock(ctx, test.key) + if test.presetValue != "" { + if err = client.WriteArray([]resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue), + }); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() if err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { - t.Error(err) + + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) } - mockServer.KeyUnlock(ctx, test.key) } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected respons \"%d\", got \"%d\"", test.expectedResponse, rv.Integer()) + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) } }) } } func Test_HandleSubStr(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error(err) + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool key string presetValue string command []string @@ -347,7 +310,6 @@ func Test_HandleSubStr(t *testing.T) { }{ { name: "Return substring within the range of the string", - preset: true, key: "SubStrKey1", presetValue: "Test String One", command: []string{"SUBSTR", "SubStrKey1", "5", "10"}, @@ -356,7 +318,6 @@ func Test_HandleSubStr(t *testing.T) { }, { name: "Return substring at the end of the string with exact end index", - preset: true, key: "SubStrKey2", presetValue: "Test String Two", command: []string{"SUBSTR", "SubStrKey2", "12", "14"}, @@ -365,7 +326,6 @@ func Test_HandleSubStr(t *testing.T) { }, { name: "Return substring at the end of the string with end index greater than length", - preset: true, key: "SubStrKey3", presetValue: "Test String Three", command: []string{"SUBSTR", "SubStrKey3", "12", "75"}, @@ -374,7 +334,6 @@ func Test_HandleSubStr(t *testing.T) { }, { name: "Return the substring at the start of the string with 0 start index", - preset: true, key: "SubStrKey4", presetValue: "Test String Four", command: []string{"SUBSTR", "SubStrKey4", "0", "3"}, @@ -385,7 +344,6 @@ func Test_HandleSubStr(t *testing.T) { // Return the substring with negative start index. // Substring should begin abs(start) from the end of the string when start is negative. name: "Return the substring with negative start index", - preset: true, key: "SubStrKey5", presetValue: "Test String Five", command: []string{"SUBSTR", "SubStrKey5", "-11", "10"}, @@ -396,7 +354,6 @@ func Test_HandleSubStr(t *testing.T) { // Return reverse substring with end index smaller than start index. // When end index is smaller than start index, the 2 indices are reversed. name: "Return reverse substring with end index smaller than start index", - preset: true, key: "SubStrKey6", presetValue: "Test String Six", command: []string{"SUBSTR", "SubStrKey6", "4", "0"}, @@ -430,42 +387,48 @@ func Test_HandleSubStr(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SUBSTR, %d", i)) - - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + if test.presetValue != "" { + if err = client.WriteArray([]resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue), + }); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.String() != test.expectedResponse { - t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, rv.String()) + + if res.String() != test.expectedResponse { + t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) } }) }