diff --git a/internal/modules/acl/acl.go b/internal/modules/acl/acl.go index d9226f6..8387b89 100644 --- a/internal/modules/acl/acl.go +++ b/internal/modules/acl/acl.go @@ -326,7 +326,7 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern } // Skip certain commands from authorization - if slices.Contains([]string{"ping", "echo"}, strings.ToLower(comm)) { + if slices.Contains([]string{"ping", "echo", "hello"}, strings.ToLower(comm)) { return nil } @@ -421,7 +421,7 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern } // 8. Check if readKeys are in IncludedReadKeys - if !slices.ContainsFunc(readKeys, func(key string) bool { + if len(readKeys) > 0 && !slices.ContainsFunc(readKeys, func(key string) bool { return slices.ContainsFunc(connection.User.IncludedReadKeys, func(readKeyGlob string) bool { if acl.GlobPatterns[readKeyGlob].Match(key) { return true @@ -433,12 +433,12 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern }) }) { if len(notAllowed) > 0 { - return fmt.Errorf("not authorised to access the following keys: %+v", notAllowed) + return fmt.Errorf("not authorised to access the following read keys: %+v", notAllowed) } } // 9. Check if write keys are in IncludedWriteKeys - if !slices.ContainsFunc(writeKeys, func(key string) bool { + if len(writeKeys) > 0 && !slices.ContainsFunc(writeKeys, func(key string) bool { return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool { if acl.GlobPatterns[writeKeyGlob].Match(key) { return true @@ -449,7 +449,7 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern return false }) }) { - return fmt.Errorf("not authorised to access the following keys: %+v", notAllowed) + return fmt.Errorf("not authorised to access the following write keys: %+v", notAllowed) } } diff --git a/internal/modules/connection/commands.go b/internal/modules/connection/commands.go index 6345a63..6682827 100644 --- a/internal/modules/connection/commands.go +++ b/internal/modules/connection/commands.go @@ -68,7 +68,7 @@ func handleHello(params internal.HandlerFuncParams) ([]byte, error) { if len(params.Command) == 1 { serverInfo := params.GetServerInfo() connectionInfo := params.GetConnectionInfo(params.Connection) - return buildHelloResponse(serverInfo, connectionInfo), nil + return BuildHelloResponse(serverInfo, connectionInfo), nil } options, err := getHelloOptions( @@ -125,7 +125,7 @@ func handleHello(params internal.HandlerFuncParams) ([]byte, error) { // Get the new connection details and server info to return to the client. serverInfo := params.GetServerInfo() connectionInfo = params.GetConnectionInfo(params.Connection) - return buildHelloResponse(serverInfo, connectionInfo), nil + return BuildHelloResponse(serverInfo, connectionInfo), nil } func handleSelect(params internal.HandlerFuncParams) ([]byte, error) { diff --git a/internal/modules/connection/commands_test.go b/internal/modules/connection/commands_test.go index f175ef3..ac8b479 100644 --- a/internal/modules/connection/commands_test.go +++ b/internal/modules/connection/commands_test.go @@ -15,10 +15,15 @@ package connection_test import ( + "bufio" + "bytes" "crypto/sha256" "encoding/hex" "errors" "fmt" + "github.com/echovault/echovault/internal/modules/connection" + "reflect" + "strconv" "strings" "testing" @@ -362,4 +367,352 @@ func Test_Connection(t *testing.T) { } }) + t.Run("Test_HandleHello", func(t *testing.T) { + t.Parallel() + + port, err := internal.GetFreePort() + if err != nil { + t.Error(err) + return + } + mockServer, err := setUpServer(port, true, "") + if err != nil { + t.Error(err) + return + } + go func() { + mockServer.Start() + }() + t.Cleanup(func() { + mockServer.ShutDown() + }) + + tests := []struct { + name string + command []resp.Value + wantRes []byte + }{ + { + name: "1. Hello", + command: []resp.Value{resp.StringValue("HELLO")}, + wantRes: connection.BuildHelloResponse( + internal.ServerInfo{ + Server: "echovault", + Version: constants.Version, + Id: "", + Mode: "standalone", + Role: "master", + Modules: mockServer.ListModules(), + }, + internal.ConnectionInfo{ + Id: 1, + Name: "", + Protocol: 2, + Database: 0, + }, + ), + }, + { + name: "2. Hello 2", + command: []resp.Value{resp.StringValue("HELLO"), resp.StringValue("2")}, + wantRes: connection.BuildHelloResponse( + internal.ServerInfo{ + Server: "echovault", + Version: constants.Version, + Id: "", + Mode: "standalone", + Role: "master", + Modules: mockServer.ListModules(), + }, + internal.ConnectionInfo{ + Id: 2, + Name: "", + Protocol: 2, + Database: 0, + }, + ), + }, + { + name: "3. Hello 3", + command: []resp.Value{resp.StringValue("HELLO"), resp.StringValue("3")}, + wantRes: connection.BuildHelloResponse( + internal.ServerInfo{ + Server: "echovault", + Version: constants.Version, + Id: "", + Mode: "standalone", + Role: "master", + Modules: mockServer.ListModules(), + }, + internal.ConnectionInfo{ + Id: 3, + Name: "", + Protocol: 3, + Database: 0, + }, + ), + }, + { + name: "4. Hello with auth success", + command: []resp.Value{ + resp.StringValue("HELLO"), + resp.StringValue("3"), + resp.StringValue("AUTH"), + resp.StringValue("default"), + resp.StringValue("password1"), + }, + wantRes: connection.BuildHelloResponse( + internal.ServerInfo{ + Server: "echovault", + Version: constants.Version, + Id: "", + Mode: "standalone", + Role: "master", + Modules: mockServer.ListModules(), + }, + internal.ConnectionInfo{ + Id: 4, + Name: "", + Protocol: 3, + Database: 0, + }, + ), + }, + { + name: "5. Hello with auth failure", + command: []resp.Value{ + resp.StringValue("HELLO"), + resp.StringValue("3"), + resp.StringValue("AUTH"), + resp.StringValue("default"), + resp.StringValue("password2"), + }, + wantRes: []byte("-Error could not authenticate user\r\n"), + }, + { + name: "6. Hello with auth and set client name", + command: []resp.Value{ + resp.StringValue("HELLO"), + resp.StringValue("3"), + resp.StringValue("AUTH"), + resp.StringValue("default"), + resp.StringValue("password1"), + resp.StringValue("SETNAME"), + resp.StringValue("client6"), + }, + wantRes: connection.BuildHelloResponse( + internal.ServerInfo{ + Server: "echovault", + Version: constants.Version, + Id: "", + Mode: "standalone", + Role: "master", + Modules: mockServer.ListModules(), + }, + internal.ConnectionInfo{ + Id: 6, + Name: "", + Protocol: 3, + Database: 0, + }, + ), + }, + { + name: "7. Command too long", + command: []resp.Value{ + resp.StringValue("HELLO"), + resp.StringValue("3"), + resp.StringValue("AUTH"), + resp.StringValue("default"), + resp.StringValue("password1"), + resp.StringValue("SETNAME"), + resp.StringValue("client6"), + resp.StringValue("extra_arg"), + }, + wantRes: []byte(fmt.Sprintf("-Error %s\r\n", constants.WrongArgsResponse)), + }, + } + + for i := 0; i < len(tests); i++ { + conn, err := internal.GetConnection("localhost", port) + if err != nil { + t.Error(err) + return + } + client := resp.NewConn(conn) + + if err = client.WriteArray(tests[i].command); err != nil { + t.Error(err) + return + } + + buf := bufio.NewReader(conn) + res, err := internal.ReadMessage(buf) + if err != nil { + t.Error(err) + return + } + + if !bytes.Equal(tests[i].wantRes, res) { + t.Errorf("expected byte resposne:\n%s, \n\ngot:\n%s", string(tests[i].wantRes), string(res)) + return + } + + // Close connection + _ = conn.Close() + } + }) + + t.Run("Test_HandleSelect", func(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + database int + wantDBErr error + setCommand []resp.Value + getCommand []resp.Value + getWantRes []resp.Value + }{ + { + name: "1. Default database 0", + database: 0, + wantDBErr: nil, + setCommand: []resp.Value{ + resp.StringValue("MSET"), + resp.StringValue("key1"), resp.StringValue("value-01"), + resp.StringValue("key2"), resp.StringValue("value-02"), + resp.StringValue("key3"), resp.StringValue("value-03"), + }, + getCommand: []resp.Value{ + resp.StringValue("MGET"), + resp.StringValue("key1"), + resp.StringValue("key2"), + resp.StringValue("key3"), + }, + getWantRes: []resp.Value{ + resp.StringValue("value-01"), + resp.StringValue("value-02"), + resp.StringValue("value-03"), + }, + }, + { + name: "2. Select database 1", + database: 1, + wantDBErr: nil, + setCommand: []resp.Value{ + resp.StringValue("MSET"), + resp.StringValue("key1"), resp.StringValue("value-11"), + resp.StringValue("key2"), resp.StringValue("value-12"), + resp.StringValue("key3"), resp.StringValue("value-13"), + }, + getCommand: []resp.Value{ + resp.StringValue("MGET"), + resp.StringValue("key1"), + resp.StringValue("key2"), + resp.StringValue("key3"), + }, + getWantRes: []resp.Value{ + resp.StringValue("value-11"), + resp.StringValue("value-12"), + resp.StringValue("value-13"), + }, + }, + { + name: "3. Error when selecting database < 0", + database: -1, + wantDBErr: errors.New("database must be >= 0"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + conn, err := internal.GetConnection("localhost", port) + if err != nil { + t.Error(err) + return + } + client := resp.NewConn(conn) + + // Authenticate the connection + if err = client.WriteArray([]resp.Value{ + resp.StringValue("AUTH"), + resp.StringValue("password1"), + }); err != nil { + t.Error(err) + return + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + return + } + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected OK auth response, got \"%s\"", res.String()) + return + } + + // If database is not 0, execute the select command + if test.database != 0 { + if err = client.WriteArray([]resp.Value{ + resp.StringValue("SELECT"), + resp.StringValue(strconv.Itoa(test.database)), + }); err != nil { + t.Error(err) + return + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + return + } + if test.wantDBErr != nil { + // If we expect a select error, check that it's the expected error. + if !strings.Contains(res.Error().Error(), test.wantDBErr.Error()) { + t.Errorf("expected error response to contain \"%s\", \"%s\"", test.wantDBErr.Error(), res.Error().Error()) + return + } + return + } else { + // We do not expect an error, check if it's an OK response. + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected OK response, got \"%s\"", res.String()) + return + } + } + } + + // Execute command to set values + if err = client.WriteArray(test.setCommand); err != nil { + t.Error(err) + return + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + return + } + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected OK set response, got \"%s\"", res.String()) + return + } + + // Execute commands to get values. + if err = client.WriteArray(test.getCommand); err != nil { + t.Error(err) + return + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + return + } + if !reflect.DeepEqual(res.Array(), test.getWantRes) { + t.Errorf("expected response %+v, got %+v", test.getWantRes, res.Array()) + return + } + }) + } + }) } diff --git a/internal/modules/connection/utils.go b/internal/modules/connection/utils.go index 474bf37..d584993 100644 --- a/internal/modules/connection/utils.go +++ b/internal/modules/connection/utils.go @@ -41,7 +41,7 @@ func getHelloOptions(cmd []string, options helloOptions) (helloOptions, error) { } } -func buildHelloResponse(serverInfo internal.ServerInfo, connectionInfo internal.ConnectionInfo) []byte { +func BuildHelloResponse(serverInfo internal.ServerInfo, connectionInfo internal.ConnectionInfo) []byte { var res []byte if connectionInfo.Protocol == 2 {