Added test for PUBLISH command handler

This commit is contained in:
Kelvin Mwinuka
2024-03-18 14:05:49 +08:00
parent dbfa398543
commit 59a03aec2c
6 changed files with 261 additions and 57 deletions

View File

@@ -10,6 +10,7 @@ import (
"net"
"slices"
"testing"
"time"
)
var pubsub *PubSub
@@ -21,7 +22,8 @@ var port uint16 = 7490
func init() {
pubsub = NewPubSub()
mockServer = server.NewServer(server.Opts{
PubSub: pubsub,
PubSub: pubsub,
Commands: Commands(),
Config: utils.Config{
BindAddr: bindAddr,
Port: port,
@@ -37,7 +39,7 @@ func init() {
func Test_HandleSubscribe(t *testing.T) {
ctx := context.WithValue(context.Background(), "test_name", "SUBSCRIBE/PSUBSCRIBE")
numOfConnection := 100
numOfConnection := 20
connections := make([]*net.Conn, numOfConnection)
for i := 0; i < numOfConnection; i++ {
@@ -47,6 +49,13 @@ func Test_HandleSubscribe(t *testing.T) {
}
connections[i] = &conn
}
defer func() {
for _, conn := range connections {
if err := (*conn).Close(); err != nil {
t.Error(err)
}
}
}()
// Test subscribe to channels
channels := []string{"sub_channel1", "sub_channel2", "sub_channel3"}
@@ -70,7 +79,7 @@ func Test_HandleSubscribe(t *testing.T) {
}
// Check if the channel has all the connections from above
for _, conn := range connections {
if !slices.Contains(c.subscribers, conn) {
if _, ok := c.subscribers[conn]; !ok {
t.Errorf("could not find all expected connection in the \"%s\"", channel)
}
}
@@ -100,7 +109,7 @@ func Test_HandleSubscribe(t *testing.T) {
}
// Check if the channel has all the connections from above
for _, conn := range connections {
if !slices.Contains(c.subscribers, conn) {
if _, ok := c.subscribers[conn]; !ok {
t.Errorf("could not find all expected connection in the \"%s\"", pattern)
}
}
@@ -122,6 +131,14 @@ func Test_HandleUnsubscribe(t *testing.T) {
return connections
}
closeConnections := func(conns []*net.Conn) {
for _, conn := range conns {
if err := (*conn).Close(); err != nil {
t.Error(err)
}
}
}
verifyResponse := func(res []byte, expectedResponse [][]string) {
rd := resp.NewReader(bytes.NewReader(res))
rv, _, err := rd.ReadValue()
@@ -130,6 +147,7 @@ func Test_HandleUnsubscribe(t *testing.T) {
}
v := rv.Array()
if len(v) != len(expectedResponse) {
fmt.Println(v)
t.Errorf("expected subscribe response of length %d, but got %d", len(expectedResponse), len(v))
}
for _, item := range v {
@@ -247,11 +265,11 @@ func Test_HandleUnsubscribe(t *testing.T) {
for _, pubsubChannel := range pubsub.channels {
if pubsubChannel.name == channel {
// Assert that target connection is no longer in the unsub channels and patterns
if slices.Contains(pubsubChannel.subscribers, test.targetConn) {
if _, ok := pubsubChannel.subscribers[test.targetConn]; ok {
t.Errorf("found unexpected target connection after unsubscrining in channel \"%s\"", channel)
}
for _, conn := range test.otherConnections {
if !slices.Contains(pubsubChannel.subscribers, conn) {
if _, ok := pubsubChannel.subscribers[conn]; !ok {
t.Errorf("did not find expected other connection in channel \"%s\"", channel)
}
}
@@ -263,16 +281,198 @@ func Test_HandleUnsubscribe(t *testing.T) {
for _, channel := range append(test.remainChannels, test.remainPatterns...) {
for _, pubsubChannel := range pubsub.channels {
if pubsubChannel.name == channel {
if !slices.Contains(pubsubChannel.subscribers, test.targetConn) {
t.Errorf("cound not find expected target connection in channel \"%s\"", channel)
if _, ok := pubsubChannel.subscribers[test.targetConn]; !ok {
t.Errorf("could not find expected target connection in channel \"%s\"", channel)
}
}
}
}
}
for _, test := range tests {
// Close all the connections
closeConnections(append(test.otherConnections, test.targetConn))
}
}
func Test_HandlePublish(t *testing.T) {}
func Test_HandlePublish(t *testing.T) {
ctx := context.WithValue(context.Background(), "test_name", "PUBLISH")
// verifyChannelMessage reads the message from the connection and asserts whether
// it's the message we expect to read as a subscriber of a channel or pattern.
verifyEvent := func(c *net.Conn, r *resp.Conn, expected []string) {
if err := (*c).SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil {
t.Error(err)
}
rv, _, err := r.ReadValue()
if err != nil {
t.Error(err)
}
v := rv.Array()
for i := 0; i < len(v); i++ {
if v[i].String() != expected[i] {
t.Errorf("expected item at index %d to be \"%s\", got \"%s\"", i, expected[i], v[i].String())
}
}
fmt.Println(v)
}
// The subscribe function handles subscribing the connection to the given
// channels and patterns and reading/verifying the message sent by the server after
// subscription.
subscribe := func(ctx context.Context, channels []string, patterns []string, c *net.Conn, r *resp.Conn) {
// Subscribe to channels
go func() {
_, _ = handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, c)
}()
// Verify all the responses for each channel subscription
for i := 0; i < len(channels); i++ {
verifyEvent(c, r, []string{"subscribe", channels[i], fmt.Sprintf("%d", i+1)})
}
// Subscribe to all the patterns
go func() {
_, _ = handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, patterns...), mockServer, c)
}()
// Verify all the responses for each pattern subscription
for i := 0; i < len(patterns); i++ {
verifyEvent(c, r, []string{"subscribe", patterns[i], fmt.Sprintf("%d", i+1)})
}
}
subscriptions := map[string]map[string][]string{
"subscriber1": {
"channels": {"pub_channel_1", "pub_channel_2", "pub_channel_3"}, // Channels to subscribe to
"patterns": {"pub_channel_[456]"}, // Patterns to subscribe to
},
"subscriber2": {
"channels": {"pub_channel_6", "pub_channel_7"}, // Channels to subscribe to
"patterns": {"pub_channel_[891]"}, // Patterns to subscribe to
},
}
// Create subscriber one and subscribe to channels and patterns
r1, w1 := net.Pipe()
rc1 := resp.NewConn(r1)
subscribe(ctx, subscriptions["subscriber1"]["channels"], subscriptions["subscriber1"]["patterns"], &w1, rc1)
// Create subscriber two and subscribe to channels and patterns
r2, w2 := net.Pipe()
rc2 := resp.NewConn(r2)
subscribe(ctx, subscriptions["subscriber2"]["channels"], subscriptions["subscriber2"]["patterns"], &w2, rc2)
type SubscriberType struct {
c *net.Conn
r *resp.Conn
l string
}
tests := []struct {
channel string
message string
subscribers []SubscriberType
}{
{
channel: "pub_channel_1",
message: "Test both subscribers 1",
subscribers: []SubscriberType{
{c: &r1, r: rc1, l: "pub_channel_1"},
{c: &r2, r: rc2, l: "pub_channel_[891]"},
},
},
{
channel: "pub_channel_6",
message: "Test both subscribers 2",
subscribers: []SubscriberType{
{c: &r1, r: rc1, l: "pub_channel_[456]"},
{c: &r2, r: rc2, l: "pub_channel_6"},
},
},
{
channel: "pub_channel_2",
message: "Test subscriber 1 1",
subscribers: []SubscriberType{
{c: &r1, r: rc1, l: "pub_channel_2"},
},
},
{
channel: "pub_channel_3",
message: "Test subscriber 1 2",
subscribers: []SubscriberType{
{c: &r1, r: rc1, l: "pub_channel_3"},
},
},
{
channel: "pub_channel_4",
message: "Test both subscribers 2",
subscribers: []SubscriberType{
{c: &r1, r: rc1, l: "pub_channel_[456]"},
},
},
{
channel: "pub_channel_5",
message: "Test subscriber 1 3",
subscribers: []SubscriberType{
{c: &r1, r: rc1, l: "pub_channel_[456]"},
},
},
{
channel: "pub_channel_7",
message: "Test subscriber 2 1",
subscribers: []SubscriberType{
{c: &r2, r: rc2, l: "pub_channel_7"},
},
},
{
channel: "pub_channel_8",
message: "Test subscriber 2 2",
subscribers: []SubscriberType{
{c: &r1, r: rc2, l: "pub_channel_[891]"},
},
},
{
channel: "pub_channel_9",
message: "Test subscriber 2 3",
subscribers: []SubscriberType{
{c: &r2, r: rc2, l: "pub_channel_[891]"},
},
},
}
// Dial server to make publisher connection
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil {
t.Error(err)
}
defer func() {
if err = conn.Close(); err != nil {
t.Error(err)
}
}()
w := resp.NewConn(conn)
for _, test := range tests {
err = w.WriteArray([]resp.Value{
resp.StringValue("PUBLISH"),
resp.StringValue(test.channel),
resp.StringValue(test.message),
})
if err != nil {
t.Error(err)
}
rv, _, err := w.ReadValue()
if err != nil {
t.Error(err)
}
if rv.String() != "OK" {
t.Errorf("Expected publish response to be \"OK\", got \"%s\"", rv.String())
}
for _, sub := range test.subscribers {
verifyEvent(sub.c, sub.r, []string{"message", sub.l, test.message})
}
}
}
func Test_HandlePubSubChannels(t *testing.T) {}