mirror of
				https://github.com/EchoVault/SugarDB.git
				synced 2025-10-31 11:26:52 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			1014 lines
		
	
	
		
			30 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			1014 lines
		
	
	
		
			30 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright 2024 Kelvin Clement Mwinuka
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //      http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| package pubsub_test
 | |
| 
 | |
| import (
 | |
| 	"github.com/echovault/echovault/echovault"
 | |
| 	"github.com/echovault/echovault/internal"
 | |
| 	"github.com/echovault/echovault/internal/config"
 | |
| 	"github.com/echovault/echovault/internal/constants"
 | |
| 	"github.com/tidwall/resp"
 | |
| 	"net"
 | |
| 	"slices"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"testing"
 | |
| )
 | |
| 
 | |
| func setUpServer(port int) (*echovault.EchoVault, error) {
 | |
| 	return echovault.NewEchoVault(
 | |
| 		echovault.WithConfig(config.Config{
 | |
| 			BindAddr:       "localhost",
 | |
| 			Port:           uint16(port),
 | |
| 			DataDir:        "",
 | |
| 			EvictionPolicy: constants.NoEviction,
 | |
| 		}),
 | |
| 	)
 | |
| }
 | |
| 
 | |
| func Test_PubSub(t *testing.T) {
 | |
| 	port, err := internal.GetFreePort()
 | |
| 	if err != nil {
 | |
| 		t.Error(err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	mockServer, err := setUpServer(port)
 | |
| 	if err != nil {
 | |
| 		t.Error(err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	go func() {
 | |
| 		mockServer.Start()
 | |
| 	}()
 | |
| 
 | |
| 	t.Cleanup(func() {
 | |
| 		mockServer.ShutDown()
 | |
| 	})
 | |
| 
 | |
| 	t.Run("Test_HandleSubscribe", func(t *testing.T) {
 | |
| 		t.Parallel()
 | |
| 
 | |
| 		// Establish connections.
 | |
| 		numOfConnections := 20
 | |
| 		rawConnections := make([]net.Conn, numOfConnections)
 | |
| 		connections := make([]*resp.Conn, numOfConnections)
 | |
| 		for i := 0; i < numOfConnections; i++ {
 | |
| 			conn, err := internal.GetConnection("localhost", port)
 | |
| 			if err != nil {
 | |
| 				t.Error(err)
 | |
| 				return
 | |
| 			}
 | |
| 			rawConnections[i] = conn
 | |
| 			connections[i] = resp.NewConn(conn)
 | |
| 		}
 | |
| 		defer func() {
 | |
| 			for _, conn := range rawConnections {
 | |
| 				_ = conn.Close()
 | |
| 			}
 | |
| 		}()
 | |
| 
 | |
| 		// Test subscribe to channels
 | |
| 		channels := []string{"sub_channel1", "sub_channel2", "sub_channel3"}
 | |
| 		command := []resp.Value{resp.StringValue("SUBSCRIBE")}
 | |
| 		for _, channel := range channels {
 | |
| 			command = append(command, resp.StringValue(channel))
 | |
| 		}
 | |
| 		for _, conn := range connections {
 | |
| 			if err := conn.WriteArray(command); err != nil {
 | |
| 				t.Error(err)
 | |
| 				return
 | |
| 			}
 | |
| 			for i := 0; i < len(channels); i++ {
 | |
| 				// Read all the subscription confirmations from the connection.
 | |
| 				if _, _, err := conn.ReadValue(); err != nil {
 | |
| 					t.Error(err)
 | |
| 					return
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 		activeChannels, err := mockServer.PubSubChannels("*")
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 			return
 | |
| 		}
 | |
| 		numSubs, err := mockServer.PubSubNumSub(channels...)
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 			return
 | |
| 		}
 | |
| 		for _, channel := range channels {
 | |
| 			// Check if the channel exists in the pubsub module.
 | |
| 			if !slices.Contains(activeChannels, channel) {
 | |
| 				t.Errorf("expected pubsub to contain channel \"%s\" but it was not found", channel)
 | |
| 				return
 | |
| 			}
 | |
| 			// Check if the channel has the right number of subscribers.
 | |
| 			if numSubs[channel] != len(connections) {
 | |
| 				t.Errorf("expected channel \"%s\" to have %d subscribers, got %d",
 | |
| 					channel, len(connections), numSubs[channel])
 | |
| 				return
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// Test subscribe to patterns
 | |
| 		patterns := []string{"psub_channel*"}
 | |
| 		command = []resp.Value{resp.StringValue("PSUBSCRIBE")}
 | |
| 		for _, pattern := range patterns {
 | |
| 			command = append(command, resp.StringValue(pattern))
 | |
| 		}
 | |
| 		for _, conn := range connections {
 | |
| 			if err := conn.WriteArray(command); err != nil {
 | |
| 				t.Error(err)
 | |
| 				return
 | |
| 			}
 | |
| 			for i := 0; i < len(patterns); i++ {
 | |
| 				// Read all the pattern subscription confirmations from the connection.
 | |
| 				if _, _, err := conn.ReadValue(); err != nil {
 | |
| 					t.Error(err)
 | |
| 					return
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 		numSubs, err = mockServer.PubSubNumSub(patterns...)
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 			return
 | |
| 		}
 | |
| 		for _, pattern := range patterns {
 | |
| 			activePatterns, err := mockServer.PubSubChannels(pattern)
 | |
| 			if err != nil {
 | |
| 				t.Error(err)
 | |
| 				return
 | |
| 			}
 | |
| 			// Check if pattern channel exists in pubsub module.
 | |
| 			if !slices.Contains(activePatterns, pattern) {
 | |
| 				t.Errorf("expected pubsub to contain pattern channel \"%s\" but it was not found", pattern)
 | |
| 				return
 | |
| 			}
 | |
| 			// Check if the channel has all the connections from above.
 | |
| 			if numSubs[pattern] != len(connections) {
 | |
| 				t.Errorf("expected pattern channel \"%s\" to have %d subscribers, got %d",
 | |
| 					pattern, len(connections), numSubs[pattern])
 | |
| 				return
 | |
| 			}
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	t.Run("Test_HandleUnsubscribe", func(t *testing.T) {
 | |
| 		t.Parallel()
 | |
| 
 | |
| 		var rawConnections []net.Conn
 | |
| 		generateConnections := func(noOfConnections int) []*resp.Conn {
 | |
| 			connections := make([]*resp.Conn, noOfConnections)
 | |
| 			for i := 0; i < noOfConnections; i++ {
 | |
| 				conn, err := internal.GetConnection("localhost", port)
 | |
| 				if err != nil {
 | |
| 					t.Error(err)
 | |
| 				}
 | |
| 				rawConnections = append(rawConnections, conn)
 | |
| 				connections[i] = resp.NewConn(conn)
 | |
| 			}
 | |
| 			return connections
 | |
| 		}
 | |
| 		defer func() {
 | |
| 			for _, conn := range rawConnections {
 | |
| 				_ = conn.Close()
 | |
| 			}
 | |
| 		}()
 | |
| 
 | |
| 		verifyResponse := func(res resp.Value, expectedResponse [][]string) {
 | |
| 			v := res.Array()
 | |
| 			if len(v) != len(expectedResponse) {
 | |
| 				t.Errorf("expected subscribe response of length %d, but got %d", len(expectedResponse), len(v))
 | |
| 			}
 | |
| 			for _, item := range v {
 | |
| 				arr := item.Array()
 | |
| 				if len(arr) != 3 {
 | |
| 					t.Errorf("expected subscribe response item to be length %d, but got %d", 3, len(arr))
 | |
| 				}
 | |
| 				if !slices.ContainsFunc(expectedResponse, func(strings []string) bool {
 | |
| 					return strings[0] == arr[0].String() && strings[1] == arr[1].String() && strings[2] == arr[2].String()
 | |
| 				}) {
 | |
| 					t.Errorf("expected to find item \"%s\" in response, did not find it.", arr[1].String())
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		tests := []struct {
 | |
| 			subChannels       []string              // All channels to subscribe to
 | |
| 			subPatterns       []string              // All patterns to subscribe to
 | |
| 			unSubChannels     []string              // Channels to unsubscribe from
 | |
| 			unSubPatterns     []string              // Patterns to unsubscribe from
 | |
| 			remainChannels    []string              // Channels to remain subscribed to
 | |
| 			remainPatterns    []string              // Patterns to remain subscribed to
 | |
| 			targetConn        *resp.Conn            // Connection used to test unsubscribe functionality
 | |
| 			otherConnections  []*resp.Conn          // Connections to fill the subscribers list for channels and patterns
 | |
| 			expectedResponses map[string][][]string // The expected response from the handler
 | |
| 		}{
 | |
| 			{ // 1. Unsubscribe from channels and patterns
 | |
| 				subChannels:      []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
 | |
| 				subPatterns:      []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
 | |
| 				unSubChannels:    []string{"xx_channel_one", "xx_channel_two"},
 | |
| 				unSubPatterns:    []string{"xx_pattern_[ab]"},
 | |
| 				remainChannels:   []string{"xx_channel_three", "xx_channel_four"},
 | |
| 				remainPatterns:   []string{"xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
 | |
| 				targetConn:       generateConnections(1)[0],
 | |
| 				otherConnections: generateConnections(20),
 | |
| 				expectedResponses: map[string][][]string{
 | |
| 					"channel": {
 | |
| 						{"unsubscribe", "xx_channel_one", "1"},
 | |
| 						{"unsubscribe", "xx_channel_two", "2"},
 | |
| 					},
 | |
| 					"pattern": {
 | |
| 						{"punsubscribe", "xx_pattern_[ab]", "1"},
 | |
| 					},
 | |
| 				},
 | |
| 			},
 | |
| 			{ // 2. Unsubscribe from all channels no channel or pattern is passed to command
 | |
| 				subChannels:      []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
 | |
| 				subPatterns:      []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
 | |
| 				unSubChannels:    []string{},
 | |
| 				unSubPatterns:    []string{},
 | |
| 				remainChannels:   []string{},
 | |
| 				remainPatterns:   []string{},
 | |
| 				targetConn:       generateConnections(1)[0],
 | |
| 				otherConnections: generateConnections(20),
 | |
| 				expectedResponses: map[string][][]string{
 | |
| 					"channel": {
 | |
| 						{"unsubscribe", "xx_channel_one", "1"},
 | |
| 						{"unsubscribe", "xx_channel_two", "2"},
 | |
| 						{"unsubscribe", "xx_channel_three", "3"},
 | |
| 						{"unsubscribe", "xx_channel_four", "4"},
 | |
| 					},
 | |
| 					"pattern": {
 | |
| 						{"punsubscribe", "xx_pattern_[ab]", "1"},
 | |
| 						{"punsubscribe", "xx_pattern_[cd]", "2"},
 | |
| 						{"punsubscribe", "xx_pattern_[ef]", "3"},
 | |
| 						{"punsubscribe", "xx_pattern_[gh]", "4"},
 | |
| 					},
 | |
| 				},
 | |
| 			},
 | |
| 			{ // 3. Don't unsubscribe from any channels or patterns if the provided ones are non-existent
 | |
| 				subChannels:      []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
 | |
| 				subPatterns:      []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
 | |
| 				unSubChannels:    []string{"xx_channel_non_existent_channel"},
 | |
| 				unSubPatterns:    []string{"xx_channel_non_existent_pattern_[ae]"},
 | |
| 				remainChannels:   []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
 | |
| 				remainPatterns:   []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
 | |
| 				targetConn:       generateConnections(1)[0],
 | |
| 				otherConnections: generateConnections(20),
 | |
| 				expectedResponses: map[string][][]string{
 | |
| 					"channel": {},
 | |
| 					"pattern": {},
 | |
| 				},
 | |
| 			},
 | |
| 		}
 | |
| 
 | |
| 		for _, test := range tests {
 | |
| 			// Subscribe to channels.
 | |
| 			for _, conn := range append(test.otherConnections, test.targetConn) {
 | |
| 				command := []resp.Value{resp.StringValue("SUBSCRIBE")}
 | |
| 				for _, channel := range test.subChannels {
 | |
| 					command = append(command, resp.StringValue(channel))
 | |
| 				}
 | |
| 				if err := conn.WriteArray(command); err != nil {
 | |
| 					t.Error(err)
 | |
| 					return
 | |
| 				}
 | |
| 				for i := 0; i < len(test.subChannels); i++ {
 | |
| 					// Read channel subscription confirmations from connection.
 | |
| 					if _, _, err := conn.ReadValue(); err != nil {
 | |
| 						t.Error(err)
 | |
| 					}
 | |
| 				}
 | |
| 
 | |
| 				// Subscribe to patterns.
 | |
| 				command = []resp.Value{resp.StringValue("PSUBSCRIBE")}
 | |
| 				for _, pattern := range test.subPatterns {
 | |
| 					command = append(command, resp.StringValue(pattern))
 | |
| 				}
 | |
| 				if err := conn.WriteArray(command); err != nil {
 | |
| 					t.Error(err)
 | |
| 					return
 | |
| 				}
 | |
| 				for i := 0; i < len(test.subPatterns); i++ {
 | |
| 					// Read pattern subscription confirmations from connection.
 | |
| 					if _, _, err := conn.ReadValue(); err != nil {
 | |
| 						t.Error(err)
 | |
| 					}
 | |
| 				}
 | |
| 
 | |
| 			}
 | |
| 
 | |
| 			// Unsubscribe the target connection from the unsub channels.
 | |
| 			command := []resp.Value{resp.StringValue("UNSUBSCRIBE")}
 | |
| 			for _, channel := range test.unSubChannels {
 | |
| 				command = append(command, resp.StringValue(channel))
 | |
| 			}
 | |
| 			if err := test.targetConn.WriteArray(command); err != nil {
 | |
| 				t.Error(err)
 | |
| 				return
 | |
| 			}
 | |
| 			res, _, err := test.targetConn.ReadValue()
 | |
| 			if err != nil {
 | |
| 				t.Error(err)
 | |
| 				return
 | |
| 			}
 | |
| 			verifyResponse(res, test.expectedResponses["channel"])
 | |
| 
 | |
| 			// Unsubscribe the target connection from the unsub patterns.
 | |
| 			command = []resp.Value{resp.StringValue("PUNSUBSCRIBE")}
 | |
| 			for _, pattern := range test.unSubPatterns {
 | |
| 				command = append(command, resp.StringValue(pattern))
 | |
| 			}
 | |
| 			if err = test.targetConn.WriteArray(command); err != nil {
 | |
| 				t.Error(err)
 | |
| 				return
 | |
| 			}
 | |
| 			res, _, err = test.targetConn.ReadValue()
 | |
| 			if err != nil {
 | |
| 				t.Error(err)
 | |
| 				return
 | |
| 			}
 | |
| 			verifyResponse(res, test.expectedResponses["pattern"])
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	t.Run("Test_HandlePublish", func(t *testing.T) {
 | |
| 		t.Parallel()
 | |
| 
 | |
| 		var rawConnections []net.Conn
 | |
| 		establishConnection := func() *resp.Conn {
 | |
| 			conn, err := internal.GetConnection("localhost", port)
 | |
| 			if err != nil {
 | |
| 				t.Error(err)
 | |
| 			}
 | |
| 			rawConnections = append(rawConnections, conn)
 | |
| 			return resp.NewConn(conn)
 | |
| 		}
 | |
| 		defer func() {
 | |
| 			for _, conn := range rawConnections {
 | |
| 				_ = conn.Close()
 | |
| 			}
 | |
| 		}()
 | |
| 
 | |
| 		// verifyChannelMessage reads the message from the connection and asserts whether
 | |
| 		// it's the message we expect to read as a subscriber of a channel or pattern.
 | |
| 		verifyEvent := func(client *resp.Conn, expected []string) {
 | |
| 			rv, _, err := client.ReadValue()
 | |
| 			if err != nil {
 | |
| 				t.Error(err)
 | |
| 			}
 | |
| 			v := rv.Array()
 | |
| 			for i := 0; i < len(v); i++ {
 | |
| 				if v[i].String() != expected[i] {
 | |
| 					t.Errorf("expected item at index %d to be \"%s\", got \"%s\"", i, expected[i], v[i].String())
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// The subscribe function handles subscribing the connection to the given
 | |
| 		// channels and patterns and reading/verifying the message sent by the server after
 | |
| 		// subscription.
 | |
| 		subscribe := func(client *resp.Conn, channels []string, patterns []string) {
 | |
| 			// Subscribe to channels
 | |
| 			command := []resp.Value{resp.StringValue("SUBSCRIBE")}
 | |
| 			for _, channel := range channels {
 | |
| 				command = append(command, resp.StringValue(channel))
 | |
| 			}
 | |
| 			if err := client.WriteArray(command); err != nil {
 | |
| 				t.Error(err)
 | |
| 			}
 | |
| 			for i := 0; i < len(channels); i++ {
 | |
| 				// Read channel subscription confirmations.
 | |
| 				if _, _, err := client.ReadValue(); err != nil {
 | |
| 					t.Error(err)
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			// Subscribe to all the patterns
 | |
| 			command = []resp.Value{resp.StringValue("PSUBSCRIBE")}
 | |
| 			for _, pattern := range patterns {
 | |
| 				command = append(command, resp.StringValue(pattern))
 | |
| 			}
 | |
| 			if err := client.WriteArray(command); err != nil {
 | |
| 				t.Error(err)
 | |
| 			}
 | |
| 			for i := 0; i < len(patterns); i++ {
 | |
| 				// Read pattern subscription confirmations.
 | |
| 				if _, _, err := client.ReadValue(); err != nil {
 | |
| 					t.Error(err)
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		subscriptions := []struct {
 | |
| 			client   *resp.Conn
 | |
| 			channels []string
 | |
| 			patterns []string
 | |
| 		}{
 | |
| 			{
 | |
| 				client:   establishConnection(),
 | |
| 				channels: []string{"pub_channel_1", "pub_channel_2", "pub_channel_3"},
 | |
| 				patterns: []string{"pub_channel_[456]"},
 | |
| 			},
 | |
| 			{
 | |
| 				client:   establishConnection(),
 | |
| 				channels: []string{"pub_channel_6", "pub_channel_7"},
 | |
| 				patterns: []string{"pub_channel_[891]"},
 | |
| 			},
 | |
| 		}
 | |
| 		for _, subscription := range subscriptions {
 | |
| 			// Subscribe to channels and patterns.
 | |
| 			subscribe(subscription.client, subscription.channels, subscription.patterns)
 | |
| 		}
 | |
| 
 | |
| 		type Subscriber struct {
 | |
| 			client  *resp.Conn
 | |
| 			channel string
 | |
| 		}
 | |
| 
 | |
| 		tests := []struct {
 | |
| 			channel     string
 | |
| 			message     string
 | |
| 			subscribers []Subscriber
 | |
| 		}{
 | |
| 			{
 | |
| 				channel: "pub_channel_1",
 | |
| 				message: "Test both subscribers 1",
 | |
| 				subscribers: []Subscriber{
 | |
| 					{client: subscriptions[0].client, channel: "pub_channel_1"},
 | |
| 					{client: subscriptions[1].client, channel: "pub_channel_[891]"},
 | |
| 				},
 | |
| 			},
 | |
| 			{
 | |
| 				channel: "pub_channel_6",
 | |
| 				message: "Test both subscribers 2",
 | |
| 				subscribers: []Subscriber{
 | |
| 					{client: subscriptions[0].client, channel: "pub_channel_[456]"},
 | |
| 					{client: subscriptions[1].client, channel: "pub_channel_6"},
 | |
| 				},
 | |
| 			},
 | |
| 			{
 | |
| 				channel: "pub_channel_2",
 | |
| 				message: "Test subscriber 1 1",
 | |
| 				subscribers: []Subscriber{
 | |
| 					{client: subscriptions[0].client, channel: "pub_channel_2"},
 | |
| 				},
 | |
| 			},
 | |
| 			{
 | |
| 				channel: "pub_channel_3",
 | |
| 				message: "Test subscriber 1 2",
 | |
| 				subscribers: []Subscriber{
 | |
| 					{client: subscriptions[0].client, channel: "pub_channel_3"},
 | |
| 				},
 | |
| 			},
 | |
| 			{
 | |
| 				channel: "pub_channel_4",
 | |
| 				message: "Test both subscribers 2",
 | |
| 				subscribers: []Subscriber{
 | |
| 					{client: subscriptions[0].client, channel: "pub_channel_[456]"},
 | |
| 				},
 | |
| 			},
 | |
| 			{
 | |
| 				channel: "pub_channel_5",
 | |
| 				message: "Test subscriber 1 3",
 | |
| 				subscribers: []Subscriber{
 | |
| 					{client: subscriptions[0].client, channel: "pub_channel_[456]"},
 | |
| 				},
 | |
| 			},
 | |
| 			{
 | |
| 				channel: "pub_channel_7",
 | |
| 				message: "Test subscriber 2 1",
 | |
| 				subscribers: []Subscriber{
 | |
| 					{client: subscriptions[1].client, channel: "pub_channel_7"},
 | |
| 				},
 | |
| 			},
 | |
| 			{
 | |
| 				channel: "pub_channel_8",
 | |
| 				message: "Test subscriber 2 2",
 | |
| 				subscribers: []Subscriber{
 | |
| 					{client: subscriptions[1].client, channel: "pub_channel_[891]"},
 | |
| 				},
 | |
| 			},
 | |
| 			{
 | |
| 				channel: "pub_channel_9",
 | |
| 				message: "Test subscriber 2 3",
 | |
| 				subscribers: []Subscriber{
 | |
| 					{client: subscriptions[1].client, channel: "pub_channel_[891]"},
 | |
| 				},
 | |
| 			},
 | |
| 		}
 | |
| 
 | |
| 		// Dial echovault to make publisher connection
 | |
| 		conn, err := internal.GetConnection("localhost", port)
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 			return
 | |
| 		}
 | |
| 		publisher := resp.NewConn(conn)
 | |
| 
 | |
| 		for _, test := range tests {
 | |
| 			err = publisher.WriteArray([]resp.Value{
 | |
| 				resp.StringValue("PUBLISH"),
 | |
| 				resp.StringValue(test.channel),
 | |
| 				resp.StringValue(test.message),
 | |
| 			})
 | |
| 			if err != nil {
 | |
| 				t.Error(err)
 | |
| 			}
 | |
| 
 | |
| 			rv, _, err := publisher.ReadValue()
 | |
| 			if err != nil {
 | |
| 				t.Error(err)
 | |
| 			}
 | |
| 			if rv.String() != "OK" {
 | |
| 				t.Errorf("Expected publish response to be \"OK\", got \"%s\"", rv.String())
 | |
| 			}
 | |
| 
 | |
| 			for _, sub := range test.subscribers {
 | |
| 				verifyEvent(sub.client, []string{"message", sub.channel, test.message})
 | |
| 			}
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	t.Run("Test_HandlePubSubChannels", func(t *testing.T) {
 | |
| 		t.Parallel()
 | |
| 
 | |
| 		verifyExpectedResponse := func(res resp.Value, expected []string) {
 | |
| 			if len(res.Array()) != len(expected) {
 | |
| 				t.Errorf("expected response array of length %d, got %d", len(expected), len(res.Array()))
 | |
| 			}
 | |
| 			for _, e := range expected {
 | |
| 				if !slices.ContainsFunc(res.Array(), func(v resp.Value) bool {
 | |
| 					return e == v.String()
 | |
| 				}) {
 | |
| 					t.Errorf("expected to find element \"%s\" in response array, could not find it", e)
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		port, err := internal.GetFreePort()
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 			return
 | |
| 		}
 | |
| 		mockServer, err := setUpServer(port)
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		wg := sync.WaitGroup{}
 | |
| 		wg.Add(1)
 | |
| 		go func() {
 | |
| 			wg.Done()
 | |
| 			mockServer.Start()
 | |
| 		}()
 | |
| 		wg.Wait()
 | |
| 
 | |
| 		subscribers := make([]*resp.Conn, 2)
 | |
| 		for i := 0; i < len(subscribers); i++ {
 | |
| 			conn, err := internal.GetConnection("localhost", port)
 | |
| 			if err != nil {
 | |
| 				t.Error(err)
 | |
| 				return
 | |
| 			}
 | |
| 			subscribers[i] = resp.NewConn(conn)
 | |
| 		}
 | |
| 
 | |
| 		channels := []string{"channel_1", "channel_2", "channel_3"}
 | |
| 		patterns := []string{"channel_[123]", "channel_[456]"}
 | |
| 
 | |
| 		subscriptions := []struct {
 | |
| 			client   *resp.Conn
 | |
| 			action   string
 | |
| 			channels []string
 | |
| 			patterns []string
 | |
| 		}{
 | |
| 			{
 | |
| 				client:   subscribers[0],
 | |
| 				action:   "SUBSCRIBE",
 | |
| 				channels: channels,
 | |
| 				patterns: make([]string, 0),
 | |
| 			},
 | |
| 			{
 | |
| 				client:   subscribers[1],
 | |
| 				action:   "PSUBSCRIBE",
 | |
| 				channels: make([]string, 0),
 | |
| 				patterns: patterns,
 | |
| 			},
 | |
| 		}
 | |
| 		for _, subscription := range subscriptions {
 | |
| 			command := []resp.Value{resp.StringValue(subscription.action)}
 | |
| 			if len(subscription.channels) > 0 {
 | |
| 				for _, channel := range subscription.channels {
 | |
| 					command = append(command, resp.StringValue(channel))
 | |
| 				}
 | |
| 			} else if len(subscription.patterns) > 0 {
 | |
| 				for _, pattern := range subscription.patterns {
 | |
| 					command = append(command, resp.StringValue(pattern))
 | |
| 				}
 | |
| 			}
 | |
| 			if err := subscription.client.WriteArray(command); err != nil {
 | |
| 				t.Error(err)
 | |
| 			}
 | |
| 			if len(subscription.channels) > 0 {
 | |
| 				for i := 0; i < len(subscription.channels); i++ {
 | |
| 					_, _, _ = subscription.client.ReadValue()
 | |
| 				}
 | |
| 				return
 | |
| 			}
 | |
| 			for i := 0; i < len(subscription.patterns); i++ {
 | |
| 				_, _, _ = subscription.client.ReadValue()
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// Get fresh connection for the next phase.
 | |
| 		conn, err := internal.GetConnection("localhost", port)
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 			return
 | |
| 		}
 | |
| 		client := resp.NewConn(conn)
 | |
| 
 | |
| 		// Check if all subscriptions are returned.
 | |
| 		if err = client.WriteArray([]resp.Value{resp.StringValue("PUBSUB"), resp.StringValue("CHANNELS")}); err != nil {
 | |
| 			t.Error(err)
 | |
| 		}
 | |
| 		res, _, err := client.ReadValue()
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 		}
 | |
| 		verifyExpectedResponse(res, append(channels, patterns...))
 | |
| 
 | |
| 		// Unsubscribe from one pattern and one channel before checking against a new slice of
 | |
| 		// expected channels/patterns in the response of the "PUBSUB CHANNELS" command.
 | |
| 		for _, unsubscribe := range []struct {
 | |
| 			client  *resp.Conn
 | |
| 			command []resp.Value
 | |
| 		}{
 | |
| 			{
 | |
| 				client:  subscribers[0],
 | |
| 				command: []resp.Value{resp.StringValue("UNSUBSCRIBE"), resp.StringValue("channel_2"), resp.StringValue("channel_3")},
 | |
| 			},
 | |
| 			{
 | |
| 				client:  subscribers[1],
 | |
| 				command: []resp.Value{resp.StringValue("UNSUBSCRIBE"), resp.StringValue("channel_[456]")},
 | |
| 			},
 | |
| 		} {
 | |
| 			if err = unsubscribe.client.WriteArray(unsubscribe.command); err != nil {
 | |
| 				t.Error(err)
 | |
| 			}
 | |
| 			for i := 0; i < len(unsubscribe.command[1:]); i++ {
 | |
| 				_, _, err = unsubscribe.client.ReadValue()
 | |
| 				if err != nil {
 | |
| 					t.Error(err)
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// Return all the remaining channels.
 | |
| 		if err = client.WriteArray([]resp.Value{resp.StringValue("PUBSUB"), resp.StringValue("CHANNELS")}); err != nil {
 | |
| 			t.Error(err)
 | |
| 		}
 | |
| 		res, _, err = client.ReadValue()
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 		}
 | |
| 		verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"})
 | |
| 
 | |
| 		// Return only one of the remaining channels when passed a pattern that matches it.
 | |
| 		if err = client.WriteArray([]resp.Value{
 | |
| 			resp.StringValue("PUBSUB"),
 | |
| 			resp.StringValue("CHANNELS"),
 | |
| 			resp.StringValue("channel_[189]"),
 | |
| 		}); err != nil {
 | |
| 			t.Error(err)
 | |
| 		}
 | |
| 		verifyExpectedResponse(res, []string{"channel_1"})
 | |
| 
 | |
| 		// Return both remaining channels when passed a pattern that matches them.
 | |
| 		if err := client.WriteArray([]resp.Value{
 | |
| 			resp.StringValue("PUBSUB"),
 | |
| 			resp.StringValue("CHANNELS"),
 | |
| 			resp.StringValue("channel_[123]"),
 | |
| 		}); err != nil {
 | |
| 			t.Error(err)
 | |
| 		}
 | |
| 		res, _, err = client.ReadValue()
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 		}
 | |
| 		verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"})
 | |
| 
 | |
| 		// Return no channels when passed a pattern that does not match either channel.
 | |
| 		if err = client.WriteArray([]resp.Value{
 | |
| 			resp.StringValue("PUBSUB"),
 | |
| 			resp.StringValue("CHANNELS"),
 | |
| 			resp.StringValue("channel_[456]"),
 | |
| 		}); err != nil {
 | |
| 			t.Error(err)
 | |
| 		}
 | |
| 		res, _, err = client.ReadValue()
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 		}
 | |
| 		verifyExpectedResponse(res, []string{})
 | |
| 	})
 | |
| 
 | |
| 	t.Run("Test_HandleNumPat", func(t *testing.T) {
 | |
| 		t.Parallel()
 | |
| 
 | |
| 		port, err := internal.GetFreePort()
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		mockServer, err := setUpServer(port)
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		wg := sync.WaitGroup{}
 | |
| 		wg.Add(1)
 | |
| 		go func() {
 | |
| 			wg.Done()
 | |
| 			mockServer.Start()
 | |
| 		}()
 | |
| 		wg.Wait()
 | |
| 
 | |
| 		// Create subscribers.
 | |
| 		subscribers := make([]*resp.Conn, 3)
 | |
| 		for i := 0; i < len(subscribers); i++ {
 | |
| 			conn, err := internal.GetConnection("localhost", port)
 | |
| 			if err != nil {
 | |
| 				t.Error(err)
 | |
| 				return
 | |
| 			}
 | |
| 			subscribers[i] = resp.NewConn(conn)
 | |
| 		}
 | |
| 
 | |
| 		patterns := []string{"pattern_[123]", "pattern_[456]", "pattern_[789]"}
 | |
| 
 | |
| 		// Subscribe to all patterns
 | |
| 		for _, client := range subscribers {
 | |
| 			command := []resp.Value{resp.StringValue("PSUBSCRIBE")}
 | |
| 			for _, pattern := range patterns {
 | |
| 				command = append(command, resp.StringValue(pattern))
 | |
| 			}
 | |
| 			if err := client.WriteArray(command); err != nil {
 | |
| 				t.Error(err)
 | |
| 			}
 | |
| 			// Read subscription responses to make sure we've subscribed to all the channels.
 | |
| 			for i := 0; i < len(patterns); i++ {
 | |
| 				res, _, err := client.ReadValue()
 | |
| 				if err != nil {
 | |
| 					t.Error(err)
 | |
| 				}
 | |
| 				if len(res.Array()) != 3 {
 | |
| 					t.Errorf("expected array response of length %d, got %d", 3, len(res.Array()))
 | |
| 				}
 | |
| 				if !strings.EqualFold(res.Array()[0].String(), "psubscribe") {
 | |
| 					t.Errorf("expected the first array item to be \"psubscribe\", got \"%s\"", res.Array()[0].String())
 | |
| 				}
 | |
| 				if !slices.Contains(patterns, res.Array()[1].String()) {
 | |
| 					t.Errorf("unexpected channel name \"%s\", expected %v", res.Array()[1].String(), patterns)
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// Get fresh connection for the next phase.
 | |
| 		conn, err := internal.GetConnection("localhost", port)
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 			return
 | |
| 		}
 | |
| 		client := resp.NewConn(conn)
 | |
| 
 | |
| 		// Check that we receive all the patterns with NUMPAT commands.
 | |
| 		if err = client.WriteArray([]resp.Value{resp.StringValue("PUBSUB"), resp.StringValue("NUMPAT")}); err != nil {
 | |
| 			t.Error(err)
 | |
| 		}
 | |
| 		res, _, err := client.ReadValue()
 | |
| 		if res.Integer() != len(patterns) {
 | |
| 			t.Errorf("expected response \"%d\", got \"%d\"", len(patterns), res.Integer())
 | |
| 		}
 | |
| 
 | |
| 		// Unsubscribe all subscribers from one pattern and check if the response is updated.
 | |
| 		for _, subscriber := range subscribers {
 | |
| 			if err = subscriber.WriteArray([]resp.Value{
 | |
| 				resp.StringValue("PUNSUBSCRIBE"),
 | |
| 				resp.StringValue(patterns[0]),
 | |
| 			}); err != nil {
 | |
| 				t.Error(err)
 | |
| 			}
 | |
| 			res, _, err = subscriber.ReadValue()
 | |
| 			if err != nil {
 | |
| 				t.Error(err)
 | |
| 			}
 | |
| 			if len(res.Array()[0].Array()) != 3 {
 | |
| 				t.Errorf("expected array response of length %d, got %d", 3, len(res.Array()[0].Array()))
 | |
| 			}
 | |
| 			if !strings.EqualFold(res.Array()[0].Array()[0].String(), "punsubscribe") {
 | |
| 				t.Errorf("expected the first array item to be \"punsubscribe\", got \"%s\"", res.Array()[0].Array()[0].String())
 | |
| 			}
 | |
| 			if res.Array()[0].Array()[1].String() != patterns[0] {
 | |
| 				t.Errorf("unexpected channel name \"%s\", expected %s", res.Array()[0].Array()[1].String(), patterns[0])
 | |
| 			}
 | |
| 		}
 | |
| 		if err = client.WriteArray([]resp.Value{resp.StringValue("PUBSUB"), resp.StringValue("NUMPAT")}); err != nil {
 | |
| 			t.Error(err)
 | |
| 		}
 | |
| 		res, _, err = client.ReadValue()
 | |
| 		if res.Integer() != len(patterns)-1 {
 | |
| 			t.Errorf("expected response \"%d\", got \"%d\"", len(patterns)-1, res.Integer())
 | |
| 		}
 | |
| 
 | |
| 		// Unsubscribe from all the channels and check if we get a 0 response
 | |
| 		for _, subscriber := range subscribers {
 | |
| 			for _, pattern := range patterns[1:] {
 | |
| 				if err = subscriber.WriteArray([]resp.Value{
 | |
| 					resp.StringValue("PUNSUBSCRIBE"),
 | |
| 					resp.StringValue(pattern),
 | |
| 				}); err != nil {
 | |
| 					t.Error(err)
 | |
| 				}
 | |
| 				res, _, err = subscriber.ReadValue()
 | |
| 				if err != nil {
 | |
| 					t.Error(err)
 | |
| 				}
 | |
| 				if len(res.Array()[0].Array()) != 3 {
 | |
| 					t.Errorf("expected array response of length %d, got %d", 3,
 | |
| 						len(res.Array()[0].Array()))
 | |
| 				}
 | |
| 				if !strings.EqualFold(res.Array()[0].Array()[0].String(), "punsubscribe") {
 | |
| 					t.Errorf("expected the first array item to be \"punsubscribe\", got \"%s\"",
 | |
| 						res.Array()[0].Array()[0].String())
 | |
| 				}
 | |
| 				if res.Array()[0].Array()[1].String() != pattern {
 | |
| 					t.Errorf("unexpected channel name \"%s\", expected %s",
 | |
| 						res.Array()[0].Array()[1].String(), pattern)
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 		if err = client.WriteArray([]resp.Value{resp.StringValue("PUBSUB"), resp.StringValue("NUMPAT")}); err != nil {
 | |
| 			t.Error(err)
 | |
| 		}
 | |
| 		res, _, err = client.ReadValue()
 | |
| 		if res.Integer() != 0 {
 | |
| 			t.Errorf("expected response \"%d\", got \"%d\"", 0, res.Integer())
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	t.Run("Test_HandleNumSub", func(t *testing.T) {
 | |
| 		t.Parallel()
 | |
| 
 | |
| 		port, err := internal.GetFreePort()
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		mockServer, err := setUpServer(port)
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		wg := sync.WaitGroup{}
 | |
| 		wg.Add(1)
 | |
| 		go func() {
 | |
| 			wg.Done()
 | |
| 			mockServer.Start()
 | |
| 		}()
 | |
| 		wg.Wait()
 | |
| 
 | |
| 		channels := []string{"channel_1", "channel_2", "channel_3"}
 | |
| 
 | |
| 		for i := 0; i < 3; i++ {
 | |
| 			conn, err := internal.GetConnection("localhost", port)
 | |
| 			if err != nil {
 | |
| 				t.Error(err)
 | |
| 				return
 | |
| 			}
 | |
| 			client := resp.NewConn(conn)
 | |
| 			command := []resp.Value{
 | |
| 				resp.StringValue("SUBSCRIBE"),
 | |
| 			}
 | |
| 			for _, channel := range channels {
 | |
| 				command = append(command, resp.StringValue(channel))
 | |
| 			}
 | |
| 			err = client.WriteArray(command)
 | |
| 			if err != nil {
 | |
| 				t.Error(err)
 | |
| 			}
 | |
| 
 | |
| 			// Read subscription responses to make sure we've subscribed to all the channels.
 | |
| 			for i := 0; i < len(channels); i++ {
 | |
| 				res, _, err := client.ReadValue()
 | |
| 				if err != nil {
 | |
| 					t.Error(err)
 | |
| 				}
 | |
| 				if len(res.Array()) != 3 {
 | |
| 					t.Errorf("expected array response of length %d, got %d", 3, len(res.Array()))
 | |
| 				}
 | |
| 				if !strings.EqualFold(res.Array()[0].String(), "subscribe") {
 | |
| 					t.Errorf("expected the first array item to be \"subscribe\", got \"%s\"", res.Array()[0].String())
 | |
| 				}
 | |
| 				if !slices.Contains(channels, res.Array()[1].String()) {
 | |
| 					t.Errorf("unexpected channel name \"%s\", expected %v", res.Array()[1].String(), channels)
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// Get fresh connection for the next phase.
 | |
| 		conn, err := internal.GetConnection("localhost", port)
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 			return
 | |
| 		}
 | |
| 		client := resp.NewConn(conn)
 | |
| 
 | |
| 		tests := []struct {
 | |
| 			name             string
 | |
| 			cmd              []string
 | |
| 			expectedResponse [][]string
 | |
| 		}{
 | |
| 			{
 | |
| 				name:             "1. Get all subscriptions on existing channels",
 | |
| 				cmd:              append([]string{"PUBSUB", "NUMSUB"}, channels...),
 | |
| 				expectedResponse: [][]string{{"channel_1", "3"}, {"channel_2", "3"}, {"channel_3", "3"}},
 | |
| 			},
 | |
| 			{
 | |
| 				name: "2. Get all the subscriptions of on existing channels and a few non-existent ones",
 | |
| 				cmd:  append([]string{"PUBSUB", "NUMSUB", "non_existent_channel_1", "non_existent_channel_2"}, channels...),
 | |
| 				expectedResponse: [][]string{
 | |
| 					{"non_existent_channel_1", "0"},
 | |
| 					{"non_existent_channel_2", "0"},
 | |
| 					{"channel_1", "3"},
 | |
| 					{"channel_2", "3"},
 | |
| 					{"channel_3", "3"},
 | |
| 				},
 | |
| 			},
 | |
| 			{
 | |
| 				name:             "3. Get an empty array when channels are not provided in the command",
 | |
| 				cmd:              []string{"PUBSUB", "NUMSUB"},
 | |
| 				expectedResponse: make([][]string, 0),
 | |
| 			},
 | |
| 		}
 | |
| 
 | |
| 		for _, test := range tests {
 | |
| 			t.Run(test.name, func(t *testing.T) {
 | |
| 				var command []resp.Value
 | |
| 				for _, token := range test.cmd {
 | |
| 					command = append(command, resp.StringValue(token))
 | |
| 				}
 | |
| 
 | |
| 				if err = client.WriteArray(command); err != nil {
 | |
| 					t.Error(err)
 | |
| 				}
 | |
| 
 | |
| 				res, _, err := client.ReadValue()
 | |
| 				if err != nil {
 | |
| 					t.Error(err)
 | |
| 				}
 | |
| 
 | |
| 				arr := res.Array()
 | |
| 				if len(arr) != len(test.expectedResponse) {
 | |
| 					t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(arr))
 | |
| 				}
 | |
| 
 | |
| 				for _, item := range arr {
 | |
| 					itemArr := item.Array()
 | |
| 					if len(itemArr) != 2 {
 | |
| 						t.Errorf("expected each response item to be of length 2, got %d", len(itemArr))
 | |
| 					}
 | |
| 					if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool {
 | |
| 						return expected[0] == itemArr[0].String() && expected[1] == itemArr[1].String()
 | |
| 					}) {
 | |
| 						t.Errorf("could not find entry with channel \"%s\", with %d subscribers in expected response",
 | |
| 							itemArr[0].String(), itemArr[1].Integer())
 | |
| 					}
 | |
| 				}
 | |
| 			})
 | |
| 		}
 | |
| 	})
 | |
| }
 | 
