// Copyright 2024 Kelvin Clement Mwinuka // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package pubsub_test import ( "context" "github.com/echovault/echovault/echovault" "github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/constants" "github.com/echovault/echovault/internal/modules/pubsub" "github.com/tidwall/resp" "net" "reflect" "slices" "strings" "sync" "testing" "unsafe" ) var ps *pubsub.PubSub var mockServer *echovault.EchoVault var bindAddr = "localhost" var port uint16 func init() { p, _ := internal.GetFreePort() port = uint16(p) mockServer = setUpServer(bindAddr, port) getPubSub := getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getPubSub")).(func() interface{}) ps = getPubSub().(*pubsub.PubSub) wg := sync.WaitGroup{} wg.Add(1) go func() { wg.Done() mockServer.Start() }() wg.Wait() } func setUpServer(bindAddr string, port uint16) *echovault.EchoVault { server, _ := echovault.NewEchoVault( echovault.WithConfig(config.Config{ BindAddr: bindAddr, Port: port, DataDir: "", EvictionPolicy: constants.NoEviction, }), ) return server } func getUnexportedField(field reflect.Value) interface{} { return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface() } func getHandler(commands ...string) internal.HandlerFunc { if len(commands) == 0 { return nil } getCommands := getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command) for _, c := range getCommands() { if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 { // Get command handler return c.HandlerFunc } if strings.EqualFold(commands[0], c.Command) { // Get sub-command handler for _, sc := range c.SubCommands { if strings.EqualFold(commands[1], sc.Command) { return sc.HandlerFunc } } } } return nil } func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn, mockServer *echovault.EchoVault) internal.HandlerFuncParams { getPubSub := getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getPubSub")).(func() interface{}) return internal.HandlerFuncParams{ Context: ctx, Command: cmd, Connection: conn, GetPubSub: getPubSub, } } // func Test_HandleSubscribe(t *testing.T) { // ctx := context.WithValue(context.Background(), "test_name", "SUBSCRIBE/PSUBSCRIBE") // // numOfConnection := 20 // connections := make([]*net.Conn, numOfConnection) // // for i := 0; i < numOfConnection; i++ { // conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) // if err != nil { // t.Error(err) // } // connections[i] = &conn // } // defer func() { // for _, conn := range connections { // if err := (*conn).Close(); err != nil { // t.Error(err) // } // } // }() // // // Test subscribe to channels // channels := []string{"sub_channel1", "sub_channel2", "sub_channel3"} // for _, conn := range connections { // _, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), conn, mockServer)) // if err != nil { // t.Error(err) // } // } // for _, channel := range channels { // // Check if the channel exists in the pubsub module // if !slices.ContainsFunc(ps.GetAllChannels(), func(c *pubsub.Channel) bool { // return c.Name() == channel // }) { // t.Errorf("expected pubsub to contain channel \"%s\" but it was not found", channel) // } // for _, c := range ps.GetAllChannels() { // if c.Name() == channel { // // Check if channel has nil pattern // if c.Pattern() != nil { // t.Errorf("expected channel \"%s\" to have nil pattern, found pattern \"%s\"", channel, c.Name()) // } // // Check if the channel has all the connections from above // for _, conn := range connections { // if _, ok := c.Subscribers()[conn]; !ok { // t.Errorf("could not find all expected connection in the \"%s\"", channel) // } // } // } // } // } // // // Test subscribe to patterns // patterns := []string{"psub_channel*"} // for _, conn := range connections { // _, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), conn, mockServer)) // if err != nil { // t.Error(err) // } // } // for _, pattern := range patterns { // // Check if pattern channel exists in pubsub module // if !slices.ContainsFunc(ps.GetAllChannels(), func(c *pubsub.Channel) bool { // return c.Name() == pattern // }) { // t.Errorf("expected pubsub to contain pattern channel \"%s\" but it was not found", pattern) // } // for _, c := range ps.GetAllChannels() { // if c.Name() == pattern { // // Check if channel has non-nil pattern // if c.Pattern() == nil { // t.Errorf("expected channel \"%s\" to have pattern \"%s\", found nil pattern", pattern, c.Name()) // } // // Check if the channel has all the connections from above // for _, conn := range connections { // if _, ok := c.Subscribers()[conn]; !ok { // t.Errorf("could not find all expected connection in the \"%s\"", pattern) // } // } // } // } // } // } // func Test_HandleUnsubscribe(t *testing.T) { // generateConnections := func(noOfConnections int) []*net.Conn { // connections := make([]*net.Conn, noOfConnections) // for i := 0; i < noOfConnections; i++ { // conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) // if err != nil { // t.Error(err) // } // connections[i] = &conn // } // return connections // } // // closeConnections := func(conns []*net.Conn) { // for _, conn := range conns { // if err := (*conn).Close(); err != nil { // t.Error(err) // } // } // } // // verifyResponse := func(res []byte, expectedResponse [][]string) { // rd := resp.NewReader(bytes.NewReader(res)) // rv, _, err := rd.ReadValue() // if err != nil { // t.Error(err) // } // v := rv.Array() // if len(v) != len(expectedResponse) { // t.Errorf("expected subscribe response of length %d, but got %d", len(expectedResponse), len(v)) // } // for _, item := range v { // arr := item.Array() // if len(arr) != 3 { // t.Errorf("expected subscribe response item to be length %d, but got %d", 3, len(arr)) // } // if !slices.ContainsFunc(expectedResponse, func(strings []string) bool { // return strings[0] == arr[0].String() && strings[1] == arr[1].String() && strings[2] == arr[2].String() // }) { // t.Errorf("expected to find item \"%s\" in response, did not find it.", arr[1].String()) // } // } // } // // tests := []struct { // subChannels []string // All channels to subscribe to // subPatterns []string // All patterns to subscribe to // unSubChannels []string // Channels to unsubscribe from // unSubPatterns []string // Patterns to unsubscribe from // remainChannels []string // Channels to remain subscribed to // remainPatterns []string // Patterns to remain subscribed to // targetConn *net.Conn // Connection used to test unsubscribe functionality // otherConnections []*net.Conn // Connections to fill the subscribers list for channels and patterns // expectedResponses map[string][][]string // The expected response from the handler // }{ // { // 1. Unsubscribe from channels and patterns // subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"}, // subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"}, // unSubChannels: []string{"xx_channel_one", "xx_channel_two"}, // unSubPatterns: []string{"xx_pattern_[ab]"}, // remainChannels: []string{"xx_channel_three", "xx_channel_four"}, // remainPatterns: []string{"xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"}, // targetConn: generateConnections(1)[0], // otherConnections: generateConnections(20), // expectedResponses: map[string][][]string{ // "channel": { // {"unsubscribe", "xx_channel_one", "1"}, // {"unsubscribe", "xx_channel_two", "2"}, // }, // "pattern": { // {"punsubscribe", "xx_pattern_[ab]", "1"}, // }, // }, // }, // { // 2. Unsubscribe from all channels no channel or pattern is passed to command // subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"}, // subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"}, // unSubChannels: []string{}, // unSubPatterns: []string{}, // remainChannels: []string{}, // remainPatterns: []string{}, // targetConn: generateConnections(1)[0], // otherConnections: generateConnections(20), // expectedResponses: map[string][][]string{ // "channel": { // {"unsubscribe", "xx_channel_one", "1"}, // {"unsubscribe", "xx_channel_two", "2"}, // {"unsubscribe", "xx_channel_three", "3"}, // {"unsubscribe", "xx_channel_four", "4"}, // }, // "pattern": { // {"punsubscribe", "xx_pattern_[ab]", "1"}, // {"punsubscribe", "xx_pattern_[cd]", "2"}, // {"punsubscribe", "xx_pattern_[ef]", "3"}, // {"punsubscribe", "xx_pattern_[gh]", "4"}, // }, // }, // }, // { // 3. Don't unsubscribe from any channels or patterns if the provided ones are non-existent // subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"}, // subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"}, // unSubChannels: []string{"xx_channel_non_existent_channel"}, // unSubPatterns: []string{"xx_channel_non_existent_pattern_[ae]"}, // remainChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"}, // remainPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"}, // targetConn: generateConnections(1)[0], // otherConnections: generateConnections(20), // expectedResponses: map[string][][]string{ // "channel": {}, // "pattern": {}, // }, // }, // } // // for i, test := range tests { // ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("UNSUBSCRIBE/PUNSUBSCRIBE, %d", i)) // // // Subscribe all the connections to the channels and patterns // for _, conn := range append(test.otherConnections, test.targetConn) { // _, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, test.subChannels...), conn, mockServer)) // if err != nil { // t.Error(err) // } // _, err = getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, test.subPatterns...), conn, mockServer)) // if err != nil { // t.Error(err) // } // } // // // Unsubscribe the target connection from the unsub channels and patterns // res, err := getHandler("UNSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"UNSUBSCRIBE"}, test.unSubChannels...), test.targetConn, mockServer)) // if err != nil { // t.Error(err) // } // verifyResponse(res, test.expectedResponses["channel"]) // // res, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PUNSUBSCRIBE"}, test.unSubPatterns...), test.targetConn, mockServer)) // if err != nil { // t.Error(err) // } // verifyResponse(res, test.expectedResponses["pattern"]) // // for _, channel := range append(test.unSubChannels, test.unSubPatterns...) { // for _, pubsubChannel := range ps.GetAllChannels() { // if pubsubChannel.Name() == channel { // // Assert that target connection is no longer in the unsub channels and patterns // if _, ok := pubsubChannel.Subscribers()[test.targetConn]; ok { // t.Errorf("found unexpected target connection after unsubscrining in channel \"%s\"", channel) // } // for _, conn := range test.otherConnections { // if _, ok := pubsubChannel.Subscribers()[conn]; !ok { // t.Errorf("did not find expected other connection in channel \"%s\"", channel) // } // } // } // } // } // // // Assert that the target connection is still in the remain channels and patterns // for _, channel := range append(test.remainChannels, test.remainPatterns...) { // for _, pubsubChannel := range ps.GetAllChannels() { // if pubsubChannel.Name() == channel { // if _, ok := pubsubChannel.Subscribers()[test.targetConn]; !ok { // t.Errorf("could not find expected target connection in channel \"%s\"", channel) // } // } // } // } // } // // for _, test := range tests { // // Close all the connections // closeConnections(append(test.otherConnections, test.targetConn)) // } // } // func Test_HandlePublish(t *testing.T) { // ctx := context.WithValue(context.Background(), "test_name", "PUBLISH") // // // verifyChannelMessage reads the message from the connection and asserts whether // // it's the message we expect to read as a subscriber of a channel or pattern. // verifyEvent := func(c *net.Conn, r *resp.Conn, expected []string) { // if err := (*c).SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { // t.Error(err) // } // rv, _, err := r.ReadValue() // if err != nil { // t.Error(err) // } // v := rv.Array() // for i := 0; i < len(v); i++ { // if v[i].String() != expected[i] { // t.Errorf("expected item at index %d to be \"%s\", got \"%s\"", i, expected[i], v[i].String()) // } // } // } // // // The subscribe function handles subscribing the connection to the given // // channels and patterns and reading/verifying the message sent by the echovault after // // subscription. // subscribe := func(ctx context.Context, channels []string, patterns []string, c *net.Conn, r *resp.Conn) { // // Subscribe to channels // go func() { // _, _ = getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), c, mockServer)) // }() // // Verify all the responses for each channel subscription // for i := 0; i < len(channels); i++ { // verifyEvent(c, r, []string{"subscribe", channels[i], fmt.Sprintf("%d", i+1)}) // } // // Subscribe to all the patterns // go func() { // _, _ = getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), c, mockServer)) // }() // // Verify all the responses for each pattern subscription // for i := 0; i < len(patterns); i++ { // verifyEvent(c, r, []string{"psubscribe", patterns[i], fmt.Sprintf("%d", i+1)}) // } // } // // subscriptions := map[string]map[string][]string{ // "subscriber1": { // "channels": {"pub_channel_1", "pub_channel_2", "pub_channel_3"}, // Channels to subscribe to // "patterns": {"pub_channel_[456]"}, // Patterns to subscribe to // }, // "subscriber2": { // "channels": {"pub_channel_6", "pub_channel_7"}, // Channels to subscribe to // "patterns": {"pub_channel_[891]"}, // Patterns to subscribe to // }, // } // // // Create subscriber one and subscribe to channels and patterns // r1, w1 := net.Pipe() // rc1 := resp.NewConn(r1) // subscribe(ctx, subscriptions["subscriber1"]["channels"], subscriptions["subscriber1"]["patterns"], &w1, rc1) // // // Create subscriber two and subscribe to channels and patterns // r2, w2 := net.Pipe() // rc2 := resp.NewConn(r2) // subscribe(ctx, subscriptions["subscriber2"]["channels"], subscriptions["subscriber2"]["patterns"], &w2, rc2) // // type SubscriberType struct { // c *net.Conn // r *resp.Conn // l string // } // // tests := []struct { // channel string // message string // subscribers []SubscriberType // }{ // { // channel: "pub_channel_1", // message: "Test both subscribers 1", // subscribers: []SubscriberType{ // {c: &r1, r: rc1, l: "pub_channel_1"}, // {c: &r2, r: rc2, l: "pub_channel_[891]"}, // }, // }, // { // channel: "pub_channel_6", // message: "Test both subscribers 2", // subscribers: []SubscriberType{ // {c: &r1, r: rc1, l: "pub_channel_[456]"}, // {c: &r2, r: rc2, l: "pub_channel_6"}, // }, // }, // { // channel: "pub_channel_2", // message: "Test subscriber 1 1", // subscribers: []SubscriberType{ // {c: &r1, r: rc1, l: "pub_channel_2"}, // }, // }, // { // channel: "pub_channel_3", // message: "Test subscriber 1 2", // subscribers: []SubscriberType{ // {c: &r1, r: rc1, l: "pub_channel_3"}, // }, // }, // { // channel: "pub_channel_4", // message: "Test both subscribers 2", // subscribers: []SubscriberType{ // {c: &r1, r: rc1, l: "pub_channel_[456]"}, // }, // }, // { // channel: "pub_channel_5", // message: "Test subscriber 1 3", // subscribers: []SubscriberType{ // {c: &r1, r: rc1, l: "pub_channel_[456]"}, // }, // }, // { // channel: "pub_channel_7", // message: "Test subscriber 2 1", // subscribers: []SubscriberType{ // {c: &r2, r: rc2, l: "pub_channel_7"}, // }, // }, // { // channel: "pub_channel_8", // message: "Test subscriber 2 2", // subscribers: []SubscriberType{ // {c: &r1, r: rc2, l: "pub_channel_[891]"}, // }, // }, // { // channel: "pub_channel_9", // message: "Test subscriber 2 3", // subscribers: []SubscriberType{ // {c: &r2, r: rc2, l: "pub_channel_[891]"}, // }, // }, // } // // // Dial echovault to make publisher connection // conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) // if err != nil { // t.Error(err) // } // defer func() { // if err = conn.Close(); err != nil { // t.Error(err) // } // }() // w := resp.NewConn(conn) // // for _, test := range tests { // err = w.WriteArray([]resp.Value{ // resp.StringValue("PUBLISH"), // resp.StringValue(test.channel), // resp.StringValue(test.message), // }) // if err != nil { // t.Error(err) // } // // rv, _, err := w.ReadValue() // if err != nil { // t.Error(err) // } // if rv.String() != "OK" { // t.Errorf("Expected publish response to be \"OK\", got \"%s\"", rv.String()) // } // // for _, sub := range test.subscribers { // verifyEvent(sub.c, sub.r, []string{"message", sub.l, test.message}) // } // } // } func Test_HandlePubSubChannels(t *testing.T) { verifyExpectedResponse := func(res resp.Value, expected []string) { if len(res.Array()) != len(expected) { t.Errorf("expected response array of length %d, got %d", len(expected), len(res.Array())) } for _, e := range expected { if !slices.ContainsFunc(res.Array(), func(v resp.Value) bool { return e == v.String() }) { t.Errorf("expected to find element \"%s\" in response array, could not find it", e) } } } port, _ := internal.GetFreePort() mockServer := setUpServer(bindAddr, uint16(port)) wg := sync.WaitGroup{} wg.Add(1) go func() { wg.Done() mockServer.Start() }() wg.Wait() subscribers := make([]*resp.Conn, 2) for i := 0; i < len(subscribers); i++ { conn, err := internal.GetConnection(bindAddr, port) if err != nil { t.Error(err) } subscribers[i] = resp.NewConn(conn) } channels := []string{"channel_1", "channel_2", "channel_3"} patterns := []string{"channel_[123]", "channel_[456]"} subscriptions := []struct { client *resp.Conn action string channels []string patterns []string }{ { client: subscribers[0], action: "SUBSCRIBE", channels: channels, patterns: make([]string, 0), }, { client: subscribers[1], action: "PSUBSCRIBE", channels: make([]string, 0), patterns: patterns, }, } for _, subscription := range subscriptions { command := []resp.Value{resp.StringValue(subscription.action)} if len(subscription.channels) > 0 { for _, channel := range subscription.channels { command = append(command, resp.StringValue(channel)) } } else if len(subscription.patterns) > 0 { for _, pattern := range subscription.patterns { command = append(command, resp.StringValue(pattern)) } } if err := subscription.client.WriteArray(command); err != nil { t.Error(err) } if len(subscription.channels) > 0 { for i := 0; i < len(subscription.channels); i++ { _, _, _ = subscription.client.ReadValue() } return } for i := 0; i < len(subscription.patterns); i++ { _, _, _ = subscription.client.ReadValue() } } // Get fresh connection for the next phase. conn, err := internal.GetConnection(bindAddr, port) if err != nil { t.Error(err) } client := resp.NewConn(conn) // Check if all subscriptions are returned. if err = client.WriteArray([]resp.Value{resp.StringValue("PUBSUB"), resp.StringValue("CHANNELS")}); err != nil { t.Error(err) } res, _, err := client.ReadValue() if err != nil { t.Error(err) } verifyExpectedResponse(res, append(channels, patterns...)) // Unsubscribe from one pattern and one channel before checking against a new slice of // expected channels/patterns in the response of the "PUBSUB CHANNELS" command. for _, unsubscribe := range []struct { client *resp.Conn command []resp.Value }{ { client: subscribers[0], command: []resp.Value{resp.StringValue("UNSUBSCRIBE"), resp.StringValue("channel_2"), resp.StringValue("channel_3")}, }, { client: subscribers[1], command: []resp.Value{resp.StringValue("UNSUBSCRIBE"), resp.StringValue("channel_[456]")}, }, } { if err = unsubscribe.client.WriteArray(unsubscribe.command); err != nil { t.Error(err) } for i := 0; i < len(unsubscribe.command[1:]); i++ { _, _, err = unsubscribe.client.ReadValue() if err != nil { t.Error(err) } } } // Return all the remaining channels. if err = client.WriteArray([]resp.Value{resp.StringValue("PUBSUB"), resp.StringValue("CHANNELS")}); err != nil { t.Error(err) } res, _, err = client.ReadValue() if err != nil { t.Error(err) } verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"}) // Return only one of the remaining channels when passed a pattern that matches it. if err = client.WriteArray([]resp.Value{ resp.StringValue("PUBSUB"), resp.StringValue("CHANNELS"), resp.StringValue("channel_[189]"), }); err != nil { t.Error(err) } verifyExpectedResponse(res, []string{"channel_1"}) // Return both remaining channels when passed a pattern that matches them. if err := client.WriteArray([]resp.Value{ resp.StringValue("PUBSUB"), resp.StringValue("CHANNELS"), resp.StringValue("channel_[123]"), }); err != nil { t.Error(err) } res, _, err = client.ReadValue() if err != nil { t.Error(err) } verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"}) // Return no channels when passed a pattern that does not match either channel. if err = client.WriteArray([]resp.Value{ resp.StringValue("PUBSUB"), resp.StringValue("CHANNELS"), resp.StringValue("channel_[456]"), }); err != nil { t.Error(err) } res, _, err = client.ReadValue() if err != nil { t.Error(err) } verifyExpectedResponse(res, []string{}) } func Test_HandleNumPat(t *testing.T) { port, _ := internal.GetFreePort() mockServer := setUpServer(bindAddr, uint16(port)) wg := sync.WaitGroup{} wg.Add(1) go func() { wg.Done() mockServer.Start() }() wg.Wait() // Create subscribers. subscribers := make([]*resp.Conn, 3) for i := 0; i < len(subscribers); i++ { conn, err := internal.GetConnection(bindAddr, port) if err != nil { t.Error(err) } subscribers[i] = resp.NewConn(conn) } patterns := []string{"pattern_[123]", "pattern_[456]", "pattern_[789]"} // Subscribe to all patterns for _, client := range subscribers { command := []resp.Value{resp.StringValue("PSUBSCRIBE")} for _, pattern := range patterns { command = append(command, resp.StringValue(pattern)) } if err := client.WriteArray(command); err != nil { t.Error(err) } // Read subscription responses to make sure we've subscribed to all the channels. for i := 0; i < len(patterns); i++ { res, _, err := client.ReadValue() if err != nil { t.Error(err) } if len(res.Array()) != 3 { t.Errorf("expected array response of length %d, got %d", 3, len(res.Array())) } if !strings.EqualFold(res.Array()[0].String(), "psubscribe") { t.Errorf("expected the first array item to be \"psubscribe\", got \"%s\"", res.Array()[0].String()) } if !slices.Contains(patterns, res.Array()[1].String()) { t.Errorf("unexpected channel name \"%s\", expected %v", res.Array()[1].String(), patterns) } } } // Get fresh connection for the next phase. conn, err := internal.GetConnection(bindAddr, port) if err != nil { t.Error(err) } client := resp.NewConn(conn) // Check that we receive all the patterns with NUMPAT commands. if err = client.WriteArray([]resp.Value{resp.StringValue("PUBSUB"), resp.StringValue("NUMPAT")}); err != nil { t.Error(err) } res, _, err := client.ReadValue() if res.Integer() != len(patterns) { t.Errorf("expected response \"%d\", got \"%d\"", len(patterns), res.Integer()) } // Unsubscribe all subscribers from one pattern and check if the response is updated. for _, subscriber := range subscribers { if err = subscriber.WriteArray([]resp.Value{ resp.StringValue("PUNSUBSCRIBE"), resp.StringValue(patterns[0]), }); err != nil { t.Error(err) } res, _, err = subscriber.ReadValue() if err != nil { t.Error(err) } if len(res.Array()[0].Array()) != 3 { t.Errorf("expected array response of length %d, got %d", 3, len(res.Array()[0].Array())) } if !strings.EqualFold(res.Array()[0].Array()[0].String(), "punsubscribe") { t.Errorf("expected the first array item to be \"punsubscribe\", got \"%s\"", res.Array()[0].Array()[0].String()) } if res.Array()[0].Array()[1].String() != patterns[0] { t.Errorf("unexpected channel name \"%s\", expected %s", res.Array()[0].Array()[1].String(), patterns[0]) } } if err = client.WriteArray([]resp.Value{resp.StringValue("PUBSUB"), resp.StringValue("NUMPAT")}); err != nil { t.Error(err) } res, _, err = client.ReadValue() if res.Integer() != len(patterns)-1 { t.Errorf("expected response \"%d\", got \"%d\"", len(patterns)-1, res.Integer()) } // Unsubscribe from all the channels and check if we get a 0 response for _, subscriber := range subscribers { for _, pattern := range patterns[1:] { if err = subscriber.WriteArray([]resp.Value{ resp.StringValue("PUNSUBSCRIBE"), resp.StringValue(pattern), }); err != nil { t.Error(err) } res, _, err = subscriber.ReadValue() if err != nil { t.Error(err) } if len(res.Array()[0].Array()) != 3 { t.Errorf("expected array response of length %d, got %d", 3, len(res.Array()[0].Array())) } if !strings.EqualFold(res.Array()[0].Array()[0].String(), "punsubscribe") { t.Errorf("expected the first array item to be \"punsubscribe\", got \"%s\"", res.Array()[0].Array()[0].String()) } if res.Array()[0].Array()[1].String() != pattern { t.Errorf("unexpected channel name \"%s\", expected %s", res.Array()[0].Array()[1].String(), pattern) } } } if err = client.WriteArray([]resp.Value{resp.StringValue("PUBSUB"), resp.StringValue("NUMPAT")}); err != nil { t.Error(err) } res, _, err = client.ReadValue() if res.Integer() != 0 { t.Errorf("expected response \"%d\", got \"%d\"", 0, res.Integer()) } } func Test_HandleNumSub(t *testing.T) { port, _ := internal.GetFreePort() mockServer := setUpServer(bindAddr, uint16(port)) wg := sync.WaitGroup{} wg.Add(1) go func() { wg.Done() mockServer.Start() }() wg.Wait() channels := []string{"channel_1", "channel_2", "channel_3"} for i := 0; i < 3; i++ { conn, err := internal.GetConnection(bindAddr, port) if err != nil { t.Error(err) } client := resp.NewConn(conn) command := []resp.Value{ resp.StringValue("SUBSCRIBE"), } for _, channel := range channels { command = append(command, resp.StringValue(channel)) } err = client.WriteArray(command) if err != nil { t.Error(err) } // Read subscription responses to make sure we've subscribed to all the channels. for i := 0; i < len(channels); i++ { res, _, err := client.ReadValue() if err != nil { t.Error(err) } if len(res.Array()) != 3 { t.Errorf("expected array response of length %d, got %d", 3, len(res.Array())) } if !strings.EqualFold(res.Array()[0].String(), "subscribe") { t.Errorf("expected the first array item to be \"subscribe\", got \"%s\"", res.Array()[0].String()) } if !slices.Contains(channels, res.Array()[1].String()) { t.Errorf("unexpected channel name \"%s\", expected %v", res.Array()[1].String(), channels) } } } // Get fresh connection for the next phase. conn, err := internal.GetConnection(bindAddr, port) if err != nil { t.Error(err) } client := resp.NewConn(conn) tests := []struct { name string cmd []string expectedResponse [][]string }{ { name: "1. Get all subscriptions on existing channels", cmd: append([]string{"PUBSUB", "NUMSUB"}, channels...), expectedResponse: [][]string{{"channel_1", "3"}, {"channel_2", "3"}, {"channel_3", "3"}}, }, { name: "2. Get all the subscriptions of on existing channels and a few non-existent ones", cmd: append([]string{"PUBSUB", "NUMSUB", "non_existent_channel_1", "non_existent_channel_2"}, channels...), expectedResponse: [][]string{ {"non_existent_channel_1", "0"}, {"non_existent_channel_2", "0"}, {"channel_1", "3"}, {"channel_2", "3"}, {"channel_3", "3"}, }, }, { name: "3. Get an empty array when channels are not provided in the command", cmd: []string{"PUBSUB", "NUMSUB"}, expectedResponse: make([][]string, 0), }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { var command []resp.Value for _, token := range test.cmd { command = append(command, resp.StringValue(token)) } if err = client.WriteArray(command); err != nil { t.Error(err) } res, _, err := client.ReadValue() if err != nil { t.Error(err) } arr := res.Array() if len(arr) != len(test.expectedResponse) { t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(arr)) } for _, item := range arr { itemArr := item.Array() if len(itemArr) != 2 { t.Errorf("expected each response item to be of length 2, got %d", len(itemArr)) } if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { return expected[0] == itemArr[0].String() && expected[1] == itemArr[1].String() }) { t.Errorf("could not find entry with channel \"%s\", with %d subscribers in expected response", itemArr[0].String(), itemArr[1].Integer()) } } }) } }