mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-06 16:36:54 +08:00
Implemented unit test for PUBSUB NUMSUB command handler
This commit is contained in:
@@ -40,7 +40,7 @@ func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, c
|
|||||||
return pubsub.Unsubscribe(ctx, conn, channels, withPattern), nil
|
return pubsub.Unsubscribe(ctx, conn, channels, withPattern), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
func handlePublish(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||||
pubsub, ok := server.GetPubSub().(*PubSub)
|
pubsub, ok := server.GetPubSub().(*PubSub)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load pubsub module")
|
return nil, errors.New("could not load pubsub module")
|
||||||
@@ -52,7 +52,7 @@ func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn
|
|||||||
return []byte(utils.OkResponse), nil
|
return []byte(utils.OkResponse), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handlePubSubChannels(_ context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
func handlePubSubChannels(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||||
if len(cmd) > 3 {
|
if len(cmd) > 3 {
|
||||||
return nil, errors.New(utils.WrongArgsResponse)
|
return nil, errors.New(utils.WrongArgsResponse)
|
||||||
}
|
}
|
||||||
@@ -70,7 +70,7 @@ func handlePubSubChannels(_ context.Context, cmd []string, server utils.Server,
|
|||||||
return pubsub.Channels(pattern), nil
|
return pubsub.Channels(pattern), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handlePubSubNumPat(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
func handlePubSubNumPat(_ context.Context, _ []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||||
pubsub, ok := server.GetPubSub().(*PubSub)
|
pubsub, ok := server.GetPubSub().(*PubSub)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load pubsub module")
|
return nil, errors.New("could not load pubsub module")
|
||||||
@@ -79,7 +79,7 @@ func handlePubSubNumPat(ctx context.Context, cmd []string, server utils.Server,
|
|||||||
return []byte(fmt.Sprintf(":%d\r\n", num)), nil
|
return []byte(fmt.Sprintf(":%d\r\n", num)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handlePubSubNumSubs(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
func handlePubSubNumSubs(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||||
pubsub, ok := server.GetPubSub().(*PubSub)
|
pubsub, ok := server.GetPubSub().(*PubSub)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("could not load pubsub module")
|
return nil, errors.New("could not load pubsub module")
|
||||||
|
@@ -715,10 +715,124 @@ func Test_HandleNumPat(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-time.After(300 * time.Millisecond):
|
case <-time.After(200 * time.Millisecond):
|
||||||
t.Error("timeout")
|
t.Error("timeout")
|
||||||
case <-done:
|
case <-done:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_HandleNumSub(t *testing.T) {}
|
func Test_HandleNumSub(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
// Create separate mock server for this test
|
||||||
|
var port uint16 = 7591
|
||||||
|
pubsub = NewPubSub()
|
||||||
|
mockServer := server.NewServer(server.Opts{
|
||||||
|
PubSub: pubsub,
|
||||||
|
Commands: Commands(),
|
||||||
|
Config: utils.Config{
|
||||||
|
BindAddr: bindAddr,
|
||||||
|
Port: port,
|
||||||
|
DataDir: "",
|
||||||
|
EvictionPolicy: utils.NoEviction,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), "test_name", "PUBSUB NUMSUB")
|
||||||
|
|
||||||
|
channels := []string{"channel_1", "channel_2", "channel_3"}
|
||||||
|
connections := make([]struct {
|
||||||
|
w *net.Conn
|
||||||
|
r *resp.Conn
|
||||||
|
}, 3)
|
||||||
|
for i := 0; i < len(connections); i++ {
|
||||||
|
w, r := net.Pipe()
|
||||||
|
connections[i] = struct {
|
||||||
|
w *net.Conn
|
||||||
|
r *resp.Conn
|
||||||
|
}{w: &w, r: resp.NewConn(r)}
|
||||||
|
go func() {
|
||||||
|
if _, err := handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, &w); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
for j := 0; j < len(channels); j++ {
|
||||||
|
v, _, err := connections[i].r.ReadValue()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
arr := v.Array()
|
||||||
|
if !slices.ContainsFunc(channels, func(s string) bool {
|
||||||
|
return s == arr[1].String()
|
||||||
|
}) {
|
||||||
|
t.Errorf("found unexpected pattern in response \"%s\"", arr[1].String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
cmd []string
|
||||||
|
expectedResponse [][]string
|
||||||
|
}{
|
||||||
|
{ // 1. Get all subscriptions on existing channels
|
||||||
|
cmd: append([]string{"PUBSUB", "NUMSUB"}, channels...),
|
||||||
|
expectedResponse: [][]string{{"channel_1", "3"}, {"channel_2", "3"}, {"channel_3", "3"}},
|
||||||
|
},
|
||||||
|
{ // 2. Get all the subscriptions of on existing channels and a few non-existent ones
|
||||||
|
cmd: append([]string{"PUBSUB", "NUMSUB", "non_existent_channel_1", "non_existent_channel_2"}, channels...),
|
||||||
|
expectedResponse: [][]string{
|
||||||
|
{"non_existent_channel_1", "0"},
|
||||||
|
{"non_existent_channel_2", "0"},
|
||||||
|
{"channel_1", "3"},
|
||||||
|
{"channel_2", "3"},
|
||||||
|
{"channel_3", "3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{ // 3. Get an empty array when channels are not provided in the command
|
||||||
|
cmd: []string{"PUBSUB", "NUMSUB"},
|
||||||
|
expectedResponse: make([][]string, 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
ctx = context.WithValue(ctx, "test_index", i)
|
||||||
|
|
||||||
|
res, err := handlePubSubNumSubs(ctx, test.cmd, mockServer, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rd := resp.NewReader(bytes.NewReader(res))
|
||||||
|
rv, _, err := rd.ReadValue()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
arr := rv.Array()
|
||||||
|
if len(arr) != len(test.expectedResponse) {
|
||||||
|
t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(arr))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range arr {
|
||||||
|
itemArr := item.Array()
|
||||||
|
if len(itemArr) != 2 {
|
||||||
|
t.Errorf("expected each response item to be of length 2, got %d", len(itemArr))
|
||||||
|
}
|
||||||
|
if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool {
|
||||||
|
return expected[0] == itemArr[0].String() && expected[1] == itemArr[1].String()
|
||||||
|
}) {
|
||||||
|
t.Errorf("could not find entry with channel \"%s\", with %d subscribers in expected response",
|
||||||
|
itemArr[0].String(), itemArr[1].Integer())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("timeout")
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -237,6 +237,7 @@ func (ps *PubSub) NumSub(channels []string) []byte {
|
|||||||
|
|
||||||
res := fmt.Sprintf("*%d\r\n", len(channels))
|
res := fmt.Sprintf("*%d\r\n", len(channels))
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
|
// If it's a pattern channel, skip it
|
||||||
chanIdx := slices.IndexFunc(ps.channels, func(c *Channel) bool {
|
chanIdx := slices.IndexFunc(ps.channels, func(c *Channel) bool {
|
||||||
return c.name == channel
|
return c.name == channel
|
||||||
})
|
})
|
||||||
|
Reference in New Issue
Block a user