From a19bfa8f7380603df1d563b83c3daef5c394595c Mon Sep 17 00:00:00 2001 From: Kelvin Mwinuka Date: Wed, 20 Mar 2024 12:36:03 +0800 Subject: [PATCH] Implemented rlocking/runlocking of rwmutex in Channel IsActive and NumSubs receiver functions. Implemented PUBSUB NUMPAT handler unit test --- src/modules/pubsub/channel.go | 4 ++ src/modules/pubsub/commands_test.go | 104 +++++++++++++++++++++++++++- src/modules/pubsub/pubsub.go | 2 +- 3 files changed, 108 insertions(+), 2 deletions(-) diff --git a/src/modules/pubsub/channel.go b/src/modules/pubsub/channel.go index 3aeedcc..7d88f06 100644 --- a/src/modules/pubsub/channel.go +++ b/src/modules/pubsub/channel.go @@ -99,9 +99,13 @@ func (ch *Channel) Publish(message string) { } func (ch *Channel) IsActive() bool { + ch.subscribersRWMut.RLock() + defer ch.subscribersRWMut.RUnlock() return len(ch.subscribers) > 0 } func (ch *Channel) NumSubs() int { + ch.subscribersRWMut.RLock() + defer ch.subscribersRWMut.RUnlock() return len(ch.subscribers) } diff --git a/src/modules/pubsub/commands_test.go b/src/modules/pubsub/commands_test.go index d0004c2..8628c87 100644 --- a/src/modules/pubsub/commands_test.go +++ b/src/modules/pubsub/commands_test.go @@ -617,6 +617,108 @@ func Test_HandlePubSubChannels(t *testing.T) { } } -func Test_HandleNumPat(t *testing.T) {} +func Test_HandleNumPat(t *testing.T) { + done := make(chan struct{}) + go func() { + // Create separate mock server for this test + var port uint16 = 7591 + pubsub = NewPubSub() + mockServer := server.NewServer(server.Opts{ + PubSub: pubsub, + Commands: Commands(), + Config: utils.Config{ + BindAddr: bindAddr, + Port: port, + DataDir: "", + EvictionPolicy: utils.NoEviction, + }, + }) + + ctx := context.WithValue(context.Background(), "test_name", "PUBSUB NUMPAT") + + patterns := []string{"pattern_[123]", "pattern_[456]", "pattern_[789]"} + + connections := make([]struct { + w *net.Conn + r *resp.Conn + }, 3) + for i := 0; i < len(connections); i++ { + w, r := net.Pipe() + connections[i] = struct { + w *net.Conn + r *resp.Conn + }{w: &w, r: resp.NewConn(r)} + go func() { + if _, err := handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, patterns...), mockServer, &w); err != nil { + t.Error(err) + } + }() + for j := 0; j < len(patterns); j++ { + v, _, err := connections[i].r.ReadValue() + if err != nil { + t.Error(err) + } + arr := v.Array() + if !slices.ContainsFunc(patterns, func(s string) bool { + return s == arr[1].String() + }) { + t.Errorf("found unexpected pattern in response \"%s\"", arr[1].String()) + } + } + } + + verifyNumPatResponse := func(res []byte, expected int) { + rd := resp.NewReader(bytes.NewReader(res)) + rv, _, err := rd.ReadValue() + if err != nil { + t.Error(err) + } + if rv.Integer() != expected { + t.Errorf("expected first NUMPAT response to be %d, got %d", expected, rv.Integer()) + } + } + + // Check that we receive all the patterns with NUMPAT commands + res, err := handlePubSubNumPat(ctx, []string{"PUBSUB", "NUMPAT"}, mockServer, nil) + if err != nil { + t.Error(err) + } + verifyNumPatResponse(res, len(patterns)) + + // Unsubscribe from a channel and check if the number of active channels is updated + for _, conn := range connections { + _, err = handleUnsubscribe(ctx, []string{"PUNSUBSCRIBE", patterns[0]}, mockServer, conn.w) + if err != nil { + t.Error(err) + } + } + res, err = handlePubSubNumPat(ctx, []string{"PUBSUB", "NUMPAT"}, mockServer, nil) + if err != nil { + t.Error(err) + } + verifyNumPatResponse(res, len(patterns)-1) + + // Unsubscribe from all the channels and check if we get a 0 response + for _, conn := range connections { + _, err = handleUnsubscribe(ctx, []string{"PUNSUBSCRIBE"}, mockServer, conn.w) + if err != nil { + t.Error(err) + } + } + res, err = handlePubSubNumPat(ctx, []string{"PUBSUB", "NUMPAT"}, mockServer, nil) + if err != nil { + t.Error(err) + } + verifyNumPatResponse(res, 0) + + done <- struct{}{} + }() + + select { + case <-time.After(300 * time.Millisecond): + t.Error("timeout") + case <-done: + } +} func Test_HandleNumSub(t *testing.T) {} diff --git a/src/modules/pubsub/pubsub.go b/src/modules/pubsub/pubsub.go index 04e0d7c..f61a28b 100644 --- a/src/modules/pubsub/pubsub.go +++ b/src/modules/pubsub/pubsub.go @@ -224,7 +224,7 @@ func (ps *PubSub) NumPat() int { var count int for _, channel := range ps.channels { - if channel.pattern != nil { + if channel.pattern != nil && channel.IsActive() { count += 1 } }