Replaced all instances of utils.Contains with slices.Contains.

Deleted Contains function from the utils.go file.
Replaced bufio.ReadWriter with io.Reader and io.Writer respectively.
Updated PubSub module to remove posibility of not passing a channel name when subscribing or publishing.
This commit is contained in:
Kelvin Clement Mwinuka
2024-01-23 23:38:07 +08:00
parent 490bddf80c
commit ce14f59f41
13 changed files with 152 additions and 123 deletions

View File

@@ -8,10 +8,10 @@ import (
type BroadcastMessage struct {
NodeMeta
Action string `json:"Action"`
Content string `json:"Content"`
ContentHash string `json:"ContentHash"`
ConnId string `json:"ConnId"`
Action string `json:"Action"`
Content []byte `json:"Content"`
ContentHash [16]byte `json:"ContentHash"`
ConnId string `json:"ConnId"`
}
// Invalidates Implements Broadcast interface

View File

@@ -103,14 +103,14 @@ func (m *MemberList) broadcastRaftAddress(ctx context.Context) {
m.broadcastQueue.QueueBroadcast(&msg)
}
func (m *MemberList) ForwardDataMutation(ctx context.Context, cmd string) {
func (m *MemberList) ForwardDataMutation(ctx context.Context, cmd []byte) {
// This function is only called by non-leaders
// It uses the broadcast queue to forward a data mutation within the cluster
connId, _ := ctx.Value(utils.ContextConnID("ConnectionID")).(string)
m.broadcastQueue.QueueBroadcast(&BroadcastMessage{
Action: "MutateData",
Content: cmd,
ContentHash: fmt.Sprintf("%x", md5.Sum([]byte(cmd))),
ContentHash: md5.Sum(cmd),
ConnId: connId,
NodeMeta: NodeMeta{
ServerID: raft.ServerID(m.options.Config.ServerID),

View File

@@ -264,6 +264,11 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
}
}
// Skip ack
if strings.EqualFold(comm, "ack") {
return nil
}
// If the command is 'auth', then return early and allow it
if strings.EqualFold(comm, "auth") {
// TODO: Add rate limiting to prevent auth spamming

View File

@@ -7,6 +7,7 @@ import (
"github.com/kelvinmwinuka/memstore/src/utils"
"math"
"net"
"slices"
"strings"
)
@@ -333,7 +334,7 @@ func handleLMove(ctx context.Context, cmd []string, server utils.Server, conn *n
whereFrom := strings.ToLower(cmd[3])
whereTo := strings.ToLower(cmd[4])
if !utils.Contains[string]([]string{"left", "right"}, whereFrom) || !utils.Contains[string]([]string{"left", "right"}, whereTo) {
if !slices.Contains([]string{"left", "right"}, whereFrom) || !slices.Contains([]string{"left", "right"}, whereTo) {
return nil, errors.New("wherefrom and whereto arguments must be either LEFT or RIGHT")
}

View File

@@ -59,7 +59,7 @@ func NewModule() Plugin {
return []string{}, nil
},
HandlerFunc: func(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
return []byte("_\r\n\r\n"), nil
return []byte("$-1\r\n\r\n"), nil
},
},
},

View File

@@ -32,9 +32,6 @@ func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, con
return nil, errors.New("could not load pubsub")
}
switch len(cmd) {
case 1:
// Subscribe to all channels
pubsub.Subscribe(ctx, conn, nil, nil)
case 2:
// Subscribe to specified channel
pubsub.Subscribe(ctx, conn, cmd[1], nil)
@@ -44,7 +41,7 @@ func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, con
default:
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
}
return []byte("+SUBSCRIBE_OK\r\n\n"), nil
return []byte("+SUBSCRIBE_OK\r\n\r\n"), nil
}
func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -68,13 +65,10 @@ func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn
if !ok {
return nil, errors.New("could not load pubsub")
}
if len(cmd) == 3 {
pubsub.Publish(ctx, cmd[2], cmd[1])
} else if len(cmd) == 2 {
pubsub.Publish(ctx, cmd[1], nil)
} else {
if len(cmd) != 3 {
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
}
pubsub.Publish(ctx, cmd[2], cmd[1])
return []byte(utils.OK_RESPONSE), nil
}

View File

@@ -1,13 +1,14 @@
package pubsub
import (
"bufio"
"bytes"
"container/ring"
"context"
"fmt"
"github.com/kelvinmwinuka/memstore/src/utils"
"io"
"net"
"strings"
"slices"
"sync"
"time"
)
@@ -39,29 +40,44 @@ func (cg *ConsumerGroup) SendMessage(message string) {
cg.subscribersRWMut.RUnlock()
rw := bufio.NewReadWriter(bufio.NewReader(*conn), bufio.NewWriter(*conn))
rw.WriteString(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))
rw.Flush()
w := io.Writer(*conn)
r := io.Reader(*conn)
if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
// Wait for an ACK
// If no ACK is received within a time limit, remove this connection from subscribers and retry
(*conn).SetReadDeadline(time.Now().Add(250 * time.Millisecond))
if msg, err := utils.ReadMessage(rw); err != nil {
if err := (*conn).SetReadDeadline(time.Now().Add(250 * time.Millisecond)); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
if msg, err := utils.ReadMessage(r); err != nil {
// Remove the connection from subscribers list
cg.Unsubscribe(conn)
// Reset the deadline
(*conn).SetReadDeadline(time.Time{})
if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
// Retry sending the message
cg.SendMessage(message)
} else {
if strings.TrimSpace(msg) != "+ACK" {
if !bytes.Equal(bytes.TrimSpace(msg), []byte("+ACK")) {
cg.Unsubscribe(conn)
(*conn).SetReadDeadline(time.Time{})
if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
cg.SendMessage(message)
}
}
(*conn).SetDeadline(time.Time{})
if err := (*conn).SetDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
cg.subscribers = cg.subscribers.Next()
}
@@ -152,19 +168,31 @@ func (ch *Channel) Start() {
for _, conn := range ch.subscribers {
go func(conn *net.Conn) {
rw := bufio.NewReadWriter(bufio.NewReader(*conn), bufio.NewWriter(*conn))
rw.WriteString(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))
rw.Flush()
w := io.Writer(*conn)
r := io.Reader(*conn)
(*conn).SetReadDeadline(time.Now().Add(200 * time.Millisecond))
if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\r\n", len(message), message))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
if err := (*conn).SetReadDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
ch.Unsubscribe(conn)
}
defer func() {
(*conn).SetReadDeadline(time.Time{})
if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
ch.Unsubscribe(conn)
}
}()
if msg, err := utils.ReadMessage(rw); err != nil {
if msg, err := utils.ReadMessage(r); err != nil {
ch.Unsubscribe(conn)
} else {
if strings.TrimSpace(msg) != "+ACK" {
if !bytes.EqualFold(bytes.TrimSpace(msg), []byte("+ACK")) {
ch.Unsubscribe(conn)
}
}
@@ -177,7 +205,7 @@ func (ch *Channel) Start() {
}
func (ch *Channel) Subscribe(conn *net.Conn, consumerGroupName interface{}) {
if consumerGroupName == nil && !utils.Contains[*net.Conn](ch.subscribers, conn) {
if consumerGroupName == nil && !slices.Contains(ch.subscribers, conn) {
ch.subscribersRWMut.Lock()
defer ch.subscribersRWMut.Unlock()
ch.subscribers = append(ch.subscribers, conn)
@@ -230,31 +258,21 @@ type PubSub struct {
func NewPubSub() *PubSub {
return &PubSub{
channels: []*Channel{
NewChannel("chan"),
},
channels: []*Channel{},
}
}
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName interface{}, consumerGroup interface{}) {
// If no channel name is given, subscribe to all channels
if channelName == nil {
for _, channel := range ps.channels {
go channel.Subscribe(conn, nil)
}
return
}
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName string, consumerGroup interface{}) {
// Check if channel with given name exists
// If it does, subscribe the connection to the channel
// If it does not, create the channel and subscribe to it
channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool {
return c.name == channelName
channelIdx := slices.IndexFunc(ps.channels, func(channel *Channel) bool {
return channel.name == channelName
})
if len(channels) <= 0 {
if channelIdx == -1 {
go func() {
newChan := NewChannel(channelName.(string))
newChan := NewChannel(channelName)
newChan.Start()
newChan.Subscribe(conn, consumerGroup)
ps.channels = append(ps.channels, newChan)
@@ -262,9 +280,7 @@ func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName int
return
}
for _, channel := range channels {
go channel.Subscribe(conn, consumerGroup)
}
go ps.channels[channelIdx].Subscribe(conn, consumerGroup)
}
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName interface{}) {
@@ -284,18 +300,10 @@ func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName i
}
}
func (ps *PubSub) Publish(ctx context.Context, message string, channelName interface{}) {
if channelName == nil {
for _, channel := range ps.channels {
go channel.Publish(message)
}
return
}
func (ps *PubSub) Publish(ctx context.Context, message string, channelName string) {
channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool {
return c.name == channelName
})
for _, channel := range channels {
go channel.Publish(message)
}

View File

@@ -3,6 +3,7 @@ package set
import (
"github.com/kelvinmwinuka/memstore/src/utils"
"math/rand"
"slices"
)
type Set struct {
@@ -72,7 +73,7 @@ func (set *Set) GetRandom(count int) []string {
// Count is positive, do not allow repeat elements
for i := 0; i < utils.AbsInt(count); {
n = rand.Intn(len(keys))
if !utils.Contains(res, keys[n]) {
if !slices.Contains(res, keys[n]) {
res = append(res, keys[n])
keys = utils.Filter(keys, func(elem string) bool {
return elem != keys[n]

View File

@@ -52,7 +52,7 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne
}
switch utils.AdaptType(cmd[i]).(type) {
case string:
if utils.Contains([]string{"-inf", "+inf"}, strings.ToLower(cmd[i])) {
if slices.Contains([]string{"-inf", "+inf"}, strings.ToLower(cmd[i])) {
membersStartIndex = i
}
case float64:
@@ -111,11 +111,11 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne
if membersStartIndex > 2 {
options := cmd[2:membersStartIndex]
for _, option := range options {
if utils.Contains([]string{"xx", "nx"}, strings.ToLower(option)) {
if slices.Contains([]string{"xx", "nx"}, strings.ToLower(option)) {
updatePolicy = option
continue
}
if utils.Contains([]string{"gt", "lt"}, strings.ToLower(option)) {
if slices.Contains([]string{"gt", "lt"}, strings.ToLower(option)) {
comparison = option
continue
}
@@ -1725,7 +1725,7 @@ respectively.`,
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
}
endIdx := slices.IndexFunc(cmd, func(s string) bool {
return utils.Contains([]string{"MIN", "MAX", "COUNT"}, strings.ToUpper(s))
return slices.Contains([]string{"MIN", "MAX", "COUNT"}, strings.ToUpper(s))
})
if endIdx == -1 {
return cmd[1:], nil

View File

@@ -138,7 +138,7 @@ func (set *SortedSet) AddOrUpdate(
if !set.Contains(m.value) {
return count, fmt.Errorf("cannot increment member %s as it does not exist in the sorted set", m.value)
}
if utils.Contains([]Score{Score(math.Inf(-1)), Score(math.Inf(1))}, set.members[m.value].score) {
if slices.Contains([]Score{Score(math.Inf(-1)), Score(math.Inf(1))}, set.members[m.value].score) {
return count, errors.New("cannot increment -inf or +inf")
}
set.members[m.value] = MemberObject{

View File

@@ -3,7 +3,6 @@ package sorted_set
import (
"cmp"
"errors"
"github.com/kelvinmwinuka/memstore/src/utils"
"slices"
"strconv"
"strings"
@@ -19,7 +18,7 @@ func extractKeysWeightsAggregateWithScores(cmd []string) ([]string, []int, strin
if weightsIndex != -1 {
firstModifierIndex = weightsIndex
for i := weightsIndex + 1; i < len(cmd); i++ {
if utils.Contains([]string{"aggregate", "withscores"}, cmd[i]) {
if slices.Contains([]string{"aggregate", "withscores"}, cmd[i]) {
break
}
w, err := strconv.Atoi(cmd[i])
@@ -43,7 +42,7 @@ func extractKeysWeightsAggregateWithScores(cmd []string) ([]string, []int, strin
if aggregateIndex >= len(cmd)-1 {
return []string{}, []int{}, "", false, errors.New("aggregate must be SUM, MIN, or MAX")
}
if !utils.Contains([]string{"sum", "min", "max"}, strings.ToLower(cmd[aggregateIndex+1])) {
if !slices.Contains([]string{"sum", "min", "max"}, strings.ToLower(cmd[aggregateIndex+1])) {
return []string{}, []int{}, "", false, errors.New("aggregate must be SUM, MIN, or MAX")
}
aggregate = strings.ToLower(cmd[aggregateIndex+1])
@@ -93,7 +92,7 @@ func validateUpdatePolicy(updatePolicy interface{}) (string, error) {
if !ok {
return "", err
}
if !utils.Contains([]string{"nx", "xx"}, strings.ToLower(policy)) {
if !slices.Contains([]string{"nx", "xx"}, strings.ToLower(policy)) {
return "", err
}
return policy, nil
@@ -108,7 +107,7 @@ func validateComparison(comparison interface{}) (string, error) {
if !ok {
return "", err
}
if !utils.Contains([]string{"lt", "gt"}, strings.ToLower(comp)) {
if !slices.Contains([]string{"lt", "gt"}, strings.ToLower(comp)) {
return "", err
}
return comp, nil

View File

@@ -1,7 +1,6 @@
package server
import (
"bufio"
"context"
"crypto/tls"
"fmt"
@@ -85,18 +84,20 @@ func (server *Server) StartTCP(ctx context.Context) {
func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
server.ACL.RegisterConnection(&conn)
connRW := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
w := io.Writer(conn)
r := io.Reader(conn)
cid := server.ConnID.Add(1)
ctx = context.WithValue(ctx, utils.ContextConnID("ConnectionID"),
fmt.Sprintf("%s-%d", ctx.Value(utils.ContextServerID("ServerID")), cid))
for {
message, err := utils.ReadMessage(connRW)
message, err := utils.ReadMessage(r)
if err != nil {
if err == io.EOF {
// Connection closed
// TODO: Remove this connection from channel subscriptions
break
}
if err, ok := err.(net.Error); ok && err.Timeout() {
@@ -115,15 +116,19 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
if cmd, err := utils.Decode(message); err != nil {
// Return error to client
connRW.Write([]byte(fmt.Sprintf("-Error %s\r\n\n", err.Error())))
connRW.Flush()
if _, err := w.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error()))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
continue
} else {
command, err := server.getCommand(cmd[0])
if err != nil {
connRW.WriteString(fmt.Sprintf("-%s\r\n\n", err.Error()))
connRW.Flush()
if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
continue
}
@@ -138,47 +143,69 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
}
if err := server.ACL.AuthorizeConnection(&conn, cmd, command, subCommand); err != nil {
connRW.WriteString(fmt.Sprintf("-%s\r\n\n", err.Error()))
connRW.Flush()
if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
continue
}
if !server.IsInCluster() || !synchronize {
if res, err := handler(ctx, cmd, server, &conn); err != nil {
connRW.Write([]byte(fmt.Sprintf("-%s\r\n\n", err.Error())))
if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
} else {
connRW.Write(res)
if command.Command == "ack" {
continue
}
if _, err := w.Write(res); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
// TODO: Write successful, add entry to AOF
}
connRW.Flush()
continue
}
// Handle other commands that need to be synced across the cluster
if server.raft.IsRaftLeader() {
if res, err := server.raftApply(ctx, cmd); err != nil {
connRW.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error())))
if _, err := w.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error()))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
} else {
connRW.Write(res)
if _, err := w.Write(res); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
}
connRW.Flush()
continue
}
// Forward message to leader and return immediate OK response
if server.Config.ForwardCommand {
server.memberList.ForwardDataMutation(ctx, message)
connRW.Write([]byte(utils.OK_RESPONSE))
connRW.Flush()
if _, err := w.Write([]byte(utils.OK_RESPONSE)); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
continue
}
connRW.Write([]byte("-Error not cluster leader, cannot carry out command\r\n\r\n"))
connRW.Flush()
if _, err := w.Write([]byte("-Error not cluster leader, cannot carry out command\r\n\r\n")); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
}
}
conn.Close()
if err := conn.Close(); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
}
func (server *Server) Start(ctx context.Context) {

View File

@@ -1,11 +1,11 @@
package utils
import (
"bufio"
"bytes"
"fmt"
"io"
"math/big"
"net"
"slices"
"strings"
"time"
@@ -31,15 +31,6 @@ func AdaptType(s string) interface{} {
return f
}
func Contains[T comparable](arr []T, elem T) bool {
for _, v := range arr {
if v == elem {
return true
}
}
return false
}
func Filter[T any](arr []T, test func(elem T) bool) (res []T) {
for _, e := range arr {
if test(e) {
@@ -49,9 +40,9 @@ func Filter[T any](arr []T, test func(elem T) bool) (res []T) {
return
}
func Decode(raw string) ([]string, error) {
rd := resp.NewReader(bytes.NewBufferString(raw))
res := []string{}
func Decode(raw []byte) ([]string, error) {
rd := resp.NewReader(bytes.NewBuffer(raw))
var res []string
v, _, err := rd.ReadValue()
@@ -59,7 +50,7 @@ func Decode(raw string) ([]string, error) {
return nil, err
}
if Contains[string]([]string{"SimpleString", "Integer", "Error"}, v.Type().String()) {
if slices.Contains([]string{"SimpleString", "Integer", "Error"}, v.Type().String()) {
return []string{v.String()}, nil
}
@@ -72,25 +63,28 @@ func Decode(raw string) ([]string, error) {
return res, nil
}
func ReadMessage(r *bufio.ReadWriter) (message string, err error) {
var line [][]byte
func ReadMessage(r io.Reader) ([]byte, error) {
delim := []byte{'\r', '\n', '\r', '\n'}
buffSize := 8
buff := make([]byte, buffSize)
var n int
var err error
var res []byte
for {
b, _, err := r.ReadLine()
if err != nil {
return "", err
}
if bytes.Equal(b, []byte("")) {
// End of message
n, err = r.Read(buff)
res = append(res, buff...)
if n < buffSize || err != nil {
break
}
line = append(line, b)
if bytes.Equal(buff[len(buff)-4:], delim) {
break
}
clear(buff)
}
return fmt.Sprintf("%s\r\n", string(bytes.Join(line, []byte("\r\n")))), nil
return res, err
}
func RetryBackoff(b retry.Backoff, maxRetries uint64, jitter, cappedDuration, maxDuration time.Duration) retry.Backoff {