Refactored PubSub Embedded API

Refactored pubsub implementation to return MessageReader on embedded instance, which implements io.Reader for reading messages (#170) - @kelvinmwinuka
This commit is contained in:
Kelvin Mwinuka
2025-01-26 22:37:14 +08:00
committed by GitHub
parent 4aab2e7799
commit ec69e52a5b
9 changed files with 631 additions and 456 deletions

View File

@@ -29,23 +29,34 @@ Subscribe to one or more patterns. This command accepts glob patterns.
]}
>
<TabItem value="go">
The Subscribe method returns a readMessage function.
This method is lazy so it must be invoked each time the you want to read the next message from
the pattern.
When subscribing to an'N' number of patterns, the first N messages will be
The PSubscribe method returns a MessageReader type which implements the `io.Reader` interface.
When subscribing to an'N' number of channels, the first N messages will be
the subscription confimations.
The readMessage functions returns a message object when called. The message
object is a string slice with the following inforamtion:
event type at index 0 (e.g. subscribe, message), pattern at index 1,
The message read is a series of bytes resulting from JSON marshalling an array. You will need to
unmarshal it back into a string array in order to read it. Here's the anatomy of a pubsub message:
event type at index 0 (e.g. psubscribe, message), channel name at index 1,
message/subscription index at index 2.
Messages published to any channels that match the pattern will be received by the pattern subscriber.
```go
db, err := sugardb.NewSugarDB()
if err != nil {
log.Fatal(err)
}
readMessage := db.PSubscribe("subscribe_tag_1", "pattern_[12]", "pattern_h[ae]llo") // Return lazy readMessage function
for i := 0; i < 2; i++ {
message := readMessage() // Call the readMessage function for each channel subscription.
// Subscribe to multiple channel patterns, returs MessageReader
msgReader := db.PSubscribe("psubscribe_tag_1", "channel[12]", "pattern[12]")
// Read message into pre-defined buffer
msg := make([]byte, 1024)
_, err := msgReader.Read(msg)
// Trim all null bytes at the end of the message before unmarshalling.
var message []string
if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil {
log.Fatalf("json unmarshal error: %+v", err)
}
```
</TabItem>

View File

@@ -29,23 +29,32 @@ Subscribe to one or more channels.
]}
>
<TabItem value="go">
The Subscribe method returns a readMessage function.
This method is lazy so it must be invoked each time the you want to read the next message from
the channel.
The Subscribe method returns a MessageReader type which implements the `io.Reader` interface.
When subscribing to an'N' number of channels, the first N messages will be
the subscription confimations.
The readMessage functions returns a message object when called. The message
object is a string slice with the following inforamtion:
The message read is a series of bytes resulting from JSON marshalling an array. You will need to
unmarshal it back into a string array in order to read it. Here's the anatomy of a pubsub message:
event type at index 0 (e.g. subscribe, message), channel name at index 1,
message/subscription index at index 2.
```go
db, err := sugardb.NewSugarDB()
if err != nil {
log.Fatal(err)
}
readMessage := db.Subscribe("subscribe_tag_1", "channel1", "channel2") // Return lazy readMessage function
for i := 0; i < 2; i++ {
message := readMessage() // Call the readMessage function for each channel subscription.
// Subscribe to multiple channel patterns, returs MessageReader.
msgReader := db.Subscribe("subscribe_tag_1", "channel1", "channel2")
// Read message into pre-defined buffer.
msg := make([]byte, 1024)
_, err := msgReader.Read(msg)
// Trim all null bytes at the end of the message before unmarshalling.
var message []string
if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil {
log.Fatalf("json unmarshal error: %+v", err)
}
```
</TabItem>

View File

@@ -15,19 +15,30 @@
package pubsub
import (
"context"
"encoding/json"
"fmt"
"github.com/gobwas/glob"
"github.com/tidwall/resp"
"log"
"net"
"slices"
"sync"
)
type Channel struct {
name string // Channel name. This can be a glob pattern string.
pattern glob.Glob // Compiled glob pattern. This is nil if the channel is not a pattern channel.
subscribersRWMut sync.RWMutex // RWMutex to concurrency control when accessing channel subscribers.
subscribers map[*net.Conn]*resp.Conn // Map containing the channel subscribers.
messageChan *chan string // Messages published to this channel will be sent to this channel.
name string // Channel name. This can be a glob pattern string.
pattern glob.Glob // Compiled glob pattern. This is nil if the channel is not a pattern channel.
messages []string // Slice that holds messages.
messagesRWMut sync.RWMutex // RWMutex for accessing channel messages.
messagesCond *sync.Cond
tcpSubs map[*net.Conn]*resp.Conn // Map containing the channel's TCP subscribers.
tcpSubsRWMut sync.RWMutex // RWMutex for accessing TCP channel subscribers.
embeddedSubs []*EmbeddedSub // Slice containing embedded subscribers to this channel.
embeddedSubsRWMut sync.RWMutex // RWMutex for accessing embedded subscribers.
}
// WithName option sets the channels name.
@@ -45,46 +56,89 @@ func WithPattern(pattern string) func(channel *Channel) {
}
}
func NewChannel(options ...func(channel *Channel)) *Channel {
messageChan := make(chan string, 4096)
func NewChannel(ctx context.Context, options ...func(channel *Channel)) *Channel {
channel := &Channel{
name: "",
pattern: nil,
subscribersRWMut: sync.RWMutex{},
subscribers: make(map[*net.Conn]*resp.Conn),
messageChan: &messageChan,
name: "",
pattern: nil,
messages: make([]string, 0),
messagesRWMut: sync.RWMutex{},
tcpSubs: make(map[*net.Conn]*resp.Conn),
tcpSubsRWMut: sync.RWMutex{},
embeddedSubs: make([]*EmbeddedSub, 0),
embeddedSubsRWMut: sync.RWMutex{},
}
channel.messagesCond = sync.NewCond(&channel.messagesRWMut)
for _, option := range options {
option(channel)
}
return channel
}
func (ch *Channel) Start() {
go func() {
for {
message := <-*ch.messageChan
select {
case <-ctx.Done():
log.Printf("closing channel %s\n", channel.name)
return
default:
channel.messagesRWMut.Lock()
for len(channel.messages) == 0 {
channel.messagesCond.Wait()
}
ch.subscribersRWMut.RLock()
message := channel.messages[0]
channel.messages = channel.messages[1:]
channel.messagesRWMut.Unlock()
for _, conn := range ch.subscribers {
go func(conn *resp.Conn) {
if err := conn.WriteArray([]resp.Value{
resp.StringValue("message"),
resp.StringValue(ch.name),
resp.StringValue(message),
}); err != nil {
log.Println(err)
wg := sync.WaitGroup{}
wg.Add(2)
// Send messages to embedded subscribers
go func() {
channel.embeddedSubsRWMut.RLock()
ewg := sync.WaitGroup{}
b, _ := json.Marshal([]string{"message", channel.name, message})
msg := append(b, byte('\n'))
for _, w := range channel.embeddedSubs {
ewg.Add(1)
go func(w *EmbeddedSub) {
_, _ = w.Write(msg)
ewg.Done()
}(w)
}
}(conn)
}
ewg.Wait()
channel.embeddedSubsRWMut.RUnlock()
wg.Done()
}()
ch.subscribersRWMut.RUnlock()
// Send messages to TCP subscribers
go func() {
channel.tcpSubsRWMut.RLock()
cwg := sync.WaitGroup{}
for _, conn := range channel.tcpSubs {
cwg.Add(1)
go func(conn *resp.Conn) {
_ = conn.WriteArray([]resp.Value{
resp.StringValue("message"),
resp.StringValue(channel.name),
resp.StringValue(message),
})
cwg.Done()
}(conn)
}
cwg.Wait()
channel.tcpSubsRWMut.RUnlock()
wg.Done()
}()
wg.Wait()
}
}
}()
return channel
}
func (ch *Channel) Name() string {
@@ -95,56 +149,90 @@ func (ch *Channel) Pattern() glob.Glob {
return ch.pattern
}
func (ch *Channel) Subscribe(conn *net.Conn) bool {
ch.subscribersRWMut.Lock()
defer ch.subscribersRWMut.Unlock()
if _, ok := ch.subscribers[conn]; !ok {
ch.subscribers[conn] = resp.NewConn(*conn)
func (ch *Channel) Subscribe(sub any, action string, chanIdx int) {
switch sub.(type) {
case *net.Conn:
ch.tcpSubsRWMut.Lock()
defer ch.tcpSubsRWMut.Unlock()
conn := sub.(*net.Conn)
if _, ok := ch.tcpSubs[conn]; !ok {
ch.tcpSubs[conn] = resp.NewConn(*conn)
}
r, _ := ch.tcpSubs[conn]
// Send subscription message
_ = r.WriteArray([]resp.Value{
resp.StringValue(action),
resp.StringValue(ch.name),
resp.IntegerValue(chanIdx + 1),
})
case *EmbeddedSub:
ch.embeddedSubsRWMut.Lock()
defer ch.embeddedSubsRWMut.Unlock()
w := sub.(*EmbeddedSub)
if !slices.ContainsFunc(ch.embeddedSubs, func(writer *EmbeddedSub) bool {
return writer == w
}) {
ch.embeddedSubs = append(ch.embeddedSubs, w)
}
// Send subscription message
b, _ := json.Marshal([]string{action, ch.name, fmt.Sprintf("%d", chanIdx+1)})
msg := append(b, byte('\n'))
_, _ = w.Write(msg)
}
_, ok := ch.subscribers[conn]
return ok
}
func (ch *Channel) Unsubscribe(conn *net.Conn) bool {
ch.subscribersRWMut.Lock()
defer ch.subscribersRWMut.Unlock()
if _, ok := ch.subscribers[conn]; !ok {
func (ch *Channel) Unsubscribe(sub any) bool {
switch sub.(type) {
default:
return false
case *net.Conn:
ch.tcpSubsRWMut.Lock()
defer ch.tcpSubsRWMut.Unlock()
conn := sub.(*net.Conn)
if _, ok := ch.tcpSubs[conn]; !ok {
return false
}
delete(ch.tcpSubs, conn)
return true
case *EmbeddedSub:
ch.embeddedSubsRWMut.Lock()
defer ch.embeddedSubsRWMut.Unlock()
w := sub.(*EmbeddedSub)
deleted := false
ch.embeddedSubs = slices.DeleteFunc(ch.embeddedSubs, func(writer *EmbeddedSub) bool {
deleted = writer == w
return deleted
})
return deleted
}
delete(ch.subscribers, conn)
return true
}
func (ch *Channel) Publish(message string) {
*ch.messageChan <- message
ch.messagesRWMut.Lock()
defer ch.messagesRWMut.Unlock()
ch.messages = append(ch.messages, message)
ch.messagesCond.Signal()
}
func (ch *Channel) IsActive() bool {
ch.subscribersRWMut.RLock()
defer ch.subscribersRWMut.RUnlock()
ch.tcpSubsRWMut.RLock()
defer ch.tcpSubsRWMut.RUnlock()
active := len(ch.subscribers) > 0
ch.embeddedSubsRWMut.RLock()
defer ch.embeddedSubsRWMut.RUnlock()
return active
return len(ch.tcpSubs)+len(ch.embeddedSubs) > 0
}
func (ch *Channel) NumSubs() int {
ch.subscribersRWMut.RLock()
defer ch.subscribersRWMut.RUnlock()
ch.tcpSubsRWMut.RLock()
defer ch.tcpSubsRWMut.RUnlock()
n := len(ch.subscribers)
ch.embeddedSubsRWMut.RLock()
defer ch.embeddedSubsRWMut.RUnlock()
return n
}
func (ch *Channel) Subscribers() map[*net.Conn]*resp.Conn {
ch.subscribersRWMut.RLock()
defer ch.subscribersRWMut.RUnlock()
subscribers := make(map[*net.Conn]*resp.Conn, len(ch.subscribers))
for k, v := range ch.subscribers {
subscribers[k] = v
}
return subscribers
return len(ch.tcpSubs) + len(ch.embeddedSubs)
}

View File

@@ -35,7 +35,8 @@ func handleSubscribe(params internal.HandlerFuncParams) ([]byte, error) {
}
withPattern := strings.EqualFold(params.Command[0], "psubscribe")
pubsub.Subscribe(params.Context, params.Connection, channels, withPattern)
pubsub.Subscribe(params.Connection, channels, withPattern)
return nil, nil
}
@@ -50,7 +51,7 @@ func handleUnsubscribe(params internal.HandlerFuncParams) ([]byte, error) {
withPattern := strings.EqualFold(params.Command[0], "punsubscribe")
return pubsub.Unsubscribe(params.Context, params.Connection, channels, withPattern), nil
return pubsub.Unsubscribe(params.Connection, channels, withPattern), nil
}
func handlePublish(params internal.HandlerFuncParams) ([]byte, error) {
@@ -61,7 +62,7 @@ func handlePublish(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) != 3 {
return nil, errors.New(constants.WrongArgsResponse)
}
pubsub.Publish(params.Context, params.Command[2], params.Command[1])
pubsub.Publish(params.Command[2], params.Command[1])
return []byte(constants.OkResponse), nil
}

View File

@@ -17,33 +17,31 @@ package pubsub
import (
"context"
"fmt"
"log"
"net"
"slices"
"sync"
"github.com/gobwas/glob"
"github.com/tidwall/resp"
)
type PubSub struct {
ctx context.Context
channels []*Channel
channelsRWMut sync.RWMutex
}
func NewPubSub() *PubSub {
func NewPubSub(ctx context.Context) *PubSub {
return &PubSub{
ctx: ctx,
channels: []*Channel{},
channelsRWMut: sync.RWMutex{},
}
}
func (ps *PubSub) Subscribe(_ context.Context, conn *net.Conn, channels []string, withPattern bool) {
func (ps *PubSub) Subscribe(sub any, channels []string, withPattern bool) {
ps.channelsRWMut.Lock()
defer ps.channelsRWMut.Unlock()
r := resp.NewConn(*conn)
action := "subscribe"
if withPattern {
action = "psubscribe"
@@ -58,40 +56,46 @@ func (ps *PubSub) Subscribe(_ context.Context, conn *net.Conn, channels []string
})
if channelIdx == -1 {
// Create new channel, start it, and subscribe to it
// Create new channel, if it does not exist
var newChan *Channel
if withPattern {
newChan = NewChannel(WithPattern(channels[i]))
newChan = NewChannel(ps.ctx, WithPattern(channels[i]))
} else {
newChan = NewChannel(WithName(channels[i]))
newChan = NewChannel(ps.ctx, WithName(channels[i]))
}
newChan.Start()
if newChan.Subscribe(conn) {
if err := r.WriteArray([]resp.Value{
resp.StringValue(action),
resp.StringValue(newChan.name),
resp.IntegerValue(i + 1),
}); err != nil {
log.Println(err)
}
ps.channels = append(ps.channels, newChan)
// Append the channel to the list of channels
ps.channels = append(ps.channels, newChan)
// Subscribe to the channel
switch sub.(type) {
case *net.Conn:
// Subscribe TCP connection
conn := sub.(*net.Conn)
newChan.Subscribe(conn, action, i)
case *EmbeddedSub:
// Subscribe io.Writer from embedded client
w := sub.(*EmbeddedSub)
newChan.Subscribe(w, action, i)
}
} else {
// Subscribe to existing channel
if ps.channels[channelIdx].Subscribe(conn) {
if err := r.WriteArray([]resp.Value{
resp.StringValue(action),
resp.StringValue(ps.channels[channelIdx].name),
resp.IntegerValue(i + 1),
}); err != nil {
log.Println(err)
}
switch sub.(type) {
case *net.Conn:
// Subscribe TCP connection
conn := sub.(*net.Conn)
ps.channels[channelIdx].Subscribe(conn, action, i)
case *EmbeddedSub:
// Subscribe io.Writer from embedded client
w := sub.(*EmbeddedSub)
ps.channels[channelIdx].Subscribe(w, action, i)
}
}
}
}
func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []string, withPattern bool) []byte {
func (ps *PubSub) Unsubscribe(sub any, channels []string, withPattern bool) []byte {
ps.channelsRWMut.RLock()
defer ps.channelsRWMut.RUnlock()
@@ -111,7 +115,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
if channel.pattern != nil { // Skip pattern channels
continue
}
if channel.Unsubscribe(conn) {
if channel.Unsubscribe(sub) {
unsubscribed[idx] = channel.name
idx += 1
}
@@ -123,7 +127,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
if channel.pattern == nil { // Skip non-pattern channels
continue
}
if channel.Unsubscribe(conn) {
if channel.Unsubscribe(sub) {
unsubscribed[idx] = channel.name
idx += 1
}
@@ -136,7 +140,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
// names exactly matches the pattern name.
for _, channel := range ps.channels { // For each channel in PubSub
for _, c := range channels { // For each channel name provided
if channel.name == c && channel.Unsubscribe(conn) {
if channel.name == c && channel.Unsubscribe(sub) {
unsubscribed[idx] = channel.name
idx += 1
}
@@ -151,7 +155,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
for _, channel := range ps.channels {
// If it's a pattern channel, directly compare the patterns
if channel.pattern != nil && channel.name == pattern {
if channel.Unsubscribe(conn) {
if channel.Unsubscribe(sub) {
unsubscribed[idx] = channel.name
idx += 1
}
@@ -159,7 +163,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
}
// If this is a regular channel, check if the channel name matches the pattern given
if g.Match(channel.name) {
if channel.Unsubscribe(conn) {
if channel.Unsubscribe(sub) {
unsubscribed[idx] = channel.name
idx += 1
}
@@ -176,7 +180,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
return []byte(res)
}
func (ps *PubSub) Publish(_ context.Context, message string, channelName string) {
func (ps *PubSub) Publish(message string, channelName string) {
ps.channelsRWMut.RLock()
defer ps.channelsRWMut.RUnlock()

View File

@@ -0,0 +1,59 @@
// Copyright 2024 Kelvin Clement Mwinuka
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pubsub
import (
"bufio"
"bytes"
"sync"
)
type EmbeddedSub struct {
mux sync.Mutex
buff *bytes.Buffer
writer *bufio.Writer
reader *bufio.Reader
}
func NewEmbeddedSub() *EmbeddedSub {
sub := &EmbeddedSub{
mux: sync.Mutex{},
buff: bytes.NewBuffer(make([]byte, 0)),
}
sub.writer = bufio.NewWriter(sub.buff)
sub.reader = bufio.NewReader(sub.buff)
return sub
}
func (sub *EmbeddedSub) Write(p []byte) (int, error) {
sub.mux.Lock()
defer sub.mux.Unlock()
n, err := sub.writer.Write(p)
if err != nil {
return n, err
}
err = sub.writer.Flush()
return n, err
}
func (sub *EmbeddedSub) Read(p []byte) (int, error) {
sub.mux.Lock()
defer sub.mux.Unlock()
chunk, err := sub.reader.ReadBytes(byte('\n'))
n := copy(p, chunk)
return n, err
}

View File

@@ -16,55 +16,23 @@ package sugardb
import (
"bytes"
"errors"
"github.com/echovault/sugardb/internal"
"github.com/echovault/sugardb/internal/modules/pubsub"
"github.com/tidwall/resp"
"net"
"strings"
"sync"
)
type conn struct {
readConn *net.Conn
writeConn *net.Conn
type MessageReader struct {
embeddedSub *pubsub.EmbeddedSub
}
var connections sync.Map
// ReadPubSubMessage is returned by the Subscribe and PSubscribe functions.
//
// This function is lazy, therefore it needs to be invoked in order to read the next message.
// When the message is read, the function returns a string slice with 3 elements.
// Index 0 holds the event type which in this case will be "message". Index 1 holds the channel name.
// Index 2 holds the actual message.
type ReadPubSubMessage func() []string
func establishConnections(tag string) (*net.Conn, *net.Conn, error) {
var readConn *net.Conn
var writeConn *net.Conn
if _, ok := connections.Load(tag); !ok {
// If connection with this name does not exist, create new connection.
rc, wc := net.Pipe()
readConn = &rc
writeConn = &wc
connections.Store(tag, conn{
readConn: &rc,
writeConn: &wc,
})
} else {
// Reuse existing connection.
c, ok := connections.Load(tag)
if !ok {
return nil, nil, errors.New("could not establish connection")
}
readConn = c.(conn).readConn
writeConn = c.(conn).writeConn
}
return readConn, writeConn, nil
func (reader *MessageReader) Read(p []byte) (int, error) {
return reader.embeddedSub.Read(p)
}
var subscriptions sync.Map
// Subscribe subscribes the caller to the list of provided channels.
//
// Parameters:
@@ -75,31 +43,22 @@ func establishConnections(tag string) (*net.Conn, *net.Conn, error) {
//
// Returns: ReadPubSubMessage function which reads the next message sent to the subscription instance.
// This function is blocking.
func (server *SugarDB) Subscribe(tag string, channels ...string) (ReadPubSubMessage, error) {
readConn, writeConn, err := establishConnections(tag)
if err != nil {
return func() []string {
return []string{}
}, err
func (server *SugarDB) Subscribe(tag string, channels ...string) (*MessageReader, error) {
var msgReader *MessageReader
sub, ok := subscriptions.Load(tag)
if !ok {
// Create new messageBuffer and store it in the subscriptions
msgReader = &MessageReader{
embeddedSub: pubsub.NewEmbeddedSub(),
}
} else {
msgReader = sub.(*MessageReader)
}
// Subscribe connection to the provided channels.
cmd := append([]string{"SUBSCRIBE"}, channels...)
go func() {
_, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), writeConn, false, true)
}()
server.pubSub.Subscribe(msgReader.embeddedSub, channels, false)
return func() []string {
r := resp.NewConn(*readConn)
v, _, _ := r.ReadValue()
res := make([]string, len(v.Array()))
for i := 0; i < len(res); i++ {
res[i] = v.Array()[i].String()
}
return res
}, nil
return msgReader, nil
}
// Unsubscribe unsubscribes the caller from the given channels.
@@ -110,12 +69,12 @@ func (server *SugarDB) Subscribe(tag string, channels ...string) (ReadPubSubMess
//
// `channels` - ...string - The list of channels to unsubscribe from.
func (server *SugarDB) Unsubscribe(tag string, channels ...string) {
c, ok := connections.Load(tag)
sub, ok := subscriptions.Load(tag)
if !ok {
return
}
cmd := append([]string{"UNSUBSCRIBE"}, channels...)
_, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), c.(conn).writeConn, false, true)
msgReader := sub.(*MessageReader)
server.pubSub.Unsubscribe(msgReader, channels, false)
}
// PSubscribe subscribes the caller to the list of provided glob patterns.
@@ -128,31 +87,23 @@ func (server *SugarDB) Unsubscribe(tag string, channels ...string) {
//
// Returns: ReadPubSubMessage function which reads the next message sent to the subscription instance.
// This function is blocking.
func (server *SugarDB) PSubscribe(tag string, patterns ...string) (ReadPubSubMessage, error) {
readConn, writeConn, err := establishConnections(tag)
if err != nil {
return func() []string {
return []string{}
}, err
func (server *SugarDB) PSubscribe(tag string, patterns ...string) (*MessageReader, error) {
var msgReader *MessageReader
sub, ok := subscriptions.Load(tag)
if !ok {
// Create new messageBuffer and store it in the subscriptions
msgReader = &MessageReader{
embeddedSub: pubsub.NewEmbeddedSub(),
}
} else {
msgReader = sub.(*MessageReader)
}
// Subscribe connection to the provided channels
cmd := append([]string{"PSUBSCRIBE"}, patterns...)
go func() {
_, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), writeConn, false, true)
}()
server.pubSub.Subscribe(msgReader.embeddedSub, patterns, true)
return func() []string {
r := resp.NewConn(*readConn)
v, _, _ := r.ReadValue()
res := make([]string, len(v.Array()))
for i := 0; i < len(res); i++ {
res[i] = v.Array()[i].String()
}
return res
}, nil
return msgReader, nil
}
// PUnsubscribe unsubscribes the caller from the given glob patterns.
@@ -163,12 +114,12 @@ func (server *SugarDB) PSubscribe(tag string, patterns ...string) (ReadPubSubMes
//
// `patterns` - ...string - The list of glob patterns to unsubscribe from.
func (server *SugarDB) PUnsubscribe(tag string, patterns ...string) {
c, ok := connections.Load(tag)
sub, ok := subscriptions.Load(tag)
if !ok {
return
}
cmd := append([]string{"PUNSUBSCRIBE"}, patterns...)
_, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), c.(conn).writeConn, false, true)
msgReader := sub.(*MessageReader)
server.pubSub.Unsubscribe(msgReader, patterns, true)
}
// Publish publishes a message to the given channel.
@@ -179,10 +130,12 @@ func (server *SugarDB) PUnsubscribe(tag string, patterns ...string) {
//
// `message` - string - The message to publish to the specified channel.
//
// Returns: true when the publish is successful. This does not indicate whether each subscriber has received the message,
// only that the message has been published.
// Returns: true when successful. This does not indicate whether each subscriber has received the message,
// only that the message has been published to the channel.
func (server *SugarDB) Publish(channel, message string) (bool, error) {
b, err := server.handleCommand(server.context, internal.EncodeCommand([]string{"PUBLISH", channel, message}), nil, false, true)
b, err := server.handleCommand(
server.context,
internal.EncodeCommand([]string{"PUBLISH", channel, message}), nil, false, true)
if err != nil {
return false, err
}

View File

@@ -15,276 +15,326 @@
package sugardb
import (
"bytes"
"encoding/json"
"fmt"
"io"
"reflect"
"slices"
"testing"
"time"
)
func Test_Subscribe(t *testing.T) {
func TestSugarDB_PubSub(t *testing.T) {
server := createSugarDB()
t.Cleanup(func() {
server.ShutDown()
})
// Subscribe to channels.
tag := "tag"
channels := []string{"channel1", "channel2"}
readMessage, err := server.Subscribe(tag, channels...)
if err != nil {
t.Errorf("SUBSCRIBE() error = %v", err)
}
t.Run("TestSugarDB_(P)Subscribe", func(t *testing.T) {
t.Parallel()
for i := 0; i < len(channels); i++ {
message := readMessage()
// Check that we've received the subscribe messages.
if message[0] != "subscribe" {
t.Errorf("SUBSCRIBE() expected index 0 for message at %d to be \"subscribe\", got %s", i, message[0])
tests := []struct {
name string
action string // subscribe | psubscribe
tag string
channels []string
pubChannels []string // Channels to publish messages to after subscriptions
wantMsg []string // Expected messages from after publishing
subFunc func(tag string, channels ...string) (*MessageReader, error)
unsubFunc func(tag string, channels ...string)
}{
{
name: "1. Subscribe to channels",
action: "subscribe",
tag: "tag_test_subscribe",
channels: []string{
"channel1",
"channel2",
},
pubChannels: []string{"channel1", "channel2"},
wantMsg: []string{
"message for channel1",
"message for channel2",
},
subFunc: server.Subscribe,
unsubFunc: server.Unsubscribe,
},
{
name: "2. Subscribe to patterns",
action: "psubscribe",
tag: "tag_test_psubscribe",
channels: []string{
"channel[34]",
"pattern[12]",
},
pubChannels: []string{
"channel3",
"channel4",
"pattern1",
"pattern2",
},
wantMsg: []string{
"message for channel3",
"message for channel4",
"message for pattern1",
"message for pattern2",
},
subFunc: server.PSubscribe,
unsubFunc: server.PUnsubscribe,
},
}
if !slices.Contains(channels, message[1]) {
t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
}
}
// Publish some messages to the channels.
for _, channel := range channels {
ok, err := server.Publish(channel, fmt.Sprintf("message for %s", channel))
if err != nil {
t.Errorf("PUBLISH() err = %v", err)
}
if !ok {
t.Errorf("PUBLISH() could not publish message to channel %s", channel)
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Read messages from the channels
for i := 0; i < len(channels); i++ {
message := readMessage()
// Check that we've received the messages.
if message[0] != "message" {
t.Errorf("SUBSCRIBE() expected index 0 for message at %d to be \"message\", got %s", i, message[0])
}
if !slices.Contains(channels, message[1]) {
t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
}
if !slices.Contains([]string{"message for channel1", "message for channel2"}, message[2]) {
t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
}
}
t.Cleanup(func() {
tt.unsubFunc(tt.tag, tt.channels...)
})
// Unsubscribe from channels
server.Unsubscribe(tag, channels...)
}
func TestSugarDB_PSubscribe(t *testing.T) {
server := createSugarDB()
// Subscribe to channels.
tag := "tag"
patterns := []string{"channel[12]", "pattern[12]"}
readMessage, err := server.PSubscribe(tag, patterns...)
if err != nil {
t.Errorf("PSubscribe() error = %v", err)
}
for i := 0; i < len(patterns); i++ {
message := readMessage()
// Check that we've received the subscribe messages.
if message[0] != "psubscribe" {
t.Errorf("PSUBSCRIBE() expected index 0 for message at %d to be \"psubscribe\", got %s", i, message[0])
}
if !slices.Contains(patterns, message[1]) {
t.Errorf("PSUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
}
}
// Publish some messages to the channels.
for _, channel := range []string{"channel1", "channel2", "pattern1", "pattern2"} {
ok, err := server.Publish(channel, fmt.Sprintf("message for %s", channel))
if err != nil {
t.Errorf("PUBLISH() err = %v", err)
}
if !ok {
t.Errorf("PUBLISH() could not publish message to channel %s", channel)
}
}
// Read messages from the channels
for i := 0; i < len(patterns)*2; i++ {
message := readMessage()
// Check that we've received the messages.
if message[0] != "message" {
t.Errorf("SUBSCRIBE() expected index 0 for message at %d to be \"message\", got %s", i, message[0])
}
if !slices.Contains(patterns, message[1]) {
t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
}
if !slices.Contains([]string{
"message for channel1", "message for channel2", "message for pattern1", "message for pattern2"}, message[2]) {
t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[2], i)
}
}
// Unsubscribe from channels
server.PUnsubscribe(tag, patterns...)
}
func TestSugarDB_PubSubChannels(t *testing.T) {
server := createSugarDB()
tests := []struct {
name string
tag string
channels []string
patterns []string
pattern string
want []string
wantErr bool
}{
{
name: "1. Get number of active channels",
tag: "tag",
channels: []string{"channel1", "channel2", "channel3", "channel4"},
patterns: []string{"channel[56]"},
pattern: "channel[123456]",
want: []string{"channel1", "channel2", "channel3", "channel4"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Subscribe to channels
readChannelMessages, err := server.Subscribe(tt.tag, tt.channels...)
if err != nil {
t.Errorf("PubSubChannels() error = %v", err)
}
for i := 0; i < len(tt.channels); i++ {
readChannelMessages()
}
// Subscribe to patterns
readPatternMessages, err := server.PSubscribe(tt.tag, tt.patterns...)
if err != nil {
t.Errorf("PubSubChannels() error = %v", err)
}
for i := 0; i < len(tt.patterns); i++ {
readPatternMessages()
}
got, err := server.PubSubChannels(tt.pattern)
if (err != nil) != tt.wantErr {
t.Errorf("PubSubChannels() error = %v, wantErr %v", err, tt.wantErr)
return
}
if len(got) != len(tt.want) {
t.Errorf("PubSubChannels() got response length %d, want %d", len(got), len(tt.want))
}
for _, item := range got {
if !slices.Contains(tt.want, item) {
t.Errorf("PubSubChannels() unexpected item \"%s\", in response", item)
// Subscribe to channels.
readMessage, err := tt.subFunc(tt.tag, tt.channels...)
if err != nil {
t.Errorf("(P)SUBSCRIBE() error = %v", err)
}
}
})
}
}
func TestSugarDB_PubSubNumPat(t *testing.T) {
server := createSugarDB()
tests := []struct {
name string
tag string
patterns []string
want int
wantErr bool
}{
{
name: "1. Get number of active patterns on the server",
tag: "tag",
patterns: []string{"channel[56]", "channel[78]"},
want: 2,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Subscribe to patterns
readPatternMessages, err := server.PSubscribe(tt.tag, tt.patterns...)
if err != nil {
t.Errorf("PubSubNumPat() error = %v", err)
}
for i := 0; i < len(tt.patterns); i++ {
readPatternMessages()
}
got, err := server.PubSubNumPat()
if (err != nil) != tt.wantErr {
t.Errorf("PubSubNumPat() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("PubSubNumPat() got = %v, want %v", got, tt.want)
}
})
}
}
for i := 0; i < len(tt.channels); i++ {
p := make([]byte, 1024)
_, err := readMessage.Read(p)
if err != nil {
t.Errorf("(P)SUBSCRIBE() read error: %+v", err)
}
var message []string
if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil {
t.Errorf("(P)SUBSCRIBE() json unmarshal error: %+v", err)
}
// Check that we've received the subscribe messages.
if message[0] != tt.action {
t.Errorf("(P)SUBSCRIBE() expected index 0 for message at %d to be \"subscribe\", got %s", i, message[0])
}
if !slices.Contains(tt.channels, message[1]) {
t.Errorf("(P)SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
}
}
func TestSugarDB_PubSubNumSub(t *testing.T) {
server := createSugarDB()
tests := []struct {
name string
subscriptions map[string]struct {
// Publish some messages to the channels.
for _, channel := range tt.pubChannels {
ok, err := server.Publish(channel, fmt.Sprintf("message for %s", channel))
if err != nil {
t.Errorf("PUBLISH() err = %v", err)
}
if !ok {
t.Errorf("PUBLISH() could not publish message to channel %s", channel)
}
}
// Read messages from the channels
for i := 0; i < len(tt.pubChannels); i++ {
p := make([]byte, 1024)
_, err := readMessage.Read(p)
doneChan := make(chan struct{}, 1)
go func() {
for {
if err != nil && err == io.EOF {
_, err = readMessage.Read(p)
continue
}
doneChan <- struct{}{}
break
}
}()
select {
case <-time.After(500 * time.Millisecond):
t.Errorf("(P)SUBSCRIBE() timeout")
case <-doneChan:
if err != nil {
t.Errorf("(P)SUBSCRIBE() read error: %+v", err)
}
}
var message []string
if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil {
t.Errorf("(P)SUBSCRIBE() json unmarshal error: %+v", err)
}
// Check that we've received the messages.
if message[0] != "message" {
t.Errorf("(P)SUBSCRIBE() expected index 0 for message at %d to be \"message\", got %s", i, message[0])
}
if !slices.Contains(tt.channels, message[1]) {
t.Errorf("(P)SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
}
if !slices.Contains(tt.wantMsg, message[2]) {
t.Errorf("(P)SUBSCRIBE() unexpected string \"%s\" at index 2 for message %d", message[1], i)
}
}
})
}
})
t.Run("TestSugarDB_PubSubChannels", func(t *testing.T) {
tests := []struct {
name string
tag string
channels []string
patterns []string
pattern string
want []string
wantErr bool
}{
{
name: "1. Get number of active channels",
tag: "tag_test_channels_1",
channels: []string{"channel1", "channel2", "channel3", "channel4"},
patterns: []string{"channel[56]"},
pattern: "channel[123456]",
want: []string{"channel1", "channel2", "channel3", "channel4"},
wantErr: false,
},
}
channels []string
want map[string]int
wantErr bool
}{
{
name: "Get number of subscriptions for the given channels",
subscriptions: map[string]struct {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Subscribe to channels
_, err := server.Subscribe(tt.tag, tt.channels...)
if err != nil {
t.Errorf("PubSubChannels() error = %v", err)
}
// Subscribe to patterns
_, err = server.PSubscribe(tt.tag, tt.patterns...)
if err != nil {
t.Errorf("PubSubChannels() error = %v", err)
}
got, err := server.PubSubChannels(tt.pattern)
if (err != nil) != tt.wantErr {
t.Errorf("PubSubChannels() error = %v, wantErr %v", err, tt.wantErr)
return
}
if len(got) != len(tt.want) {
t.Errorf("PubSubChannels() got response length %d, want %d", len(got), len(tt.want))
}
for _, item := range got {
if !slices.Contains(tt.want, item) {
t.Errorf("PubSubChannels() unexpected item \"%s\", in response", item)
}
}
})
}
})
t.Run("TestSugarDB_PubSubNumPat", func(t *testing.T) {
tests := []struct {
name string
tag string
patterns []string
want int
wantErr bool
}{
{
name: "1. Get number of active patterns on the server",
tag: "tag_test_numpat_1",
patterns: []string{"channel[56]", "channel[78]"},
want: 2,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Subscribe to patterns
_, err := server.PSubscribe(tt.tag, tt.patterns...)
if err != nil {
t.Errorf("PubSubNumPat() error = %v", err)
}
got, err := server.PubSubNumPat()
if (err != nil) != tt.wantErr {
t.Errorf("PubSubNumPat() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("PubSubNumPat() got = %v, want %v", got, tt.want)
}
})
}
})
t.Run("TestSugarDB_PubSubNumSub", func(t *testing.T) {
tests := []struct {
name string
subscriptions map[string]struct {
channels []string
patterns []string
}{
"tag1": {
channels: []string{"channel1", "channel2"},
patterns: []string{"channel[34]"},
}
channels []string
want map[string]int
wantErr bool
}{
{
name: "1. Get number of subscriptions for the given channels",
subscriptions: map[string]struct {
channels []string
patterns []string
}{
"tag1_test_numsub_1": {
channels: []string{"test_num_sub_channel1", "test_num_sub_channel2"},
patterns: []string{"test_num_sub_channel[34]"},
},
"tag2_test_numsub_2": {
channels: []string{"test_num_sub_channel2", "test_num_sub_channel3"},
patterns: []string{"test_num_sub_channel[23]"},
},
"tag3_test_numsub_3": {
channels: []string{"test_num_sub_channel2", "test_num_sub_channel4"},
patterns: []string{},
},
},
"tag2": {
channels: []string{"channel2", "channel3"},
patterns: []string{"channel[23]"},
channels: []string{
"test_num_sub_channel1",
"test_num_sub_channel2",
"test_num_sub_channel3",
"test_num_sub_channel4",
"test_num_sub_channel5",
},
"tag3": {
channels: []string{"channel2", "channel4"},
patterns: []string{},
want: map[string]int{
"test_num_sub_channel1": 1,
"test_num_sub_channel2": 3,
"test_num_sub_channel3": 1,
"test_num_sub_channel4": 1,
"test_num_sub_channel5": 0,
},
wantErr: false,
},
channels: []string{"channel1", "channel2", "channel3", "channel4", "channel5"},
want: map[string]int{"channel1": 1, "channel2": 3, "channel3": 1, "channel4": 1, "channel5": 0},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for tag, subs := range tt.subscriptions {
readPat, err := server.PSubscribe(tag, subs.patterns...)
if err != nil {
t.Errorf("PubSubNumSub() error = %v", err)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
for tag, subs := range tt.subscriptions {
_, err := server.PSubscribe(tag, subs.patterns...)
if err != nil {
t.Errorf("PubSubNumSub() error = %v", err)
}
_, err = server.Subscribe(tag, subs.channels...)
if err != nil {
t.Errorf("PubSubNumSub() error = %v", err)
}
}
for _, _ = range subs.patterns {
readPat()
got, err := server.PubSubNumSub(tt.channels...)
if (err != nil) != tt.wantErr {
t.Errorf("PubSubNumSub() error = %v, wantErr %v", err, tt.wantErr)
return
}
readChan, err := server.Subscribe(tag, subs.channels...)
if err != nil {
t.Errorf("PubSubNumSub() error = %v", err)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("PubSubNumSub() got = %v, want %v", got, tt.want)
}
for _, _ = range subs.channels {
readChan()
}
}
got, err := server.PubSubNumSub(tt.channels...)
if (err != nil) != tt.wantErr {
t.Errorf("PubSubNumSub() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("PubSubNumSub() got = %v, want %v", got, tt.want)
}
})
}
})
}
})
}

View File

@@ -204,7 +204,7 @@ func NewSugarDB(options ...func(sugarDB *SugarDB)) (*SugarDB, error) {
sugarDB.acl = acl.NewACL(sugarDB.config)
// Set up Pub/Sub module
sugarDB.pubSub = pubsub.NewPubSub()
sugarDB.pubSub = pubsub.NewPubSub(sugarDB.context)
if sugarDB.isInCluster() {
sugarDB.raft = raft.NewRaft(raft.Opts{