Added PUBSUB commands and made pubsub module more compatible with redis-cli client

This commit is contained in:
Kelvin Clement Mwinuka
2024-02-28 10:14:46 +08:00
parent 6566bb41c4
commit e18257e600
5 changed files with 241 additions and 47 deletions

View File

@@ -144,7 +144,6 @@ func handleLRange(ctx context.Context, cmd []string, server utils.Server, conn *
} else {
i--
}
}
return bytes, nil

View File

@@ -83,19 +83,31 @@ func (ch *Channel) Subscribe(conn *net.Conn) {
}
}
func (ch *Channel) Unsubscribe(conn *net.Conn, waitGroup *sync.WaitGroup) {
func (ch *Channel) Unsubscribe(conn *net.Conn) bool {
ch.subscribersRWMut.Lock()
defer ch.subscribersRWMut.Unlock()
var removed bool
ch.subscribers = slices.DeleteFunc(ch.subscribers, func(c *net.Conn) bool {
return c == conn
if c == conn {
removed = true
return true
}
return false
})
if waitGroup != nil {
waitGroup.Done()
}
return removed
}
func (ch *Channel) Publish(message string) {
*ch.messageChan <- message
}
func (ch *Channel) IsActive() bool {
return len(ch.subscribers) > 0
}
func (ch *Channel) NumSubs() int {
return len(ch.subscribers)
}

View File

@@ -3,6 +3,7 @@ package pubsub
import (
"context"
"errors"
"fmt"
"github.com/echovault/echovault/src/utils"
"net"
"strings"
@@ -22,9 +23,9 @@ func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, con
switch strings.ToLower(cmd[0]) {
case "subscribe":
pubsub.Subscribe(ctx, conn, channels, false)
return pubsub.Subscribe(ctx, conn, channels, false), nil
case "psubscribe":
pubsub.Subscribe(ctx, conn, channels, true)
return pubsub.Subscribe(ctx, conn, channels, true), nil
}
return []byte{}, nil
@@ -38,16 +39,14 @@ func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, c
channels := cmd[1:]
if len(channels) == 0 {
pubsub.Unsubscribe(ctx, conn, "*")
return []byte(utils.OK_RESPONSE), nil
switch strings.ToLower(cmd[0]) {
case "unsubscribe":
return pubsub.Unsubscribe(ctx, conn, channels, false), nil
case "punsubscribe":
return pubsub.Unsubscribe(ctx, conn, channels, true), nil
default:
return []byte{}, nil
}
for _, channel := range channels {
pubsub.Unsubscribe(ctx, conn, channel)
}
return []byte(utils.OK_RESPONSE), nil
}
func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -62,6 +61,41 @@ func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn
return []byte(utils.OK_RESPONSE), nil
}
func handlePubSubChannels(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
if len(cmd) > 3 {
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
}
pubsub, ok := server.GetPubSub().(*PubSub)
if !ok {
return nil, errors.New("could not load pubsub module")
}
pattern := ""
if len(cmd) == 3 {
pattern = cmd[2]
}
return pubsub.Channels(ctx, pattern), nil
}
func handlePubSubNumPat(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*PubSub)
if !ok {
return nil, errors.New("could not load pubsub module")
}
num := pubsub.NumPat(ctx)
return []byte(fmt.Sprintf(":%d\r\n", num)), nil
}
func handlePubSubNumSubs(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*PubSub)
if !ok {
return nil, errors.New("could not load pubsub module")
}
return pubsub.NumSub(ctx, cmd[2:]), nil
}
func Commands() []utils.Command {
return []utils.Command{
{
@@ -119,5 +153,57 @@ it's currently subscribe to.`,
},
HandlerFunc: handleUnsubscribe,
},
{
Command: "punsubscribe",
Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory},
Description: `(PUNSUBSCRIBE [channel [channel ...]]) Unsubscribe from a list of channels using patterns.
If the pattern list is not provided, then the connection will be unsubscribed from all the patterns that
it's currently subscribe to.`,
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
// Treat the channels as keys
return cmd[1:], nil
},
HandlerFunc: handleUnsubscribe,
},
{
Command: "pubsub",
Categories: []string{},
Description: "",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
HandlerFunc: func(_ context.Context, _ []string, _ utils.Server, _ *net.Conn) ([]byte, error) {
return nil, errors.New("provide CHANNELS, NUMPAT, or NUMSUB subcommand")
},
SubCommands: []utils.SubCommand{
{
Command: "channels",
Categories: []string{utils.PubSubCategory, utils.SlowCategory},
Description: `(PUBSUB CHANNELS [pattern]) Returns an array containing the list of channels that
match the given pattern. If no pattern is provided, all active channels are returned. Active channels are
channels with 1 or more subscribers.`,
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
HandlerFunc: handlePubSubChannels,
},
{
Command: "numpat",
Categories: []string{utils.PubSubCategory, utils.SlowCategory},
Description: `(PUBSUB NUMPAT) Return the number of patterns that are currently subscribed to by clients.`,
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
HandlerFunc: handlePubSubNumPat,
},
{
Command: "numsub",
Categories: []string{utils.PubSubCategory, utils.SlowCategory},
Description: `(PUBSUB NUMSUB [channel [channel ...]]) Return an array of arrays containing the provided
channel name and how many clients are currently subscribed to the channel.`,
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { return cmd[2:], nil },
HandlerFunc: handlePubSubNumSubs,
},
},
},
}
}

View File

@@ -3,8 +3,7 @@ package pubsub
import (
"context"
"fmt"
"io"
"log"
"github.com/gobwas/glob"
"net"
"slices"
"sync"
@@ -23,7 +22,8 @@ func NewPubSub() *PubSub {
}
}
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) {
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte {
res := fmt.Sprintf("*%d\r\n", len(channels))
for i := 0; i < len(channels); i++ {
// Check if channel with given name exists
@@ -49,52 +49,86 @@ func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []stri
ps.channels[channelIdx].Subscribe(conn)
}
var res string
if len(channels) > 1 {
// If subscribing to more than one channel, write array to verify the subscription of this channel
res = fmt.Sprintf("*3\r\n+subscribe\r\n$%d\r\n%s\r\n:%d\r\n", len(channels[i]), channels[i], i+1)
res += fmt.Sprintf("*3\r\n+subscribe\r\n$%d\r\n%s\r\n:%d\r\n", len(channels[i]), channels[i], i+1)
} else {
// Ony one channel, simply send "subscribe" simple string response
res = "+subscribe\r\n"
}
w := io.Writer(*conn)
if _, err := w.Write([]byte(res)); err != nil {
log.Println(err)
}
}
}
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName string) {
return []byte(res)
}
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte {
ps.channelsRWMut.RLock()
ps.channelsRWMut.RUnlock()
if channelName == "*" {
wg := &sync.WaitGroup{}
action := "unsubscribe"
if withPattern {
action = "subscribe"
}
unsubscribed := make(map[int]string)
count := 1
// If the channels slice is empty, unsubscribe from all channels.
if len(channels) <= 0 {
for _, channel := range ps.channels {
wg.Add(1)
go channel.Unsubscribe(conn, wg)
if channel.Unsubscribe(conn) {
unsubscribed[1] = channel.name
count += 1
}
}
wg.Wait()
return
}
channelIdx := slices.IndexFunc(ps.channels, func(channel *Channel) bool {
return channel.name == channelName
})
if channelIdx != -1 {
ps.channels[channelIdx].Unsubscribe(conn, nil)
// If withPattern is false, unsubscribe from channels where the name exactly matches channel name.
if !withPattern {
for _, channel := range ps.channels { // For each channel in PubSub
for _, c := range channels { // For each channel name provided
if channel.name == c && channel.Unsubscribe(conn) {
unsubscribed[count] = channel.name
count += 1
}
}
}
}
// If withPattern is true, unsubscribe from channels where pattern matches pattern provided,
// also unsubscribe from channels where the name matches the given pattern.
if withPattern {
for _, pattern := range channels {
g := glob.MustCompile(pattern)
for _, channel := range ps.channels {
// If it's a pattern channel, directly compare the patterns
if channel.pattern != nil && channel.name == pattern {
unsubscribed[count] = channel.name
count += 1
continue
}
// If this is a regular channel, check if the channel name matches the pattern given
if g.Match(channel.name) {
unsubscribed[count] = channel.name
count += 1
}
}
}
}
res := fmt.Sprintf("*%d\r\n", len(unsubscribed))
for key, value := range unsubscribed {
res += fmt.Sprintf("*3\r\n+%s\r\n$%d\r\n%s\r\n:%d\r\n", action, len(value), value, key)
}
return []byte(res)
}
func (ps *PubSub) Publish(ctx context.Context, message string, channelName string) {
ps.channelsRWMut.RLock()
defer ps.channelsRWMut.RUnlock()
for _, channel := range ps.channels {
fmt.Println(channel.name, channel.pattern)
// If it's a regular channel, check if the channel name matches the name given
if channel.pattern == nil {
if channel.name == channelName {
@@ -108,3 +142,62 @@ func (ps *PubSub) Publish(ctx context.Context, message string, channelName strin
}
}
}
func (ps *PubSub) Channels(ctx context.Context, pattern string) []byte {
var count int
var res string
if pattern == "" {
for _, channel := range ps.channels {
if channel.IsActive() {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name)
count += 1
}
}
res = fmt.Sprintf("*%d\r\n%s", count, res)
return []byte(res)
}
g := glob.MustCompile(pattern)
for _, channel := range ps.channels {
// If channel is a pattern channel, then directly compare the channel name to pattern
if channel.pattern != nil && channel.name == pattern && channel.IsActive() {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name)
count += 1
continue
}
if g.Match(channel.name) && channel.IsActive() {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name)
count += 1
}
}
return []byte(res)
}
func (ps *PubSub) NumPat(ctx context.Context) int {
var count int
for _, channel := range ps.channels {
if channel.pattern != nil {
count += 1
}
}
return count
}
func (ps *PubSub) NumSub(ctx context.Context, channels []string) []byte {
res := fmt.Sprintf("*%d\r\n", len(channels))
for _, channel := range channels {
chanIdx := slices.IndexFunc(ps.channels, func(c *Channel) bool {
return c.name == channel
})
if chanIdx == -1 {
res += fmt.Sprintf("*2\r\n$%d\r\n%s\r\n:0\r\n", len(channel), channel)
continue
}
res += fmt.Sprintf("*2\r\n$%d\r\n%s\r\n:%d\r\n", len(channel), channel, ps.channels[chanIdx].NumSubs())
}
return []byte(res)
}

View File

@@ -219,6 +219,10 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
res, err := server.handleCommand(ctx, message, &conn, false)
if err != nil && errors.Is(err, io.EOF) {
break
}
if err != nil {
if _, err = w.Write([]byte(fmt.Sprintf("-Error %s\r\n", err.Error()))); err != nil {
log.Println(err)
@@ -234,10 +238,7 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
}
if len(res) <= chunkSize {
_, err = w.Write(res)
if err != nil {
log.Println(err)
}
_, _ = w.Write(res)
continue
}
@@ -246,7 +247,10 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
for {
// If the current start index is less than chunkSize from length, return the remaining bytes.
if len(res)-1-startIndex < chunkSize {
_, _ = w.Write(res[startIndex:])
_, err = w.Write(res[startIndex:])
if err != nil {
log.Println(err)
}
break
}
n, _ := w.Write(res[startIndex : startIndex+chunkSize])