mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-24 08:23:30 +08:00
Replaced all instances of utils.Contains with slices.Contains.
Deleted Contains function from the utils.go file. Replaced bufio.ReadWriter with io.Reader and io.Writer respectively. Updated PubSub module to remove posibility of not passing a channel name when subscribing or publishing.
This commit is contained in:
@@ -8,10 +8,10 @@ import (
|
||||
|
||||
type BroadcastMessage struct {
|
||||
NodeMeta
|
||||
Action string `json:"Action"`
|
||||
Content string `json:"Content"`
|
||||
ContentHash string `json:"ContentHash"`
|
||||
ConnId string `json:"ConnId"`
|
||||
Action string `json:"Action"`
|
||||
Content []byte `json:"Content"`
|
||||
ContentHash [16]byte `json:"ContentHash"`
|
||||
ConnId string `json:"ConnId"`
|
||||
}
|
||||
|
||||
// Invalidates Implements Broadcast interface
|
||||
|
||||
@@ -103,14 +103,14 @@ func (m *MemberList) broadcastRaftAddress(ctx context.Context) {
|
||||
m.broadcastQueue.QueueBroadcast(&msg)
|
||||
}
|
||||
|
||||
func (m *MemberList) ForwardDataMutation(ctx context.Context, cmd string) {
|
||||
func (m *MemberList) ForwardDataMutation(ctx context.Context, cmd []byte) {
|
||||
// This function is only called by non-leaders
|
||||
// It uses the broadcast queue to forward a data mutation within the cluster
|
||||
connId, _ := ctx.Value(utils.ContextConnID("ConnectionID")).(string)
|
||||
m.broadcastQueue.QueueBroadcast(&BroadcastMessage{
|
||||
Action: "MutateData",
|
||||
Content: cmd,
|
||||
ContentHash: fmt.Sprintf("%x", md5.Sum([]byte(cmd))),
|
||||
ContentHash: md5.Sum(cmd),
|
||||
ConnId: connId,
|
||||
NodeMeta: NodeMeta{
|
||||
ServerID: raft.ServerID(m.options.Config.ServerID),
|
||||
|
||||
@@ -264,6 +264,11 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
|
||||
}
|
||||
}
|
||||
|
||||
// Skip ack
|
||||
if strings.EqualFold(comm, "ack") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the command is 'auth', then return early and allow it
|
||||
if strings.EqualFold(comm, "auth") {
|
||||
// TODO: Add rate limiting to prevent auth spamming
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/kelvinmwinuka/memstore/src/utils"
|
||||
"math"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -333,7 +334,7 @@ func handleLMove(ctx context.Context, cmd []string, server utils.Server, conn *n
|
||||
whereFrom := strings.ToLower(cmd[3])
|
||||
whereTo := strings.ToLower(cmd[4])
|
||||
|
||||
if !utils.Contains[string]([]string{"left", "right"}, whereFrom) || !utils.Contains[string]([]string{"left", "right"}, whereTo) {
|
||||
if !slices.Contains([]string{"left", "right"}, whereFrom) || !slices.Contains([]string{"left", "right"}, whereTo) {
|
||||
return nil, errors.New("wherefrom and whereto arguments must be either LEFT or RIGHT")
|
||||
}
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ func NewModule() Plugin {
|
||||
return []string{}, nil
|
||||
},
|
||||
HandlerFunc: func(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
return []byte("_\r\n\r\n"), nil
|
||||
return []byte("$-1\r\n\r\n"), nil
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -32,9 +32,6 @@ func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, con
|
||||
return nil, errors.New("could not load pubsub")
|
||||
}
|
||||
switch len(cmd) {
|
||||
case 1:
|
||||
// Subscribe to all channels
|
||||
pubsub.Subscribe(ctx, conn, nil, nil)
|
||||
case 2:
|
||||
// Subscribe to specified channel
|
||||
pubsub.Subscribe(ctx, conn, cmd[1], nil)
|
||||
@@ -44,7 +41,7 @@ func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, con
|
||||
default:
|
||||
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
|
||||
}
|
||||
return []byte("+SUBSCRIBE_OK\r\n\n"), nil
|
||||
return []byte("+SUBSCRIBE_OK\r\n\r\n"), nil
|
||||
}
|
||||
|
||||
func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
|
||||
@@ -68,13 +65,10 @@ func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn
|
||||
if !ok {
|
||||
return nil, errors.New("could not load pubsub")
|
||||
}
|
||||
if len(cmd) == 3 {
|
||||
pubsub.Publish(ctx, cmd[2], cmd[1])
|
||||
} else if len(cmd) == 2 {
|
||||
pubsub.Publish(ctx, cmd[1], nil)
|
||||
} else {
|
||||
if len(cmd) != 3 {
|
||||
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
|
||||
}
|
||||
pubsub.Publish(ctx, cmd[2], cmd[1])
|
||||
return []byte(utils.OK_RESPONSE), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"container/ring"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/kelvinmwinuka/memstore/src/utils"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -39,29 +40,44 @@ func (cg *ConsumerGroup) SendMessage(message string) {
|
||||
|
||||
cg.subscribersRWMut.RUnlock()
|
||||
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(*conn), bufio.NewWriter(*conn))
|
||||
rw.WriteString(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))
|
||||
rw.Flush()
|
||||
w := io.Writer(*conn)
|
||||
r := io.Reader(*conn)
|
||||
|
||||
if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
// Wait for an ACK
|
||||
// If no ACK is received within a time limit, remove this connection from subscribers and retry
|
||||
(*conn).SetReadDeadline(time.Now().Add(250 * time.Millisecond))
|
||||
if msg, err := utils.ReadMessage(rw); err != nil {
|
||||
if err := (*conn).SetReadDeadline(time.Now().Add(250 * time.Millisecond)); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
if msg, err := utils.ReadMessage(r); err != nil {
|
||||
// Remove the connection from subscribers list
|
||||
cg.Unsubscribe(conn)
|
||||
// Reset the deadline
|
||||
(*conn).SetReadDeadline(time.Time{})
|
||||
if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
// Retry sending the message
|
||||
cg.SendMessage(message)
|
||||
} else {
|
||||
if strings.TrimSpace(msg) != "+ACK" {
|
||||
if !bytes.Equal(bytes.TrimSpace(msg), []byte("+ACK")) {
|
||||
cg.Unsubscribe(conn)
|
||||
(*conn).SetReadDeadline(time.Time{})
|
||||
if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
cg.SendMessage(message)
|
||||
}
|
||||
}
|
||||
|
||||
(*conn).SetDeadline(time.Time{})
|
||||
if err := (*conn).SetDeadline(time.Time{}); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
cg.subscribers = cg.subscribers.Next()
|
||||
}
|
||||
|
||||
@@ -152,19 +168,31 @@ func (ch *Channel) Start() {
|
||||
|
||||
for _, conn := range ch.subscribers {
|
||||
go func(conn *net.Conn) {
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(*conn), bufio.NewWriter(*conn))
|
||||
rw.WriteString(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))
|
||||
rw.Flush()
|
||||
w := io.Writer(*conn)
|
||||
r := io.Reader(*conn)
|
||||
|
||||
(*conn).SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
||||
if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\r\n", len(message), message))); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
if err := (*conn).SetReadDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
ch.Unsubscribe(conn)
|
||||
}
|
||||
defer func() {
|
||||
(*conn).SetReadDeadline(time.Time{})
|
||||
if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
ch.Unsubscribe(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
if msg, err := utils.ReadMessage(rw); err != nil {
|
||||
if msg, err := utils.ReadMessage(r); err != nil {
|
||||
ch.Unsubscribe(conn)
|
||||
} else {
|
||||
if strings.TrimSpace(msg) != "+ACK" {
|
||||
if !bytes.EqualFold(bytes.TrimSpace(msg), []byte("+ACK")) {
|
||||
ch.Unsubscribe(conn)
|
||||
}
|
||||
}
|
||||
@@ -177,7 +205,7 @@ func (ch *Channel) Start() {
|
||||
}
|
||||
|
||||
func (ch *Channel) Subscribe(conn *net.Conn, consumerGroupName interface{}) {
|
||||
if consumerGroupName == nil && !utils.Contains[*net.Conn](ch.subscribers, conn) {
|
||||
if consumerGroupName == nil && !slices.Contains(ch.subscribers, conn) {
|
||||
ch.subscribersRWMut.Lock()
|
||||
defer ch.subscribersRWMut.Unlock()
|
||||
ch.subscribers = append(ch.subscribers, conn)
|
||||
@@ -230,31 +258,21 @@ type PubSub struct {
|
||||
|
||||
func NewPubSub() *PubSub {
|
||||
return &PubSub{
|
||||
channels: []*Channel{
|
||||
NewChannel("chan"),
|
||||
},
|
||||
channels: []*Channel{},
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName interface{}, consumerGroup interface{}) {
|
||||
// If no channel name is given, subscribe to all channels
|
||||
if channelName == nil {
|
||||
for _, channel := range ps.channels {
|
||||
go channel.Subscribe(conn, nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName string, consumerGroup interface{}) {
|
||||
// Check if channel with given name exists
|
||||
// If it does, subscribe the connection to the channel
|
||||
// If it does not, create the channel and subscribe to it
|
||||
channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool {
|
||||
return c.name == channelName
|
||||
channelIdx := slices.IndexFunc(ps.channels, func(channel *Channel) bool {
|
||||
return channel.name == channelName
|
||||
})
|
||||
|
||||
if len(channels) <= 0 {
|
||||
if channelIdx == -1 {
|
||||
go func() {
|
||||
newChan := NewChannel(channelName.(string))
|
||||
newChan := NewChannel(channelName)
|
||||
newChan.Start()
|
||||
newChan.Subscribe(conn, consumerGroup)
|
||||
ps.channels = append(ps.channels, newChan)
|
||||
@@ -262,9 +280,7 @@ func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName int
|
||||
return
|
||||
}
|
||||
|
||||
for _, channel := range channels {
|
||||
go channel.Subscribe(conn, consumerGroup)
|
||||
}
|
||||
go ps.channels[channelIdx].Subscribe(conn, consumerGroup)
|
||||
}
|
||||
|
||||
func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName interface{}) {
|
||||
@@ -284,18 +300,10 @@ func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName i
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *PubSub) Publish(ctx context.Context, message string, channelName interface{}) {
|
||||
if channelName == nil {
|
||||
for _, channel := range ps.channels {
|
||||
go channel.Publish(message)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ps *PubSub) Publish(ctx context.Context, message string, channelName string) {
|
||||
channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool {
|
||||
return c.name == channelName
|
||||
})
|
||||
|
||||
for _, channel := range channels {
|
||||
go channel.Publish(message)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package set
|
||||
import (
|
||||
"github.com/kelvinmwinuka/memstore/src/utils"
|
||||
"math/rand"
|
||||
"slices"
|
||||
)
|
||||
|
||||
type Set struct {
|
||||
@@ -72,7 +73,7 @@ func (set *Set) GetRandom(count int) []string {
|
||||
// Count is positive, do not allow repeat elements
|
||||
for i := 0; i < utils.AbsInt(count); {
|
||||
n = rand.Intn(len(keys))
|
||||
if !utils.Contains(res, keys[n]) {
|
||||
if !slices.Contains(res, keys[n]) {
|
||||
res = append(res, keys[n])
|
||||
keys = utils.Filter(keys, func(elem string) bool {
|
||||
return elem != keys[n]
|
||||
|
||||
@@ -52,7 +52,7 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne
|
||||
}
|
||||
switch utils.AdaptType(cmd[i]).(type) {
|
||||
case string:
|
||||
if utils.Contains([]string{"-inf", "+inf"}, strings.ToLower(cmd[i])) {
|
||||
if slices.Contains([]string{"-inf", "+inf"}, strings.ToLower(cmd[i])) {
|
||||
membersStartIndex = i
|
||||
}
|
||||
case float64:
|
||||
@@ -111,11 +111,11 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne
|
||||
if membersStartIndex > 2 {
|
||||
options := cmd[2:membersStartIndex]
|
||||
for _, option := range options {
|
||||
if utils.Contains([]string{"xx", "nx"}, strings.ToLower(option)) {
|
||||
if slices.Contains([]string{"xx", "nx"}, strings.ToLower(option)) {
|
||||
updatePolicy = option
|
||||
continue
|
||||
}
|
||||
if utils.Contains([]string{"gt", "lt"}, strings.ToLower(option)) {
|
||||
if slices.Contains([]string{"gt", "lt"}, strings.ToLower(option)) {
|
||||
comparison = option
|
||||
continue
|
||||
}
|
||||
@@ -1725,7 +1725,7 @@ respectively.`,
|
||||
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
|
||||
}
|
||||
endIdx := slices.IndexFunc(cmd, func(s string) bool {
|
||||
return utils.Contains([]string{"MIN", "MAX", "COUNT"}, strings.ToUpper(s))
|
||||
return slices.Contains([]string{"MIN", "MAX", "COUNT"}, strings.ToUpper(s))
|
||||
})
|
||||
if endIdx == -1 {
|
||||
return cmd[1:], nil
|
||||
|
||||
@@ -138,7 +138,7 @@ func (set *SortedSet) AddOrUpdate(
|
||||
if !set.Contains(m.value) {
|
||||
return count, fmt.Errorf("cannot increment member %s as it does not exist in the sorted set", m.value)
|
||||
}
|
||||
if utils.Contains([]Score{Score(math.Inf(-1)), Score(math.Inf(1))}, set.members[m.value].score) {
|
||||
if slices.Contains([]Score{Score(math.Inf(-1)), Score(math.Inf(1))}, set.members[m.value].score) {
|
||||
return count, errors.New("cannot increment -inf or +inf")
|
||||
}
|
||||
set.members[m.value] = MemberObject{
|
||||
|
||||
@@ -3,7 +3,6 @@ package sorted_set
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"github.com/kelvinmwinuka/memstore/src/utils"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -19,7 +18,7 @@ func extractKeysWeightsAggregateWithScores(cmd []string) ([]string, []int, strin
|
||||
if weightsIndex != -1 {
|
||||
firstModifierIndex = weightsIndex
|
||||
for i := weightsIndex + 1; i < len(cmd); i++ {
|
||||
if utils.Contains([]string{"aggregate", "withscores"}, cmd[i]) {
|
||||
if slices.Contains([]string{"aggregate", "withscores"}, cmd[i]) {
|
||||
break
|
||||
}
|
||||
w, err := strconv.Atoi(cmd[i])
|
||||
@@ -43,7 +42,7 @@ func extractKeysWeightsAggregateWithScores(cmd []string) ([]string, []int, strin
|
||||
if aggregateIndex >= len(cmd)-1 {
|
||||
return []string{}, []int{}, "", false, errors.New("aggregate must be SUM, MIN, or MAX")
|
||||
}
|
||||
if !utils.Contains([]string{"sum", "min", "max"}, strings.ToLower(cmd[aggregateIndex+1])) {
|
||||
if !slices.Contains([]string{"sum", "min", "max"}, strings.ToLower(cmd[aggregateIndex+1])) {
|
||||
return []string{}, []int{}, "", false, errors.New("aggregate must be SUM, MIN, or MAX")
|
||||
}
|
||||
aggregate = strings.ToLower(cmd[aggregateIndex+1])
|
||||
@@ -93,7 +92,7 @@ func validateUpdatePolicy(updatePolicy interface{}) (string, error) {
|
||||
if !ok {
|
||||
return "", err
|
||||
}
|
||||
if !utils.Contains([]string{"nx", "xx"}, strings.ToLower(policy)) {
|
||||
if !slices.Contains([]string{"nx", "xx"}, strings.ToLower(policy)) {
|
||||
return "", err
|
||||
}
|
||||
return policy, nil
|
||||
@@ -108,7 +107,7 @@ func validateComparison(comparison interface{}) (string, error) {
|
||||
if !ok {
|
||||
return "", err
|
||||
}
|
||||
if !utils.Contains([]string{"lt", "gt"}, strings.ToLower(comp)) {
|
||||
if !slices.Contains([]string{"lt", "gt"}, strings.ToLower(comp)) {
|
||||
return "", err
|
||||
}
|
||||
return comp, nil
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
@@ -85,18 +84,20 @@ func (server *Server) StartTCP(ctx context.Context) {
|
||||
func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
server.ACL.RegisterConnection(&conn)
|
||||
|
||||
connRW := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
||||
w := io.Writer(conn)
|
||||
r := io.Reader(conn)
|
||||
|
||||
cid := server.ConnID.Add(1)
|
||||
ctx = context.WithValue(ctx, utils.ContextConnID("ConnectionID"),
|
||||
fmt.Sprintf("%s-%d", ctx.Value(utils.ContextServerID("ServerID")), cid))
|
||||
|
||||
for {
|
||||
message, err := utils.ReadMessage(connRW)
|
||||
message, err := utils.ReadMessage(r)
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
// Connection closed
|
||||
// TODO: Remove this connection from channel subscriptions
|
||||
break
|
||||
}
|
||||
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||
@@ -115,15 +116,19 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
|
||||
if cmd, err := utils.Decode(message); err != nil {
|
||||
// Return error to client
|
||||
connRW.Write([]byte(fmt.Sprintf("-Error %s\r\n\n", err.Error())))
|
||||
connRW.Flush()
|
||||
if _, err := w.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error()))); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
continue
|
||||
} else {
|
||||
command, err := server.getCommand(cmd[0])
|
||||
|
||||
if err != nil {
|
||||
connRW.WriteString(fmt.Sprintf("-%s\r\n\n", err.Error()))
|
||||
connRW.Flush()
|
||||
if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -138,47 +143,69 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
}
|
||||
|
||||
if err := server.ACL.AuthorizeConnection(&conn, cmd, command, subCommand); err != nil {
|
||||
connRW.WriteString(fmt.Sprintf("-%s\r\n\n", err.Error()))
|
||||
connRW.Flush()
|
||||
if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if !server.IsInCluster() || !synchronize {
|
||||
if res, err := handler(ctx, cmd, server, &conn); err != nil {
|
||||
connRW.Write([]byte(fmt.Sprintf("-%s\r\n\n", err.Error())))
|
||||
if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
} else {
|
||||
connRW.Write(res)
|
||||
if command.Command == "ack" {
|
||||
continue
|
||||
}
|
||||
if _, err := w.Write(res); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
// TODO: Write successful, add entry to AOF
|
||||
}
|
||||
connRW.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle other commands that need to be synced across the cluster
|
||||
if server.raft.IsRaftLeader() {
|
||||
if res, err := server.raftApply(ctx, cmd); err != nil {
|
||||
connRW.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error())))
|
||||
if _, err := w.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error()))); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
} else {
|
||||
connRW.Write(res)
|
||||
if _, err := w.Write(res); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
}
|
||||
connRW.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
// Forward message to leader and return immediate OK response
|
||||
if server.Config.ForwardCommand {
|
||||
server.memberList.ForwardDataMutation(ctx, message)
|
||||
connRW.Write([]byte(utils.OK_RESPONSE))
|
||||
connRW.Flush()
|
||||
if _, err := w.Write([]byte(utils.OK_RESPONSE)); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
connRW.Write([]byte("-Error not cluster leader, cannot carry out command\r\n\r\n"))
|
||||
connRW.Flush()
|
||||
if _, err := w.Write([]byte("-Error not cluster leader, cannot carry out command\r\n\r\n")); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
if err := conn.Close(); err != nil {
|
||||
// TODO: Log error at configured logger
|
||||
fmt.Println(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) Start(ctx context.Context) {
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -31,15 +31,6 @@ func AdaptType(s string) interface{} {
|
||||
return f
|
||||
}
|
||||
|
||||
func Contains[T comparable](arr []T, elem T) bool {
|
||||
for _, v := range arr {
|
||||
if v == elem {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func Filter[T any](arr []T, test func(elem T) bool) (res []T) {
|
||||
for _, e := range arr {
|
||||
if test(e) {
|
||||
@@ -49,9 +40,9 @@ func Filter[T any](arr []T, test func(elem T) bool) (res []T) {
|
||||
return
|
||||
}
|
||||
|
||||
func Decode(raw string) ([]string, error) {
|
||||
rd := resp.NewReader(bytes.NewBufferString(raw))
|
||||
res := []string{}
|
||||
func Decode(raw []byte) ([]string, error) {
|
||||
rd := resp.NewReader(bytes.NewBuffer(raw))
|
||||
var res []string
|
||||
|
||||
v, _, err := rd.ReadValue()
|
||||
|
||||
@@ -59,7 +50,7 @@ func Decode(raw string) ([]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if Contains[string]([]string{"SimpleString", "Integer", "Error"}, v.Type().String()) {
|
||||
if slices.Contains([]string{"SimpleString", "Integer", "Error"}, v.Type().String()) {
|
||||
return []string{v.String()}, nil
|
||||
}
|
||||
|
||||
@@ -72,25 +63,28 @@ func Decode(raw string) ([]string, error) {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func ReadMessage(r *bufio.ReadWriter) (message string, err error) {
|
||||
var line [][]byte
|
||||
func ReadMessage(r io.Reader) ([]byte, error) {
|
||||
delim := []byte{'\r', '\n', '\r', '\n'}
|
||||
buffSize := 8
|
||||
buff := make([]byte, buffSize)
|
||||
|
||||
var n int
|
||||
var err error
|
||||
var res []byte
|
||||
|
||||
for {
|
||||
b, _, err := r.ReadLine()
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if bytes.Equal(b, []byte("")) {
|
||||
// End of message
|
||||
n, err = r.Read(buff)
|
||||
res = append(res, buff...)
|
||||
if n < buffSize || err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
line = append(line, b)
|
||||
if bytes.Equal(buff[len(buff)-4:], delim) {
|
||||
break
|
||||
}
|
||||
clear(buff)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s\r\n", string(bytes.Join(line, []byte("\r\n")))), nil
|
||||
return res, err
|
||||
}
|
||||
|
||||
func RetryBackoff(b retry.Backoff, maxRetries uint64, jitter, cappedDuration, maxDuration time.Duration) retry.Backoff {
|
||||
|
||||
Reference in New Issue
Block a user