mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-06 08:27:04 +08:00
quit
This commit is contained in:
@@ -15,10 +15,15 @@
|
||||
package connection_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/internal/modules/connection"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -362,4 +367,352 @@ func Test_Connection(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Test_HandleHello", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
mockServer, err := setUpServer(port, true, "")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
mockServer.Start()
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
mockServer.ShutDown()
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
command []resp.Value
|
||||
wantRes []byte
|
||||
}{
|
||||
{
|
||||
name: "1. Hello",
|
||||
command: []resp.Value{resp.StringValue("HELLO")},
|
||||
wantRes: connection.BuildHelloResponse(
|
||||
internal.ServerInfo{
|
||||
Server: "echovault",
|
||||
Version: constants.Version,
|
||||
Id: "",
|
||||
Mode: "standalone",
|
||||
Role: "master",
|
||||
Modules: mockServer.ListModules(),
|
||||
},
|
||||
internal.ConnectionInfo{
|
||||
Id: 1,
|
||||
Name: "",
|
||||
Protocol: 2,
|
||||
Database: 0,
|
||||
},
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "2. Hello 2",
|
||||
command: []resp.Value{resp.StringValue("HELLO"), resp.StringValue("2")},
|
||||
wantRes: connection.BuildHelloResponse(
|
||||
internal.ServerInfo{
|
||||
Server: "echovault",
|
||||
Version: constants.Version,
|
||||
Id: "",
|
||||
Mode: "standalone",
|
||||
Role: "master",
|
||||
Modules: mockServer.ListModules(),
|
||||
},
|
||||
internal.ConnectionInfo{
|
||||
Id: 2,
|
||||
Name: "",
|
||||
Protocol: 2,
|
||||
Database: 0,
|
||||
},
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "3. Hello 3",
|
||||
command: []resp.Value{resp.StringValue("HELLO"), resp.StringValue("3")},
|
||||
wantRes: connection.BuildHelloResponse(
|
||||
internal.ServerInfo{
|
||||
Server: "echovault",
|
||||
Version: constants.Version,
|
||||
Id: "",
|
||||
Mode: "standalone",
|
||||
Role: "master",
|
||||
Modules: mockServer.ListModules(),
|
||||
},
|
||||
internal.ConnectionInfo{
|
||||
Id: 3,
|
||||
Name: "",
|
||||
Protocol: 3,
|
||||
Database: 0,
|
||||
},
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "4. Hello with auth success",
|
||||
command: []resp.Value{
|
||||
resp.StringValue("HELLO"),
|
||||
resp.StringValue("3"),
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("default"),
|
||||
resp.StringValue("password1"),
|
||||
},
|
||||
wantRes: connection.BuildHelloResponse(
|
||||
internal.ServerInfo{
|
||||
Server: "echovault",
|
||||
Version: constants.Version,
|
||||
Id: "",
|
||||
Mode: "standalone",
|
||||
Role: "master",
|
||||
Modules: mockServer.ListModules(),
|
||||
},
|
||||
internal.ConnectionInfo{
|
||||
Id: 4,
|
||||
Name: "",
|
||||
Protocol: 3,
|
||||
Database: 0,
|
||||
},
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "5. Hello with auth failure",
|
||||
command: []resp.Value{
|
||||
resp.StringValue("HELLO"),
|
||||
resp.StringValue("3"),
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("default"),
|
||||
resp.StringValue("password2"),
|
||||
},
|
||||
wantRes: []byte("-Error could not authenticate user\r\n"),
|
||||
},
|
||||
{
|
||||
name: "6. Hello with auth and set client name",
|
||||
command: []resp.Value{
|
||||
resp.StringValue("HELLO"),
|
||||
resp.StringValue("3"),
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("default"),
|
||||
resp.StringValue("password1"),
|
||||
resp.StringValue("SETNAME"),
|
||||
resp.StringValue("client6"),
|
||||
},
|
||||
wantRes: connection.BuildHelloResponse(
|
||||
internal.ServerInfo{
|
||||
Server: "echovault",
|
||||
Version: constants.Version,
|
||||
Id: "",
|
||||
Mode: "standalone",
|
||||
Role: "master",
|
||||
Modules: mockServer.ListModules(),
|
||||
},
|
||||
internal.ConnectionInfo{
|
||||
Id: 6,
|
||||
Name: "",
|
||||
Protocol: 3,
|
||||
Database: 0,
|
||||
},
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "7. Command too long",
|
||||
command: []resp.Value{
|
||||
resp.StringValue("HELLO"),
|
||||
resp.StringValue("3"),
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("default"),
|
||||
resp.StringValue("password1"),
|
||||
resp.StringValue("SETNAME"),
|
||||
resp.StringValue("client6"),
|
||||
resp.StringValue("extra_arg"),
|
||||
},
|
||||
wantRes: []byte(fmt.Sprintf("-Error %s\r\n", constants.WrongArgsResponse)),
|
||||
},
|
||||
}
|
||||
|
||||
for i := 0; i < len(tests); i++ {
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
if err = client.WriteArray(tests[i].command); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
buf := bufio.NewReader(conn)
|
||||
res, err := internal.ReadMessage(buf)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal(tests[i].wantRes, res) {
|
||||
t.Errorf("expected byte resposne:\n%s, \n\ngot:\n%s", string(tests[i].wantRes), string(res))
|
||||
return
|
||||
}
|
||||
|
||||
// Close connection
|
||||
_ = conn.Close()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Test_HandleSelect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
database int
|
||||
wantDBErr error
|
||||
setCommand []resp.Value
|
||||
getCommand []resp.Value
|
||||
getWantRes []resp.Value
|
||||
}{
|
||||
{
|
||||
name: "1. Default database 0",
|
||||
database: 0,
|
||||
wantDBErr: nil,
|
||||
setCommand: []resp.Value{
|
||||
resp.StringValue("MSET"),
|
||||
resp.StringValue("key1"), resp.StringValue("value-01"),
|
||||
resp.StringValue("key2"), resp.StringValue("value-02"),
|
||||
resp.StringValue("key3"), resp.StringValue("value-03"),
|
||||
},
|
||||
getCommand: []resp.Value{
|
||||
resp.StringValue("MGET"),
|
||||
resp.StringValue("key1"),
|
||||
resp.StringValue("key2"),
|
||||
resp.StringValue("key3"),
|
||||
},
|
||||
getWantRes: []resp.Value{
|
||||
resp.StringValue("value-01"),
|
||||
resp.StringValue("value-02"),
|
||||
resp.StringValue("value-03"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "2. Select database 1",
|
||||
database: 1,
|
||||
wantDBErr: nil,
|
||||
setCommand: []resp.Value{
|
||||
resp.StringValue("MSET"),
|
||||
resp.StringValue("key1"), resp.StringValue("value-11"),
|
||||
resp.StringValue("key2"), resp.StringValue("value-12"),
|
||||
resp.StringValue("key3"), resp.StringValue("value-13"),
|
||||
},
|
||||
getCommand: []resp.Value{
|
||||
resp.StringValue("MGET"),
|
||||
resp.StringValue("key1"),
|
||||
resp.StringValue("key2"),
|
||||
resp.StringValue("key3"),
|
||||
},
|
||||
getWantRes: []resp.Value{
|
||||
resp.StringValue("value-11"),
|
||||
resp.StringValue("value-12"),
|
||||
resp.StringValue("value-13"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "3. Error when selecting database < 0",
|
||||
database: -1,
|
||||
wantDBErr: errors.New("database must be >= 0"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
// Authenticate the connection
|
||||
if err = client.WriteArray([]resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("password1"),
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(res.String(), "ok") {
|
||||
t.Errorf("expected OK auth response, got \"%s\"", res.String())
|
||||
return
|
||||
}
|
||||
|
||||
// If database is not 0, execute the select command
|
||||
if test.database != 0 {
|
||||
if err = client.WriteArray([]resp.Value{
|
||||
resp.StringValue("SELECT"),
|
||||
resp.StringValue(strconv.Itoa(test.database)),
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if test.wantDBErr != nil {
|
||||
// If we expect a select error, check that it's the expected error.
|
||||
if !strings.Contains(res.Error().Error(), test.wantDBErr.Error()) {
|
||||
t.Errorf("expected error response to contain \"%s\", \"%s\"", test.wantDBErr.Error(), res.Error().Error())
|
||||
return
|
||||
}
|
||||
return
|
||||
} else {
|
||||
// We do not expect an error, check if it's an OK response.
|
||||
if !strings.EqualFold(res.String(), "ok") {
|
||||
t.Errorf("expected OK response, got \"%s\"", res.String())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Execute command to set values
|
||||
if err = client.WriteArray(test.setCommand); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
res, _, err = client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(res.String(), "ok") {
|
||||
t.Errorf("expected OK set response, got \"%s\"", res.String())
|
||||
return
|
||||
}
|
||||
|
||||
// Execute commands to get values.
|
||||
if err = client.WriteArray(test.getCommand); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
res, _, err = client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(res.Array(), test.getWantRes) {
|
||||
t.Errorf("expected response %+v, got %+v", test.getWantRes, res.Array())
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user