Instead of passing in EchoVault instance to commands handler, we now pass a struct of params containing all the variables and functions used within the handler function. This removes the modules' dependency on the echovault package. Moved string command and api tests to test/modules/string

This commit is contained in:
Kelvin Clement Mwinuka
2024-04-24 16:34:59 +08:00
parent 2b01c7342c
commit fbf4782b7c
6 changed files with 205 additions and 76 deletions

View File

@@ -86,33 +86,34 @@ func (fsm *FSM) Apply(log *raft.Log) interface{} {
}
case "command":
// TODO: Re-Implement Command handling with dependency injection
// Handle command
command, err := fsm.options.GetCommand(request.CMD[0])
if err != nil {
return internal.ApplyResponse{
Error: err,
Response: nil,
}
}
handler := command.HandlerFunc
subCommand, ok := internal.GetSubCommand(command, request.CMD).(types.SubCommand)
if ok {
handler = subCommand.HandlerFunc
}
if res, err := handler(ctx, request.CMD, fsm.options.EchoVault, nil); err != nil {
return internal.ApplyResponse{
Error: err,
Response: nil,
}
} else {
return internal.ApplyResponse{
Error: nil,
Response: res,
}
}
// command, err := fsm.options.GetCommand(request.CMD[0])
// if err != nil {
// return internal.ApplyResponse{
// Error: err,
// Response: nil,
// }
// }
//
// handler := command.HandlerFunc
//
// subCommand, ok := internal.GetSubCommand(command, request.CMD).(types.SubCommand)
// if ok {
// handler = subCommand.HandlerFunc
// }
//
// if res, err := handler(ctx, request.CMD, fsm.options.EchoVault, nil); err != nil {
// return internal.ApplyResponse{
// Error: err,
// Response: nil,
// }
// } else {
// return internal.ApplyResponse{
// Error: nil,
// Response: res,
// }
// }
}
}

View File

@@ -46,6 +46,23 @@ func (server *EchoVault) getCommand(cmd string) (types.Command, error) {
return types.Command{}, fmt.Errorf("command %s not supported", cmd)
}
func (server *EchoVault) getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) types.HandlerFuncParams {
return types.HandlerFuncParams{
// TODO: Add all the required methods here
Context: ctx,
Command: cmd,
Connection: conn,
KeyExists: server.KeyExists,
CreateKeyAndLock: server.CreateKeyAndLock,
KeyLock: server.KeyLock,
KeyRLock: server.KeyRLock,
KeyUnlock: server.KeyUnlock,
KeyRUnlock: server.KeyRUnlock,
GetValue: server.GetValue,
SetValue: server.SetValue,
}
}
func (server *EchoVault) handleCommand(ctx context.Context, message []byte, conn *net.Conn, replay bool, embedded bool) ([]byte, error) {
cmd, err := internal.Decode(message)
if err != nil {
@@ -85,7 +102,7 @@ func (server *EchoVault) handleCommand(ctx context.Context, message []byte, conn
}
if !server.isInCluster() || !synchronize {
res, err := handler(ctx, cmd, server, conn)
res, err := handler(server.getHandlerFuncParams(ctx, cmd, conn))
if err != nil {
return nil, err
}

View File

@@ -15,47 +15,45 @@
package str
import (
"context"
"errors"
"fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/types"
"net"
)
func handleSetRange(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := setRangeKeyFunc(cmd)
func handleSetRange(params types.HandlerFuncParams) ([]byte, error) {
keys, err := setRangeKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.WriteKeys[0]
offset, ok := internal.AdaptType(cmd[2]).(int)
offset, ok := internal.AdaptType(params.Command[2]).(int)
if !ok {
return nil, errors.New("offset must be an integer")
}
newStr := cmd[3]
newStr := params.Command[3]
if !server.KeyExists(ctx, key) {
if _, err = server.CreateKeyAndLock(ctx, key); err != nil {
if !params.KeyExists(params.Context, key) {
if _, err = params.CreateKeyAndLock(params.Context, key); err != nil {
return nil, err
}
if err = server.SetValue(ctx, key, newStr); err != nil {
if err = params.SetValue(params.Context, key, newStr); err != nil {
return nil, err
}
server.KeyUnlock(ctx, key)
params.KeyUnlock(params.Context, key)
return []byte(fmt.Sprintf(":%d\r\n", len(newStr))), nil
}
if _, err := server.KeyLock(ctx, key); err != nil {
if _, err := params.KeyLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyUnlock(ctx, key)
defer params.KeyUnlock(params.Context, key)
str, ok := server.GetValue(ctx, key).(string)
str, ok := params.GetValue(params.Context, key).(string)
if !ok {
return nil, fmt.Errorf("value at key %s is not a string", key)
}
@@ -63,7 +61,7 @@ func handleSetRange(ctx context.Context, cmd []string, server types.EchoVault, _
// If the offset >= length of the string, append the new string to the old one.
if offset >= len(str) {
newStr = str + newStr
if err = server.SetValue(ctx, key, newStr); err != nil {
if err = params.SetValue(params.Context, key, newStr); err != nil {
return nil, err
}
return []byte(fmt.Sprintf(":%d\r\n", len(newStr))), nil
@@ -72,7 +70,7 @@ func handleSetRange(ctx context.Context, cmd []string, server types.EchoVault, _
// If the offset is < 0, prepend the new string to the old one.
if offset < 0 {
newStr = newStr + str
if err = server.SetValue(ctx, key, newStr); err != nil {
if err = params.SetValue(params.Context, key, newStr); err != nil {
return nil, err
}
return []byte(fmt.Sprintf(":%d\r\n", len(newStr))), nil
@@ -92,31 +90,31 @@ func handleSetRange(ctx context.Context, cmd []string, server types.EchoVault, _
break
}
if err = server.SetValue(ctx, key, string(strRunes)); err != nil {
if err = params.SetValue(params.Context, key, string(strRunes)); err != nil {
return nil, err
}
return []byte(fmt.Sprintf(":%d\r\n", len(strRunes))), nil
}
func handleStrLen(ctx context.Context, cmd []string, server types.EchoVault, conn *net.Conn) ([]byte, error) {
keys, err := strLenKeyFunc(cmd)
func handleStrLen(params types.HandlerFuncParams) ([]byte, error) {
keys, err := strLenKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return []byte(":0\r\n"), nil
}
if _, err := server.KeyRLock(ctx, key); err != nil {
if _, err := params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
value, ok := server.GetValue(ctx, key).(string)
value, ok := params.GetValue(params.Context, key).(string)
if !ok {
return nil, fmt.Errorf("value at key %s is not a string", key)
@@ -125,32 +123,32 @@ func handleStrLen(ctx context.Context, cmd []string, server types.EchoVault, con
return []byte(fmt.Sprintf(":%d\r\n", len(value))), nil
}
func handleSubStr(ctx context.Context, cmd []string, server types.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := subStrKeyFunc(cmd)
func handleSubStr(params types.HandlerFuncParams) ([]byte, error) {
keys, err := subStrKeyFunc(params.Command)
if err != nil {
return nil, err
}
key := keys.ReadKeys[0]
start, startOk := internal.AdaptType(cmd[2]).(int)
end, endOk := internal.AdaptType(cmd[3]).(int)
start, startOk := internal.AdaptType(params.Command[2]).(int)
end, endOk := internal.AdaptType(params.Command[3]).(int)
reversed := false
if !startOk || !endOk {
return nil, errors.New("start and end indices must be integers")
}
if !server.KeyExists(ctx, key) {
if !params.KeyExists(params.Context, key) {
return nil, fmt.Errorf("key %s does not exist", key)
}
if _, err = server.KeyRLock(ctx, key); err != nil {
if _, err = params.KeyRLock(params.Context, key); err != nil {
return nil, err
}
defer server.KeyRUnlock(ctx, key)
defer params.KeyRUnlock(params.Context, key)
value, ok := server.GetValue(ctx, key).(string)
value, ok := params.GetValue(params.Context, key).(string)
if !ok {
return nil, fmt.Errorf("value at key %s is not a string", key)
}

View File

@@ -50,7 +50,31 @@ type AccessKeys struct {
}
type KeyExtractionFunc func(cmd []string) (AccessKeys, error)
type HandlerFunc func(ctx context.Context, cmd []string, echovault EchoVault, conn *net.Conn) ([]byte, error)
type HandlerFuncParams struct {
Context context.Context
Command []string
Connection *net.Conn
KeyLock func(ctx context.Context, key string) (bool, error)
KeyUnlock func(ctx context.Context, key string)
KeyRLock func(ctx context.Context, key string) (bool, error)
KeyRUnlock func(ctx context.Context, key string)
KeyExists func(ctx context.Context, key string) bool
CreateKeyAndLock func(ctx context.Context, key string) (bool, error)
GetValue func(ctx context.Context, key string) interface{}
SetValue func(ctx context.Context, key string, value interface{}) error
GetExpiry func(ctx context.Context, key string) time.Time
SetExpiry func(ctx context.Context, key string, expire time.Time, touch bool)
RemoveExpiry func(key string)
DeleteKey func(ctx context.Context, key string) error
GetClock func() clock.Clock
GetAllCommands func() []Command
GetACL func() interface{}
GetPubSub func() interface{}
TakeSnapshot func() error
RewriteAOF func() error
GetLatestSnapshotTime func() int64
}
type HandlerFunc func(params HandlerFuncParams) ([]byte, error)
type SubCommand struct {
Command string

View File

@@ -12,19 +12,32 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package echovault
package str
import (
"context"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/commands"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
str "github.com/echovault/echovault/pkg/modules/string"
"testing"
)
func presetValue(server *echovault.EchoVault, ctx context.Context, key string, value interface{}) error {
if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
return err
}
if err := server.SetValue(ctx, key, value); err != nil {
return err
}
server.KeyUnlock(ctx, key)
return nil
}
func TestEchoVault_SUBSTR(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
server, _ := echovault.NewEchoVault(
echovault.WithCommands(str.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
@@ -168,7 +181,11 @@ func TestEchoVault_SUBSTR(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := tt.substrFunc(tt.key, tt.start, tt.end)
if (err != nil) != tt.wantErr {
@@ -183,9 +200,9 @@ func TestEchoVault_SUBSTR(t *testing.T) {
}
func TestEchoVault_SETRANGE(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
server, _ := echovault.NewEchoVault(
echovault.WithCommands(str.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
@@ -258,7 +275,11 @@ func TestEchoVault_SETRANGE(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.SETRANGE(tt.key, tt.offset, tt.new)
if (err != nil) != tt.wantErr {
@@ -273,9 +294,9 @@ func TestEchoVault_SETRANGE(t *testing.T) {
}
func TestEchoVault_STRLEN(t *testing.T) {
server, _ := NewEchoVault(
WithCommands(commands.All()),
WithConfig(config.Config{
server, _ := echovault.NewEchoVault(
echovault.WithCommands(str.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
@@ -306,7 +327,11 @@ func TestEchoVault_STRLEN(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.presetValue != nil {
presetValue(server, tt.key, tt.presetValue)
err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil {
t.Error(err)
return
}
}
got, err := server.STRLEN(tt.key)
if (err != nil) != tt.wantErr {

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package str
package str_test
import (
"bytes"
@@ -23,8 +23,11 @@ import (
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/pkg/constants"
"github.com/echovault/echovault/pkg/echovault"
str "github.com/echovault/echovault/pkg/modules/string"
"github.com/echovault/echovault/pkg/types"
"github.com/tidwall/resp"
"strconv"
"strings"
"testing"
)
@@ -32,6 +35,7 @@ var mockServer *echovault.EchoVault
func init() {
mockServer, _ = echovault.NewEchoVault(
echovault.WithCommands(str.Commands()),
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
@@ -39,6 +43,15 @@ func init() {
)
}
func getHandler(command string) types.HandlerFunc {
for _, c := range mockServer.GetAllCommands() {
if strings.EqualFold(command, c.Command) {
return c.HandlerFunc
}
}
return nil
}
func Test_HandleSetRange(t *testing.T) {
tests := []struct {
name string
@@ -157,7 +170,24 @@ func Test_HandleSetRange(t *testing.T) {
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleSetRange(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(types.HandlerFuncParams{
Context: ctx,
Command: test.command,
Connection: nil,
KeyExists: mockServer.KeyExists,
CreateKeyAndLock: mockServer.CreateKeyAndLock,
KeyLock: mockServer.KeyLock,
KeyUnlock: mockServer.KeyUnlock,
GetValue: mockServer.GetValue,
SetValue: mockServer.SetValue,
})
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -254,7 +284,24 @@ func Test_HandleStrLen(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleStrLen(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(types.HandlerFuncParams{
Context: ctx,
Command: test.command,
Connection: nil,
KeyExists: mockServer.KeyExists,
KeyRLock: mockServer.KeyRLock,
KeyRUnlock: mockServer.KeyRUnlock,
GetValue: mockServer.GetValue,
SetValue: mockServer.SetValue,
})
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -382,7 +429,24 @@ func Test_HandleSubStr(t *testing.T) {
}
mockServer.KeyUnlock(ctx, test.key)
}
res, err := handleSubStr(ctx, test.command, mockServer, nil)
handler := getHandler(test.command[0])
if handler == nil {
t.Errorf("no handler found for command %s", test.command[0])
return
}
res, err := handler(types.HandlerFuncParams{
Context: ctx,
Command: test.command,
Connection: nil,
KeyExists: mockServer.KeyExists,
KeyRLock: mockServer.KeyRLock,
KeyRUnlock: mockServer.KeyRUnlock,
GetValue: mockServer.GetValue,
SetValue: mockServer.SetValue,
})
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())