Removed test folder and moved all commands tests to their respective internal modules. Moved api tests into echovault package. This change has been made because the speratate test folder is not idiomatic and caused test coverage report to not be generated.

This commit is contained in:
Kelvin Clement Mwinuka
2024-05-04 17:45:10 +08:00
parent eb386d5b8f
commit 193871ec72
22 changed files with 4035 additions and 318 deletions

View File

@@ -0,0 +1,884 @@
// Copyright 2024 Kelvin Clement Mwinuka
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pubsub_test
import (
"bytes"
"context"
"fmt"
"github.com/echovault/echovault/echovault"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/constants"
"github.com/echovault/echovault/internal/modules/pubsub"
"github.com/tidwall/resp"
"net"
"reflect"
"slices"
"strings"
"sync"
"testing"
"time"
"unsafe"
)
var ps *pubsub.PubSub
var mockServer *echovault.EchoVault
var bindAddr = "localhost"
var port uint16
func init() {
p, _ := internal.GetFreePort()
port = uint16(p)
mockServer = setUpServer(bindAddr, port)
getPubSub := getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getPubSub")).(func() interface{})
ps = getPubSub().(*pubsub.PubSub)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
mockServer.Start()
}()
wg.Wait()
}
func setUpServer(bindAddr string, port uint16) *echovault.EchoVault {
server, _ := echovault.NewEchoVault(
echovault.WithConfig(config.Config{
BindAddr: bindAddr,
Port: port,
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
return server
}
func getUnexportedField(field reflect.Value) interface{} {
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
}
func getHandler(commands ...string) internal.HandlerFunc {
if len(commands) == 0 {
return nil
}
getCommands :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command)
for _, c := range getCommands() {
if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 {
// Get command handler
return c.HandlerFunc
}
if strings.EqualFold(commands[0], c.Command) {
// Get sub-command handler
for _, sc := range c.SubCommands {
if strings.EqualFold(commands[1], sc.Command) {
return sc.HandlerFunc
}
}
}
}
return nil
}
func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn, mockServer *echovault.EchoVault) internal.HandlerFuncParams {
getPubSub :=
getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getPubSub")).(func() interface{})
return internal.HandlerFuncParams{
Context: ctx,
Command: cmd,
Connection: conn,
GetPubSub: getPubSub,
}
}
func Test_HandleSubscribe(t *testing.T) {
ctx := context.WithValue(context.Background(), "test_name", "SUBSCRIBE/PSUBSCRIBE")
numOfConnection := 20
connections := make([]*net.Conn, numOfConnection)
for i := 0; i < numOfConnection; i++ {
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil {
t.Error(err)
}
connections[i] = &conn
}
defer func() {
for _, conn := range connections {
if err := (*conn).Close(); err != nil {
t.Error(err)
}
}
}()
// Test subscribe to channels
channels := []string{"sub_channel1", "sub_channel2", "sub_channel3"}
for _, conn := range connections {
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), conn, mockServer))
if err != nil {
t.Error(err)
}
}
for _, channel := range channels {
// Check if the channel exists in the pubsub module
if !slices.ContainsFunc(ps.GetAllChannels(), func(c *pubsub.Channel) bool {
return c.Name() == channel
}) {
t.Errorf("expected pubsub to contain channel \"%s\" but it was not found", channel)
}
for _, c := range ps.GetAllChannels() {
if c.Name() == channel {
// Check if channel has nil pattern
if c.Pattern() != nil {
t.Errorf("expected channel \"%s\" to have nil pattern, found pattern \"%s\"", channel, c.Name())
}
// Check if the channel has all the connections from above
for _, conn := range connections {
if _, ok := c.Subscribers()[conn]; !ok {
t.Errorf("could not find all expected connection in the \"%s\"", channel)
}
}
}
}
}
// Test subscribe to patterns
patterns := []string{"psub_channel*"}
for _, conn := range connections {
_, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), conn, mockServer))
if err != nil {
t.Error(err)
}
}
for _, pattern := range patterns {
// Check if pattern channel exists in pubsub module
if !slices.ContainsFunc(ps.GetAllChannels(), func(c *pubsub.Channel) bool {
return c.Name() == pattern
}) {
t.Errorf("expected pubsub to contain pattern channel \"%s\" but it was not found", pattern)
}
for _, c := range ps.GetAllChannels() {
if c.Name() == pattern {
// Check if channel has non-nil pattern
if c.Pattern() == nil {
t.Errorf("expected channel \"%s\" to have pattern \"%s\", found nil pattern", pattern, c.Name())
}
// Check if the channel has all the connections from above
for _, conn := range connections {
if _, ok := c.Subscribers()[conn]; !ok {
t.Errorf("could not find all expected connection in the \"%s\"", pattern)
}
}
}
}
}
}
func Test_HandleUnsubscribe(t *testing.T) {
generateConnections := func(noOfConnections int) []*net.Conn {
connections := make([]*net.Conn, noOfConnections)
for i := 0; i < noOfConnections; i++ {
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil {
t.Error(err)
}
connections[i] = &conn
}
return connections
}
closeConnections := func(conns []*net.Conn) {
for _, conn := range conns {
if err := (*conn).Close(); err != nil {
t.Error(err)
}
}
}
verifyResponse := func(res []byte, expectedResponse [][]string) {
rd := resp.NewReader(bytes.NewReader(res))
rv, _, err := rd.ReadValue()
if err != nil {
t.Error(err)
}
v := rv.Array()
if len(v) != len(expectedResponse) {
t.Errorf("expected subscribe response of length %d, but got %d", len(expectedResponse), len(v))
}
for _, item := range v {
arr := item.Array()
if len(arr) != 3 {
t.Errorf("expected subscribe response item to be length %d, but got %d", 3, len(arr))
}
if !slices.ContainsFunc(expectedResponse, func(strings []string) bool {
return strings[0] == arr[0].String() && strings[1] == arr[1].String() && strings[2] == arr[2].String()
}) {
t.Errorf("expected to find item \"%s\" in response, did not find it.", arr[1].String())
}
}
}
tests := []struct {
subChannels []string // All channels to subscribe to
subPatterns []string // All patterns to subscribe to
unSubChannels []string // Channels to unsubscribe from
unSubPatterns []string // Patterns to unsubscribe from
remainChannels []string // Channels to remain subscribed to
remainPatterns []string // Patterns to remain subscribed to
targetConn *net.Conn // Connection used to test unsubscribe functionality
otherConnections []*net.Conn // Connections to fill the subscribers list for channels and patterns
expectedResponses map[string][][]string // The expected response from the handler
}{
{ // 1. Unsubscribe from channels and patterns
subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
unSubChannels: []string{"xx_channel_one", "xx_channel_two"},
unSubPatterns: []string{"xx_pattern_[ab]"},
remainChannels: []string{"xx_channel_three", "xx_channel_four"},
remainPatterns: []string{"xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
targetConn: generateConnections(1)[0],
otherConnections: generateConnections(20),
expectedResponses: map[string][][]string{
"channel": {
{"unsubscribe", "xx_channel_one", "1"},
{"unsubscribe", "xx_channel_two", "2"},
},
"pattern": {
{"punsubscribe", "xx_pattern_[ab]", "1"},
},
},
},
{ // 2. Unsubscribe from all channels no channel or pattern is passed to command
subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
unSubChannels: []string{},
unSubPatterns: []string{},
remainChannels: []string{},
remainPatterns: []string{},
targetConn: generateConnections(1)[0],
otherConnections: generateConnections(20),
expectedResponses: map[string][][]string{
"channel": {
{"unsubscribe", "xx_channel_one", "1"},
{"unsubscribe", "xx_channel_two", "2"},
{"unsubscribe", "xx_channel_three", "3"},
{"unsubscribe", "xx_channel_four", "4"},
},
"pattern": {
{"punsubscribe", "xx_pattern_[ab]", "1"},
{"punsubscribe", "xx_pattern_[cd]", "2"},
{"punsubscribe", "xx_pattern_[ef]", "3"},
{"punsubscribe", "xx_pattern_[gh]", "4"},
},
},
},
{ // 3. Don't unsubscribe from any channels or patterns if the provided ones are non-existent
subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
unSubChannels: []string{"xx_channel_non_existent_channel"},
unSubPatterns: []string{"xx_channel_non_existent_pattern_[ae]"},
remainChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
remainPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
targetConn: generateConnections(1)[0],
otherConnections: generateConnections(20),
expectedResponses: map[string][][]string{
"channel": {},
"pattern": {},
},
},
}
for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("UNSUBSCRIBE/PUNSUBSCRIBE, %d", i))
// Subscribe all the connections to the channels and patterns
for _, conn := range append(test.otherConnections, test.targetConn) {
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, test.subChannels...), conn, mockServer))
if err != nil {
t.Error(err)
}
_, err = getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, test.subPatterns...), conn, mockServer))
if err != nil {
t.Error(err)
}
}
// Unsubscribe the target connection from the unsub channels and patterns
res, err := getHandler("UNSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"UNSUBSCRIBE"}, test.unSubChannels...), test.targetConn, mockServer))
if err != nil {
t.Error(err)
}
verifyResponse(res, test.expectedResponses["channel"])
res, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PUNSUBSCRIBE"}, test.unSubPatterns...), test.targetConn, mockServer))
if err != nil {
t.Error(err)
}
verifyResponse(res, test.expectedResponses["pattern"])
for _, channel := range append(test.unSubChannels, test.unSubPatterns...) {
for _, pubsubChannel := range ps.GetAllChannels() {
if pubsubChannel.Name() == channel {
// Assert that target connection is no longer in the unsub channels and patterns
if _, ok := pubsubChannel.Subscribers()[test.targetConn]; ok {
t.Errorf("found unexpected target connection after unsubscrining in channel \"%s\"", channel)
}
for _, conn := range test.otherConnections {
if _, ok := pubsubChannel.Subscribers()[conn]; !ok {
t.Errorf("did not find expected other connection in channel \"%s\"", channel)
}
}
}
}
}
// Assert that the target connection is still in the remain channels and patterns
for _, channel := range append(test.remainChannels, test.remainPatterns...) {
for _, pubsubChannel := range ps.GetAllChannels() {
if pubsubChannel.Name() == channel {
if _, ok := pubsubChannel.Subscribers()[test.targetConn]; !ok {
t.Errorf("could not find expected target connection in channel \"%s\"", channel)
}
}
}
}
}
for _, test := range tests {
// Close all the connections
closeConnections(append(test.otherConnections, test.targetConn))
}
}
func Test_HandlePublish(t *testing.T) {
ctx := context.WithValue(context.Background(), "test_name", "PUBLISH")
// verifyChannelMessage reads the message from the connection and asserts whether
// it's the message we expect to read as a subscriber of a channel or pattern.
verifyEvent := func(c *net.Conn, r *resp.Conn, expected []string) {
if err := (*c).SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil {
t.Error(err)
}
rv, _, err := r.ReadValue()
if err != nil {
t.Error(err)
}
v := rv.Array()
for i := 0; i < len(v); i++ {
if v[i].String() != expected[i] {
t.Errorf("expected item at index %d to be \"%s\", got \"%s\"", i, expected[i], v[i].String())
}
}
}
// The subscribe function handles subscribing the connection to the given
// channels and patterns and reading/verifying the message sent by the echovault after
// subscription.
subscribe := func(ctx context.Context, channels []string, patterns []string, c *net.Conn, r *resp.Conn) {
// Subscribe to channels
go func() {
_, _ = getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), c, mockServer))
}()
// Verify all the responses for each channel subscription
for i := 0; i < len(channels); i++ {
verifyEvent(c, r, []string{"subscribe", channels[i], fmt.Sprintf("%d", i+1)})
}
// Subscribe to all the patterns
go func() {
_, _ = getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), c, mockServer))
}()
// Verify all the responses for each pattern subscription
for i := 0; i < len(patterns); i++ {
verifyEvent(c, r, []string{"psubscribe", patterns[i], fmt.Sprintf("%d", i+1)})
}
}
subscriptions := map[string]map[string][]string{
"subscriber1": {
"channels": {"pub_channel_1", "pub_channel_2", "pub_channel_3"}, // Channels to subscribe to
"patterns": {"pub_channel_[456]"}, // Patterns to subscribe to
},
"subscriber2": {
"channels": {"pub_channel_6", "pub_channel_7"}, // Channels to subscribe to
"patterns": {"pub_channel_[891]"}, // Patterns to subscribe to
},
}
// Create subscriber one and subscribe to channels and patterns
r1, w1 := net.Pipe()
rc1 := resp.NewConn(r1)
subscribe(ctx, subscriptions["subscriber1"]["channels"], subscriptions["subscriber1"]["patterns"], &w1, rc1)
// Create subscriber two and subscribe to channels and patterns
r2, w2 := net.Pipe()
rc2 := resp.NewConn(r2)
subscribe(ctx, subscriptions["subscriber2"]["channels"], subscriptions["subscriber2"]["patterns"], &w2, rc2)
type SubscriberType struct {
c *net.Conn
r *resp.Conn
l string
}
tests := []struct {
channel string
message string
subscribers []SubscriberType
}{
{
channel: "pub_channel_1",
message: "Test both subscribers 1",
subscribers: []SubscriberType{
{c: &r1, r: rc1, l: "pub_channel_1"},
{c: &r2, r: rc2, l: "pub_channel_[891]"},
},
},
{
channel: "pub_channel_6",
message: "Test both subscribers 2",
subscribers: []SubscriberType{
{c: &r1, r: rc1, l: "pub_channel_[456]"},
{c: &r2, r: rc2, l: "pub_channel_6"},
},
},
{
channel: "pub_channel_2",
message: "Test subscriber 1 1",
subscribers: []SubscriberType{
{c: &r1, r: rc1, l: "pub_channel_2"},
},
},
{
channel: "pub_channel_3",
message: "Test subscriber 1 2",
subscribers: []SubscriberType{
{c: &r1, r: rc1, l: "pub_channel_3"},
},
},
{
channel: "pub_channel_4",
message: "Test both subscribers 2",
subscribers: []SubscriberType{
{c: &r1, r: rc1, l: "pub_channel_[456]"},
},
},
{
channel: "pub_channel_5",
message: "Test subscriber 1 3",
subscribers: []SubscriberType{
{c: &r1, r: rc1, l: "pub_channel_[456]"},
},
},
{
channel: "pub_channel_7",
message: "Test subscriber 2 1",
subscribers: []SubscriberType{
{c: &r2, r: rc2, l: "pub_channel_7"},
},
},
{
channel: "pub_channel_8",
message: "Test subscriber 2 2",
subscribers: []SubscriberType{
{c: &r1, r: rc2, l: "pub_channel_[891]"},
},
},
{
channel: "pub_channel_9",
message: "Test subscriber 2 3",
subscribers: []SubscriberType{
{c: &r2, r: rc2, l: "pub_channel_[891]"},
},
},
}
// Dial echovault to make publisher connection
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil {
t.Error(err)
}
defer func() {
if err = conn.Close(); err != nil {
t.Error(err)
}
}()
w := resp.NewConn(conn)
for _, test := range tests {
err = w.WriteArray([]resp.Value{
resp.StringValue("PUBLISH"),
resp.StringValue(test.channel),
resp.StringValue(test.message),
})
if err != nil {
t.Error(err)
}
rv, _, err := w.ReadValue()
if err != nil {
t.Error(err)
}
if rv.String() != "OK" {
t.Errorf("Expected publish response to be \"OK\", got \"%s\"", rv.String())
}
for _, sub := range test.subscribers {
verifyEvent(sub.c, sub.r, []string{"message", sub.l, test.message})
}
}
}
func Test_HandlePubSubChannels(t *testing.T) {
done := make(chan struct{})
go func() {
// Create separate mock echovault for this test
port, _ := internal.GetFreePort()
mockServer := setUpServer(bindAddr, uint16(port))
ctx := context.WithValue(context.Background(), "test_name", "PUBSUB CHANNELS")
channels := []string{"channel_1", "channel_2", "channel_3"}
patterns := []string{"channel_[123]", "channel_[456]"}
rConn1, wConn1 := net.Pipe()
rc1 := resp.NewConn(rConn1)
rConn2, wConn2 := net.Pipe()
rc2 := resp.NewConn(rConn2)
// Subscribe connections to channels
go func() {
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), &wConn1, mockServer))
if err != nil {
t.Error(err)
}
}()
for i := 0; i < len(channels); i++ {
v, _, err := rc1.ReadValue()
if err != nil {
t.Error(err)
}
if !slices.ContainsFunc(channels, func(s string) bool {
return s == v.Array()[1].String()
}) {
t.Errorf("unexpected channel %s in response", v.Array()[1].String())
}
}
go func() {
_, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), &wConn2, mockServer))
if err != nil {
t.Error(err)
}
}()
for i := 0; i < len(patterns); i++ {
v, _, err := rc2.ReadValue()
if err != nil {
t.Error(err)
}
if !slices.ContainsFunc(patterns, func(s string) bool {
return s == v.Array()[1].String()
}) {
t.Errorf("unexpected pattern %s in response", v.Array()[1].String())
}
}
verifyExpectedResponse := func(res []byte, expected []string) {
rd := resp.NewReader(bytes.NewReader(res))
rv, _, err := rd.ReadValue()
if err != nil {
t.Error(err)
}
if len(rv.Array()) != len(expected) {
t.Errorf("expected response array of length %d, got %d", len(expected), len(rv.Array()))
}
for _, e := range expected {
if !slices.ContainsFunc(rv.Array(), func(v resp.Value) bool {
return e == v.String()
}) {
t.Errorf("expected to find element \"%s\" in response array, could not find it", e)
}
}
}
// Check if all subscriptions are returned
res, err := getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS"}, nil, mockServer))
if err != nil {
t.Error(err)
}
verifyExpectedResponse(res, append(channels, patterns...))
// Unsubscribe from one pattern and one channel before checking against a new slice of
// expected channels/patterns in the response of the "PUBSUB CHANNELS" command
_, err = getHandler("UNSUBSCRIBE")(getHandlerFuncParams(
ctx,
append([]string{"UNSUBSCRIBE"}, []string{"channel_2", "channel_3"}...),
&wConn1,
mockServer,
))
if err != nil {
t.Error(err)
}
_, err = getHandler("UNSUBSCRIBE")(getHandlerFuncParams(
ctx,
append([]string{"UNSUBSCRIBE"}, "channel_[456]"),
&wConn2,
mockServer,
))
if err != nil {
t.Error(err)
}
// Return all the remaining channels
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS"}, nil, mockServer))
if err != nil {
t.Error(err)
}
verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"})
// Return only one of the remaining channels when passed a pattern that matches it
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS", "channel_[189]"}, nil, mockServer))
if err != nil {
t.Error(err)
}
verifyExpectedResponse(res, []string{"channel_1"})
// Return both remaining channels when passed a pattern that matches them
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS", "channel_[123]"}, nil, mockServer))
if err != nil {
t.Error(err)
}
verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"})
// Return none channels when passed a pattern that does not match either channel
res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS", "channel_[456]"}, nil, mockServer))
if err != nil {
t.Error(err)
}
verifyExpectedResponse(res, []string{})
done <- struct{}{}
}()
select {
case <-time.After(200 * time.Millisecond):
t.Error("timeout")
case <-done:
}
}
func Test_HandleNumPat(t *testing.T) {
done := make(chan struct{})
go func() {
// Create separate mock echovault for this test
port, _ := internal.GetFreePort()
mockServer := setUpServer(bindAddr, uint16(port))
ctx := context.WithValue(context.Background(), "test_name", "PUBSUB NUMPAT")
patterns := []string{"pattern_[123]", "pattern_[456]", "pattern_[789]"}
connections := make([]struct {
w *net.Conn
r *resp.Conn
}, 3)
for i := 0; i < len(connections); i++ {
w, r := net.Pipe()
connections[i] = struct {
w *net.Conn
r *resp.Conn
}{w: &w, r: resp.NewConn(r)}
go func() {
_, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), &w, mockServer))
if err != nil {
t.Error(err)
}
}()
for j := 0; j < len(patterns); j++ {
v, _, err := connections[i].r.ReadValue()
if err != nil {
t.Error(err)
}
arr := v.Array()
if !slices.ContainsFunc(patterns, func(s string) bool {
return s == arr[1].String()
}) {
t.Errorf("found unexpected pattern in response \"%s\"", arr[1].String())
}
}
}
verifyNumPatResponse := func(res []byte, expected int) {
rd := resp.NewReader(bytes.NewReader(res))
rv, _, err := rd.ReadValue()
if err != nil {
t.Error(err)
}
if rv.Integer() != expected {
t.Errorf("expected first NUMPAT response to be %d, got %d", expected, rv.Integer())
}
}
// Check that we receive all the patterns with NUMPAT commands
res, err := getHandler("PUBSUB", "NUMPAT")(getHandlerFuncParams(ctx, []string{"PUBSUB", "NUMPAT"}, nil, mockServer))
if err != nil {
t.Error(err)
}
verifyNumPatResponse(res, len(patterns))
// Unsubscribe from a channel and check if the number of active channels is updated
for _, conn := range connections {
_, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, []string{"PUNSUBSCRIBE", patterns[0]}, conn.w, mockServer))
if err != nil {
t.Error(err)
}
}
res, err = getHandler("PUBSUB", "NUMPAT")(getHandlerFuncParams(ctx, []string{"PUBSUB", "NUMPAT"}, nil, mockServer))
if err != nil {
t.Error(err)
}
verifyNumPatResponse(res, len(patterns)-1)
// Unsubscribe from all the channels and check if we get a 0 response
for _, conn := range connections {
_, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, []string{"PUNSUBSCRIBE"}, conn.w, mockServer))
if err != nil {
t.Error(err)
}
}
res, err = getHandler("PUBSUB", "NUMPAT")(getHandlerFuncParams(ctx, []string{"PUBSUB", "NUMPAT"}, nil, mockServer))
if err != nil {
t.Error(err)
}
verifyNumPatResponse(res, 0)
done <- struct{}{}
}()
select {
case <-time.After(200 * time.Millisecond):
t.Error("timeout")
case <-done:
}
}
func Test_HandleNumSub(t *testing.T) {
done := make(chan struct{})
go func() {
// Create separate mock echovault for this test
port, _ := internal.GetFreePort()
mockServer := setUpServer(bindAddr, uint16(port))
ctx := context.WithValue(context.Background(), "test_name", "PUBSUB NUMSUB")
channels := []string{"channel_1", "channel_2", "channel_3"}
connections := make([]struct {
w *net.Conn
r *resp.Conn
}, 3)
for i := 0; i < len(connections); i++ {
w, r := net.Pipe()
connections[i] = struct {
w *net.Conn
r *resp.Conn
}{w: &w, r: resp.NewConn(r)}
go func() {
_, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), &w, mockServer))
if err != nil {
t.Error(err)
}
}()
for j := 0; j < len(channels); j++ {
v, _, err := connections[i].r.ReadValue()
if err != nil {
t.Error(err)
}
arr := v.Array()
if !slices.ContainsFunc(channels, func(s string) bool {
return s == arr[1].String()
}) {
t.Errorf("found unexpected pattern in response \"%s\"", arr[1].String())
}
}
}
tests := []struct {
cmd []string
expectedResponse [][]string
}{
{ // 1. Get all subscriptions on existing channels
cmd: append([]string{"PUBSUB", "NUMSUB"}, channels...),
expectedResponse: [][]string{{"channel_1", "3"}, {"channel_2", "3"}, {"channel_3", "3"}},
},
{ // 2. Get all the subscriptions of on existing channels and a few non-existent ones
cmd: append([]string{"PUBSUB", "NUMSUB", "non_existent_channel_1", "non_existent_channel_2"}, channels...),
expectedResponse: [][]string{
{"non_existent_channel_1", "0"},
{"non_existent_channel_2", "0"},
{"channel_1", "3"},
{"channel_2", "3"},
{"channel_3", "3"},
},
},
{ // 3. Get an empty array when channels are not provided in the command
cmd: []string{"PUBSUB", "NUMSUB"},
expectedResponse: make([][]string, 0),
},
}
for i, test := range tests {
ctx = context.WithValue(ctx, "test_index", i)
res, err := getHandler("PUBSUB", "NUMSUB")(getHandlerFuncParams(ctx, test.cmd, nil, mockServer))
if err != nil {
t.Error(err)
}
rd := resp.NewReader(bytes.NewReader(res))
rv, _, err := rd.ReadValue()
if err != nil {
t.Error(err)
}
arr := rv.Array()
if len(arr) != len(test.expectedResponse) {
t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(arr))
}
for _, item := range arr {
itemArr := item.Array()
if len(itemArr) != 2 {
t.Errorf("expected each response item to be of length 2, got %d", len(itemArr))
}
if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool {
return expected[0] == itemArr[0].String() && expected[1] == itemArr[1].String()
}) {
t.Errorf("could not find entry with channel \"%s\", with %d subscribers in expected response",
itemArr[0].String(), itemArr[1].Integer())
}
}
}
done <- struct{}{}
}()
select {
case <-time.After(200 * time.Millisecond):
t.Error("timeout")
case <-done:
}
}