mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-05 16:06:57 +08:00
Removed test folder and moved all commands tests to their respective internal modules. Moved api tests into echovault package. This change has been made because the speratate test folder is not idiomatic and caused test coverage report to not be generated.
This commit is contained in:
884
internal/modules/pubsub/commands_test.go
Normal file
884
internal/modules/pubsub/commands_test.go
Normal file
@@ -0,0 +1,884 @@
|
||||
// Copyright 2024 Kelvin Clement Mwinuka
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pubsub_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/echovault"
|
||||
"github.com/echovault/echovault/internal"
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/internal/constants"
|
||||
"github.com/echovault/echovault/internal/modules/pubsub"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var ps *pubsub.PubSub
|
||||
var mockServer *echovault.EchoVault
|
||||
|
||||
var bindAddr = "localhost"
|
||||
var port uint16
|
||||
|
||||
func init() {
|
||||
p, _ := internal.GetFreePort()
|
||||
port = uint16(p)
|
||||
|
||||
mockServer = setUpServer(bindAddr, port)
|
||||
|
||||
getPubSub := getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getPubSub")).(func() interface{})
|
||||
ps = getPubSub().(*pubsub.PubSub)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Done()
|
||||
mockServer.Start()
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func setUpServer(bindAddr string, port uint16) *echovault.EchoVault {
|
||||
server, _ := echovault.NewEchoVault(
|
||||
echovault.WithConfig(config.Config{
|
||||
BindAddr: bindAddr,
|
||||
Port: port,
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
return server
|
||||
}
|
||||
|
||||
func getUnexportedField(field reflect.Value) interface{} {
|
||||
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
|
||||
}
|
||||
|
||||
func getHandler(commands ...string) internal.HandlerFunc {
|
||||
if len(commands) == 0 {
|
||||
return nil
|
||||
}
|
||||
getCommands :=
|
||||
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
|
||||
for _, c := range getCommands() {
|
||||
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
|
||||
// Get command handler
|
||||
return c.HandlerFunc
|
||||
}
|
||||
if strings.EqualFold(commands[0], c.Command) {
|
||||
// Get sub-command handler
|
||||
for _, sc := range c.SubCommands {
|
||||
if strings.EqualFold(commands[1], sc.Command) {
|
||||
return sc.HandlerFunc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn, mockServer *echovault.EchoVault) internal.HandlerFuncParams {
|
||||
getPubSub :=
|
||||
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getPubSub")).(func() interface{})
|
||||
return internal.HandlerFuncParams{
|
||||
Context: ctx,
|
||||
Command: cmd,
|
||||
Connection: conn,
|
||||
GetPubSub: getPubSub,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleSubscribe(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), "test_name", "SUBSCRIBE/PSUBSCRIBE")
|
||||
|
||||
numOfConnection := 20
|
||||
connections := make([]*net.Conn, numOfConnection)
|
||||
|
||||
for i := 0; i < numOfConnection; i++ {
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
connections[i] = &conn
|
||||
}
|
||||
defer func() {
|
||||
for _, conn := range connections {
|
||||
if err := (*conn).Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Test subscribe to channels
|
||||
channels := []string{"sub_channel1", "sub_channel2", "sub_channel3"}
|
||||
for _, conn := range connections {
|
||||
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), conn, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
for _, channel := range channels {
|
||||
// Check if the channel exists in the pubsub module
|
||||
if !slices.ContainsFunc(ps.GetAllChannels(), func(c *pubsub.Channel) bool {
|
||||
return c.Name() == channel
|
||||
}) {
|
||||
t.Errorf("expected pubsub to contain channel \"%s\" but it was not found", channel)
|
||||
}
|
||||
for _, c := range ps.GetAllChannels() {
|
||||
if c.Name() == channel {
|
||||
// Check if channel has nil pattern
|
||||
if c.Pattern() != nil {
|
||||
t.Errorf("expected channel \"%s\" to have nil pattern, found pattern \"%s\"", channel, c.Name())
|
||||
}
|
||||
// Check if the channel has all the connections from above
|
||||
for _, conn := range connections {
|
||||
if _, ok := c.Subscribers()[conn]; !ok {
|
||||
t.Errorf("could not find all expected connection in the \"%s\"", channel)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test subscribe to patterns
|
||||
patterns := []string{"psub_channel*"}
|
||||
for _, conn := range connections {
|
||||
_, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), conn, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
for _, pattern := range patterns {
|
||||
// Check if pattern channel exists in pubsub module
|
||||
if !slices.ContainsFunc(ps.GetAllChannels(), func(c *pubsub.Channel) bool {
|
||||
return c.Name() == pattern
|
||||
}) {
|
||||
t.Errorf("expected pubsub to contain pattern channel \"%s\" but it was not found", pattern)
|
||||
}
|
||||
for _, c := range ps.GetAllChannels() {
|
||||
if c.Name() == pattern {
|
||||
// Check if channel has non-nil pattern
|
||||
if c.Pattern() == nil {
|
||||
t.Errorf("expected channel \"%s\" to have pattern \"%s\", found nil pattern", pattern, c.Name())
|
||||
}
|
||||
// Check if the channel has all the connections from above
|
||||
for _, conn := range connections {
|
||||
if _, ok := c.Subscribers()[conn]; !ok {
|
||||
t.Errorf("could not find all expected connection in the \"%s\"", pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleUnsubscribe(t *testing.T) {
|
||||
generateConnections := func(noOfConnections int) []*net.Conn {
|
||||
connections := make([]*net.Conn, noOfConnections)
|
||||
for i := 0; i < noOfConnections; i++ {
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
connections[i] = &conn
|
||||
}
|
||||
return connections
|
||||
}
|
||||
|
||||
closeConnections := func(conns []*net.Conn) {
|
||||
for _, conn := range conns {
|
||||
if err := (*conn).Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
verifyResponse := func(res []byte, expectedResponse [][]string) {
|
||||
rd := resp.NewReader(bytes.NewReader(res))
|
||||
rv, _, err := rd.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
v := rv.Array()
|
||||
if len(v) != len(expectedResponse) {
|
||||
t.Errorf("expected subscribe response of length %d, but got %d", len(expectedResponse), len(v))
|
||||
}
|
||||
for _, item := range v {
|
||||
arr := item.Array()
|
||||
if len(arr) != 3 {
|
||||
t.Errorf("expected subscribe response item to be length %d, but got %d", 3, len(arr))
|
||||
}
|
||||
if !slices.ContainsFunc(expectedResponse, func(strings []string) bool {
|
||||
return strings[0] == arr[0].String() && strings[1] == arr[1].String() && strings[2] == arr[2].String()
|
||||
}) {
|
||||
t.Errorf("expected to find item \"%s\" in response, did not find it.", arr[1].String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
subChannels []string // All channels to subscribe to
|
||||
subPatterns []string // All patterns to subscribe to
|
||||
unSubChannels []string // Channels to unsubscribe from
|
||||
unSubPatterns []string // Patterns to unsubscribe from
|
||||
remainChannels []string // Channels to remain subscribed to
|
||||
remainPatterns []string // Patterns to remain subscribed to
|
||||
targetConn *net.Conn // Connection used to test unsubscribe functionality
|
||||
otherConnections []*net.Conn // Connections to fill the subscribers list for channels and patterns
|
||||
expectedResponses map[string][][]string // The expected response from the handler
|
||||
}{
|
||||
{ // 1. Unsubscribe from channels and patterns
|
||||
subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
|
||||
subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
|
||||
unSubChannels: []string{"xx_channel_one", "xx_channel_two"},
|
||||
unSubPatterns: []string{"xx_pattern_[ab]"},
|
||||
remainChannels: []string{"xx_channel_three", "xx_channel_four"},
|
||||
remainPatterns: []string{"xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
|
||||
targetConn: generateConnections(1)[0],
|
||||
otherConnections: generateConnections(20),
|
||||
expectedResponses: map[string][][]string{
|
||||
"channel": {
|
||||
{"unsubscribe", "xx_channel_one", "1"},
|
||||
{"unsubscribe", "xx_channel_two", "2"},
|
||||
},
|
||||
"pattern": {
|
||||
{"punsubscribe", "xx_pattern_[ab]", "1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{ // 2. Unsubscribe from all channels no channel or pattern is passed to command
|
||||
subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
|
||||
subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
|
||||
unSubChannels: []string{},
|
||||
unSubPatterns: []string{},
|
||||
remainChannels: []string{},
|
||||
remainPatterns: []string{},
|
||||
targetConn: generateConnections(1)[0],
|
||||
otherConnections: generateConnections(20),
|
||||
expectedResponses: map[string][][]string{
|
||||
"channel": {
|
||||
{"unsubscribe", "xx_channel_one", "1"},
|
||||
{"unsubscribe", "xx_channel_two", "2"},
|
||||
{"unsubscribe", "xx_channel_three", "3"},
|
||||
{"unsubscribe", "xx_channel_four", "4"},
|
||||
},
|
||||
"pattern": {
|
||||
{"punsubscribe", "xx_pattern_[ab]", "1"},
|
||||
{"punsubscribe", "xx_pattern_[cd]", "2"},
|
||||
{"punsubscribe", "xx_pattern_[ef]", "3"},
|
||||
{"punsubscribe", "xx_pattern_[gh]", "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{ // 3. Don't unsubscribe from any channels or patterns if the provided ones are non-existent
|
||||
subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
|
||||
subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
|
||||
unSubChannels: []string{"xx_channel_non_existent_channel"},
|
||||
unSubPatterns: []string{"xx_channel_non_existent_pattern_[ae]"},
|
||||
remainChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
|
||||
remainPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
|
||||
targetConn: generateConnections(1)[0],
|
||||
otherConnections: generateConnections(20),
|
||||
expectedResponses: map[string][][]string{
|
||||
"channel": {},
|
||||
"pattern": {},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("UNSUBSCRIBE/PUNSUBSCRIBE, %d", i))
|
||||
|
||||
// Subscribe all the connections to the channels and patterns
|
||||
for _, conn := range append(test.otherConnections, test.targetConn) {
|
||||
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, test.subChannels...), conn, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
_, err = getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, test.subPatterns...), conn, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Unsubscribe the target connection from the unsub channels and patterns
|
||||
res, err := getHandler("UNSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"UNSUBSCRIBE"}, test.unSubChannels...), test.targetConn, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
verifyResponse(res, test.expectedResponses["channel"])
|
||||
|
||||
res, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PUNSUBSCRIBE"}, test.unSubPatterns...), test.targetConn, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
verifyResponse(res, test.expectedResponses["pattern"])
|
||||
|
||||
for _, channel := range append(test.unSubChannels, test.unSubPatterns...) {
|
||||
for _, pubsubChannel := range ps.GetAllChannels() {
|
||||
if pubsubChannel.Name() == channel {
|
||||
// Assert that target connection is no longer in the unsub channels and patterns
|
||||
if _, ok := pubsubChannel.Subscribers()[test.targetConn]; ok {
|
||||
t.Errorf("found unexpected target connection after unsubscrining in channel \"%s\"", channel)
|
||||
}
|
||||
for _, conn := range test.otherConnections {
|
||||
if _, ok := pubsubChannel.Subscribers()[conn]; !ok {
|
||||
t.Errorf("did not find expected other connection in channel \"%s\"", channel)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Assert that the target connection is still in the remain channels and patterns
|
||||
for _, channel := range append(test.remainChannels, test.remainPatterns...) {
|
||||
for _, pubsubChannel := range ps.GetAllChannels() {
|
||||
if pubsubChannel.Name() == channel {
|
||||
if _, ok := pubsubChannel.Subscribers()[test.targetConn]; !ok {
|
||||
t.Errorf("could not find expected target connection in channel \"%s\"", channel)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
// Close all the connections
|
||||
closeConnections(append(test.otherConnections, test.targetConn))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandlePublish(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), "test_name", "PUBLISH")
|
||||
|
||||
// verifyChannelMessage reads the message from the connection and asserts whether
|
||||
// it's the message we expect to read as a subscriber of a channel or pattern.
|
||||
verifyEvent := func(c *net.Conn, r *resp.Conn, expected []string) {
|
||||
if err := (*c).SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
rv, _, err := r.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
v := rv.Array()
|
||||
for i := 0; i < len(v); i++ {
|
||||
if v[i].String() != expected[i] {
|
||||
t.Errorf("expected item at index %d to be \"%s\", got \"%s\"", i, expected[i], v[i].String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The subscribe function handles subscribing the connection to the given
|
||||
// channels and patterns and reading/verifying the message sent by the echovault after
|
||||
// subscription.
|
||||
subscribe := func(ctx context.Context, channels []string, patterns []string, c *net.Conn, r *resp.Conn) {
|
||||
// Subscribe to channels
|
||||
go func() {
|
||||
_, _ = getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), c, mockServer))
|
||||
}()
|
||||
// Verify all the responses for each channel subscription
|
||||
for i := 0; i < len(channels); i++ {
|
||||
verifyEvent(c, r, []string{"subscribe", channels[i], fmt.Sprintf("%d", i+1)})
|
||||
}
|
||||
// Subscribe to all the patterns
|
||||
go func() {
|
||||
_, _ = getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), c, mockServer))
|
||||
}()
|
||||
// Verify all the responses for each pattern subscription
|
||||
for i := 0; i < len(patterns); i++ {
|
||||
verifyEvent(c, r, []string{"psubscribe", patterns[i], fmt.Sprintf("%d", i+1)})
|
||||
}
|
||||
}
|
||||
|
||||
subscriptions := map[string]map[string][]string{
|
||||
"subscriber1": {
|
||||
"channels": {"pub_channel_1", "pub_channel_2", "pub_channel_3"}, // Channels to subscribe to
|
||||
"patterns": {"pub_channel_[456]"}, // Patterns to subscribe to
|
||||
},
|
||||
"subscriber2": {
|
||||
"channels": {"pub_channel_6", "pub_channel_7"}, // Channels to subscribe to
|
||||
"patterns": {"pub_channel_[891]"}, // Patterns to subscribe to
|
||||
},
|
||||
}
|
||||
|
||||
// Create subscriber one and subscribe to channels and patterns
|
||||
r1, w1 := net.Pipe()
|
||||
rc1 := resp.NewConn(r1)
|
||||
subscribe(ctx, subscriptions["subscriber1"]["channels"], subscriptions["subscriber1"]["patterns"], &w1, rc1)
|
||||
|
||||
// Create subscriber two and subscribe to channels and patterns
|
||||
r2, w2 := net.Pipe()
|
||||
rc2 := resp.NewConn(r2)
|
||||
subscribe(ctx, subscriptions["subscriber2"]["channels"], subscriptions["subscriber2"]["patterns"], &w2, rc2)
|
||||
|
||||
type SubscriberType struct {
|
||||
c *net.Conn
|
||||
r *resp.Conn
|
||||
l string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
channel string
|
||||
message string
|
||||
subscribers []SubscriberType
|
||||
}{
|
||||
{
|
||||
channel: "pub_channel_1",
|
||||
message: "Test both subscribers 1",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r1, r: rc1, l: "pub_channel_1"},
|
||||
{c: &r2, r: rc2, l: "pub_channel_[891]"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_6",
|
||||
message: "Test both subscribers 2",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r1, r: rc1, l: "pub_channel_[456]"},
|
||||
{c: &r2, r: rc2, l: "pub_channel_6"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_2",
|
||||
message: "Test subscriber 1 1",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r1, r: rc1, l: "pub_channel_2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_3",
|
||||
message: "Test subscriber 1 2",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r1, r: rc1, l: "pub_channel_3"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_4",
|
||||
message: "Test both subscribers 2",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r1, r: rc1, l: "pub_channel_[456]"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_5",
|
||||
message: "Test subscriber 1 3",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r1, r: rc1, l: "pub_channel_[456]"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_7",
|
||||
message: "Test subscriber 2 1",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r2, r: rc2, l: "pub_channel_7"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_8",
|
||||
message: "Test subscriber 2 2",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r1, r: rc2, l: "pub_channel_[891]"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_9",
|
||||
message: "Test subscriber 2 3",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r2, r: rc2, l: "pub_channel_[891]"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Dial echovault to make publisher connection
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer func() {
|
||||
if err = conn.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
w := resp.NewConn(conn)
|
||||
|
||||
for _, test := range tests {
|
||||
err = w.WriteArray([]resp.Value{
|
||||
resp.StringValue("PUBLISH"),
|
||||
resp.StringValue(test.channel),
|
||||
resp.StringValue(test.message),
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
rv, _, err := w.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if rv.String() != "OK" {
|
||||
t.Errorf("Expected publish response to be \"OK\", got \"%s\"", rv.String())
|
||||
}
|
||||
|
||||
for _, sub := range test.subscribers {
|
||||
verifyEvent(sub.c, sub.r, []string{"message", sub.l, test.message})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandlePubSubChannels(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// Create separate mock echovault for this test
|
||||
port, _ := internal.GetFreePort()
|
||||
mockServer := setUpServer(bindAddr, uint16(port))
|
||||
|
||||
ctx := context.WithValue(context.Background(), "test_name", "PUBSUB CHANNELS")
|
||||
|
||||
channels := []string{"channel_1", "channel_2", "channel_3"}
|
||||
patterns := []string{"channel_[123]", "channel_[456]"}
|
||||
|
||||
rConn1, wConn1 := net.Pipe()
|
||||
rc1 := resp.NewConn(rConn1)
|
||||
|
||||
rConn2, wConn2 := net.Pipe()
|
||||
rc2 := resp.NewConn(rConn2)
|
||||
|
||||
// Subscribe connections to channels
|
||||
go func() {
|
||||
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), &wConn1, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
for i := 0; i < len(channels); i++ {
|
||||
v, _, err := rc1.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !slices.ContainsFunc(channels, func(s string) bool {
|
||||
return s == v.Array()[1].String()
|
||||
}) {
|
||||
t.Errorf("unexpected channel %s in response", v.Array()[1].String())
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
_, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), &wConn2, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
for i := 0; i < len(patterns); i++ {
|
||||
v, _, err := rc2.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !slices.ContainsFunc(patterns, func(s string) bool {
|
||||
return s == v.Array()[1].String()
|
||||
}) {
|
||||
t.Errorf("unexpected pattern %s in response", v.Array()[1].String())
|
||||
}
|
||||
}
|
||||
|
||||
verifyExpectedResponse := func(res []byte, expected []string) {
|
||||
rd := resp.NewReader(bytes.NewReader(res))
|
||||
rv, _, err := rd.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if len(rv.Array()) != len(expected) {
|
||||
t.Errorf("expected response array of length %d, got %d", len(expected), len(rv.Array()))
|
||||
}
|
||||
for _, e := range expected {
|
||||
if !slices.ContainsFunc(rv.Array(), func(v resp.Value) bool {
|
||||
return e == v.String()
|
||||
}) {
|
||||
t.Errorf("expected to find element \"%s\" in response array, could not find it", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all subscriptions are returned
|
||||
res, err := getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
verifyExpectedResponse(res, append(channels, patterns...))
|
||||
|
||||
// Unsubscribe from one pattern and one channel before checking against a new slice of
|
||||
// expected channels/patterns in the response of the "PUBSUB CHANNELS" command
|
||||
_, err = getHandler("UNSUBSCRIBE")(getHandlerFuncParams(
|
||||
ctx,
|
||||
append([]string{"UNSUBSCRIBE"}, []string{"channel_2", "channel_3"}...),
|
||||
&wConn1,
|
||||
mockServer,
|
||||
))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
_, err = getHandler("UNSUBSCRIBE")(getHandlerFuncParams(
|
||||
ctx,
|
||||
append([]string{"UNSUBSCRIBE"}, "channel_[456]"),
|
||||
&wConn2,
|
||||
mockServer,
|
||||
))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Return all the remaining channels
|
||||
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"})
|
||||
// Return only one of the remaining channels when passed a pattern that matches it
|
||||
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS", "channel_[189]"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
verifyExpectedResponse(res, []string{"channel_1"})
|
||||
// Return both remaining channels when passed a pattern that matches them
|
||||
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS", "channel_[123]"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"})
|
||||
// Return none channels when passed a pattern that does not match either channel
|
||||
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS", "channel_[456]"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
verifyExpectedResponse(res, []string{})
|
||||
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("timeout")
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleNumPat(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// Create separate mock echovault for this test
|
||||
port, _ := internal.GetFreePort()
|
||||
mockServer := setUpServer(bindAddr, uint16(port))
|
||||
|
||||
ctx := context.WithValue(context.Background(), "test_name", "PUBSUB NUMPAT")
|
||||
|
||||
patterns := []string{"pattern_[123]", "pattern_[456]", "pattern_[789]"}
|
||||
|
||||
connections := make([]struct {
|
||||
w *net.Conn
|
||||
r *resp.Conn
|
||||
}, 3)
|
||||
for i := 0; i < len(connections); i++ {
|
||||
w, r := net.Pipe()
|
||||
connections[i] = struct {
|
||||
w *net.Conn
|
||||
r *resp.Conn
|
||||
}{w: &w, r: resp.NewConn(r)}
|
||||
go func() {
|
||||
_, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), &w, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
for j := 0; j < len(patterns); j++ {
|
||||
v, _, err := connections[i].r.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
arr := v.Array()
|
||||
if !slices.ContainsFunc(patterns, func(s string) bool {
|
||||
return s == arr[1].String()
|
||||
}) {
|
||||
t.Errorf("found unexpected pattern in response \"%s\"", arr[1].String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
verifyNumPatResponse := func(res []byte, expected int) {
|
||||
rd := resp.NewReader(bytes.NewReader(res))
|
||||
rv, _, err := rd.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if rv.Integer() != expected {
|
||||
t.Errorf("expected first NUMPAT response to be %d, got %d", expected, rv.Integer())
|
||||
}
|
||||
}
|
||||
|
||||
// Check that we receive all the patterns with NUMPAT commands
|
||||
res, err := getHandler("PUBSUB", "NUMPAT")(getHandlerFuncParams(ctx, []string{"PUBSUB", "NUMPAT"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
verifyNumPatResponse(res, len(patterns))
|
||||
|
||||
// Unsubscribe from a channel and check if the number of active channels is updated
|
||||
for _, conn := range connections {
|
||||
_, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, []string{"PUNSUBSCRIBE", patterns[0]}, conn.w, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
res, err = getHandler("PUBSUB", "NUMPAT")(getHandlerFuncParams(ctx, []string{"PUBSUB", "NUMPAT"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
verifyNumPatResponse(res, len(patterns)-1)
|
||||
|
||||
// Unsubscribe from all the channels and check if we get a 0 response
|
||||
for _, conn := range connections {
|
||||
_, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, []string{"PUNSUBSCRIBE"}, conn.w, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
res, err = getHandler("PUBSUB", "NUMPAT")(getHandlerFuncParams(ctx, []string{"PUBSUB", "NUMPAT"}, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
verifyNumPatResponse(res, 0)
|
||||
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("timeout")
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleNumSub(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// Create separate mock echovault for this test
|
||||
port, _ := internal.GetFreePort()
|
||||
mockServer := setUpServer(bindAddr, uint16(port))
|
||||
|
||||
ctx := context.WithValue(context.Background(), "test_name", "PUBSUB NUMSUB")
|
||||
|
||||
channels := []string{"channel_1", "channel_2", "channel_3"}
|
||||
connections := make([]struct {
|
||||
w *net.Conn
|
||||
r *resp.Conn
|
||||
}, 3)
|
||||
for i := 0; i < len(connections); i++ {
|
||||
w, r := net.Pipe()
|
||||
connections[i] = struct {
|
||||
w *net.Conn
|
||||
r *resp.Conn
|
||||
}{w: &w, r: resp.NewConn(r)}
|
||||
go func() {
|
||||
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), &w, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
for j := 0; j < len(channels); j++ {
|
||||
v, _, err := connections[i].r.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
arr := v.Array()
|
||||
if !slices.ContainsFunc(channels, func(s string) bool {
|
||||
return s == arr[1].String()
|
||||
}) {
|
||||
t.Errorf("found unexpected pattern in response \"%s\"", arr[1].String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
cmd []string
|
||||
expectedResponse [][]string
|
||||
}{
|
||||
{ // 1. Get all subscriptions on existing channels
|
||||
cmd: append([]string{"PUBSUB", "NUMSUB"}, channels...),
|
||||
expectedResponse: [][]string{{"channel_1", "3"}, {"channel_2", "3"}, {"channel_3", "3"}},
|
||||
},
|
||||
{ // 2. Get all the subscriptions of on existing channels and a few non-existent ones
|
||||
cmd: append([]string{"PUBSUB", "NUMSUB", "non_existent_channel_1", "non_existent_channel_2"}, channels...),
|
||||
expectedResponse: [][]string{
|
||||
{"non_existent_channel_1", "0"},
|
||||
{"non_existent_channel_2", "0"},
|
||||
{"channel_1", "3"},
|
||||
{"channel_2", "3"},
|
||||
{"channel_3", "3"},
|
||||
},
|
||||
},
|
||||
{ // 3. Get an empty array when channels are not provided in the command
|
||||
cmd: []string{"PUBSUB", "NUMSUB"},
|
||||
expectedResponse: make([][]string, 0),
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
ctx = context.WithValue(ctx, "test_index", i)
|
||||
|
||||
res, err := getHandler("PUBSUB", "NUMSUB")(getHandlerFuncParams(ctx, test.cmd, nil, mockServer))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
rd := resp.NewReader(bytes.NewReader(res))
|
||||
rv, _, err := rd.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
arr := rv.Array()
|
||||
if len(arr) != len(test.expectedResponse) {
|
||||
t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(arr))
|
||||
}
|
||||
|
||||
for _, item := range arr {
|
||||
itemArr := item.Array()
|
||||
if len(itemArr) != 2 {
|
||||
t.Errorf("expected each response item to be of length 2, got %d", len(itemArr))
|
||||
}
|
||||
if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool {
|
||||
return expected[0] == itemArr[0].String() && expected[1] == itemArr[1].String()
|
||||
}) {
|
||||
t.Errorf("could not find entry with channel \"%s\", with %d subscribers in expected response",
|
||||
itemArr[0].String(), itemArr[1].Integer())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("timeout")
|
||||
case <-done:
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user