mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-09-26 20:11:15 +08:00
Refactored PubSub Embedded API
Refactored pubsub implementation to return MessageReader on embedded instance, which implements io.Reader for reading messages (#170) - @kelvinmwinuka
This commit is contained in:
@@ -29,23 +29,34 @@ Subscribe to one or more patterns. This command accepts glob patterns.
|
||||
]}
|
||||
>
|
||||
<TabItem value="go">
|
||||
The Subscribe method returns a readMessage function.
|
||||
This method is lazy so it must be invoked each time the you want to read the next message from
|
||||
the pattern.
|
||||
When subscribing to an'N' number of patterns, the first N messages will be
|
||||
The PSubscribe method returns a MessageReader type which implements the `io.Reader` interface.
|
||||
When subscribing to an'N' number of channels, the first N messages will be
|
||||
the subscription confimations.
|
||||
The readMessage functions returns a message object when called. The message
|
||||
object is a string slice with the following inforamtion:
|
||||
event type at index 0 (e.g. subscribe, message), pattern at index 1,
|
||||
|
||||
The message read is a series of bytes resulting from JSON marshalling an array. You will need to
|
||||
unmarshal it back into a string array in order to read it. Here's the anatomy of a pubsub message:
|
||||
event type at index 0 (e.g. psubscribe, message), channel name at index 1,
|
||||
message/subscription index at index 2.
|
||||
|
||||
Messages published to any channels that match the pattern will be received by the pattern subscriber.
|
||||
|
||||
```go
|
||||
db, err := sugardb.NewSugarDB()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
readMessage := db.PSubscribe("subscribe_tag_1", "pattern_[12]", "pattern_h[ae]llo") // Return lazy readMessage function
|
||||
for i := 0; i < 2; i++ {
|
||||
message := readMessage() // Call the readMessage function for each channel subscription.
|
||||
|
||||
// Subscribe to multiple channel patterns, returs MessageReader
|
||||
msgReader := db.PSubscribe("psubscribe_tag_1", "channel[12]", "pattern[12]")
|
||||
|
||||
// Read message into pre-defined buffer
|
||||
msg := make([]byte, 1024)
|
||||
_, err := msgReader.Read(msg)
|
||||
|
||||
// Trim all null bytes at the end of the message before unmarshalling.
|
||||
var message []string
|
||||
if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil {
|
||||
log.Fatalf("json unmarshal error: %+v", err)
|
||||
}
|
||||
```
|
||||
</TabItem>
|
||||
|
@@ -29,23 +29,32 @@ Subscribe to one or more channels.
|
||||
]}
|
||||
>
|
||||
<TabItem value="go">
|
||||
The Subscribe method returns a readMessage function.
|
||||
This method is lazy so it must be invoked each time the you want to read the next message from
|
||||
the channel.
|
||||
The Subscribe method returns a MessageReader type which implements the `io.Reader` interface.
|
||||
When subscribing to an'N' number of channels, the first N messages will be
|
||||
the subscription confimations.
|
||||
The readMessage functions returns a message object when called. The message
|
||||
object is a string slice with the following inforamtion:
|
||||
|
||||
The message read is a series of bytes resulting from JSON marshalling an array. You will need to
|
||||
unmarshal it back into a string array in order to read it. Here's the anatomy of a pubsub message:
|
||||
event type at index 0 (e.g. subscribe, message), channel name at index 1,
|
||||
message/subscription index at index 2.
|
||||
|
||||
```go
|
||||
db, err := sugardb.NewSugarDB()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
readMessage := db.Subscribe("subscribe_tag_1", "channel1", "channel2") // Return lazy readMessage function
|
||||
for i := 0; i < 2; i++ {
|
||||
message := readMessage() // Call the readMessage function for each channel subscription.
|
||||
|
||||
// Subscribe to multiple channel patterns, returs MessageReader.
|
||||
msgReader := db.Subscribe("subscribe_tag_1", "channel1", "channel2")
|
||||
|
||||
// Read message into pre-defined buffer.
|
||||
msg := make([]byte, 1024)
|
||||
_, err := msgReader.Read(msg)
|
||||
|
||||
// Trim all null bytes at the end of the message before unmarshalling.
|
||||
var message []string
|
||||
if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil {
|
||||
log.Fatalf("json unmarshal error: %+v", err)
|
||||
}
|
||||
```
|
||||
</TabItem>
|
||||
|
@@ -15,19 +15,30 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gobwas/glob"
|
||||
"github.com/tidwall/resp"
|
||||
"log"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
name string // Channel name. This can be a glob pattern string.
|
||||
pattern glob.Glob // Compiled glob pattern. This is nil if the channel is not a pattern channel.
|
||||
subscribersRWMut sync.RWMutex // RWMutex to concurrency control when accessing channel subscribers.
|
||||
subscribers map[*net.Conn]*resp.Conn // Map containing the channel subscribers.
|
||||
messageChan *chan string // Messages published to this channel will be sent to this channel.
|
||||
name string // Channel name. This can be a glob pattern string.
|
||||
pattern glob.Glob // Compiled glob pattern. This is nil if the channel is not a pattern channel.
|
||||
|
||||
messages []string // Slice that holds messages.
|
||||
messagesRWMut sync.RWMutex // RWMutex for accessing channel messages.
|
||||
messagesCond *sync.Cond
|
||||
|
||||
tcpSubs map[*net.Conn]*resp.Conn // Map containing the channel's TCP subscribers.
|
||||
tcpSubsRWMut sync.RWMutex // RWMutex for accessing TCP channel subscribers.
|
||||
|
||||
embeddedSubs []*EmbeddedSub // Slice containing embedded subscribers to this channel.
|
||||
embeddedSubsRWMut sync.RWMutex // RWMutex for accessing embedded subscribers.
|
||||
}
|
||||
|
||||
// WithName option sets the channels name.
|
||||
@@ -45,46 +56,89 @@ func WithPattern(pattern string) func(channel *Channel) {
|
||||
}
|
||||
}
|
||||
|
||||
func NewChannel(options ...func(channel *Channel)) *Channel {
|
||||
messageChan := make(chan string, 4096)
|
||||
|
||||
func NewChannel(ctx context.Context, options ...func(channel *Channel)) *Channel {
|
||||
channel := &Channel{
|
||||
name: "",
|
||||
pattern: nil,
|
||||
subscribersRWMut: sync.RWMutex{},
|
||||
subscribers: make(map[*net.Conn]*resp.Conn),
|
||||
messageChan: &messageChan,
|
||||
name: "",
|
||||
pattern: nil,
|
||||
|
||||
messages: make([]string, 0),
|
||||
messagesRWMut: sync.RWMutex{},
|
||||
|
||||
tcpSubs: make(map[*net.Conn]*resp.Conn),
|
||||
tcpSubsRWMut: sync.RWMutex{},
|
||||
|
||||
embeddedSubs: make([]*EmbeddedSub, 0),
|
||||
embeddedSubsRWMut: sync.RWMutex{},
|
||||
}
|
||||
channel.messagesCond = sync.NewCond(&channel.messagesRWMut)
|
||||
|
||||
for _, option := range options {
|
||||
option(channel)
|
||||
}
|
||||
|
||||
return channel
|
||||
}
|
||||
|
||||
func (ch *Channel) Start() {
|
||||
go func() {
|
||||
for {
|
||||
message := <-*ch.messageChan
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("closing channel %s\n", channel.name)
|
||||
return
|
||||
default:
|
||||
channel.messagesRWMut.Lock()
|
||||
for len(channel.messages) == 0 {
|
||||
channel.messagesCond.Wait()
|
||||
}
|
||||
|
||||
ch.subscribersRWMut.RLock()
|
||||
message := channel.messages[0]
|
||||
channel.messages = channel.messages[1:]
|
||||
channel.messagesRWMut.Unlock()
|
||||
|
||||
for _, conn := range ch.subscribers {
|
||||
go func(conn *resp.Conn) {
|
||||
if err := conn.WriteArray([]resp.Value{
|
||||
resp.StringValue("message"),
|
||||
resp.StringValue(ch.name),
|
||||
resp.StringValue(message),
|
||||
}); err != nil {
|
||||
log.Println(err)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
// Send messages to embedded subscribers
|
||||
go func() {
|
||||
channel.embeddedSubsRWMut.RLock()
|
||||
ewg := sync.WaitGroup{}
|
||||
b, _ := json.Marshal([]string{"message", channel.name, message})
|
||||
msg := append(b, byte('\n'))
|
||||
for _, w := range channel.embeddedSubs {
|
||||
ewg.Add(1)
|
||||
go func(w *EmbeddedSub) {
|
||||
_, _ = w.Write(msg)
|
||||
ewg.Done()
|
||||
}(w)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
ewg.Wait()
|
||||
channel.embeddedSubsRWMut.RUnlock()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
ch.subscribersRWMut.RUnlock()
|
||||
// Send messages to TCP subscribers
|
||||
go func() {
|
||||
channel.tcpSubsRWMut.RLock()
|
||||
cwg := sync.WaitGroup{}
|
||||
for _, conn := range channel.tcpSubs {
|
||||
cwg.Add(1)
|
||||
go func(conn *resp.Conn) {
|
||||
_ = conn.WriteArray([]resp.Value{
|
||||
resp.StringValue("message"),
|
||||
resp.StringValue(channel.name),
|
||||
resp.StringValue(message),
|
||||
})
|
||||
cwg.Done()
|
||||
}(conn)
|
||||
}
|
||||
cwg.Wait()
|
||||
channel.tcpSubsRWMut.RUnlock()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return channel
|
||||
}
|
||||
|
||||
func (ch *Channel) Name() string {
|
||||
@@ -95,56 +149,90 @@ func (ch *Channel) Pattern() glob.Glob {
|
||||
return ch.pattern
|
||||
}
|
||||
|
||||
func (ch *Channel) Subscribe(conn *net.Conn) bool {
|
||||
ch.subscribersRWMut.Lock()
|
||||
defer ch.subscribersRWMut.Unlock()
|
||||
if _, ok := ch.subscribers[conn]; !ok {
|
||||
ch.subscribers[conn] = resp.NewConn(*conn)
|
||||
func (ch *Channel) Subscribe(sub any, action string, chanIdx int) {
|
||||
switch sub.(type) {
|
||||
case *net.Conn:
|
||||
ch.tcpSubsRWMut.Lock()
|
||||
defer ch.tcpSubsRWMut.Unlock()
|
||||
conn := sub.(*net.Conn)
|
||||
if _, ok := ch.tcpSubs[conn]; !ok {
|
||||
ch.tcpSubs[conn] = resp.NewConn(*conn)
|
||||
}
|
||||
r, _ := ch.tcpSubs[conn]
|
||||
// Send subscription message
|
||||
_ = r.WriteArray([]resp.Value{
|
||||
resp.StringValue(action),
|
||||
resp.StringValue(ch.name),
|
||||
resp.IntegerValue(chanIdx + 1),
|
||||
})
|
||||
|
||||
case *EmbeddedSub:
|
||||
ch.embeddedSubsRWMut.Lock()
|
||||
defer ch.embeddedSubsRWMut.Unlock()
|
||||
w := sub.(*EmbeddedSub)
|
||||
if !slices.ContainsFunc(ch.embeddedSubs, func(writer *EmbeddedSub) bool {
|
||||
return writer == w
|
||||
}) {
|
||||
ch.embeddedSubs = append(ch.embeddedSubs, w)
|
||||
}
|
||||
// Send subscription message
|
||||
b, _ := json.Marshal([]string{action, ch.name, fmt.Sprintf("%d", chanIdx+1)})
|
||||
msg := append(b, byte('\n'))
|
||||
_, _ = w.Write(msg)
|
||||
}
|
||||
_, ok := ch.subscribers[conn]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (ch *Channel) Unsubscribe(conn *net.Conn) bool {
|
||||
ch.subscribersRWMut.Lock()
|
||||
defer ch.subscribersRWMut.Unlock()
|
||||
if _, ok := ch.subscribers[conn]; !ok {
|
||||
func (ch *Channel) Unsubscribe(sub any) bool {
|
||||
switch sub.(type) {
|
||||
default:
|
||||
return false
|
||||
|
||||
case *net.Conn:
|
||||
ch.tcpSubsRWMut.Lock()
|
||||
defer ch.tcpSubsRWMut.Unlock()
|
||||
conn := sub.(*net.Conn)
|
||||
if _, ok := ch.tcpSubs[conn]; !ok {
|
||||
return false
|
||||
}
|
||||
delete(ch.tcpSubs, conn)
|
||||
return true
|
||||
|
||||
case *EmbeddedSub:
|
||||
ch.embeddedSubsRWMut.Lock()
|
||||
defer ch.embeddedSubsRWMut.Unlock()
|
||||
w := sub.(*EmbeddedSub)
|
||||
deleted := false
|
||||
ch.embeddedSubs = slices.DeleteFunc(ch.embeddedSubs, func(writer *EmbeddedSub) bool {
|
||||
deleted = writer == w
|
||||
return deleted
|
||||
})
|
||||
return deleted
|
||||
}
|
||||
delete(ch.subscribers, conn)
|
||||
return true
|
||||
}
|
||||
|
||||
func (ch *Channel) Publish(message string) {
|
||||
*ch.messageChan <- message
|
||||
ch.messagesRWMut.Lock()
|
||||
defer ch.messagesRWMut.Unlock()
|
||||
ch.messages = append(ch.messages, message)
|
||||
ch.messagesCond.Signal()
|
||||
}
|
||||
|
||||
func (ch *Channel) IsActive() bool {
|
||||
ch.subscribersRWMut.RLock()
|
||||
defer ch.subscribersRWMut.RUnlock()
|
||||
ch.tcpSubsRWMut.RLock()
|
||||
defer ch.tcpSubsRWMut.RUnlock()
|
||||
|
||||
active := len(ch.subscribers) > 0
|
||||
ch.embeddedSubsRWMut.RLock()
|
||||
defer ch.embeddedSubsRWMut.RUnlock()
|
||||
|
||||
return active
|
||||
return len(ch.tcpSubs)+len(ch.embeddedSubs) > 0
|
||||
}
|
||||
|
||||
func (ch *Channel) NumSubs() int {
|
||||
ch.subscribersRWMut.RLock()
|
||||
defer ch.subscribersRWMut.RUnlock()
|
||||
ch.tcpSubsRWMut.RLock()
|
||||
defer ch.tcpSubsRWMut.RUnlock()
|
||||
|
||||
n := len(ch.subscribers)
|
||||
ch.embeddedSubsRWMut.RLock()
|
||||
defer ch.embeddedSubsRWMut.RUnlock()
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (ch *Channel) Subscribers() map[*net.Conn]*resp.Conn {
|
||||
ch.subscribersRWMut.RLock()
|
||||
defer ch.subscribersRWMut.RUnlock()
|
||||
|
||||
subscribers := make(map[*net.Conn]*resp.Conn, len(ch.subscribers))
|
||||
for k, v := range ch.subscribers {
|
||||
subscribers[k] = v
|
||||
}
|
||||
|
||||
return subscribers
|
||||
return len(ch.tcpSubs) + len(ch.embeddedSubs)
|
||||
}
|
||||
|
@@ -35,7 +35,8 @@ func handleSubscribe(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
}
|
||||
|
||||
withPattern := strings.EqualFold(params.Command[0], "psubscribe")
|
||||
pubsub.Subscribe(params.Context, params.Connection, channels, withPattern)
|
||||
|
||||
pubsub.Subscribe(params.Connection, channels, withPattern)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
@@ -50,7 +51,7 @@ func handleUnsubscribe(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
|
||||
withPattern := strings.EqualFold(params.Command[0], "punsubscribe")
|
||||
|
||||
return pubsub.Unsubscribe(params.Context, params.Connection, channels, withPattern), nil
|
||||
return pubsub.Unsubscribe(params.Connection, channels, withPattern), nil
|
||||
}
|
||||
|
||||
func handlePublish(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
@@ -61,7 +62,7 @@ func handlePublish(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
if len(params.Command) != 3 {
|
||||
return nil, errors.New(constants.WrongArgsResponse)
|
||||
}
|
||||
pubsub.Publish(params.Context, params.Command[2], params.Command[1])
|
||||
pubsub.Publish(params.Command[2], params.Command[1])
|
||||
return []byte(constants.OkResponse), nil
|
||||
}
|
||||
|
||||
|
@@ -17,33 +17,31 @@ package pubsub
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/gobwas/glob"
|
||||
"github.com/tidwall/resp"
|
||||
)
|
||||
|
||||
type PubSub struct {
|
||||
ctx context.Context
|
||||
channels []*Channel
|
||||
channelsRWMut sync.RWMutex
|
||||
}
|
||||
|
||||
func NewPubSub() *PubSub {
|
||||
func NewPubSub(ctx context.Context) *PubSub {
|
||||
return &PubSub{
|
||||
ctx: ctx,
|
||||
channels: []*Channel{},
|
||||
channelsRWMut: sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *PubSub) Subscribe(_ context.Context, conn *net.Conn, channels []string, withPattern bool) {
|
||||
func (ps *PubSub) Subscribe(sub any, channels []string, withPattern bool) {
|
||||
ps.channelsRWMut.Lock()
|
||||
defer ps.channelsRWMut.Unlock()
|
||||
|
||||
r := resp.NewConn(*conn)
|
||||
|
||||
action := "subscribe"
|
||||
if withPattern {
|
||||
action = "psubscribe"
|
||||
@@ -58,40 +56,46 @@ func (ps *PubSub) Subscribe(_ context.Context, conn *net.Conn, channels []string
|
||||
})
|
||||
|
||||
if channelIdx == -1 {
|
||||
// Create new channel, start it, and subscribe to it
|
||||
// Create new channel, if it does not exist
|
||||
var newChan *Channel
|
||||
if withPattern {
|
||||
newChan = NewChannel(WithPattern(channels[i]))
|
||||
newChan = NewChannel(ps.ctx, WithPattern(channels[i]))
|
||||
} else {
|
||||
newChan = NewChannel(WithName(channels[i]))
|
||||
newChan = NewChannel(ps.ctx, WithName(channels[i]))
|
||||
}
|
||||
newChan.Start()
|
||||
if newChan.Subscribe(conn) {
|
||||
if err := r.WriteArray([]resp.Value{
|
||||
resp.StringValue(action),
|
||||
resp.StringValue(newChan.name),
|
||||
resp.IntegerValue(i + 1),
|
||||
}); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
ps.channels = append(ps.channels, newChan)
|
||||
// Append the channel to the list of channels
|
||||
ps.channels = append(ps.channels, newChan)
|
||||
|
||||
// Subscribe to the channel
|
||||
switch sub.(type) {
|
||||
case *net.Conn:
|
||||
// Subscribe TCP connection
|
||||
conn := sub.(*net.Conn)
|
||||
newChan.Subscribe(conn, action, i)
|
||||
|
||||
case *EmbeddedSub:
|
||||
// Subscribe io.Writer from embedded client
|
||||
w := sub.(*EmbeddedSub)
|
||||
newChan.Subscribe(w, action, i)
|
||||
}
|
||||
} else {
|
||||
// Subscribe to existing channel
|
||||
if ps.channels[channelIdx].Subscribe(conn) {
|
||||
if err := r.WriteArray([]resp.Value{
|
||||
resp.StringValue(action),
|
||||
resp.StringValue(ps.channels[channelIdx].name),
|
||||
resp.IntegerValue(i + 1),
|
||||
}); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
switch sub.(type) {
|
||||
case *net.Conn:
|
||||
// Subscribe TCP connection
|
||||
conn := sub.(*net.Conn)
|
||||
ps.channels[channelIdx].Subscribe(conn, action, i)
|
||||
|
||||
case *EmbeddedSub:
|
||||
// Subscribe io.Writer from embedded client
|
||||
w := sub.(*EmbeddedSub)
|
||||
ps.channels[channelIdx].Subscribe(w, action, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []string, withPattern bool) []byte {
|
||||
func (ps *PubSub) Unsubscribe(sub any, channels []string, withPattern bool) []byte {
|
||||
ps.channelsRWMut.RLock()
|
||||
defer ps.channelsRWMut.RUnlock()
|
||||
|
||||
@@ -111,7 +115,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
|
||||
if channel.pattern != nil { // Skip pattern channels
|
||||
continue
|
||||
}
|
||||
if channel.Unsubscribe(conn) {
|
||||
if channel.Unsubscribe(sub) {
|
||||
unsubscribed[idx] = channel.name
|
||||
idx += 1
|
||||
}
|
||||
@@ -123,7 +127,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
|
||||
if channel.pattern == nil { // Skip non-pattern channels
|
||||
continue
|
||||
}
|
||||
if channel.Unsubscribe(conn) {
|
||||
if channel.Unsubscribe(sub) {
|
||||
unsubscribed[idx] = channel.name
|
||||
idx += 1
|
||||
}
|
||||
@@ -136,7 +140,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
|
||||
// names exactly matches the pattern name.
|
||||
for _, channel := range ps.channels { // For each channel in PubSub
|
||||
for _, c := range channels { // For each channel name provided
|
||||
if channel.name == c && channel.Unsubscribe(conn) {
|
||||
if channel.name == c && channel.Unsubscribe(sub) {
|
||||
unsubscribed[idx] = channel.name
|
||||
idx += 1
|
||||
}
|
||||
@@ -151,7 +155,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
|
||||
for _, channel := range ps.channels {
|
||||
// If it's a pattern channel, directly compare the patterns
|
||||
if channel.pattern != nil && channel.name == pattern {
|
||||
if channel.Unsubscribe(conn) {
|
||||
if channel.Unsubscribe(sub) {
|
||||
unsubscribed[idx] = channel.name
|
||||
idx += 1
|
||||
}
|
||||
@@ -159,7 +163,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
|
||||
}
|
||||
// If this is a regular channel, check if the channel name matches the pattern given
|
||||
if g.Match(channel.name) {
|
||||
if channel.Unsubscribe(conn) {
|
||||
if channel.Unsubscribe(sub) {
|
||||
unsubscribed[idx] = channel.name
|
||||
idx += 1
|
||||
}
|
||||
@@ -176,7 +180,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri
|
||||
return []byte(res)
|
||||
}
|
||||
|
||||
func (ps *PubSub) Publish(_ context.Context, message string, channelName string) {
|
||||
func (ps *PubSub) Publish(message string, channelName string) {
|
||||
ps.channelsRWMut.RLock()
|
||||
defer ps.channelsRWMut.RUnlock()
|
||||
|
||||
|
59
internal/modules/pubsub/sub.go
Normal file
59
internal/modules/pubsub/sub.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright 2024 Kelvin Clement Mwinuka
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type EmbeddedSub struct {
|
||||
mux sync.Mutex
|
||||
buff *bytes.Buffer
|
||||
writer *bufio.Writer
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func NewEmbeddedSub() *EmbeddedSub {
|
||||
sub := &EmbeddedSub{
|
||||
mux: sync.Mutex{},
|
||||
buff: bytes.NewBuffer(make([]byte, 0)),
|
||||
}
|
||||
sub.writer = bufio.NewWriter(sub.buff)
|
||||
sub.reader = bufio.NewReader(sub.buff)
|
||||
return sub
|
||||
}
|
||||
|
||||
func (sub *EmbeddedSub) Write(p []byte) (int, error) {
|
||||
sub.mux.Lock()
|
||||
defer sub.mux.Unlock()
|
||||
n, err := sub.writer.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
err = sub.writer.Flush()
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (sub *EmbeddedSub) Read(p []byte) (int, error) {
|
||||
sub.mux.Lock()
|
||||
defer sub.mux.Unlock()
|
||||
|
||||
chunk, err := sub.reader.ReadBytes(byte('\n'))
|
||||
n := copy(p, chunk)
|
||||
|
||||
return n, err
|
||||
}
|
@@ -16,55 +16,23 @@ package sugardb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/echovault/sugardb/internal"
|
||||
"github.com/echovault/sugardb/internal/modules/pubsub"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type conn struct {
|
||||
readConn *net.Conn
|
||||
writeConn *net.Conn
|
||||
type MessageReader struct {
|
||||
embeddedSub *pubsub.EmbeddedSub
|
||||
}
|
||||
|
||||
var connections sync.Map
|
||||
|
||||
// ReadPubSubMessage is returned by the Subscribe and PSubscribe functions.
|
||||
//
|
||||
// This function is lazy, therefore it needs to be invoked in order to read the next message.
|
||||
// When the message is read, the function returns a string slice with 3 elements.
|
||||
// Index 0 holds the event type which in this case will be "message". Index 1 holds the channel name.
|
||||
// Index 2 holds the actual message.
|
||||
type ReadPubSubMessage func() []string
|
||||
|
||||
func establishConnections(tag string) (*net.Conn, *net.Conn, error) {
|
||||
var readConn *net.Conn
|
||||
var writeConn *net.Conn
|
||||
|
||||
if _, ok := connections.Load(tag); !ok {
|
||||
// If connection with this name does not exist, create new connection.
|
||||
rc, wc := net.Pipe()
|
||||
readConn = &rc
|
||||
writeConn = &wc
|
||||
connections.Store(tag, conn{
|
||||
readConn: &rc,
|
||||
writeConn: &wc,
|
||||
})
|
||||
} else {
|
||||
// Reuse existing connection.
|
||||
c, ok := connections.Load(tag)
|
||||
if !ok {
|
||||
return nil, nil, errors.New("could not establish connection")
|
||||
}
|
||||
readConn = c.(conn).readConn
|
||||
writeConn = c.(conn).writeConn
|
||||
}
|
||||
|
||||
return readConn, writeConn, nil
|
||||
func (reader *MessageReader) Read(p []byte) (int, error) {
|
||||
return reader.embeddedSub.Read(p)
|
||||
}
|
||||
|
||||
var subscriptions sync.Map
|
||||
|
||||
// Subscribe subscribes the caller to the list of provided channels.
|
||||
//
|
||||
// Parameters:
|
||||
@@ -75,31 +43,22 @@ func establishConnections(tag string) (*net.Conn, *net.Conn, error) {
|
||||
//
|
||||
// Returns: ReadPubSubMessage function which reads the next message sent to the subscription instance.
|
||||
// This function is blocking.
|
||||
func (server *SugarDB) Subscribe(tag string, channels ...string) (ReadPubSubMessage, error) {
|
||||
readConn, writeConn, err := establishConnections(tag)
|
||||
if err != nil {
|
||||
return func() []string {
|
||||
return []string{}
|
||||
}, err
|
||||
func (server *SugarDB) Subscribe(tag string, channels ...string) (*MessageReader, error) {
|
||||
var msgReader *MessageReader
|
||||
|
||||
sub, ok := subscriptions.Load(tag)
|
||||
if !ok {
|
||||
// Create new messageBuffer and store it in the subscriptions
|
||||
msgReader = &MessageReader{
|
||||
embeddedSub: pubsub.NewEmbeddedSub(),
|
||||
}
|
||||
} else {
|
||||
msgReader = sub.(*MessageReader)
|
||||
}
|
||||
|
||||
// Subscribe connection to the provided channels.
|
||||
cmd := append([]string{"SUBSCRIBE"}, channels...)
|
||||
go func() {
|
||||
_, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), writeConn, false, true)
|
||||
}()
|
||||
server.pubSub.Subscribe(msgReader.embeddedSub, channels, false)
|
||||
|
||||
return func() []string {
|
||||
r := resp.NewConn(*readConn)
|
||||
v, _, _ := r.ReadValue()
|
||||
|
||||
res := make([]string, len(v.Array()))
|
||||
for i := 0; i < len(res); i++ {
|
||||
res[i] = v.Array()[i].String()
|
||||
}
|
||||
|
||||
return res
|
||||
}, nil
|
||||
return msgReader, nil
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes the caller from the given channels.
|
||||
@@ -110,12 +69,12 @@ func (server *SugarDB) Subscribe(tag string, channels ...string) (ReadPubSubMess
|
||||
//
|
||||
// `channels` - ...string - The list of channels to unsubscribe from.
|
||||
func (server *SugarDB) Unsubscribe(tag string, channels ...string) {
|
||||
c, ok := connections.Load(tag)
|
||||
sub, ok := subscriptions.Load(tag)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
cmd := append([]string{"UNSUBSCRIBE"}, channels...)
|
||||
_, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), c.(conn).writeConn, false, true)
|
||||
msgReader := sub.(*MessageReader)
|
||||
server.pubSub.Unsubscribe(msgReader, channels, false)
|
||||
}
|
||||
|
||||
// PSubscribe subscribes the caller to the list of provided glob patterns.
|
||||
@@ -128,31 +87,23 @@ func (server *SugarDB) Unsubscribe(tag string, channels ...string) {
|
||||
//
|
||||
// Returns: ReadPubSubMessage function which reads the next message sent to the subscription instance.
|
||||
// This function is blocking.
|
||||
func (server *SugarDB) PSubscribe(tag string, patterns ...string) (ReadPubSubMessage, error) {
|
||||
readConn, writeConn, err := establishConnections(tag)
|
||||
if err != nil {
|
||||
return func() []string {
|
||||
return []string{}
|
||||
}, err
|
||||
|
||||
func (server *SugarDB) PSubscribe(tag string, patterns ...string) (*MessageReader, error) {
|
||||
var msgReader *MessageReader
|
||||
|
||||
sub, ok := subscriptions.Load(tag)
|
||||
if !ok {
|
||||
// Create new messageBuffer and store it in the subscriptions
|
||||
msgReader = &MessageReader{
|
||||
embeddedSub: pubsub.NewEmbeddedSub(),
|
||||
}
|
||||
} else {
|
||||
msgReader = sub.(*MessageReader)
|
||||
}
|
||||
|
||||
// Subscribe connection to the provided channels
|
||||
cmd := append([]string{"PSUBSCRIBE"}, patterns...)
|
||||
go func() {
|
||||
_, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), writeConn, false, true)
|
||||
}()
|
||||
server.pubSub.Subscribe(msgReader.embeddedSub, patterns, true)
|
||||
|
||||
return func() []string {
|
||||
r := resp.NewConn(*readConn)
|
||||
v, _, _ := r.ReadValue()
|
||||
|
||||
res := make([]string, len(v.Array()))
|
||||
for i := 0; i < len(res); i++ {
|
||||
res[i] = v.Array()[i].String()
|
||||
}
|
||||
|
||||
return res
|
||||
}, nil
|
||||
return msgReader, nil
|
||||
}
|
||||
|
||||
// PUnsubscribe unsubscribes the caller from the given glob patterns.
|
||||
@@ -163,12 +114,12 @@ func (server *SugarDB) PSubscribe(tag string, patterns ...string) (ReadPubSubMes
|
||||
//
|
||||
// `patterns` - ...string - The list of glob patterns to unsubscribe from.
|
||||
func (server *SugarDB) PUnsubscribe(tag string, patterns ...string) {
|
||||
c, ok := connections.Load(tag)
|
||||
sub, ok := subscriptions.Load(tag)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
cmd := append([]string{"PUNSUBSCRIBE"}, patterns...)
|
||||
_, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), c.(conn).writeConn, false, true)
|
||||
msgReader := sub.(*MessageReader)
|
||||
server.pubSub.Unsubscribe(msgReader, patterns, true)
|
||||
}
|
||||
|
||||
// Publish publishes a message to the given channel.
|
||||
@@ -179,10 +130,12 @@ func (server *SugarDB) PUnsubscribe(tag string, patterns ...string) {
|
||||
//
|
||||
// `message` - string - The message to publish to the specified channel.
|
||||
//
|
||||
// Returns: true when the publish is successful. This does not indicate whether each subscriber has received the message,
|
||||
// only that the message has been published.
|
||||
// Returns: true when successful. This does not indicate whether each subscriber has received the message,
|
||||
// only that the message has been published to the channel.
|
||||
func (server *SugarDB) Publish(channel, message string) (bool, error) {
|
||||
b, err := server.handleCommand(server.context, internal.EncodeCommand([]string{"PUBLISH", channel, message}), nil, false, true)
|
||||
b, err := server.handleCommand(
|
||||
server.context,
|
||||
internal.EncodeCommand([]string{"PUBLISH", channel, message}), nil, false, true)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@@ -15,276 +15,326 @@
|
||||
package sugardb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_Subscribe(t *testing.T) {
|
||||
func TestSugarDB_PubSub(t *testing.T) {
|
||||
server := createSugarDB()
|
||||
t.Cleanup(func() {
|
||||
server.ShutDown()
|
||||
})
|
||||
|
||||
// Subscribe to channels.
|
||||
tag := "tag"
|
||||
channels := []string{"channel1", "channel2"}
|
||||
readMessage, err := server.Subscribe(tag, channels...)
|
||||
if err != nil {
|
||||
t.Errorf("SUBSCRIBE() error = %v", err)
|
||||
}
|
||||
t.Run("TestSugarDB_(P)Subscribe", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i := 0; i < len(channels); i++ {
|
||||
message := readMessage()
|
||||
// Check that we've received the subscribe messages.
|
||||
if message[0] != "subscribe" {
|
||||
t.Errorf("SUBSCRIBE() expected index 0 for message at %d to be \"subscribe\", got %s", i, message[0])
|
||||
tests := []struct {
|
||||
name string
|
||||
action string // subscribe | psubscribe
|
||||
tag string
|
||||
channels []string
|
||||
pubChannels []string // Channels to publish messages to after subscriptions
|
||||
wantMsg []string // Expected messages from after publishing
|
||||
subFunc func(tag string, channels ...string) (*MessageReader, error)
|
||||
unsubFunc func(tag string, channels ...string)
|
||||
}{
|
||||
{
|
||||
name: "1. Subscribe to channels",
|
||||
action: "subscribe",
|
||||
tag: "tag_test_subscribe",
|
||||
channels: []string{
|
||||
"channel1",
|
||||
"channel2",
|
||||
},
|
||||
pubChannels: []string{"channel1", "channel2"},
|
||||
wantMsg: []string{
|
||||
"message for channel1",
|
||||
"message for channel2",
|
||||
},
|
||||
subFunc: server.Subscribe,
|
||||
unsubFunc: server.Unsubscribe,
|
||||
},
|
||||
{
|
||||
name: "2. Subscribe to patterns",
|
||||
action: "psubscribe",
|
||||
tag: "tag_test_psubscribe",
|
||||
channels: []string{
|
||||
"channel[34]",
|
||||
"pattern[12]",
|
||||
},
|
||||
pubChannels: []string{
|
||||
"channel3",
|
||||
"channel4",
|
||||
"pattern1",
|
||||
"pattern2",
|
||||
},
|
||||
wantMsg: []string{
|
||||
"message for channel3",
|
||||
"message for channel4",
|
||||
"message for pattern1",
|
||||
"message for pattern2",
|
||||
},
|
||||
subFunc: server.PSubscribe,
|
||||
unsubFunc: server.PUnsubscribe,
|
||||
},
|
||||
}
|
||||
if !slices.Contains(channels, message[1]) {
|
||||
t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
|
||||
}
|
||||
}
|
||||
|
||||
// Publish some messages to the channels.
|
||||
for _, channel := range channels {
|
||||
ok, err := server.Publish(channel, fmt.Sprintf("message for %s", channel))
|
||||
if err != nil {
|
||||
t.Errorf("PUBLISH() err = %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Errorf("PUBLISH() could not publish message to channel %s", channel)
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Read messages from the channels
|
||||
for i := 0; i < len(channels); i++ {
|
||||
message := readMessage()
|
||||
// Check that we've received the messages.
|
||||
if message[0] != "message" {
|
||||
t.Errorf("SUBSCRIBE() expected index 0 for message at %d to be \"message\", got %s", i, message[0])
|
||||
}
|
||||
if !slices.Contains(channels, message[1]) {
|
||||
t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
|
||||
}
|
||||
if !slices.Contains([]string{"message for channel1", "message for channel2"}, message[2]) {
|
||||
t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
|
||||
}
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
tt.unsubFunc(tt.tag, tt.channels...)
|
||||
})
|
||||
|
||||
// Unsubscribe from channels
|
||||
server.Unsubscribe(tag, channels...)
|
||||
}
|
||||
|
||||
func TestSugarDB_PSubscribe(t *testing.T) {
|
||||
server := createSugarDB()
|
||||
|
||||
// Subscribe to channels.
|
||||
tag := "tag"
|
||||
patterns := []string{"channel[12]", "pattern[12]"}
|
||||
readMessage, err := server.PSubscribe(tag, patterns...)
|
||||
if err != nil {
|
||||
t.Errorf("PSubscribe() error = %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < len(patterns); i++ {
|
||||
message := readMessage()
|
||||
// Check that we've received the subscribe messages.
|
||||
if message[0] != "psubscribe" {
|
||||
t.Errorf("PSUBSCRIBE() expected index 0 for message at %d to be \"psubscribe\", got %s", i, message[0])
|
||||
}
|
||||
if !slices.Contains(patterns, message[1]) {
|
||||
t.Errorf("PSUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
|
||||
}
|
||||
}
|
||||
|
||||
// Publish some messages to the channels.
|
||||
for _, channel := range []string{"channel1", "channel2", "pattern1", "pattern2"} {
|
||||
ok, err := server.Publish(channel, fmt.Sprintf("message for %s", channel))
|
||||
if err != nil {
|
||||
t.Errorf("PUBLISH() err = %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Errorf("PUBLISH() could not publish message to channel %s", channel)
|
||||
}
|
||||
}
|
||||
|
||||
// Read messages from the channels
|
||||
for i := 0; i < len(patterns)*2; i++ {
|
||||
message := readMessage()
|
||||
// Check that we've received the messages.
|
||||
if message[0] != "message" {
|
||||
t.Errorf("SUBSCRIBE() expected index 0 for message at %d to be \"message\", got %s", i, message[0])
|
||||
}
|
||||
if !slices.Contains(patterns, message[1]) {
|
||||
t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
|
||||
}
|
||||
if !slices.Contains([]string{
|
||||
"message for channel1", "message for channel2", "message for pattern1", "message for pattern2"}, message[2]) {
|
||||
t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[2], i)
|
||||
}
|
||||
}
|
||||
|
||||
// Unsubscribe from channels
|
||||
server.PUnsubscribe(tag, patterns...)
|
||||
}
|
||||
|
||||
func TestSugarDB_PubSubChannels(t *testing.T) {
|
||||
server := createSugarDB()
|
||||
tests := []struct {
|
||||
name string
|
||||
tag string
|
||||
channels []string
|
||||
patterns []string
|
||||
pattern string
|
||||
want []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "1. Get number of active channels",
|
||||
tag: "tag",
|
||||
channels: []string{"channel1", "channel2", "channel3", "channel4"},
|
||||
patterns: []string{"channel[56]"},
|
||||
pattern: "channel[123456]",
|
||||
want: []string{"channel1", "channel2", "channel3", "channel4"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Subscribe to channels
|
||||
readChannelMessages, err := server.Subscribe(tt.tag, tt.channels...)
|
||||
if err != nil {
|
||||
t.Errorf("PubSubChannels() error = %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < len(tt.channels); i++ {
|
||||
readChannelMessages()
|
||||
}
|
||||
// Subscribe to patterns
|
||||
readPatternMessages, err := server.PSubscribe(tt.tag, tt.patterns...)
|
||||
if err != nil {
|
||||
t.Errorf("PubSubChannels() error = %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < len(tt.patterns); i++ {
|
||||
readPatternMessages()
|
||||
}
|
||||
got, err := server.PubSubChannels(tt.pattern)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("PubSubChannels() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("PubSubChannels() got response length %d, want %d", len(got), len(tt.want))
|
||||
}
|
||||
for _, item := range got {
|
||||
if !slices.Contains(tt.want, item) {
|
||||
t.Errorf("PubSubChannels() unexpected item \"%s\", in response", item)
|
||||
// Subscribe to channels.
|
||||
readMessage, err := tt.subFunc(tt.tag, tt.channels...)
|
||||
if err != nil {
|
||||
t.Errorf("(P)SUBSCRIBE() error = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSugarDB_PubSubNumPat(t *testing.T) {
|
||||
server := createSugarDB()
|
||||
tests := []struct {
|
||||
name string
|
||||
tag string
|
||||
patterns []string
|
||||
want int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "1. Get number of active patterns on the server",
|
||||
tag: "tag",
|
||||
patterns: []string{"channel[56]", "channel[78]"},
|
||||
want: 2,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Subscribe to patterns
|
||||
readPatternMessages, err := server.PSubscribe(tt.tag, tt.patterns...)
|
||||
if err != nil {
|
||||
t.Errorf("PubSubNumPat() error = %v", err)
|
||||
}
|
||||
for i := 0; i < len(tt.patterns); i++ {
|
||||
readPatternMessages()
|
||||
}
|
||||
got, err := server.PubSubNumPat()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("PubSubNumPat() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("PubSubNumPat() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(tt.channels); i++ {
|
||||
p := make([]byte, 1024)
|
||||
_, err := readMessage.Read(p)
|
||||
if err != nil {
|
||||
t.Errorf("(P)SUBSCRIBE() read error: %+v", err)
|
||||
}
|
||||
var message []string
|
||||
if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil {
|
||||
t.Errorf("(P)SUBSCRIBE() json unmarshal error: %+v", err)
|
||||
}
|
||||
// Check that we've received the subscribe messages.
|
||||
if message[0] != tt.action {
|
||||
t.Errorf("(P)SUBSCRIBE() expected index 0 for message at %d to be \"subscribe\", got %s", i, message[0])
|
||||
}
|
||||
if !slices.Contains(tt.channels, message[1]) {
|
||||
t.Errorf("(P)SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSugarDB_PubSubNumSub(t *testing.T) {
|
||||
server := createSugarDB()
|
||||
tests := []struct {
|
||||
name string
|
||||
subscriptions map[string]struct {
|
||||
// Publish some messages to the channels.
|
||||
for _, channel := range tt.pubChannels {
|
||||
ok, err := server.Publish(channel, fmt.Sprintf("message for %s", channel))
|
||||
if err != nil {
|
||||
t.Errorf("PUBLISH() err = %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Errorf("PUBLISH() could not publish message to channel %s", channel)
|
||||
}
|
||||
}
|
||||
|
||||
// Read messages from the channels
|
||||
for i := 0; i < len(tt.pubChannels); i++ {
|
||||
p := make([]byte, 1024)
|
||||
_, err := readMessage.Read(p)
|
||||
|
||||
doneChan := make(chan struct{}, 1)
|
||||
go func() {
|
||||
for {
|
||||
if err != nil && err == io.EOF {
|
||||
_, err = readMessage.Read(p)
|
||||
continue
|
||||
}
|
||||
doneChan <- struct{}{}
|
||||
break
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("(P)SUBSCRIBE() timeout")
|
||||
case <-doneChan:
|
||||
if err != nil {
|
||||
t.Errorf("(P)SUBSCRIBE() read error: %+v", err)
|
||||
}
|
||||
}
|
||||
|
||||
var message []string
|
||||
if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil {
|
||||
t.Errorf("(P)SUBSCRIBE() json unmarshal error: %+v", err)
|
||||
}
|
||||
// Check that we've received the messages.
|
||||
if message[0] != "message" {
|
||||
t.Errorf("(P)SUBSCRIBE() expected index 0 for message at %d to be \"message\", got %s", i, message[0])
|
||||
}
|
||||
if !slices.Contains(tt.channels, message[1]) {
|
||||
t.Errorf("(P)SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i)
|
||||
}
|
||||
if !slices.Contains(tt.wantMsg, message[2]) {
|
||||
t.Errorf("(P)SUBSCRIBE() unexpected string \"%s\" at index 2 for message %d", message[1], i)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TestSugarDB_PubSubChannels", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tag string
|
||||
channels []string
|
||||
patterns []string
|
||||
pattern string
|
||||
want []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "1. Get number of active channels",
|
||||
tag: "tag_test_channels_1",
|
||||
channels: []string{"channel1", "channel2", "channel3", "channel4"},
|
||||
patterns: []string{"channel[56]"},
|
||||
pattern: "channel[123456]",
|
||||
want: []string{"channel1", "channel2", "channel3", "channel4"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
channels []string
|
||||
want map[string]int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Get number of subscriptions for the given channels",
|
||||
subscriptions: map[string]struct {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Subscribe to channels
|
||||
_, err := server.Subscribe(tt.tag, tt.channels...)
|
||||
if err != nil {
|
||||
t.Errorf("PubSubChannels() error = %v", err)
|
||||
}
|
||||
|
||||
// Subscribe to patterns
|
||||
_, err = server.PSubscribe(tt.tag, tt.patterns...)
|
||||
if err != nil {
|
||||
t.Errorf("PubSubChannels() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := server.PubSubChannels(tt.pattern)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("PubSubChannels() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("PubSubChannels() got response length %d, want %d", len(got), len(tt.want))
|
||||
}
|
||||
for _, item := range got {
|
||||
if !slices.Contains(tt.want, item) {
|
||||
t.Errorf("PubSubChannels() unexpected item \"%s\", in response", item)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TestSugarDB_PubSubNumPat", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tag string
|
||||
patterns []string
|
||||
want int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "1. Get number of active patterns on the server",
|
||||
tag: "tag_test_numpat_1",
|
||||
patterns: []string{"channel[56]", "channel[78]"},
|
||||
want: 2,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Subscribe to patterns
|
||||
_, err := server.PSubscribe(tt.tag, tt.patterns...)
|
||||
if err != nil {
|
||||
t.Errorf("PubSubNumPat() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := server.PubSubNumPat()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("PubSubNumPat() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("PubSubNumPat() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TestSugarDB_PubSubNumSub", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
subscriptions map[string]struct {
|
||||
channels []string
|
||||
patterns []string
|
||||
}{
|
||||
"tag1": {
|
||||
channels: []string{"channel1", "channel2"},
|
||||
patterns: []string{"channel[34]"},
|
||||
}
|
||||
channels []string
|
||||
want map[string]int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "1. Get number of subscriptions for the given channels",
|
||||
subscriptions: map[string]struct {
|
||||
channels []string
|
||||
patterns []string
|
||||
}{
|
||||
"tag1_test_numsub_1": {
|
||||
channels: []string{"test_num_sub_channel1", "test_num_sub_channel2"},
|
||||
patterns: []string{"test_num_sub_channel[34]"},
|
||||
},
|
||||
"tag2_test_numsub_2": {
|
||||
channels: []string{"test_num_sub_channel2", "test_num_sub_channel3"},
|
||||
patterns: []string{"test_num_sub_channel[23]"},
|
||||
},
|
||||
"tag3_test_numsub_3": {
|
||||
channels: []string{"test_num_sub_channel2", "test_num_sub_channel4"},
|
||||
patterns: []string{},
|
||||
},
|
||||
},
|
||||
"tag2": {
|
||||
channels: []string{"channel2", "channel3"},
|
||||
patterns: []string{"channel[23]"},
|
||||
channels: []string{
|
||||
"test_num_sub_channel1",
|
||||
"test_num_sub_channel2",
|
||||
"test_num_sub_channel3",
|
||||
"test_num_sub_channel4",
|
||||
"test_num_sub_channel5",
|
||||
},
|
||||
"tag3": {
|
||||
channels: []string{"channel2", "channel4"},
|
||||
patterns: []string{},
|
||||
want: map[string]int{
|
||||
"test_num_sub_channel1": 1,
|
||||
"test_num_sub_channel2": 3,
|
||||
"test_num_sub_channel3": 1,
|
||||
"test_num_sub_channel4": 1,
|
||||
"test_num_sub_channel5": 0,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
channels: []string{"channel1", "channel2", "channel3", "channel4", "channel5"},
|
||||
want: map[string]int{"channel1": 1, "channel2": 3, "channel3": 1, "channel4": 1, "channel5": 0},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for tag, subs := range tt.subscriptions {
|
||||
readPat, err := server.PSubscribe(tag, subs.patterns...)
|
||||
if err != nil {
|
||||
t.Errorf("PubSubNumSub() error = %v", err)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for tag, subs := range tt.subscriptions {
|
||||
_, err := server.PSubscribe(tag, subs.patterns...)
|
||||
if err != nil {
|
||||
t.Errorf("PubSubNumSub() error = %v", err)
|
||||
}
|
||||
|
||||
_, err = server.Subscribe(tag, subs.channels...)
|
||||
if err != nil {
|
||||
t.Errorf("PubSubNumSub() error = %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
for _, _ = range subs.patterns {
|
||||
readPat()
|
||||
got, err := server.PubSubNumSub(tt.channels...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("PubSubNumSub() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
readChan, err := server.Subscribe(tag, subs.channels...)
|
||||
if err != nil {
|
||||
t.Errorf("PubSubNumSub() error = %v", err)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("PubSubNumSub() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
for _, _ = range subs.channels {
|
||||
readChan()
|
||||
}
|
||||
}
|
||||
got, err := server.PubSubNumSub(tt.channels...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("PubSubNumSub() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("PubSubNumSub() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@@ -204,7 +204,7 @@ func NewSugarDB(options ...func(sugarDB *SugarDB)) (*SugarDB, error) {
|
||||
sugarDB.acl = acl.NewACL(sugarDB.config)
|
||||
|
||||
// Set up Pub/Sub module
|
||||
sugarDB.pubSub = pubsub.NewPubSub()
|
||||
sugarDB.pubSub = pubsub.NewPubSub(sugarDB.context)
|
||||
|
||||
if sugarDB.isInCluster() {
|
||||
sugarDB.raft = raft.NewRaft(raft.Opts{
|
||||
|
Reference in New Issue
Block a user