mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-06 00:16:53 +08:00
Added test for PUBLISH command handler
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,3 +6,4 @@ volumes
|
||||
|
||||
dist/
|
||||
src/modules/*/aof
|
||||
dump.rdb
|
@@ -1,12 +1,10 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gobwas/glob"
|
||||
"io"
|
||||
"github.com/tidwall/resp"
|
||||
"log"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@@ -17,7 +15,7 @@ type Channel struct {
|
||||
name string
|
||||
pattern glob.Glob
|
||||
subscribersRWMut sync.RWMutex
|
||||
subscribers []*net.Conn
|
||||
subscribers map[*net.Conn]*resp.Conn
|
||||
messageChan *chan string
|
||||
}
|
||||
|
||||
@@ -41,7 +39,7 @@ func NewChannel(options ...func(channel *Channel)) *Channel {
|
||||
name: "",
|
||||
pattern: nil,
|
||||
subscribersRWMut: sync.RWMutex{},
|
||||
subscribers: []*net.Conn{},
|
||||
subscribers: make(map[*net.Conn]*resp.Conn),
|
||||
messageChan: &messageChan,
|
||||
}
|
||||
|
||||
@@ -60,10 +58,12 @@ func (ch *Channel) Start() {
|
||||
ch.subscribersRWMut.RLock()
|
||||
|
||||
for _, conn := range ch.subscribers {
|
||||
go func(conn *net.Conn) {
|
||||
w := io.Writer(*conn)
|
||||
|
||||
if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(message), message))); err != nil {
|
||||
go func(conn *resp.Conn) {
|
||||
if err := conn.WriteArray([]resp.Value{
|
||||
resp.StringValue("message"),
|
||||
resp.StringValue(ch.name),
|
||||
resp.StringValue(message),
|
||||
}); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}(conn)
|
||||
@@ -74,30 +74,24 @@ func (ch *Channel) Start() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (ch *Channel) Subscribe(conn *net.Conn) {
|
||||
if !slices.Contains(ch.subscribers, conn) {
|
||||
func (ch *Channel) Subscribe(conn *net.Conn) bool {
|
||||
ch.subscribersRWMut.Lock()
|
||||
defer ch.subscribersRWMut.Unlock()
|
||||
|
||||
ch.subscribers = append(ch.subscribers, conn)
|
||||
if _, ok := ch.subscribers[conn]; !ok {
|
||||
ch.subscribers[conn] = resp.NewConn(*conn)
|
||||
}
|
||||
_, ok := ch.subscribers[conn]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (ch *Channel) Unsubscribe(conn *net.Conn) bool {
|
||||
ch.subscribersRWMut.Lock()
|
||||
defer ch.subscribersRWMut.Unlock()
|
||||
|
||||
var removed bool
|
||||
|
||||
ch.subscribers = slices.DeleteFunc(ch.subscribers, func(c *net.Conn) bool {
|
||||
if c == conn {
|
||||
removed = true
|
||||
return true
|
||||
}
|
||||
if _, ok := ch.subscribers[conn]; !ok {
|
||||
return false
|
||||
})
|
||||
|
||||
return removed
|
||||
}
|
||||
delete(ch.subscribers, conn)
|
||||
return true
|
||||
}
|
||||
|
||||
func (ch *Channel) Publish(message string) {
|
||||
|
@@ -22,8 +22,9 @@ func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, con
|
||||
}
|
||||
|
||||
withPattern := strings.EqualFold(cmd[0], "psubscribe")
|
||||
pubsub.Subscribe(ctx, conn, channels, withPattern)
|
||||
|
||||
return pubsub.Subscribe(ctx, conn, channels, withPattern), nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
@@ -48,6 +49,7 @@ func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn
|
||||
return nil, errors.New(utils.WrongArgsResponse)
|
||||
}
|
||||
pubsub.Publish(ctx, cmd[2], cmd[1])
|
||||
fmt.Println("PUBLISHED:", cmd[2])
|
||||
return []byte(utils.OkResponse), nil
|
||||
}
|
||||
|
||||
|
@@ -10,6 +10,7 @@ import (
|
||||
"net"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var pubsub *PubSub
|
||||
@@ -22,6 +23,7 @@ func init() {
|
||||
pubsub = NewPubSub()
|
||||
mockServer = server.NewServer(server.Opts{
|
||||
PubSub: pubsub,
|
||||
Commands: Commands(),
|
||||
Config: utils.Config{
|
||||
BindAddr: bindAddr,
|
||||
Port: port,
|
||||
@@ -37,7 +39,7 @@ func init() {
|
||||
func Test_HandleSubscribe(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), "test_name", "SUBSCRIBE/PSUBSCRIBE")
|
||||
|
||||
numOfConnection := 100
|
||||
numOfConnection := 20
|
||||
connections := make([]*net.Conn, numOfConnection)
|
||||
|
||||
for i := 0; i < numOfConnection; i++ {
|
||||
@@ -47,6 +49,13 @@ func Test_HandleSubscribe(t *testing.T) {
|
||||
}
|
||||
connections[i] = &conn
|
||||
}
|
||||
defer func() {
|
||||
for _, conn := range connections {
|
||||
if err := (*conn).Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Test subscribe to channels
|
||||
channels := []string{"sub_channel1", "sub_channel2", "sub_channel3"}
|
||||
@@ -70,7 +79,7 @@ func Test_HandleSubscribe(t *testing.T) {
|
||||
}
|
||||
// Check if the channel has all the connections from above
|
||||
for _, conn := range connections {
|
||||
if !slices.Contains(c.subscribers, conn) {
|
||||
if _, ok := c.subscribers[conn]; !ok {
|
||||
t.Errorf("could not find all expected connection in the \"%s\"", channel)
|
||||
}
|
||||
}
|
||||
@@ -100,7 +109,7 @@ func Test_HandleSubscribe(t *testing.T) {
|
||||
}
|
||||
// Check if the channel has all the connections from above
|
||||
for _, conn := range connections {
|
||||
if !slices.Contains(c.subscribers, conn) {
|
||||
if _, ok := c.subscribers[conn]; !ok {
|
||||
t.Errorf("could not find all expected connection in the \"%s\"", pattern)
|
||||
}
|
||||
}
|
||||
@@ -122,6 +131,14 @@ func Test_HandleUnsubscribe(t *testing.T) {
|
||||
return connections
|
||||
}
|
||||
|
||||
closeConnections := func(conns []*net.Conn) {
|
||||
for _, conn := range conns {
|
||||
if err := (*conn).Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
verifyResponse := func(res []byte, expectedResponse [][]string) {
|
||||
rd := resp.NewReader(bytes.NewReader(res))
|
||||
rv, _, err := rd.ReadValue()
|
||||
@@ -130,6 +147,7 @@ func Test_HandleUnsubscribe(t *testing.T) {
|
||||
}
|
||||
v := rv.Array()
|
||||
if len(v) != len(expectedResponse) {
|
||||
fmt.Println(v)
|
||||
t.Errorf("expected subscribe response of length %d, but got %d", len(expectedResponse), len(v))
|
||||
}
|
||||
for _, item := range v {
|
||||
@@ -247,11 +265,11 @@ func Test_HandleUnsubscribe(t *testing.T) {
|
||||
for _, pubsubChannel := range pubsub.channels {
|
||||
if pubsubChannel.name == channel {
|
||||
// Assert that target connection is no longer in the unsub channels and patterns
|
||||
if slices.Contains(pubsubChannel.subscribers, test.targetConn) {
|
||||
if _, ok := pubsubChannel.subscribers[test.targetConn]; ok {
|
||||
t.Errorf("found unexpected target connection after unsubscrining in channel \"%s\"", channel)
|
||||
}
|
||||
for _, conn := range test.otherConnections {
|
||||
if !slices.Contains(pubsubChannel.subscribers, conn) {
|
||||
if _, ok := pubsubChannel.subscribers[conn]; !ok {
|
||||
t.Errorf("did not find expected other connection in channel \"%s\"", channel)
|
||||
}
|
||||
}
|
||||
@@ -263,16 +281,198 @@ func Test_HandleUnsubscribe(t *testing.T) {
|
||||
for _, channel := range append(test.remainChannels, test.remainPatterns...) {
|
||||
for _, pubsubChannel := range pubsub.channels {
|
||||
if pubsubChannel.name == channel {
|
||||
if !slices.Contains(pubsubChannel.subscribers, test.targetConn) {
|
||||
t.Errorf("cound not find expected target connection in channel \"%s\"", channel)
|
||||
}
|
||||
if _, ok := pubsubChannel.subscribers[test.targetConn]; !ok {
|
||||
t.Errorf("could not find expected target connection in channel \"%s\"", channel)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandlePublish(t *testing.T) {}
|
||||
for _, test := range tests {
|
||||
// Close all the connections
|
||||
closeConnections(append(test.otherConnections, test.targetConn))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandlePublish(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), "test_name", "PUBLISH")
|
||||
|
||||
// verifyChannelMessage reads the message from the connection and asserts whether
|
||||
// it's the message we expect to read as a subscriber of a channel or pattern.
|
||||
verifyEvent := func(c *net.Conn, r *resp.Conn, expected []string) {
|
||||
if err := (*c).SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
rv, _, err := r.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
v := rv.Array()
|
||||
for i := 0; i < len(v); i++ {
|
||||
if v[i].String() != expected[i] {
|
||||
t.Errorf("expected item at index %d to be \"%s\", got \"%s\"", i, expected[i], v[i].String())
|
||||
}
|
||||
}
|
||||
fmt.Println(v)
|
||||
}
|
||||
|
||||
// The subscribe function handles subscribing the connection to the given
|
||||
// channels and patterns and reading/verifying the message sent by the server after
|
||||
// subscription.
|
||||
subscribe := func(ctx context.Context, channels []string, patterns []string, c *net.Conn, r *resp.Conn) {
|
||||
// Subscribe to channels
|
||||
go func() {
|
||||
_, _ = handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, c)
|
||||
}()
|
||||
// Verify all the responses for each channel subscription
|
||||
for i := 0; i < len(channels); i++ {
|
||||
verifyEvent(c, r, []string{"subscribe", channels[i], fmt.Sprintf("%d", i+1)})
|
||||
}
|
||||
// Subscribe to all the patterns
|
||||
go func() {
|
||||
_, _ = handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, patterns...), mockServer, c)
|
||||
}()
|
||||
// Verify all the responses for each pattern subscription
|
||||
for i := 0; i < len(patterns); i++ {
|
||||
verifyEvent(c, r, []string{"subscribe", patterns[i], fmt.Sprintf("%d", i+1)})
|
||||
}
|
||||
}
|
||||
|
||||
subscriptions := map[string]map[string][]string{
|
||||
"subscriber1": {
|
||||
"channels": {"pub_channel_1", "pub_channel_2", "pub_channel_3"}, // Channels to subscribe to
|
||||
"patterns": {"pub_channel_[456]"}, // Patterns to subscribe to
|
||||
},
|
||||
"subscriber2": {
|
||||
"channels": {"pub_channel_6", "pub_channel_7"}, // Channels to subscribe to
|
||||
"patterns": {"pub_channel_[891]"}, // Patterns to subscribe to
|
||||
},
|
||||
}
|
||||
|
||||
// Create subscriber one and subscribe to channels and patterns
|
||||
r1, w1 := net.Pipe()
|
||||
rc1 := resp.NewConn(r1)
|
||||
subscribe(ctx, subscriptions["subscriber1"]["channels"], subscriptions["subscriber1"]["patterns"], &w1, rc1)
|
||||
|
||||
// Create subscriber two and subscribe to channels and patterns
|
||||
r2, w2 := net.Pipe()
|
||||
rc2 := resp.NewConn(r2)
|
||||
subscribe(ctx, subscriptions["subscriber2"]["channels"], subscriptions["subscriber2"]["patterns"], &w2, rc2)
|
||||
|
||||
type SubscriberType struct {
|
||||
c *net.Conn
|
||||
r *resp.Conn
|
||||
l string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
channel string
|
||||
message string
|
||||
subscribers []SubscriberType
|
||||
}{
|
||||
{
|
||||
channel: "pub_channel_1",
|
||||
message: "Test both subscribers 1",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r1, r: rc1, l: "pub_channel_1"},
|
||||
{c: &r2, r: rc2, l: "pub_channel_[891]"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_6",
|
||||
message: "Test both subscribers 2",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r1, r: rc1, l: "pub_channel_[456]"},
|
||||
{c: &r2, r: rc2, l: "pub_channel_6"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_2",
|
||||
message: "Test subscriber 1 1",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r1, r: rc1, l: "pub_channel_2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_3",
|
||||
message: "Test subscriber 1 2",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r1, r: rc1, l: "pub_channel_3"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_4",
|
||||
message: "Test both subscribers 2",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r1, r: rc1, l: "pub_channel_[456]"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_5",
|
||||
message: "Test subscriber 1 3",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r1, r: rc1, l: "pub_channel_[456]"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_7",
|
||||
message: "Test subscriber 2 1",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r2, r: rc2, l: "pub_channel_7"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_8",
|
||||
message: "Test subscriber 2 2",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r1, r: rc2, l: "pub_channel_[891]"},
|
||||
},
|
||||
},
|
||||
{
|
||||
channel: "pub_channel_9",
|
||||
message: "Test subscriber 2 3",
|
||||
subscribers: []SubscriberType{
|
||||
{c: &r2, r: rc2, l: "pub_channel_[891]"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Dial server to make publisher connection
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer func() {
|
||||
if err = conn.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
w := resp.NewConn(conn)
|
||||
|
||||
for _, test := range tests {
|
||||
err = w.WriteArray([]resp.Value{
|
||||
resp.StringValue("PUBLISH"),
|
||||
resp.StringValue(test.channel),
|
||||
resp.StringValue(test.message),
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
rv, _, err := w.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if rv.String() != "OK" {
|
||||
t.Errorf("Expected publish response to be \"OK\", got \"%s\"", rv.String())
|
||||
}
|
||||
|
||||
for _, sub := range test.subscribers {
|
||||
verifyEvent(sub.c, sub.r, []string{"message", sub.l, test.message})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandlePubSubChannels(t *testing.T) {}
|
||||
|
||||
|
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/gobwas/glob"
|
||||
"github.com/tidwall/resp"
|
||||
"log"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
@@ -22,9 +24,8 @@ func NewPubSub() *PubSub {
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte {
|
||||
res := fmt.Sprintf("*%d\r\n", len(channels))
|
||||
|
||||
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) {
|
||||
r := resp.NewConn(*conn)
|
||||
for i := 0; i < len(channels); i++ {
|
||||
// Check if channel with given name exists
|
||||
// If it does, subscribe the connection to the channel
|
||||
@@ -42,23 +43,29 @@ func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []stri
|
||||
newChan = NewChannel(WithName(channels[i]))
|
||||
}
|
||||
newChan.Start()
|
||||
newChan.Subscribe(conn)
|
||||
if newChan.Subscribe(conn) {
|
||||
if err := r.WriteArray([]resp.Value{
|
||||
resp.StringValue("subscribe"),
|
||||
resp.StringValue(newChan.name),
|
||||
resp.IntegerValue(i + 1),
|
||||
}); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
ps.channels = append(ps.channels, newChan)
|
||||
} else {
|
||||
// Subscribe to existing channel
|
||||
ps.channels[channelIdx].Subscribe(conn)
|
||||
}
|
||||
|
||||
if len(channels) > 1 {
|
||||
// If subscribing to more than one channel, write array to verify the subscription of this channel
|
||||
res += fmt.Sprintf("*3\r\n+subscribe\r\n$%d\r\n%s\r\n:%d\r\n", len(channels[i]), channels[i], i+1)
|
||||
} else {
|
||||
// Ony one channel, simply send "subscribe" simple string response
|
||||
res = "+subscribe\r\n"
|
||||
if ps.channels[channelIdx].Subscribe(conn) {
|
||||
if err := r.WriteArray([]resp.Value{
|
||||
resp.StringValue("subscribe"),
|
||||
resp.StringValue(ps.channels[channelIdx].name),
|
||||
resp.IntegerValue(i + 1),
|
||||
}); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return []byte(res)
|
||||
}
|
||||
|
||||
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte {
|
||||
@@ -138,8 +145,6 @@ func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channels []st
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("UNSUBBED: ", unsubscribed)
|
||||
|
||||
res := fmt.Sprintf("*%d\r\n", len(unsubscribed))
|
||||
for key, value := range unsubscribed {
|
||||
res += fmt.Sprintf("*3\r\n+%s\r\n$%d\r\n%s\r\n:%d\r\n", action, len(value), value, key)
|
||||
|
@@ -51,11 +51,13 @@ func (server *Server) handleCommand(ctx context.Context, message []byte, conn *n
|
||||
}
|
||||
|
||||
if conn != nil {
|
||||
// Authorize connection if it's provided
|
||||
// Authorize connection if it's provided and if ACL module is present
|
||||
if server.ACL != nil {
|
||||
if err = server.ACL.AuthorizeConnection(conn, cmd, command, subCommand); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the command is a write command, wait for state copy to finish.
|
||||
if utils.IsWriteCommand(command, subCommand) {
|
||||
|
Reference in New Issue
Block a user