Removed Consumer Group in PubSub module and made the module more compatible with redis client

This commit is contained in:
Kelvin Clement Mwinuka
2024-02-27 17:45:20 +08:00
parent 160c701c3a
commit fc8d301525
4 changed files with 71 additions and 233 deletions

View File

@@ -31,17 +31,5 @@ func Commands() []utils.Command {
},
HandlerFunc: handlePing,
},
{
Command: "ack",
Categories: []string{},
Description: "",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
},
HandlerFunc: func(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
return []byte{}, nil
},
},
}
}

View File

@@ -10,41 +10,46 @@ import (
func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*PubSub)
if !ok {
return nil, errors.New("could not load pubsub")
return nil, errors.New("could not load pubsub module")
}
switch len(cmd) {
case 2:
// Subscribe to specified channel
pubsub.Subscribe(ctx, conn, cmd[1], nil)
case 3:
// Subscribe to specified channel and specified consumer group
pubsub.Subscribe(ctx, conn, cmd[1], cmd[2])
default:
channels := cmd[1:]
if len(channels) == 0 {
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
}
return []byte("+SUBSCRIBE_OK\r\n\r\n"), nil
for i := 0; i < len(channels); i++ {
pubsub.Subscribe(ctx, conn, channels[i], i)
}
return []byte{}, nil
}
func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*PubSub)
if !ok {
return nil, errors.New("could not load pubsub")
return nil, errors.New("could not load pubsub module")
}
switch len(cmd) {
case 1:
pubsub.Unsubscribe(ctx, conn, nil)
case 2:
pubsub.Unsubscribe(ctx, conn, cmd[1])
default:
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
channels := cmd[1:]
if len(channels) == 0 {
pubsub.Unsubscribe(ctx, conn, "*")
return []byte(utils.OK_RESPONSE), nil
}
for _, channel := range channels {
pubsub.Unsubscribe(ctx, conn, channel)
}
return []byte(utils.OK_RESPONSE), nil
}
func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*PubSub)
if !ok {
return nil, errors.New("could not load pubsub")
return nil, errors.New("could not load pubsub module")
}
if len(cmd) != 3 {
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
@@ -72,28 +77,27 @@ func Commands() []utils.Command {
{
Command: "subscribe",
Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory},
Description: "(SUBSCRIBE channel [consumer_group]) Subscribe to a channel with an option to join a consumer group on the channel.",
Description: "(SUBSCRIBE channel [channel ...]) Subscribe to one or more channels.",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
// Treat the channel as a key
if len(cmd) < 2 {
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
}
return []string{cmd[1]}, nil
return cmd[1:], nil
},
HandlerFunc: handleSubscribe,
},
{
Command: "unsubscribe",
Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory},
Description: "(UNSUBSCRIBE channel) Unsubscribe from a channel.",
Description: `(UNSUBSCRIBE [channel [channel ...]]) Unsubscribe from a list of channels.
If the channel list is not provided, then the connection will be unsubscribed from all the channels that
it's currently subscribe to.`,
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
// Treat the channel as a key
if len(cmd) != 2 {
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
}
return []string{cmd[1]}, nil
// Treat the channels as keys
return cmd[1:], nil
},
HandlerFunc: handleUnsubscribe,
},

View File

@@ -1,140 +1,16 @@
package pubsub
import (
"bytes"
"container/ring"
"context"
"fmt"
"github.com/echovault/echovault/src/utils"
"io"
"log"
"net"
"slices"
"sync"
"time"
)
// ConsumerGroup allows multiple subscribers to share the consumption load of a channel.
// Only one subscriber in the consumer group will receive messages published to the channel.
type ConsumerGroup struct {
name string
subscribersRWMut sync.RWMutex
subscribers *ring.Ring
messageChan *chan string
}
func NewConsumerGroup(name string) *ConsumerGroup {
messageChan := make(chan string)
return &ConsumerGroup{
name: name,
subscribersRWMut: sync.RWMutex{},
subscribers: nil,
messageChan: &messageChan,
}
}
func (cg *ConsumerGroup) SendMessage(message string) {
cg.subscribersRWMut.RLock()
conn := cg.subscribers.Value.(*net.Conn)
cg.subscribersRWMut.RUnlock()
w, r := io.Writer(*conn), io.Reader(*conn)
if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
// Wait for an ACK
// If no ACK is received within a time limit, remove this connection from subscribers and retry
if err := (*conn).SetReadDeadline(time.Now().Add(250 * time.Millisecond)); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
if msg, err := utils.ReadMessage(r); err != nil {
// Remove the connection from subscribers list
cg.Unsubscribe(conn)
// Reset the deadline
if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
// Retry sending the message
cg.SendMessage(message)
} else {
if !bytes.Equal(bytes.TrimSpace(msg), []byte("+ACK")) {
cg.Unsubscribe(conn)
if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
cg.SendMessage(message)
}
}
if err := (*conn).SetDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
cg.subscribers = cg.subscribers.Next()
}
func (cg *ConsumerGroup) Start() {
go func() {
for {
message := <-*cg.messageChan
if cg.subscribers != nil {
cg.SendMessage(message)
}
}
}()
}
func (cg *ConsumerGroup) Subscribe(conn *net.Conn) {
cg.subscribersRWMut.Lock()
defer cg.subscribersRWMut.Unlock()
r := ring.New(1)
for i := 0; i < r.Len(); i++ {
r.Value = conn
r = r.Next()
}
if cg.subscribers == nil {
cg.subscribers = r
return
}
cg.subscribers = cg.subscribers.Link(r)
}
func (cg *ConsumerGroup) Unsubscribe(conn *net.Conn) {
cg.subscribersRWMut.Lock()
defer cg.subscribersRWMut.Unlock()
// If length is 1 and the connection passed is the one contained within, unlink it
if cg.subscribers.Len() == 1 {
if cg.subscribers.Value == conn {
cg.subscribers = nil
}
return
}
for i := 0; i < cg.subscribers.Len(); i++ {
if cg.subscribers.Value == conn {
cg.subscribers = cg.subscribers.Prev()
cg.subscribers.Unlink(1)
break
}
cg.subscribers = cg.subscribers.Next()
}
}
func (cg *ConsumerGroup) Publish(message string) {
*cg.messageChan <- message
}
// Channel - A channel can be subscribed to directly, or via a consumer group.
// All direct subscribers to the channel will receive any message published to the channel.
// Only one subscriber of a channel's consumer group will receive a message posted to the channel.
@@ -142,18 +18,16 @@ type Channel struct {
name string
subscribersRWMut sync.RWMutex
subscribers []*net.Conn
consumerGroups []*ConsumerGroup
messageChan *chan string
}
func NewChannel(name string) *Channel {
messageChan := make(chan string)
messageChan := make(chan string, 4096)
return &Channel{
name: name,
subscribersRWMut: sync.RWMutex{},
subscribers: []*net.Conn{},
consumerGroups: []*ConsumerGroup{},
messageChan: &messageChan,
}
}
@@ -167,32 +41,10 @@ func (ch *Channel) Start() {
for _, conn := range ch.subscribers {
go func(conn *net.Conn) {
w, r := io.Writer(*conn), io.Reader(*conn)
w := io.Writer(*conn)
if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\r\n", len(message), message))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
if err := (*conn).SetReadDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
ch.Unsubscribe(conn)
}
defer func() {
if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
ch.Unsubscribe(conn)
}
}()
if msg, err := utils.ReadMessage(r); err != nil {
ch.Unsubscribe(conn)
} else {
if !bytes.EqualFold(bytes.TrimSpace(msg), []byte("+ACK")) {
ch.Unsubscribe(conn)
}
if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(message), message))); err != nil {
log.Println(err)
}
}(conn)
}
@@ -202,50 +54,36 @@ func (ch *Channel) Start() {
}()
}
func (ch *Channel) Subscribe(conn *net.Conn, consumerGroupName interface{}) {
if consumerGroupName == nil && !slices.Contains(ch.subscribers, conn) {
func (ch *Channel) Subscribe(conn *net.Conn, index int) {
if !slices.Contains(ch.subscribers, conn) {
ch.subscribersRWMut.Lock()
defer ch.subscribersRWMut.Unlock()
ch.subscribers = append(ch.subscribers, conn)
return
// Write array to verify the subscription of this channel
res := fmt.Sprintf("*3\r\n+subscribe\r\n$%d\r\n%s\r\n:%d\r\n", len(ch.name), ch.name, index+1)
w := io.Writer(*conn)
if _, err := w.Write([]byte(res)); err != nil {
log.Println(err)
}
groups := utils.Filter[*ConsumerGroup](ch.consumerGroups, func(group *ConsumerGroup) bool {
return group.name == consumerGroupName.(string)
})
if len(groups) == 0 {
go func() {
newGroup := NewConsumerGroup(consumerGroupName.(string))
newGroup.Start()
newGroup.Subscribe(conn)
ch.consumerGroups = append(ch.consumerGroups, newGroup)
}()
return
}
for _, group := range groups {
go group.Subscribe(conn)
}
}
func (ch *Channel) Unsubscribe(conn *net.Conn) {
func (ch *Channel) Unsubscribe(conn *net.Conn, waitGroup *sync.WaitGroup) {
ch.subscribersRWMut.Lock()
defer ch.subscribersRWMut.Unlock()
ch.subscribers = utils.Filter[*net.Conn](ch.subscribers, func(c *net.Conn) bool {
return c != conn
ch.subscribers = slices.DeleteFunc(ch.subscribers, func(c *net.Conn) bool {
return c == conn
})
for _, group := range ch.consumerGroups {
go group.Unsubscribe(conn)
if waitGroup != nil {
waitGroup.Done()
}
}
func (ch *Channel) Publish(message string) {
for _, group := range ch.consumerGroups {
go group.Publish(message)
}
*ch.messageChan <- message
}
@@ -260,7 +98,7 @@ func NewPubSub() *PubSub {
}
}
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName string, consumerGroup interface{}) {
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName string, index int) {
// Check if channel with given name exists
// If it does, subscribe the connection to the channel
// If it does not, create the channel and subscribe to it
@@ -272,29 +110,32 @@ func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName str
go func() {
newChan := NewChannel(channelName)
newChan.Start()
newChan.Subscribe(conn, consumerGroup)
newChan.Subscribe(conn, index)
ps.channels = append(ps.channels, newChan)
}()
return
}
go ps.channels[channelIdx].Subscribe(conn, consumerGroup)
ps.channels[channelIdx].Subscribe(conn, index)
}
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName interface{}) {
if channelName == nil {
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName string) {
if channelName == "*" {
wg := &sync.WaitGroup{}
for _, channel := range ps.channels {
go channel.Unsubscribe(conn)
wg.Add(1)
go channel.Unsubscribe(conn, wg)
}
wg.Wait()
return
}
channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool {
return c.name == channelName
channelIdx := slices.IndexFunc(ps.channels, func(channel *Channel) bool {
return channel.name == channelName
})
for _, channel := range channels {
go channel.Unsubscribe(conn)
if channelIdx != -1 {
ps.channels[channelIdx].Unsubscribe(conn, nil)
}
}

View File

@@ -228,6 +228,11 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
chunkSize := 1024
// If the length of the response is 0, return nothing to the client
if len(res) == 0 {
continue
}
if len(res) <= chunkSize {
_, err = w.Write(res)
if err != nil {