Files
SugarDB/src/server/server.go
Kelvin Clement Mwinuka ce14f59f41 Replaced all instances of utils.Contains with slices.Contains.
Deleted Contains function from the utils.go file.
Replaced bufio.ReadWriter with io.Reader and io.Writer respectively.
Updated PubSub module to remove posibility of not passing a channel name when subscribing or publishing.
2024-01-23 23:38:07 +08:00

253 lines
6.1 KiB
Go

package server
import (
"context"
"crypto/tls"
"fmt"
"github.com/kelvinmwinuka/memstore/src/memberlist"
"github.com/kelvinmwinuka/memstore/src/modules/acl"
"github.com/kelvinmwinuka/memstore/src/modules/pubsub"
"github.com/kelvinmwinuka/memstore/src/raft"
"github.com/kelvinmwinuka/memstore/src/utils"
"io"
"log"
"net"
"os"
"sync"
"sync/atomic"
"time"
)
type Server struct {
Config utils.Config
ConnID atomic.Uint64
store map[string]interface{}
keyLocks map[string]*sync.RWMutex
keyCreationLock *sync.Mutex
commands []utils.Command
raft *raft.Raft
memberList *memberlist.MemberList
CancelCh *chan os.Signal
ACL *acl.ACL
PubSub *pubsub.PubSub
}
func (server *Server) StartTCP(ctx context.Context) {
conf := server.Config
listenConfig := net.ListenConfig{
KeepAlive: 200 * time.Millisecond,
}
listener, err := listenConfig.Listen(ctx, "tcp", fmt.Sprintf("%s:%d", conf.BindAddr, conf.Port))
if err != nil {
log.Fatal(err)
}
if !conf.TLS {
// TCP
fmt.Printf("Starting TCP server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
}
if conf.TLS {
// TLS
fmt.Printf("Starting TLS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
cer, err := tls.LoadX509KeyPair(conf.Cert, conf.Key)
if err != nil {
log.Fatal(err)
}
listener = tls.NewListener(listener, &tls.Config{
Certificates: []tls.Certificate{cer},
})
}
// Listen to connection
for {
conn, err := listener.Accept()
if err != nil {
fmt.Println("Could not establish connection")
continue
}
// Read loop for connection
go server.handleConnection(ctx, conn)
}
}
func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
server.ACL.RegisterConnection(&conn)
w := io.Writer(conn)
r := io.Reader(conn)
cid := server.ConnID.Add(1)
ctx = context.WithValue(ctx, utils.ContextConnID("ConnectionID"),
fmt.Sprintf("%s-%d", ctx.Value(utils.ContextServerID("ServerID")), cid))
for {
message, err := utils.ReadMessage(r)
if err != nil {
if err == io.EOF {
// Connection closed
// TODO: Remove this connection from channel subscriptions
break
}
if err, ok := err.(net.Error); ok && err.Timeout() {
// Connection timeout
fmt.Println(err)
break
}
if err, ok := err.(tls.RecordHeaderError); ok {
// TLS verification error
fmt.Println(err)
break
}
fmt.Println(err)
break
}
if cmd, err := utils.Decode(message); err != nil {
// Return error to client
if _, err := w.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error()))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
continue
} else {
command, err := server.getCommand(cmd[0])
if err != nil {
if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
continue
}
synchronize := command.Sync
handler := command.HandlerFunc
subCommand, ok := utils.GetSubCommand(command, cmd).(utils.SubCommand)
if ok {
synchronize = subCommand.Sync
handler = subCommand.HandlerFunc
}
if err := server.ACL.AuthorizeConnection(&conn, cmd, command, subCommand); err != nil {
if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
continue
}
if !server.IsInCluster() || !synchronize {
if res, err := handler(ctx, cmd, server, &conn); err != nil {
if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
} else {
if command.Command == "ack" {
continue
}
if _, err := w.Write(res); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
// TODO: Write successful, add entry to AOF
}
continue
}
// Handle other commands that need to be synced across the cluster
if server.raft.IsRaftLeader() {
if res, err := server.raftApply(ctx, cmd); err != nil {
if _, err := w.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error()))); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
} else {
if _, err := w.Write(res); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
}
continue
}
// Forward message to leader and return immediate OK response
if server.Config.ForwardCommand {
server.memberList.ForwardDataMutation(ctx, message)
if _, err := w.Write([]byte(utils.OK_RESPONSE)); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
continue
}
if _, err := w.Write([]byte("-Error not cluster leader, cannot carry out command\r\n\r\n")); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
}
}
if err := conn.Close(); err != nil {
// TODO: Log error at configured logger
fmt.Println(err)
}
}
func (server *Server) Start(ctx context.Context) {
conf := server.Config
server.store = make(map[string]interface{})
server.keyLocks = make(map[string]*sync.RWMutex)
server.keyCreationLock = &sync.Mutex{}
server.LoadModules(ctx)
if conf.TLS && (len(conf.Key) <= 0 || len(conf.Cert) <= 0) {
fmt.Println("Must provide key and certificate file paths for TLS mode.")
return
}
if server.IsInCluster() {
// Initialise raft and memberlist
server.raft = raft.NewRaft(raft.RaftOpts{
Config: conf,
Server: server,
GetCommand: server.getCommand,
})
server.memberList = memberlist.NewMemberList(memberlist.MemberlistOpts{
Config: conf,
HasJoinedCluster: server.raft.HasJoinedCluster,
AddVoter: server.raft.AddVoter,
RemoveRaftServer: server.raft.RemoveServer,
IsRaftLeader: server.raft.IsRaftLeader,
ApplyMutate: server.raftApply,
})
server.raft.RaftInit(ctx)
server.memberList.MemberListInit(ctx)
}
server.StartTCP(ctx)
}
func (server *Server) ShutDown(ctx context.Context) {
if server.IsInCluster() {
server.raft.RaftShutdown(ctx)
server.memberList.MemberListShutdown(ctx)
}
}