diff --git a/.gitignore b/.gitignore index 50ad5e2..f7e24c4 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ volumes /coverage/ dist/ -src/modules/*/aof \ No newline at end of file +src/modules/*/aof +dump.rdb \ No newline at end of file diff --git a/src/modules/pubsub/channel.go b/src/modules/pubsub/channel.go index 41ae70f..3aeedcc 100644 --- a/src/modules/pubsub/channel.go +++ b/src/modules/pubsub/channel.go @@ -1,12 +1,10 @@ package pubsub import ( - "fmt" "github.com/gobwas/glob" - "io" + "github.com/tidwall/resp" "log" "net" - "slices" "sync" ) @@ -17,7 +15,7 @@ type Channel struct { name string pattern glob.Glob subscribersRWMut sync.RWMutex - subscribers []*net.Conn + subscribers map[*net.Conn]*resp.Conn messageChan *chan string } @@ -41,7 +39,7 @@ func NewChannel(options ...func(channel *Channel)) *Channel { name: "", pattern: nil, subscribersRWMut: sync.RWMutex{}, - subscribers: []*net.Conn{}, + subscribers: make(map[*net.Conn]*resp.Conn), messageChan: &messageChan, } @@ -60,10 +58,12 @@ func (ch *Channel) Start() { ch.subscribersRWMut.RLock() for _, conn := range ch.subscribers { - go func(conn *net.Conn) { - w := io.Writer(*conn) - - if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(message), message))); err != nil { + go func(conn *resp.Conn) { + if err := conn.WriteArray([]resp.Value{ + resp.StringValue("message"), + resp.StringValue(ch.name), + resp.StringValue(message), + }); err != nil { log.Println(err) } }(conn) @@ -74,30 +74,24 @@ func (ch *Channel) Start() { }() } -func (ch *Channel) Subscribe(conn *net.Conn) { - if !slices.Contains(ch.subscribers, conn) { - ch.subscribersRWMut.Lock() - defer ch.subscribersRWMut.Unlock() - - ch.subscribers = append(ch.subscribers, conn) +func (ch *Channel) Subscribe(conn *net.Conn) bool { + ch.subscribersRWMut.Lock() + defer ch.subscribersRWMut.Unlock() + if _, ok := ch.subscribers[conn]; !ok { + ch.subscribers[conn] = resp.NewConn(*conn) } + _, ok := ch.subscribers[conn] + return ok } func (ch *Channel) Unsubscribe(conn *net.Conn) bool { ch.subscribersRWMut.Lock() defer ch.subscribersRWMut.Unlock() - - var removed bool - - ch.subscribers = slices.DeleteFunc(ch.subscribers, func(c *net.Conn) bool { - if c == conn { - removed = true - return true - } + if _, ok := ch.subscribers[conn]; !ok { return false - }) - - return removed + } + delete(ch.subscribers, conn) + return true } func (ch *Channel) Publish(message string) { diff --git a/src/modules/pubsub/commands.go b/src/modules/pubsub/commands.go index 9907132..05f21ae 100644 --- a/src/modules/pubsub/commands.go +++ b/src/modules/pubsub/commands.go @@ -22,8 +22,9 @@ func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, con } withPattern := strings.EqualFold(cmd[0], "psubscribe") + pubsub.Subscribe(ctx, conn, channels, withPattern) - return pubsub.Subscribe(ctx, conn, channels, withPattern), nil + return nil, nil } func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -48,6 +49,7 @@ func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn return nil, errors.New(utils.WrongArgsResponse) } pubsub.Publish(ctx, cmd[2], cmd[1]) + fmt.Println("PUBLISHED:", cmd[2]) return []byte(utils.OkResponse), nil } diff --git a/src/modules/pubsub/commands_test.go b/src/modules/pubsub/commands_test.go index feb6e70..fe5e9f3 100644 --- a/src/modules/pubsub/commands_test.go +++ b/src/modules/pubsub/commands_test.go @@ -10,6 +10,7 @@ import ( "net" "slices" "testing" + "time" ) var pubsub *PubSub @@ -21,7 +22,8 @@ var port uint16 = 7490 func init() { pubsub = NewPubSub() mockServer = server.NewServer(server.Opts{ - PubSub: pubsub, + PubSub: pubsub, + Commands: Commands(), Config: utils.Config{ BindAddr: bindAddr, Port: port, @@ -37,7 +39,7 @@ func init() { func Test_HandleSubscribe(t *testing.T) { ctx := context.WithValue(context.Background(), "test_name", "SUBSCRIBE/PSUBSCRIBE") - numOfConnection := 100 + numOfConnection := 20 connections := make([]*net.Conn, numOfConnection) for i := 0; i < numOfConnection; i++ { @@ -47,6 +49,13 @@ func Test_HandleSubscribe(t *testing.T) { } connections[i] = &conn } + defer func() { + for _, conn := range connections { + if err := (*conn).Close(); err != nil { + t.Error(err) + } + } + }() // Test subscribe to channels channels := []string{"sub_channel1", "sub_channel2", "sub_channel3"} @@ -70,7 +79,7 @@ func Test_HandleSubscribe(t *testing.T) { } // Check if the channel has all the connections from above for _, conn := range connections { - if !slices.Contains(c.subscribers, conn) { + if _, ok := c.subscribers[conn]; !ok { t.Errorf("could not find all expected connection in the \"%s\"", channel) } } @@ -100,7 +109,7 @@ func Test_HandleSubscribe(t *testing.T) { } // Check if the channel has all the connections from above for _, conn := range connections { - if !slices.Contains(c.subscribers, conn) { + if _, ok := c.subscribers[conn]; !ok { t.Errorf("could not find all expected connection in the \"%s\"", pattern) } } @@ -122,6 +131,14 @@ func Test_HandleUnsubscribe(t *testing.T) { return connections } + closeConnections := func(conns []*net.Conn) { + for _, conn := range conns { + if err := (*conn).Close(); err != nil { + t.Error(err) + } + } + } + verifyResponse := func(res []byte, expectedResponse [][]string) { rd := resp.NewReader(bytes.NewReader(res)) rv, _, err := rd.ReadValue() @@ -130,6 +147,7 @@ func Test_HandleUnsubscribe(t *testing.T) { } v := rv.Array() if len(v) != len(expectedResponse) { + fmt.Println(v) t.Errorf("expected subscribe response of length %d, but got %d", len(expectedResponse), len(v)) } for _, item := range v { @@ -247,11 +265,11 @@ func Test_HandleUnsubscribe(t *testing.T) { for _, pubsubChannel := range pubsub.channels { if pubsubChannel.name == channel { // Assert that target connection is no longer in the unsub channels and patterns - if slices.Contains(pubsubChannel.subscribers, test.targetConn) { + if _, ok := pubsubChannel.subscribers[test.targetConn]; ok { t.Errorf("found unexpected target connection after unsubscrining in channel \"%s\"", channel) } for _, conn := range test.otherConnections { - if !slices.Contains(pubsubChannel.subscribers, conn) { + if _, ok := pubsubChannel.subscribers[conn]; !ok { t.Errorf("did not find expected other connection in channel \"%s\"", channel) } } @@ -263,16 +281,198 @@ func Test_HandleUnsubscribe(t *testing.T) { for _, channel := range append(test.remainChannels, test.remainPatterns...) { for _, pubsubChannel := range pubsub.channels { if pubsubChannel.name == channel { - if !slices.Contains(pubsubChannel.subscribers, test.targetConn) { - t.Errorf("cound not find expected target connection in channel \"%s\"", channel) + if _, ok := pubsubChannel.subscribers[test.targetConn]; !ok { + t.Errorf("could not find expected target connection in channel \"%s\"", channel) } } } } } + + for _, test := range tests { + // Close all the connections + closeConnections(append(test.otherConnections, test.targetConn)) + } } -func Test_HandlePublish(t *testing.T) {} +func Test_HandlePublish(t *testing.T) { + ctx := context.WithValue(context.Background(), "test_name", "PUBLISH") + + // verifyChannelMessage reads the message from the connection and asserts whether + // it's the message we expect to read as a subscriber of a channel or pattern. + verifyEvent := func(c *net.Conn, r *resp.Conn, expected []string) { + if err := (*c).SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { + t.Error(err) + } + rv, _, err := r.ReadValue() + if err != nil { + t.Error(err) + } + v := rv.Array() + for i := 0; i < len(v); i++ { + if v[i].String() != expected[i] { + t.Errorf("expected item at index %d to be \"%s\", got \"%s\"", i, expected[i], v[i].String()) + } + } + fmt.Println(v) + } + + // The subscribe function handles subscribing the connection to the given + // channels and patterns and reading/verifying the message sent by the server after + // subscription. + subscribe := func(ctx context.Context, channels []string, patterns []string, c *net.Conn, r *resp.Conn) { + // Subscribe to channels + go func() { + _, _ = handleSubscribe(ctx, append([]string{"SUBSCRIBE"}, channels...), mockServer, c) + }() + // Verify all the responses for each channel subscription + for i := 0; i < len(channels); i++ { + verifyEvent(c, r, []string{"subscribe", channels[i], fmt.Sprintf("%d", i+1)}) + } + // Subscribe to all the patterns + go func() { + _, _ = handleSubscribe(ctx, append([]string{"PSUBSCRIBE"}, patterns...), mockServer, c) + }() + // Verify all the responses for each pattern subscription + for i := 0; i < len(patterns); i++ { + verifyEvent(c, r, []string{"subscribe", patterns[i], fmt.Sprintf("%d", i+1)}) + } + } + + subscriptions := map[string]map[string][]string{ + "subscriber1": { + "channels": {"pub_channel_1", "pub_channel_2", "pub_channel_3"}, // Channels to subscribe to + "patterns": {"pub_channel_[456]"}, // Patterns to subscribe to + }, + "subscriber2": { + "channels": {"pub_channel_6", "pub_channel_7"}, // Channels to subscribe to + "patterns": {"pub_channel_[891]"}, // Patterns to subscribe to + }, + } + + // Create subscriber one and subscribe to channels and patterns + r1, w1 := net.Pipe() + rc1 := resp.NewConn(r1) + subscribe(ctx, subscriptions["subscriber1"]["channels"], subscriptions["subscriber1"]["patterns"], &w1, rc1) + + // Create subscriber two and subscribe to channels and patterns + r2, w2 := net.Pipe() + rc2 := resp.NewConn(r2) + subscribe(ctx, subscriptions["subscriber2"]["channels"], subscriptions["subscriber2"]["patterns"], &w2, rc2) + + type SubscriberType struct { + c *net.Conn + r *resp.Conn + l string + } + + tests := []struct { + channel string + message string + subscribers []SubscriberType + }{ + { + channel: "pub_channel_1", + message: "Test both subscribers 1", + subscribers: []SubscriberType{ + {c: &r1, r: rc1, l: "pub_channel_1"}, + {c: &r2, r: rc2, l: "pub_channel_[891]"}, + }, + }, + { + channel: "pub_channel_6", + message: "Test both subscribers 2", + subscribers: []SubscriberType{ + {c: &r1, r: rc1, l: "pub_channel_[456]"}, + {c: &r2, r: rc2, l: "pub_channel_6"}, + }, + }, + { + channel: "pub_channel_2", + message: "Test subscriber 1 1", + subscribers: []SubscriberType{ + {c: &r1, r: rc1, l: "pub_channel_2"}, + }, + }, + { + channel: "pub_channel_3", + message: "Test subscriber 1 2", + subscribers: []SubscriberType{ + {c: &r1, r: rc1, l: "pub_channel_3"}, + }, + }, + { + channel: "pub_channel_4", + message: "Test both subscribers 2", + subscribers: []SubscriberType{ + {c: &r1, r: rc1, l: "pub_channel_[456]"}, + }, + }, + { + channel: "pub_channel_5", + message: "Test subscriber 1 3", + subscribers: []SubscriberType{ + {c: &r1, r: rc1, l: "pub_channel_[456]"}, + }, + }, + { + channel: "pub_channel_7", + message: "Test subscriber 2 1", + subscribers: []SubscriberType{ + {c: &r2, r: rc2, l: "pub_channel_7"}, + }, + }, + { + channel: "pub_channel_8", + message: "Test subscriber 2 2", + subscribers: []SubscriberType{ + {c: &r1, r: rc2, l: "pub_channel_[891]"}, + }, + }, + { + channel: "pub_channel_9", + message: "Test subscriber 2 3", + subscribers: []SubscriberType{ + {c: &r2, r: rc2, l: "pub_channel_[891]"}, + }, + }, + } + + // Dial server to make publisher connection + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) + if err != nil { + t.Error(err) + } + defer func() { + if err = conn.Close(); err != nil { + t.Error(err) + } + }() + w := resp.NewConn(conn) + + for _, test := range tests { + err = w.WriteArray([]resp.Value{ + resp.StringValue("PUBLISH"), + resp.StringValue(test.channel), + resp.StringValue(test.message), + }) + if err != nil { + t.Error(err) + } + + rv, _, err := w.ReadValue() + if err != nil { + t.Error(err) + } + if rv.String() != "OK" { + t.Errorf("Expected publish response to be \"OK\", got \"%s\"", rv.String()) + } + + for _, sub := range test.subscribers { + verifyEvent(sub.c, sub.r, []string{"message", sub.l, test.message}) + } + } +} func Test_HandlePubSubChannels(t *testing.T) {} diff --git a/src/modules/pubsub/pubsub.go b/src/modules/pubsub/pubsub.go index 813d4c1..9fecddc 100644 --- a/src/modules/pubsub/pubsub.go +++ b/src/modules/pubsub/pubsub.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "github.com/gobwas/glob" + "github.com/tidwall/resp" + "log" "net" "slices" "sync" @@ -22,9 +24,8 @@ func NewPubSub() *PubSub { } } -func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte { - res := fmt.Sprintf("*%d\r\n", len(channels)) - +func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) { + r := resp.NewConn(*conn) for i := 0; i < len(channels); i++ { // Check if channel with given name exists // If it does, subscribe the connection to the channel @@ -42,23 +43,29 @@ func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []stri newChan = NewChannel(WithName(channels[i])) } newChan.Start() - newChan.Subscribe(conn) + if newChan.Subscribe(conn) { + if err := r.WriteArray([]resp.Value{ + resp.StringValue("subscribe"), + resp.StringValue(newChan.name), + resp.IntegerValue(i + 1), + }); err != nil { + log.Println(err) + } + } ps.channels = append(ps.channels, newChan) } else { // Subscribe to existing channel - ps.channels[channelIdx].Subscribe(conn) - } - - if len(channels) > 1 { - // If subscribing to more than one channel, write array to verify the subscription of this channel - res += fmt.Sprintf("*3\r\n+subscribe\r\n$%d\r\n%s\r\n:%d\r\n", len(channels[i]), channels[i], i+1) - } else { - // Ony one channel, simply send "subscribe" simple string response - res = "+subscribe\r\n" + if ps.channels[channelIdx].Subscribe(conn) { + if err := r.WriteArray([]resp.Value{ + resp.StringValue("subscribe"), + resp.StringValue(ps.channels[channelIdx].name), + resp.IntegerValue(i + 1), + }); err != nil { + log.Println(err) + } + } } } - - return []byte(res) } func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte { @@ -138,8 +145,6 @@ func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channels []st } } - fmt.Println("UNSUBBED: ", unsubscribed) - res := fmt.Sprintf("*%d\r\n", len(unsubscribed)) for key, value := range unsubscribed { res += fmt.Sprintf("*3\r\n+%s\r\n$%d\r\n%s\r\n:%d\r\n", action, len(value), value, key) diff --git a/src/server/modules.go b/src/server/modules.go index 2f86889..a31a74a 100644 --- a/src/server/modules.go +++ b/src/server/modules.go @@ -51,9 +51,11 @@ func (server *Server) handleCommand(ctx context.Context, message []byte, conn *n } if conn != nil { - // Authorize connection if it's provided - if err = server.ACL.AuthorizeConnection(conn, cmd, command, subCommand); err != nil { - return nil, err + // Authorize connection if it's provided and if ACL module is present + if server.ACL != nil { + if err = server.ACL.AuthorizeConnection(conn, cmd, command, subCommand); err != nil { + return nil, err + } } }