Files
SugarDB/internal/modules/pubsub/commands_test.go
Kelvin Mwinuka 703ad2a802 Rename the project to SugarDB. (#130)
Renames project to "SugarDB" - @kelvinmwinuka
2024-09-22 21:31:12 +08:00

1014 lines
30 KiB
Go

// Copyright 2024 Kelvin Clement Mwinuka
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pubsub_test
import (
"github.com/echovault/sugardb/internal"
"github.com/echovault/sugardb/internal/config"
"github.com/echovault/sugardb/internal/constants"
"github.com/echovault/sugardb/sugardb"
"github.com/tidwall/resp"
"net"
"slices"
"strings"
"sync"
"testing"
)
func setUpServer(port int) (*sugardb.SugarDB, error) {
return sugardb.NewSugarDB(
sugardb.WithConfig(config.Config{
BindAddr: "localhost",
Port: uint16(port),
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
}
func Test_PubSub(t *testing.T) {
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setUpServer(port)
if err != nil {
t.Error(err)
return
}
go func() {
mockServer.Start()
}()
t.Cleanup(func() {
mockServer.ShutDown()
})
t.Run("Test_HandleSubscribe", func(t *testing.T) {
t.Parallel()
// Establish connections.
numOfConnections := 20
rawConnections := make([]net.Conn, numOfConnections)
connections := make([]*resp.Conn, numOfConnections)
for i := 0; i < numOfConnections; i++ {
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
rawConnections[i] = conn
connections[i] = resp.NewConn(conn)
}
defer func() {
for _, conn := range rawConnections {
_ = conn.Close()
}
}()
// Test subscribe to channels
channels := []string{"sub_channel1", "sub_channel2", "sub_channel3"}
command := []resp.Value{resp.StringValue("SUBSCRIBE")}
for _, channel := range channels {
command = append(command, resp.StringValue(channel))
}
for _, conn := range connections {
if err := conn.WriteArray(command); err != nil {
t.Error(err)
return
}
for i := 0; i < len(channels); i++ {
// Read all the subscription confirmations from the connection.
if _, _, err := conn.ReadValue(); err != nil {
t.Error(err)
return
}
}
}
activeChannels, err := mockServer.PubSubChannels("*")
if err != nil {
t.Error(err)
return
}
numSubs, err := mockServer.PubSubNumSub(channels...)
if err != nil {
t.Error(err)
return
}
for _, channel := range channels {
// Check if the channel exists in the pubsub module.
if !slices.Contains(activeChannels, channel) {
t.Errorf("expected pubsub to contain channel \"%s\" but it was not found", channel)
return
}
// Check if the channel has the right number of subscribers.
if numSubs[channel] != len(connections) {
t.Errorf("expected channel \"%s\" to have %d subscribers, got %d",
channel, len(connections), numSubs[channel])
return
}
}
// Test subscribe to patterns
patterns := []string{"psub_channel*"}
command = []resp.Value{resp.StringValue("PSUBSCRIBE")}
for _, pattern := range patterns {
command = append(command, resp.StringValue(pattern))
}
for _, conn := range connections {
if err := conn.WriteArray(command); err != nil {
t.Error(err)
return
}
for i := 0; i < len(patterns); i++ {
// Read all the pattern subscription confirmations from the connection.
if _, _, err := conn.ReadValue(); err != nil {
t.Error(err)
return
}
}
}
numSubs, err = mockServer.PubSubNumSub(patterns...)
if err != nil {
t.Error(err)
return
}
for _, pattern := range patterns {
activePatterns, err := mockServer.PubSubChannels(pattern)
if err != nil {
t.Error(err)
return
}
// Check if pattern channel exists in pubsub module.
if !slices.Contains(activePatterns, pattern) {
t.Errorf("expected pubsub to contain pattern channel \"%s\" but it was not found", pattern)
return
}
// Check if the channel has all the connections from above.
if numSubs[pattern] != len(connections) {
t.Errorf("expected pattern channel \"%s\" to have %d subscribers, got %d",
pattern, len(connections), numSubs[pattern])
return
}
}
})
t.Run("Test_HandleUnsubscribe", func(t *testing.T) {
t.Parallel()
var rawConnections []net.Conn
generateConnections := func(noOfConnections int) []*resp.Conn {
connections := make([]*resp.Conn, noOfConnections)
for i := 0; i < noOfConnections; i++ {
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
}
rawConnections = append(rawConnections, conn)
connections[i] = resp.NewConn(conn)
}
return connections
}
defer func() {
for _, conn := range rawConnections {
_ = conn.Close()
}
}()
verifyResponse := func(res resp.Value, expectedResponse [][]string) {
v := res.Array()
if len(v) != len(expectedResponse) {
t.Errorf("expected subscribe response of length %d, but got %d", len(expectedResponse), len(v))
}
for _, item := range v {
arr := item.Array()
if len(arr) != 3 {
t.Errorf("expected subscribe response item to be length %d, but got %d", 3, len(arr))
}
if !slices.ContainsFunc(expectedResponse, func(strings []string) bool {
return strings[0] == arr[0].String() && strings[1] == arr[1].String() && strings[2] == arr[2].String()
}) {
t.Errorf("expected to find item \"%s\" in response, did not find it.", arr[1].String())
}
}
}
tests := []struct {
subChannels []string // All channels to subscribe to
subPatterns []string // All patterns to subscribe to
unSubChannels []string // Channels to unsubscribe from
unSubPatterns []string // Patterns to unsubscribe from
remainChannels []string // Channels to remain subscribed to
remainPatterns []string // Patterns to remain subscribed to
targetConn *resp.Conn // Connection used to test unsubscribe functionality
otherConnections []*resp.Conn // Connections to fill the subscribers list for channels and patterns
expectedResponses map[string][][]string // The expected response from the handler
}{
{ // 1. Unsubscribe from channels and patterns
subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
unSubChannels: []string{"xx_channel_one", "xx_channel_two"},
unSubPatterns: []string{"xx_pattern_[ab]"},
remainChannels: []string{"xx_channel_three", "xx_channel_four"},
remainPatterns: []string{"xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
targetConn: generateConnections(1)[0],
otherConnections: generateConnections(20),
expectedResponses: map[string][][]string{
"channel": {
{"unsubscribe", "xx_channel_one", "1"},
{"unsubscribe", "xx_channel_two", "2"},
},
"pattern": {
{"punsubscribe", "xx_pattern_[ab]", "1"},
},
},
},
{ // 2. Unsubscribe from all channels no channel or pattern is passed to command
subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
unSubChannels: []string{},
unSubPatterns: []string{},
remainChannels: []string{},
remainPatterns: []string{},
targetConn: generateConnections(1)[0],
otherConnections: generateConnections(20),
expectedResponses: map[string][][]string{
"channel": {
{"unsubscribe", "xx_channel_one", "1"},
{"unsubscribe", "xx_channel_two", "2"},
{"unsubscribe", "xx_channel_three", "3"},
{"unsubscribe", "xx_channel_four", "4"},
},
"pattern": {
{"punsubscribe", "xx_pattern_[ab]", "1"},
{"punsubscribe", "xx_pattern_[cd]", "2"},
{"punsubscribe", "xx_pattern_[ef]", "3"},
{"punsubscribe", "xx_pattern_[gh]", "4"},
},
},
},
{ // 3. Don't unsubscribe from any channels or patterns if the provided ones are non-existent
subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
unSubChannels: []string{"xx_channel_non_existent_channel"},
unSubPatterns: []string{"xx_channel_non_existent_pattern_[ae]"},
remainChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
remainPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
targetConn: generateConnections(1)[0],
otherConnections: generateConnections(20),
expectedResponses: map[string][][]string{
"channel": {},
"pattern": {},
},
},
}
for _, test := range tests {
// Subscribe to channels.
for _, conn := range append(test.otherConnections, test.targetConn) {
command := []resp.Value{resp.StringValue("SUBSCRIBE")}
for _, channel := range test.subChannels {
command = append(command, resp.StringValue(channel))
}
if err := conn.WriteArray(command); err != nil {
t.Error(err)
return
}
for i := 0; i < len(test.subChannels); i++ {
// Read channel subscription confirmations from connection.
if _, _, err := conn.ReadValue(); err != nil {
t.Error(err)
}
}
// Subscribe to patterns.
command = []resp.Value{resp.StringValue("PSUBSCRIBE")}
for _, pattern := range test.subPatterns {
command = append(command, resp.StringValue(pattern))
}
if err := conn.WriteArray(command); err != nil {
t.Error(err)
return
}
for i := 0; i < len(test.subPatterns); i++ {
// Read pattern subscription confirmations from connection.
if _, _, err := conn.ReadValue(); err != nil {
t.Error(err)
}
}
}
// Unsubscribe the target connection from the unsub channels.
command := []resp.Value{resp.StringValue("UNSUBSCRIBE")}
for _, channel := range test.unSubChannels {
command = append(command, resp.StringValue(channel))
}
if err := test.targetConn.WriteArray(command); err != nil {
t.Error(err)
return
}
res, _, err := test.targetConn.ReadValue()
if err != nil {
t.Error(err)
return
}
verifyResponse(res, test.expectedResponses["channel"])
// Unsubscribe the target connection from the unsub patterns.
command = []resp.Value{resp.StringValue("PUNSUBSCRIBE")}
for _, pattern := range test.unSubPatterns {
command = append(command, resp.StringValue(pattern))
}
if err = test.targetConn.WriteArray(command); err != nil {
t.Error(err)
return
}
res, _, err = test.targetConn.ReadValue()
if err != nil {
t.Error(err)
return
}
verifyResponse(res, test.expectedResponses["pattern"])
}
})
t.Run("Test_HandlePublish", func(t *testing.T) {
t.Parallel()
var rawConnections []net.Conn
establishConnection := func() *resp.Conn {
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
}
rawConnections = append(rawConnections, conn)
return resp.NewConn(conn)
}
defer func() {
for _, conn := range rawConnections {
_ = conn.Close()
}
}()
// verifyChannelMessage reads the message from the connection and asserts whether
// it's the message we expect to read as a subscriber of a channel or pattern.
verifyEvent := func(client *resp.Conn, expected []string) {
rv, _, err := client.ReadValue()
if err != nil {
t.Error(err)
}
v := rv.Array()
for i := 0; i < len(v); i++ {
if v[i].String() != expected[i] {
t.Errorf("expected item at index %d to be \"%s\", got \"%s\"", i, expected[i], v[i].String())
}
}
}
// The subscribe function handles subscribing the connection to the given
// channels and patterns and reading/verifying the message sent by the server after
// subscription.
subscribe := func(client *resp.Conn, channels []string, patterns []string) {
// Subscribe to channels
command := []resp.Value{resp.StringValue("SUBSCRIBE")}
for _, channel := range channels {
command = append(command, resp.StringValue(channel))
}
if err := client.WriteArray(command); err != nil {
t.Error(err)
}
for i := 0; i < len(channels); i++ {
// Read channel subscription confirmations.
if _, _, err := client.ReadValue(); err != nil {
t.Error(err)
}
}
// Subscribe to all the patterns
command = []resp.Value{resp.StringValue("PSUBSCRIBE")}
for _, pattern := range patterns {
command = append(command, resp.StringValue(pattern))
}
if err := client.WriteArray(command); err != nil {
t.Error(err)
}
for i := 0; i < len(patterns); i++ {
// Read pattern subscription confirmations.
if _, _, err := client.ReadValue(); err != nil {
t.Error(err)
}
}
}
subscriptions := []struct {
client *resp.Conn
channels []string
patterns []string
}{
{
client: establishConnection(),
channels: []string{"pub_channel_1", "pub_channel_2", "pub_channel_3"},
patterns: []string{"pub_channel_[456]"},
},
{
client: establishConnection(),
channels: []string{"pub_channel_6", "pub_channel_7"},
patterns: []string{"pub_channel_[891]"},
},
}
for _, subscription := range subscriptions {
// Subscribe to channels and patterns.
subscribe(subscription.client, subscription.channels, subscription.patterns)
}
type Subscriber struct {
client *resp.Conn
channel string
}
tests := []struct {
channel string
message string
subscribers []Subscriber
}{
{
channel: "pub_channel_1",
message: "Test both subscribers 1",
subscribers: []Subscriber{
{client: subscriptions[0].client, channel: "pub_channel_1"},
{client: subscriptions[1].client, channel: "pub_channel_[891]"},
},
},
{
channel: "pub_channel_6",
message: "Test both subscribers 2",
subscribers: []Subscriber{
{client: subscriptions[0].client, channel: "pub_channel_[456]"},
{client: subscriptions[1].client, channel: "pub_channel_6"},
},
},
{
channel: "pub_channel_2",
message: "Test subscriber 1 1",
subscribers: []Subscriber{
{client: subscriptions[0].client, channel: "pub_channel_2"},
},
},
{
channel: "pub_channel_3",
message: "Test subscriber 1 2",
subscribers: []Subscriber{
{client: subscriptions[0].client, channel: "pub_channel_3"},
},
},
{
channel: "pub_channel_4",
message: "Test both subscribers 2",
subscribers: []Subscriber{
{client: subscriptions[0].client, channel: "pub_channel_[456]"},
},
},
{
channel: "pub_channel_5",
message: "Test subscriber 1 3",
subscribers: []Subscriber{
{client: subscriptions[0].client, channel: "pub_channel_[456]"},
},
},
{
channel: "pub_channel_7",
message: "Test subscriber 2 1",
subscribers: []Subscriber{
{client: subscriptions[1].client, channel: "pub_channel_7"},
},
},
{
channel: "pub_channel_8",
message: "Test subscriber 2 2",
subscribers: []Subscriber{
{client: subscriptions[1].client, channel: "pub_channel_[891]"},
},
},
{
channel: "pub_channel_9",
message: "Test subscriber 2 3",
subscribers: []Subscriber{
{client: subscriptions[1].client, channel: "pub_channel_[891]"},
},
},
}
// Dial echovault to make publisher connection
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
publisher := resp.NewConn(conn)
for _, test := range tests {
err = publisher.WriteArray([]resp.Value{
resp.StringValue("PUBLISH"),
resp.StringValue(test.channel),
resp.StringValue(test.message),
})
if err != nil {
t.Error(err)
}
rv, _, err := publisher.ReadValue()
if err != nil {
t.Error(err)
}
if rv.String() != "OK" {
t.Errorf("Expected publish response to be \"OK\", got \"%s\"", rv.String())
}
for _, sub := range test.subscribers {
verifyEvent(sub.client, []string{"message", sub.channel, test.message})
}
}
})
t.Run("Test_HandlePubSubChannels", func(t *testing.T) {
t.Parallel()
verifyExpectedResponse := func(res resp.Value, expected []string) {
if len(res.Array()) != len(expected) {
t.Errorf("expected response array of length %d, got %d", len(expected), len(res.Array()))
}
for _, e := range expected {
if !slices.ContainsFunc(res.Array(), func(v resp.Value) bool {
return e == v.String()
}) {
t.Errorf("expected to find element \"%s\" in response array, could not find it", e)
}
}
}
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setUpServer(port)
if err != nil {
t.Error(err)
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
mockServer.Start()
}()
wg.Wait()
subscribers := make([]*resp.Conn, 2)
for i := 0; i < len(subscribers); i++ {
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
subscribers[i] = resp.NewConn(conn)
}
channels := []string{"channel_1", "channel_2", "channel_3"}
patterns := []string{"channel_[123]", "channel_[456]"}
subscriptions := []struct {
client *resp.Conn
action string
channels []string
patterns []string
}{
{
client: subscribers[0],
action: "SUBSCRIBE",
channels: channels,
patterns: make([]string, 0),
},
{
client: subscribers[1],
action: "PSUBSCRIBE",
channels: make([]string, 0),
patterns: patterns,
},
}
for _, subscription := range subscriptions {
command := []resp.Value{resp.StringValue(subscription.action)}
if len(subscription.channels) > 0 {
for _, channel := range subscription.channels {
command = append(command, resp.StringValue(channel))
}
} else if len(subscription.patterns) > 0 {
for _, pattern := range subscription.patterns {
command = append(command, resp.StringValue(pattern))
}
}
if err := subscription.client.WriteArray(command); err != nil {
t.Error(err)
}
if len(subscription.channels) > 0 {
for i := 0; i < len(subscription.channels); i++ {
_, _, _ = subscription.client.ReadValue()
}
return
}
for i := 0; i < len(subscription.patterns); i++ {
_, _, _ = subscription.client.ReadValue()
}
}
// Get fresh connection for the next phase.
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
client := resp.NewConn(conn)
// Check if all subscriptions are returned.
if err = client.WriteArray([]resp.Value{resp.StringValue("PUBSUB"), resp.StringValue("CHANNELS")}); err != nil {
t.Error(err)
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
}
verifyExpectedResponse(res, append(channels, patterns...))
// Unsubscribe from one pattern and one channel before checking against a new slice of
// expected channels/patterns in the response of the "PUBSUB CHANNELS" command.
for _, unsubscribe := range []struct {
client *resp.Conn
command []resp.Value
}{
{
client: subscribers[0],
command: []resp.Value{resp.StringValue("UNSUBSCRIBE"), resp.StringValue("channel_2"), resp.StringValue("channel_3")},
},
{
client: subscribers[1],
command: []resp.Value{resp.StringValue("UNSUBSCRIBE"), resp.StringValue("channel_[456]")},
},
} {
if err = unsubscribe.client.WriteArray(unsubscribe.command); err != nil {
t.Error(err)
}
for i := 0; i < len(unsubscribe.command[1:]); i++ {
_, _, err = unsubscribe.client.ReadValue()
if err != nil {
t.Error(err)
}
}
}
// Return all the remaining channels.
if err = client.WriteArray([]resp.Value{resp.StringValue("PUBSUB"), resp.StringValue("CHANNELS")}); err != nil {
t.Error(err)
}
res, _, err = client.ReadValue()
if err != nil {
t.Error(err)
}
verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"})
// Return only one of the remaining channels when passed a pattern that matches it.
if err = client.WriteArray([]resp.Value{
resp.StringValue("PUBSUB"),
resp.StringValue("CHANNELS"),
resp.StringValue("channel_[189]"),
}); err != nil {
t.Error(err)
}
verifyExpectedResponse(res, []string{"channel_1"})
// Return both remaining channels when passed a pattern that matches them.
if err := client.WriteArray([]resp.Value{
resp.StringValue("PUBSUB"),
resp.StringValue("CHANNELS"),
resp.StringValue("channel_[123]"),
}); err != nil {
t.Error(err)
}
res, _, err = client.ReadValue()
if err != nil {
t.Error(err)
}
verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"})
// Return no channels when passed a pattern that does not match either channel.
if err = client.WriteArray([]resp.Value{
resp.StringValue("PUBSUB"),
resp.StringValue("CHANNELS"),
resp.StringValue("channel_[456]"),
}); err != nil {
t.Error(err)
}
res, _, err = client.ReadValue()
if err != nil {
t.Error(err)
}
verifyExpectedResponse(res, []string{})
})
t.Run("Test_HandleNumPat", func(t *testing.T) {
t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setUpServer(port)
if err != nil {
t.Error(err)
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
mockServer.Start()
}()
wg.Wait()
// Create subscribers.
subscribers := make([]*resp.Conn, 3)
for i := 0; i < len(subscribers); i++ {
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
subscribers[i] = resp.NewConn(conn)
}
patterns := []string{"pattern_[123]", "pattern_[456]", "pattern_[789]"}
// Subscribe to all patterns
for _, client := range subscribers {
command := []resp.Value{resp.StringValue("PSUBSCRIBE")}
for _, pattern := range patterns {
command = append(command, resp.StringValue(pattern))
}
if err := client.WriteArray(command); err != nil {
t.Error(err)
}
// Read subscription responses to make sure we've subscribed to all the channels.
for i := 0; i < len(patterns); i++ {
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
}
if len(res.Array()) != 3 {
t.Errorf("expected array response of length %d, got %d", 3, len(res.Array()))
}
if !strings.EqualFold(res.Array()[0].String(), "psubscribe") {
t.Errorf("expected the first array item to be \"psubscribe\", got \"%s\"", res.Array()[0].String())
}
if !slices.Contains(patterns, res.Array()[1].String()) {
t.Errorf("unexpected channel name \"%s\", expected %v", res.Array()[1].String(), patterns)
}
}
}
// Get fresh connection for the next phase.
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
client := resp.NewConn(conn)
// Check that we receive all the patterns with NUMPAT commands.
if err = client.WriteArray([]resp.Value{resp.StringValue("PUBSUB"), resp.StringValue("NUMPAT")}); err != nil {
t.Error(err)
}
res, _, err := client.ReadValue()
if res.Integer() != len(patterns) {
t.Errorf("expected response \"%d\", got \"%d\"", len(patterns), res.Integer())
}
// Unsubscribe all subscribers from one pattern and check if the response is updated.
for _, subscriber := range subscribers {
if err = subscriber.WriteArray([]resp.Value{
resp.StringValue("PUNSUBSCRIBE"),
resp.StringValue(patterns[0]),
}); err != nil {
t.Error(err)
}
res, _, err = subscriber.ReadValue()
if err != nil {
t.Error(err)
}
if len(res.Array()[0].Array()) != 3 {
t.Errorf("expected array response of length %d, got %d", 3, len(res.Array()[0].Array()))
}
if !strings.EqualFold(res.Array()[0].Array()[0].String(), "punsubscribe") {
t.Errorf("expected the first array item to be \"punsubscribe\", got \"%s\"", res.Array()[0].Array()[0].String())
}
if res.Array()[0].Array()[1].String() != patterns[0] {
t.Errorf("unexpected channel name \"%s\", expected %s", res.Array()[0].Array()[1].String(), patterns[0])
}
}
if err = client.WriteArray([]resp.Value{resp.StringValue("PUBSUB"), resp.StringValue("NUMPAT")}); err != nil {
t.Error(err)
}
res, _, err = client.ReadValue()
if res.Integer() != len(patterns)-1 {
t.Errorf("expected response \"%d\", got \"%d\"", len(patterns)-1, res.Integer())
}
// Unsubscribe from all the channels and check if we get a 0 response
for _, subscriber := range subscribers {
for _, pattern := range patterns[1:] {
if err = subscriber.WriteArray([]resp.Value{
resp.StringValue("PUNSUBSCRIBE"),
resp.StringValue(pattern),
}); err != nil {
t.Error(err)
}
res, _, err = subscriber.ReadValue()
if err != nil {
t.Error(err)
}
if len(res.Array()[0].Array()) != 3 {
t.Errorf("expected array response of length %d, got %d", 3,
len(res.Array()[0].Array()))
}
if !strings.EqualFold(res.Array()[0].Array()[0].String(), "punsubscribe") {
t.Errorf("expected the first array item to be \"punsubscribe\", got \"%s\"",
res.Array()[0].Array()[0].String())
}
if res.Array()[0].Array()[1].String() != pattern {
t.Errorf("unexpected channel name \"%s\", expected %s",
res.Array()[0].Array()[1].String(), pattern)
}
}
}
if err = client.WriteArray([]resp.Value{resp.StringValue("PUBSUB"), resp.StringValue("NUMPAT")}); err != nil {
t.Error(err)
}
res, _, err = client.ReadValue()
if res.Integer() != 0 {
t.Errorf("expected response \"%d\", got \"%d\"", 0, res.Integer())
}
})
t.Run("Test_HandleNumSub", func(t *testing.T) {
t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setUpServer(port)
if err != nil {
t.Error(err)
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
mockServer.Start()
}()
wg.Wait()
channels := []string{"channel_1", "channel_2", "channel_3"}
for i := 0; i < 3; i++ {
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
client := resp.NewConn(conn)
command := []resp.Value{
resp.StringValue("SUBSCRIBE"),
}
for _, channel := range channels {
command = append(command, resp.StringValue(channel))
}
err = client.WriteArray(command)
if err != nil {
t.Error(err)
}
// Read subscription responses to make sure we've subscribed to all the channels.
for i := 0; i < len(channels); i++ {
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
}
if len(res.Array()) != 3 {
t.Errorf("expected array response of length %d, got %d", 3, len(res.Array()))
}
if !strings.EqualFold(res.Array()[0].String(), "subscribe") {
t.Errorf("expected the first array item to be \"subscribe\", got \"%s\"", res.Array()[0].String())
}
if !slices.Contains(channels, res.Array()[1].String()) {
t.Errorf("unexpected channel name \"%s\", expected %v", res.Array()[1].String(), channels)
}
}
}
// Get fresh connection for the next phase.
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
client := resp.NewConn(conn)
tests := []struct {
name string
cmd []string
expectedResponse [][]string
}{
{
name: "1. Get all subscriptions on existing channels",
cmd: append([]string{"PUBSUB", "NUMSUB"}, channels...),
expectedResponse: [][]string{{"channel_1", "3"}, {"channel_2", "3"}, {"channel_3", "3"}},
},
{
name: "2. Get all the subscriptions of on existing channels and a few non-existent ones",
cmd: append([]string{"PUBSUB", "NUMSUB", "non_existent_channel_1", "non_existent_channel_2"}, channels...),
expectedResponse: [][]string{
{"non_existent_channel_1", "0"},
{"non_existent_channel_2", "0"},
{"channel_1", "3"},
{"channel_2", "3"},
{"channel_3", "3"},
},
},
{
name: "3. Get an empty array when channels are not provided in the command",
cmd: []string{"PUBSUB", "NUMSUB"},
expectedResponse: make([][]string, 0),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var command []resp.Value
for _, token := range test.cmd {
command = append(command, resp.StringValue(token))
}
if err = client.WriteArray(command); err != nil {
t.Error(err)
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
}
arr := res.Array()
if len(arr) != len(test.expectedResponse) {
t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(arr))
}
for _, item := range arr {
itemArr := item.Array()
if len(itemArr) != 2 {
t.Errorf("expected each response item to be of length 2, got %d", len(itemArr))
}
if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool {
return expected[0] == itemArr[0].String() && expected[1] == itemArr[1].String()
}) {
t.Errorf("could not find entry with channel \"%s\", with %d subscribers in expected response",
itemArr[0].String(), itemArr[1].Integer())
}
}
})
}
})
}