Replaced all instances of utils.Contains with slices.Contains.

Deleted Contains function from the utils.go file.
Replaced bufio.ReadWriter with io.Reader and io.Writer respectively.
Updated PubSub module to remove posibility of not passing a channel name when subscribing or publishing.
This commit is contained in:
Kelvin Clement Mwinuka
2024-01-23 23:38:07 +08:00
parent 490bddf80c
commit ce14f59f41
13 changed files with 152 additions and 123 deletions

View File

@@ -8,10 +8,10 @@ import (
type BroadcastMessage struct { type BroadcastMessage struct {
NodeMeta NodeMeta
Action string `json:"Action"` Action string `json:"Action"`
Content string `json:"Content"` Content []byte `json:"Content"`
ContentHash string `json:"ContentHash"` ContentHash [16]byte `json:"ContentHash"`
ConnId string `json:"ConnId"` ConnId string `json:"ConnId"`
} }
// Invalidates Implements Broadcast interface // Invalidates Implements Broadcast interface

View File

@@ -103,14 +103,14 @@ func (m *MemberList) broadcastRaftAddress(ctx context.Context) {
m.broadcastQueue.QueueBroadcast(&msg) m.broadcastQueue.QueueBroadcast(&msg)
} }
func (m *MemberList) ForwardDataMutation(ctx context.Context, cmd string) { func (m *MemberList) ForwardDataMutation(ctx context.Context, cmd []byte) {
// This function is only called by non-leaders // This function is only called by non-leaders
// It uses the broadcast queue to forward a data mutation within the cluster // It uses the broadcast queue to forward a data mutation within the cluster
connId, _ := ctx.Value(utils.ContextConnID("ConnectionID")).(string) connId, _ := ctx.Value(utils.ContextConnID("ConnectionID")).(string)
m.broadcastQueue.QueueBroadcast(&BroadcastMessage{ m.broadcastQueue.QueueBroadcast(&BroadcastMessage{
Action: "MutateData", Action: "MutateData",
Content: cmd, Content: cmd,
ContentHash: fmt.Sprintf("%x", md5.Sum([]byte(cmd))), ContentHash: md5.Sum(cmd),
ConnId: connId, ConnId: connId,
NodeMeta: NodeMeta{ NodeMeta: NodeMeta{
ServerID: raft.ServerID(m.options.Config.ServerID), ServerID: raft.ServerID(m.options.Config.ServerID),

View File

@@ -264,6 +264,11 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
} }
} }
// Skip ack
if strings.EqualFold(comm, "ack") {
return nil
}
// If the command is 'auth', then return early and allow it // If the command is 'auth', then return early and allow it
if strings.EqualFold(comm, "auth") { if strings.EqualFold(comm, "auth") {
// TODO: Add rate limiting to prevent auth spamming // TODO: Add rate limiting to prevent auth spamming

View File

@@ -7,6 +7,7 @@ import (
"github.com/kelvinmwinuka/memstore/src/utils" "github.com/kelvinmwinuka/memstore/src/utils"
"math" "math"
"net" "net"
"slices"
"strings" "strings"
) )
@@ -333,7 +334,7 @@ func handleLMove(ctx context.Context, cmd []string, server utils.Server, conn *n
whereFrom := strings.ToLower(cmd[3]) whereFrom := strings.ToLower(cmd[3])
whereTo := strings.ToLower(cmd[4]) whereTo := strings.ToLower(cmd[4])
if !utils.Contains[string]([]string{"left", "right"}, whereFrom) || !utils.Contains[string]([]string{"left", "right"}, whereTo) { if !slices.Contains([]string{"left", "right"}, whereFrom) || !slices.Contains([]string{"left", "right"}, whereTo) {
return nil, errors.New("wherefrom and whereto arguments must be either LEFT or RIGHT") return nil, errors.New("wherefrom and whereto arguments must be either LEFT or RIGHT")
} }

View File

@@ -59,7 +59,7 @@ func NewModule() Plugin {
return []string{}, nil return []string{}, nil
}, },
HandlerFunc: func(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { HandlerFunc: func(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
return []byte("_\r\n\r\n"), nil return []byte("$-1\r\n\r\n"), nil
}, },
}, },
}, },

View File

@@ -32,9 +32,6 @@ func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, con
return nil, errors.New("could not load pubsub") return nil, errors.New("could not load pubsub")
} }
switch len(cmd) { switch len(cmd) {
case 1:
// Subscribe to all channels
pubsub.Subscribe(ctx, conn, nil, nil)
case 2: case 2:
// Subscribe to specified channel // Subscribe to specified channel
pubsub.Subscribe(ctx, conn, cmd[1], nil) pubsub.Subscribe(ctx, conn, cmd[1], nil)
@@ -44,7 +41,7 @@ func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, con
default: default:
return nil, errors.New(utils.WRONG_ARGS_RESPONSE) return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
} }
return []byte("+SUBSCRIBE_OK\r\n\n"), nil return []byte("+SUBSCRIBE_OK\r\n\r\n"), nil
} }
func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -68,13 +65,10 @@ func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn
if !ok { if !ok {
return nil, errors.New("could not load pubsub") return nil, errors.New("could not load pubsub")
} }
if len(cmd) == 3 { if len(cmd) != 3 {
pubsub.Publish(ctx, cmd[2], cmd[1])
} else if len(cmd) == 2 {
pubsub.Publish(ctx, cmd[1], nil)
} else {
return nil, errors.New(utils.WRONG_ARGS_RESPONSE) return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
} }
pubsub.Publish(ctx, cmd[2], cmd[1])
return []byte(utils.OK_RESPONSE), nil return []byte(utils.OK_RESPONSE), nil
} }

View File

@@ -1,13 +1,14 @@
package pubsub package pubsub
import ( import (
"bufio" "bytes"
"container/ring" "container/ring"
"context" "context"
"fmt" "fmt"
"github.com/kelvinmwinuka/memstore/src/utils" "github.com/kelvinmwinuka/memstore/src/utils"
"io"
"net" "net"
"strings" "slices"
"sync" "sync"
"time" "time"
) )
@@ -39,29 +40,44 @@ func (cg *ConsumerGroup) SendMessage(message string) {
cg.subscribersRWMut.RUnlock() cg.subscribersRWMut.RUnlock()
rw := bufio.NewReadWriter(bufio.NewReader(*conn), bufio.NewWriter(*conn)) w := io.Writer(*conn)
rw.WriteString(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message)) r := io.Reader(*conn)
rw.Flush()
if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
// Wait for an ACK // Wait for an ACK
// If no ACK is received within a time limit, remove this connection from subscribers and retry // If no ACK is received within a time limit, remove this connection from subscribers and retry
(*conn).SetReadDeadline(time.Now().Add(250 * time.Millisecond)) if err := (*conn).SetReadDeadline(time.Now().Add(250 * time.Millisecond)); err != nil {
if msg, err := utils.ReadMessage(rw); err != nil { // TODO: Log error at configured logger
fmt.Println(err)
}
if msg, err := utils.ReadMessage(r); err != nil {
// Remove the connection from subscribers list // Remove the connection from subscribers list
cg.Unsubscribe(conn) cg.Unsubscribe(conn)
// Reset the deadline // Reset the deadline
(*conn).SetReadDeadline(time.Time{}) if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
// Retry sending the message // Retry sending the message
cg.SendMessage(message) cg.SendMessage(message)
} else { } else {
if strings.TrimSpace(msg) != "+ACK" { if !bytes.Equal(bytes.TrimSpace(msg), []byte("+ACK")) {
cg.Unsubscribe(conn) cg.Unsubscribe(conn)
(*conn).SetReadDeadline(time.Time{}) if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
cg.SendMessage(message) cg.SendMessage(message)
} }
} }
(*conn).SetDeadline(time.Time{}) if err := (*conn).SetDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
cg.subscribers = cg.subscribers.Next() cg.subscribers = cg.subscribers.Next()
} }
@@ -152,19 +168,31 @@ func (ch *Channel) Start() {
for _, conn := range ch.subscribers { for _, conn := range ch.subscribers {
go func(conn *net.Conn) { go func(conn *net.Conn) {
rw := bufio.NewReadWriter(bufio.NewReader(*conn), bufio.NewWriter(*conn)) w := io.Writer(*conn)
rw.WriteString(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message)) r := io.Reader(*conn)
rw.Flush()
(*conn).SetReadDeadline(time.Now().Add(200 * time.Millisecond)) if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\r\n", len(message), message))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
if err := (*conn).SetReadDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
ch.Unsubscribe(conn)
}
defer func() { defer func() {
(*conn).SetReadDeadline(time.Time{}) if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
ch.Unsubscribe(conn)
}
}() }()
if msg, err := utils.ReadMessage(rw); err != nil { if msg, err := utils.ReadMessage(r); err != nil {
ch.Unsubscribe(conn) ch.Unsubscribe(conn)
} else { } else {
if strings.TrimSpace(msg) != "+ACK" { if !bytes.EqualFold(bytes.TrimSpace(msg), []byte("+ACK")) {
ch.Unsubscribe(conn) ch.Unsubscribe(conn)
} }
} }
@@ -177,7 +205,7 @@ func (ch *Channel) Start() {
} }
func (ch *Channel) Subscribe(conn *net.Conn, consumerGroupName interface{}) { func (ch *Channel) Subscribe(conn *net.Conn, consumerGroupName interface{}) {
if consumerGroupName == nil && !utils.Contains[*net.Conn](ch.subscribers, conn) { if consumerGroupName == nil && !slices.Contains(ch.subscribers, conn) {
ch.subscribersRWMut.Lock() ch.subscribersRWMut.Lock()
defer ch.subscribersRWMut.Unlock() defer ch.subscribersRWMut.Unlock()
ch.subscribers = append(ch.subscribers, conn) ch.subscribers = append(ch.subscribers, conn)
@@ -230,31 +258,21 @@ type PubSub struct {
func NewPubSub() *PubSub { func NewPubSub() *PubSub {
return &PubSub{ return &PubSub{
channels: []*Channel{ channels: []*Channel{},
NewChannel("chan"),
},
} }
} }
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName interface{}, consumerGroup interface{}) { func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName string, consumerGroup interface{}) {
// If no channel name is given, subscribe to all channels
if channelName == nil {
for _, channel := range ps.channels {
go channel.Subscribe(conn, nil)
}
return
}
// Check if channel with given name exists // Check if channel with given name exists
// If it does, subscribe the connection to the channel // If it does, subscribe the connection to the channel
// If it does not, create the channel and subscribe to it // If it does not, create the channel and subscribe to it
channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool { channelIdx := slices.IndexFunc(ps.channels, func(channel *Channel) bool {
return c.name == channelName return channel.name == channelName
}) })
if len(channels) <= 0 { if channelIdx == -1 {
go func() { go func() {
newChan := NewChannel(channelName.(string)) newChan := NewChannel(channelName)
newChan.Start() newChan.Start()
newChan.Subscribe(conn, consumerGroup) newChan.Subscribe(conn, consumerGroup)
ps.channels = append(ps.channels, newChan) ps.channels = append(ps.channels, newChan)
@@ -262,9 +280,7 @@ func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName int
return return
} }
for _, channel := range channels { go ps.channels[channelIdx].Subscribe(conn, consumerGroup)
go channel.Subscribe(conn, consumerGroup)
}
} }
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName interface{}) { func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName interface{}) {
@@ -284,18 +300,10 @@ func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName i
} }
} }
func (ps *PubSub) Publish(ctx context.Context, message string, channelName interface{}) { func (ps *PubSub) Publish(ctx context.Context, message string, channelName string) {
if channelName == nil {
for _, channel := range ps.channels {
go channel.Publish(message)
}
return
}
channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool { channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool {
return c.name == channelName return c.name == channelName
}) })
for _, channel := range channels { for _, channel := range channels {
go channel.Publish(message) go channel.Publish(message)
} }

View File

@@ -3,6 +3,7 @@ package set
import ( import (
"github.com/kelvinmwinuka/memstore/src/utils" "github.com/kelvinmwinuka/memstore/src/utils"
"math/rand" "math/rand"
"slices"
) )
type Set struct { type Set struct {
@@ -72,7 +73,7 @@ func (set *Set) GetRandom(count int) []string {
// Count is positive, do not allow repeat elements // Count is positive, do not allow repeat elements
for i := 0; i < utils.AbsInt(count); { for i := 0; i < utils.AbsInt(count); {
n = rand.Intn(len(keys)) n = rand.Intn(len(keys))
if !utils.Contains(res, keys[n]) { if !slices.Contains(res, keys[n]) {
res = append(res, keys[n]) res = append(res, keys[n])
keys = utils.Filter(keys, func(elem string) bool { keys = utils.Filter(keys, func(elem string) bool {
return elem != keys[n] return elem != keys[n]

View File

@@ -52,7 +52,7 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne
} }
switch utils.AdaptType(cmd[i]).(type) { switch utils.AdaptType(cmd[i]).(type) {
case string: case string:
if utils.Contains([]string{"-inf", "+inf"}, strings.ToLower(cmd[i])) { if slices.Contains([]string{"-inf", "+inf"}, strings.ToLower(cmd[i])) {
membersStartIndex = i membersStartIndex = i
} }
case float64: case float64:
@@ -111,11 +111,11 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne
if membersStartIndex > 2 { if membersStartIndex > 2 {
options := cmd[2:membersStartIndex] options := cmd[2:membersStartIndex]
for _, option := range options { for _, option := range options {
if utils.Contains([]string{"xx", "nx"}, strings.ToLower(option)) { if slices.Contains([]string{"xx", "nx"}, strings.ToLower(option)) {
updatePolicy = option updatePolicy = option
continue continue
} }
if utils.Contains([]string{"gt", "lt"}, strings.ToLower(option)) { if slices.Contains([]string{"gt", "lt"}, strings.ToLower(option)) {
comparison = option comparison = option
continue continue
} }
@@ -1725,7 +1725,7 @@ respectively.`,
return nil, errors.New(utils.WRONG_ARGS_RESPONSE) return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
} }
endIdx := slices.IndexFunc(cmd, func(s string) bool { endIdx := slices.IndexFunc(cmd, func(s string) bool {
return utils.Contains([]string{"MIN", "MAX", "COUNT"}, strings.ToUpper(s)) return slices.Contains([]string{"MIN", "MAX", "COUNT"}, strings.ToUpper(s))
}) })
if endIdx == -1 { if endIdx == -1 {
return cmd[1:], nil return cmd[1:], nil

View File

@@ -138,7 +138,7 @@ func (set *SortedSet) AddOrUpdate(
if !set.Contains(m.value) { if !set.Contains(m.value) {
return count, fmt.Errorf("cannot increment member %s as it does not exist in the sorted set", m.value) return count, fmt.Errorf("cannot increment member %s as it does not exist in the sorted set", m.value)
} }
if utils.Contains([]Score{Score(math.Inf(-1)), Score(math.Inf(1))}, set.members[m.value].score) { if slices.Contains([]Score{Score(math.Inf(-1)), Score(math.Inf(1))}, set.members[m.value].score) {
return count, errors.New("cannot increment -inf or +inf") return count, errors.New("cannot increment -inf or +inf")
} }
set.members[m.value] = MemberObject{ set.members[m.value] = MemberObject{

View File

@@ -3,7 +3,6 @@ package sorted_set
import ( import (
"cmp" "cmp"
"errors" "errors"
"github.com/kelvinmwinuka/memstore/src/utils"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@@ -19,7 +18,7 @@ func extractKeysWeightsAggregateWithScores(cmd []string) ([]string, []int, strin
if weightsIndex != -1 { if weightsIndex != -1 {
firstModifierIndex = weightsIndex firstModifierIndex = weightsIndex
for i := weightsIndex + 1; i < len(cmd); i++ { for i := weightsIndex + 1; i < len(cmd); i++ {
if utils.Contains([]string{"aggregate", "withscores"}, cmd[i]) { if slices.Contains([]string{"aggregate", "withscores"}, cmd[i]) {
break break
} }
w, err := strconv.Atoi(cmd[i]) w, err := strconv.Atoi(cmd[i])
@@ -43,7 +42,7 @@ func extractKeysWeightsAggregateWithScores(cmd []string) ([]string, []int, strin
if aggregateIndex >= len(cmd)-1 { if aggregateIndex >= len(cmd)-1 {
return []string{}, []int{}, "", false, errors.New("aggregate must be SUM, MIN, or MAX") return []string{}, []int{}, "", false, errors.New("aggregate must be SUM, MIN, or MAX")
} }
if !utils.Contains([]string{"sum", "min", "max"}, strings.ToLower(cmd[aggregateIndex+1])) { if !slices.Contains([]string{"sum", "min", "max"}, strings.ToLower(cmd[aggregateIndex+1])) {
return []string{}, []int{}, "", false, errors.New("aggregate must be SUM, MIN, or MAX") return []string{}, []int{}, "", false, errors.New("aggregate must be SUM, MIN, or MAX")
} }
aggregate = strings.ToLower(cmd[aggregateIndex+1]) aggregate = strings.ToLower(cmd[aggregateIndex+1])
@@ -93,7 +92,7 @@ func validateUpdatePolicy(updatePolicy interface{}) (string, error) {
if !ok { if !ok {
return "", err return "", err
} }
if !utils.Contains([]string{"nx", "xx"}, strings.ToLower(policy)) { if !slices.Contains([]string{"nx", "xx"}, strings.ToLower(policy)) {
return "", err return "", err
} }
return policy, nil return policy, nil
@@ -108,7 +107,7 @@ func validateComparison(comparison interface{}) (string, error) {
if !ok { if !ok {
return "", err return "", err
} }
if !utils.Contains([]string{"lt", "gt"}, strings.ToLower(comp)) { if !slices.Contains([]string{"lt", "gt"}, strings.ToLower(comp)) {
return "", err return "", err
} }
return comp, nil return comp, nil

View File

@@ -1,7 +1,6 @@
package server package server
import ( import (
"bufio"
"context" "context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
@@ -85,18 +84,20 @@ func (server *Server) StartTCP(ctx context.Context) {
func (server *Server) handleConnection(ctx context.Context, conn net.Conn) { func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
server.ACL.RegisterConnection(&conn) server.ACL.RegisterConnection(&conn)
connRW := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) w := io.Writer(conn)
r := io.Reader(conn)
cid := server.ConnID.Add(1) cid := server.ConnID.Add(1)
ctx = context.WithValue(ctx, utils.ContextConnID("ConnectionID"), ctx = context.WithValue(ctx, utils.ContextConnID("ConnectionID"),
fmt.Sprintf("%s-%d", ctx.Value(utils.ContextServerID("ServerID")), cid)) fmt.Sprintf("%s-%d", ctx.Value(utils.ContextServerID("ServerID")), cid))
for { for {
message, err := utils.ReadMessage(connRW) message, err := utils.ReadMessage(r)
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
// Connection closed // Connection closed
// TODO: Remove this connection from channel subscriptions
break break
} }
if err, ok := err.(net.Error); ok && err.Timeout() { if err, ok := err.(net.Error); ok && err.Timeout() {
@@ -115,15 +116,19 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
if cmd, err := utils.Decode(message); err != nil { if cmd, err := utils.Decode(message); err != nil {
// Return error to client // Return error to client
connRW.Write([]byte(fmt.Sprintf("-Error %s\r\n\n", err.Error()))) if _, err := w.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error()))); err != nil {
connRW.Flush() // TODO: Log error at configured logger
fmt.Println(err)
}
continue continue
} else { } else {
command, err := server.getCommand(cmd[0]) command, err := server.getCommand(cmd[0])
if err != nil { if err != nil {
connRW.WriteString(fmt.Sprintf("-%s\r\n\n", err.Error())) if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
connRW.Flush() // TODO: Log error at configured logger
fmt.Println(err)
}
continue continue
} }
@@ -138,47 +143,69 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
} }
if err := server.ACL.AuthorizeConnection(&conn, cmd, command, subCommand); err != nil { if err := server.ACL.AuthorizeConnection(&conn, cmd, command, subCommand); err != nil {
connRW.WriteString(fmt.Sprintf("-%s\r\n\n", err.Error())) if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
connRW.Flush() // TODO: Log error at configured logger
fmt.Println(err)
}
continue continue
} }
if !server.IsInCluster() || !synchronize { if !server.IsInCluster() || !synchronize {
if res, err := handler(ctx, cmd, server, &conn); err != nil { if res, err := handler(ctx, cmd, server, &conn); err != nil {
connRW.Write([]byte(fmt.Sprintf("-%s\r\n\n", err.Error()))) if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
} else { } else {
connRW.Write(res) if command.Command == "ack" {
continue
}
if _, err := w.Write(res); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
// TODO: Write successful, add entry to AOF // TODO: Write successful, add entry to AOF
} }
connRW.Flush()
continue continue
} }
// Handle other commands that need to be synced across the cluster // Handle other commands that need to be synced across the cluster
if server.raft.IsRaftLeader() { if server.raft.IsRaftLeader() {
if res, err := server.raftApply(ctx, cmd); err != nil { if res, err := server.raftApply(ctx, cmd); err != nil {
connRW.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error()))) if _, err := w.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error()))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
} else { } else {
connRW.Write(res) if _, err := w.Write(res); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
} }
connRW.Flush()
continue continue
} }
// Forward message to leader and return immediate OK response // Forward message to leader and return immediate OK response
if server.Config.ForwardCommand { if server.Config.ForwardCommand {
server.memberList.ForwardDataMutation(ctx, message) server.memberList.ForwardDataMutation(ctx, message)
connRW.Write([]byte(utils.OK_RESPONSE)) if _, err := w.Write([]byte(utils.OK_RESPONSE)); err != nil {
connRW.Flush() // TODO: Log error at configured logger
fmt.Println(err)
}
continue continue
} }
connRW.Write([]byte("-Error not cluster leader, cannot carry out command\r\n\r\n")) if _, err := w.Write([]byte("-Error not cluster leader, cannot carry out command\r\n\r\n")); err != nil {
connRW.Flush() // TODO: Log error at configured logger
fmt.Println(err)
}
} }
} }
conn.Close() if err := conn.Close(); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
} }
func (server *Server) Start(ctx context.Context) { func (server *Server) Start(ctx context.Context) {

View File

@@ -1,11 +1,11 @@
package utils package utils
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "io"
"math/big" "math/big"
"net" "net"
"slices"
"strings" "strings"
"time" "time"
@@ -31,15 +31,6 @@ func AdaptType(s string) interface{} {
return f return f
} }
func Contains[T comparable](arr []T, elem T) bool {
for _, v := range arr {
if v == elem {
return true
}
}
return false
}
func Filter[T any](arr []T, test func(elem T) bool) (res []T) { func Filter[T any](arr []T, test func(elem T) bool) (res []T) {
for _, e := range arr { for _, e := range arr {
if test(e) { if test(e) {
@@ -49,9 +40,9 @@ func Filter[T any](arr []T, test func(elem T) bool) (res []T) {
return return
} }
func Decode(raw string) ([]string, error) { func Decode(raw []byte) ([]string, error) {
rd := resp.NewReader(bytes.NewBufferString(raw)) rd := resp.NewReader(bytes.NewBuffer(raw))
res := []string{} var res []string
v, _, err := rd.ReadValue() v, _, err := rd.ReadValue()
@@ -59,7 +50,7 @@ func Decode(raw string) ([]string, error) {
return nil, err return nil, err
} }
if Contains[string]([]string{"SimpleString", "Integer", "Error"}, v.Type().String()) { if slices.Contains([]string{"SimpleString", "Integer", "Error"}, v.Type().String()) {
return []string{v.String()}, nil return []string{v.String()}, nil
} }
@@ -72,25 +63,28 @@ func Decode(raw string) ([]string, error) {
return res, nil return res, nil
} }
func ReadMessage(r *bufio.ReadWriter) (message string, err error) { func ReadMessage(r io.Reader) ([]byte, error) {
var line [][]byte delim := []byte{'\r', '\n', '\r', '\n'}
buffSize := 8
buff := make([]byte, buffSize)
var n int
var err error
var res []byte
for { for {
b, _, err := r.ReadLine() n, err = r.Read(buff)
res = append(res, buff...)
if err != nil { if n < buffSize || err != nil {
return "", err
}
if bytes.Equal(b, []byte("")) {
// End of message
break break
} }
if bytes.Equal(buff[len(buff)-4:], delim) {
line = append(line, b) break
}
clear(buff)
} }
return fmt.Sprintf("%s\r\n", string(bytes.Join(line, []byte("\r\n")))), nil return res, err
} }
func RetryBackoff(b retry.Backoff, maxRetries uint64, jitter, cappedDuration, maxDuration time.Duration) retry.Backoff { func RetryBackoff(b retry.Backoff, maxRetries uint64, jitter, cappedDuration, maxDuration time.Duration) retry.Backoff {