// Copyright 2024 Kelvin Clement Mwinuka // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package pubsub import ( "bytes" "context" "fmt" "github.com/echovault/echovault/constants" "github.com/echovault/echovault/echovault" "github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/modules/pubsub" "github.com/tidwall/resp" "net" "reflect" "slices" "strings" "sync" "testing" "time" "unsafe" ) var ps *pubsub.PubSub var mockServer *echovault.EchoVault var bindAddr = "localhost" var port uint16 = 7490 func init() { mockServer = setUpServer(bindAddr, port) getPubSub := getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getPubSub")).(func() interface{}) ps = getPubSub().(*pubsub.PubSub) wg := sync.WaitGroup{} wg.Add(1) go func() { wg.Done() mockServer.Start() }() wg.Wait() } func setUpServer(bindAddr string, port uint16) *echovault.EchoVault { server, _ := echovault.NewEchoVault( echovault.WithConfig(config.Config{ BindAddr: bindAddr, Port: port, DataDir: "", EvictionPolicy: constants.NoEviction, }), ) return server } func getUnexportedField(field reflect.Value) interface{} { return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface() } func getHandler(commands ...string) internal.HandlerFunc { if len(commands) == 0 { return nil } getCommands := getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command) for _, c := range getCommands() { if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 { // Get command handler return c.HandlerFunc } if strings.EqualFold(commands[0], c.Command) { // Get sub-command handler for _, sc := range c.SubCommands { if strings.EqualFold(commands[1], sc.Command) { return sc.HandlerFunc } } } } return nil } func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn, mockServer *echovault.EchoVault) internal.HandlerFuncParams { getPubSub := getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getPubSub")).(func() interface{}) return internal.HandlerFuncParams{ Context: ctx, Command: cmd, Connection: conn, GetPubSub: getPubSub, } } func Test_HandleSubscribe(t *testing.T) { ctx := context.WithValue(context.Background(), "test_name", "SUBSCRIBE/PSUBSCRIBE") numOfConnection := 20 connections := make([]*net.Conn, numOfConnection) for i := 0; i < numOfConnection; i++ { conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) if err != nil { t.Error(err) } connections[i] = &conn } defer func() { for _, conn := range connections { if err := (*conn).Close(); err != nil { t.Error(err) } } }() // Test subscribe to channels channels := []string{"sub_channel1", "sub_channel2", "sub_channel3"} for _, conn := range connections { _, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), conn, mockServer)) if err != nil { t.Error(err) } } for _, channel := range channels { // Check if the channel exists in the pubsub module if !slices.ContainsFunc(ps.GetAllChannels(), func(c *pubsub.Channel) bool { return c.Name() == channel }) { t.Errorf("expected pubsub to contain channel \"%s\" but it was not found", channel) } for _, c := range ps.GetAllChannels() { if c.Name() == channel { // Check if channel has nil pattern if c.Pattern() != nil { t.Errorf("expected channel \"%s\" to have nil pattern, found pattern \"%s\"", channel, c.Name()) } // Check if the channel has all the connections from above for _, conn := range connections { if _, ok := c.Subscribers()[conn]; !ok { t.Errorf("could not find all expected connection in the \"%s\"", channel) } } } } } // Test subscribe to patterns patterns := []string{"psub_channel*"} for _, conn := range connections { _, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), conn, mockServer)) if err != nil { t.Error(err) } } for _, pattern := range patterns { // Check if pattern channel exists in pubsub module if !slices.ContainsFunc(ps.GetAllChannels(), func(c *pubsub.Channel) bool { return c.Name() == pattern }) { t.Errorf("expected pubsub to contain pattern channel \"%s\" but it was not found", pattern) } for _, c := range ps.GetAllChannels() { if c.Name() == pattern { // Check if channel has non-nil pattern if c.Pattern() == nil { t.Errorf("expected channel \"%s\" to have pattern \"%s\", found nil pattern", pattern, c.Name()) } // Check if the channel has all the connections from above for _, conn := range connections { if _, ok := c.Subscribers()[conn]; !ok { t.Errorf("could not find all expected connection in the \"%s\"", pattern) } } } } } } func Test_HandleUnsubscribe(t *testing.T) { generateConnections := func(noOfConnections int) []*net.Conn { connections := make([]*net.Conn, noOfConnections) for i := 0; i < noOfConnections; i++ { conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) if err != nil { t.Error(err) } connections[i] = &conn } return connections } closeConnections := func(conns []*net.Conn) { for _, conn := range conns { if err := (*conn).Close(); err != nil { t.Error(err) } } } verifyResponse := func(res []byte, expectedResponse [][]string) { rd := resp.NewReader(bytes.NewReader(res)) rv, _, err := rd.ReadValue() if err != nil { t.Error(err) } v := rv.Array() if len(v) != len(expectedResponse) { t.Errorf("expected subscribe response of length %d, but got %d", len(expectedResponse), len(v)) } for _, item := range v { arr := item.Array() if len(arr) != 3 { t.Errorf("expected subscribe response item to be length %d, but got %d", 3, len(arr)) } if !slices.ContainsFunc(expectedResponse, func(strings []string) bool { return strings[0] == arr[0].String() && strings[1] == arr[1].String() && strings[2] == arr[2].String() }) { t.Errorf("expected to find item \"%s\" in response, did not find it.", arr[1].String()) } } } tests := []struct { subChannels []string // All channels to subscribe to subPatterns []string // All patterns to subscribe to unSubChannels []string // Channels to unsubscribe from unSubPatterns []string // Patterns to unsubscribe from remainChannels []string // Channels to remain subscribed to remainPatterns []string // Patterns to remain subscribed to targetConn *net.Conn // Connection used to test unsubscribe functionality otherConnections []*net.Conn // Connections to fill the subscribers list for channels and patterns expectedResponses map[string][][]string // The expected response from the handler }{ { // 1. Unsubscribe from channels and patterns subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"}, subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"}, unSubChannels: []string{"xx_channel_one", "xx_channel_two"}, unSubPatterns: []string{"xx_pattern_[ab]"}, remainChannels: []string{"xx_channel_three", "xx_channel_four"}, remainPatterns: []string{"xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"}, targetConn: generateConnections(1)[0], otherConnections: generateConnections(20), expectedResponses: map[string][][]string{ "channel": { {"unsubscribe", "xx_channel_one", "1"}, {"unsubscribe", "xx_channel_two", "2"}, }, "pattern": { {"punsubscribe", "xx_pattern_[ab]", "1"}, }, }, }, { // 2. Unsubscribe from all channels no channel or pattern is passed to command subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"}, subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"}, unSubChannels: []string{}, unSubPatterns: []string{}, remainChannels: []string{}, remainPatterns: []string{}, targetConn: generateConnections(1)[0], otherConnections: generateConnections(20), expectedResponses: map[string][][]string{ "channel": { {"unsubscribe", "xx_channel_one", "1"}, {"unsubscribe", "xx_channel_two", "2"}, {"unsubscribe", "xx_channel_three", "3"}, {"unsubscribe", "xx_channel_four", "4"}, }, "pattern": { {"punsubscribe", "xx_pattern_[ab]", "1"}, {"punsubscribe", "xx_pattern_[cd]", "2"}, {"punsubscribe", "xx_pattern_[ef]", "3"}, {"punsubscribe", "xx_pattern_[gh]", "4"}, }, }, }, { // 3. Don't unsubscribe from any channels or patterns if the provided ones are non-existent subChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"}, subPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"}, unSubChannels: []string{"xx_channel_non_existent_channel"}, unSubPatterns: []string{"xx_channel_non_existent_pattern_[ae]"}, remainChannels: []string{"xx_channel_one", "xx_channel_two", "xx_channel_three", "xx_channel_four"}, remainPatterns: []string{"xx_pattern_[ab]", "xx_pattern_[cd]", "xx_pattern_[ef]", "xx_pattern_[gh]"}, targetConn: generateConnections(1)[0], otherConnections: generateConnections(20), expectedResponses: map[string][][]string{ "channel": {}, "pattern": {}, }, }, } for i, test := range tests { ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("UNSUBSCRIBE/PUNSUBSCRIBE, %d", i)) // Subscribe all the connections to the channels and patterns for _, conn := range append(test.otherConnections, test.targetConn) { _, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, test.subChannels...), conn, mockServer)) if err != nil { t.Error(err) } _, err = getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, test.subPatterns...), conn, mockServer)) if err != nil { t.Error(err) } } // Unsubscribe the target connection from the unsub channels and patterns res, err := getHandler("UNSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"UNSUBSCRIBE"}, test.unSubChannels...), test.targetConn, mockServer)) if err != nil { t.Error(err) } verifyResponse(res, test.expectedResponses["channel"]) res, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PUNSUBSCRIBE"}, test.unSubPatterns...), test.targetConn, mockServer)) if err != nil { t.Error(err) } verifyResponse(res, test.expectedResponses["pattern"]) for _, channel := range append(test.unSubChannels, test.unSubPatterns...) { for _, pubsubChannel := range ps.GetAllChannels() { if pubsubChannel.Name() == channel { // Assert that target connection is no longer in the unsub channels and patterns if _, ok := pubsubChannel.Subscribers()[test.targetConn]; ok { t.Errorf("found unexpected target connection after unsubscrining in channel \"%s\"", channel) } for _, conn := range test.otherConnections { if _, ok := pubsubChannel.Subscribers()[conn]; !ok { t.Errorf("did not find expected other connection in channel \"%s\"", channel) } } } } } // Assert that the target connection is still in the remain channels and patterns for _, channel := range append(test.remainChannels, test.remainPatterns...) { for _, pubsubChannel := range ps.GetAllChannels() { if pubsubChannel.Name() == channel { if _, ok := pubsubChannel.Subscribers()[test.targetConn]; !ok { t.Errorf("could not find expected target connection in channel \"%s\"", channel) } } } } } for _, test := range tests { // Close all the connections closeConnections(append(test.otherConnections, test.targetConn)) } } func Test_HandlePublish(t *testing.T) { ctx := context.WithValue(context.Background(), "test_name", "PUBLISH") // verifyChannelMessage reads the message from the connection and asserts whether // it's the message we expect to read as a subscriber of a channel or pattern. verifyEvent := func(c *net.Conn, r *resp.Conn, expected []string) { if err := (*c).SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { t.Error(err) } rv, _, err := r.ReadValue() if err != nil { t.Error(err) } v := rv.Array() for i := 0; i < len(v); i++ { if v[i].String() != expected[i] { t.Errorf("expected item at index %d to be \"%s\", got \"%s\"", i, expected[i], v[i].String()) } } } // The subscribe function handles subscribing the connection to the given // channels and patterns and reading/verifying the message sent by the echovault after // subscription. subscribe := func(ctx context.Context, channels []string, patterns []string, c *net.Conn, r *resp.Conn) { // Subscribe to channels go func() { _, _ = getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), c, mockServer)) }() // Verify all the responses for each channel subscription for i := 0; i < len(channels); i++ { verifyEvent(c, r, []string{"subscribe", channels[i], fmt.Sprintf("%d", i+1)}) } // Subscribe to all the patterns go func() { _, _ = getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), c, mockServer)) }() // Verify all the responses for each pattern subscription for i := 0; i < len(patterns); i++ { verifyEvent(c, r, []string{"psubscribe", patterns[i], fmt.Sprintf("%d", i+1)}) } } subscriptions := map[string]map[string][]string{ "subscriber1": { "channels": {"pub_channel_1", "pub_channel_2", "pub_channel_3"}, // Channels to subscribe to "patterns": {"pub_channel_[456]"}, // Patterns to subscribe to }, "subscriber2": { "channels": {"pub_channel_6", "pub_channel_7"}, // Channels to subscribe to "patterns": {"pub_channel_[891]"}, // Patterns to subscribe to }, } // Create subscriber one and subscribe to channels and patterns r1, w1 := net.Pipe() rc1 := resp.NewConn(r1) subscribe(ctx, subscriptions["subscriber1"]["channels"], subscriptions["subscriber1"]["patterns"], &w1, rc1) // Create subscriber two and subscribe to channels and patterns r2, w2 := net.Pipe() rc2 := resp.NewConn(r2) subscribe(ctx, subscriptions["subscriber2"]["channels"], subscriptions["subscriber2"]["patterns"], &w2, rc2) type SubscriberType struct { c *net.Conn r *resp.Conn l string } tests := []struct { channel string message string subscribers []SubscriberType }{ { channel: "pub_channel_1", message: "Test both subscribers 1", subscribers: []SubscriberType{ {c: &r1, r: rc1, l: "pub_channel_1"}, {c: &r2, r: rc2, l: "pub_channel_[891]"}, }, }, { channel: "pub_channel_6", message: "Test both subscribers 2", subscribers: []SubscriberType{ {c: &r1, r: rc1, l: "pub_channel_[456]"}, {c: &r2, r: rc2, l: "pub_channel_6"}, }, }, { channel: "pub_channel_2", message: "Test subscriber 1 1", subscribers: []SubscriberType{ {c: &r1, r: rc1, l: "pub_channel_2"}, }, }, { channel: "pub_channel_3", message: "Test subscriber 1 2", subscribers: []SubscriberType{ {c: &r1, r: rc1, l: "pub_channel_3"}, }, }, { channel: "pub_channel_4", message: "Test both subscribers 2", subscribers: []SubscriberType{ {c: &r1, r: rc1, l: "pub_channel_[456]"}, }, }, { channel: "pub_channel_5", message: "Test subscriber 1 3", subscribers: []SubscriberType{ {c: &r1, r: rc1, l: "pub_channel_[456]"}, }, }, { channel: "pub_channel_7", message: "Test subscriber 2 1", subscribers: []SubscriberType{ {c: &r2, r: rc2, l: "pub_channel_7"}, }, }, { channel: "pub_channel_8", message: "Test subscriber 2 2", subscribers: []SubscriberType{ {c: &r1, r: rc2, l: "pub_channel_[891]"}, }, }, { channel: "pub_channel_9", message: "Test subscriber 2 3", subscribers: []SubscriberType{ {c: &r2, r: rc2, l: "pub_channel_[891]"}, }, }, } // Dial echovault to make publisher connection conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) if err != nil { t.Error(err) } defer func() { if err = conn.Close(); err != nil { t.Error(err) } }() w := resp.NewConn(conn) for _, test := range tests { err = w.WriteArray([]resp.Value{ resp.StringValue("PUBLISH"), resp.StringValue(test.channel), resp.StringValue(test.message), }) if err != nil { t.Error(err) } rv, _, err := w.ReadValue() if err != nil { t.Error(err) } if rv.String() != "OK" { t.Errorf("Expected publish response to be \"OK\", got \"%s\"", rv.String()) } for _, sub := range test.subscribers { verifyEvent(sub.c, sub.r, []string{"message", sub.l, test.message}) } } } func Test_HandlePubSubChannels(t *testing.T) { done := make(chan struct{}) go func() { // Create separate mock echovault for this test var port uint16 = 7590 mockServer := setUpServer(bindAddr, port) ctx := context.WithValue(context.Background(), "test_name", "PUBSUB CHANNELS") channels := []string{"channel_1", "channel_2", "channel_3"} patterns := []string{"channel_[123]", "channel_[456]"} rConn1, wConn1 := net.Pipe() rc1 := resp.NewConn(rConn1) rConn2, wConn2 := net.Pipe() rc2 := resp.NewConn(rConn2) // Subscribe connections to channels go func() { _, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), &wConn1, mockServer)) if err != nil { t.Error(err) } }() for i := 0; i < len(channels); i++ { v, _, err := rc1.ReadValue() if err != nil { t.Error(err) } if !slices.ContainsFunc(channels, func(s string) bool { return s == v.Array()[1].String() }) { t.Errorf("unexpected channel %s in response", v.Array()[1].String()) } } go func() { _, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), &wConn2, mockServer)) if err != nil { t.Error(err) } }() for i := 0; i < len(patterns); i++ { v, _, err := rc2.ReadValue() if err != nil { t.Error(err) } if !slices.ContainsFunc(patterns, func(s string) bool { return s == v.Array()[1].String() }) { t.Errorf("unexpected pattern %s in response", v.Array()[1].String()) } } verifyExpectedResponse := func(res []byte, expected []string) { rd := resp.NewReader(bytes.NewReader(res)) rv, _, err := rd.ReadValue() if err != nil { t.Error(err) } if len(rv.Array()) != len(expected) { t.Errorf("expected response array of length %d, got %d", len(expected), len(rv.Array())) } for _, e := range expected { if !slices.ContainsFunc(rv.Array(), func(v resp.Value) bool { return e == v.String() }) { t.Errorf("expected to find element \"%s\" in response array, could not find it", e) } } } // Check if all subscriptions are returned res, err := getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS"}, nil, mockServer)) if err != nil { t.Error(err) } verifyExpectedResponse(res, append(channels, patterns...)) // Unsubscribe from one pattern and one channel before checking against a new slice of // expected channels/patterns in the response of the "PUBSUB CHANNELS" command _, err = getHandler("UNSUBSCRIBE")(getHandlerFuncParams( ctx, append([]string{"UNSUBSCRIBE"}, []string{"channel_2", "channel_3"}...), &wConn1, mockServer, )) if err != nil { t.Error(err) } _, err = getHandler("UNSUBSCRIBE")(getHandlerFuncParams( ctx, append([]string{"UNSUBSCRIBE"}, "channel_[456]"), &wConn2, mockServer, )) if err != nil { t.Error(err) } // Return all the remaining channels res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS"}, nil, mockServer)) if err != nil { t.Error(err) } verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"}) // Return only one of the remaining channels when passed a pattern that matches it res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS", "channel_[189]"}, nil, mockServer)) if err != nil { t.Error(err) } verifyExpectedResponse(res, []string{"channel_1"}) // Return both remaining channels when passed a pattern that matches them res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS", "channel_[123]"}, nil, mockServer)) if err != nil { t.Error(err) } verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"}) // Return none channels when passed a pattern that does not match either channel res, err = getHandler("PUBSUB", "CHANNELS")(getHandlerFuncParams(ctx, []string{"PUBSUB", "CHANNELS", "channel_[456]"}, nil, mockServer)) if err != nil { t.Error(err) } verifyExpectedResponse(res, []string{}) done <- struct{}{} }() select { case <-time.After(200 * time.Millisecond): t.Error("timeout") case <-done: } } func Test_HandleNumPat(t *testing.T) { done := make(chan struct{}) go func() { // Create separate mock echovault for this test var port uint16 = 7591 mockServer := setUpServer(bindAddr, port) ctx := context.WithValue(context.Background(), "test_name", "PUBSUB NUMPAT") patterns := []string{"pattern_[123]", "pattern_[456]", "pattern_[789]"} connections := make([]struct { w *net.Conn r *resp.Conn }, 3) for i := 0; i < len(connections); i++ { w, r := net.Pipe() connections[i] = struct { w *net.Conn r *resp.Conn }{w: &w, r: resp.NewConn(r)} go func() { _, err := getHandler("PSUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"PSUBSCRIBE"}, patterns...), &w, mockServer)) if err != nil { t.Error(err) } }() for j := 0; j < len(patterns); j++ { v, _, err := connections[i].r.ReadValue() if err != nil { t.Error(err) } arr := v.Array() if !slices.ContainsFunc(patterns, func(s string) bool { return s == arr[1].String() }) { t.Errorf("found unexpected pattern in response \"%s\"", arr[1].String()) } } } verifyNumPatResponse := func(res []byte, expected int) { rd := resp.NewReader(bytes.NewReader(res)) rv, _, err := rd.ReadValue() if err != nil { t.Error(err) } if rv.Integer() != expected { t.Errorf("expected first NUMPAT response to be %d, got %d", expected, rv.Integer()) } } // Check that we receive all the patterns with NUMPAT commands res, err := getHandler("PUBSUB", "NUMPAT")(getHandlerFuncParams(ctx, []string{"PUBSUB", "NUMPAT"}, nil, mockServer)) if err != nil { t.Error(err) } verifyNumPatResponse(res, len(patterns)) // Unsubscribe from a channel and check if the number of active channels is updated for _, conn := range connections { _, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, []string{"PUNSUBSCRIBE", patterns[0]}, conn.w, mockServer)) if err != nil { t.Error(err) } } res, err = getHandler("PUBSUB", "NUMPAT")(getHandlerFuncParams(ctx, []string{"PUBSUB", "NUMPAT"}, nil, mockServer)) if err != nil { t.Error(err) } verifyNumPatResponse(res, len(patterns)-1) // Unsubscribe from all the channels and check if we get a 0 response for _, conn := range connections { _, err = getHandler("PUNSUBSCRIBE")(getHandlerFuncParams(ctx, []string{"PUNSUBSCRIBE"}, conn.w, mockServer)) if err != nil { t.Error(err) } } res, err = getHandler("PUBSUB", "NUMPAT")(getHandlerFuncParams(ctx, []string{"PUBSUB", "NUMPAT"}, nil, mockServer)) if err != nil { t.Error(err) } verifyNumPatResponse(res, 0) done <- struct{}{} }() select { case <-time.After(200 * time.Millisecond): t.Error("timeout") case <-done: } } func Test_HandleNumSub(t *testing.T) { done := make(chan struct{}) go func() { // Create separate mock echovault for this test var port uint16 = 7591 mockServer := setUpServer(bindAddr, port) ctx := context.WithValue(context.Background(), "test_name", "PUBSUB NUMSUB") channels := []string{"channel_1", "channel_2", "channel_3"} connections := make([]struct { w *net.Conn r *resp.Conn }, 3) for i := 0; i < len(connections); i++ { w, r := net.Pipe() connections[i] = struct { w *net.Conn r *resp.Conn }{w: &w, r: resp.NewConn(r)} go func() { _, err := getHandler("SUBSCRIBE")(getHandlerFuncParams(ctx, append([]string{"SUBSCRIBE"}, channels...), &w, mockServer)) if err != nil { t.Error(err) } }() for j := 0; j < len(channels); j++ { v, _, err := connections[i].r.ReadValue() if err != nil { t.Error(err) } arr := v.Array() if !slices.ContainsFunc(channels, func(s string) bool { return s == arr[1].String() }) { t.Errorf("found unexpected pattern in response \"%s\"", arr[1].String()) } } } tests := []struct { cmd []string expectedResponse [][]string }{ { // 1. Get all subscriptions on existing channels cmd: append([]string{"PUBSUB", "NUMSUB"}, channels...), expectedResponse: [][]string{{"channel_1", "3"}, {"channel_2", "3"}, {"channel_3", "3"}}, }, { // 2. Get all the subscriptions of on existing channels and a few non-existent ones cmd: append([]string{"PUBSUB", "NUMSUB", "non_existent_channel_1", "non_existent_channel_2"}, channels...), expectedResponse: [][]string{ {"non_existent_channel_1", "0"}, {"non_existent_channel_2", "0"}, {"channel_1", "3"}, {"channel_2", "3"}, {"channel_3", "3"}, }, }, { // 3. Get an empty array when channels are not provided in the command cmd: []string{"PUBSUB", "NUMSUB"}, expectedResponse: make([][]string, 0), }, } for i, test := range tests { ctx = context.WithValue(ctx, "test_index", i) res, err := getHandler("PUBSUB", "NUMSUB")(getHandlerFuncParams(ctx, test.cmd, nil, mockServer)) if err != nil { t.Error(err) } rd := resp.NewReader(bytes.NewReader(res)) rv, _, err := rd.ReadValue() if err != nil { t.Error(err) } arr := rv.Array() if len(arr) != len(test.expectedResponse) { t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(arr)) } for _, item := range arr { itemArr := item.Array() if len(itemArr) != 2 { t.Errorf("expected each response item to be of length 2, got %d", len(itemArr)) } if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { return expected[0] == itemArr[0].String() && expected[1] == itemArr[1].String() }) { t.Errorf("could not find entry with channel \"%s\", with %d subscribers in expected response", itemArr[0].String(), itemArr[1].Integer()) } } } done <- struct{}{} }() select { case <-time.After(200 * time.Millisecond): t.Error("timeout") case <-done: } }