mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-07 00:43:37 +08:00
quit
This commit is contained in:
@@ -326,7 +326,7 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Skip certain commands from authorization
|
// Skip certain commands from authorization
|
||||||
if slices.Contains([]string{"ping", "echo"}, strings.ToLower(comm)) {
|
if slices.Contains([]string{"ping", "echo", "hello"}, strings.ToLower(comm)) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -421,7 +421,7 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 8. Check if readKeys are in IncludedReadKeys
|
// 8. Check if readKeys are in IncludedReadKeys
|
||||||
if !slices.ContainsFunc(readKeys, func(key string) bool {
|
if len(readKeys) > 0 && !slices.ContainsFunc(readKeys, func(key string) bool {
|
||||||
return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKeyGlob string) bool {
|
return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKeyGlob string) bool {
|
||||||
if acl.GlobPatterns[readKeyGlob].Match(key) {
|
if acl.GlobPatterns[readKeyGlob].Match(key) {
|
||||||
return true
|
return true
|
||||||
@@ -433,12 +433,12 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
|
|||||||
})
|
})
|
||||||
}) {
|
}) {
|
||||||
if len(notAllowed) > 0 {
|
if len(notAllowed) > 0 {
|
||||||
return fmt.Errorf("not authorised to access the following keys: %+v", notAllowed)
|
return fmt.Errorf("not authorised to access the following read keys: %+v", notAllowed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 9. Check if write keys are in IncludedWriteKeys
|
// 9. Check if write keys are in IncludedWriteKeys
|
||||||
if !slices.ContainsFunc(writeKeys, func(key string) bool {
|
if len(writeKeys) > 0 && !slices.ContainsFunc(writeKeys, func(key string) bool {
|
||||||
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool {
|
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool {
|
||||||
if acl.GlobPatterns[writeKeyGlob].Match(key) {
|
if acl.GlobPatterns[writeKeyGlob].Match(key) {
|
||||||
return true
|
return true
|
||||||
@@ -449,7 +449,7 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
|
|||||||
return false
|
return false
|
||||||
})
|
})
|
||||||
}) {
|
}) {
|
||||||
return fmt.Errorf("not authorised to access the following keys: %+v", notAllowed)
|
return fmt.Errorf("not authorised to access the following write keys: %+v", notAllowed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -68,7 +68,7 @@ func handleHello(params internal.HandlerFuncParams) ([]byte, error) {
|
|||||||
if len(params.Command) == 1 {
|
if len(params.Command) == 1 {
|
||||||
serverInfo := params.GetServerInfo()
|
serverInfo := params.GetServerInfo()
|
||||||
connectionInfo := params.GetConnectionInfo(params.Connection)
|
connectionInfo := params.GetConnectionInfo(params.Connection)
|
||||||
return buildHelloResponse(serverInfo, connectionInfo), nil
|
return BuildHelloResponse(serverInfo, connectionInfo), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
options, err := getHelloOptions(
|
options, err := getHelloOptions(
|
||||||
@@ -125,7 +125,7 @@ func handleHello(params internal.HandlerFuncParams) ([]byte, error) {
|
|||||||
// Get the new connection details and server info to return to the client.
|
// Get the new connection details and server info to return to the client.
|
||||||
serverInfo := params.GetServerInfo()
|
serverInfo := params.GetServerInfo()
|
||||||
connectionInfo = params.GetConnectionInfo(params.Connection)
|
connectionInfo = params.GetConnectionInfo(params.Connection)
|
||||||
return buildHelloResponse(serverInfo, connectionInfo), nil
|
return BuildHelloResponse(serverInfo, connectionInfo), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleSelect(params internal.HandlerFuncParams) ([]byte, error) {
|
func handleSelect(params internal.HandlerFuncParams) ([]byte, error) {
|
||||||
|
@@ -15,10 +15,15 @@
|
|||||||
package connection_test
|
package connection_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/echovault/echovault/internal/modules/connection"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -362,4 +367,352 @@ func Test_Connection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("Test_HandleHello", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
port, err := internal.GetFreePort()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mockServer, err := setUpServer(port, true, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
mockServer.Start()
|
||||||
|
}()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
mockServer.ShutDown()
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
command []resp.Value
|
||||||
|
wantRes []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "1. Hello",
|
||||||
|
command: []resp.Value{resp.StringValue("HELLO")},
|
||||||
|
wantRes: connection.BuildHelloResponse(
|
||||||
|
internal.ServerInfo{
|
||||||
|
Server: "echovault",
|
||||||
|
Version: constants.Version,
|
||||||
|
Id: "",
|
||||||
|
Mode: "standalone",
|
||||||
|
Role: "master",
|
||||||
|
Modules: mockServer.ListModules(),
|
||||||
|
},
|
||||||
|
internal.ConnectionInfo{
|
||||||
|
Id: 1,
|
||||||
|
Name: "",
|
||||||
|
Protocol: 2,
|
||||||
|
Database: 0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "2. Hello 2",
|
||||||
|
command: []resp.Value{resp.StringValue("HELLO"), resp.StringValue("2")},
|
||||||
|
wantRes: connection.BuildHelloResponse(
|
||||||
|
internal.ServerInfo{
|
||||||
|
Server: "echovault",
|
||||||
|
Version: constants.Version,
|
||||||
|
Id: "",
|
||||||
|
Mode: "standalone",
|
||||||
|
Role: "master",
|
||||||
|
Modules: mockServer.ListModules(),
|
||||||
|
},
|
||||||
|
internal.ConnectionInfo{
|
||||||
|
Id: 2,
|
||||||
|
Name: "",
|
||||||
|
Protocol: 2,
|
||||||
|
Database: 0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "3. Hello 3",
|
||||||
|
command: []resp.Value{resp.StringValue("HELLO"), resp.StringValue("3")},
|
||||||
|
wantRes: connection.BuildHelloResponse(
|
||||||
|
internal.ServerInfo{
|
||||||
|
Server: "echovault",
|
||||||
|
Version: constants.Version,
|
||||||
|
Id: "",
|
||||||
|
Mode: "standalone",
|
||||||
|
Role: "master",
|
||||||
|
Modules: mockServer.ListModules(),
|
||||||
|
},
|
||||||
|
internal.ConnectionInfo{
|
||||||
|
Id: 3,
|
||||||
|
Name: "",
|
||||||
|
Protocol: 3,
|
||||||
|
Database: 0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "4. Hello with auth success",
|
||||||
|
command: []resp.Value{
|
||||||
|
resp.StringValue("HELLO"),
|
||||||
|
resp.StringValue("3"),
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("default"),
|
||||||
|
resp.StringValue("password1"),
|
||||||
|
},
|
||||||
|
wantRes: connection.BuildHelloResponse(
|
||||||
|
internal.ServerInfo{
|
||||||
|
Server: "echovault",
|
||||||
|
Version: constants.Version,
|
||||||
|
Id: "",
|
||||||
|
Mode: "standalone",
|
||||||
|
Role: "master",
|
||||||
|
Modules: mockServer.ListModules(),
|
||||||
|
},
|
||||||
|
internal.ConnectionInfo{
|
||||||
|
Id: 4,
|
||||||
|
Name: "",
|
||||||
|
Protocol: 3,
|
||||||
|
Database: 0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "5. Hello with auth failure",
|
||||||
|
command: []resp.Value{
|
||||||
|
resp.StringValue("HELLO"),
|
||||||
|
resp.StringValue("3"),
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("default"),
|
||||||
|
resp.StringValue("password2"),
|
||||||
|
},
|
||||||
|
wantRes: []byte("-Error could not authenticate user\r\n"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "6. Hello with auth and set client name",
|
||||||
|
command: []resp.Value{
|
||||||
|
resp.StringValue("HELLO"),
|
||||||
|
resp.StringValue("3"),
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("default"),
|
||||||
|
resp.StringValue("password1"),
|
||||||
|
resp.StringValue("SETNAME"),
|
||||||
|
resp.StringValue("client6"),
|
||||||
|
},
|
||||||
|
wantRes: connection.BuildHelloResponse(
|
||||||
|
internal.ServerInfo{
|
||||||
|
Server: "echovault",
|
||||||
|
Version: constants.Version,
|
||||||
|
Id: "",
|
||||||
|
Mode: "standalone",
|
||||||
|
Role: "master",
|
||||||
|
Modules: mockServer.ListModules(),
|
||||||
|
},
|
||||||
|
internal.ConnectionInfo{
|
||||||
|
Id: 6,
|
||||||
|
Name: "",
|
||||||
|
Protocol: 3,
|
||||||
|
Database: 0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "7. Command too long",
|
||||||
|
command: []resp.Value{
|
||||||
|
resp.StringValue("HELLO"),
|
||||||
|
resp.StringValue("3"),
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("default"),
|
||||||
|
resp.StringValue("password1"),
|
||||||
|
resp.StringValue("SETNAME"),
|
||||||
|
resp.StringValue("client6"),
|
||||||
|
resp.StringValue("extra_arg"),
|
||||||
|
},
|
||||||
|
wantRes: []byte(fmt.Sprintf("-Error %s\r\n", constants.WrongArgsResponse)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(tests); i++ {
|
||||||
|
conn, err := internal.GetConnection("localhost", port)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
client := resp.NewConn(conn)
|
||||||
|
|
||||||
|
if err = client.WriteArray(tests[i].command); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := bufio.NewReader(conn)
|
||||||
|
res, err := internal.ReadMessage(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(tests[i].wantRes, res) {
|
||||||
|
t.Errorf("expected byte resposne:\n%s, \n\ngot:\n%s", string(tests[i].wantRes), string(res))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close connection
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Test_HandleSelect", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
database int
|
||||||
|
wantDBErr error
|
||||||
|
setCommand []resp.Value
|
||||||
|
getCommand []resp.Value
|
||||||
|
getWantRes []resp.Value
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "1. Default database 0",
|
||||||
|
database: 0,
|
||||||
|
wantDBErr: nil,
|
||||||
|
setCommand: []resp.Value{
|
||||||
|
resp.StringValue("MSET"),
|
||||||
|
resp.StringValue("key1"), resp.StringValue("value-01"),
|
||||||
|
resp.StringValue("key2"), resp.StringValue("value-02"),
|
||||||
|
resp.StringValue("key3"), resp.StringValue("value-03"),
|
||||||
|
},
|
||||||
|
getCommand: []resp.Value{
|
||||||
|
resp.StringValue("MGET"),
|
||||||
|
resp.StringValue("key1"),
|
||||||
|
resp.StringValue("key2"),
|
||||||
|
resp.StringValue("key3"),
|
||||||
|
},
|
||||||
|
getWantRes: []resp.Value{
|
||||||
|
resp.StringValue("value-01"),
|
||||||
|
resp.StringValue("value-02"),
|
||||||
|
resp.StringValue("value-03"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "2. Select database 1",
|
||||||
|
database: 1,
|
||||||
|
wantDBErr: nil,
|
||||||
|
setCommand: []resp.Value{
|
||||||
|
resp.StringValue("MSET"),
|
||||||
|
resp.StringValue("key1"), resp.StringValue("value-11"),
|
||||||
|
resp.StringValue("key2"), resp.StringValue("value-12"),
|
||||||
|
resp.StringValue("key3"), resp.StringValue("value-13"),
|
||||||
|
},
|
||||||
|
getCommand: []resp.Value{
|
||||||
|
resp.StringValue("MGET"),
|
||||||
|
resp.StringValue("key1"),
|
||||||
|
resp.StringValue("key2"),
|
||||||
|
resp.StringValue("key3"),
|
||||||
|
},
|
||||||
|
getWantRes: []resp.Value{
|
||||||
|
resp.StringValue("value-11"),
|
||||||
|
resp.StringValue("value-12"),
|
||||||
|
resp.StringValue("value-13"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "3. Error when selecting database < 0",
|
||||||
|
database: -1,
|
||||||
|
wantDBErr: errors.New("database must be >= 0"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
conn, err := internal.GetConnection("localhost", port)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
client := resp.NewConn(conn)
|
||||||
|
|
||||||
|
// Authenticate the connection
|
||||||
|
if err = client.WriteArray([]resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("password1"),
|
||||||
|
}); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res, _, err := client.ReadValue()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(res.String(), "ok") {
|
||||||
|
t.Errorf("expected OK auth response, got \"%s\"", res.String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If database is not 0, execute the select command
|
||||||
|
if test.database != 0 {
|
||||||
|
if err = client.WriteArray([]resp.Value{
|
||||||
|
resp.StringValue("SELECT"),
|
||||||
|
resp.StringValue(strconv.Itoa(test.database)),
|
||||||
|
}); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res, _, err := client.ReadValue()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if test.wantDBErr != nil {
|
||||||
|
// If we expect a select error, check that it's the expected error.
|
||||||
|
if !strings.Contains(res.Error().Error(), test.wantDBErr.Error()) {
|
||||||
|
t.Errorf("expected error response to contain \"%s\", \"%s\"", test.wantDBErr.Error(), res.Error().Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
// We do not expect an error, check if it's an OK response.
|
||||||
|
if !strings.EqualFold(res.String(), "ok") {
|
||||||
|
t.Errorf("expected OK response, got \"%s\"", res.String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute command to set values
|
||||||
|
if err = client.WriteArray(test.setCommand); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res, _, err = client.ReadValue()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(res.String(), "ok") {
|
||||||
|
t.Errorf("expected OK set response, got \"%s\"", res.String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute commands to get values.
|
||||||
|
if err = client.WriteArray(test.getCommand); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res, _, err = client.ReadValue()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(res.Array(), test.getWantRes) {
|
||||||
|
t.Errorf("expected response %+v, got %+v", test.getWantRes, res.Array())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
@@ -41,7 +41,7 @@ func getHelloOptions(cmd []string, options helloOptions) (helloOptions, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildHelloResponse(serverInfo internal.ServerInfo, connectionInfo internal.ConnectionInfo) []byte {
|
func BuildHelloResponse(serverInfo internal.ServerInfo, connectionInfo internal.ConnectionInfo) []byte {
|
||||||
var res []byte
|
var res []byte
|
||||||
|
|
||||||
if connectionInfo.Protocol == 2 {
|
if connectionInfo.Protocol == 2 {
|
||||||
|
Reference in New Issue
Block a user