mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-21 06:59:24 +08:00

Deleted Contains function from the utils.go file. Replaced bufio.ReadWriter with io.Reader and io.Writer respectively. Updated PubSub module to remove posibility of not passing a channel name when subscribing or publishing.
253 lines
6.1 KiB
Go
253 lines
6.1 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"github.com/kelvinmwinuka/memstore/src/memberlist"
|
|
"github.com/kelvinmwinuka/memstore/src/modules/acl"
|
|
"github.com/kelvinmwinuka/memstore/src/modules/pubsub"
|
|
"github.com/kelvinmwinuka/memstore/src/raft"
|
|
"github.com/kelvinmwinuka/memstore/src/utils"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
type Server struct {
|
|
Config utils.Config
|
|
|
|
ConnID atomic.Uint64
|
|
|
|
store map[string]interface{}
|
|
keyLocks map[string]*sync.RWMutex
|
|
keyCreationLock *sync.Mutex
|
|
|
|
commands []utils.Command
|
|
|
|
raft *raft.Raft
|
|
memberList *memberlist.MemberList
|
|
|
|
CancelCh *chan os.Signal
|
|
|
|
ACL *acl.ACL
|
|
PubSub *pubsub.PubSub
|
|
}
|
|
|
|
func (server *Server) StartTCP(ctx context.Context) {
|
|
conf := server.Config
|
|
|
|
listenConfig := net.ListenConfig{
|
|
KeepAlive: 200 * time.Millisecond,
|
|
}
|
|
|
|
listener, err := listenConfig.Listen(ctx, "tcp", fmt.Sprintf("%s:%d", conf.BindAddr, conf.Port))
|
|
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
if !conf.TLS {
|
|
// TCP
|
|
fmt.Printf("Starting TCP server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
|
|
}
|
|
|
|
if conf.TLS {
|
|
// TLS
|
|
fmt.Printf("Starting TLS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
|
|
cer, err := tls.LoadX509KeyPair(conf.Cert, conf.Key)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
listener = tls.NewListener(listener, &tls.Config{
|
|
Certificates: []tls.Certificate{cer},
|
|
})
|
|
}
|
|
|
|
// Listen to connection
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
fmt.Println("Could not establish connection")
|
|
continue
|
|
}
|
|
// Read loop for connection
|
|
go server.handleConnection(ctx, conn)
|
|
}
|
|
}
|
|
|
|
func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
|
server.ACL.RegisterConnection(&conn)
|
|
|
|
w := io.Writer(conn)
|
|
r := io.Reader(conn)
|
|
|
|
cid := server.ConnID.Add(1)
|
|
ctx = context.WithValue(ctx, utils.ContextConnID("ConnectionID"),
|
|
fmt.Sprintf("%s-%d", ctx.Value(utils.ContextServerID("ServerID")), cid))
|
|
|
|
for {
|
|
message, err := utils.ReadMessage(r)
|
|
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
// Connection closed
|
|
// TODO: Remove this connection from channel subscriptions
|
|
break
|
|
}
|
|
if err, ok := err.(net.Error); ok && err.Timeout() {
|
|
// Connection timeout
|
|
fmt.Println(err)
|
|
break
|
|
}
|
|
if err, ok := err.(tls.RecordHeaderError); ok {
|
|
// TLS verification error
|
|
fmt.Println(err)
|
|
break
|
|
}
|
|
fmt.Println(err)
|
|
break
|
|
}
|
|
|
|
if cmd, err := utils.Decode(message); err != nil {
|
|
// Return error to client
|
|
if _, err := w.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error()))); err != nil {
|
|
// TODO: Log error at configured logger
|
|
fmt.Println(err)
|
|
}
|
|
continue
|
|
} else {
|
|
command, err := server.getCommand(cmd[0])
|
|
|
|
if err != nil {
|
|
if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
|
|
// TODO: Log error at configured logger
|
|
fmt.Println(err)
|
|
}
|
|
continue
|
|
}
|
|
|
|
synchronize := command.Sync
|
|
handler := command.HandlerFunc
|
|
|
|
subCommand, ok := utils.GetSubCommand(command, cmd).(utils.SubCommand)
|
|
|
|
if ok {
|
|
synchronize = subCommand.Sync
|
|
handler = subCommand.HandlerFunc
|
|
}
|
|
|
|
if err := server.ACL.AuthorizeConnection(&conn, cmd, command, subCommand); err != nil {
|
|
if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
|
|
// TODO: Log error at configured logger
|
|
fmt.Println(err)
|
|
}
|
|
continue
|
|
}
|
|
|
|
if !server.IsInCluster() || !synchronize {
|
|
if res, err := handler(ctx, cmd, server, &conn); err != nil {
|
|
if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
|
|
// TODO: Log error at configured logger
|
|
fmt.Println(err)
|
|
}
|
|
} else {
|
|
if command.Command == "ack" {
|
|
continue
|
|
}
|
|
if _, err := w.Write(res); err != nil {
|
|
// TODO: Log error at configured logger
|
|
fmt.Println(err)
|
|
}
|
|
// TODO: Write successful, add entry to AOF
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Handle other commands that need to be synced across the cluster
|
|
if server.raft.IsRaftLeader() {
|
|
if res, err := server.raftApply(ctx, cmd); err != nil {
|
|
if _, err := w.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error()))); err != nil {
|
|
// TODO: Log error at configured logger
|
|
fmt.Println(err)
|
|
}
|
|
} else {
|
|
if _, err := w.Write(res); err != nil {
|
|
// TODO: Log error at configured logger
|
|
fmt.Println(err)
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Forward message to leader and return immediate OK response
|
|
if server.Config.ForwardCommand {
|
|
server.memberList.ForwardDataMutation(ctx, message)
|
|
if _, err := w.Write([]byte(utils.OK_RESPONSE)); err != nil {
|
|
// TODO: Log error at configured logger
|
|
fmt.Println(err)
|
|
}
|
|
continue
|
|
}
|
|
|
|
if _, err := w.Write([]byte("-Error not cluster leader, cannot carry out command\r\n\r\n")); err != nil {
|
|
// TODO: Log error at configured logger
|
|
fmt.Println(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := conn.Close(); err != nil {
|
|
// TODO: Log error at configured logger
|
|
fmt.Println(err)
|
|
}
|
|
}
|
|
|
|
func (server *Server) Start(ctx context.Context) {
|
|
conf := server.Config
|
|
|
|
server.store = make(map[string]interface{})
|
|
server.keyLocks = make(map[string]*sync.RWMutex)
|
|
server.keyCreationLock = &sync.Mutex{}
|
|
|
|
server.LoadModules(ctx)
|
|
|
|
if conf.TLS && (len(conf.Key) <= 0 || len(conf.Cert) <= 0) {
|
|
fmt.Println("Must provide key and certificate file paths for TLS mode.")
|
|
return
|
|
}
|
|
|
|
if server.IsInCluster() {
|
|
// Initialise raft and memberlist
|
|
server.raft = raft.NewRaft(raft.RaftOpts{
|
|
Config: conf,
|
|
Server: server,
|
|
GetCommand: server.getCommand,
|
|
})
|
|
server.memberList = memberlist.NewMemberList(memberlist.MemberlistOpts{
|
|
Config: conf,
|
|
HasJoinedCluster: server.raft.HasJoinedCluster,
|
|
AddVoter: server.raft.AddVoter,
|
|
RemoveRaftServer: server.raft.RemoveServer,
|
|
IsRaftLeader: server.raft.IsRaftLeader,
|
|
ApplyMutate: server.raftApply,
|
|
})
|
|
server.raft.RaftInit(ctx)
|
|
server.memberList.MemberListInit(ctx)
|
|
}
|
|
|
|
server.StartTCP(ctx)
|
|
}
|
|
|
|
func (server *Server) ShutDown(ctx context.Context) {
|
|
if server.IsInCluster() {
|
|
server.raft.RaftShutdown(ctx)
|
|
server.memberList.MemberListShutdown(ctx)
|
|
}
|
|
}
|