diff --git a/src/modules/pubsub/commands.go b/src/modules/pubsub/commands.go index bbbd97d..b95317a 100644 --- a/src/modules/pubsub/commands.go +++ b/src/modules/pubsub/commands.go @@ -40,7 +40,7 @@ func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, c return pubsub.Unsubscribe(ctx, conn, channels, withPattern), nil } -func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { +func handlePublish(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { pubsub, ok := server.GetPubSub().(*PubSub) if !ok { return nil, errors.New("could not load pubsub module") @@ -52,7 +52,7 @@ func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn return []byte(utils.OkResponse), nil } -func handlePubSubChannels(_ context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { +func handlePubSubChannels(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { if len(cmd) > 3 { return nil, errors.New(utils.WrongArgsResponse) } @@ -70,7 +70,7 @@ func handlePubSubChannels(_ context.Context, cmd []string, server utils.Server, return pubsub.Channels(pattern), nil } -func handlePubSubNumPat(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { +func handlePubSubNumPat(_ context.Context, _ []string, server utils.Server, _ *net.Conn) ([]byte, error) { pubsub, ok := server.GetPubSub().(*PubSub) if !ok { return nil, errors.New("could not load pubsub module") @@ -79,7 +79,7 @@ func handlePubSubNumPat(ctx context.Context, cmd []string, server utils.Server, return []byte(fmt.Sprintf(":%d\r\n", num)), nil } -func handlePubSubNumSubs(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { +func handlePubSubNumSubs(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { pubsub, ok := server.GetPubSub().(*PubSub) if !ok { return nil, errors.New("could not load pubsub module") diff --git a/src/modules/pubsub/commands_test.go b/src/modules/pubsub/commands_test.go index 8628c87..8792bb3 100644 --- a/src/modules/pubsub/commands_test.go +++ b/src/modules/pubsub/commands_test.go @@ -715,10 +715,124 @@ func Test_HandleNumPat(t *testing.T) { }() select { - case <-time.After(300 * time.Millisecond): + case <-time.After(200 * time.Millisecond): t.Error("timeout") case <-done: } } -func Test_HandleNumSub(t *testing.T) {} +func Test_HandleNumSub(t *testing.T) { + done := make(chan struct{}) + go func() { + // Create separate mock server for this test + var port uint16 = 7591 + pubsub = NewPubSub() + mockServer := server.NewServer(server.Opts{ + PubSub: pubsub, + Commands: Commands(), + Config: utils.Config{ + BindAddr: bindAddr, + Port: port, + DataDir: "", + EvictionPolicy: utils.NoEviction, + }, + }) + + ctx := context.WithValue(context.Background(), "test_name", "PUBSUB NUMSUB") + + channels := []string{"channel_1", "channel_2", "channel_3"} + connections := make([]struct { + w *net.Conn + r *resp.Conn + }, 3) + for i := 0; i < len(connections); i++ { + w, r := net.Pipe() + connections[i] = struct { + w *net.Conn + r *resp.Conn + }{w: &w, r: resp.NewConn(r)} + go func() { + if _, err := handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, &w); err != nil { + t.Error(err) + } + }() + for j := 0; j < len(channels); j++ { + v, _, err := connections[i].r.ReadValue() + if err != nil { + t.Error(err) + } + arr := v.Array() + if !slices.ContainsFunc(channels, func(s string) bool { + return s == arr[1].String() + }) { + t.Errorf("found unexpected pattern in response \"%s\"", arr[1].String()) + } + } + } + + tests := []struct { + cmd []string + expectedResponse [][]string + }{ + { // 1. Get all subscriptions on existing channels + cmd: append([]string{"PUBSUB", "NUMSUB"}, channels...), + expectedResponse: [][]string{{"channel_1", "3"}, {"channel_2", "3"}, {"channel_3", "3"}}, + }, + { // 2. Get all the subscriptions of on existing channels and a few non-existent ones + cmd: append([]string{"PUBSUB", "NUMSUB", "non_existent_channel_1", "non_existent_channel_2"}, channels...), + expectedResponse: [][]string{ + {"non_existent_channel_1", "0"}, + {"non_existent_channel_2", "0"}, + {"channel_1", "3"}, + {"channel_2", "3"}, + {"channel_3", "3"}, + }, + }, + { // 3. Get an empty array when channels are not provided in the command + cmd: []string{"PUBSUB", "NUMSUB"}, + expectedResponse: make([][]string, 0), + }, + } + + for i, test := range tests { + ctx = context.WithValue(ctx, "test_index", i) + + res, err := handlePubSubNumSubs(ctx, test.cmd, mockServer, nil) + if err != nil { + t.Error(err) + } + + rd := resp.NewReader(bytes.NewReader(res)) + rv, _, err := rd.ReadValue() + if err != nil { + t.Error(err) + } + + arr := rv.Array() + if len(arr) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(arr)) + } + + for _, item := range arr { + itemArr := item.Array() + if len(itemArr) != 2 { + t.Errorf("expected each response item to be of length 2, got %d", len(itemArr)) + } + if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { + return expected[0] == itemArr[0].String() && expected[1] == itemArr[1].String() + }) { + t.Errorf("could not find entry with channel \"%s\", with %d subscribers in expected response", + itemArr[0].String(), itemArr[1].Integer()) + } + } + } + + done <- struct{}{} + }() + + select { + case <-time.After(200 * time.Millisecond): + t.Error("timeout") + case <-done: + } +} diff --git a/src/modules/pubsub/pubsub.go b/src/modules/pubsub/pubsub.go index f61a28b..d66279f 100644 --- a/src/modules/pubsub/pubsub.go +++ b/src/modules/pubsub/pubsub.go @@ -237,6 +237,7 @@ func (ps *PubSub) NumSub(channels []string) []byte { res := fmt.Sprintf("*%d\r\n", len(channels)) for _, channel := range channels { + // If it's a pattern channel, skip it chanIdx := slices.IndexFunc(ps.channels, func(c *Channel) bool { return c.name == channel })