mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-01 06:12:08 +08:00
Updated test suites to retry connection to tcp server. Use concurrency to setup cluster in Test_Cluster suite.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -40,6 +40,7 @@ type ClientServerPair struct {
|
||||
raftPort int
|
||||
mlPort int
|
||||
bootstrapCluster bool
|
||||
joinAddr string
|
||||
raw net.Conn
|
||||
client *resp.Conn
|
||||
server *EchoVault
|
||||
@@ -92,9 +93,57 @@ func setupServer(
|
||||
)
|
||||
}
|
||||
|
||||
func setupNode(node *ClientServerPair, isLeader bool, errChan *chan error) {
|
||||
server, err := setupServer(
|
||||
node.serverId,
|
||||
node.bootstrapCluster,
|
||||
node.bindAddr,
|
||||
node.joinAddr,
|
||||
node.port,
|
||||
node.raftPort,
|
||||
node.mlPort,
|
||||
)
|
||||
if err != nil {
|
||||
*errChan <- fmt.Errorf("could not start server; %v", err)
|
||||
}
|
||||
|
||||
// Start the server.
|
||||
go func() {
|
||||
server.Start()
|
||||
}()
|
||||
|
||||
if isLeader {
|
||||
// If node is a leader, wait until it's established itself as a leader of the raft cluster.
|
||||
for {
|
||||
if server.raft.IsRaftLeader() {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If the node is a follower, wait until it's joined the raft cluster before moving forward.
|
||||
for {
|
||||
if server.raft.HasJoinedCluster() {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setup client connection.
|
||||
conn, err := internal.GetConnection(node.bindAddr, node.port)
|
||||
if err != nil {
|
||||
*errChan <- fmt.Errorf("could not open tcp connection: %v", err)
|
||||
}
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
node.raw = conn
|
||||
node.client = client
|
||||
node.server = server
|
||||
}
|
||||
|
||||
func makeCluster(size int) ([]ClientServerPair, error) {
|
||||
pairs := make([]ClientServerPair, size)
|
||||
|
||||
// Set up node metadata.
|
||||
for i := 0; i < len(pairs); i++ {
|
||||
serverId := fmt.Sprintf("SERVER-%d", i)
|
||||
bindAddr := getBindAddr().String()
|
||||
@@ -115,42 +164,6 @@ func makeCluster(size int) ([]ClientServerPair, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get free memberlist port: %v", err)
|
||||
}
|
||||
server, err := setupServer(serverId, bootstrapCluster, bindAddr, joinAddr, port, raftPort, memberlistPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not start server; %v", err)
|
||||
}
|
||||
|
||||
// Start the server
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
server.Start()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
if i == 0 {
|
||||
// If node is a leader, wait until it's established itself as a leader of the raft cluster.
|
||||
for {
|
||||
if server.raft.IsRaftLeader() {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If the node is a follower, wait until it's joined the raft cluster before moving forward.
|
||||
for {
|
||||
if server.raft.HasJoinedCluster() {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setup client connection.
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not open tcp connection: %v", err)
|
||||
}
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
pairs[i] = ClientServerPair{
|
||||
serverId: serverId,
|
||||
@@ -159,12 +172,37 @@ func makeCluster(size int) ([]ClientServerPair, error) {
|
||||
raftPort: raftPort,
|
||||
mlPort: memberlistPort,
|
||||
bootstrapCluster: bootstrapCluster,
|
||||
raw: conn,
|
||||
client: client,
|
||||
server: server,
|
||||
joinAddr: joinAddr,
|
||||
}
|
||||
}
|
||||
|
||||
errChan := make(chan error)
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
// Set up nodes.
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < len(pairs); i++ {
|
||||
if i == 0 {
|
||||
setupNode(&pairs[i], pairs[i].bootstrapCluster, &errChan)
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
setupNode(&pairs[idx], pairs[idx].bootstrapCluster, &errChan)
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
go func() {
|
||||
wg.Wait()
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-doneChan:
|
||||
}
|
||||
|
||||
return pairs, nil
|
||||
}
|
||||
|
||||
@@ -428,275 +466,283 @@ func Test_Cluster(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// t.Run("Test_ForwardCommand", func(t *testing.T) {
|
||||
// tests := tests["forward"]
|
||||
// // Write all the data a random cluster follower.
|
||||
// for i, test := range tests {
|
||||
// // Send write command to follower node.
|
||||
// node := nodes[1]
|
||||
// if err := node.client.WriteArray([]resp.Value{
|
||||
// resp.StringValue("SET"),
|
||||
// resp.StringValue(test.key),
|
||||
// resp.StringValue(test.value),
|
||||
// }); err != nil {
|
||||
// t.Errorf("could not write data to follower node (test %d): %v", i, err)
|
||||
// }
|
||||
// // Read response and make sure we received "ok" response.
|
||||
// rd, _, err := node.client.ReadValue()
|
||||
// if err != nil {
|
||||
// t.Errorf("could not read response from follower node (test %d): %v", i, err)
|
||||
// }
|
||||
// if !strings.EqualFold(rd.String(), "ok") {
|
||||
// t.Errorf("expected response for test %d to be \"OK\", got %s", i, rd.String())
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// <-time.After(200 * time.Millisecond) // Short yield to allow change to take effect.
|
||||
//
|
||||
// // Check if the data has been replicated on a quorum (majority of the cluster).
|
||||
// quorum := int(math.Ceil(float64(len(nodes)/2)) + 1)
|
||||
// for i, test := range tests {
|
||||
// count := 0
|
||||
// for j := 0; j < len(nodes); j++ {
|
||||
// node := nodes[j]
|
||||
// if err := node.client.WriteArray([]resp.Value{
|
||||
// resp.StringValue("GET"),
|
||||
// resp.StringValue(test.key),
|
||||
// }); err != nil {
|
||||
// t.Errorf("could not write data to follower node %d (test %d): %v", j, i, err)
|
||||
// }
|
||||
// rd, _, err := node.client.ReadValue()
|
||||
// if err != nil {
|
||||
// t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err)
|
||||
// }
|
||||
// if rd.String() == test.value {
|
||||
// count += 1 // If the expected value is found, increment the count.
|
||||
// }
|
||||
// }
|
||||
// // Fail if count is less than quorum.
|
||||
// if count < quorum {
|
||||
// t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key)
|
||||
// }
|
||||
// }
|
||||
// })
|
||||
t.Run("Test_ForwardCommand", func(t *testing.T) {
|
||||
tests := tests["forward"]
|
||||
// Write all the data a random cluster follower.
|
||||
for i, test := range tests {
|
||||
// Send write command to follower node.
|
||||
node := nodes[1]
|
||||
if err := node.client.WriteArray([]resp.Value{
|
||||
resp.StringValue("SET"),
|
||||
resp.StringValue(test.key),
|
||||
resp.StringValue(test.value),
|
||||
}); err != nil {
|
||||
t.Errorf("could not write data to follower node (test %d): %v", i, err)
|
||||
}
|
||||
// Read response and make sure we received "ok" response.
|
||||
rd, _, err := node.client.ReadValue()
|
||||
if err != nil {
|
||||
t.Errorf("could not read response from follower node (test %d): %v", i, err)
|
||||
}
|
||||
if !strings.EqualFold(rd.String(), "ok") {
|
||||
t.Errorf("expected response for test %d to be \"OK\", got %s", i, rd.String())
|
||||
}
|
||||
}
|
||||
|
||||
<-time.After(1 * time.Second) // Short yield to allow change to take effect.
|
||||
|
||||
// Check if the data has been replicated on a quorum (majority of the cluster).
|
||||
quorum := int(math.Ceil(float64(len(nodes)/2)) + 1)
|
||||
for i, test := range tests {
|
||||
count := 0
|
||||
for j := 0; j < len(nodes); j++ {
|
||||
node := nodes[j]
|
||||
if err := node.client.WriteArray([]resp.Value{
|
||||
resp.StringValue("GET"),
|
||||
resp.StringValue(test.key),
|
||||
}); err != nil {
|
||||
t.Errorf("could not write data to follower node %d (test %d): %v", j, i, err)
|
||||
}
|
||||
rd, _, err := node.client.ReadValue()
|
||||
if err != nil {
|
||||
t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err)
|
||||
}
|
||||
if rd.String() == test.value {
|
||||
count += 1 // If the expected value is found, increment the count.
|
||||
}
|
||||
}
|
||||
// Fail if count is less than quorum.
|
||||
if count < quorum {
|
||||
t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_TLS(t *testing.T) {
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
func Test_Standalone(t *testing.T) {
|
||||
t.Run("Test_TLS", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conf := DefaultConfig()
|
||||
conf.DataDir = ""
|
||||
conf.BindAddr = "localhost"
|
||||
conf.Port = uint16(port)
|
||||
conf.TLS = true
|
||||
conf.CertKeyPairs = [][]string{
|
||||
{
|
||||
path.Join("..", "openssl", "server", "server1.crt"),
|
||||
path.Join("..", "openssl", "server", "server1.key"),
|
||||
},
|
||||
{
|
||||
path.Join("..", "openssl", "server", "server2.crt"),
|
||||
path.Join("..", "openssl", "server", "server2.key"),
|
||||
},
|
||||
}
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
server, err := NewEchoVault(WithConfig(conf))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
conf := DefaultConfig()
|
||||
conf.DataDir = ""
|
||||
conf.BindAddr = "localhost"
|
||||
conf.Port = uint16(port)
|
||||
conf.TLS = true
|
||||
conf.CertKeyPairs = [][]string{
|
||||
{
|
||||
path.Join("..", "openssl", "server", "server1.crt"),
|
||||
path.Join("..", "openssl", "server", "server1.key"),
|
||||
},
|
||||
{
|
||||
path.Join("..", "openssl", "server", "server2.crt"),
|
||||
path.Join("..", "openssl", "server", "server2.key"),
|
||||
},
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
server.Start()
|
||||
}()
|
||||
wg.Wait()
|
||||
server, err := NewEchoVault(WithConfig(conf))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Dial with ServerCAs
|
||||
serverCAs := x509.NewCertPool()
|
||||
f, err := os.Open(path.Join("..", "openssl", "server", "rootCA.crt"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
cert, err := io.ReadAll(bufio.NewReader(f))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
ok := serverCAs.AppendCertsFromPEM(cert)
|
||||
if !ok {
|
||||
t.Error("could not load server CA")
|
||||
}
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
server.Start()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
conn, err := tls.Dial("tcp", fmt.Sprintf("localhost:%d", port), &tls.Config{
|
||||
RootCAs: serverCAs,
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
server.ShutDown()
|
||||
}()
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
// Test that we can set and get a value from the server.
|
||||
key := "key1"
|
||||
value := "value1"
|
||||
err = client.WriteArray([]resp.Value{
|
||||
resp.StringValue("SET"), resp.StringValue(key), resp.StringValue(value),
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !strings.EqualFold(res.String(), "ok") {
|
||||
t.Errorf("expected response OK, got \"%s\"", res.String())
|
||||
}
|
||||
|
||||
err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
res, _, err = client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if res.String() != value {
|
||||
t.Errorf("expected response at key \"%s\" to be \"%s\", got \"%s\"", key, value, res.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_MTLS(t *testing.T) {
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
conf := DefaultConfig()
|
||||
conf.DataDir = ""
|
||||
conf.BindAddr = "localhost"
|
||||
conf.Port = uint16(port)
|
||||
conf.TLS = true
|
||||
conf.MTLS = true
|
||||
conf.ClientCAs = []string{
|
||||
path.Join("..", "openssl", "client", "rootCA.crt"),
|
||||
}
|
||||
conf.CertKeyPairs = [][]string{
|
||||
{
|
||||
path.Join("..", "openssl", "server", "server1.crt"),
|
||||
path.Join("..", "openssl", "server", "server1.key"),
|
||||
},
|
||||
{
|
||||
path.Join("..", "openssl", "server", "server2.crt"),
|
||||
path.Join("..", "openssl", "server", "server2.key"),
|
||||
},
|
||||
}
|
||||
|
||||
server, err := NewEchoVault(WithConfig(conf))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
server.Start()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
// Dial with ServerCAs and client certificates
|
||||
clientCertKeyPairs := [][]string{
|
||||
{
|
||||
path.Join("..", "openssl", "client", "client1.crt"),
|
||||
path.Join("..", "openssl", "client", "client1.key"),
|
||||
},
|
||||
{
|
||||
path.Join("..", "openssl", "client", "client2.crt"),
|
||||
path.Join("..", "openssl", "client", "client2.key"),
|
||||
},
|
||||
}
|
||||
var certificates []tls.Certificate
|
||||
for _, pair := range clientCertKeyPairs {
|
||||
c, err := tls.LoadX509KeyPair(pair[0], pair[1])
|
||||
// Dial with ServerCAs
|
||||
serverCAs := x509.NewCertPool()
|
||||
f, err := os.Open(path.Join("..", "openssl", "server", "rootCA.crt"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
certificates = append(certificates, c)
|
||||
}
|
||||
cert, err := io.ReadAll(bufio.NewReader(f))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
ok := serverCAs.AppendCertsFromPEM(cert)
|
||||
if !ok {
|
||||
t.Error("could not load server CA")
|
||||
}
|
||||
|
||||
serverCAs := x509.NewCertPool()
|
||||
f, err := os.Open(path.Join("..", "openssl", "server", "rootCA.crt"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
cert, err := io.ReadAll(bufio.NewReader(f))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
ok := serverCAs.AppendCertsFromPEM(cert)
|
||||
if !ok {
|
||||
t.Error("could not load server CA")
|
||||
}
|
||||
conn, err := internal.GetTLSConnection("localhost", port, &tls.Config{
|
||||
RootCAs: serverCAs,
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
server.ShutDown()
|
||||
}()
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
conn, err := tls.Dial("tcp", fmt.Sprintf("localhost:%d", port), &tls.Config{
|
||||
RootCAs: serverCAs,
|
||||
Certificates: certificates,
|
||||
// Test that we can set and get a value from the server.
|
||||
key := "key1"
|
||||
value := "value1"
|
||||
err = client.WriteArray([]resp.Value{
|
||||
resp.StringValue("SET"), resp.StringValue(key), resp.StringValue(value),
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !strings.EqualFold(res.String(), "ok") {
|
||||
t.Errorf("expected response OK, got \"%s\"", res.String())
|
||||
}
|
||||
|
||||
err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
res, _, err = client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if res.String() != value {
|
||||
t.Errorf("expected response at key \"%s\" to be \"%s\", got \"%s\"", key, value, res.String())
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
server.ShutDown()
|
||||
}()
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
// Test that we can set and get a value from the server.
|
||||
key := "key1"
|
||||
value := "value1"
|
||||
err = client.WriteArray([]resp.Value{
|
||||
resp.StringValue("SET"), resp.StringValue(key), resp.StringValue(value),
|
||||
t.Run("Test_MTLS", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
conf := DefaultConfig()
|
||||
conf.DataDir = ""
|
||||
conf.BindAddr = "localhost"
|
||||
conf.Port = uint16(port)
|
||||
conf.TLS = true
|
||||
conf.MTLS = true
|
||||
conf.ClientCAs = []string{
|
||||
path.Join("..", "openssl", "client", "rootCA.crt"),
|
||||
}
|
||||
conf.CertKeyPairs = [][]string{
|
||||
{
|
||||
path.Join("..", "openssl", "server", "server1.crt"),
|
||||
path.Join("..", "openssl", "server", "server1.key"),
|
||||
},
|
||||
{
|
||||
path.Join("..", "openssl", "server", "server2.crt"),
|
||||
path.Join("..", "openssl", "server", "server2.key"),
|
||||
},
|
||||
}
|
||||
|
||||
server, err := NewEchoVault(WithConfig(conf))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
server.Start()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
// Dial with ServerCAs and client certificates
|
||||
clientCertKeyPairs := [][]string{
|
||||
{
|
||||
path.Join("..", "openssl", "client", "client1.crt"),
|
||||
path.Join("..", "openssl", "client", "client1.key"),
|
||||
},
|
||||
{
|
||||
path.Join("..", "openssl", "client", "client2.crt"),
|
||||
path.Join("..", "openssl", "client", "client2.key"),
|
||||
},
|
||||
}
|
||||
var certificates []tls.Certificate
|
||||
for _, pair := range clientCertKeyPairs {
|
||||
c, err := tls.LoadX509KeyPair(pair[0], pair[1])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
certificates = append(certificates, c)
|
||||
}
|
||||
|
||||
serverCAs := x509.NewCertPool()
|
||||
f, err := os.Open(path.Join("..", "openssl", "server", "rootCA.crt"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
cert, err := io.ReadAll(bufio.NewReader(f))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
ok := serverCAs.AppendCertsFromPEM(cert)
|
||||
if !ok {
|
||||
t.Error("could not load server CA")
|
||||
}
|
||||
|
||||
conn, err := internal.GetTLSConnection("localhost", port, &tls.Config{
|
||||
RootCAs: serverCAs,
|
||||
Certificates: certificates,
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
server.ShutDown()
|
||||
}()
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
// Test that we can set and get a value from the server.
|
||||
key := "key1"
|
||||
value := "value1"
|
||||
err = client.WriteArray([]resp.Value{
|
||||
resp.StringValue("SET"), resp.StringValue(key), resp.StringValue(value),
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !strings.EqualFold(res.String(), "ok") {
|
||||
t.Errorf("expected response OK, got \"%s\"", res.String())
|
||||
}
|
||||
|
||||
err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
res, _, err = client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if res.String() != value {
|
||||
t.Errorf("expected response at key \"%s\" to be \"%s\", got \"%s\"", key, value, res.String())
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !strings.EqualFold(res.String(), "ok") {
|
||||
t.Errorf("expected response OK, got \"%s\"", res.String())
|
||||
}
|
||||
|
||||
err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
res, _, err = client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if res.String() != value {
|
||||
t.Errorf("expected response at key \"%s\" to be \"%s\", got \"%s\"", key, value, res.String())
|
||||
}
|
||||
}
|
||||
|
@@ -23,9 +23,7 @@ import (
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/internal/constants"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -56,13 +54,9 @@ func Test_Generic(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
mockServer.Start()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
t.Cleanup(func() {
|
||||
mockServer.ShutDown()
|
||||
@@ -70,7 +64,7 @@ func Test_Generic(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSET", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -475,7 +469,7 @@ func Test_Generic(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleMSET", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -567,7 +561,7 @@ func Test_Generic(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleGET", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -686,7 +680,7 @@ func Test_Generic(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleMGET", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -793,7 +787,7 @@ func Test_Generic(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleDEL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -905,7 +899,7 @@ func Test_Generic(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandlePERSIST", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -1062,7 +1056,7 @@ func Test_Generic(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleEXPIRETIME", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -1183,7 +1177,7 @@ func Test_Generic(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleTTL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -1304,7 +1298,7 @@ func Test_Generic(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleEXPIRE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -1591,7 +1585,7 @@ func Test_Generic(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleEXPIREAT", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
@@ -16,17 +16,14 @@ package hash_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/echovault"
|
||||
"github.com/echovault/echovault/internal"
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/internal/constants"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -50,13 +47,9 @@ func Test_Hash(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
mockServer.Start()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
t.Cleanup(func() {
|
||||
mockServer.ShutDown()
|
||||
@@ -64,7 +57,7 @@ func Test_Hash(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleHSET", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -248,7 +241,7 @@ func Test_Hash(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleHINCRBY", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -450,7 +443,7 @@ func Test_Hash(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleHGET", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -606,7 +599,7 @@ func Test_Hash(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleHSTRLEN", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -764,7 +757,7 @@ func Test_Hash(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleHVALS", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -906,7 +899,7 @@ func Test_Hash(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleHRANDFIELD", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -1115,7 +1108,7 @@ func Test_Hash(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleHLEN", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -1241,7 +1234,7 @@ func Test_Hash(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleHKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -1369,7 +1362,7 @@ func Test_Hash(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleHGETALL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -1510,7 +1503,7 @@ func Test_Hash(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleHEXISTS", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -1637,7 +1630,7 @@ func Test_Hash(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleHDEL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
@@ -16,17 +16,14 @@ package list_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/echovault"
|
||||
"github.com/echovault/echovault/internal"
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/internal/constants"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -50,13 +47,9 @@ func Test_List(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
mockServer.Start()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
t.Cleanup(func() {
|
||||
mockServer.ShutDown()
|
||||
@@ -64,7 +57,7 @@ func Test_List(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleLLEN", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -188,7 +181,7 @@ func Test_List(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleLINDEX", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -352,7 +345,7 @@ func Test_List(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleLRANGE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -541,7 +534,7 @@ func Test_List(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleLSET", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -730,7 +723,7 @@ func Test_List(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleLTRIM", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -930,7 +923,7 @@ func Test_List(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleLREM", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -1095,7 +1088,7 @@ func Test_List(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleLMOVE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -1326,7 +1319,7 @@ func Test_List(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleLPUSH", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -1481,7 +1474,7 @@ func Test_List(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleRPUSH", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -1636,7 +1629,7 @@ func Test_List(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandlePOP", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -16,18 +16,15 @@ package set_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/echovault"
|
||||
"github.com/echovault/echovault/internal"
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/internal/constants"
|
||||
"github.com/echovault/echovault/internal/modules/set"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -51,13 +48,9 @@ func Test_Set(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
mockServer.Start()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
t.Cleanup(func() {
|
||||
mockServer.ShutDown()
|
||||
@@ -65,7 +58,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSADD", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -213,7 +206,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSCARD", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -340,7 +333,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSDIFF", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -490,7 +483,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSDIFFSTORE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -672,7 +665,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSINTER", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -822,7 +815,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSINTERCARD", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -977,7 +970,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSINTERSTORE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -1159,7 +1152,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSISMEMBER", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -1281,7 +1274,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSMEMBERS", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -1408,7 +1401,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSMISMEMBER", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -1535,7 +1528,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSMOVE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -1731,7 +1724,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSPOP", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -1873,7 +1866,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSRANDMEMBER", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -2038,7 +2031,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSREM", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -2179,7 +2172,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSUNION", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -2322,7 +2315,7 @@ func Test_Set(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSUNIONSTORE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
|
@@ -16,7 +16,6 @@ package sorted_set_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/echovault"
|
||||
"github.com/echovault/echovault/internal"
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
@@ -24,11 +23,9 @@ import (
|
||||
"github.com/echovault/echovault/internal/modules/sorted_set"
|
||||
"github.com/tidwall/resp"
|
||||
"math"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -52,13 +49,9 @@ func Test_SortedSet(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
mockServer.Start()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
t.Cleanup(func() {
|
||||
mockServer.ShutDown()
|
||||
@@ -66,7 +59,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZADD", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -282,7 +275,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZCARD", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -413,7 +406,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZCOUNT", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -588,7 +581,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZLEXCOUNT", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -739,7 +732,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZDIFF", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -959,7 +952,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZDIFFSTORE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -1202,7 +1195,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZINCRBY", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -1485,7 +1478,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZMPOP", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -1800,7 +1793,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZPOP", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -2061,7 +2054,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZMSCORE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -2189,7 +2182,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZSCORE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -2326,7 +2319,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZRANDMEMBER", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -2523,7 +2516,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZRANK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -2680,7 +2673,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZREM", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -2853,7 +2846,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZREMRANGEBYSCORE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -3029,7 +3022,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZREMRANGEBYRANK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -3259,7 +3252,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZREMRANGEBYLEX", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -3460,7 +3453,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZRANGE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -3761,7 +3754,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZRANGESTORE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -4112,7 +4105,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZINTER", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -4480,7 +4473,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZINTERSTORE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -4895,7 +4888,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZUNION", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
@@ -5288,7 +5281,7 @@ func Test_SortedSet(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleZUNIONSTORE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error()
|
||||
return
|
||||
|
@@ -16,16 +16,13 @@ package str_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/echovault"
|
||||
"github.com/echovault/echovault/internal"
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/internal/constants"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -49,13 +46,9 @@ func Test_String(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
mockServer.Start()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
t.Cleanup(func() {
|
||||
mockServer.ShutDown()
|
||||
@@ -63,7 +56,7 @@ func Test_String(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSetRange", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -213,7 +206,7 @@ func Test_String(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleStrLen", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -314,7 +307,7 @@ func Test_String(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleSubStr", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
@@ -18,6 +18,7 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"cmp"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/internal/constants"
|
||||
@@ -30,6 +31,7 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sethvargo/go-retry"
|
||||
@@ -429,15 +431,51 @@ func GetFreePort() (int, error) {
|
||||
}
|
||||
|
||||
func GetConnection(addr string, port int) (net.Conn, error) {
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for {
|
||||
// Wait until connection is no longer nil.
|
||||
if conn != nil {
|
||||
var conn net.Conn
|
||||
var err error
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port))
|
||||
if err != nil && errors.Is(err.(*net.OpError), syscall.ECONNREFUSED) {
|
||||
// If we get a "connection refused error, try again."
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(10 * time.Second):
|
||||
return nil, errors.New("connection timeout")
|
||||
case <-done:
|
||||
return conn, err
|
||||
}
|
||||
}
|
||||
|
||||
func GetTLSConnection(addr string, port int, config *tls.Config) (net.Conn, error) {
|
||||
var conn net.Conn
|
||||
var err error
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", addr, port), config)
|
||||
if err != nil && errors.Is(err.(*net.OpError), syscall.ECONNREFUSED) {
|
||||
// If we get a "connection refused error, try again."
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(10 * time.Second):
|
||||
return nil, errors.New("connection timeout")
|
||||
case <-done:
|
||||
return conn, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user