mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-24 16:30:21 +08:00
Added PUBSUB commands and made pubsub module more compatible with redis-cli client
This commit is contained in:
@@ -144,7 +144,6 @@ func handleLRange(ctx context.Context, cmd []string, server utils.Server, conn *
|
|||||||
} else {
|
} else {
|
||||||
i--
|
i--
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return bytes, nil
|
return bytes, nil
|
||||||
|
|||||||
@@ -83,19 +83,31 @@ func (ch *Channel) Subscribe(conn *net.Conn) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch *Channel) Unsubscribe(conn *net.Conn, waitGroup *sync.WaitGroup) {
|
func (ch *Channel) Unsubscribe(conn *net.Conn) bool {
|
||||||
ch.subscribersRWMut.Lock()
|
ch.subscribersRWMut.Lock()
|
||||||
defer ch.subscribersRWMut.Unlock()
|
defer ch.subscribersRWMut.Unlock()
|
||||||
|
|
||||||
|
var removed bool
|
||||||
|
|
||||||
ch.subscribers = slices.DeleteFunc(ch.subscribers, func(c *net.Conn) bool {
|
ch.subscribers = slices.DeleteFunc(ch.subscribers, func(c *net.Conn) bool {
|
||||||
return c == conn
|
if c == conn {
|
||||||
|
removed = true
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
})
|
})
|
||||||
|
|
||||||
if waitGroup != nil {
|
return removed
|
||||||
waitGroup.Done()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch *Channel) Publish(message string) {
|
func (ch *Channel) Publish(message string) {
|
||||||
*ch.messageChan <- message
|
*ch.messageChan <- message
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ch *Channel) IsActive() bool {
|
||||||
|
return len(ch.subscribers) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch *Channel) NumSubs() int {
|
||||||
|
return len(ch.subscribers)
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package pubsub
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"github.com/echovault/echovault/src/utils"
|
"github.com/echovault/echovault/src/utils"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -22,9 +23,9 @@ func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, con
|
|||||||
|
|
||||||
switch strings.ToLower(cmd[0]) {
|
switch strings.ToLower(cmd[0]) {
|
||||||
case "subscribe":
|
case "subscribe":
|
||||||
pubsub.Subscribe(ctx, conn, channels, false)
|
return pubsub.Subscribe(ctx, conn, channels, false), nil
|
||||||
case "psubscribe":
|
case "psubscribe":
|
||||||
pubsub.Subscribe(ctx, conn, channels, true)
|
return pubsub.Subscribe(ctx, conn, channels, true), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return []byte{}, nil
|
return []byte{}, nil
|
||||||
@@ -38,16 +39,14 @@ func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, c
|
|||||||
|
|
||||||
channels := cmd[1:]
|
channels := cmd[1:]
|
||||||
|
|
||||||
if len(channels) == 0 {
|
switch strings.ToLower(cmd[0]) {
|
||||||
pubsub.Unsubscribe(ctx, conn, "*")
|
case "unsubscribe":
|
||||||
return []byte(utils.OK_RESPONSE), nil
|
return pubsub.Unsubscribe(ctx, conn, channels, false), nil
|
||||||
|
case "punsubscribe":
|
||||||
|
return pubsub.Unsubscribe(ctx, conn, channels, true), nil
|
||||||
|
default:
|
||||||
|
return []byte{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, channel := range channels {
|
|
||||||
pubsub.Unsubscribe(ctx, conn, channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
return []byte(utils.OK_RESPONSE), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||||
@@ -62,6 +61,41 @@ func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn
|
|||||||
return []byte(utils.OK_RESPONSE), nil
|
return []byte(utils.OK_RESPONSE), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handlePubSubChannels(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||||
|
if len(cmd) > 3 {
|
||||||
|
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
|
||||||
|
}
|
||||||
|
|
||||||
|
pubsub, ok := server.GetPubSub().(*PubSub)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("could not load pubsub module")
|
||||||
|
}
|
||||||
|
|
||||||
|
pattern := ""
|
||||||
|
if len(cmd) == 3 {
|
||||||
|
pattern = cmd[2]
|
||||||
|
}
|
||||||
|
|
||||||
|
return pubsub.Channels(ctx, pattern), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handlePubSubNumPat(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||||
|
pubsub, ok := server.GetPubSub().(*PubSub)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("could not load pubsub module")
|
||||||
|
}
|
||||||
|
num := pubsub.NumPat(ctx)
|
||||||
|
return []byte(fmt.Sprintf(":%d\r\n", num)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handlePubSubNumSubs(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||||
|
pubsub, ok := server.GetPubSub().(*PubSub)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("could not load pubsub module")
|
||||||
|
}
|
||||||
|
return pubsub.NumSub(ctx, cmd[2:]), nil
|
||||||
|
}
|
||||||
|
|
||||||
func Commands() []utils.Command {
|
func Commands() []utils.Command {
|
||||||
return []utils.Command{
|
return []utils.Command{
|
||||||
{
|
{
|
||||||
@@ -119,5 +153,57 @@ it's currently subscribe to.`,
|
|||||||
},
|
},
|
||||||
HandlerFunc: handleUnsubscribe,
|
HandlerFunc: handleUnsubscribe,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Command: "punsubscribe",
|
||||||
|
Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory},
|
||||||
|
Description: `(PUNSUBSCRIBE [channel [channel ...]]) Unsubscribe from a list of channels using patterns.
|
||||||
|
If the pattern list is not provided, then the connection will be unsubscribed from all the patterns that
|
||||||
|
it's currently subscribe to.`,
|
||||||
|
Sync: false,
|
||||||
|
KeyExtractionFunc: func(cmd []string) ([]string, error) {
|
||||||
|
// Treat the channels as keys
|
||||||
|
return cmd[1:], nil
|
||||||
|
},
|
||||||
|
HandlerFunc: handleUnsubscribe,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Command: "pubsub",
|
||||||
|
Categories: []string{},
|
||||||
|
Description: "",
|
||||||
|
Sync: false,
|
||||||
|
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
|
||||||
|
HandlerFunc: func(_ context.Context, _ []string, _ utils.Server, _ *net.Conn) ([]byte, error) {
|
||||||
|
return nil, errors.New("provide CHANNELS, NUMPAT, or NUMSUB subcommand")
|
||||||
|
},
|
||||||
|
SubCommands: []utils.SubCommand{
|
||||||
|
{
|
||||||
|
Command: "channels",
|
||||||
|
Categories: []string{utils.PubSubCategory, utils.SlowCategory},
|
||||||
|
Description: `(PUBSUB CHANNELS [pattern]) Returns an array containing the list of channels that
|
||||||
|
match the given pattern. If no pattern is provided, all active channels are returned. Active channels are
|
||||||
|
channels with 1 or more subscribers.`,
|
||||||
|
Sync: false,
|
||||||
|
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
|
||||||
|
HandlerFunc: handlePubSubChannels,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Command: "numpat",
|
||||||
|
Categories: []string{utils.PubSubCategory, utils.SlowCategory},
|
||||||
|
Description: `(PUBSUB NUMPAT) Return the number of patterns that are currently subscribed to by clients.`,
|
||||||
|
Sync: false,
|
||||||
|
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
|
||||||
|
HandlerFunc: handlePubSubNumPat,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Command: "numsub",
|
||||||
|
Categories: []string{utils.PubSubCategory, utils.SlowCategory},
|
||||||
|
Description: `(PUBSUB NUMSUB [channel [channel ...]]) Return an array of arrays containing the provided
|
||||||
|
channel name and how many clients are currently subscribed to the channel.`,
|
||||||
|
Sync: false,
|
||||||
|
KeyExtractionFunc: func(cmd []string) ([]string, error) { return cmd[2:], nil },
|
||||||
|
HandlerFunc: handlePubSubNumSubs,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,8 +3,7 @@ package pubsub
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"github.com/gobwas/glob"
|
||||||
"log"
|
|
||||||
"net"
|
"net"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -23,7 +22,8 @@ func NewPubSub() *PubSub {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) {
|
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte {
|
||||||
|
res := fmt.Sprintf("*%d\r\n", len(channels))
|
||||||
|
|
||||||
for i := 0; i < len(channels); i++ {
|
for i := 0; i < len(channels); i++ {
|
||||||
// Check if channel with given name exists
|
// Check if channel with given name exists
|
||||||
@@ -49,52 +49,86 @@ func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []stri
|
|||||||
ps.channels[channelIdx].Subscribe(conn)
|
ps.channels[channelIdx].Subscribe(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
var res string
|
|
||||||
if len(channels) > 1 {
|
if len(channels) > 1 {
|
||||||
// If subscribing to more than one channel, write array to verify the subscription of this channel
|
// If subscribing to more than one channel, write array to verify the subscription of this channel
|
||||||
res = fmt.Sprintf("*3\r\n+subscribe\r\n$%d\r\n%s\r\n:%d\r\n", len(channels[i]), channels[i], i+1)
|
res += fmt.Sprintf("*3\r\n+subscribe\r\n$%d\r\n%s\r\n:%d\r\n", len(channels[i]), channels[i], i+1)
|
||||||
} else {
|
} else {
|
||||||
// Ony one channel, simply send "subscribe" simple string response
|
// Ony one channel, simply send "subscribe" simple string response
|
||||||
res = "+subscribe\r\n"
|
res = "+subscribe\r\n"
|
||||||
}
|
}
|
||||||
|
|
||||||
w := io.Writer(*conn)
|
|
||||||
if _, err := w.Write([]byte(res)); err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName string) {
|
return []byte(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte {
|
||||||
ps.channelsRWMut.RLock()
|
ps.channelsRWMut.RLock()
|
||||||
ps.channelsRWMut.RUnlock()
|
ps.channelsRWMut.RUnlock()
|
||||||
|
|
||||||
if channelName == "*" {
|
action := "unsubscribe"
|
||||||
wg := &sync.WaitGroup{}
|
if withPattern {
|
||||||
|
action = "subscribe"
|
||||||
|
}
|
||||||
|
|
||||||
|
unsubscribed := make(map[int]string)
|
||||||
|
count := 1
|
||||||
|
|
||||||
|
// If the channels slice is empty, unsubscribe from all channels.
|
||||||
|
if len(channels) <= 0 {
|
||||||
for _, channel := range ps.channels {
|
for _, channel := range ps.channels {
|
||||||
wg.Add(1)
|
if channel.Unsubscribe(conn) {
|
||||||
go channel.Unsubscribe(conn, wg)
|
unsubscribed[1] = channel.name
|
||||||
|
count += 1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
wg.Wait()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
channelIdx := slices.IndexFunc(ps.channels, func(channel *Channel) bool {
|
// If withPattern is false, unsubscribe from channels where the name exactly matches channel name.
|
||||||
return channel.name == channelName
|
if !withPattern {
|
||||||
})
|
for _, channel := range ps.channels { // For each channel in PubSub
|
||||||
|
for _, c := range channels { // For each channel name provided
|
||||||
if channelIdx != -1 {
|
if channel.name == c && channel.Unsubscribe(conn) {
|
||||||
ps.channels[channelIdx].Unsubscribe(conn, nil)
|
unsubscribed[count] = channel.name
|
||||||
|
count += 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If withPattern is true, unsubscribe from channels where pattern matches pattern provided,
|
||||||
|
// also unsubscribe from channels where the name matches the given pattern.
|
||||||
|
if withPattern {
|
||||||
|
for _, pattern := range channels {
|
||||||
|
g := glob.MustCompile(pattern)
|
||||||
|
for _, channel := range ps.channels {
|
||||||
|
// If it's a pattern channel, directly compare the patterns
|
||||||
|
if channel.pattern != nil && channel.name == pattern {
|
||||||
|
unsubscribed[count] = channel.name
|
||||||
|
count += 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// If this is a regular channel, check if the channel name matches the pattern given
|
||||||
|
if g.Match(channel.name) {
|
||||||
|
unsubscribed[count] = channel.name
|
||||||
|
count += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res := fmt.Sprintf("*%d\r\n", len(unsubscribed))
|
||||||
|
for key, value := range unsubscribed {
|
||||||
|
res += fmt.Sprintf("*3\r\n+%s\r\n$%d\r\n%s\r\n:%d\r\n", action, len(value), value, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte(res)
|
||||||
|
}
|
||||||
|
|
||||||
func (ps *PubSub) Publish(ctx context.Context, message string, channelName string) {
|
func (ps *PubSub) Publish(ctx context.Context, message string, channelName string) {
|
||||||
ps.channelsRWMut.RLock()
|
ps.channelsRWMut.RLock()
|
||||||
defer ps.channelsRWMut.RUnlock()
|
defer ps.channelsRWMut.RUnlock()
|
||||||
|
|
||||||
for _, channel := range ps.channels {
|
for _, channel := range ps.channels {
|
||||||
fmt.Println(channel.name, channel.pattern)
|
|
||||||
|
|
||||||
// If it's a regular channel, check if the channel name matches the name given
|
// If it's a regular channel, check if the channel name matches the name given
|
||||||
if channel.pattern == nil {
|
if channel.pattern == nil {
|
||||||
if channel.name == channelName {
|
if channel.name == channelName {
|
||||||
@@ -108,3 +142,62 @@ func (ps *PubSub) Publish(ctx context.Context, message string, channelName strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ps *PubSub) Channels(ctx context.Context, pattern string) []byte {
|
||||||
|
var count int
|
||||||
|
var res string
|
||||||
|
|
||||||
|
if pattern == "" {
|
||||||
|
for _, channel := range ps.channels {
|
||||||
|
if channel.IsActive() {
|
||||||
|
res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name)
|
||||||
|
count += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res = fmt.Sprintf("*%d\r\n%s", count, res)
|
||||||
|
return []byte(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
g := glob.MustCompile(pattern)
|
||||||
|
|
||||||
|
for _, channel := range ps.channels {
|
||||||
|
// If channel is a pattern channel, then directly compare the channel name to pattern
|
||||||
|
if channel.pattern != nil && channel.name == pattern && channel.IsActive() {
|
||||||
|
res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name)
|
||||||
|
count += 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if g.Match(channel.name) && channel.IsActive() {
|
||||||
|
res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name)
|
||||||
|
count += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *PubSub) NumPat(ctx context.Context) int {
|
||||||
|
var count int
|
||||||
|
for _, channel := range ps.channels {
|
||||||
|
if channel.pattern != nil {
|
||||||
|
count += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *PubSub) NumSub(ctx context.Context, channels []string) []byte {
|
||||||
|
res := fmt.Sprintf("*%d\r\n", len(channels))
|
||||||
|
for _, channel := range channels {
|
||||||
|
chanIdx := slices.IndexFunc(ps.channels, func(c *Channel) bool {
|
||||||
|
return c.name == channel
|
||||||
|
})
|
||||||
|
if chanIdx == -1 {
|
||||||
|
res += fmt.Sprintf("*2\r\n$%d\r\n%s\r\n:0\r\n", len(channel), channel)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
res += fmt.Sprintf("*2\r\n$%d\r\n%s\r\n:%d\r\n", len(channel), channel, ps.channels[chanIdx].NumSubs())
|
||||||
|
}
|
||||||
|
return []byte(res)
|
||||||
|
}
|
||||||
|
|||||||
@@ -219,6 +219,10 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
|||||||
|
|
||||||
res, err := server.handleCommand(ctx, message, &conn, false)
|
res, err := server.handleCommand(ctx, message, &conn, false)
|
||||||
|
|
||||||
|
if err != nil && errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, err = w.Write([]byte(fmt.Sprintf("-Error %s\r\n", err.Error()))); err != nil {
|
if _, err = w.Write([]byte(fmt.Sprintf("-Error %s\r\n", err.Error()))); err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
@@ -234,10 +238,7 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(res) <= chunkSize {
|
if len(res) <= chunkSize {
|
||||||
_, err = w.Write(res)
|
_, _ = w.Write(res)
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -246,7 +247,10 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
|||||||
for {
|
for {
|
||||||
// If the current start index is less than chunkSize from length, return the remaining bytes.
|
// If the current start index is less than chunkSize from length, return the remaining bytes.
|
||||||
if len(res)-1-startIndex < chunkSize {
|
if len(res)-1-startIndex < chunkSize {
|
||||||
_, _ = w.Write(res[startIndex:])
|
_, err = w.Write(res[startIndex:])
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
n, _ := w.Write(res[startIndex : startIndex+chunkSize])
|
n, _ := w.Write(res[startIndex : startIndex+chunkSize])
|
||||||
|
|||||||
Reference in New Issue
Block a user