diff --git a/docs/docs/commands/pubsub/psubscribe.mdx b/docs/docs/commands/pubsub/psubscribe.mdx index 5106811..8c82734 100644 --- a/docs/docs/commands/pubsub/psubscribe.mdx +++ b/docs/docs/commands/pubsub/psubscribe.mdx @@ -29,23 +29,34 @@ Subscribe to one or more patterns. This command accepts glob patterns. ]} > - The Subscribe method returns a readMessage function. - This method is lazy so it must be invoked each time the you want to read the next message from - the pattern. - When subscribing to an'N' number of patterns, the first N messages will be + The PSubscribe method returns a MessageReader type which implements the `io.Reader` interface. + When subscribing to an'N' number of channels, the first N messages will be the subscription confimations. - The readMessage functions returns a message object when called. The message - object is a string slice with the following inforamtion: - event type at index 0 (e.g. subscribe, message), pattern at index 1, + + The message read is a series of bytes resulting from JSON marshalling an array. You will need to + unmarshal it back into a string array in order to read it. Here's the anatomy of a pubsub message: + event type at index 0 (e.g. psubscribe, message), channel name at index 1, message/subscription index at index 2. + + Messages published to any channels that match the pattern will be received by the pattern subscriber. + ```go db, err := sugardb.NewSugarDB() if err != nil { log.Fatal(err) } - readMessage := db.PSubscribe("subscribe_tag_1", "pattern_[12]", "pattern_h[ae]llo") // Return lazy readMessage function - for i := 0; i < 2; i++ { - message := readMessage() // Call the readMessage function for each channel subscription. + + // Subscribe to multiple channel patterns, returs MessageReader + msgReader := db.PSubscribe("psubscribe_tag_1", "channel[12]", "pattern[12]") + + // Read message into pre-defined buffer + msg := make([]byte, 1024) + _, err := msgReader.Read(msg) + + // Trim all null bytes at the end of the message before unmarshalling. + var message []string + if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil { + log.Fatalf("json unmarshal error: %+v", err) } ``` diff --git a/docs/docs/commands/pubsub/subscribe.mdx b/docs/docs/commands/pubsub/subscribe.mdx index 6975f61..55b22c3 100644 --- a/docs/docs/commands/pubsub/subscribe.mdx +++ b/docs/docs/commands/pubsub/subscribe.mdx @@ -29,23 +29,32 @@ Subscribe to one or more channels. ]} > - The Subscribe method returns a readMessage function. - This method is lazy so it must be invoked each time the you want to read the next message from - the channel. + The Subscribe method returns a MessageReader type which implements the `io.Reader` interface. When subscribing to an'N' number of channels, the first N messages will be the subscription confimations. - The readMessage functions returns a message object when called. The message - object is a string slice with the following inforamtion: + + The message read is a series of bytes resulting from JSON marshalling an array. You will need to + unmarshal it back into a string array in order to read it. Here's the anatomy of a pubsub message: event type at index 0 (e.g. subscribe, message), channel name at index 1, message/subscription index at index 2. + ```go db, err := sugardb.NewSugarDB() if err != nil { log.Fatal(err) } - readMessage := db.Subscribe("subscribe_tag_1", "channel1", "channel2") // Return lazy readMessage function - for i := 0; i < 2; i++ { - message := readMessage() // Call the readMessage function for each channel subscription. + + // Subscribe to multiple channel patterns, returs MessageReader. + msgReader := db.Subscribe("subscribe_tag_1", "channel1", "channel2") + + // Read message into pre-defined buffer. + msg := make([]byte, 1024) + _, err := msgReader.Read(msg) + + // Trim all null bytes at the end of the message before unmarshalling. + var message []string + if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil { + log.Fatalf("json unmarshal error: %+v", err) } ``` diff --git a/internal/modules/pubsub/channel.go b/internal/modules/pubsub/channel.go index 8200113..0961db9 100644 --- a/internal/modules/pubsub/channel.go +++ b/internal/modules/pubsub/channel.go @@ -15,19 +15,30 @@ package pubsub import ( + "context" + "encoding/json" + "fmt" "github.com/gobwas/glob" "github.com/tidwall/resp" "log" "net" + "slices" "sync" ) type Channel struct { - name string // Channel name. This can be a glob pattern string. - pattern glob.Glob // Compiled glob pattern. This is nil if the channel is not a pattern channel. - subscribersRWMut sync.RWMutex // RWMutex to concurrency control when accessing channel subscribers. - subscribers map[*net.Conn]*resp.Conn // Map containing the channel subscribers. - messageChan *chan string // Messages published to this channel will be sent to this channel. + name string // Channel name. This can be a glob pattern string. + pattern glob.Glob // Compiled glob pattern. This is nil if the channel is not a pattern channel. + + messages []string // Slice that holds messages. + messagesRWMut sync.RWMutex // RWMutex for accessing channel messages. + messagesCond *sync.Cond + + tcpSubs map[*net.Conn]*resp.Conn // Map containing the channel's TCP subscribers. + tcpSubsRWMut sync.RWMutex // RWMutex for accessing TCP channel subscribers. + + embeddedSubs []*EmbeddedSub // Slice containing embedded subscribers to this channel. + embeddedSubsRWMut sync.RWMutex // RWMutex for accessing embedded subscribers. } // WithName option sets the channels name. @@ -45,46 +56,89 @@ func WithPattern(pattern string) func(channel *Channel) { } } -func NewChannel(options ...func(channel *Channel)) *Channel { - messageChan := make(chan string, 4096) - +func NewChannel(ctx context.Context, options ...func(channel *Channel)) *Channel { channel := &Channel{ - name: "", - pattern: nil, - subscribersRWMut: sync.RWMutex{}, - subscribers: make(map[*net.Conn]*resp.Conn), - messageChan: &messageChan, + name: "", + pattern: nil, + + messages: make([]string, 0), + messagesRWMut: sync.RWMutex{}, + + tcpSubs: make(map[*net.Conn]*resp.Conn), + tcpSubsRWMut: sync.RWMutex{}, + + embeddedSubs: make([]*EmbeddedSub, 0), + embeddedSubsRWMut: sync.RWMutex{}, } + channel.messagesCond = sync.NewCond(&channel.messagesRWMut) for _, option := range options { option(channel) } - return channel -} - -func (ch *Channel) Start() { go func() { for { - message := <-*ch.messageChan + select { + case <-ctx.Done(): + log.Printf("closing channel %s\n", channel.name) + return + default: + channel.messagesRWMut.Lock() + for len(channel.messages) == 0 { + channel.messagesCond.Wait() + } - ch.subscribersRWMut.RLock() + message := channel.messages[0] + channel.messages = channel.messages[1:] + channel.messagesRWMut.Unlock() - for _, conn := range ch.subscribers { - go func(conn *resp.Conn) { - if err := conn.WriteArray([]resp.Value{ - resp.StringValue("message"), - resp.StringValue(ch.name), - resp.StringValue(message), - }); err != nil { - log.Println(err) + wg := sync.WaitGroup{} + wg.Add(2) + + // Send messages to embedded subscribers + go func() { + channel.embeddedSubsRWMut.RLock() + ewg := sync.WaitGroup{} + b, _ := json.Marshal([]string{"message", channel.name, message}) + msg := append(b, byte('\n')) + for _, w := range channel.embeddedSubs { + ewg.Add(1) + go func(w *EmbeddedSub) { + _, _ = w.Write(msg) + ewg.Done() + }(w) } - }(conn) - } + ewg.Wait() + channel.embeddedSubsRWMut.RUnlock() + wg.Done() + }() - ch.subscribersRWMut.RUnlock() + // Send messages to TCP subscribers + go func() { + channel.tcpSubsRWMut.RLock() + cwg := sync.WaitGroup{} + for _, conn := range channel.tcpSubs { + cwg.Add(1) + go func(conn *resp.Conn) { + _ = conn.WriteArray([]resp.Value{ + resp.StringValue("message"), + resp.StringValue(channel.name), + resp.StringValue(message), + }) + cwg.Done() + }(conn) + } + cwg.Wait() + channel.tcpSubsRWMut.RUnlock() + wg.Done() + }() + + wg.Wait() + } } }() + + return channel } func (ch *Channel) Name() string { @@ -95,56 +149,90 @@ func (ch *Channel) Pattern() glob.Glob { return ch.pattern } -func (ch *Channel) Subscribe(conn *net.Conn) bool { - ch.subscribersRWMut.Lock() - defer ch.subscribersRWMut.Unlock() - if _, ok := ch.subscribers[conn]; !ok { - ch.subscribers[conn] = resp.NewConn(*conn) +func (ch *Channel) Subscribe(sub any, action string, chanIdx int) { + switch sub.(type) { + case *net.Conn: + ch.tcpSubsRWMut.Lock() + defer ch.tcpSubsRWMut.Unlock() + conn := sub.(*net.Conn) + if _, ok := ch.tcpSubs[conn]; !ok { + ch.tcpSubs[conn] = resp.NewConn(*conn) + } + r, _ := ch.tcpSubs[conn] + // Send subscription message + _ = r.WriteArray([]resp.Value{ + resp.StringValue(action), + resp.StringValue(ch.name), + resp.IntegerValue(chanIdx + 1), + }) + + case *EmbeddedSub: + ch.embeddedSubsRWMut.Lock() + defer ch.embeddedSubsRWMut.Unlock() + w := sub.(*EmbeddedSub) + if !slices.ContainsFunc(ch.embeddedSubs, func(writer *EmbeddedSub) bool { + return writer == w + }) { + ch.embeddedSubs = append(ch.embeddedSubs, w) + } + // Send subscription message + b, _ := json.Marshal([]string{action, ch.name, fmt.Sprintf("%d", chanIdx+1)}) + msg := append(b, byte('\n')) + _, _ = w.Write(msg) } - _, ok := ch.subscribers[conn] - return ok } -func (ch *Channel) Unsubscribe(conn *net.Conn) bool { - ch.subscribersRWMut.Lock() - defer ch.subscribersRWMut.Unlock() - if _, ok := ch.subscribers[conn]; !ok { +func (ch *Channel) Unsubscribe(sub any) bool { + switch sub.(type) { + default: return false + + case *net.Conn: + ch.tcpSubsRWMut.Lock() + defer ch.tcpSubsRWMut.Unlock() + conn := sub.(*net.Conn) + if _, ok := ch.tcpSubs[conn]; !ok { + return false + } + delete(ch.tcpSubs, conn) + return true + + case *EmbeddedSub: + ch.embeddedSubsRWMut.Lock() + defer ch.embeddedSubsRWMut.Unlock() + w := sub.(*EmbeddedSub) + deleted := false + ch.embeddedSubs = slices.DeleteFunc(ch.embeddedSubs, func(writer *EmbeddedSub) bool { + deleted = writer == w + return deleted + }) + return deleted } - delete(ch.subscribers, conn) - return true } func (ch *Channel) Publish(message string) { - *ch.messageChan <- message + ch.messagesRWMut.Lock() + defer ch.messagesRWMut.Unlock() + ch.messages = append(ch.messages, message) + ch.messagesCond.Signal() } func (ch *Channel) IsActive() bool { - ch.subscribersRWMut.RLock() - defer ch.subscribersRWMut.RUnlock() + ch.tcpSubsRWMut.RLock() + defer ch.tcpSubsRWMut.RUnlock() - active := len(ch.subscribers) > 0 + ch.embeddedSubsRWMut.RLock() + defer ch.embeddedSubsRWMut.RUnlock() - return active + return len(ch.tcpSubs)+len(ch.embeddedSubs) > 0 } func (ch *Channel) NumSubs() int { - ch.subscribersRWMut.RLock() - defer ch.subscribersRWMut.RUnlock() + ch.tcpSubsRWMut.RLock() + defer ch.tcpSubsRWMut.RUnlock() - n := len(ch.subscribers) + ch.embeddedSubsRWMut.RLock() + defer ch.embeddedSubsRWMut.RUnlock() - return n -} - -func (ch *Channel) Subscribers() map[*net.Conn]*resp.Conn { - ch.subscribersRWMut.RLock() - defer ch.subscribersRWMut.RUnlock() - - subscribers := make(map[*net.Conn]*resp.Conn, len(ch.subscribers)) - for k, v := range ch.subscribers { - subscribers[k] = v - } - - return subscribers + return len(ch.tcpSubs) + len(ch.embeddedSubs) } diff --git a/internal/modules/pubsub/commands.go b/internal/modules/pubsub/commands.go index f404d96..5e57328 100644 --- a/internal/modules/pubsub/commands.go +++ b/internal/modules/pubsub/commands.go @@ -35,7 +35,8 @@ func handleSubscribe(params internal.HandlerFuncParams) ([]byte, error) { } withPattern := strings.EqualFold(params.Command[0], "psubscribe") - pubsub.Subscribe(params.Context, params.Connection, channels, withPattern) + + pubsub.Subscribe(params.Connection, channels, withPattern) return nil, nil } @@ -50,7 +51,7 @@ func handleUnsubscribe(params internal.HandlerFuncParams) ([]byte, error) { withPattern := strings.EqualFold(params.Command[0], "punsubscribe") - return pubsub.Unsubscribe(params.Context, params.Connection, channels, withPattern), nil + return pubsub.Unsubscribe(params.Connection, channels, withPattern), nil } func handlePublish(params internal.HandlerFuncParams) ([]byte, error) { @@ -61,7 +62,7 @@ func handlePublish(params internal.HandlerFuncParams) ([]byte, error) { if len(params.Command) != 3 { return nil, errors.New(constants.WrongArgsResponse) } - pubsub.Publish(params.Context, params.Command[2], params.Command[1]) + pubsub.Publish(params.Command[2], params.Command[1]) return []byte(constants.OkResponse), nil } diff --git a/internal/modules/pubsub/pubsub.go b/internal/modules/pubsub/pubsub.go index 49f94bc..a1c2bb2 100644 --- a/internal/modules/pubsub/pubsub.go +++ b/internal/modules/pubsub/pubsub.go @@ -17,33 +17,31 @@ package pubsub import ( "context" "fmt" - "log" "net" "slices" "sync" "github.com/gobwas/glob" - "github.com/tidwall/resp" ) type PubSub struct { + ctx context.Context channels []*Channel channelsRWMut sync.RWMutex } -func NewPubSub() *PubSub { +func NewPubSub(ctx context.Context) *PubSub { return &PubSub{ + ctx: ctx, channels: []*Channel{}, channelsRWMut: sync.RWMutex{}, } } -func (ps *PubSub) Subscribe(_ context.Context, conn *net.Conn, channels []string, withPattern bool) { +func (ps *PubSub) Subscribe(sub any, channels []string, withPattern bool) { ps.channelsRWMut.Lock() defer ps.channelsRWMut.Unlock() - r := resp.NewConn(*conn) - action := "subscribe" if withPattern { action = "psubscribe" @@ -58,40 +56,46 @@ func (ps *PubSub) Subscribe(_ context.Context, conn *net.Conn, channels []string }) if channelIdx == -1 { - // Create new channel, start it, and subscribe to it + // Create new channel, if it does not exist var newChan *Channel if withPattern { - newChan = NewChannel(WithPattern(channels[i])) + newChan = NewChannel(ps.ctx, WithPattern(channels[i])) } else { - newChan = NewChannel(WithName(channels[i])) + newChan = NewChannel(ps.ctx, WithName(channels[i])) } - newChan.Start() - if newChan.Subscribe(conn) { - if err := r.WriteArray([]resp.Value{ - resp.StringValue(action), - resp.StringValue(newChan.name), - resp.IntegerValue(i + 1), - }); err != nil { - log.Println(err) - } - ps.channels = append(ps.channels, newChan) + // Append the channel to the list of channels + ps.channels = append(ps.channels, newChan) + + // Subscribe to the channel + switch sub.(type) { + case *net.Conn: + // Subscribe TCP connection + conn := sub.(*net.Conn) + newChan.Subscribe(conn, action, i) + + case *EmbeddedSub: + // Subscribe io.Writer from embedded client + w := sub.(*EmbeddedSub) + newChan.Subscribe(w, action, i) } } else { // Subscribe to existing channel - if ps.channels[channelIdx].Subscribe(conn) { - if err := r.WriteArray([]resp.Value{ - resp.StringValue(action), - resp.StringValue(ps.channels[channelIdx].name), - resp.IntegerValue(i + 1), - }); err != nil { - log.Println(err) - } + switch sub.(type) { + case *net.Conn: + // Subscribe TCP connection + conn := sub.(*net.Conn) + ps.channels[channelIdx].Subscribe(conn, action, i) + + case *EmbeddedSub: + // Subscribe io.Writer from embedded client + w := sub.(*EmbeddedSub) + ps.channels[channelIdx].Subscribe(w, action, i) } } } } -func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []string, withPattern bool) []byte { +func (ps *PubSub) Unsubscribe(sub any, channels []string, withPattern bool) []byte { ps.channelsRWMut.RLock() defer ps.channelsRWMut.RUnlock() @@ -111,7 +115,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri if channel.pattern != nil { // Skip pattern channels continue } - if channel.Unsubscribe(conn) { + if channel.Unsubscribe(sub) { unsubscribed[idx] = channel.name idx += 1 } @@ -123,7 +127,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri if channel.pattern == nil { // Skip non-pattern channels continue } - if channel.Unsubscribe(conn) { + if channel.Unsubscribe(sub) { unsubscribed[idx] = channel.name idx += 1 } @@ -136,7 +140,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri // names exactly matches the pattern name. for _, channel := range ps.channels { // For each channel in PubSub for _, c := range channels { // For each channel name provided - if channel.name == c && channel.Unsubscribe(conn) { + if channel.name == c && channel.Unsubscribe(sub) { unsubscribed[idx] = channel.name idx += 1 } @@ -151,7 +155,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri for _, channel := range ps.channels { // If it's a pattern channel, directly compare the patterns if channel.pattern != nil && channel.name == pattern { - if channel.Unsubscribe(conn) { + if channel.Unsubscribe(sub) { unsubscribed[idx] = channel.name idx += 1 } @@ -159,7 +163,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri } // If this is a regular channel, check if the channel name matches the pattern given if g.Match(channel.name) { - if channel.Unsubscribe(conn) { + if channel.Unsubscribe(sub) { unsubscribed[idx] = channel.name idx += 1 } @@ -176,7 +180,7 @@ func (ps *PubSub) Unsubscribe(_ context.Context, conn *net.Conn, channels []stri return []byte(res) } -func (ps *PubSub) Publish(_ context.Context, message string, channelName string) { +func (ps *PubSub) Publish(message string, channelName string) { ps.channelsRWMut.RLock() defer ps.channelsRWMut.RUnlock() diff --git a/internal/modules/pubsub/sub.go b/internal/modules/pubsub/sub.go new file mode 100644 index 0000000..d9a3486 --- /dev/null +++ b/internal/modules/pubsub/sub.go @@ -0,0 +1,59 @@ +// Copyright 2024 Kelvin Clement Mwinuka +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pubsub + +import ( + "bufio" + "bytes" + "sync" +) + +type EmbeddedSub struct { + mux sync.Mutex + buff *bytes.Buffer + writer *bufio.Writer + reader *bufio.Reader +} + +func NewEmbeddedSub() *EmbeddedSub { + sub := &EmbeddedSub{ + mux: sync.Mutex{}, + buff: bytes.NewBuffer(make([]byte, 0)), + } + sub.writer = bufio.NewWriter(sub.buff) + sub.reader = bufio.NewReader(sub.buff) + return sub +} + +func (sub *EmbeddedSub) Write(p []byte) (int, error) { + sub.mux.Lock() + defer sub.mux.Unlock() + n, err := sub.writer.Write(p) + if err != nil { + return n, err + } + err = sub.writer.Flush() + return n, err +} + +func (sub *EmbeddedSub) Read(p []byte) (int, error) { + sub.mux.Lock() + defer sub.mux.Unlock() + + chunk, err := sub.reader.ReadBytes(byte('\n')) + n := copy(p, chunk) + + return n, err +} diff --git a/sugardb/api_pubsub.go b/sugardb/api_pubsub.go index 5c0d452..6d03231 100644 --- a/sugardb/api_pubsub.go +++ b/sugardb/api_pubsub.go @@ -16,55 +16,23 @@ package sugardb import ( "bytes" - "errors" "github.com/echovault/sugardb/internal" + "github.com/echovault/sugardb/internal/modules/pubsub" "github.com/tidwall/resp" - "net" "strings" "sync" ) -type conn struct { - readConn *net.Conn - writeConn *net.Conn +type MessageReader struct { + embeddedSub *pubsub.EmbeddedSub } -var connections sync.Map - -// ReadPubSubMessage is returned by the Subscribe and PSubscribe functions. -// -// This function is lazy, therefore it needs to be invoked in order to read the next message. -// When the message is read, the function returns a string slice with 3 elements. -// Index 0 holds the event type which in this case will be "message". Index 1 holds the channel name. -// Index 2 holds the actual message. -type ReadPubSubMessage func() []string - -func establishConnections(tag string) (*net.Conn, *net.Conn, error) { - var readConn *net.Conn - var writeConn *net.Conn - - if _, ok := connections.Load(tag); !ok { - // If connection with this name does not exist, create new connection. - rc, wc := net.Pipe() - readConn = &rc - writeConn = &wc - connections.Store(tag, conn{ - readConn: &rc, - writeConn: &wc, - }) - } else { - // Reuse existing connection. - c, ok := connections.Load(tag) - if !ok { - return nil, nil, errors.New("could not establish connection") - } - readConn = c.(conn).readConn - writeConn = c.(conn).writeConn - } - - return readConn, writeConn, nil +func (reader *MessageReader) Read(p []byte) (int, error) { + return reader.embeddedSub.Read(p) } +var subscriptions sync.Map + // Subscribe subscribes the caller to the list of provided channels. // // Parameters: @@ -75,31 +43,22 @@ func establishConnections(tag string) (*net.Conn, *net.Conn, error) { // // Returns: ReadPubSubMessage function which reads the next message sent to the subscription instance. // This function is blocking. -func (server *SugarDB) Subscribe(tag string, channels ...string) (ReadPubSubMessage, error) { - readConn, writeConn, err := establishConnections(tag) - if err != nil { - return func() []string { - return []string{} - }, err +func (server *SugarDB) Subscribe(tag string, channels ...string) (*MessageReader, error) { + var msgReader *MessageReader + + sub, ok := subscriptions.Load(tag) + if !ok { + // Create new messageBuffer and store it in the subscriptions + msgReader = &MessageReader{ + embeddedSub: pubsub.NewEmbeddedSub(), + } + } else { + msgReader = sub.(*MessageReader) } - // Subscribe connection to the provided channels. - cmd := append([]string{"SUBSCRIBE"}, channels...) - go func() { - _, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), writeConn, false, true) - }() + server.pubSub.Subscribe(msgReader.embeddedSub, channels, false) - return func() []string { - r := resp.NewConn(*readConn) - v, _, _ := r.ReadValue() - - res := make([]string, len(v.Array())) - for i := 0; i < len(res); i++ { - res[i] = v.Array()[i].String() - } - - return res - }, nil + return msgReader, nil } // Unsubscribe unsubscribes the caller from the given channels. @@ -110,12 +69,12 @@ func (server *SugarDB) Subscribe(tag string, channels ...string) (ReadPubSubMess // // `channels` - ...string - The list of channels to unsubscribe from. func (server *SugarDB) Unsubscribe(tag string, channels ...string) { - c, ok := connections.Load(tag) + sub, ok := subscriptions.Load(tag) if !ok { return } - cmd := append([]string{"UNSUBSCRIBE"}, channels...) - _, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), c.(conn).writeConn, false, true) + msgReader := sub.(*MessageReader) + server.pubSub.Unsubscribe(msgReader, channels, false) } // PSubscribe subscribes the caller to the list of provided glob patterns. @@ -128,31 +87,23 @@ func (server *SugarDB) Unsubscribe(tag string, channels ...string) { // // Returns: ReadPubSubMessage function which reads the next message sent to the subscription instance. // This function is blocking. -func (server *SugarDB) PSubscribe(tag string, patterns ...string) (ReadPubSubMessage, error) { - readConn, writeConn, err := establishConnections(tag) - if err != nil { - return func() []string { - return []string{} - }, err + +func (server *SugarDB) PSubscribe(tag string, patterns ...string) (*MessageReader, error) { + var msgReader *MessageReader + + sub, ok := subscriptions.Load(tag) + if !ok { + // Create new messageBuffer and store it in the subscriptions + msgReader = &MessageReader{ + embeddedSub: pubsub.NewEmbeddedSub(), + } + } else { + msgReader = sub.(*MessageReader) } - // Subscribe connection to the provided channels - cmd := append([]string{"PSUBSCRIBE"}, patterns...) - go func() { - _, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), writeConn, false, true) - }() + server.pubSub.Subscribe(msgReader.embeddedSub, patterns, true) - return func() []string { - r := resp.NewConn(*readConn) - v, _, _ := r.ReadValue() - - res := make([]string, len(v.Array())) - for i := 0; i < len(res); i++ { - res[i] = v.Array()[i].String() - } - - return res - }, nil + return msgReader, nil } // PUnsubscribe unsubscribes the caller from the given glob patterns. @@ -163,12 +114,12 @@ func (server *SugarDB) PSubscribe(tag string, patterns ...string) (ReadPubSubMes // // `patterns` - ...string - The list of glob patterns to unsubscribe from. func (server *SugarDB) PUnsubscribe(tag string, patterns ...string) { - c, ok := connections.Load(tag) + sub, ok := subscriptions.Load(tag) if !ok { return } - cmd := append([]string{"PUNSUBSCRIBE"}, patterns...) - _, _ = server.handleCommand(server.context, internal.EncodeCommand(cmd), c.(conn).writeConn, false, true) + msgReader := sub.(*MessageReader) + server.pubSub.Unsubscribe(msgReader, patterns, true) } // Publish publishes a message to the given channel. @@ -179,10 +130,12 @@ func (server *SugarDB) PUnsubscribe(tag string, patterns ...string) { // // `message` - string - The message to publish to the specified channel. // -// Returns: true when the publish is successful. This does not indicate whether each subscriber has received the message, -// only that the message has been published. +// Returns: true when successful. This does not indicate whether each subscriber has received the message, +// only that the message has been published to the channel. func (server *SugarDB) Publish(channel, message string) (bool, error) { - b, err := server.handleCommand(server.context, internal.EncodeCommand([]string{"PUBLISH", channel, message}), nil, false, true) + b, err := server.handleCommand( + server.context, + internal.EncodeCommand([]string{"PUBLISH", channel, message}), nil, false, true) if err != nil { return false, err } diff --git a/sugardb/api_pubsub_test.go b/sugardb/api_pubsub_test.go index 10abf67..ad98922 100644 --- a/sugardb/api_pubsub_test.go +++ b/sugardb/api_pubsub_test.go @@ -15,276 +15,326 @@ package sugardb import ( + "bytes" + "encoding/json" "fmt" + "io" "reflect" "slices" "testing" + "time" ) -func Test_Subscribe(t *testing.T) { +func TestSugarDB_PubSub(t *testing.T) { server := createSugarDB() + t.Cleanup(func() { + server.ShutDown() + }) - // Subscribe to channels. - tag := "tag" - channels := []string{"channel1", "channel2"} - readMessage, err := server.Subscribe(tag, channels...) - if err != nil { - t.Errorf("SUBSCRIBE() error = %v", err) - } + t.Run("TestSugarDB_(P)Subscribe", func(t *testing.T) { + t.Parallel() - for i := 0; i < len(channels); i++ { - message := readMessage() - // Check that we've received the subscribe messages. - if message[0] != "subscribe" { - t.Errorf("SUBSCRIBE() expected index 0 for message at %d to be \"subscribe\", got %s", i, message[0]) + tests := []struct { + name string + action string // subscribe | psubscribe + tag string + channels []string + pubChannels []string // Channels to publish messages to after subscriptions + wantMsg []string // Expected messages from after publishing + subFunc func(tag string, channels ...string) (*MessageReader, error) + unsubFunc func(tag string, channels ...string) + }{ + { + name: "1. Subscribe to channels", + action: "subscribe", + tag: "tag_test_subscribe", + channels: []string{ + "channel1", + "channel2", + }, + pubChannels: []string{"channel1", "channel2"}, + wantMsg: []string{ + "message for channel1", + "message for channel2", + }, + subFunc: server.Subscribe, + unsubFunc: server.Unsubscribe, + }, + { + name: "2. Subscribe to patterns", + action: "psubscribe", + tag: "tag_test_psubscribe", + channels: []string{ + "channel[34]", + "pattern[12]", + }, + pubChannels: []string{ + "channel3", + "channel4", + "pattern1", + "pattern2", + }, + wantMsg: []string{ + "message for channel3", + "message for channel4", + "message for pattern1", + "message for pattern2", + }, + subFunc: server.PSubscribe, + unsubFunc: server.PUnsubscribe, + }, } - if !slices.Contains(channels, message[1]) { - t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i) - } - } - // Publish some messages to the channels. - for _, channel := range channels { - ok, err := server.Publish(channel, fmt.Sprintf("message for %s", channel)) - if err != nil { - t.Errorf("PUBLISH() err = %v", err) - } - if !ok { - t.Errorf("PUBLISH() could not publish message to channel %s", channel) - } - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() - // Read messages from the channels - for i := 0; i < len(channels); i++ { - message := readMessage() - // Check that we've received the messages. - if message[0] != "message" { - t.Errorf("SUBSCRIBE() expected index 0 for message at %d to be \"message\", got %s", i, message[0]) - } - if !slices.Contains(channels, message[1]) { - t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i) - } - if !slices.Contains([]string{"message for channel1", "message for channel2"}, message[2]) { - t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i) - } - } + t.Cleanup(func() { + tt.unsubFunc(tt.tag, tt.channels...) + }) - // Unsubscribe from channels - server.Unsubscribe(tag, channels...) -} - -func TestSugarDB_PSubscribe(t *testing.T) { - server := createSugarDB() - - // Subscribe to channels. - tag := "tag" - patterns := []string{"channel[12]", "pattern[12]"} - readMessage, err := server.PSubscribe(tag, patterns...) - if err != nil { - t.Errorf("PSubscribe() error = %v", err) - } - - for i := 0; i < len(patterns); i++ { - message := readMessage() - // Check that we've received the subscribe messages. - if message[0] != "psubscribe" { - t.Errorf("PSUBSCRIBE() expected index 0 for message at %d to be \"psubscribe\", got %s", i, message[0]) - } - if !slices.Contains(patterns, message[1]) { - t.Errorf("PSUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i) - } - } - - // Publish some messages to the channels. - for _, channel := range []string{"channel1", "channel2", "pattern1", "pattern2"} { - ok, err := server.Publish(channel, fmt.Sprintf("message for %s", channel)) - if err != nil { - t.Errorf("PUBLISH() err = %v", err) - } - if !ok { - t.Errorf("PUBLISH() could not publish message to channel %s", channel) - } - } - - // Read messages from the channels - for i := 0; i < len(patterns)*2; i++ { - message := readMessage() - // Check that we've received the messages. - if message[0] != "message" { - t.Errorf("SUBSCRIBE() expected index 0 for message at %d to be \"message\", got %s", i, message[0]) - } - if !slices.Contains(patterns, message[1]) { - t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i) - } - if !slices.Contains([]string{ - "message for channel1", "message for channel2", "message for pattern1", "message for pattern2"}, message[2]) { - t.Errorf("SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[2], i) - } - } - - // Unsubscribe from channels - server.PUnsubscribe(tag, patterns...) -} - -func TestSugarDB_PubSubChannels(t *testing.T) { - server := createSugarDB() - tests := []struct { - name string - tag string - channels []string - patterns []string - pattern string - want []string - wantErr bool - }{ - { - name: "1. Get number of active channels", - tag: "tag", - channels: []string{"channel1", "channel2", "channel3", "channel4"}, - patterns: []string{"channel[56]"}, - pattern: "channel[123456]", - want: []string{"channel1", "channel2", "channel3", "channel4"}, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Subscribe to channels - readChannelMessages, err := server.Subscribe(tt.tag, tt.channels...) - if err != nil { - t.Errorf("PubSubChannels() error = %v", err) - } - - for i := 0; i < len(tt.channels); i++ { - readChannelMessages() - } - // Subscribe to patterns - readPatternMessages, err := server.PSubscribe(tt.tag, tt.patterns...) - if err != nil { - t.Errorf("PubSubChannels() error = %v", err) - } - - for i := 0; i < len(tt.patterns); i++ { - readPatternMessages() - } - got, err := server.PubSubChannels(tt.pattern) - if (err != nil) != tt.wantErr { - t.Errorf("PubSubChannels() error = %v, wantErr %v", err, tt.wantErr) - return - } - if len(got) != len(tt.want) { - t.Errorf("PubSubChannels() got response length %d, want %d", len(got), len(tt.want)) - } - for _, item := range got { - if !slices.Contains(tt.want, item) { - t.Errorf("PubSubChannels() unexpected item \"%s\", in response", item) + // Subscribe to channels. + readMessage, err := tt.subFunc(tt.tag, tt.channels...) + if err != nil { + t.Errorf("(P)SUBSCRIBE() error = %v", err) } - } - }) - } -} -func TestSugarDB_PubSubNumPat(t *testing.T) { - server := createSugarDB() - tests := []struct { - name string - tag string - patterns []string - want int - wantErr bool - }{ - { - name: "1. Get number of active patterns on the server", - tag: "tag", - patterns: []string{"channel[56]", "channel[78]"}, - want: 2, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Subscribe to patterns - readPatternMessages, err := server.PSubscribe(tt.tag, tt.patterns...) - if err != nil { - t.Errorf("PubSubNumPat() error = %v", err) - } - for i := 0; i < len(tt.patterns); i++ { - readPatternMessages() - } - got, err := server.PubSubNumPat() - if (err != nil) != tt.wantErr { - t.Errorf("PubSubNumPat() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("PubSubNumPat() got = %v, want %v", got, tt.want) - } - }) - } -} + for i := 0; i < len(tt.channels); i++ { + p := make([]byte, 1024) + _, err := readMessage.Read(p) + if err != nil { + t.Errorf("(P)SUBSCRIBE() read error: %+v", err) + } + var message []string + if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil { + t.Errorf("(P)SUBSCRIBE() json unmarshal error: %+v", err) + } + // Check that we've received the subscribe messages. + if message[0] != tt.action { + t.Errorf("(P)SUBSCRIBE() expected index 0 for message at %d to be \"subscribe\", got %s", i, message[0]) + } + if !slices.Contains(tt.channels, message[1]) { + t.Errorf("(P)SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i) + } + } -func TestSugarDB_PubSubNumSub(t *testing.T) { - server := createSugarDB() - tests := []struct { - name string - subscriptions map[string]struct { + // Publish some messages to the channels. + for _, channel := range tt.pubChannels { + ok, err := server.Publish(channel, fmt.Sprintf("message for %s", channel)) + if err != nil { + t.Errorf("PUBLISH() err = %v", err) + } + if !ok { + t.Errorf("PUBLISH() could not publish message to channel %s", channel) + } + } + + // Read messages from the channels + for i := 0; i < len(tt.pubChannels); i++ { + p := make([]byte, 1024) + _, err := readMessage.Read(p) + + doneChan := make(chan struct{}, 1) + go func() { + for { + if err != nil && err == io.EOF { + _, err = readMessage.Read(p) + continue + } + doneChan <- struct{}{} + break + } + }() + + select { + case <-time.After(500 * time.Millisecond): + t.Errorf("(P)SUBSCRIBE() timeout") + case <-doneChan: + if err != nil { + t.Errorf("(P)SUBSCRIBE() read error: %+v", err) + } + } + + var message []string + if err = json.Unmarshal(bytes.TrimRight(p, "\x00"), &message); err != nil { + t.Errorf("(P)SUBSCRIBE() json unmarshal error: %+v", err) + } + // Check that we've received the messages. + if message[0] != "message" { + t.Errorf("(P)SUBSCRIBE() expected index 0 for message at %d to be \"message\", got %s", i, message[0]) + } + if !slices.Contains(tt.channels, message[1]) { + t.Errorf("(P)SUBSCRIBE() unexpected string \"%s\" at index 1 for message %d", message[1], i) + } + if !slices.Contains(tt.wantMsg, message[2]) { + t.Errorf("(P)SUBSCRIBE() unexpected string \"%s\" at index 2 for message %d", message[1], i) + } + } + }) + } + }) + + t.Run("TestSugarDB_PubSubChannels", func(t *testing.T) { + tests := []struct { + name string + tag string channels []string patterns []string + pattern string + want []string + wantErr bool + }{ + { + name: "1. Get number of active channels", + tag: "tag_test_channels_1", + channels: []string{"channel1", "channel2", "channel3", "channel4"}, + patterns: []string{"channel[56]"}, + pattern: "channel[123456]", + want: []string{"channel1", "channel2", "channel3", "channel4"}, + wantErr: false, + }, } - channels []string - want map[string]int - wantErr bool - }{ - { - name: "Get number of subscriptions for the given channels", - subscriptions: map[string]struct { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // Subscribe to channels + _, err := server.Subscribe(tt.tag, tt.channels...) + if err != nil { + t.Errorf("PubSubChannels() error = %v", err) + } + + // Subscribe to patterns + _, err = server.PSubscribe(tt.tag, tt.patterns...) + if err != nil { + t.Errorf("PubSubChannels() error = %v", err) + } + + got, err := server.PubSubChannels(tt.pattern) + if (err != nil) != tt.wantErr { + t.Errorf("PubSubChannels() error = %v, wantErr %v", err, tt.wantErr) + return + } + if len(got) != len(tt.want) { + t.Errorf("PubSubChannels() got response length %d, want %d", len(got), len(tt.want)) + } + for _, item := range got { + if !slices.Contains(tt.want, item) { + t.Errorf("PubSubChannels() unexpected item \"%s\", in response", item) + } + } + }) + } + }) + + t.Run("TestSugarDB_PubSubNumPat", func(t *testing.T) { + tests := []struct { + name string + tag string + patterns []string + want int + wantErr bool + }{ + { + name: "1. Get number of active patterns on the server", + tag: "tag_test_numpat_1", + patterns: []string{"channel[56]", "channel[78]"}, + want: 2, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // Subscribe to patterns + _, err := server.PSubscribe(tt.tag, tt.patterns...) + if err != nil { + t.Errorf("PubSubNumPat() error = %v", err) + } + + got, err := server.PubSubNumPat() + if (err != nil) != tt.wantErr { + t.Errorf("PubSubNumPat() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("PubSubNumPat() got = %v, want %v", got, tt.want) + } + }) + } + }) + + t.Run("TestSugarDB_PubSubNumSub", func(t *testing.T) { + tests := []struct { + name string + subscriptions map[string]struct { channels []string patterns []string - }{ - "tag1": { - channels: []string{"channel1", "channel2"}, - patterns: []string{"channel[34]"}, + } + channels []string + want map[string]int + wantErr bool + }{ + { + name: "1. Get number of subscriptions for the given channels", + subscriptions: map[string]struct { + channels []string + patterns []string + }{ + "tag1_test_numsub_1": { + channels: []string{"test_num_sub_channel1", "test_num_sub_channel2"}, + patterns: []string{"test_num_sub_channel[34]"}, + }, + "tag2_test_numsub_2": { + channels: []string{"test_num_sub_channel2", "test_num_sub_channel3"}, + patterns: []string{"test_num_sub_channel[23]"}, + }, + "tag3_test_numsub_3": { + channels: []string{"test_num_sub_channel2", "test_num_sub_channel4"}, + patterns: []string{}, + }, }, - "tag2": { - channels: []string{"channel2", "channel3"}, - patterns: []string{"channel[23]"}, + channels: []string{ + "test_num_sub_channel1", + "test_num_sub_channel2", + "test_num_sub_channel3", + "test_num_sub_channel4", + "test_num_sub_channel5", }, - "tag3": { - channels: []string{"channel2", "channel4"}, - patterns: []string{}, + want: map[string]int{ + "test_num_sub_channel1": 1, + "test_num_sub_channel2": 3, + "test_num_sub_channel3": 1, + "test_num_sub_channel4": 1, + "test_num_sub_channel5": 0, }, + wantErr: false, }, - channels: []string{"channel1", "channel2", "channel3", "channel4", "channel5"}, - want: map[string]int{"channel1": 1, "channel2": 3, "channel3": 1, "channel4": 1, "channel5": 0}, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - for tag, subs := range tt.subscriptions { - readPat, err := server.PSubscribe(tag, subs.patterns...) - if err != nil { - t.Errorf("PubSubNumSub() error = %v", err) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + for tag, subs := range tt.subscriptions { + _, err := server.PSubscribe(tag, subs.patterns...) + if err != nil { + t.Errorf("PubSubNumSub() error = %v", err) + } + + _, err = server.Subscribe(tag, subs.channels...) + if err != nil { + t.Errorf("PubSubNumSub() error = %v", err) + } + } - for _, _ = range subs.patterns { - readPat() + got, err := server.PubSubNumSub(tt.channels...) + if (err != nil) != tt.wantErr { + t.Errorf("PubSubNumSub() error = %v, wantErr %v", err, tt.wantErr) + return } - readChan, err := server.Subscribe(tag, subs.channels...) - if err != nil { - t.Errorf("PubSubNumSub() error = %v", err) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("PubSubNumSub() got = %v, want %v", got, tt.want) } - for _, _ = range subs.channels { - readChan() - } - } - got, err := server.PubSubNumSub(tt.channels...) - if (err != nil) != tt.wantErr { - t.Errorf("PubSubNumSub() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("PubSubNumSub() got = %v, want %v", got, tt.want) - } - }) - } + }) + } + }) } diff --git a/sugardb/sugardb.go b/sugardb/sugardb.go index 938471c..9c5b4f4 100644 --- a/sugardb/sugardb.go +++ b/sugardb/sugardb.go @@ -204,7 +204,7 @@ func NewSugarDB(options ...func(sugarDB *SugarDB)) (*SugarDB, error) { sugarDB.acl = acl.NewACL(sugarDB.config) // Set up Pub/Sub module - sugarDB.pubSub = pubsub.NewPubSub() + sugarDB.pubSub = pubsub.NewPubSub(sugarDB.context) if sugarDB.isInCluster() { sugarDB.raft = raft.NewRaft(raft.Opts{