mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-24 00:14:08 +08:00
Removed Consumer Group in PubSub module and made the module more compatible with redis client
This commit is contained in:
@@ -31,17 +31,5 @@ func Commands() []utils.Command {
|
||||
},
|
||||
HandlerFunc: handlePing,
|
||||
},
|
||||
{
|
||||
Command: "ack",
|
||||
Categories: []string{},
|
||||
Description: "",
|
||||
Sync: false,
|
||||
KeyExtractionFunc: func(cmd []string) ([]string, error) {
|
||||
return []string{}, nil
|
||||
},
|
||||
HandlerFunc: func(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@@ -10,41 +10,46 @@ import (
|
||||
func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
pubsub, ok := server.GetPubSub().(*PubSub)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load pubsub")
|
||||
return nil, errors.New("could not load pubsub module")
|
||||
}
|
||||
switch len(cmd) {
|
||||
case 2:
|
||||
// Subscribe to specified channel
|
||||
pubsub.Subscribe(ctx, conn, cmd[1], nil)
|
||||
case 3:
|
||||
// Subscribe to specified channel and specified consumer group
|
||||
pubsub.Subscribe(ctx, conn, cmd[1], cmd[2])
|
||||
default:
|
||||
|
||||
channels := cmd[1:]
|
||||
|
||||
if len(channels) == 0 {
|
||||
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
|
||||
}
|
||||
return []byte("+SUBSCRIBE_OK\r\n\r\n"), nil
|
||||
|
||||
for i := 0; i < len(channels); i++ {
|
||||
pubsub.Subscribe(ctx, conn, channels[i], i)
|
||||
}
|
||||
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
pubsub, ok := server.GetPubSub().(*PubSub)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load pubsub")
|
||||
return nil, errors.New("could not load pubsub module")
|
||||
}
|
||||
switch len(cmd) {
|
||||
case 1:
|
||||
pubsub.Unsubscribe(ctx, conn, nil)
|
||||
case 2:
|
||||
pubsub.Unsubscribe(ctx, conn, cmd[1])
|
||||
default:
|
||||
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
|
||||
|
||||
channels := cmd[1:]
|
||||
|
||||
if len(channels) == 0 {
|
||||
pubsub.Unsubscribe(ctx, conn, "*")
|
||||
return []byte(utils.OK_RESPONSE), nil
|
||||
}
|
||||
|
||||
for _, channel := range channels {
|
||||
pubsub.Unsubscribe(ctx, conn, channel)
|
||||
}
|
||||
|
||||
return []byte(utils.OK_RESPONSE), nil
|
||||
}
|
||||
|
||||
func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
pubsub, ok := server.GetPubSub().(*PubSub)
|
||||
if !ok {
|
||||
return nil, errors.New("could not load pubsub")
|
||||
return nil, errors.New("could not load pubsub module")
|
||||
}
|
||||
if len(cmd) != 3 {
|
||||
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
|
||||
@@ -72,28 +77,27 @@ func Commands() []utils.Command {
|
||||
{
|
||||
Command: "subscribe",
|
||||
Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory},
|
||||
Description: "(SUBSCRIBE channel [consumer_group]) Subscribe to a channel with an option to join a consumer group on the channel.",
|
||||
Description: "(SUBSCRIBE channel [channel ...]) Subscribe to one or more channels.",
|
||||
Sync: false,
|
||||
KeyExtractionFunc: func(cmd []string) ([]string, error) {
|
||||
// Treat the channel as a key
|
||||
if len(cmd) < 2 {
|
||||
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
|
||||
}
|
||||
return []string{cmd[1]}, nil
|
||||
return cmd[1:], nil
|
||||
},
|
||||
HandlerFunc: handleSubscribe,
|
||||
},
|
||||
{
|
||||
Command: "unsubscribe",
|
||||
Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory},
|
||||
Description: "(UNSUBSCRIBE channel) Unsubscribe from a channel.",
|
||||
Description: `(UNSUBSCRIBE [channel [channel ...]]) Unsubscribe from a list of channels.
|
||||
If the channel list is not provided, then the connection will be unsubscribed from all the channels that
|
||||
it's currently subscribe to.`,
|
||||
Sync: false,
|
||||
KeyExtractionFunc: func(cmd []string) ([]string, error) {
|
||||
// Treat the channel as a key
|
||||
if len(cmd) != 2 {
|
||||
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
|
||||
}
|
||||
return []string{cmd[1]}, nil
|
||||
// Treat the channels as keys
|
||||
return cmd[1:], nil
|
||||
},
|
||||
HandlerFunc: handleUnsubscribe,
|
||||
},
|
||||
|
@@ -1,140 +1,16 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"container/ring"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/echovault/echovault/src/utils"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConsumerGroup allows multiple subscribers to share the consumption load of a channel.
|
||||
// Only one subscriber in the consumer group will receive messages published to the channel.
|
||||
type ConsumerGroup struct {
|
||||
name string
|
||||
subscribersRWMut sync.RWMutex
|
||||
subscribers *ring.Ring
|
||||
messageChan *chan string
|
||||
}
|
||||
|
||||
func NewConsumerGroup(name string) *ConsumerGroup {
|
||||
messageChan := make(chan string)
|
||||
|
||||
return &ConsumerGroup{
|
||||
name: name,
|
||||
subscribersRWMut: sync.RWMutex{},
|
||||
subscribers: nil,
|
||||
messageChan: &messageChan,
|
||||
}
|
||||
}
|
||||
|
||||
func (cg *ConsumerGroup) SendMessage(message string) {
|
||||
cg.subscribersRWMut.RLock()
|
||||
|
||||
conn := cg.subscribers.Value.(*net.Conn)
|
||||
|
||||
cg.subscribersRWMut.RUnlock()
|
||||
|
||||
w, r := io.Writer(*conn), io.Reader(*conn)
|
||||
|
||||
if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
// Wait for an ACK
|
||||
// If no ACK is received within a time limit, remove this connection from subscribers and retry
|
||||
if err := (*conn).SetReadDeadline(time.Now().Add(250 * time.Millisecond)); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
if msg, err := utils.ReadMessage(r); err != nil {
|
||||
// Remove the connection from subscribers list
|
||||
cg.Unsubscribe(conn)
|
||||
// Reset the deadline
|
||||
if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
// Retry sending the message
|
||||
cg.SendMessage(message)
|
||||
} else {
|
||||
if !bytes.Equal(bytes.TrimSpace(msg), []byte("+ACK")) {
|
||||
cg.Unsubscribe(conn)
|
||||
if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
cg.SendMessage(message)
|
||||
}
|
||||
}
|
||||
|
||||
if err := (*conn).SetDeadline(time.Time{}); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
cg.subscribers = cg.subscribers.Next()
|
||||
}
|
||||
|
||||
func (cg *ConsumerGroup) Start() {
|
||||
go func() {
|
||||
for {
|
||||
message := <-*cg.messageChan
|
||||
if cg.subscribers != nil {
|
||||
cg.SendMessage(message)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (cg *ConsumerGroup) Subscribe(conn *net.Conn) {
|
||||
cg.subscribersRWMut.Lock()
|
||||
defer cg.subscribersRWMut.Unlock()
|
||||
|
||||
r := ring.New(1)
|
||||
for i := 0; i < r.Len(); i++ {
|
||||
r.Value = conn
|
||||
r = r.Next()
|
||||
}
|
||||
|
||||
if cg.subscribers == nil {
|
||||
cg.subscribers = r
|
||||
return
|
||||
}
|
||||
|
||||
cg.subscribers = cg.subscribers.Link(r)
|
||||
}
|
||||
|
||||
func (cg *ConsumerGroup) Unsubscribe(conn *net.Conn) {
|
||||
cg.subscribersRWMut.Lock()
|
||||
defer cg.subscribersRWMut.Unlock()
|
||||
|
||||
// If length is 1 and the connection passed is the one contained within, unlink it
|
||||
if cg.subscribers.Len() == 1 {
|
||||
if cg.subscribers.Value == conn {
|
||||
cg.subscribers = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < cg.subscribers.Len(); i++ {
|
||||
if cg.subscribers.Value == conn {
|
||||
cg.subscribers = cg.subscribers.Prev()
|
||||
cg.subscribers.Unlink(1)
|
||||
break
|
||||
}
|
||||
cg.subscribers = cg.subscribers.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (cg *ConsumerGroup) Publish(message string) {
|
||||
*cg.messageChan <- message
|
||||
}
|
||||
|
||||
// Channel - A channel can be subscribed to directly, or via a consumer group.
|
||||
// All direct subscribers to the channel will receive any message published to the channel.
|
||||
// Only one subscriber of a channel's consumer group will receive a message posted to the channel.
|
||||
@@ -142,18 +18,16 @@ type Channel struct {
|
||||
name string
|
||||
subscribersRWMut sync.RWMutex
|
||||
subscribers []*net.Conn
|
||||
consumerGroups []*ConsumerGroup
|
||||
messageChan *chan string
|
||||
}
|
||||
|
||||
func NewChannel(name string) *Channel {
|
||||
messageChan := make(chan string)
|
||||
messageChan := make(chan string, 4096)
|
||||
|
||||
return &Channel{
|
||||
name: name,
|
||||
subscribersRWMut: sync.RWMutex{},
|
||||
subscribers: []*net.Conn{},
|
||||
consumerGroups: []*ConsumerGroup{},
|
||||
messageChan: &messageChan,
|
||||
}
|
||||
}
|
||||
@@ -167,32 +41,10 @@ func (ch *Channel) Start() {
|
||||
|
||||
for _, conn := range ch.subscribers {
|
||||
go func(conn *net.Conn) {
|
||||
w, r := io.Writer(*conn), io.Reader(*conn)
|
||||
w := io.Writer(*conn)
|
||||
|
||||
if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\r\n", len(message), message))); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
if err := (*conn).SetReadDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
ch.Unsubscribe(conn)
|
||||
}
|
||||
defer func() {
|
||||
if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
ch.Unsubscribe(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
if msg, err := utils.ReadMessage(r); err != nil {
|
||||
ch.Unsubscribe(conn)
|
||||
} else {
|
||||
if !bytes.EqualFold(bytes.TrimSpace(msg), []byte("+ACK")) {
|
||||
ch.Unsubscribe(conn)
|
||||
}
|
||||
if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(message), message))); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
@@ -202,50 +54,36 @@ func (ch *Channel) Start() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (ch *Channel) Subscribe(conn *net.Conn, consumerGroupName interface{}) {
|
||||
if consumerGroupName == nil && !slices.Contains(ch.subscribers, conn) {
|
||||
func (ch *Channel) Subscribe(conn *net.Conn, index int) {
|
||||
if !slices.Contains(ch.subscribers, conn) {
|
||||
ch.subscribersRWMut.Lock()
|
||||
defer ch.subscribersRWMut.Unlock()
|
||||
|
||||
ch.subscribers = append(ch.subscribers, conn)
|
||||
return
|
||||
|
||||
// Write array to verify the subscription of this channel
|
||||
res := fmt.Sprintf("*3\r\n+subscribe\r\n$%d\r\n%s\r\n:%d\r\n", len(ch.name), ch.name, index+1)
|
||||
w := io.Writer(*conn)
|
||||
if _, err := w.Write([]byte(res)); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
||||
groups := utils.Filter[*ConsumerGroup](ch.consumerGroups, func(group *ConsumerGroup) bool {
|
||||
return group.name == consumerGroupName.(string)
|
||||
})
|
||||
|
||||
if len(groups) == 0 {
|
||||
go func() {
|
||||
newGroup := NewConsumerGroup(consumerGroupName.(string))
|
||||
newGroup.Start()
|
||||
newGroup.Subscribe(conn)
|
||||
ch.consumerGroups = append(ch.consumerGroups, newGroup)
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
go group.Subscribe(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *Channel) Unsubscribe(conn *net.Conn) {
|
||||
func (ch *Channel) Unsubscribe(conn *net.Conn, waitGroup *sync.WaitGroup) {
|
||||
ch.subscribersRWMut.Lock()
|
||||
defer ch.subscribersRWMut.Unlock()
|
||||
|
||||
ch.subscribers = utils.Filter[*net.Conn](ch.subscribers, func(c *net.Conn) bool {
|
||||
return c != conn
|
||||
ch.subscribers = slices.DeleteFunc(ch.subscribers, func(c *net.Conn) bool {
|
||||
return c == conn
|
||||
})
|
||||
|
||||
for _, group := range ch.consumerGroups {
|
||||
go group.Unsubscribe(conn)
|
||||
if waitGroup != nil {
|
||||
waitGroup.Done()
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *Channel) Publish(message string) {
|
||||
for _, group := range ch.consumerGroups {
|
||||
go group.Publish(message)
|
||||
}
|
||||
*ch.messageChan <- message
|
||||
}
|
||||
|
||||
@@ -260,7 +98,7 @@ func NewPubSub() *PubSub {
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName string, consumerGroup interface{}) {
|
||||
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName string, index int) {
|
||||
// Check if channel with given name exists
|
||||
// If it does, subscribe the connection to the channel
|
||||
// If it does not, create the channel and subscribe to it
|
||||
@@ -272,29 +110,32 @@ func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName str
|
||||
go func() {
|
||||
newChan := NewChannel(channelName)
|
||||
newChan.Start()
|
||||
newChan.Subscribe(conn, consumerGroup)
|
||||
newChan.Subscribe(conn, index)
|
||||
ps.channels = append(ps.channels, newChan)
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
go ps.channels[channelIdx].Subscribe(conn, consumerGroup)
|
||||
ps.channels[channelIdx].Subscribe(conn, index)
|
||||
}
|
||||
|
||||
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName interface{}) {
|
||||
if channelName == nil {
|
||||
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName string) {
|
||||
if channelName == "*" {
|
||||
wg := &sync.WaitGroup{}
|
||||
for _, channel := range ps.channels {
|
||||
go channel.Unsubscribe(conn)
|
||||
wg.Add(1)
|
||||
go channel.Unsubscribe(conn, wg)
|
||||
}
|
||||
wg.Wait()
|
||||
return
|
||||
}
|
||||
|
||||
channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool {
|
||||
return c.name == channelName
|
||||
channelIdx := slices.IndexFunc(ps.channels, func(channel *Channel) bool {
|
||||
return channel.name == channelName
|
||||
})
|
||||
|
||||
for _, channel := range channels {
|
||||
go channel.Unsubscribe(conn)
|
||||
if channelIdx != -1 {
|
||||
ps.channels[channelIdx].Unsubscribe(conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -228,6 +228,11 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
|
||||
chunkSize := 1024
|
||||
|
||||
// If the length of the response is 0, return nothing to the client
|
||||
if len(res) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(res) <= chunkSize {
|
||||
_, err = w.Write(res)
|
||||
if err != nil {
|
||||
|
Reference in New Issue
Block a user