Removed Consumer Group in PubSub module and made the module more compatible with redis client

This commit is contained in:
Kelvin Clement Mwinuka
2024-02-27 17:45:20 +08:00
parent 160c701c3a
commit fc8d301525
4 changed files with 71 additions and 233 deletions

View File

@@ -31,17 +31,5 @@ func Commands() []utils.Command {
}, },
HandlerFunc: handlePing, HandlerFunc: handlePing,
}, },
{
Command: "ack",
Categories: []string{},
Description: "",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
},
HandlerFunc: func(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
return []byte{}, nil
},
},
} }
} }

View File

@@ -10,41 +10,46 @@ import (
func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*PubSub) pubsub, ok := server.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub") return nil, errors.New("could not load pubsub module")
} }
switch len(cmd) {
case 2: channels := cmd[1:]
// Subscribe to specified channel
pubsub.Subscribe(ctx, conn, cmd[1], nil) if len(channels) == 0 {
case 3:
// Subscribe to specified channel and specified consumer group
pubsub.Subscribe(ctx, conn, cmd[1], cmd[2])
default:
return nil, errors.New(utils.WRONG_ARGS_RESPONSE) return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
} }
return []byte("+SUBSCRIBE_OK\r\n\r\n"), nil
for i := 0; i < len(channels); i++ {
pubsub.Subscribe(ctx, conn, channels[i], i)
}
return []byte{}, nil
} }
func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*PubSub) pubsub, ok := server.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub") return nil, errors.New("could not load pubsub module")
} }
switch len(cmd) {
case 1: channels := cmd[1:]
pubsub.Unsubscribe(ctx, conn, nil)
case 2: if len(channels) == 0 {
pubsub.Unsubscribe(ctx, conn, cmd[1]) pubsub.Unsubscribe(ctx, conn, "*")
default: return []byte(utils.OK_RESPONSE), nil
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
} }
for _, channel := range channels {
pubsub.Unsubscribe(ctx, conn, channel)
}
return []byte(utils.OK_RESPONSE), nil return []byte(utils.OK_RESPONSE), nil
} }
func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*PubSub) pubsub, ok := server.GetPubSub().(*PubSub)
if !ok { if !ok {
return nil, errors.New("could not load pubsub") return nil, errors.New("could not load pubsub module")
} }
if len(cmd) != 3 { if len(cmd) != 3 {
return nil, errors.New(utils.WRONG_ARGS_RESPONSE) return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
@@ -72,28 +77,27 @@ func Commands() []utils.Command {
{ {
Command: "subscribe", Command: "subscribe",
Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory}, Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory},
Description: "(SUBSCRIBE channel [consumer_group]) Subscribe to a channel with an option to join a consumer group on the channel.", Description: "(SUBSCRIBE channel [channel ...]) Subscribe to one or more channels.",
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) ([]string, error) {
// Treat the channel as a key // Treat the channel as a key
if len(cmd) < 2 { if len(cmd) < 2 {
return nil, errors.New(utils.WRONG_ARGS_RESPONSE) return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
} }
return []string{cmd[1]}, nil return cmd[1:], nil
}, },
HandlerFunc: handleSubscribe, HandlerFunc: handleSubscribe,
}, },
{ {
Command: "unsubscribe", Command: "unsubscribe",
Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory}, Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory},
Description: "(UNSUBSCRIBE channel) Unsubscribe from a channel.", Description: `(UNSUBSCRIBE [channel [channel ...]]) Unsubscribe from a list of channels.
If the channel list is not provided, then the connection will be unsubscribed from all the channels that
it's currently subscribe to.`,
Sync: false, Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) { KeyExtractionFunc: func(cmd []string) ([]string, error) {
// Treat the channel as a key // Treat the channels as keys
if len(cmd) != 2 { return cmd[1:], nil
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
}
return []string{cmd[1]}, nil
}, },
HandlerFunc: handleUnsubscribe, HandlerFunc: handleUnsubscribe,
}, },

View File

@@ -1,140 +1,16 @@
package pubsub package pubsub
import ( import (
"bytes"
"container/ring"
"context" "context"
"fmt" "fmt"
"github.com/echovault/echovault/src/utils" "github.com/echovault/echovault/src/utils"
"io" "io"
"log"
"net" "net"
"slices" "slices"
"sync" "sync"
"time"
) )
// ConsumerGroup allows multiple subscribers to share the consumption load of a channel.
// Only one subscriber in the consumer group will receive messages published to the channel.
type ConsumerGroup struct {
name string
subscribersRWMut sync.RWMutex
subscribers *ring.Ring
messageChan *chan string
}
func NewConsumerGroup(name string) *ConsumerGroup {
messageChan := make(chan string)
return &ConsumerGroup{
name: name,
subscribersRWMut: sync.RWMutex{},
subscribers: nil,
messageChan: &messageChan,
}
}
func (cg *ConsumerGroup) SendMessage(message string) {
cg.subscribersRWMut.RLock()
conn := cg.subscribers.Value.(*net.Conn)
cg.subscribersRWMut.RUnlock()
w, r := io.Writer(*conn), io.Reader(*conn)
if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
// Wait for an ACK
// If no ACK is received within a time limit, remove this connection from subscribers and retry
if err := (*conn).SetReadDeadline(time.Now().Add(250 * time.Millisecond)); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
if msg, err := utils.ReadMessage(r); err != nil {
// Remove the connection from subscribers list
cg.Unsubscribe(conn)
// Reset the deadline
if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
// Retry sending the message
cg.SendMessage(message)
} else {
if !bytes.Equal(bytes.TrimSpace(msg), []byte("+ACK")) {
cg.Unsubscribe(conn)
if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
cg.SendMessage(message)
}
}
if err := (*conn).SetDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
cg.subscribers = cg.subscribers.Next()
}
func (cg *ConsumerGroup) Start() {
go func() {
for {
message := <-*cg.messageChan
if cg.subscribers != nil {
cg.SendMessage(message)
}
}
}()
}
func (cg *ConsumerGroup) Subscribe(conn *net.Conn) {
cg.subscribersRWMut.Lock()
defer cg.subscribersRWMut.Unlock()
r := ring.New(1)
for i := 0; i < r.Len(); i++ {
r.Value = conn
r = r.Next()
}
if cg.subscribers == nil {
cg.subscribers = r
return
}
cg.subscribers = cg.subscribers.Link(r)
}
func (cg *ConsumerGroup) Unsubscribe(conn *net.Conn) {
cg.subscribersRWMut.Lock()
defer cg.subscribersRWMut.Unlock()
// If length is 1 and the connection passed is the one contained within, unlink it
if cg.subscribers.Len() == 1 {
if cg.subscribers.Value == conn {
cg.subscribers = nil
}
return
}
for i := 0; i < cg.subscribers.Len(); i++ {
if cg.subscribers.Value == conn {
cg.subscribers = cg.subscribers.Prev()
cg.subscribers.Unlink(1)
break
}
cg.subscribers = cg.subscribers.Next()
}
}
func (cg *ConsumerGroup) Publish(message string) {
*cg.messageChan <- message
}
// Channel - A channel can be subscribed to directly, or via a consumer group. // Channel - A channel can be subscribed to directly, or via a consumer group.
// All direct subscribers to the channel will receive any message published to the channel. // All direct subscribers to the channel will receive any message published to the channel.
// Only one subscriber of a channel's consumer group will receive a message posted to the channel. // Only one subscriber of a channel's consumer group will receive a message posted to the channel.
@@ -142,18 +18,16 @@ type Channel struct {
name string name string
subscribersRWMut sync.RWMutex subscribersRWMut sync.RWMutex
subscribers []*net.Conn subscribers []*net.Conn
consumerGroups []*ConsumerGroup
messageChan *chan string messageChan *chan string
} }
func NewChannel(name string) *Channel { func NewChannel(name string) *Channel {
messageChan := make(chan string) messageChan := make(chan string, 4096)
return &Channel{ return &Channel{
name: name, name: name,
subscribersRWMut: sync.RWMutex{}, subscribersRWMut: sync.RWMutex{},
subscribers: []*net.Conn{}, subscribers: []*net.Conn{},
consumerGroups: []*ConsumerGroup{},
messageChan: &messageChan, messageChan: &messageChan,
} }
} }
@@ -167,32 +41,10 @@ func (ch *Channel) Start() {
for _, conn := range ch.subscribers { for _, conn := range ch.subscribers {
go func(conn *net.Conn) { go func(conn *net.Conn) {
w, r := io.Writer(*conn), io.Reader(*conn) w := io.Writer(*conn)
if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\r\n", len(message), message))); err != nil { if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(message), message))); err != nil {
// TODO: Log error at configured logger log.Println(err)
fmt.Println(err)
}
if err := (*conn).SetReadDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
ch.Unsubscribe(conn)
}
defer func() {
if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
ch.Unsubscribe(conn)
}
}()
if msg, err := utils.ReadMessage(r); err != nil {
ch.Unsubscribe(conn)
} else {
if !bytes.EqualFold(bytes.TrimSpace(msg), []byte("+ACK")) {
ch.Unsubscribe(conn)
}
} }
}(conn) }(conn)
} }
@@ -202,50 +54,36 @@ func (ch *Channel) Start() {
}() }()
} }
func (ch *Channel) Subscribe(conn *net.Conn, consumerGroupName interface{}) { func (ch *Channel) Subscribe(conn *net.Conn, index int) {
if consumerGroupName == nil && !slices.Contains(ch.subscribers, conn) { if !slices.Contains(ch.subscribers, conn) {
ch.subscribersRWMut.Lock() ch.subscribersRWMut.Lock()
defer ch.subscribersRWMut.Unlock() defer ch.subscribersRWMut.Unlock()
ch.subscribers = append(ch.subscribers, conn) ch.subscribers = append(ch.subscribers, conn)
return
// Write array to verify the subscription of this channel
res := fmt.Sprintf("*3\r\n+subscribe\r\n$%d\r\n%s\r\n:%d\r\n", len(ch.name), ch.name, index+1)
w := io.Writer(*conn)
if _, err := w.Write([]byte(res)); err != nil {
log.Println(err)
} }
groups := utils.Filter[*ConsumerGroup](ch.consumerGroups, func(group *ConsumerGroup) bool {
return group.name == consumerGroupName.(string)
})
if len(groups) == 0 {
go func() {
newGroup := NewConsumerGroup(consumerGroupName.(string))
newGroup.Start()
newGroup.Subscribe(conn)
ch.consumerGroups = append(ch.consumerGroups, newGroup)
}()
return
}
for _, group := range groups {
go group.Subscribe(conn)
} }
} }
func (ch *Channel) Unsubscribe(conn *net.Conn) { func (ch *Channel) Unsubscribe(conn *net.Conn, waitGroup *sync.WaitGroup) {
ch.subscribersRWMut.Lock() ch.subscribersRWMut.Lock()
defer ch.subscribersRWMut.Unlock() defer ch.subscribersRWMut.Unlock()
ch.subscribers = utils.Filter[*net.Conn](ch.subscribers, func(c *net.Conn) bool { ch.subscribers = slices.DeleteFunc(ch.subscribers, func(c *net.Conn) bool {
return c != conn return c == conn
}) })
for _, group := range ch.consumerGroups { if waitGroup != nil {
go group.Unsubscribe(conn) waitGroup.Done()
} }
} }
func (ch *Channel) Publish(message string) { func (ch *Channel) Publish(message string) {
for _, group := range ch.consumerGroups {
go group.Publish(message)
}
*ch.messageChan <- message *ch.messageChan <- message
} }
@@ -260,7 +98,7 @@ func NewPubSub() *PubSub {
} }
} }
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName string, consumerGroup interface{}) { func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName string, index int) {
// Check if channel with given name exists // Check if channel with given name exists
// If it does, subscribe the connection to the channel // If it does, subscribe the connection to the channel
// If it does not, create the channel and subscribe to it // If it does not, create the channel and subscribe to it
@@ -272,29 +110,32 @@ func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName str
go func() { go func() {
newChan := NewChannel(channelName) newChan := NewChannel(channelName)
newChan.Start() newChan.Start()
newChan.Subscribe(conn, consumerGroup) newChan.Subscribe(conn, index)
ps.channels = append(ps.channels, newChan) ps.channels = append(ps.channels, newChan)
}() }()
return return
} }
go ps.channels[channelIdx].Subscribe(conn, consumerGroup) ps.channels[channelIdx].Subscribe(conn, index)
} }
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName interface{}) { func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName string) {
if channelName == nil { if channelName == "*" {
wg := &sync.WaitGroup{}
for _, channel := range ps.channels { for _, channel := range ps.channels {
go channel.Unsubscribe(conn) wg.Add(1)
go channel.Unsubscribe(conn, wg)
} }
wg.Wait()
return return
} }
channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool { channelIdx := slices.IndexFunc(ps.channels, func(channel *Channel) bool {
return c.name == channelName return channel.name == channelName
}) })
for _, channel := range channels { if channelIdx != -1 {
go channel.Unsubscribe(conn) ps.channels[channelIdx].Unsubscribe(conn, nil)
} }
} }

View File

@@ -228,6 +228,11 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
chunkSize := 1024 chunkSize := 1024
// If the length of the response is 0, return nothing to the client
if len(res) == 0 {
continue
}
if len(res) <= chunkSize { if len(res) <= chunkSize {
_, err = w.Write(res) _, err = w.Write(res)
if err != nil { if err != nil {