Files
SugarDB/src/command_modules/pubsub/pubsub.go
2024-01-10 02:37:48 +03:00

303 lines
7.0 KiB
Go

package pubsub
import (
"bufio"
"container/ring"
"context"
"fmt"
"github.com/kelvinmwinuka/memstore/src/utils"
"net"
"strings"
"sync"
"time"
)
// ConsumerGroup allows multiple subscribers to share the consumption load of a channel.
// Only one subscriber in the consumer group will receive messages published to the channel.
type ConsumerGroup struct {
name string
subscribersRWMut sync.RWMutex
subscribers *ring.Ring
messageChan *chan string
}
func NewConsumerGroup(name string) *ConsumerGroup {
messageChan := make(chan string)
return &ConsumerGroup{
name: name,
subscribersRWMut: sync.RWMutex{},
subscribers: nil,
messageChan: &messageChan,
}
}
func (cg *ConsumerGroup) SendMessage(message string) {
cg.subscribersRWMut.RLock()
conn := cg.subscribers.Value.(*net.Conn)
cg.subscribersRWMut.RUnlock()
rw := bufio.NewReadWriter(bufio.NewReader(*conn), bufio.NewWriter(*conn))
rw.WriteString(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))
rw.Flush()
// Wait for an ACK
// If no ACK is received within a time limit, remove this connection from subscribers and retry
(*conn).SetReadDeadline(time.Now().Add(250 * time.Millisecond))
if msg, err := utils.ReadMessage(rw); err != nil {
// Remove the connection from subscribers list
cg.Unsubscribe(conn)
// Reset the deadline
(*conn).SetReadDeadline(time.Time{})
// Retry sending the message
cg.SendMessage(message)
} else {
if strings.TrimSpace(msg) != "+ACK" {
cg.Unsubscribe(conn)
(*conn).SetReadDeadline(time.Time{})
cg.SendMessage(message)
}
}
(*conn).SetDeadline(time.Time{})
cg.subscribers = cg.subscribers.Next()
}
func (cg *ConsumerGroup) Start() {
go func() {
for {
message := <-*cg.messageChan
if cg.subscribers != nil {
cg.SendMessage(message)
}
}
}()
}
func (cg *ConsumerGroup) Subscribe(conn *net.Conn) {
cg.subscribersRWMut.Lock()
defer cg.subscribersRWMut.Unlock()
r := ring.New(1)
for i := 0; i < r.Len(); i++ {
r.Value = conn
r = r.Next()
}
if cg.subscribers == nil {
cg.subscribers = r
return
}
cg.subscribers = cg.subscribers.Link(r)
}
func (cg *ConsumerGroup) Unsubscribe(conn *net.Conn) {
cg.subscribersRWMut.Lock()
defer cg.subscribersRWMut.Unlock()
// If length is 1 and the connection passed is the one contained within, unlink it
if cg.subscribers.Len() == 1 {
if cg.subscribers.Value == conn {
cg.subscribers = nil
}
return
}
for i := 0; i < cg.subscribers.Len(); i++ {
if cg.subscribers.Value == conn {
cg.subscribers = cg.subscribers.Prev()
cg.subscribers.Unlink(1)
break
}
cg.subscribers = cg.subscribers.Next()
}
}
func (cg *ConsumerGroup) Publish(message string) {
*cg.messageChan <- message
}
// Channel - A channel can be subscribed to directly, or via a consumer group.
// All direct subscribers to the channel will receive any message published to the channel.
// Only one subscriber of a channel's consumer group will receive a message posted to the channel.
type Channel struct {
name string
subscribersRWMut sync.RWMutex
subscribers []*net.Conn
consumerGroups []*ConsumerGroup
messageChan *chan string
}
func NewChannel(name string) *Channel {
messageChan := make(chan string)
return &Channel{
name: name,
subscribersRWMut: sync.RWMutex{},
subscribers: []*net.Conn{},
consumerGroups: []*ConsumerGroup{},
messageChan: &messageChan,
}
}
func (ch *Channel) Start() {
go func() {
for {
message := <-*ch.messageChan
ch.subscribersRWMut.RLock()
for _, conn := range ch.subscribers {
go func(conn *net.Conn) {
rw := bufio.NewReadWriter(bufio.NewReader(*conn), bufio.NewWriter(*conn))
rw.WriteString(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))
rw.Flush()
(*conn).SetReadDeadline(time.Now().Add(200 * time.Millisecond))
defer func() {
(*conn).SetReadDeadline(time.Time{})
}()
if msg, err := utils.ReadMessage(rw); err != nil {
ch.Unsubscribe(conn)
} else {
if strings.TrimSpace(msg) != "+ACK" {
ch.Unsubscribe(conn)
}
}
}(conn)
}
ch.subscribersRWMut.RUnlock()
}
}()
}
func (ch *Channel) Subscribe(conn *net.Conn, consumerGroupName interface{}) {
if consumerGroupName == nil && !utils.Contains[*net.Conn](ch.subscribers, conn) {
ch.subscribersRWMut.Lock()
defer ch.subscribersRWMut.Unlock()
ch.subscribers = append(ch.subscribers, conn)
return
}
groups := utils.Filter[*ConsumerGroup](ch.consumerGroups, func(group *ConsumerGroup) bool {
return group.name == consumerGroupName.(string)
})
if len(groups) == 0 {
go func() {
newGroup := NewConsumerGroup(consumerGroupName.(string))
newGroup.Start()
newGroup.Subscribe(conn)
ch.consumerGroups = append(ch.consumerGroups, newGroup)
}()
return
}
for _, group := range groups {
go group.Subscribe(conn)
}
}
func (ch *Channel) Unsubscribe(conn *net.Conn) {
ch.subscribersRWMut.Lock()
defer ch.subscribersRWMut.Unlock()
ch.subscribers = utils.Filter[*net.Conn](ch.subscribers, func(c *net.Conn) bool {
return c != conn
})
for _, group := range ch.consumerGroups {
go group.Unsubscribe(conn)
}
}
func (ch *Channel) Publish(message string) {
for _, group := range ch.consumerGroups {
go group.Publish(message)
}
*ch.messageChan <- message
}
// PubSub container
type PubSub struct {
channels []*Channel
}
func NewPubSub() *PubSub {
return &PubSub{
channels: []*Channel{
NewChannel("chan"),
},
}
}
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName interface{}, consumerGroup interface{}) {
// If no channel name is given, subscribe to all channels
if channelName == nil {
for _, channel := range ps.channels {
go channel.Subscribe(conn, nil)
}
return
}
// Check if channel with given name exists
// If it does, subscribe the connection to the channel
// If it does not, create the channel and subscribe to it
channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool {
return c.name == channelName
})
if len(channels) <= 0 {
go func() {
newChan := NewChannel(channelName.(string))
newChan.Start()
newChan.Subscribe(conn, consumerGroup)
ps.channels = append(ps.channels, newChan)
}()
return
}
for _, channel := range channels {
go channel.Subscribe(conn, consumerGroup)
}
}
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName interface{}) {
if channelName == nil {
for _, channel := range ps.channels {
go channel.Unsubscribe(conn)
}
return
}
channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool {
return c.name == channelName
})
for _, channel := range channels {
go channel.Unsubscribe(conn)
}
}
func (ps *PubSub) Publish(ctx context.Context, message string, channelName interface{}) {
if channelName == nil {
for _, channel := range ps.channels {
go channel.Publish(message)
}
return
}
channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool {
return c.name == channelName
})
for _, channel := range channels {
go channel.Publish(message)
}
}