mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-21 06:59:24 +08:00
Added PUBSUB commands and made pubsub module more compatible with redis-cli client
This commit is contained in:
@@ -144,7 +144,6 @@ func handleLRange(ctx context.Context, cmd []string, server utils.Server, conn *
|
||||
} else {
|
||||
i--
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return bytes, nil
|
||||
|
@@ -83,19 +83,31 @@ func (ch *Channel) Subscribe(conn *net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *Channel) Unsubscribe(conn *net.Conn, waitGroup *sync.WaitGroup) {
|
||||
func (ch *Channel) Unsubscribe(conn *net.Conn) bool {
|
||||
ch.subscribersRWMut.Lock()
|
||||
defer ch.subscribersRWMut.Unlock()
|
||||
|
||||
var removed bool
|
||||
|
||||
ch.subscribers = slices.DeleteFunc(ch.subscribers, func(c *net.Conn) bool {
|
||||
return c == conn
|
||||
if c == conn {
|
||||
removed = true
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
if waitGroup != nil {
|
||||
waitGroup.Done()
|
||||
}
|
||||
return removed
|
||||
}
|
||||
|
||||
func (ch *Channel) Publish(message string) {
|
||||
*ch.messageChan <- message
|
||||
}
|
||||
|
||||
func (ch *Channel) IsActive() bool {
|
||||
return len(ch.subscribers) > 0
|
||||
}
|
||||
|
||||
func (ch *Channel) NumSubs() int {
|
||||
return len(ch.subscribers)
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@ package pubsub
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/src/utils"
|
||||
"net"
|
||||
"strings"
|
||||
@@ -22,9 +23,9 @@ func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, con
|
||||
|
||||
switch strings.ToLower(cmd[0]) {
|
||||
case "subscribe":
|
||||
pubsub.Subscribe(ctx, conn, channels, false)
|
||||
return pubsub.Subscribe(ctx, conn, channels, false), nil
|
||||
case "psubscribe":
|
||||
pubsub.Subscribe(ctx, conn, channels, true)
|
||||
return pubsub.Subscribe(ctx, conn, channels, true), nil
|
||||
}
|
||||
|
||||
return []byte{}, nil
|
||||
@@ -38,16 +39,14 @@ func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, c
|
||||
|
||||
channels := cmd[1:]
|
||||
|
||||
if len(channels) == 0 {
|
||||
pubsub.Unsubscribe(ctx, conn, "*")
|
||||
return []byte(utils.OK_RESPONSE), nil
|
||||
switch strings.ToLower(cmd[0]) {
|
||||
case "unsubscribe":
|
||||
return pubsub.Unsubscribe(ctx, conn, channels, false), nil
|
||||
case "punsubscribe":
|
||||
return pubsub.Unsubscribe(ctx, conn, channels, true), nil
|
||||
default:
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
for _, channel := range channels {
|
||||
pubsub.Unsubscribe(ctx, conn, channel)
|
||||
}
|
||||
|
||||
return []byte(utils.OK_RESPONSE), nil
|
||||
}
|
||||
|
||||
func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
@@ -62,6 +61,41 @@ func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn
|
||||
return []byte(utils.OK_RESPONSE), nil
|
||||
}
|
||||
|
||||
func handlePubSubChannels(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
if len(cmd) > 3 {
|
||||
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
|
||||
}
|
||||
|
||||
pubsub, ok := server.GetPubSub().(*PubSub)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load pubsub module")
|
||||
}
|
||||
|
||||
pattern := ""
|
||||
if len(cmd) == 3 {
|
||||
pattern = cmd[2]
|
||||
}
|
||||
|
||||
return pubsub.Channels(ctx, pattern), nil
|
||||
}
|
||||
|
||||
func handlePubSubNumPat(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
pubsub, ok := server.GetPubSub().(*PubSub)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load pubsub module")
|
||||
}
|
||||
num := pubsub.NumPat(ctx)
|
||||
return []byte(fmt.Sprintf(":%d\r\n", num)), nil
|
||||
}
|
||||
|
||||
func handlePubSubNumSubs(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
pubsub, ok := server.GetPubSub().(*PubSub)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load pubsub module")
|
||||
}
|
||||
return pubsub.NumSub(ctx, cmd[2:]), nil
|
||||
}
|
||||
|
||||
func Commands() []utils.Command {
|
||||
return []utils.Command{
|
||||
{
|
||||
@@ -119,5 +153,57 @@ it's currently subscribe to.`,
|
||||
},
|
||||
HandlerFunc: handleUnsubscribe,
|
||||
},
|
||||
{
|
||||
Command: "punsubscribe",
|
||||
Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory},
|
||||
Description: `(PUNSUBSCRIBE [channel [channel ...]]) Unsubscribe from a list of channels using patterns.
|
||||
If the pattern list is not provided, then the connection will be unsubscribed from all the patterns that
|
||||
it's currently subscribe to.`,
|
||||
Sync: false,
|
||||
KeyExtractionFunc: func(cmd []string) ([]string, error) {
|
||||
// Treat the channels as keys
|
||||
return cmd[1:], nil
|
||||
},
|
||||
HandlerFunc: handleUnsubscribe,
|
||||
},
|
||||
{
|
||||
Command: "pubsub",
|
||||
Categories: []string{},
|
||||
Description: "",
|
||||
Sync: false,
|
||||
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
|
||||
HandlerFunc: func(_ context.Context, _ []string, _ utils.Server, _ *net.Conn) ([]byte, error) {
|
||||
return nil, errors.New("provide CHANNELS, NUMPAT, or NUMSUB subcommand")
|
||||
},
|
||||
SubCommands: []utils.SubCommand{
|
||||
{
|
||||
Command: "channels",
|
||||
Categories: []string{utils.PubSubCategory, utils.SlowCategory},
|
||||
Description: `(PUBSUB CHANNELS [pattern]) Returns an array containing the list of channels that
|
||||
match the given pattern. If no pattern is provided, all active channels are returned. Active channels are
|
||||
channels with 1 or more subscribers.`,
|
||||
Sync: false,
|
||||
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
|
||||
HandlerFunc: handlePubSubChannels,
|
||||
},
|
||||
{
|
||||
Command: "numpat",
|
||||
Categories: []string{utils.PubSubCategory, utils.SlowCategory},
|
||||
Description: `(PUBSUB NUMPAT) Return the number of patterns that are currently subscribed to by clients.`,
|
||||
Sync: false,
|
||||
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
|
||||
HandlerFunc: handlePubSubNumPat,
|
||||
},
|
||||
{
|
||||
Command: "numsub",
|
||||
Categories: []string{utils.PubSubCategory, utils.SlowCategory},
|
||||
Description: `(PUBSUB NUMSUB [channel [channel ...]]) Return an array of arrays containing the provided
|
||||
channel name and how many clients are currently subscribed to the channel.`,
|
||||
Sync: false,
|
||||
KeyExtractionFunc: func(cmd []string) ([]string, error) { return cmd[2:], nil },
|
||||
HandlerFunc: handlePubSubNumSubs,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@@ -3,8 +3,7 @@ package pubsub
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"github.com/gobwas/glob"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
@@ -23,7 +22,8 @@ func NewPubSub() *PubSub {
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) {
|
||||
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte {
|
||||
res := fmt.Sprintf("*%d\r\n", len(channels))
|
||||
|
||||
for i := 0; i < len(channels); i++ {
|
||||
// Check if channel with given name exists
|
||||
@@ -49,43 +49,79 @@ func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []stri
|
||||
ps.channels[channelIdx].Subscribe(conn)
|
||||
}
|
||||
|
||||
var res string
|
||||
if len(channels) > 1 {
|
||||
// If subscribing to more than one channel, write array to verify the subscription of this channel
|
||||
res = fmt.Sprintf("*3\r\n+subscribe\r\n$%d\r\n%s\r\n:%d\r\n", len(channels[i]), channels[i], i+1)
|
||||
res += fmt.Sprintf("*3\r\n+subscribe\r\n$%d\r\n%s\r\n:%d\r\n", len(channels[i]), channels[i], i+1)
|
||||
} else {
|
||||
// Ony one channel, simply send "subscribe" simple string response
|
||||
res = "+subscribe\r\n"
|
||||
}
|
||||
}
|
||||
|
||||
w := io.Writer(*conn)
|
||||
if _, err := w.Write([]byte(res)); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
return []byte(res)
|
||||
}
|
||||
|
||||
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName string) {
|
||||
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte {
|
||||
ps.channelsRWMut.RLock()
|
||||
ps.channelsRWMut.RUnlock()
|
||||
|
||||
if channelName == "*" {
|
||||
wg := &sync.WaitGroup{}
|
||||
action := "unsubscribe"
|
||||
if withPattern {
|
||||
action = "subscribe"
|
||||
}
|
||||
|
||||
unsubscribed := make(map[int]string)
|
||||
count := 1
|
||||
|
||||
// If the channels slice is empty, unsubscribe from all channels.
|
||||
if len(channels) <= 0 {
|
||||
for _, channel := range ps.channels {
|
||||
wg.Add(1)
|
||||
go channel.Unsubscribe(conn, wg)
|
||||
if channel.Unsubscribe(conn) {
|
||||
unsubscribed[1] = channel.name
|
||||
count += 1
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
return
|
||||
}
|
||||
|
||||
channelIdx := slices.IndexFunc(ps.channels, func(channel *Channel) bool {
|
||||
return channel.name == channelName
|
||||
})
|
||||
|
||||
if channelIdx != -1 {
|
||||
ps.channels[channelIdx].Unsubscribe(conn, nil)
|
||||
// If withPattern is false, unsubscribe from channels where the name exactly matches channel name.
|
||||
if !withPattern {
|
||||
for _, channel := range ps.channels { // For each channel in PubSub
|
||||
for _, c := range channels { // For each channel name provided
|
||||
if channel.name == c && channel.Unsubscribe(conn) {
|
||||
unsubscribed[count] = channel.name
|
||||
count += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If withPattern is true, unsubscribe from channels where pattern matches pattern provided,
|
||||
// also unsubscribe from channels where the name matches the given pattern.
|
||||
if withPattern {
|
||||
for _, pattern := range channels {
|
||||
g := glob.MustCompile(pattern)
|
||||
for _, channel := range ps.channels {
|
||||
// If it's a pattern channel, directly compare the patterns
|
||||
if channel.pattern != nil && channel.name == pattern {
|
||||
unsubscribed[count] = channel.name
|
||||
count += 1
|
||||
continue
|
||||
}
|
||||
// If this is a regular channel, check if the channel name matches the pattern given
|
||||
if g.Match(channel.name) {
|
||||
unsubscribed[count] = channel.name
|
||||
count += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
res := fmt.Sprintf("*%d\r\n", len(unsubscribed))
|
||||
for key, value := range unsubscribed {
|
||||
res += fmt.Sprintf("*3\r\n+%s\r\n$%d\r\n%s\r\n:%d\r\n", action, len(value), value, key)
|
||||
}
|
||||
|
||||
return []byte(res)
|
||||
}
|
||||
|
||||
func (ps *PubSub) Publish(ctx context.Context, message string, channelName string) {
|
||||
@@ -93,8 +129,6 @@ func (ps *PubSub) Publish(ctx context.Context, message string, channelName strin
|
||||
defer ps.channelsRWMut.RUnlock()
|
||||
|
||||
for _, channel := range ps.channels {
|
||||
fmt.Println(channel.name, channel.pattern)
|
||||
|
||||
// If it's a regular channel, check if the channel name matches the name given
|
||||
if channel.pattern == nil {
|
||||
if channel.name == channelName {
|
||||
@@ -108,3 +142,62 @@ func (ps *PubSub) Publish(ctx context.Context, message string, channelName strin
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *PubSub) Channels(ctx context.Context, pattern string) []byte {
|
||||
var count int
|
||||
var res string
|
||||
|
||||
if pattern == "" {
|
||||
for _, channel := range ps.channels {
|
||||
if channel.IsActive() {
|
||||
res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name)
|
||||
count += 1
|
||||
}
|
||||
}
|
||||
|
||||
res = fmt.Sprintf("*%d\r\n%s", count, res)
|
||||
return []byte(res)
|
||||
}
|
||||
|
||||
g := glob.MustCompile(pattern)
|
||||
|
||||
for _, channel := range ps.channels {
|
||||
// If channel is a pattern channel, then directly compare the channel name to pattern
|
||||
if channel.pattern != nil && channel.name == pattern && channel.IsActive() {
|
||||
res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name)
|
||||
count += 1
|
||||
continue
|
||||
}
|
||||
if g.Match(channel.name) && channel.IsActive() {
|
||||
res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name)
|
||||
count += 1
|
||||
}
|
||||
}
|
||||
|
||||
return []byte(res)
|
||||
}
|
||||
|
||||
func (ps *PubSub) NumPat(ctx context.Context) int {
|
||||
var count int
|
||||
for _, channel := range ps.channels {
|
||||
if channel.pattern != nil {
|
||||
count += 1
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (ps *PubSub) NumSub(ctx context.Context, channels []string) []byte {
|
||||
res := fmt.Sprintf("*%d\r\n", len(channels))
|
||||
for _, channel := range channels {
|
||||
chanIdx := slices.IndexFunc(ps.channels, func(c *Channel) bool {
|
||||
return c.name == channel
|
||||
})
|
||||
if chanIdx == -1 {
|
||||
res += fmt.Sprintf("*2\r\n$%d\r\n%s\r\n:0\r\n", len(channel), channel)
|
||||
continue
|
||||
}
|
||||
res += fmt.Sprintf("*2\r\n$%d\r\n%s\r\n:%d\r\n", len(channel), channel, ps.channels[chanIdx].NumSubs())
|
||||
}
|
||||
return []byte(res)
|
||||
}
|
||||
|
@@ -219,6 +219,10 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
|
||||
res, err := server.handleCommand(ctx, message, &conn, false)
|
||||
|
||||
if err != nil && errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if _, err = w.Write([]byte(fmt.Sprintf("-Error %s\r\n", err.Error()))); err != nil {
|
||||
log.Println(err)
|
||||
@@ -234,10 +238,7 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
}
|
||||
|
||||
if len(res) <= chunkSize {
|
||||
_, err = w.Write(res)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
_, _ = w.Write(res)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -246,7 +247,10 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
for {
|
||||
// If the current start index is less than chunkSize from length, return the remaining bytes.
|
||||
if len(res)-1-startIndex < chunkSize {
|
||||
_, _ = w.Write(res[startIndex:])
|
||||
_, err = w.Write(res[startIndex:])
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
break
|
||||
}
|
||||
n, _ := w.Write(res[startIndex : startIndex+chunkSize])
|
||||
|
Reference in New Issue
Block a user