Refactored PubSub Embedded API

Refactored pubsub implementation to return MessageReader on embedded instance, which implements io.Reader for reading messages (#170) - @kelvinmwinuka
This commit is contained in:
Kelvin Mwinuka
2025-01-26 22:37:14 +08:00
committed by GitHub
parent 4aab2e7799
commit ec69e52a5b
9 changed files with 631 additions and 456 deletions

View File

@@ -29,23 +29,34 @@ Subscribe to one or more patterns. This command accepts glob patterns.
]} ]}
> >
<TabItem value="go"> <TabItem value="go">
The Subscribe method returns a readMessage function. The PSubscribe method returns a MessageReader type which implements the `io.Reader` interface.
This method is lazy so it must be invoked each time the you want to read the next message from When subscribing to an'N' number of channels, the first N messages will be
the pattern.
When subscribing to an'N' number of patterns, the first N messages will be
the subscription confimations. the subscription confimations.
The readMessage functions returns a message object when called. The message
object is a string slice with the following inforamtion: The message read is a series of bytes resulting from JSON marshalling an array. You will need to
event type at index 0 (e.g. subscribe, message), pattern at index 1, unmarshal it back into a string array in order to read it. Here's the anatomy of a pubsub message:
event type at index 0 (e.g. psubscribe, message), channel name at index 1,
message/subscription index at index 2. message/subscription index at index 2.
Messages published to any channels that match the pattern will be received by the pattern subscriber.
```go ```go
db, err := sugardb.NewSugarDB() db, err := sugardb.NewSugarDB()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
readMessage := db.PSubscribe("subscribe_tag_1", "pattern_[12]", "pattern_h[ae]llo") // Return lazy readMessage function
for i := 0; i < 2; i++ { // Subscribe to multiple channel patterns, returs MessageReader
message := readMessage() // Call the readMessage function for each channel subscription. msgReader := db.PSubscribe("psubscribe_tag_1", "channel[12]", "pattern[12]")
// Read message into pre-defined buffer
msg := make([]byte, 1024)
_, err := msgReader.Read(msg)
// Trim all null bytes at the end of the message before unmarshalling.
var message []string
if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil {
log.Fatalf("json unmarshal error: %+v", err)
} }
``` ```
</TabItem> </TabItem>

View File

@@ -29,23 +29,32 @@ Subscribe to one or more channels.
]} ]}
> >
<TabItem value="go"> <TabItem value="go">
The Subscribe method returns a readMessage function. The Subscribe method returns a MessageReader type which implements the `io.Reader` interface.
This method is lazy so it must be invoked each time the you want to read the next message from
the channel.
When subscribing to an'N' number of channels, the first N messages will be When subscribing to an'N' number of channels, the first N messages will be
the subscription confimations. the subscription confimations.
The readMessage functions returns a message object when called. The message
object is a string slice with the following inforamtion: The message read is a series of bytes resulting from JSON marshalling an array. You will need to
unmarshal it back into a string array in order to read it. Here's the anatomy of a pubsub message:
event type at index 0 (e.g. subscribe, message), channel name at index 1, event type at index 0 (e.g. subscribe, message), channel name at index 1,
message/subscription index at index 2. message/subscription index at index 2.
```go ```go
db, err := sugardb.NewSugarDB() db, err := sugardb.NewSugarDB()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
readMessage := db.Subscribe("subscribe_tag_1", "channel1", "channel2") // Return lazy readMessage function
for i := 0; i < 2; i++ { // Subscribe to multiple channel patterns, returs MessageReader.
message := readMessage() // Call the readMessage function for each channel subscription. msgReader := db.Subscribe("subscribe_tag_1", "channel1", "channel2")
// Read message into pre-defined buffer.
msg := make([]byte, 1024)
_, err := msgReader.Read(msg)
// Trim all null bytes at the end of the message before unmarshalling.
var message []string
if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil {
log.Fatalf("json unmarshal error: %+v", err)
} }
``` ```
</TabItem> </TabItem>

View File

@@ -15,19 +15,30 @@
package pubsub package pubsub
import ( import (
"context"
"encoding/json"
"fmt"
"github.com/gobwas/glob" "github.com/gobwas/glob"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"log" "log"
"net" "net"
"slices"
"sync" "sync"
) )
type Channel struct { type Channel struct {
name string // Channel name. This can be a glob pattern string. name string // Channel name. This can be a glob pattern string.
pattern glob.Glob // Compiled glob pattern. This is nil if the channel is not a pattern channel. pattern glob.Glob // Compiled glob pattern. This is nil if the channel is not a pattern channel.
subscribersRWMut sync.RWMutex // RWMutex to concurrency control when accessing channel subscribers.
subscribers map[*net.Conn]*resp.Conn // Map containing the channel subscribers. messages []string // Slice that holds messages.
messageChan *chan string // Messages published to this channel will be sent to this channel. messagesRWMut sync.RWMutex // RWMutex for accessing channel messages.
messagesCond *sync.Cond
tcpSubs map[*net.Conn]*resp.Conn // Map containing the channel's TCP subscribers.
tcpSubsRWMut sync.RWMutex // RWMutex for accessing TCP channel subscribers.
embeddedSubs []*EmbeddedSub // Slice containing embedded subscribers to this channel.
embeddedSubsRWMut sync.RWMutex // RWMutex for accessing embedded subscribers.
} }
// WithName option sets the channels name. // WithName option sets the channels name.
@@ -45,46 +56,89 @@ func WithPattern(pattern string) func(channel *Channel) {
} }
} }
func NewChannel(options ...func(channel *Channel)) *Channel { func NewChannel(ctx context.Context, options ...func(channel *Channel)) *Channel {
messageChan := make(chan string, 4096)
channel := &Channel{ channel := &Channel{
name: "", name: "",
pattern: nil, pattern: nil,
subscribersRWMut: sync.RWMutex{},
subscribers: make(map[*net.Conn]*resp.Conn), messages: make([]string, 0),
messageChan: &messageChan, messagesRWMut: sync.RWMutex{},
tcpSubs: make(map[*net.Conn]*resp.Conn),
tcpSubsRWMut: sync.RWMutex{},
embeddedSubs: make([]*EmbeddedSub, 0),
embeddedSubsRWMut: sync.RWMutex{},
} }
channel.messagesCond = sync.NewCond(&channel.messagesRWMut)
for _, option := range options { for _, option := range options {
option(channel) option(channel)
} }
return channel
}
func (ch *Channel) Start() {
go func() { go func() {
for { for {
message := <-*ch.messageChan select {
case <-ctx.Done():
log.Printf("closing channel %s\n", channel.name)
return
default:
channel.messagesRWMut.Lock()
for len(channel.messages) == 0 {
channel.messagesCond.Wait()
}
ch.subscribersRWMut.RLock() message := channel.messages[0]
channel.messages = channel.messages[1:]
channel.messagesRWMut.Unlock()
for _, conn := range ch.subscribers { wg := sync.WaitGroup{}
go func(conn *resp.Conn) { wg.Add(2)
if err := conn.WriteArray([]resp.Value{
resp.StringValue("message"), // Send messages to embedded subscribers
resp.StringValue(ch.name), go func() {
resp.StringValue(message), channel.embeddedSubsRWMut.RLock()
}); err != nil { ewg := sync.WaitGroup{}
log.Println(err) b, _ := json.Marshal([]string{"message", channel.name, message})
msg := append(b, byte('\n'))
for _, w := range channel.embeddedSubs {
ewg.Add(1)
go func(w *EmbeddedSub) {
_, _ = w.Write(msg)
ewg.Done()
}(w)
} }
}(conn) ewg.Wait()
} channel.embeddedSubsRWMut.RUnlock()
wg.Done()
}()
ch.subscribersRWMut.RUnlock() // Send messages to TCP subscribers
go func() {
channel.tcpSubsRWMut.RLock()
cwg := sync.WaitGroup{}
for _, conn := range channel.tcpSubs {
cwg.Add(1)
go func(conn *resp.Conn) {
_ = conn.WriteArray([]resp.Value{
resp.StringValue("message"),
resp.StringValue(channel.name),
resp.StringValue(message),
})
cwg.Done()
}(conn)
}
cwg.Wait()
channel.tcpSubsRWMut.RUnlock()
wg.Done()
}()
wg.Wait()
}
} }
}() }()
return channel
} }
func (ch *Channel) Name() string { func (ch *Channel) Name() string {
@@ -95,56 +149,90 @@ func (ch *Channel) Pattern() glob.Glob {
return ch.pattern return ch.pattern
} }
func (ch *Channel) Subscribe(conn *net.Conn) bool { func (ch *Channel) Subscribe(sub any, action string, chanIdx int) {
ch.subscribersRWMut.Lock() switch sub.(type) {
defer ch.subscribersRWMut.Unlock() case *net.Conn:
if _, ok := ch.subscribers[conn]; !ok { ch.tcpSubsRWMut.Lock()
ch.subscribers[conn] = resp.NewConn(*conn) defer ch.tcpSubsRWMut.Unlock()
conn := sub.(*net.Conn)
if _, ok := ch.tcpSubs[conn]; !ok {
ch.tcpSubs[conn] = resp.NewConn(*conn)
}
r, _ := ch.tcpSubs[conn]
// Send subscription message
_ = r.WriteArray([]resp.Value{
resp.StringValue(action),
resp.StringValue(ch.name),
resp.IntegerValue(chanIdx + 1),
})
case *EmbeddedSub:
ch.embeddedSubsRWMut.Lock()
defer ch.embeddedSubsRWMut.Unlock()
w := sub.(*EmbeddedSub)
if !slices.ContainsFunc(ch.embeddedSubs, func(writer *EmbeddedSub) bool {
return writer == w
}) {
ch.embeddedSubs = append(ch.embeddedSubs, w)
}
// Send subscription message
b, _ := json.Marshal([]string{action, ch.name, fmt.Sprintf("%d", chanIdx+1)})
msg := append(b, byte('\n'))
_, _ = w.Write(msg)
} }
_, ok := ch.subscribers[conn]
return ok
} }
func (ch *Channel) Unsubscribe(conn *net.Conn) bool { func (ch *Channel) Unsubscribe(sub any) bool {
ch.subscribersRWMut.Lock() switch sub.(type) {
defer ch.subscribersRWMut.Unlock() default:
if _, ok := ch.subscribers[conn]; !ok {
return false return false
case *net.Conn:
ch.tcpSubsRWMut.Lock()
defer ch.tcpSubsRWMut.Unlock()
conn := sub.(*net.Conn)
if _, ok := ch.tcpSubs[conn]; !ok {
return false
}
delete(ch.tcpSubs, conn)
return true
case *EmbeddedSub:
ch.embeddedSubsRWMut.Lock()
defer ch.embeddedSubsRWMut.Unlock()
w := sub.(*EmbeddedSub)
deleted := false
ch.embeddedSubs = slices.DeleteFunc(ch.embeddedSubs, func(writer *EmbeddedSub) bool {
deleted = writer == w
return deleted
})
return deleted
} }
delete(ch.subscribers, conn)
return true
} }
func (ch *Channel) Publish(message string) { func (ch *Channel) Publish(message string) {
*ch.messageChan <- message ch.messagesRWMut.Lock()
defer ch.messagesRWMut.Unlock()
ch.messages = append(ch.messages, message)
ch.messagesCond.Signal()
} }
func (ch *Channel) IsActive() bool { func (ch *Channel) IsActive() bool {
ch.subscribersRWMut.RLock() ch.tcpSubsRWMut.RLock()
defer ch.subscribersRWMut.RUnlock() defer ch.tcpSubsRWMut.RUnlock()
active := len(ch.subscribers) > 0 ch.embeddedSubsRWMut.RLock()
defer ch.embeddedSubsRWMut.RUnlock()
return active return len(ch.tcpSubs)+len(ch.embeddedSubs) > 0
} }
func (ch *Channel) NumSubs() int { func (ch *Channel) NumSubs() int {
ch.subscribersRWMut.RLock() ch.tcpSubsRWMut.RLock()
defer ch.subscribersRWMut.RUnlock() defer ch.tcpSubsRWMut.RUnlock()
n := len(ch.subscribers) ch.embeddedSubsRWMut.RLock()
defer ch.embeddedSubsRWMut.RUnlock()
return n return len(ch.tcpSubs) + len(ch.embeddedSubs)
}
func (ch *Channel) Subscribers() map[*net.Conn]*resp.Conn {
ch.subscribersRWMut.RLock()
defer ch.subscribersRWMut.RUnlock()
subscribers := make(map[*net.Conn]*resp.Conn, len(ch.subscribers))
for k, v := range ch.subscribers {
subscribers[k] = v
}
return subscribers
} }

View File

@@ -35,7 +35,8 @@ func handleSubscribe(params internal.HandlerFuncParams) ([]byte, error) {
} }
withPattern := strings.EqualFold(params.Command[0], "psubscribe") withPattern := strings.EqualFold(params.Command[0], "psubscribe")
pubsub.Subscribe(params.Context, params.Connection, channels, withPattern)
pubsub.Subscribe(params.Connection, channels, withPattern)
return nil, nil return nil, nil
} }
@@ -50,7 +51,7 @@ func handleUnsubscribe(params internal.HandlerFuncParams) ([]byte, error) {
withPattern := strings.EqualFold(params.Command[0], "punsubscribe") withPattern := strings.EqualFold(params.Command[0], "punsubscribe")
return pubsub.Unsubscribe(params.Context, params.Connection, channels, withPattern), nil return pubsub.Unsubscribe(params.Connection, channels, withPattern), nil
} }
func handlePublish(params internal.HandlerFuncParams) ([]byte, error) { func handlePublish(params internal.HandlerFuncParams) ([]byte, error) {
@@ -61,7 +62,7 @@ func handlePublish(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) != 3 { if len(params.Command) != 3 {
return nil, errors.New(constants.WrongArgsResponse) return nil, errors.New(constants.WrongArgsResponse)
} }
pubsub.Publish(params.Context, params.Command[2], params.Command[1]) pubsub.Publish(params.Command[2], params.Command[1])
return []byte(constants.OkResponse), nil return []byte(constants.OkResponse), nil
} }

View File

@@ -17,33 +17,31 @@ package pubsub
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"net" "net"
"slices" "slices"
"sync" "sync"
"github.com/gobwas/glob" "github.com/gobwas/glob"
"github.com/tidwall/resp"
) )
type PubSub struct { type PubSub struct {
ctx context.Context
channels []*Channel channels []*Channel
channelsRWMut sync.RWMutex channelsRWMut sync.RWMutex
} }
func NewPubSub() *PubSub { func NewPubSub(ctx context.Context) *PubSub {
return &PubSub{ return &PubSub{
ctx: ctx,
channels: []*Channel{}, channels: []*Channel{},
channelsRWMut: sync.RWMutex{}, channelsRWMut: sync.RWMutex{},
} }
} }
func (ps *PubSub) Subscribe(_ context.Context, conn *net.Conn, channels []string, withPattern bool) { func (ps *PubSub) Subscribe(sub any, channels []string, withPattern bool) {
ps.channelsRWMut.Lock() ps.channelsRWMut.Lock()
defer ps.channelsRWMut.Unlock() defer ps.channelsRWMut.Unlock()
r := resp.NewConn(*conn)
action := "subscribe" action := "subscribe"
if withPattern { if withPattern {
action = "psubscribe" action = "psubscribe"
@@ -58,40 +56,46 @@ func (ps *PubSub) Subscribe(_ context.Context, conn *net.Conn, channels []string
}) })
if channelIdx == -1 { if channelIdx == -1 {
// Create new channel, start it, and subscribe to it // Create new channel, if it does not exist
var newChan *Channel var newChan *Channel
if withPattern { if withPattern {
newChan = NewChannel(WithPattern(channels[i])) newChan = NewChannel(ps.ctx, WithPattern(channels[i]))
} else { } else {
newChan = NewChannel(WithName(channels[i])) newChan = NewChannel(ps.ctx, WithName(channels[i]))
} }
newChan.Start() // Append the channel to the list of channels
if newChan.Subscribe(conn) { ps.channels = append(ps.channels, newChan)
if err := r.WriteArray([]resp.Value{
resp.StringValue(action), // Subscribe to the channel
resp.StringValue(newChan.name), switch sub.(type) {
resp.IntegerValue(i + 1), case *net.Conn:
}); err != nil { // Subscribe TCP connection
log.Println(err) conn := sub.(*net.Conn)
} newChan.Subscribe(conn, action, i)
ps.channels = append(ps.channels, newChan)
case *EmbeddedSub:
// Subscribe io.Writer from embedded client
w := sub.(*EmbeddedSub)
newChan.Subscribe(w, action, i)
} }
} else { } else {
// Subscribe to existing channel // Subscribe to existing channel
if ps.channels[channelIdx].Subscribe(conn) { switch sub.(type) {
if err := r.WriteArray([]resp.Value{ case *net.Conn:
resp.StringValue(action), // Subscribe TCP connection
resp.StringValue(ps.channels[channelIdx].name), conn := sub.(*net.Conn)
resp.IntegerValue(i + 1), ps.channels[channelIdx].Subscribe(conn, action, i)
}); err != nil {
log.Println(err) case *EmbeddedSub:
} // Subscribe io.Writer from embedded client
w := sub.(*EmbeddedSub)
ps.channels[channelIdx].Subscribe(w, action, i)
} }
} }
} }
} }
func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []string, withPattern bool) []byte { func (ps *PubSub) Unsubscribe(sub any, channels []string, withPattern bool) []byte {
ps.channelsRWMut.RLock() ps.channelsRWMut.RLock()
defer ps.channelsRWMut.RUnlock() defer ps.channelsRWMut.RUnlock()
@@ -111,7 +115,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
if channel.pattern != nil { // Skip pattern channels if channel.pattern != nil { // Skip pattern channels
continue continue
} }
if channel.Unsubscribe(conn) { if channel.Unsubscribe(sub) {
unsubscribed[idx] = channel.name unsubscribed[idx] = channel.name
idx += 1 idx += 1
} }
@@ -123,7 +127,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
if channel.pattern == nil { // Skip non-pattern channels if channel.pattern == nil { // Skip non-pattern channels
continue continue
} }
if channel.Unsubscribe(conn) { if channel.Unsubscribe(sub) {
unsubscribed[idx] = channel.name unsubscribed[idx] = channel.name
idx += 1 idx += 1
} }
@@ -136,7 +140,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
// names exactly matches the pattern name. // names exactly matches the pattern name.
for _, channel := range ps.channels { // For each channel in PubSub for _, channel := range ps.channels { // For each channel in PubSub
for _, c := range channels { // For each channel name provided for _, c := range channels { // For each channel name provided
if channel.name == c && channel.Unsubscribe(conn) { if channel.name == c && channel.Unsubscribe(sub) {
unsubscribed[idx] = channel.name unsubscribed[idx] = channel.name
idx += 1 idx += 1
} }
@@ -151,7 +155,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
for _, channel := range ps.channels { for _, channel := range ps.channels {
// If it's a pattern channel, directly compare the patterns // If it's a pattern channel, directly compare the patterns
if channel.pattern != nil && channel.name == pattern { if channel.pattern != nil && channel.name == pattern {
if channel.Unsubscribe(conn) { if channel.Unsubscribe(sub) {
unsubscribed[idx] = channel.name unsubscribed[idx] = channel.name
idx += 1 idx += 1
} }
@@ -159,7 +163,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
} }
// If this is a regular channel, check if the channel name matches the pattern given // If this is a regular channel, check if the channel name matches the pattern given
if g.Match(channel.name) { if g.Match(channel.name) {
if channel.Unsubscribe(conn) { if channel.Unsubscribe(sub) {
unsubscribed[idx] = channel.name unsubscribed[idx] = channel.name
idx += 1 idx += 1
} }
@@ -176,7 +180,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
return []byte(res) return []byte(res)
} }
func (ps *PubSub) Publish(_ context.Context, message string, channelName string) { func (ps *PubSub) Publish(message string, channelName string) {
ps.channelsRWMut.RLock() ps.channelsRWMut.RLock()
defer ps.channelsRWMut.RUnlock() defer ps.channelsRWMut.RUnlock()

View File

@@ -0,0 +1,59 @@
// Copyright 2024 Kelvin Clement Mwinuka
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pubsub
import (
"bufio"
"bytes"
"sync"
)
type EmbeddedSub struct {
mux sync.Mutex
buff *bytes.Buffer
writer *bufio.Writer
reader *bufio.Reader
}
func NewEmbeddedSub() *EmbeddedSub {
sub := &EmbeddedSub{
mux: sync.Mutex{},
buff: bytes.NewBuffer(make([]byte, 0)),
}
sub.writer = bufio.NewWriter(sub.buff)
sub.reader = bufio.NewReader(sub.buff)
return sub
}
func (sub *EmbeddedSub) Write(p []byte) (int, error) {
sub.mux.Lock()
defer sub.mux.Unlock()
n, err := sub.writer.Write(p)
if err != nil {
return n, err
}
err = sub.writer.Flush()
return n, err
}
func (sub *EmbeddedSub) Read(p []byte) (int, error) {
sub.mux.Lock()
defer sub.mux.Unlock()
chunk, err := sub.reader.ReadBytes(byte('\n'))
n := copy(p, chunk)
return n, err
}

View File

@@ -16,55 +16,23 @@ package sugardb
import ( import (
"bytes" "bytes"
"errors"
"github.com/echovault/sugardb/internal" "github.com/echovault/sugardb/internal"
"github.com/echovault/sugardb/internal/modules/pubsub"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net"
"strings" "strings"
"sync" "sync"
) )
type conn struct { type MessageReader struct {
readConn *net.Conn embeddedSub *pubsub.EmbeddedSub
writeConn *net.Conn
} }
var connections sync.Map func (reader *MessageReader) Read(p []byte) (int, error) {
return reader.embeddedSub.Read(p)
// ReadPubSubMessage is returned by the Subscribe and PSubscribe functions.
//
// This function is lazy, therefore it needs to be invoked in order to read the next message.
// When the message is read, the function returns a string slice with 3 elements.
// Index 0 holds the event type which in this case will be "message". Index 1 holds the channel name.
// Index 2 holds the actual message.
type ReadPubSubMessage func() []string
func establishConnections(tag string) (*net.Conn, *net.Conn, error) {
var readConn *net.Conn
var writeConn *net.Conn
if _, ok := connections.Load(tag); !ok {
// If connection with this name does not exist, create new connection.
rc, wc := net.Pipe()
readConn = &rc
writeConn = &wc
connections.Store(tag, conn{
readConn: &rc,
writeConn: &wc,
})
} else {
// Reuse existing connection.
c, ok := connections.Load(tag)
if !ok {
return nil, nil, errors.New("could not establish connection")
}
readConn = c.(conn).readConn
writeConn = c.(conn).writeConn
}
return readConn, writeConn, nil
} }
var subscriptions sync.Map
// Subscribe subscribes the caller to the list of provided channels. // Subscribe subscribes the caller to the list of provided channels.
// //
// Parameters: // Parameters:
@@ -75,31 +43,22 @@ func establishConnections(tag string) (*net.Conn, *net.Conn, error) {
// //
// Returns: ReadPubSubMessage function which reads the next message sent to the subscription instance. // Returns: ReadPubSubMessage function which reads the next message sent to the subscription instance.
// This function is blocking. // This function is blocking.
func (server *SugarDB) Subscribe(tag string, channels ...string) (ReadPubSubMessage, error) { func (server *SugarDB) Subscribe(tag string, channels ...string) (*MessageReader, error) {
readConn, writeConn, err := establishConnections(tag) var msgReader *MessageReader
if err != nil {
return func() []string { sub, ok := subscriptions.Load(tag)
return []string{} if !ok {
}, err // Create new messageBuffer and store it in the subscriptions
msgReader = &MessageReader{
embeddedSub: pubsub.NewEmbeddedSub(),
}
} else {
msgReader = sub.(*MessageReader)
} }
// Subscribe connection to the provided channels. server.pubSub.Subscribe(msgReader.embeddedSub, channels, false)
cmd := append([]string{"SUBSCRIBE"}, channels...)
go func() {
_, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), writeConn, false, true)
}()
return func() []string { return msgReader, nil
r := resp.NewConn(*readConn)
v, _, _ := r.ReadValue()
res := make([]string, len(v.Array()))
for i := 0; i < len(res); i++ {
res[i] = v.Array()[i].String()
}
return res
}, nil
} }
// Unsubscribe unsubscribes the caller from the given channels. // Unsubscribe unsubscribes the caller from the given channels.
@@ -110,12 +69,12 @@ func (server *SugarDB) Subscribe(tag string, channels ...string) (ReadPubSubMess
// //
// `channels` - ...string - The list of channels to unsubscribe from. // `channels` - ...string - The list of channels to unsubscribe from.
func (server *SugarDB) Unsubscribe(tag string, channels ...string) { func (server *SugarDB) Unsubscribe(tag string, channels ...string) {
c, ok := connections.Load(tag) sub, ok := subscriptions.Load(tag)
if !ok { if !ok {
return return
} }
cmd := append([]string{"UNSUBSCRIBE"}, channels...) msgReader := sub.(*MessageReader)
_, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), c.(conn).writeConn, false, true) server.pubSub.Unsubscribe(msgReader, channels, false)
} }
// PSubscribe subscribes the caller to the list of provided glob patterns. // PSubscribe subscribes the caller to the list of provided glob patterns.
@@ -128,31 +87,23 @@ func (server *SugarDB) Unsubscribe(tag string, channels ...string) {
// //
// Returns: ReadPubSubMessage function which reads the next message sent to the subscription instance. // Returns: ReadPubSubMessage function which reads the next message sent to the subscription instance.
// This function is blocking. // This function is blocking.
func (server *SugarDB) PSubscribe(tag string, patterns ...string) (ReadPubSubMessage, error) {
readConn, writeConn, err := establishConnections(tag) func (server *SugarDB) PSubscribe(tag string, patterns ...string) (*MessageReader, error) {
if err != nil { var msgReader *MessageReader
return func() []string {
return []string{} sub, ok := subscriptions.Load(tag)
}, err if !ok {
// Create new messageBuffer and store it in the subscriptions
msgReader = &MessageReader{
embeddedSub: pubsub.NewEmbeddedSub(),
}
} else {
msgReader = sub.(*MessageReader)
} }
// Subscribe connection to the provided channels server.pubSub.Subscribe(msgReader.embeddedSub, patterns, true)
cmd := append([]string{"PSUBSCRIBE"}, patterns...)
go func() {
_, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), writeConn, false, true)
}()
return func() []string { return msgReader, nil
r := resp.NewConn(*readConn)
v, _, _ := r.ReadValue()
res := make([]string, len(v.Array()))
for i := 0; i < len(res); i++ {
res[i] = v.Array()[i].String()
}
return res
}, nil
} }
// PUnsubscribe unsubscribes the caller from the given glob patterns. // PUnsubscribe unsubscribes the caller from the given glob patterns.
@@ -163,12 +114,12 @@ func (server *SugarDB) PSubscribe(tag string, patterns ...string) (ReadPubSubMes
// //
// `patterns` - ...string - The list of glob patterns to unsubscribe from. // `patterns` - ...string - The list of glob patterns to unsubscribe from.
func (server *SugarDB) PUnsubscribe(tag string, patterns ...string) { func (server *SugarDB) PUnsubscribe(tag string, patterns ...string) {
c, ok := connections.Load(tag) sub, ok := subscriptions.Load(tag)
if !ok { if !ok {
return return
} }
cmd := append([]string{"PUNSUBSCRIBE"}, patterns...) msgReader := sub.(*MessageReader)
_, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), c.(conn).writeConn, false, true) server.pubSub.Unsubscribe(msgReader, patterns, true)
} }
// Publish publishes a message to the given channel. // Publish publishes a message to the given channel.
@@ -179,10 +130,12 @@ func (server *SugarDB) PUnsubscribe(tag string, patterns ...string) {
// //
// `message` - string - The message to publish to the specified channel. // `message` - string - The message to publish to the specified channel.
// //
// Returns: true when the publish is successful. This does not indicate whether each subscriber has received the message, // Returns: true when successful. This does not indicate whether each subscriber has received the message,
// only that the message has been published. // only that the message has been published to the channel.
func (server *SugarDB) Publish(channel, message string) (bool, error) { func (server *SugarDB) Publish(channel, message string) (bool, error) {
b, err := server.handleCommand(server.context, internal.EncodeCommand([]string{"PUBLISH", channel, message}), nil, false, true) b, err := server.handleCommand(
server.context,
internal.EncodeCommand([]string{"PUBLISH", channel, message}), nil, false, true)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@@ -15,276 +15,326 @@
package sugardb package sugardb
import ( import (
"bytes"
"encoding/json"
"fmt" "fmt"
"io"
"reflect" "reflect"
"slices" "slices"
"testing" "testing"
"time"
) )
func Test_Subscribe(t *testing.T) { func TestSugarDB_PubSub(t *testing.T) {
server := createSugarDB() server := createSugarDB()
t.Cleanup(func() {
server.ShutDown()
})
// Subscribe to channels. t.Run("TestSugarDB_(P)Subscribe", func(t *testing.T) {
tag := "tag" t.Parallel()
channels := []string{"channel1", "channel2"}
readMessage, err := server.Subscribe(tag, channels...)
if err != nil {
t.Errorf("SUBSCRIBE() error = %v", err)
}
for i := 0; i < len(channels); i++ { tests := []struct {
message := readMessage() name string
// Check that we've received the subscribe messages. action string // subscribe | psubscribe
if message[0] != "subscribe" { tag string
t.Errorf("SUBSCRIBE() expected index 0 for message at %d to be \"subscribe\", got %s", i, message[0]) channels []string
pubChannels []string // Channels to publish messages to after subscriptions
wantMsg []string // Expected messages from after publishing
subFunc func(tag string, channels ...string) (*MessageReader, error)
unsubFunc func(tag string, channels ...string)
}{
{
name: "1. Subscribe to channels",
action: "subscribe",
tag: "tag_test_subscribe",
channels: []string{
"channel1",
"channel2",
},
pubChannels: []string{"channel1", "channel2"},
wantMsg: []string{
"message for channel1",
"message for channel2",
},
subFunc: server.Subscribe,
unsubFunc: server.Unsubscribe,
},
{
name: "2. Subscribe to patterns",
action: "psubscribe",
tag: "tag_test_psubscribe",
channels: []string{
"channel[34]",
"pattern[12]",
},
pubChannels: []string{
"channel3",
"channel4",
"pattern1",
"pattern2",
},
wantMsg: []string{
"message for channel3",
"message for channel4",
"message for pattern1",
"message for pattern2",
},
subFunc: server.PSubscribe,
unsubFunc: server.PUnsubscribe,
},
} }
if !slices.Contains(channels, message[1]) {
t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
}
}
// Publish some messages to the channels. for _, tt := range tests {
for _, channel := range channels { t.Run(tt.name, func(t *testing.T) {
ok, err := server.Publish(channel, fmt.Sprintf("message for %s", channel)) t.Parallel()
if err != nil {
t.Errorf("PUBLISH() err = %v", err)
}
if !ok {
t.Errorf("PUBLISH() could not publish message to channel %s", channel)
}
}
// Read messages from the channels t.Cleanup(func() {
for i := 0; i < len(channels); i++ { tt.unsubFunc(tt.tag, tt.channels...)
message := readMessage() })
// Check that we've received the messages.
if message[0] != "message" {
t.Errorf("SUBSCRIBE() expected index 0 for message at %d to be \"message\", got %s", i, message[0])
}
if !slices.Contains(channels, message[1]) {
t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
}
if !slices.Contains([]string{"message for channel1", "message for channel2"}, message[2]) {
t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
}
}
// Unsubscribe from channels // Subscribe to channels.
server.Unsubscribe(tag, channels...) readMessage, err := tt.subFunc(tt.tag, tt.channels...)
} if err != nil {
t.Errorf("(P)SUBSCRIBE() error = %v", err)
func TestSugarDB_PSubscribe(t *testing.T) {
server := createSugarDB()
// Subscribe to channels.
tag := "tag"
patterns := []string{"channel[12]", "pattern[12]"}
readMessage, err := server.PSubscribe(tag, patterns...)
if err != nil {
t.Errorf("PSubscribe() error = %v", err)
}
for i := 0; i < len(patterns); i++ {
message := readMessage()
// Check that we've received the subscribe messages.
if message[0] != "psubscribe" {
t.Errorf("PSUBSCRIBE() expected index 0 for message at %d to be \"psubscribe\", got %s", i, message[0])
}
if !slices.Contains(patterns, message[1]) {
t.Errorf("PSUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
}
}
// Publish some messages to the channels.
for _, channel := range []string{"channel1", "channel2", "pattern1", "pattern2"} {
ok, err := server.Publish(channel, fmt.Sprintf("message for %s", channel))
if err != nil {
t.Errorf("PUBLISH() err = %v", err)
}
if !ok {
t.Errorf("PUBLISH() could not publish message to channel %s", channel)
}
}
// Read messages from the channels
for i := 0; i < len(patterns)*2; i++ {
message := readMessage()
// Check that we've received the messages.
if message[0] != "message" {
t.Errorf("SUBSCRIBE() expected index 0 for message at %d to be \"message\", got %s", i, message[0])
}
if !slices.Contains(patterns, message[1]) {
t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
}
if !slices.Contains([]string{
"message for channel1", "message for channel2", "message for pattern1", "message for pattern2"}, message[2]) {
t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[2], i)
}
}
// Unsubscribe from channels
server.PUnsubscribe(tag, patterns...)
}
func TestSugarDB_PubSubChannels(t *testing.T) {
server := createSugarDB()
tests := []struct {
name string
tag string
channels []string
patterns []string
pattern string
want []string
wantErr bool
}{
{
name: "1. Get number of active channels",
tag: "tag",
channels: []string{"channel1", "channel2", "channel3", "channel4"},
patterns: []string{"channel[56]"},
pattern: "channel[123456]",
want: []string{"channel1", "channel2", "channel3", "channel4"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Subscribe to channels
readChannelMessages, err := server.Subscribe(tt.tag, tt.channels...)
if err != nil {
t.Errorf("PubSubChannels() error = %v", err)
}
for i := 0; i < len(tt.channels); i++ {
readChannelMessages()
}
// Subscribe to patterns
readPatternMessages, err := server.PSubscribe(tt.tag, tt.patterns...)
if err != nil {
t.Errorf("PubSubChannels() error = %v", err)
}
for i := 0; i < len(tt.patterns); i++ {
readPatternMessages()
}
got, err := server.PubSubChannels(tt.pattern)
if (err != nil) != tt.wantErr {
t.Errorf("PubSubChannels() error = %v, wantErr %v", err, tt.wantErr)
return
}
if len(got) != len(tt.want) {
t.Errorf("PubSubChannels() got response length %d, want %d", len(got), len(tt.want))
}
for _, item := range got {
if !slices.Contains(tt.want, item) {
t.Errorf("PubSubChannels() unexpected item \"%s\", in response", item)
} }
}
})
}
}
func TestSugarDB_PubSubNumPat(t *testing.T) { for i := 0; i < len(tt.channels); i++ {
server := createSugarDB() p := make([]byte, 1024)
tests := []struct { _, err := readMessage.Read(p)
name string if err != nil {
tag string t.Errorf("(P)SUBSCRIBE() read error: %+v", err)
patterns []string }
want int var message []string
wantErr bool if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil {
}{ t.Errorf("(P)SUBSCRIBE() json unmarshal error: %+v", err)
{ }
name: "1. Get number of active patterns on the server", // Check that we've received the subscribe messages.
tag: "tag", if message[0] != tt.action {
patterns: []string{"channel[56]", "channel[78]"}, t.Errorf("(P)SUBSCRIBE() expected index 0 for message at %d to be \"subscribe\", got %s", i, message[0])
want: 2, }
wantErr: false, if !slices.Contains(tt.channels, message[1]) {
}, t.Errorf("(P)SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
} }
for _, tt := range tests { }
t.Run(tt.name, func(t *testing.T) {
// Subscribe to patterns
readPatternMessages, err := server.PSubscribe(tt.tag, tt.patterns...)
if err != nil {
t.Errorf("PubSubNumPat() error = %v", err)
}
for i := 0; i < len(tt.patterns); i++ {
readPatternMessages()
}
got, err := server.PubSubNumPat()
if (err != nil) != tt.wantErr {
t.Errorf("PubSubNumPat() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("PubSubNumPat() got = %v, want %v", got, tt.want)
}
})
}
}
func TestSugarDB_PubSubNumSub(t *testing.T) { // Publish some messages to the channels.
server := createSugarDB() for _, channel := range tt.pubChannels {
tests := []struct { ok, err := server.Publish(channel, fmt.Sprintf("message for %s", channel))
name string if err != nil {
subscriptions map[string]struct { t.Errorf("PUBLISH() err = %v", err)
}
if !ok {
t.Errorf("PUBLISH() could not publish message to channel %s", channel)
}
}
// Read messages from the channels
for i := 0; i < len(tt.pubChannels); i++ {
p := make([]byte, 1024)
_, err := readMessage.Read(p)
doneChan := make(chan struct{}, 1)
go func() {
for {
if err != nil && err == io.EOF {
_, err = readMessage.Read(p)
continue
}
doneChan <- struct{}{}
break
}
}()
select {
case <-time.After(500 * time.Millisecond):
t.Errorf("(P)SUBSCRIBE() timeout")
case <-doneChan:
if err != nil {
t.Errorf("(P)SUBSCRIBE() read error: %+v", err)
}
}
var message []string
if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil {
t.Errorf("(P)SUBSCRIBE() json unmarshal error: %+v", err)
}
// Check that we've received the messages.
if message[0] != "message" {
t.Errorf("(P)SUBSCRIBE() expected index 0 for message at %d to be \"message\", got %s", i, message[0])
}
if !slices.Contains(tt.channels, message[1]) {
t.Errorf("(P)SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
}
if !slices.Contains(tt.wantMsg, message[2]) {
t.Errorf("(P)SUBSCRIBE() unexpected string \"%s\" at index 2 for message %d", message[1], i)
}
}
})
}
})
t.Run("TestSugarDB_PubSubChannels", func(t *testing.T) {
tests := []struct {
name string
tag string
channels []string channels []string
patterns []string patterns []string
pattern string
want []string
wantErr bool
}{
{
name: "1. Get number of active channels",
tag: "tag_test_channels_1",
channels: []string{"channel1", "channel2", "channel3", "channel4"},
patterns: []string{"channel[56]"},
pattern: "channel[123456]",
want: []string{"channel1", "channel2", "channel3", "channel4"},
wantErr: false,
},
} }
channels []string for _, tt := range tests {
want map[string]int t.Run(tt.name, func(t *testing.T) {
wantErr bool t.Parallel()
}{ // Subscribe to channels
{ _, err := server.Subscribe(tt.tag, tt.channels...)
name: "Get number of subscriptions for the given channels", if err != nil {
subscriptions: map[string]struct { t.Errorf("PubSubChannels() error = %v", err)
}
// Subscribe to patterns
_, err = server.PSubscribe(tt.tag, tt.patterns...)
if err != nil {
t.Errorf("PubSubChannels() error = %v", err)
}
got, err := server.PubSubChannels(tt.pattern)
if (err != nil) != tt.wantErr {
t.Errorf("PubSubChannels() error = %v, wantErr %v", err, tt.wantErr)
return
}
if len(got) != len(tt.want) {
t.Errorf("PubSubChannels() got response length %d, want %d", len(got), len(tt.want))
}
for _, item := range got {
if !slices.Contains(tt.want, item) {
t.Errorf("PubSubChannels() unexpected item \"%s\", in response", item)
}
}
})
}
})
t.Run("TestSugarDB_PubSubNumPat", func(t *testing.T) {
tests := []struct {
name string
tag string
patterns []string
want int
wantErr bool
}{
{
name: "1. Get number of active patterns on the server",
tag: "tag_test_numpat_1",
patterns: []string{"channel[56]", "channel[78]"},
want: 2,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Subscribe to patterns
_, err := server.PSubscribe(tt.tag, tt.patterns...)
if err != nil {
t.Errorf("PubSubNumPat() error = %v", err)
}
got, err := server.PubSubNumPat()
if (err != nil) != tt.wantErr {
t.Errorf("PubSubNumPat() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("PubSubNumPat() got = %v, want %v", got, tt.want)
}
})
}
})
t.Run("TestSugarDB_PubSubNumSub", func(t *testing.T) {
tests := []struct {
name string
subscriptions map[string]struct {
channels []string channels []string
patterns []string patterns []string
}{ }
"tag1": { channels []string
channels: []string{"channel1", "channel2"}, want map[string]int
patterns: []string{"channel[34]"}, wantErr bool
}{
{
name: "1. Get number of subscriptions for the given channels",
subscriptions: map[string]struct {
channels []string
patterns []string
}{
"tag1_test_numsub_1": {
channels: []string{"test_num_sub_channel1", "test_num_sub_channel2"},
patterns: []string{"test_num_sub_channel[34]"},
},
"tag2_test_numsub_2": {
channels: []string{"test_num_sub_channel2", "test_num_sub_channel3"},
patterns: []string{"test_num_sub_channel[23]"},
},
"tag3_test_numsub_3": {
channels: []string{"test_num_sub_channel2", "test_num_sub_channel4"},
patterns: []string{},
},
}, },
"tag2": { channels: []string{
channels: []string{"channel2", "channel3"}, "test_num_sub_channel1",
patterns: []string{"channel[23]"}, "test_num_sub_channel2",
"test_num_sub_channel3",
"test_num_sub_channel4",
"test_num_sub_channel5",
}, },
"tag3": { want: map[string]int{
channels: []string{"channel2", "channel4"}, "test_num_sub_channel1": 1,
patterns: []string{}, "test_num_sub_channel2": 3,
"test_num_sub_channel3": 1,
"test_num_sub_channel4": 1,
"test_num_sub_channel5": 0,
}, },
wantErr: false,
}, },
channels: []string{"channel1", "channel2", "channel3", "channel4", "channel5"}, }
want: map[string]int{"channel1": 1, "channel2": 3, "channel3": 1, "channel4": 1, "channel5": 0},
wantErr: false, for _, tt := range tests {
}, t.Run(tt.name, func(t *testing.T) {
} t.Parallel()
for _, tt := range tests { for tag, subs := range tt.subscriptions {
t.Run(tt.name, func(t *testing.T) { _, err := server.PSubscribe(tag, subs.patterns...)
for tag, subs := range tt.subscriptions { if err != nil {
readPat, err := server.PSubscribe(tag, subs.patterns...) t.Errorf("PubSubNumSub() error = %v", err)
if err != nil { }
t.Errorf("PubSubNumSub() error = %v", err)
_, err = server.Subscribe(tag, subs.channels...)
if err != nil {
t.Errorf("PubSubNumSub() error = %v", err)
}
} }
for _, _ = range subs.patterns { got, err := server.PubSubNumSub(tt.channels...)
readPat() if (err != nil) != tt.wantErr {
t.Errorf("PubSubNumSub() error = %v, wantErr %v", err, tt.wantErr)
return
} }
readChan, err := server.Subscribe(tag, subs.channels...) if !reflect.DeepEqual(got, tt.want) {
if err != nil { t.Errorf("PubSubNumSub() got = %v, want %v", got, tt.want)
t.Errorf("PubSubNumSub() error = %v", err)
} }
for _, _ = range subs.channels { })
readChan() }
} })
}
got, err := server.PubSubNumSub(tt.channels...)
if (err != nil) != tt.wantErr {
t.Errorf("PubSubNumSub() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("PubSubNumSub() got = %v, want %v", got, tt.want)
}
})
}
} }

View File

@@ -204,7 +204,7 @@ func NewSugarDB(options ...func(sugarDB *SugarDB)) (*SugarDB, error) {
sugarDB.acl = acl.NewACL(sugarDB.config) sugarDB.acl = acl.NewACL(sugarDB.config)
// Set up Pub/Sub module // Set up Pub/Sub module
sugarDB.pubSub = pubsub.NewPubSub() sugarDB.pubSub = pubsub.NewPubSub(sugarDB.context)
if sugarDB.isInCluster() { if sugarDB.isInCluster() {
sugarDB.raft = raft.NewRaft(raft.Opts{ sugarDB.raft = raft.NewRaft(raft.Opts{