This commit is contained in:
Kelvin Mwinuka
2024-06-27 04:03:02 +08:00
parent 0bb4ce6756
commit 4dd3aa40b2
4 changed files with 361 additions and 8 deletions

View File

@@ -326,7 +326,7 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
}
// Skip certain commands from authorization
if slices.Contains([]string{"ping", "echo"}, strings.ToLower(comm)) {
if slices.Contains([]string{"ping", "echo", "hello"}, strings.ToLower(comm)) {
return nil
}
@@ -421,7 +421,7 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
}
// 8. Check if readKeys are in IncludedReadKeys
if !slices.ContainsFunc(readKeys, func(key string) bool {
if len(readKeys) > 0 && !slices.ContainsFunc(readKeys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKeyGlob string) bool {
if acl.GlobPatterns[readKeyGlob].Match(key) {
return true
@@ -433,12 +433,12 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
})
}) {
if len(notAllowed) > 0 {
return fmt.Errorf("not authorised to access the following keys: %+v", notAllowed)
return fmt.Errorf("not authorised to access the following read keys: %+v", notAllowed)
}
}
// 9. Check if write keys are in IncludedWriteKeys
if !slices.ContainsFunc(writeKeys, func(key string) bool {
if len(writeKeys) > 0 && !slices.ContainsFunc(writeKeys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool {
if acl.GlobPatterns[writeKeyGlob].Match(key) {
return true
@@ -449,7 +449,7 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
return false
})
}) {
return fmt.Errorf("not authorised to access the following keys: %+v", notAllowed)
return fmt.Errorf("not authorised to access the following write keys: %+v", notAllowed)
}
}

View File

@@ -68,7 +68,7 @@ func handleHello(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) == 1 {
serverInfo := params.GetServerInfo()
connectionInfo := params.GetConnectionInfo(params.Connection)
return buildHelloResponse(serverInfo, connectionInfo), nil
return BuildHelloResponse(serverInfo, connectionInfo), nil
}
options, err := getHelloOptions(
@@ -125,7 +125,7 @@ func handleHello(params internal.HandlerFuncParams) ([]byte, error) {
// Get the new connection details and server info to return to the client.
serverInfo := params.GetServerInfo()
connectionInfo = params.GetConnectionInfo(params.Connection)
return buildHelloResponse(serverInfo, connectionInfo), nil
return BuildHelloResponse(serverInfo, connectionInfo), nil
}
func handleSelect(params internal.HandlerFuncParams) ([]byte, error) {

View File

@@ -15,10 +15,15 @@
package connection_test
import (
"bufio"
"bytes"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"github.com/echovault/echovault/internal/modules/connection"
"reflect"
"strconv"
"strings"
"testing"
@@ -362,4 +367,352 @@ func Test_Connection(t *testing.T) {
}
})
t.Run("Test_HandleHello", func(t *testing.T) {
t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setUpServer(port, true, "")
if err != nil {
t.Error(err)
return
}
go func() {
mockServer.Start()
}()
t.Cleanup(func() {
mockServer.ShutDown()
})
tests := []struct {
name string
command []resp.Value
wantRes []byte
}{
{
name: "1. Hello",
command: []resp.Value{resp.StringValue("HELLO")},
wantRes: connection.BuildHelloResponse(
internal.ServerInfo{
Server: "echovault",
Version: constants.Version,
Id: "",
Mode: "standalone",
Role: "master",
Modules: mockServer.ListModules(),
},
internal.ConnectionInfo{
Id: 1,
Name: "",
Protocol: 2,
Database: 0,
},
),
},
{
name: "2. Hello 2",
command: []resp.Value{resp.StringValue("HELLO"), resp.StringValue("2")},
wantRes: connection.BuildHelloResponse(
internal.ServerInfo{
Server: "echovault",
Version: constants.Version,
Id: "",
Mode: "standalone",
Role: "master",
Modules: mockServer.ListModules(),
},
internal.ConnectionInfo{
Id: 2,
Name: "",
Protocol: 2,
Database: 0,
},
),
},
{
name: "3. Hello 3",
command: []resp.Value{resp.StringValue("HELLO"), resp.StringValue("3")},
wantRes: connection.BuildHelloResponse(
internal.ServerInfo{
Server: "echovault",
Version: constants.Version,
Id: "",
Mode: "standalone",
Role: "master",
Modules: mockServer.ListModules(),
},
internal.ConnectionInfo{
Id: 3,
Name: "",
Protocol: 3,
Database: 0,
},
),
},
{
name: "4. Hello with auth success",
command: []resp.Value{
resp.StringValue("HELLO"),
resp.StringValue("3"),
resp.StringValue("AUTH"),
resp.StringValue("default"),
resp.StringValue("password1"),
},
wantRes: connection.BuildHelloResponse(
internal.ServerInfo{
Server: "echovault",
Version: constants.Version,
Id: "",
Mode: "standalone",
Role: "master",
Modules: mockServer.ListModules(),
},
internal.ConnectionInfo{
Id: 4,
Name: "",
Protocol: 3,
Database: 0,
},
),
},
{
name: "5. Hello with auth failure",
command: []resp.Value{
resp.StringValue("HELLO"),
resp.StringValue("3"),
resp.StringValue("AUTH"),
resp.StringValue("default"),
resp.StringValue("password2"),
},
wantRes: []byte("-Error could not authenticate user\r\n"),
},
{
name: "6. Hello with auth and set client name",
command: []resp.Value{
resp.StringValue("HELLO"),
resp.StringValue("3"),
resp.StringValue("AUTH"),
resp.StringValue("default"),
resp.StringValue("password1"),
resp.StringValue("SETNAME"),
resp.StringValue("client6"),
},
wantRes: connection.BuildHelloResponse(
internal.ServerInfo{
Server: "echovault",
Version: constants.Version,
Id: "",
Mode: "standalone",
Role: "master",
Modules: mockServer.ListModules(),
},
internal.ConnectionInfo{
Id: 6,
Name: "",
Protocol: 3,
Database: 0,
},
),
},
{
name: "7. Command too long",
command: []resp.Value{
resp.StringValue("HELLO"),
resp.StringValue("3"),
resp.StringValue("AUTH"),
resp.StringValue("default"),
resp.StringValue("password1"),
resp.StringValue("SETNAME"),
resp.StringValue("client6"),
resp.StringValue("extra_arg"),
},
wantRes: []byte(fmt.Sprintf("-Error %s\r\n", constants.WrongArgsResponse)),
},
}
for i := 0; i < len(tests); i++ {
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
client := resp.NewConn(conn)
if err = client.WriteArray(tests[i].command); err != nil {
t.Error(err)
return
}
buf := bufio.NewReader(conn)
res, err := internal.ReadMessage(buf)
if err != nil {
t.Error(err)
return
}
if !bytes.Equal(tests[i].wantRes, res) {
t.Errorf("expected byte resposne:\n%s, \n\ngot:\n%s", string(tests[i].wantRes), string(res))
return
}
// Close connection
_ = conn.Close()
}
})
t.Run("Test_HandleSelect", func(t *testing.T) {
t.Parallel()
tests := []struct {
name string
database int
wantDBErr error
setCommand []resp.Value
getCommand []resp.Value
getWantRes []resp.Value
}{
{
name: "1. Default database 0",
database: 0,
wantDBErr: nil,
setCommand: []resp.Value{
resp.StringValue("MSET"),
resp.StringValue("key1"), resp.StringValue("value-01"),
resp.StringValue("key2"), resp.StringValue("value-02"),
resp.StringValue("key3"), resp.StringValue("value-03"),
},
getCommand: []resp.Value{
resp.StringValue("MGET"),
resp.StringValue("key1"),
resp.StringValue("key2"),
resp.StringValue("key3"),
},
getWantRes: []resp.Value{
resp.StringValue("value-01"),
resp.StringValue("value-02"),
resp.StringValue("value-03"),
},
},
{
name: "2. Select database 1",
database: 1,
wantDBErr: nil,
setCommand: []resp.Value{
resp.StringValue("MSET"),
resp.StringValue("key1"), resp.StringValue("value-11"),
resp.StringValue("key2"), resp.StringValue("value-12"),
resp.StringValue("key3"), resp.StringValue("value-13"),
},
getCommand: []resp.Value{
resp.StringValue("MGET"),
resp.StringValue("key1"),
resp.StringValue("key2"),
resp.StringValue("key3"),
},
getWantRes: []resp.Value{
resp.StringValue("value-11"),
resp.StringValue("value-12"),
resp.StringValue("value-13"),
},
},
{
name: "3. Error when selecting database < 0",
database: -1,
wantDBErr: errors.New("database must be >= 0"),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
client := resp.NewConn(conn)
// Authenticate the connection
if err = client.WriteArray([]resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("password1"),
}); err != nil {
t.Error(err)
return
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
return
}
if !strings.EqualFold(res.String(), "ok") {
t.Errorf("expected OK auth response, got \"%s\"", res.String())
return
}
// If database is not 0, execute the select command
if test.database != 0 {
if err = client.WriteArray([]resp.Value{
resp.StringValue("SELECT"),
resp.StringValue(strconv.Itoa(test.database)),
}); err != nil {
t.Error(err)
return
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
return
}
if test.wantDBErr != nil {
// If we expect a select error, check that it's the expected error.
if !strings.Contains(res.Error().Error(), test.wantDBErr.Error()) {
t.Errorf("expected error response to contain \"%s\", \"%s\"", test.wantDBErr.Error(), res.Error().Error())
return
}
return
} else {
// We do not expect an error, check if it's an OK response.
if !strings.EqualFold(res.String(), "ok") {
t.Errorf("expected OK response, got \"%s\"", res.String())
return
}
}
}
// Execute command to set values
if err = client.WriteArray(test.setCommand); err != nil {
t.Error(err)
return
}
res, _, err = client.ReadValue()
if err != nil {
t.Error(err)
return
}
if !strings.EqualFold(res.String(), "ok") {
t.Errorf("expected OK set response, got \"%s\"", res.String())
return
}
// Execute commands to get values.
if err = client.WriteArray(test.getCommand); err != nil {
t.Error(err)
return
}
res, _, err = client.ReadValue()
if err != nil {
t.Error(err)
return
}
if !reflect.DeepEqual(res.Array(), test.getWantRes) {
t.Errorf("expected response %+v, got %+v", test.getWantRes, res.Array())
return
}
})
}
})
}

View File

@@ -41,7 +41,7 @@ func getHelloOptions(cmd []string, options helloOptions) (helloOptions, error) {
}
}
func buildHelloResponse(serverInfo internal.ServerInfo, connectionInfo internal.ConnectionInfo) []byte {
func BuildHelloResponse(serverInfo internal.ServerInfo, connectionInfo internal.ConnectionInfo) []byte {
var res []byte
if connectionInfo.Protocol == 2 {