Implemented tests for SUBSCRIBE, PSUBSCRIBE, UNSUBSCRIBE, and PUNSUBSCRIBE command handlers

This commit is contained in:
Kelvin Mwinuka
2024-03-17 02:41:49 +08:00
parent e685d5041b
commit dbfa398543
4 changed files with 303 additions and 47 deletions

View File

@@ -21,14 +21,9 @@ func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, con
return nil, errors.New(utils.WrongArgsResponse) return nil, errors.New(utils.WrongArgsResponse)
} }
switch strings.ToLower(cmd[0]) { withPattern := strings.EqualFold(cmd[0], "psubscribe")
case "subscribe":
return pubsub.Subscribe(ctx, conn, channels, false), nil
case "psubscribe":
return pubsub.Subscribe(ctx, conn, channels, true), nil
}
return []byte{}, nil return pubsub.Subscribe(ctx, conn, channels, withPattern), nil
} }
func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -39,14 +34,9 @@ func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, c
channels := cmd[1:] channels := cmd[1:]
switch strings.ToLower(cmd[0]) { withPattern := strings.EqualFold(cmd[0], "punsubscribe")
case "unsubscribe":
return pubsub.Unsubscribe(ctx, conn, channels, false), nil return pubsub.Unsubscribe(ctx, conn, channels, withPattern), nil
case "punsubscribe":
return pubsub.Unsubscribe(ctx, conn, channels, true), nil
default:
return []byte{}, nil
}
} }
func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -156,7 +146,7 @@ it's currently subscribe to.`,
{ {
Command: "punsubscribe", Command: "punsubscribe",
Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory}, Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory},
Description: `(PUNSUBSCRIBE [channel [channel ...]]) Unsubscribe from a list of channels using patterns. Description: `(PUNSUBSCRIBE [pattern [pattern ...]]) Unsubscribe from a list of channels using patterns.
If the pattern list is not provided, then the connection will be unsubscribed from all the patterns that If the pattern list is not provided, then the connection will be unsubscribed from all the patterns that
it's currently subscribe to.`, it's currently subscribe to.`,
Sync: false, Sync: false,

View File

@@ -1,42 +1,281 @@
package pubsub package pubsub
import ( import (
"bytes"
"context"
"fmt"
"github.com/echovault/echovault/src/server" "github.com/echovault/echovault/src/server"
"github.com/echovault/echovault/src/utils" "github.com/echovault/echovault/src/utils"
"github.com/tidwall/resp"
"net"
"slices"
"testing" "testing"
) )
var pubsub *PubSub
var mockServer *server.Server var mockServer *server.Server
var bindAddr = "localhost"
var port uint16 = 7490
func init() { func init() {
pubsub = NewPubSub()
mockServer = server.NewServer(server.Opts{ mockServer = server.NewServer(server.Opts{
PubSub: pubsub,
Config: utils.Config{ Config: utils.Config{
BindAddr: bindAddr,
Port: port,
DataDir: "", DataDir: "",
EvictionPolicy: utils.NoEviction, EvictionPolicy: utils.NoEviction,
}, },
}) })
go func() {
mockServer.Start(context.Background())
}()
} }
func Test_HandleSubscribe(t *testing.T) { func Test_HandleSubscribe(t *testing.T) {
ctx := context.WithValue(context.Background(), "test_name", "SUBSCRIBE/PSUBSCRIBE")
numOfConnection := 100
connections := make([]*net.Conn, numOfConnection)
for i := 0; i < numOfConnection; i++ {
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil {
t.Error(err)
}
connections[i] = &conn
}
// Test subscribe to channels
channels := []string{"sub_channel1", "sub_channel2", "sub_channel3"}
for _, conn := range connections {
if _, err := handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, conn); err != nil {
t.Error(err)
}
}
for _, channel := range channels {
// Check if the channel exists in the pubsub module
if !slices.ContainsFunc(pubsub.channels, func(c *Channel) bool {
return c.name == channel
}) {
t.Errorf("expected pubsub to contain channel \"%s\" but it was not found", channel)
}
for _, c := range pubsub.channels {
if c.name == channel {
// Check if channel has nil pattern
if c.pattern != nil {
t.Errorf("expected channel \"%s\" to have nil pattern, found pattern \"%s\"", channel, c.name)
}
// Check if the channel has all the connections from above
for _, conn := range connections {
if !slices.Contains(c.subscribers, conn) {
t.Errorf("could not find all expected connection in the \"%s\"", channel)
}
}
}
}
}
// Test subscribe to patterns
patterns := []string{"psub_channel*"}
for _, conn := range connections {
if _, err := handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, patterns...), mockServer, conn); err != nil {
t.Error(err)
}
}
for _, pattern := range patterns {
// Check if pattern channel exists in pubsub module
if !slices.ContainsFunc(pubsub.channels, func(c *Channel) bool {
return c.name == pattern
}) {
t.Errorf("expected pubsub to contain pattern channel \"%s\" but it was not found", pattern)
}
for _, c := range pubsub.channels {
if c.name == pattern {
// Check if channel has non-nil pattern
if c.pattern == nil {
t.Errorf("expected channel \"%s\" to have pattern \"%s\", found nil pattern", pattern, c.name)
}
// Check if the channel has all the connections from above
for _, conn := range connections {
if !slices.Contains(c.subscribers, conn) {
t.Errorf("could not find all expected connection in the \"%s\"", pattern)
}
}
}
}
}
} }
func Test_HandleUnsubscribe(t *testing.T) { func Test_HandleUnsubscribe(t *testing.T) {
generateConnections := func(noOfConnections int) []*net.Conn {
connections := make([]*net.Conn, noOfConnections)
for i := 0; i < noOfConnections; i++ {
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil {
t.Error(err)
}
connections[i] = &conn
}
return connections
}
verifyResponse := func(res []byte, expectedResponse [][]string) {
rd := resp.NewReader(bytes.NewReader(res))
rv, _, err := rd.ReadValue()
if err != nil {
t.Error(err)
}
v := rv.Array()
if len(v) != len(expectedResponse) {
t.Errorf("expected subscribe response of length %d, but got %d", len(expectedResponse), len(v))
}
for _, item := range v {
arr := item.Array()
if len(arr) != 3 {
t.Errorf("expected subscribe response item to be length %d, but got %d", 3, len(arr))
}
if !slices.ContainsFunc(expectedResponse, func(strings []string) bool {
return strings[0] == arr[0].String() && strings[1] == arr[1].String() && strings[2] == arr[2].String()
}) {
t.Errorf("expected to find item \"%s\" in response, did not find it.", arr[1].String())
}
}
}
tests := []struct {
subChannels []string // All channels to subscribe to
subPatterns []string // All patterns to subscribe to
unSubChannels []string // Channels to unsubscribe from
unSubPatterns []string // Patterns to unsubscribe from
remainChannels []string // Channels to remain subscribed to
remainPatterns []string // Patterns to remain subscribed to
targetConn *net.Conn // Connection used to test unsubscribe functionality
otherConnections []*net.Conn // Connections to fill the subscribers list for channels and patterns
expectedResponses map[string][][]string // The expected response from the handler
}{
{ // 1. Unsubscribe from channels and patterns
subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
unSubChannels: []string{"xx_channel_one", "xx_channel_two"},
unSubPatterns: []string{"xx_pattern_[ab]"},
remainChannels: []string{"xx_channel_three", "xx_channel_four"},
remainPatterns: []string{"xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
targetConn: generateConnections(1)[0],
otherConnections: generateConnections(20),
expectedResponses: map[string][][]string{
"channel": {
{"unsubscribe", "xx_channel_one", "1"},
{"unsubscribe", "xx_channel_two", "2"},
},
"pattern": {
{"punsubscribe", "xx_pattern_[ab]", "1"},
},
},
},
{ // 2. Unsubscribe from all channels no channel or pattern is passed to command
subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
unSubChannels: []string{},
unSubPatterns: []string{},
remainChannels: []string{},
remainPatterns: []string{},
targetConn: generateConnections(1)[0],
otherConnections: generateConnections(20),
expectedResponses: map[string][][]string{
"channel": {
{"unsubscribe", "xx_channel_one", "1"},
{"unsubscribe", "xx_channel_two", "2"},
{"unsubscribe", "xx_channel_three", "3"},
{"unsubscribe", "xx_channel_four", "4"},
},
"pattern": {
{"punsubscribe", "xx_pattern_[ab]", "1"},
{"punsubscribe", "xx_pattern_[cd]", "2"},
{"punsubscribe", "xx_pattern_[ef]", "3"},
{"punsubscribe", "xx_pattern_[gh]", "4"},
},
},
},
{ // 3. Don't unsubscribe from any channels or patterns if the provided ones are non-existent
subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
unSubChannels: []string{"xx_channel_non_existent_channel"},
unSubPatterns: []string{"xx_channel_non_existent_pattern_[ae]"},
remainChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"},
remainPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"},
targetConn: generateConnections(1)[0],
otherConnections: generateConnections(20),
expectedResponses: map[string][][]string{
"channel": {},
"pattern": {},
},
},
}
for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("UNSUBSCRIBE/PUNSUBSCRIBE, %d", i))
// Subscribe all the connections to the channels and patterns
for _, conn := range append(test.otherConnections, test.targetConn) {
_, err := handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, test.subChannels...), mockServer, conn)
if err != nil {
t.Error(err)
}
_, err = handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, test.subPatterns...), mockServer, conn)
if err != nil {
t.Error(err)
}
}
// Unsubscribe the target connection from the unsub channels and patterns
res, err := handleUnsubscribe(ctx, append([]string{"UNSUBSCRIBE"}, test.unSubChannels...), mockServer, test.targetConn)
if err != nil {
t.Error(err)
}
verifyResponse(res, test.expectedResponses["channel"])
res, err = handleUnsubscribe(ctx, append([]string{"PUNSUBSCRIBE"}, test.unSubPatterns...), mockServer, test.targetConn)
if err != nil {
t.Error(err)
}
verifyResponse(res, test.expectedResponses["pattern"])
for _, channel := range append(test.unSubChannels, test.unSubPatterns...) {
for _, pubsubChannel := range pubsub.channels {
if pubsubChannel.name == channel {
// Assert that target connection is no longer in the unsub channels and patterns
if slices.Contains(pubsubChannel.subscribers, test.targetConn) {
t.Errorf("found unexpected target connection after unsubscrining in channel \"%s\"", channel)
}
for _, conn := range test.otherConnections {
if !slices.Contains(pubsubChannel.subscribers, conn) {
t.Errorf("did not find expected other connection in channel \"%s\"", channel)
}
}
}
}
}
// Assert that the target connection is still in the remain channels and patterns
for _, channel := range append(test.remainChannels, test.remainPatterns...) {
for _, pubsubChannel := range pubsub.channels {
if pubsubChannel.name == channel {
if !slices.Contains(pubsubChannel.subscribers, test.targetConn) {
t.Errorf("cound not find expected target connection in channel \"%s\"", channel)
}
}
}
}
}
} }
func Test_HandlePublish(t *testing.T) { func Test_HandlePublish(t *testing.T) {}
} func Test_HandlePubSubChannels(t *testing.T) {}
func Test_HandlePubSubChannels(t *testing.T) { func Test_HandleNumPat(t *testing.T) {}
} func Test_HandleNumSub(t *testing.T) {}
func Test_HandleNumPat(t *testing.T) {
}
func Test_HandleNumSub(t *testing.T) {
}

View File

@@ -67,24 +67,43 @@ func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channels []st
action := "unsubscribe" action := "unsubscribe"
if withPattern { if withPattern {
action = "subscribe" action = "punsubscribe"
} }
unsubscribed := make(map[int]string) unsubscribed := make(map[int]string)
count := 1 count := 1
// If the channels slice is empty, unsubscribe from all channels.
if len(channels) <= 0 { if len(channels) <= 0 {
if !withPattern {
// If the channels slice is empty, and no pattern is provided
// only unsubscribe from all channels.
for _, channel := range ps.channels { for _, channel := range ps.channels {
if channel.pattern != nil { // Skip pattern channels
continue
}
if channel.Unsubscribe(conn) { if channel.Unsubscribe(conn) {
unsubscribed[1] = channel.name unsubscribed[count] = channel.name
count += 1
}
}
} else {
// If the channels slice is empty, and pattern is provided
// only unsubscribe from all patterns.
for _, channel := range ps.channels {
if channel.pattern == nil { // Skip non-pattern channels
continue
}
if channel.Unsubscribe(conn) {
unsubscribed[count] = channel.name
count += 1 count += 1
} }
} }
} }
}
// If withPattern is false, unsubscribe from channels where the name exactly matches channel name. // Unsubscribe from channels where the name exactly matches channel name.
if !withPattern { // If unsubscribing from a pattern, also unsubscribe from all channel whose
// names exactly matches the pattern name.
for _, channel := range ps.channels { // For each channel in PubSub for _, channel := range ps.channels { // For each channel in PubSub
for _, c := range channels { // For each channel name provided for _, c := range channels { // For each channel name provided
if channel.name == c && channel.Unsubscribe(conn) { if channel.name == c && channel.Unsubscribe(conn) {
@@ -93,7 +112,6 @@ func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channels []st
} }
} }
} }
}
// If withPattern is true, unsubscribe from channels where pattern matches pattern provided, // If withPattern is true, unsubscribe from channels where pattern matches pattern provided,
// also unsubscribe from channels where the name matches the given pattern. // also unsubscribe from channels where the name matches the given pattern.
@@ -103,18 +121,24 @@ func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channels []st
for _, channel := range ps.channels { for _, channel := range ps.channels {
// If it's a pattern channel, directly compare the patterns // If it's a pattern channel, directly compare the patterns
if channel.pattern != nil && channel.name == pattern { if channel.pattern != nil && channel.name == pattern {
if channel.Unsubscribe(conn) {
unsubscribed[count] = channel.name unsubscribed[count] = channel.name
count += 1 count += 1
}
continue continue
} }
// If this is a regular channel, check if the channel name matches the pattern given // If this is a regular channel, check if the channel name matches the pattern given
if g.Match(channel.name) { if g.Match(channel.name) {
if channel.Unsubscribe(conn) {
unsubscribed[count] = channel.name unsubscribed[count] = channel.name
count += 1 count += 1
} }
} }
} }
} }
}
fmt.Println("UNSUBBED: ", unsubscribed)
res := fmt.Sprintf("*%d\r\n", len(unsubscribed)) res := fmt.Sprintf("*%d\r\n", len(unsubscribed))
for key, value := range unsubscribed { for key, value := range unsubscribed {

View File

@@ -261,7 +261,10 @@ func (server *Server) StartTCP(ctx context.Context) {
} }
func (server *Server) handleConnection(ctx context.Context, conn net.Conn) { func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
// If ACL module is loaded, register the connection with the ACL
if server.ACL != nil {
server.ACL.RegisterConnection(&conn) server.ACL.RegisterConnection(&conn)
}
w, r := io.Writer(conn), io.Reader(conn) w, r := io.Writer(conn), io.Reader(conn)