Implemented unit test for PUBSUB NUMSUB command handler

This commit is contained in:
Kelvin Mwinuka
2024-03-20 13:27:58 +08:00
parent a19bfa8f73
commit d08231f82f
3 changed files with 121 additions and 6 deletions

View File

@@ -40,7 +40,7 @@ func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, c
return pubsub.Unsubscribe(ctx, conn, channels, withPattern), nil return pubsub.Unsubscribe(ctx, conn, channels, withPattern), nil
} }
func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handlePublish(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*PubSub) pubsub, ok := server.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub module") return nil, errors.New("could not load pubsub module")
@@ -52,7 +52,7 @@ func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn
return []byte(utils.OkResponse), nil return []byte(utils.OkResponse), nil
} }
func handlePubSubChannels(_ context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handlePubSubChannels(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) > 3 { if len(cmd) > 3 {
return nil, errors.New(utils.WrongArgsResponse) return nil, errors.New(utils.WrongArgsResponse)
} }
@@ -70,7 +70,7 @@ func handlePubSubChannels(_ context.Context, cmd []string, server utils.Server,
return pubsub.Channels(pattern), nil return pubsub.Channels(pattern), nil
} }
func handlePubSubNumPat(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handlePubSubNumPat(_ context.Context, _ []string, server utils.Server, _ *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*PubSub) pubsub, ok := server.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub module") return nil, errors.New("could not load pubsub module")
@@ -79,7 +79,7 @@ func handlePubSubNumPat(ctx context.Context, cmd []string, server utils.Server,
return []byte(fmt.Sprintf(":%d\r\n", num)), nil return []byte(fmt.Sprintf(":%d\r\n", num)), nil
} }
func handlePubSubNumSubs(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handlePubSubNumSubs(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*PubSub) pubsub, ok := server.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub module") return nil, errors.New("could not load pubsub module")

View File

@@ -715,10 +715,124 @@ func Test_HandleNumPat(t *testing.T) {
}() }()
select { select {
case <-time.After(300 * time.Millisecond): case <-time.After(200 * time.Millisecond):
t.Error("timeout") t.Error("timeout")
case <-done: case <-done:
} }
} }
func Test_HandleNumSub(t *testing.T) {} func Test_HandleNumSub(t *testing.T) {
done := make(chan struct{})
go func() {
// Create separate mock server for this test
var port uint16 = 7591
pubsub = NewPubSub()
mockServer := server.NewServer(server.Opts{
PubSub: pubsub,
Commands: Commands(),
Config: utils.Config{
BindAddr: bindAddr,
Port: port,
DataDir: "",
EvictionPolicy: utils.NoEviction,
},
})
ctx := context.WithValue(context.Background(), "test_name", "PUBSUB NUMSUB")
channels := []string{"channel_1", "channel_2", "channel_3"}
connections := make([]struct {
w *net.Conn
r *resp.Conn
}, 3)
for i := 0; i < len(connections); i++ {
w, r := net.Pipe()
connections[i] = struct {
w *net.Conn
r *resp.Conn
}{w: &w, r: resp.NewConn(r)}
go func() {
if _, err := handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, &w); err != nil {
t.Error(err)
}
}()
for j := 0; j < len(channels); j++ {
v, _, err := connections[i].r.ReadValue()
if err != nil {
t.Error(err)
}
arr := v.Array()
if !slices.ContainsFunc(channels, func(s string) bool {
return s == arr[1].String()
}) {
t.Errorf("found unexpected pattern in response \"%s\"", arr[1].String())
}
}
}
tests := []struct {
cmd []string
expectedResponse [][]string
}{
{ // 1. Get all subscriptions on existing channels
cmd: append([]string{"PUBSUB", "NUMSUB"}, channels...),
expectedResponse: [][]string{{"channel_1", "3"}, {"channel_2", "3"}, {"channel_3", "3"}},
},
{ // 2. Get all the subscriptions of on existing channels and a few non-existent ones
cmd: append([]string{"PUBSUB", "NUMSUB", "non_existent_channel_1", "non_existent_channel_2"}, channels...),
expectedResponse: [][]string{
{"non_existent_channel_1", "0"},
{"non_existent_channel_2", "0"},
{"channel_1", "3"},
{"channel_2", "3"},
{"channel_3", "3"},
},
},
{ // 3. Get an empty array when channels are not provided in the command
cmd: []string{"PUBSUB", "NUMSUB"},
expectedResponse: make([][]string, 0),
},
}
for i, test := range tests {
ctx = context.WithValue(ctx, "test_index", i)
res, err := handlePubSubNumSubs(ctx, test.cmd, mockServer, nil)
if err != nil {
t.Error(err)
}
rd := resp.NewReader(bytes.NewReader(res))
rv, _, err := rd.ReadValue()
if err != nil {
t.Error(err)
}
arr := rv.Array()
if len(arr) != len(test.expectedResponse) {
t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(arr))
}
for _, item := range arr {
itemArr := item.Array()
if len(itemArr) != 2 {
t.Errorf("expected each response item to be of length 2, got %d", len(itemArr))
}
if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool {
return expected[0] == itemArr[0].String() && expected[1] == itemArr[1].String()
}) {
t.Errorf("could not find entry with channel \"%s\", with %d subscribers in expected response",
itemArr[0].String(), itemArr[1].Integer())
}
}
}
done <- struct{}{}
}()
select {
case <-time.After(200 * time.Millisecond):
t.Error("timeout")
case <-done:
}
}

View File

@@ -237,6 +237,7 @@ func (ps *PubSub) NumSub(channels []string) []byte {
res := fmt.Sprintf("*%d\r\n", len(channels)) res := fmt.Sprintf("*%d\r\n", len(channels))
for _, channel := range channels { for _, channel := range channels {
// If it's a pattern channel, skip it
chanIdx := slices.IndexFunc(ps.channels, func(c *Channel) bool { chanIdx := slices.IndexFunc(ps.channels, func(c *Channel) bool {
return c.name == channel return c.name == channel
}) })