diff --git a/src/modules/pubsub/commands.go b/src/modules/pubsub/commands.go index b0a4639..e2d453b 100644 --- a/src/modules/pubsub/commands.go +++ b/src/modules/pubsub/commands.go @@ -49,7 +49,6 @@ func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn return nil, errors.New(utils.WrongArgsResponse) } pubsub.Publish(ctx, cmd[2], cmd[1]) - fmt.Println("PUBLISHED:", cmd[2]) return []byte(utils.OkResponse), nil } diff --git a/src/modules/pubsub/commands_test.go b/src/modules/pubsub/commands_test.go index 7082eb7..d0004c2 100644 --- a/src/modules/pubsub/commands_test.go +++ b/src/modules/pubsub/commands_test.go @@ -147,7 +147,6 @@ func Test_HandleUnsubscribe(t *testing.T) { } v := rv.Array() if len(v) != len(expectedResponse) { - fmt.Println(v) t.Errorf("expected subscribe response of length %d, but got %d", len(expectedResponse), len(v)) } for _, item := range v { @@ -314,7 +313,6 @@ func Test_HandlePublish(t *testing.T) { t.Errorf("expected item at index %d to be \"%s\", got \"%s\"", i, expected[i], v[i].String()) } } - fmt.Println(v) } // The subscribe function handles subscribing the connection to the given @@ -474,7 +472,150 @@ func Test_HandlePublish(t *testing.T) { } } -func Test_HandlePubSubChannels(t *testing.T) {} +func Test_HandlePubSubChannels(t *testing.T) { + done := make(chan struct{}) + go func() { + // Create separate mock server for this test + var port uint16 = 7590 + pubsub = NewPubSub() + mockServer := server.NewServer(server.Opts{ + PubSub: pubsub, + Commands: Commands(), + Config: utils.Config{ + BindAddr: bindAddr, + Port: port, + DataDir: "", + EvictionPolicy: utils.NoEviction, + }, + }) + + ctx := context.WithValue(context.Background(), "test_name", "PUBSUB CHANNELS") + + channels := []string{"channel_1", "channel_2", "channel_3"} + patterns := []string{"channel_[123]", "channel_[456]"} + + rConn1, wConn1 := net.Pipe() + rc1 := resp.NewConn(rConn1) + + rConn2, wConn2 := net.Pipe() + rc2 := resp.NewConn(rConn2) + + // Subscribe connections to channels + go func() { + _, err := handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, &wConn1) + if err != nil { + t.Error(err) + } + }() + for i := 0; i < len(channels); i++ { + v, _, err := rc1.ReadValue() + if err != nil { + t.Error(err) + } + if !slices.ContainsFunc(channels, func(s string) bool { + return s == v.Array()[1].String() + }) { + t.Errorf("unexpected channel %s in response", v.Array()[1].String()) + } + } + go func() { + _, err := handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, patterns...), mockServer, &wConn2) + if err != nil { + t.Error(err) + } + }() + for i := 0; i < len(patterns); i++ { + v, _, err := rc2.ReadValue() + if err != nil { + t.Error(err) + } + if !slices.ContainsFunc(patterns, func(s string) bool { + return s == v.Array()[1].String() + }) { + t.Errorf("unexpected pattern %s in response", v.Array()[1].String()) + } + } + + verifyExpectedResponse := func(res []byte, expected []string) { + rd := resp.NewReader(bytes.NewReader(res)) + rv, _, err := rd.ReadValue() + if err != nil { + t.Error(err) + } + if len(rv.Array()) != len(expected) { + t.Errorf("expected response array of length %d, got %d", len(expected), len(rv.Array())) + } + for _, e := range expected { + if !slices.ContainsFunc(rv.Array(), func(v resp.Value) bool { + return e == v.String() + }) { + t.Errorf("expected to find element \"%s\" in response array, could not find it", e) + } + } + } + + // Check if all subscriptions are returned + res, err := handlePubSubChannels(ctx, []string{"PUBSUB", "CHANNELS"}, mockServer, nil) + if err != nil { + t.Error(err) + } + verifyExpectedResponse(res, append(channels, patterns...)) + + // Unsubscribe from one pattern and one channel before checking against a new slice of + // expected channels/patterns in the response of the "PUBSUB CHANNELS" command + _, err = handleUnsubscribe( + ctx, + append([]string{"UNSUBSCRIBE"}, []string{"channel_2", "channel_3"}...), + mockServer, + &wConn1, + ) + if err != nil { + t.Error(err) + } + _, err = handleUnsubscribe( + ctx, + append([]string{"UNSUBSCRIBE"}, "channel_[456]"), + mockServer, + &wConn2, + ) + if err != nil { + t.Error(err) + } + + // Return all the remaining channels + res, err = handlePubSubChannels(ctx, []string{"PUBSUB", "CHANNELS"}, mockServer, nil) + if err != nil { + t.Error(err) + } + verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"}) + // Return only one of the remaining channels when passed a pattern that matches it + res, err = handlePubSubChannels(ctx, []string{"PUBSUB", "CHANNELS", "channel_[189]"}, mockServer, nil) + if err != nil { + t.Error(err) + } + verifyExpectedResponse(res, []string{"channel_1"}) + // Return both remaining channels when passed a pattern that matches them + res, err = handlePubSubChannels(ctx, []string{"PUBSUB", "CHANNELS", "channel_[123]"}, mockServer, nil) + if err != nil { + t.Error(err) + } + verifyExpectedResponse(res, []string{"channel_1", "channel_[123]"}) + // Return none channels when passed a pattern that does not match either channel + res, err = handlePubSubChannels(ctx, []string{"PUBSUB", "CHANNELS", "channel_[456]"}, mockServer, nil) + if err != nil { + t.Error(err) + } + verifyExpectedResponse(res, []string{}) + + done <- struct{}{} + }() + + select { + case <-time.After(200 * time.Millisecond): + t.Error("timeout") + case <-done: + } +} func Test_HandleNumPat(t *testing.T) {} diff --git a/src/modules/pubsub/pubsub.go b/src/modules/pubsub/pubsub.go index 3678c23..b7f4b7c 100644 --- a/src/modules/pubsub/pubsub.go +++ b/src/modules/pubsub/pubsub.go @@ -25,6 +25,9 @@ func NewPubSub() *PubSub { } func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) { + ps.channelsRWMut.Lock() + defer ps.channelsRWMut.Unlock() + r := resp.NewConn(*conn) action := "subscribe" @@ -76,7 +79,7 @@ func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []stri func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte { ps.channelsRWMut.RLock() - ps.channelsRWMut.RUnlock() + defer ps.channelsRWMut.RUnlock() action := "unsubscribe" if withPattern { @@ -179,6 +182,9 @@ func (ps *PubSub) Publish(ctx context.Context, message string, channelName strin } func (ps *PubSub) Channels(pattern string) []byte { + ps.channelsRWMut.RLock() + defer ps.channelsRWMut.RUnlock() + var count int var res string