mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-09-27 04:16:06 +08:00

Refactored pubsub implementation to return MessageReader on embedded instance, which implements io.Reader for reading messages (#170) - @kelvinmwinuka
239 lines
5.8 KiB
Go
239 lines
5.8 KiB
Go
// Copyright 2024 Kelvin Clement Mwinuka
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package pubsub
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/gobwas/glob"
|
|
"github.com/tidwall/resp"
|
|
"log"
|
|
"net"
|
|
"slices"
|
|
"sync"
|
|
)
|
|
|
|
type Channel struct {
|
|
name string // Channel name. This can be a glob pattern string.
|
|
pattern glob.Glob // Compiled glob pattern. This is nil if the channel is not a pattern channel.
|
|
|
|
messages []string // Slice that holds messages.
|
|
messagesRWMut sync.RWMutex // RWMutex for accessing channel messages.
|
|
messagesCond *sync.Cond
|
|
|
|
tcpSubs map[*net.Conn]*resp.Conn // Map containing the channel's TCP subscribers.
|
|
tcpSubsRWMut sync.RWMutex // RWMutex for accessing TCP channel subscribers.
|
|
|
|
embeddedSubs []*EmbeddedSub // Slice containing embedded subscribers to this channel.
|
|
embeddedSubsRWMut sync.RWMutex // RWMutex for accessing embedded subscribers.
|
|
}
|
|
|
|
// WithName option sets the channels name.
|
|
func WithName(name string) func(channel *Channel) {
|
|
return func(channel *Channel) {
|
|
channel.name = name
|
|
}
|
|
}
|
|
|
|
// WithPattern option sets the compiled glob pattern for the channel if it's a pattern channel.
|
|
func WithPattern(pattern string) func(channel *Channel) {
|
|
return func(channel *Channel) {
|
|
channel.name = pattern
|
|
channel.pattern = glob.MustCompile(pattern)
|
|
}
|
|
}
|
|
|
|
func NewChannel(ctx context.Context, options ...func(channel *Channel)) *Channel {
|
|
channel := &Channel{
|
|
name: "",
|
|
pattern: nil,
|
|
|
|
messages: make([]string, 0),
|
|
messagesRWMut: sync.RWMutex{},
|
|
|
|
tcpSubs: make(map[*net.Conn]*resp.Conn),
|
|
tcpSubsRWMut: sync.RWMutex{},
|
|
|
|
embeddedSubs: make([]*EmbeddedSub, 0),
|
|
embeddedSubsRWMut: sync.RWMutex{},
|
|
}
|
|
channel.messagesCond = sync.NewCond(&channel.messagesRWMut)
|
|
|
|
for _, option := range options {
|
|
option(channel)
|
|
}
|
|
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Printf("closing channel %s\n", channel.name)
|
|
return
|
|
default:
|
|
channel.messagesRWMut.Lock()
|
|
for len(channel.messages) == 0 {
|
|
channel.messagesCond.Wait()
|
|
}
|
|
|
|
message := channel.messages[0]
|
|
channel.messages = channel.messages[1:]
|
|
channel.messagesRWMut.Unlock()
|
|
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(2)
|
|
|
|
// Send messages to embedded subscribers
|
|
go func() {
|
|
channel.embeddedSubsRWMut.RLock()
|
|
ewg := sync.WaitGroup{}
|
|
b, _ := json.Marshal([]string{"message", channel.name, message})
|
|
msg := append(b, byte('\n'))
|
|
for _, w := range channel.embeddedSubs {
|
|
ewg.Add(1)
|
|
go func(w *EmbeddedSub) {
|
|
_, _ = w.Write(msg)
|
|
ewg.Done()
|
|
}(w)
|
|
}
|
|
ewg.Wait()
|
|
channel.embeddedSubsRWMut.RUnlock()
|
|
wg.Done()
|
|
}()
|
|
|
|
// Send messages to TCP subscribers
|
|
go func() {
|
|
channel.tcpSubsRWMut.RLock()
|
|
cwg := sync.WaitGroup{}
|
|
for _, conn := range channel.tcpSubs {
|
|
cwg.Add(1)
|
|
go func(conn *resp.Conn) {
|
|
_ = conn.WriteArray([]resp.Value{
|
|
resp.StringValue("message"),
|
|
resp.StringValue(channel.name),
|
|
resp.StringValue(message),
|
|
})
|
|
cwg.Done()
|
|
}(conn)
|
|
}
|
|
cwg.Wait()
|
|
channel.tcpSubsRWMut.RUnlock()
|
|
wg.Done()
|
|
}()
|
|
|
|
wg.Wait()
|
|
}
|
|
}
|
|
}()
|
|
|
|
return channel
|
|
}
|
|
|
|
func (ch *Channel) Name() string {
|
|
return ch.name
|
|
}
|
|
|
|
func (ch *Channel) Pattern() glob.Glob {
|
|
return ch.pattern
|
|
}
|
|
|
|
func (ch *Channel) Subscribe(sub any, action string, chanIdx int) {
|
|
switch sub.(type) {
|
|
case *net.Conn:
|
|
ch.tcpSubsRWMut.Lock()
|
|
defer ch.tcpSubsRWMut.Unlock()
|
|
conn := sub.(*net.Conn)
|
|
if _, ok := ch.tcpSubs[conn]; !ok {
|
|
ch.tcpSubs[conn] = resp.NewConn(*conn)
|
|
}
|
|
r, _ := ch.tcpSubs[conn]
|
|
// Send subscription message
|
|
_ = r.WriteArray([]resp.Value{
|
|
resp.StringValue(action),
|
|
resp.StringValue(ch.name),
|
|
resp.IntegerValue(chanIdx + 1),
|
|
})
|
|
|
|
case *EmbeddedSub:
|
|
ch.embeddedSubsRWMut.Lock()
|
|
defer ch.embeddedSubsRWMut.Unlock()
|
|
w := sub.(*EmbeddedSub)
|
|
if !slices.ContainsFunc(ch.embeddedSubs, func(writer *EmbeddedSub) bool {
|
|
return writer == w
|
|
}) {
|
|
ch.embeddedSubs = append(ch.embeddedSubs, w)
|
|
}
|
|
// Send subscription message
|
|
b, _ := json.Marshal([]string{action, ch.name, fmt.Sprintf("%d", chanIdx+1)})
|
|
msg := append(b, byte('\n'))
|
|
_, _ = w.Write(msg)
|
|
}
|
|
}
|
|
|
|
func (ch *Channel) Unsubscribe(sub any) bool {
|
|
switch sub.(type) {
|
|
default:
|
|
return false
|
|
|
|
case *net.Conn:
|
|
ch.tcpSubsRWMut.Lock()
|
|
defer ch.tcpSubsRWMut.Unlock()
|
|
conn := sub.(*net.Conn)
|
|
if _, ok := ch.tcpSubs[conn]; !ok {
|
|
return false
|
|
}
|
|
delete(ch.tcpSubs, conn)
|
|
return true
|
|
|
|
case *EmbeddedSub:
|
|
ch.embeddedSubsRWMut.Lock()
|
|
defer ch.embeddedSubsRWMut.Unlock()
|
|
w := sub.(*EmbeddedSub)
|
|
deleted := false
|
|
ch.embeddedSubs = slices.DeleteFunc(ch.embeddedSubs, func(writer *EmbeddedSub) bool {
|
|
deleted = writer == w
|
|
return deleted
|
|
})
|
|
return deleted
|
|
}
|
|
}
|
|
|
|
func (ch *Channel) Publish(message string) {
|
|
ch.messagesRWMut.Lock()
|
|
defer ch.messagesRWMut.Unlock()
|
|
ch.messages = append(ch.messages, message)
|
|
ch.messagesCond.Signal()
|
|
}
|
|
|
|
func (ch *Channel) IsActive() bool {
|
|
ch.tcpSubsRWMut.RLock()
|
|
defer ch.tcpSubsRWMut.RUnlock()
|
|
|
|
ch.embeddedSubsRWMut.RLock()
|
|
defer ch.embeddedSubsRWMut.RUnlock()
|
|
|
|
return len(ch.tcpSubs)+len(ch.embeddedSubs) > 0
|
|
}
|
|
|
|
func (ch *Channel) NumSubs() int {
|
|
ch.tcpSubsRWMut.RLock()
|
|
defer ch.tcpSubsRWMut.RUnlock()
|
|
|
|
ch.embeddedSubsRWMut.RLock()
|
|
defer ch.embeddedSubsRWMut.RUnlock()
|
|
|
|
return len(ch.tcpSubs) + len(ch.embeddedSubs)
|
|
}
|