diff --git a/src/modules/list/commands.go b/src/modules/list/commands.go index 0a81f0c..a9a1301 100644 --- a/src/modules/list/commands.go +++ b/src/modules/list/commands.go @@ -144,7 +144,6 @@ func handleLRange(ctx context.Context, cmd []string, server utils.Server, conn * } else { i-- } - } return bytes, nil diff --git a/src/modules/pubsub/channel.go b/src/modules/pubsub/channel.go index 41d7781..41ae70f 100644 --- a/src/modules/pubsub/channel.go +++ b/src/modules/pubsub/channel.go @@ -83,19 +83,31 @@ func (ch *Channel) Subscribe(conn *net.Conn) { } } -func (ch *Channel) Unsubscribe(conn *net.Conn, waitGroup *sync.WaitGroup) { +func (ch *Channel) Unsubscribe(conn *net.Conn) bool { ch.subscribersRWMut.Lock() defer ch.subscribersRWMut.Unlock() + var removed bool + ch.subscribers = slices.DeleteFunc(ch.subscribers, func(c *net.Conn) bool { - return c == conn + if c == conn { + removed = true + return true + } + return false }) - if waitGroup != nil { - waitGroup.Done() - } + return removed } func (ch *Channel) Publish(message string) { *ch.messageChan <- message } + +func (ch *Channel) IsActive() bool { + return len(ch.subscribers) > 0 +} + +func (ch *Channel) NumSubs() int { + return len(ch.subscribers) +} diff --git a/src/modules/pubsub/commands.go b/src/modules/pubsub/commands.go index bfbf64c..5ed322b 100644 --- a/src/modules/pubsub/commands.go +++ b/src/modules/pubsub/commands.go @@ -3,6 +3,7 @@ package pubsub import ( "context" "errors" + "fmt" "github.com/echovault/echovault/src/utils" "net" "strings" @@ -22,9 +23,9 @@ func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, con switch strings.ToLower(cmd[0]) { case "subscribe": - pubsub.Subscribe(ctx, conn, channels, false) + return pubsub.Subscribe(ctx, conn, channels, false), nil case "psubscribe": - pubsub.Subscribe(ctx, conn, channels, true) + return pubsub.Subscribe(ctx, conn, channels, true), nil } return []byte{}, nil @@ -38,16 +39,14 @@ func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, c channels := cmd[1:] - if len(channels) == 0 { - pubsub.Unsubscribe(ctx, conn, "*") - return []byte(utils.OK_RESPONSE), nil + switch strings.ToLower(cmd[0]) { + case "unsubscribe": + return pubsub.Unsubscribe(ctx, conn, channels, false), nil + case "punsubscribe": + return pubsub.Unsubscribe(ctx, conn, channels, true), nil + default: + return []byte{}, nil } - - for _, channel := range channels { - pubsub.Unsubscribe(ctx, conn, channel) - } - - return []byte(utils.OK_RESPONSE), nil } func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -62,6 +61,41 @@ func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn return []byte(utils.OK_RESPONSE), nil } +func handlePubSubChannels(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { + if len(cmd) > 3 { + return nil, errors.New(utils.WRONG_ARGS_RESPONSE) + } + + pubsub, ok := server.GetPubSub().(*PubSub) + if !ok { + return nil, errors.New("could not load pubsub module") + } + + pattern := "" + if len(cmd) == 3 { + pattern = cmd[2] + } + + return pubsub.Channels(ctx, pattern), nil +} + +func handlePubSubNumPat(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { + pubsub, ok := server.GetPubSub().(*PubSub) + if !ok { + return nil, errors.New("could not load pubsub module") + } + num := pubsub.NumPat(ctx) + return []byte(fmt.Sprintf(":%d\r\n", num)), nil +} + +func handlePubSubNumSubs(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { + pubsub, ok := server.GetPubSub().(*PubSub) + if !ok { + return nil, errors.New("could not load pubsub module") + } + return pubsub.NumSub(ctx, cmd[2:]), nil +} + func Commands() []utils.Command { return []utils.Command{ { @@ -119,5 +153,57 @@ it's currently subscribe to.`, }, HandlerFunc: handleUnsubscribe, }, + { + Command: "punsubscribe", + Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory}, + Description: `(PUNSUBSCRIBE [channel [channel ...]]) Unsubscribe from a list of channels using patterns. +If the pattern list is not provided, then the connection will be unsubscribed from all the patterns that +it's currently subscribe to.`, + Sync: false, + KeyExtractionFunc: func(cmd []string) ([]string, error) { + // Treat the channels as keys + return cmd[1:], nil + }, + HandlerFunc: handleUnsubscribe, + }, + { + Command: "pubsub", + Categories: []string{}, + Description: "", + Sync: false, + KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, + HandlerFunc: func(_ context.Context, _ []string, _ utils.Server, _ *net.Conn) ([]byte, error) { + return nil, errors.New("provide CHANNELS, NUMPAT, or NUMSUB subcommand") + }, + SubCommands: []utils.SubCommand{ + { + Command: "channels", + Categories: []string{utils.PubSubCategory, utils.SlowCategory}, + Description: `(PUBSUB CHANNELS [pattern]) Returns an array containing the list of channels that +match the given pattern. If no pattern is provided, all active channels are returned. Active channels are +channels with 1 or more subscribers.`, + Sync: false, + KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, + HandlerFunc: handlePubSubChannels, + }, + { + Command: "numpat", + Categories: []string{utils.PubSubCategory, utils.SlowCategory}, + Description: `(PUBSUB NUMPAT) Return the number of patterns that are currently subscribed to by clients.`, + Sync: false, + KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, + HandlerFunc: handlePubSubNumPat, + }, + { + Command: "numsub", + Categories: []string{utils.PubSubCategory, utils.SlowCategory}, + Description: `(PUBSUB NUMSUB [channel [channel ...]]) Return an array of arrays containing the provided +channel name and how many clients are currently subscribed to the channel.`, + Sync: false, + KeyExtractionFunc: func(cmd []string) ([]string, error) { return cmd[2:], nil }, + HandlerFunc: handlePubSubNumSubs, + }, + }, + }, } } diff --git a/src/modules/pubsub/pubsub.go b/src/modules/pubsub/pubsub.go index 1407c0c..433e447 100644 --- a/src/modules/pubsub/pubsub.go +++ b/src/modules/pubsub/pubsub.go @@ -3,8 +3,7 @@ package pubsub import ( "context" "fmt" - "io" - "log" + "github.com/gobwas/glob" "net" "slices" "sync" @@ -23,7 +22,8 @@ func NewPubSub() *PubSub { } } -func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) { +func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte { + res := fmt.Sprintf("*%d\r\n", len(channels)) for i := 0; i < len(channels); i++ { // Check if channel with given name exists @@ -49,43 +49,79 @@ func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []stri ps.channels[channelIdx].Subscribe(conn) } - var res string if len(channels) > 1 { // If subscribing to more than one channel, write array to verify the subscription of this channel - res = fmt.Sprintf("*3\r\n+subscribe\r\n$%d\r\n%s\r\n:%d\r\n", len(channels[i]), channels[i], i+1) + res += fmt.Sprintf("*3\r\n+subscribe\r\n$%d\r\n%s\r\n:%d\r\n", len(channels[i]), channels[i], i+1) } else { // Ony one channel, simply send "subscribe" simple string response res = "+subscribe\r\n" } - - w := io.Writer(*conn) - if _, err := w.Write([]byte(res)); err != nil { - log.Println(err) - } } + + return []byte(res) } -func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName string) { +func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte { ps.channelsRWMut.RLock() ps.channelsRWMut.RUnlock() - if channelName == "*" { - wg := &sync.WaitGroup{} + action := "unsubscribe" + if withPattern { + action = "subscribe" + } + + unsubscribed := make(map[int]string) + count := 1 + + // If the channels slice is empty, unsubscribe from all channels. + if len(channels) <= 0 { for _, channel := range ps.channels { - wg.Add(1) - go channel.Unsubscribe(conn, wg) + if channel.Unsubscribe(conn) { + unsubscribed[1] = channel.name + count += 1 + } } - wg.Wait() - return } - channelIdx := slices.IndexFunc(ps.channels, func(channel *Channel) bool { - return channel.name == channelName - }) - - if channelIdx != -1 { - ps.channels[channelIdx].Unsubscribe(conn, nil) + // If withPattern is false, unsubscribe from channels where the name exactly matches channel name. + if !withPattern { + for _, channel := range ps.channels { // For each channel in PubSub + for _, c := range channels { // For each channel name provided + if channel.name == c && channel.Unsubscribe(conn) { + unsubscribed[count] = channel.name + count += 1 + } + } + } } + + // If withPattern is true, unsubscribe from channels where pattern matches pattern provided, + // also unsubscribe from channels where the name matches the given pattern. + if withPattern { + for _, pattern := range channels { + g := glob.MustCompile(pattern) + for _, channel := range ps.channels { + // If it's a pattern channel, directly compare the patterns + if channel.pattern != nil && channel.name == pattern { + unsubscribed[count] = channel.name + count += 1 + continue + } + // If this is a regular channel, check if the channel name matches the pattern given + if g.Match(channel.name) { + unsubscribed[count] = channel.name + count += 1 + } + } + } + } + + res := fmt.Sprintf("*%d\r\n", len(unsubscribed)) + for key, value := range unsubscribed { + res += fmt.Sprintf("*3\r\n+%s\r\n$%d\r\n%s\r\n:%d\r\n", action, len(value), value, key) + } + + return []byte(res) } func (ps *PubSub) Publish(ctx context.Context, message string, channelName string) { @@ -93,8 +129,6 @@ func (ps *PubSub) Publish(ctx context.Context, message string, channelName strin defer ps.channelsRWMut.RUnlock() for _, channel := range ps.channels { - fmt.Println(channel.name, channel.pattern) - // If it's a regular channel, check if the channel name matches the name given if channel.pattern == nil { if channel.name == channelName { @@ -108,3 +142,62 @@ func (ps *PubSub) Publish(ctx context.Context, message string, channelName strin } } } + +func (ps *PubSub) Channels(ctx context.Context, pattern string) []byte { + var count int + var res string + + if pattern == "" { + for _, channel := range ps.channels { + if channel.IsActive() { + res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name) + count += 1 + } + } + + res = fmt.Sprintf("*%d\r\n%s", count, res) + return []byte(res) + } + + g := glob.MustCompile(pattern) + + for _, channel := range ps.channels { + // If channel is a pattern channel, then directly compare the channel name to pattern + if channel.pattern != nil && channel.name == pattern && channel.IsActive() { + res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name) + count += 1 + continue + } + if g.Match(channel.name) && channel.IsActive() { + res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name) + count += 1 + } + } + + return []byte(res) +} + +func (ps *PubSub) NumPat(ctx context.Context) int { + var count int + for _, channel := range ps.channels { + if channel.pattern != nil { + count += 1 + } + } + return count +} + +func (ps *PubSub) NumSub(ctx context.Context, channels []string) []byte { + res := fmt.Sprintf("*%d\r\n", len(channels)) + for _, channel := range channels { + chanIdx := slices.IndexFunc(ps.channels, func(c *Channel) bool { + return c.name == channel + }) + if chanIdx == -1 { + res += fmt.Sprintf("*2\r\n$%d\r\n%s\r\n:0\r\n", len(channel), channel) + continue + } + res += fmt.Sprintf("*2\r\n$%d\r\n%s\r\n:%d\r\n", len(channel), channel, ps.channels[chanIdx].NumSubs()) + } + return []byte(res) +} diff --git a/src/server/server.go b/src/server/server.go index 95b2712..cb36647 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -219,6 +219,10 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) { res, err := server.handleCommand(ctx, message, &conn, false) + if err != nil && errors.Is(err, io.EOF) { + break + } + if err != nil { if _, err = w.Write([]byte(fmt.Sprintf("-Error %s\r\n", err.Error()))); err != nil { log.Println(err) @@ -234,10 +238,7 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) { } if len(res) <= chunkSize { - _, err = w.Write(res) - if err != nil { - log.Println(err) - } + _, _ = w.Write(res) continue } @@ -246,7 +247,10 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) { for { // If the current start index is less than chunkSize from length, return the remaining bytes. if len(res)-1-startIndex < chunkSize { - _, _ = w.Write(res[startIndex:]) + _, err = w.Write(res[startIndex:]) + if err != nil { + log.Println(err) + } break } n, _ := w.Write(res[startIndex : startIndex+chunkSize])