Removed logic to get unexported methods from the echovault package in all tests.

This commit is contained in:
Kelvin Clement Mwinuka
2024-05-30 19:33:01 +08:00
parent e1d5e8203f
commit 502e804459
14 changed files with 3050 additions and 3053 deletions

View File

@@ -15,8 +15,6 @@
package admin_test
import (
"bytes"
"context"
"errors"
"fmt"
"github.com/echovault/echovault/echovault"
@@ -36,12 +34,10 @@ import (
"net"
"os"
"path"
"reflect"
"slices"
"strings"
"sync"
"testing"
"unsafe"
)
func setupServer(port uint16) (*echovault.EchoVault, error) {
@@ -53,45 +49,7 @@ func setupServer(port uint16) (*echovault.EchoVault, error) {
return echovault.NewEchoVault(echovault.WithConfig(cfg))
}
func getUnexportedField(field reflect.Value) interface{} {
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
}
func getHandler(mockServer *echovault.EchoVault, commands ...string) internal.HandlerFunc {
if len(commands) == 0 {
return nil
}
getCommands :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
for _, c := range getCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler
return c.HandlerFunc
}
if strings.EqualFold(commands[0], c.Command) {
// Get sub-command handler
for _, sc := range c.SubCommands {
if strings.EqualFold(commands[1], sc.Command) {
return sc.HandlerFunc
}
}
}
}
return nil
}
func getHandlerFuncParams(ctx context.Context, mockServer *echovault.EchoVault, cmd []string, conn *net.Conn) internal.HandlerFuncParams {
getCommands :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
return internal.HandlerFuncParams{
Context: ctx,
Command: cmd,
Connection: conn,
GetAllCommands: getCommands,
}
}
func Test_AdminCommand(t *testing.T) {
func Test_AdminCommands(t *testing.T) {
t.Cleanup(func() {
_ = os.RemoveAll("./testdata")
})
@@ -104,23 +62,37 @@ func Test_AdminCommand(t *testing.T) {
t.Error(err)
return
}
mockServer, err := setupServer(uint16(port))
if err != nil {
t.Error(err)
return
}
res, err := getHandler(mockServer, "COMMANDS")(
getHandlerFuncParams(context.Background(), mockServer, []string{"commands"}, nil),
)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
mockServer.Start()
}()
wg.Wait()
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
t.Error(err)
return
}
client := resp.NewConn(conn)
if err = client.WriteArray([]resp.Value{resp.StringValue("COMMANDS")}); err != nil {
t.Error(err)
return
}
rd := resp.NewReader(bytes.NewReader(res))
rv, _, err := rd.ReadValue()
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
return
}
// Get all the commands from the existing modules.
@@ -148,8 +120,8 @@ func Test_AdminCommand(t *testing.T) {
}
}
if len(allCommands) != len(rv.Array()) {
t.Errorf("expected commands list to be of length %d, got %d", len(allCommands), len(rv.Array()))
if len(allCommands) != len(res.Array()) {
t.Errorf("expected commands list to be of length %d, got %d", len(allCommands), len(res.Array()))
}
})
@@ -161,23 +133,37 @@ func Test_AdminCommand(t *testing.T) {
t.Error(err)
return
}
mockServer, err := setupServer(uint16(port))
if err != nil {
t.Error(err)
return
}
res, err := getHandler(mockServer, "COMMAND", "COUNT")(
getHandlerFuncParams(context.Background(), mockServer, []string{"command", "count"}, nil),
)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
mockServer.Start()
}()
wg.Wait()
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
t.Error(err)
return
}
client := resp.NewConn(conn)
if err = client.WriteArray([]resp.Value{resp.StringValue("COMMAND"), resp.StringValue("COUNT")}); err != nil {
t.Error(err)
return
}
rd := resp.NewReader(bytes.NewReader(res))
rv, _, err := rd.ReadValue()
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
return
}
// Get all the commands from the existing modules.
@@ -205,8 +191,8 @@ func Test_AdminCommand(t *testing.T) {
}
}
if len(allCommands) != rv.Integer() {
t.Errorf("expected COMMAND COUNT to return %d, got %d", len(allCommands), rv.Integer())
if len(allCommands) != res.Integer() {
t.Errorf("expected COMMAND COUNT to return %d, got %d", len(allCommands), res.Integer())
}
})
@@ -225,6 +211,21 @@ func Test_AdminCommand(t *testing.T) {
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
mockServer.Start()
}()
wg.Wait()
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
t.Error(err)
return
}
client := resp.NewConn(conn)
// Get all the commands from the existing modules.
var allCommands []internal.Command
allCommands = append(allCommands, acl.Commands()...)
@@ -305,24 +306,26 @@ func Test_AdminCommand(t *testing.T) {
}
for _, test := range tests {
res, err := getHandler(mockServer, test.cmd...)(
getHandlerFuncParams(context.Background(), mockServer, test.cmd, nil),
)
command := make([]resp.Value, len(test.cmd))
for i, c := range test.cmd {
command[i] = resp.StringValue(c)
}
if err = client.WriteArray(command); err != nil {
t.Error(err)
return
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
return
}
rd := resp.NewReader(bytes.NewReader(res))
rv, _, err := rd.ReadValue()
if err != nil {
t.Error(err)
if len(res.Array()) != len(test.want) {
t.Errorf("expected response of length %d, got %d", len(test.want), len(res.Array()))
}
if len(rv.Array()) != len(test.want) {
t.Errorf("expected response of length %d, got %d", len(test.want), len(rv.Array()))
}
for _, command := range rv.Array() {
for _, command := range res.Array() {
if !slices.ContainsFunc(test.want, func(c string) bool {
return strings.EqualFold(c, command.String())
}) {
@@ -441,6 +444,7 @@ func Test_AdminCommand(t *testing.T) {
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
t.Error(err)
return
}
respConn := resp.NewConn(conn)