mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-09 18:00:23 +08:00
303 lines
7.0 KiB
Go
303 lines
7.0 KiB
Go
package pubsub
|
|
|
|
import (
|
|
"bufio"
|
|
"container/ring"
|
|
"context"
|
|
"fmt"
|
|
"github.com/kelvinmwinuka/memstore/src/utils"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// ConsumerGroup allows multiple subscribers to share the consumption load of a channel.
|
|
// Only one subscriber in the consumer group will receive messages published to the channel.
|
|
type ConsumerGroup struct {
|
|
name string
|
|
subscribersRWMut sync.RWMutex
|
|
subscribers *ring.Ring
|
|
messageChan *chan string
|
|
}
|
|
|
|
func NewConsumerGroup(name string) *ConsumerGroup {
|
|
messageChan := make(chan string)
|
|
|
|
return &ConsumerGroup{
|
|
name: name,
|
|
subscribersRWMut: sync.RWMutex{},
|
|
subscribers: nil,
|
|
messageChan: &messageChan,
|
|
}
|
|
}
|
|
|
|
func (cg *ConsumerGroup) SendMessage(message string) {
|
|
cg.subscribersRWMut.RLock()
|
|
|
|
conn := cg.subscribers.Value.(*net.Conn)
|
|
|
|
cg.subscribersRWMut.RUnlock()
|
|
|
|
rw := bufio.NewReadWriter(bufio.NewReader(*conn), bufio.NewWriter(*conn))
|
|
rw.WriteString(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))
|
|
rw.Flush()
|
|
|
|
// Wait for an ACK
|
|
// If no ACK is received within a time limit, remove this connection from subscribers and retry
|
|
(*conn).SetReadDeadline(time.Now().Add(250 * time.Millisecond))
|
|
if msg, err := utils.ReadMessage(rw); err != nil {
|
|
// Remove the connection from subscribers list
|
|
cg.Unsubscribe(conn)
|
|
// Reset the deadline
|
|
(*conn).SetReadDeadline(time.Time{})
|
|
// Retry sending the message
|
|
cg.SendMessage(message)
|
|
} else {
|
|
if strings.TrimSpace(msg) != "+ACK" {
|
|
cg.Unsubscribe(conn)
|
|
(*conn).SetReadDeadline(time.Time{})
|
|
cg.SendMessage(message)
|
|
}
|
|
}
|
|
|
|
(*conn).SetDeadline(time.Time{})
|
|
cg.subscribers = cg.subscribers.Next()
|
|
}
|
|
|
|
func (cg *ConsumerGroup) Start() {
|
|
go func() {
|
|
for {
|
|
message := <-*cg.messageChan
|
|
if cg.subscribers != nil {
|
|
cg.SendMessage(message)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (cg *ConsumerGroup) Subscribe(conn *net.Conn) {
|
|
cg.subscribersRWMut.Lock()
|
|
defer cg.subscribersRWMut.Unlock()
|
|
|
|
r := ring.New(1)
|
|
for i := 0; i < r.Len(); i++ {
|
|
r.Value = conn
|
|
r = r.Next()
|
|
}
|
|
|
|
if cg.subscribers == nil {
|
|
cg.subscribers = r
|
|
return
|
|
}
|
|
|
|
cg.subscribers = cg.subscribers.Link(r)
|
|
}
|
|
|
|
func (cg *ConsumerGroup) Unsubscribe(conn *net.Conn) {
|
|
cg.subscribersRWMut.Lock()
|
|
defer cg.subscribersRWMut.Unlock()
|
|
|
|
// If length is 1 and the connection passed is the one contained within, unlink it
|
|
if cg.subscribers.Len() == 1 {
|
|
if cg.subscribers.Value == conn {
|
|
cg.subscribers = nil
|
|
}
|
|
return
|
|
}
|
|
|
|
for i := 0; i < cg.subscribers.Len(); i++ {
|
|
if cg.subscribers.Value == conn {
|
|
cg.subscribers = cg.subscribers.Prev()
|
|
cg.subscribers.Unlink(1)
|
|
break
|
|
}
|
|
cg.subscribers = cg.subscribers.Next()
|
|
}
|
|
}
|
|
|
|
func (cg *ConsumerGroup) Publish(message string) {
|
|
*cg.messageChan <- message
|
|
}
|
|
|
|
// Channel - A channel can be subscribed to directly, or via a consumer group.
|
|
// All direct subscribers to the channel will receive any message published to the channel.
|
|
// Only one subscriber of a channel's consumer group will receive a message posted to the channel.
|
|
type Channel struct {
|
|
name string
|
|
subscribersRWMut sync.RWMutex
|
|
subscribers []*net.Conn
|
|
consumerGroups []*ConsumerGroup
|
|
messageChan *chan string
|
|
}
|
|
|
|
func NewChannel(name string) *Channel {
|
|
messageChan := make(chan string)
|
|
|
|
return &Channel{
|
|
name: name,
|
|
subscribersRWMut: sync.RWMutex{},
|
|
subscribers: []*net.Conn{},
|
|
consumerGroups: []*ConsumerGroup{},
|
|
messageChan: &messageChan,
|
|
}
|
|
}
|
|
|
|
func (ch *Channel) Start() {
|
|
go func() {
|
|
for {
|
|
message := <-*ch.messageChan
|
|
|
|
ch.subscribersRWMut.RLock()
|
|
|
|
for _, conn := range ch.subscribers {
|
|
go func(conn *net.Conn) {
|
|
rw := bufio.NewReadWriter(bufio.NewReader(*conn), bufio.NewWriter(*conn))
|
|
rw.WriteString(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))
|
|
rw.Flush()
|
|
|
|
(*conn).SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
|
defer func() {
|
|
(*conn).SetReadDeadline(time.Time{})
|
|
}()
|
|
|
|
if msg, err := utils.ReadMessage(rw); err != nil {
|
|
ch.Unsubscribe(conn)
|
|
} else {
|
|
if strings.TrimSpace(msg) != "+ACK" {
|
|
ch.Unsubscribe(conn)
|
|
}
|
|
}
|
|
}(conn)
|
|
}
|
|
|
|
ch.subscribersRWMut.RUnlock()
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (ch *Channel) Subscribe(conn *net.Conn, consumerGroupName interface{}) {
|
|
if consumerGroupName == nil && !utils.Contains[*net.Conn](ch.subscribers, conn) {
|
|
ch.subscribersRWMut.Lock()
|
|
defer ch.subscribersRWMut.Unlock()
|
|
ch.subscribers = append(ch.subscribers, conn)
|
|
return
|
|
}
|
|
|
|
groups := utils.Filter[*ConsumerGroup](ch.consumerGroups, func(group *ConsumerGroup) bool {
|
|
return group.name == consumerGroupName.(string)
|
|
})
|
|
|
|
if len(groups) == 0 {
|
|
go func() {
|
|
newGroup := NewConsumerGroup(consumerGroupName.(string))
|
|
newGroup.Start()
|
|
newGroup.Subscribe(conn)
|
|
ch.consumerGroups = append(ch.consumerGroups, newGroup)
|
|
}()
|
|
return
|
|
}
|
|
|
|
for _, group := range groups {
|
|
go group.Subscribe(conn)
|
|
}
|
|
}
|
|
|
|
func (ch *Channel) Unsubscribe(conn *net.Conn) {
|
|
ch.subscribersRWMut.Lock()
|
|
defer ch.subscribersRWMut.Unlock()
|
|
|
|
ch.subscribers = utils.Filter[*net.Conn](ch.subscribers, func(c *net.Conn) bool {
|
|
return c != conn
|
|
})
|
|
|
|
for _, group := range ch.consumerGroups {
|
|
go group.Unsubscribe(conn)
|
|
}
|
|
}
|
|
|
|
func (ch *Channel) Publish(message string) {
|
|
for _, group := range ch.consumerGroups {
|
|
go group.Publish(message)
|
|
}
|
|
*ch.messageChan <- message
|
|
}
|
|
|
|
// PubSub container
|
|
type PubSub struct {
|
|
channels []*Channel
|
|
}
|
|
|
|
func NewPubSub() *PubSub {
|
|
return &PubSub{
|
|
channels: []*Channel{
|
|
NewChannel("chan"),
|
|
},
|
|
}
|
|
}
|
|
|
|
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName interface{}, consumerGroup interface{}) {
|
|
// If no channel name is given, subscribe to all channels
|
|
if channelName == nil {
|
|
for _, channel := range ps.channels {
|
|
go channel.Subscribe(conn, nil)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Check if channel with given name exists
|
|
// If it does, subscribe the connection to the channel
|
|
// If it does not, create the channel and subscribe to it
|
|
channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool {
|
|
return c.name == channelName
|
|
})
|
|
|
|
if len(channels) <= 0 {
|
|
go func() {
|
|
newChan := NewChannel(channelName.(string))
|
|
newChan.Start()
|
|
newChan.Subscribe(conn, consumerGroup)
|
|
ps.channels = append(ps.channels, newChan)
|
|
}()
|
|
return
|
|
}
|
|
|
|
for _, channel := range channels {
|
|
go channel.Subscribe(conn, consumerGroup)
|
|
}
|
|
}
|
|
|
|
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName interface{}) {
|
|
if channelName == nil {
|
|
for _, channel := range ps.channels {
|
|
go channel.Unsubscribe(conn)
|
|
}
|
|
return
|
|
}
|
|
|
|
channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool {
|
|
return c.name == channelName
|
|
})
|
|
|
|
for _, channel := range channels {
|
|
go channel.Unsubscribe(conn)
|
|
}
|
|
}
|
|
|
|
func (ps *PubSub) Publish(ctx context.Context, message string, channelName interface{}) {
|
|
if channelName == nil {
|
|
for _, channel := range ps.channels {
|
|
go channel.Publish(message)
|
|
}
|
|
return
|
|
}
|
|
|
|
channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool {
|
|
return c.name == channelName
|
|
})
|
|
|
|
for _, channel := range channels {
|
|
go channel.Publish(message)
|
|
}
|
|
}
|