mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-04 23:52:42 +08:00
Removed logic to get unexported methods from the echovault package in all tests.
This commit is contained in:
@@ -15,8 +15,6 @@
|
||||
package admin_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/echovault"
|
||||
@@ -36,12 +34,10 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func setupServer(port uint16) (*echovault.EchoVault, error) {
|
||||
@@ -53,45 +49,7 @@ func setupServer(port uint16) (*echovault.EchoVault, error) {
|
||||
return echovault.NewEchoVault(echovault.WithConfig(cfg))
|
||||
}
|
||||
|
||||
func getUnexportedField(field reflect.Value) interface{} {
|
||||
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
|
||||
}
|
||||
|
||||
func getHandler(mockServer *echovault.EchoVault, commands ...string) internal.HandlerFunc {
|
||||
if len(commands) == 0 {
|
||||
return nil
|
||||
}
|
||||
getCommands :=
|
||||
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
|
||||
for _, c := range getCommands() {
|
||||
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
|
||||
// Get command handler
|
||||
return c.HandlerFunc
|
||||
}
|
||||
if strings.EqualFold(commands[0], c.Command) {
|
||||
// Get sub-command handler
|
||||
for _, sc := range c.SubCommands {
|
||||
if strings.EqualFold(commands[1], sc.Command) {
|
||||
return sc.HandlerFunc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHandlerFuncParams(ctx context.Context, mockServer *echovault.EchoVault, cmd []string, conn *net.Conn) internal.HandlerFuncParams {
|
||||
getCommands :=
|
||||
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
|
||||
return internal.HandlerFuncParams{
|
||||
Context: ctx,
|
||||
Command: cmd,
|
||||
Connection: conn,
|
||||
GetAllCommands: getCommands,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_AdminCommand(t *testing.T) {
|
||||
func Test_AdminCommands(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll("./testdata")
|
||||
})
|
||||
@@ -104,23 +62,37 @@ func Test_AdminCommand(t *testing.T) {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
mockServer, err := setupServer(uint16(port))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
res, err := getHandler(mockServer, "COMMANDS")(
|
||||
getHandlerFuncParams(context.Background(), mockServer, []string{"commands"}, nil),
|
||||
)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
mockServer.Start()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
if err = client.WriteArray([]resp.Value{resp.StringValue("COMMANDS")}); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
rd := resp.NewReader(bytes.NewReader(res))
|
||||
rv, _, err := rd.ReadValue()
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Get all the commands from the existing modules.
|
||||
@@ -148,8 +120,8 @@ func Test_AdminCommand(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(allCommands) != len(rv.Array()) {
|
||||
t.Errorf("expected commands list to be of length %d, got %d", len(allCommands), len(rv.Array()))
|
||||
if len(allCommands) != len(res.Array()) {
|
||||
t.Errorf("expected commands list to be of length %d, got %d", len(allCommands), len(res.Array()))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -161,23 +133,37 @@ func Test_AdminCommand(t *testing.T) {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
mockServer, err := setupServer(uint16(port))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
res, err := getHandler(mockServer, "COMMAND", "COUNT")(
|
||||
getHandlerFuncParams(context.Background(), mockServer, []string{"command", "count"}, nil),
|
||||
)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
mockServer.Start()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
if err = client.WriteArray([]resp.Value{resp.StringValue("COMMAND"), resp.StringValue("COUNT")}); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
rd := resp.NewReader(bytes.NewReader(res))
|
||||
rv, _, err := rd.ReadValue()
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Get all the commands from the existing modules.
|
||||
@@ -205,8 +191,8 @@ func Test_AdminCommand(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(allCommands) != rv.Integer() {
|
||||
t.Errorf("expected COMMAND COUNT to return %d, got %d", len(allCommands), rv.Integer())
|
||||
if len(allCommands) != res.Integer() {
|
||||
t.Errorf("expected COMMAND COUNT to return %d, got %d", len(allCommands), res.Integer())
|
||||
}
|
||||
})
|
||||
|
||||
@@ -225,6 +211,21 @@ func Test_AdminCommand(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
mockServer.Start()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
// Get all the commands from the existing modules.
|
||||
var allCommands []internal.Command
|
||||
allCommands = append(allCommands, acl.Commands()...)
|
||||
@@ -305,24 +306,26 @@ func Test_AdminCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
res, err := getHandler(mockServer, test.cmd...)(
|
||||
getHandlerFuncParams(context.Background(), mockServer, test.cmd, nil),
|
||||
)
|
||||
command := make([]resp.Value, len(test.cmd))
|
||||
for i, c := range test.cmd {
|
||||
command[i] = resp.StringValue(c)
|
||||
}
|
||||
if err = client.WriteArray(command); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
rd := resp.NewReader(bytes.NewReader(res))
|
||||
rv, _, err := rd.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
if len(res.Array()) != len(test.want) {
|
||||
t.Errorf("expected response of length %d, got %d", len(test.want), len(res.Array()))
|
||||
}
|
||||
|
||||
if len(rv.Array()) != len(test.want) {
|
||||
t.Errorf("expected response of length %d, got %d", len(test.want), len(rv.Array()))
|
||||
}
|
||||
|
||||
for _, command := range rv.Array() {
|
||||
for _, command := range res.Array() {
|
||||
if !slices.ContainsFunc(test.want, func(c string) bool {
|
||||
return strings.EqualFold(c, command.String())
|
||||
}) {
|
||||
@@ -441,6 +444,7 @@ func Test_AdminCommand(t *testing.T) {
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
respConn := resp.NewConn(conn)
|
||||
|
Reference in New Issue
Block a user