mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-17 05:20:47 +08:00
Removed http options flag as it will not be used for now.
Moved server cluster receiver functions into cluster.go. Moved keyspace receiver functions into keyspace.go. Moved module and command loading into modules.go. Updated Dockerfile and docker-compose to remove http flag to server.
This commit is contained in:
@@ -20,7 +20,6 @@ CMD "./server" \
|
||||
"--cert" "${CERT}" \
|
||||
"--pluginDir" "${PLUGIN_DIR}" \
|
||||
"--dataDir" "${DATA_DIR}" \
|
||||
"--http=${HTTP}" \
|
||||
"--tls=${TLS}" \
|
||||
"--inMemory=${IN_MEMORY}" \
|
||||
"--bootstrapCluster=${BOOTSTRAP_CLUSTER}" \
|
||||
|
423
src/main.go
423
src/main.go
@@ -1,416 +1,18 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
ml "github.com/kelvinmwinuka/memstore/src/memberlist"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/acl"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/etc"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/get"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/hash"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/list"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/ping"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/pubsub"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/set"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/sorted_set"
|
||||
str "github.com/kelvinmwinuka/memstore/src/modules/string"
|
||||
rl "github.com/kelvinmwinuka/memstore/src/raft"
|
||||
"io"
|
||||
"github.com/kelvinmwinuka/memstore/src/server"
|
||||
"github.com/kelvinmwinuka/memstore/src/utils"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/kelvinmwinuka/memstore/src/utils"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
config utils.Config
|
||||
|
||||
connID atomic.Uint64
|
||||
|
||||
store map[string]interface{}
|
||||
keyLocks map[string]*sync.RWMutex
|
||||
keyCreationLock *sync.Mutex
|
||||
|
||||
commands []utils.Command
|
||||
|
||||
raft *rl.Raft
|
||||
memberList *ml.MemberList
|
||||
|
||||
cancelCh *chan os.Signal
|
||||
|
||||
ACL *acl.ACL
|
||||
PubSub *pubsub.PubSub
|
||||
}
|
||||
|
||||
func (server *Server) KeyLock(ctx context.Context, key string) (bool, error) {
|
||||
ticker := time.NewTicker(5 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
default:
|
||||
ok := server.keyLocks[key].TryLock()
|
||||
if ok {
|
||||
return true, nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return false, context.Cause(ctx)
|
||||
}
|
||||
<-ticker.C
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) KeyUnlock(key string) {
|
||||
server.keyLocks[key].Unlock()
|
||||
}
|
||||
|
||||
func (server *Server) KeyRLock(ctx context.Context, key string) (bool, error) {
|
||||
ticker := time.NewTicker(5 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
default:
|
||||
ok := server.keyLocks[key].TryRLock()
|
||||
if ok {
|
||||
return true, nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return false, context.Cause(ctx)
|
||||
}
|
||||
<-ticker.C
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) KeyRUnlock(key string) {
|
||||
server.keyLocks[key].RUnlock()
|
||||
}
|
||||
|
||||
func (server *Server) KeyExists(key string) bool {
|
||||
return server.keyLocks[key] != nil
|
||||
}
|
||||
|
||||
func (server *Server) CreateKeyAndLock(ctx context.Context, key string) (bool, error) {
|
||||
server.keyCreationLock.Lock()
|
||||
defer server.keyCreationLock.Unlock()
|
||||
|
||||
if !server.KeyExists(key) {
|
||||
keyLock := &sync.RWMutex{}
|
||||
keyLock.Lock()
|
||||
server.keyLocks[key] = keyLock
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return server.KeyLock(ctx, key)
|
||||
}
|
||||
|
||||
func (server *Server) GetValue(key string) interface{} {
|
||||
return server.store[key]
|
||||
}
|
||||
|
||||
func (server *Server) SetValue(ctx context.Context, key string, value interface{}) {
|
||||
server.store[key] = value
|
||||
}
|
||||
|
||||
func (server *Server) GetAllCommands(ctx context.Context) []utils.Command {
|
||||
return server.commands
|
||||
}
|
||||
|
||||
func (server *Server) GetACL() interface{} {
|
||||
return server.ACL
|
||||
}
|
||||
|
||||
func (server *Server) GetPubSub() interface{} {
|
||||
return server.PubSub
|
||||
}
|
||||
|
||||
func (server *Server) getCommand(cmd string) (utils.Command, error) {
|
||||
for _, command := range server.commands {
|
||||
if strings.EqualFold(command.Command, cmd) {
|
||||
return command, nil
|
||||
}
|
||||
}
|
||||
return utils.Command{}, fmt.Errorf("command %s not supported", cmd)
|
||||
}
|
||||
|
||||
func (server *Server) raftApply(ctx context.Context, cmd []string) ([]byte, error) {
|
||||
serverId, _ := ctx.Value(utils.ContextServerID("ServerID")).(string)
|
||||
connectionId, _ := ctx.Value(utils.ContextConnID("ConnectionID")).(string)
|
||||
|
||||
applyRequest := utils.ApplyRequest{
|
||||
ServerID: serverId,
|
||||
ConnectionID: connectionId,
|
||||
CMD: cmd,
|
||||
}
|
||||
|
||||
b, err := json.Marshal(applyRequest)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.New("could not parse request")
|
||||
}
|
||||
|
||||
applyFuture := server.raft.Apply(b, 500*time.Millisecond)
|
||||
|
||||
if err := applyFuture.Error(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, ok := applyFuture.Response().(utils.ApplyResponse)
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unprocessable entity %v", r)
|
||||
}
|
||||
|
||||
if r.Error != nil {
|
||||
return nil, r.Error
|
||||
}
|
||||
|
||||
return r.Response, nil
|
||||
}
|
||||
|
||||
func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
server.ACL.RegisterConnection(&conn)
|
||||
|
||||
connRW := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
||||
|
||||
cid := server.connID.Add(1)
|
||||
ctx = context.WithValue(ctx, utils.ContextConnID("ConnectionID"),
|
||||
fmt.Sprintf("%s-%d", ctx.Value(utils.ContextServerID("ServerID")), cid))
|
||||
|
||||
for {
|
||||
message, err := utils.ReadMessage(connRW)
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
// Connection closed
|
||||
break
|
||||
}
|
||||
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||
// Connection timeout
|
||||
fmt.Println(err)
|
||||
break
|
||||
}
|
||||
if err, ok := err.(tls.RecordHeaderError); ok {
|
||||
// TLS verification error
|
||||
fmt.Println(err)
|
||||
break
|
||||
}
|
||||
fmt.Println(err)
|
||||
break
|
||||
}
|
||||
|
||||
if cmd, err := utils.Decode(message); err != nil {
|
||||
// Return error to client
|
||||
connRW.Write([]byte(fmt.Sprintf("-Error %s\r\n\n", err.Error())))
|
||||
connRW.Flush()
|
||||
continue
|
||||
} else {
|
||||
command, err := server.getCommand(cmd[0])
|
||||
|
||||
if err != nil {
|
||||
connRW.WriteString(fmt.Sprintf("-%s\r\n\n", err.Error()))
|
||||
connRW.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
synchronize := command.Sync
|
||||
handler := command.HandlerFunc
|
||||
|
||||
subCommand, ok := utils.GetSubCommand(command, cmd).(utils.SubCommand)
|
||||
|
||||
if ok {
|
||||
synchronize = subCommand.Sync
|
||||
handler = subCommand.HandlerFunc
|
||||
}
|
||||
|
||||
if err := server.ACL.AuthorizeConnection(&conn, cmd, command, subCommand); err != nil {
|
||||
connRW.WriteString(fmt.Sprintf("-%s\r\n\n", err.Error()))
|
||||
connRW.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
if !server.IsInCluster() || !synchronize {
|
||||
if res, err := handler(ctx, cmd, server, &conn); err != nil {
|
||||
connRW.Write([]byte(fmt.Sprintf("-%s\r\n\n", err.Error())))
|
||||
} else {
|
||||
connRW.Write(res)
|
||||
// TODO: Write successful, add entry to AOF
|
||||
}
|
||||
connRW.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle other commands that need to be synced across the cluster
|
||||
if server.raft.IsRaftLeader() {
|
||||
if res, err := server.raftApply(ctx, cmd); err != nil {
|
||||
connRW.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error())))
|
||||
} else {
|
||||
connRW.Write(res)
|
||||
}
|
||||
connRW.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
// Forward message to leader and return immediate OK response
|
||||
if server.config.ForwardCommand {
|
||||
server.memberList.ForwardDataMutation(ctx, message)
|
||||
connRW.Write([]byte(utils.OK_RESPONSE))
|
||||
connRW.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
connRW.Write([]byte("-Error not cluster leader, cannot carry out command\r\n\r\n"))
|
||||
connRW.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func (server *Server) StartTCP(ctx context.Context) {
|
||||
conf := server.config
|
||||
|
||||
listenConfig := net.ListenConfig{
|
||||
KeepAlive: 200 * time.Millisecond,
|
||||
}
|
||||
|
||||
listener, err := listenConfig.Listen(ctx, "tcp", fmt.Sprintf("%s:%d", conf.BindAddr, conf.Port))
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if !conf.TLS {
|
||||
// TCP
|
||||
fmt.Printf("Starting TCP server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
|
||||
}
|
||||
|
||||
if conf.TLS {
|
||||
// TLS
|
||||
fmt.Printf("Starting TLS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
|
||||
cer, err := tls.LoadX509KeyPair(conf.Cert, conf.Key)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
listener = tls.NewListener(listener, &tls.Config{
|
||||
Certificates: []tls.Certificate{cer},
|
||||
})
|
||||
}
|
||||
|
||||
// Listen to connection
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
fmt.Println("Could not establish connection")
|
||||
continue
|
||||
}
|
||||
// Read loop for connection
|
||||
go server.handleConnection(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) StartHTTP(ctx context.Context) {
|
||||
conf := server.config
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Hello from memstore!"))
|
||||
})
|
||||
|
||||
var err error
|
||||
|
||||
if conf.TLS {
|
||||
fmt.Printf("Starting HTTPS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
|
||||
err = http.ListenAndServeTLS(fmt.Sprintf("%s:%d", conf.BindAddr, conf.Port), conf.Cert, conf.Key, nil)
|
||||
} else {
|
||||
fmt.Printf("Starting HTTP server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
|
||||
err = http.ListenAndServe(fmt.Sprintf("%s:%d", conf.BindAddr, conf.Port), nil)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) LoadCommands(plugin utils.Plugin) {
|
||||
commands := plugin.Commands()
|
||||
for _, command := range commands {
|
||||
server.commands = append(server.commands, command)
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) LoadModules(ctx context.Context) {
|
||||
server.LoadCommands(acl.NewModule())
|
||||
server.LoadCommands(pubsub.NewModule())
|
||||
server.LoadCommands(ping.NewModule())
|
||||
server.LoadCommands(get.NewModule())
|
||||
server.LoadCommands(list.NewModule())
|
||||
server.LoadCommands(str.NewModule())
|
||||
server.LoadCommands(etc.NewModule())
|
||||
server.LoadCommands(set.NewModule())
|
||||
server.LoadCommands(sorted_set.NewModule())
|
||||
server.LoadCommands(hash.NewModule())
|
||||
}
|
||||
|
||||
func (server *Server) Start(ctx context.Context) {
|
||||
conf := server.config
|
||||
|
||||
server.store = make(map[string]interface{})
|
||||
server.keyLocks = make(map[string]*sync.RWMutex)
|
||||
server.keyCreationLock = &sync.Mutex{}
|
||||
|
||||
server.LoadModules(ctx)
|
||||
|
||||
if conf.TLS && (len(conf.Key) <= 0 || len(conf.Cert) <= 0) {
|
||||
fmt.Println("Must provide key and certificate file paths for TLS mode.")
|
||||
return
|
||||
}
|
||||
|
||||
if server.IsInCluster() {
|
||||
// Initialise raft and memberlist
|
||||
server.raft = rl.NewRaft(rl.RaftOpts{
|
||||
Config: conf,
|
||||
Server: server,
|
||||
GetCommand: server.getCommand,
|
||||
})
|
||||
server.memberList = ml.NewMemberList(ml.MemberlistOpts{
|
||||
Config: conf,
|
||||
HasJoinedCluster: server.raft.HasJoinedCluster,
|
||||
AddVoter: server.raft.AddVoter,
|
||||
RemoveRaftServer: server.raft.RemoveServer,
|
||||
IsRaftLeader: server.raft.IsRaftLeader,
|
||||
ApplyMutate: server.raftApply,
|
||||
})
|
||||
server.raft.RaftInit(ctx)
|
||||
server.memberList.MemberListInit(ctx)
|
||||
}
|
||||
|
||||
if conf.HTTP {
|
||||
server.StartHTTP(ctx)
|
||||
} else {
|
||||
server.StartTCP(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) IsInCluster() bool {
|
||||
return server.config.BootstrapCluster || server.config.JoinAddr != ""
|
||||
}
|
||||
|
||||
func (server *Server) ShutDown(ctx context.Context) {
|
||||
if server.IsInCluster() {
|
||||
server.raft.RaftShutdown(ctx)
|
||||
server.memberList.MemberListShutdown(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
config, err := utils.GetConfig()
|
||||
|
||||
@@ -420,7 +22,7 @@ func main() {
|
||||
|
||||
ctx := context.WithValue(context.Background(), utils.ContextServerID("ServerID"), config.ServerID)
|
||||
|
||||
// Default BindAddr if it's not etc
|
||||
// Default BindAddr if it's not specified
|
||||
if config.BindAddr == "" {
|
||||
if addr, err := utils.GetIPAddress(); err != nil {
|
||||
log.Fatal(err)
|
||||
@@ -432,20 +34,17 @@ func main() {
|
||||
cancelCh := make(chan os.Signal, 1)
|
||||
signal.Notify(cancelCh, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
|
||||
|
||||
server := &Server{
|
||||
config: config,
|
||||
|
||||
connID: atomic.Uint64{},
|
||||
|
||||
ACL: acl.NewACL(config),
|
||||
PubSub: pubsub.NewPubSub(),
|
||||
|
||||
cancelCh: &cancelCh,
|
||||
s := &server.Server{
|
||||
Config: config,
|
||||
ConnID: atomic.Uint64{},
|
||||
ACL: acl.NewACL(config),
|
||||
PubSub: pubsub.NewPubSub(),
|
||||
CancelCh: &cancelCh,
|
||||
}
|
||||
|
||||
go server.Start(ctx)
|
||||
go s.Start(ctx)
|
||||
|
||||
<-cancelCh
|
||||
|
||||
server.ShutDown(ctx)
|
||||
s.ShutDown(ctx)
|
||||
}
|
||||
|
49
src/server/cluster.go
Normal file
49
src/server/cluster.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/kelvinmwinuka/memstore/src/utils"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (server *Server) IsInCluster() bool {
|
||||
return server.Config.BootstrapCluster || server.Config.JoinAddr != ""
|
||||
}
|
||||
|
||||
func (server *Server) raftApply(ctx context.Context, cmd []string) ([]byte, error) {
|
||||
serverId, _ := ctx.Value(utils.ContextServerID("ServerID")).(string)
|
||||
connectionId, _ := ctx.Value(utils.ContextConnID("ConnectionID")).(string)
|
||||
|
||||
applyRequest := utils.ApplyRequest{
|
||||
ServerID: serverId,
|
||||
ConnectionID: connectionId,
|
||||
CMD: cmd,
|
||||
}
|
||||
|
||||
b, err := json.Marshal(applyRequest)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.New("could not parse request")
|
||||
}
|
||||
|
||||
applyFuture := server.raft.Apply(b, 500*time.Millisecond)
|
||||
|
||||
if err := applyFuture.Error(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, ok := applyFuture.Response().(utils.ApplyResponse)
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unprocessable entity %v", r)
|
||||
}
|
||||
|
||||
if r.Error != nil {
|
||||
return nil, r.Error
|
||||
}
|
||||
|
||||
return r.Response, nil
|
||||
}
|
73
src/server/keyspace.go
Normal file
73
src/server/keyspace.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (server *Server) KeyLock(ctx context.Context, key string) (bool, error) {
|
||||
ticker := time.NewTicker(5 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
default:
|
||||
ok := server.keyLocks[key].TryLock()
|
||||
if ok {
|
||||
return true, nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return false, context.Cause(ctx)
|
||||
}
|
||||
<-ticker.C
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) KeyUnlock(key string) {
|
||||
server.keyLocks[key].Unlock()
|
||||
}
|
||||
|
||||
func (server *Server) KeyRLock(ctx context.Context, key string) (bool, error) {
|
||||
ticker := time.NewTicker(5 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
default:
|
||||
ok := server.keyLocks[key].TryRLock()
|
||||
if ok {
|
||||
return true, nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return false, context.Cause(ctx)
|
||||
}
|
||||
<-ticker.C
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) KeyRUnlock(key string) {
|
||||
server.keyLocks[key].RUnlock()
|
||||
}
|
||||
|
||||
func (server *Server) KeyExists(key string) bool {
|
||||
return server.keyLocks[key] != nil
|
||||
}
|
||||
|
||||
func (server *Server) CreateKeyAndLock(ctx context.Context, key string) (bool, error) {
|
||||
server.keyCreationLock.Lock()
|
||||
defer server.keyCreationLock.Unlock()
|
||||
|
||||
if !server.KeyExists(key) {
|
||||
keyLock := &sync.RWMutex{}
|
||||
keyLock.Lock()
|
||||
server.keyLocks[key] = keyLock
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return server.KeyLock(ctx, key)
|
||||
}
|
||||
|
||||
func (server *Server) GetValue(key string) interface{} {
|
||||
return server.store[key]
|
||||
}
|
||||
|
||||
func (server *Server) SetValue(ctx context.Context, key string, value interface{}) {
|
||||
server.store[key] = value
|
||||
}
|
59
src/server/modules.go
Normal file
59
src/server/modules.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/acl"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/etc"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/get"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/hash"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/list"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/ping"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/pubsub"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/set"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/sorted_set"
|
||||
str "github.com/kelvinmwinuka/memstore/src/modules/string"
|
||||
"github.com/kelvinmwinuka/memstore/src/utils"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (server *Server) LoadCommands(plugin utils.Plugin) {
|
||||
commands := plugin.Commands()
|
||||
for _, command := range commands {
|
||||
server.commands = append(server.commands, command)
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) LoadModules(ctx context.Context) {
|
||||
server.LoadCommands(acl.NewModule())
|
||||
server.LoadCommands(pubsub.NewModule())
|
||||
server.LoadCommands(ping.NewModule())
|
||||
server.LoadCommands(get.NewModule())
|
||||
server.LoadCommands(list.NewModule())
|
||||
server.LoadCommands(str.NewModule())
|
||||
server.LoadCommands(etc.NewModule())
|
||||
server.LoadCommands(set.NewModule())
|
||||
server.LoadCommands(sorted_set.NewModule())
|
||||
server.LoadCommands(hash.NewModule())
|
||||
}
|
||||
|
||||
func (server *Server) GetAllCommands(ctx context.Context) []utils.Command {
|
||||
return server.commands
|
||||
}
|
||||
|
||||
func (server *Server) GetACL() interface{} {
|
||||
return server.ACL
|
||||
}
|
||||
|
||||
func (server *Server) GetPubSub() interface{} {
|
||||
return server.PubSub
|
||||
}
|
||||
|
||||
func (server *Server) getCommand(cmd string) (utils.Command, error) {
|
||||
for _, command := range server.commands {
|
||||
if strings.EqualFold(command.Command, cmd) {
|
||||
return command, nil
|
||||
}
|
||||
}
|
||||
return utils.Command{}, fmt.Errorf("command %s not supported", cmd)
|
||||
}
|
@@ -1 +1,225 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/kelvinmwinuka/memstore/src/memberlist"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/acl"
|
||||
"github.com/kelvinmwinuka/memstore/src/modules/pubsub"
|
||||
"github.com/kelvinmwinuka/memstore/src/raft"
|
||||
"github.com/kelvinmwinuka/memstore/src/utils"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Config utils.Config
|
||||
|
||||
ConnID atomic.Uint64
|
||||
|
||||
store map[string]interface{}
|
||||
keyLocks map[string]*sync.RWMutex
|
||||
keyCreationLock *sync.Mutex
|
||||
|
||||
commands []utils.Command
|
||||
|
||||
raft *raft.Raft
|
||||
memberList *memberlist.MemberList
|
||||
|
||||
CancelCh *chan os.Signal
|
||||
|
||||
ACL *acl.ACL
|
||||
PubSub *pubsub.PubSub
|
||||
}
|
||||
|
||||
func (server *Server) StartTCP(ctx context.Context) {
|
||||
conf := server.Config
|
||||
|
||||
listenConfig := net.ListenConfig{
|
||||
KeepAlive: 200 * time.Millisecond,
|
||||
}
|
||||
|
||||
listener, err := listenConfig.Listen(ctx, "tcp", fmt.Sprintf("%s:%d", conf.BindAddr, conf.Port))
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if !conf.TLS {
|
||||
// TCP
|
||||
fmt.Printf("Starting TCP server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
|
||||
}
|
||||
|
||||
if conf.TLS {
|
||||
// TLS
|
||||
fmt.Printf("Starting TLS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
|
||||
cer, err := tls.LoadX509KeyPair(conf.Cert, conf.Key)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
listener = tls.NewListener(listener, &tls.Config{
|
||||
Certificates: []tls.Certificate{cer},
|
||||
})
|
||||
}
|
||||
|
||||
// Listen to connection
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
fmt.Println("Could not establish connection")
|
||||
continue
|
||||
}
|
||||
// Read loop for connection
|
||||
go server.handleConnection(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
server.ACL.RegisterConnection(&conn)
|
||||
|
||||
connRW := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
||||
|
||||
cid := server.ConnID.Add(1)
|
||||
ctx = context.WithValue(ctx, utils.ContextConnID("ConnectionID"),
|
||||
fmt.Sprintf("%s-%d", ctx.Value(utils.ContextServerID("ServerID")), cid))
|
||||
|
||||
for {
|
||||
message, err := utils.ReadMessage(connRW)
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
// Connection closed
|
||||
break
|
||||
}
|
||||
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||
// Connection timeout
|
||||
fmt.Println(err)
|
||||
break
|
||||
}
|
||||
if err, ok := err.(tls.RecordHeaderError); ok {
|
||||
// TLS verification error
|
||||
fmt.Println(err)
|
||||
break
|
||||
}
|
||||
fmt.Println(err)
|
||||
break
|
||||
}
|
||||
|
||||
if cmd, err := utils.Decode(message); err != nil {
|
||||
// Return error to client
|
||||
connRW.Write([]byte(fmt.Sprintf("-Error %s\r\n\n", err.Error())))
|
||||
connRW.Flush()
|
||||
continue
|
||||
} else {
|
||||
command, err := server.getCommand(cmd[0])
|
||||
|
||||
if err != nil {
|
||||
connRW.WriteString(fmt.Sprintf("-%s\r\n\n", err.Error()))
|
||||
connRW.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
synchronize := command.Sync
|
||||
handler := command.HandlerFunc
|
||||
|
||||
subCommand, ok := utils.GetSubCommand(command, cmd).(utils.SubCommand)
|
||||
|
||||
if ok {
|
||||
synchronize = subCommand.Sync
|
||||
handler = subCommand.HandlerFunc
|
||||
}
|
||||
|
||||
if err := server.ACL.AuthorizeConnection(&conn, cmd, command, subCommand); err != nil {
|
||||
connRW.WriteString(fmt.Sprintf("-%s\r\n\n", err.Error()))
|
||||
connRW.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
if !server.IsInCluster() || !synchronize {
|
||||
if res, err := handler(ctx, cmd, server, &conn); err != nil {
|
||||
connRW.Write([]byte(fmt.Sprintf("-%s\r\n\n", err.Error())))
|
||||
} else {
|
||||
connRW.Write(res)
|
||||
// TODO: Write successful, add entry to AOF
|
||||
}
|
||||
connRW.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle other commands that need to be synced across the cluster
|
||||
if server.raft.IsRaftLeader() {
|
||||
if res, err := server.raftApply(ctx, cmd); err != nil {
|
||||
connRW.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error())))
|
||||
} else {
|
||||
connRW.Write(res)
|
||||
}
|
||||
connRW.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
// Forward message to leader and return immediate OK response
|
||||
if server.Config.ForwardCommand {
|
||||
server.memberList.ForwardDataMutation(ctx, message)
|
||||
connRW.Write([]byte(utils.OK_RESPONSE))
|
||||
connRW.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
connRW.Write([]byte("-Error not cluster leader, cannot carry out command\r\n\r\n"))
|
||||
connRW.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func (server *Server) Start(ctx context.Context) {
|
||||
conf := server.Config
|
||||
|
||||
server.store = make(map[string]interface{})
|
||||
server.keyLocks = make(map[string]*sync.RWMutex)
|
||||
server.keyCreationLock = &sync.Mutex{}
|
||||
|
||||
server.LoadModules(ctx)
|
||||
|
||||
if conf.TLS && (len(conf.Key) <= 0 || len(conf.Cert) <= 0) {
|
||||
fmt.Println("Must provide key and certificate file paths for TLS mode.")
|
||||
return
|
||||
}
|
||||
|
||||
if server.IsInCluster() {
|
||||
// Initialise raft and memberlist
|
||||
server.raft = raft.NewRaft(raft.RaftOpts{
|
||||
Config: conf,
|
||||
Server: server,
|
||||
GetCommand: server.getCommand,
|
||||
})
|
||||
server.memberList = memberlist.NewMemberList(memberlist.MemberlistOpts{
|
||||
Config: conf,
|
||||
HasJoinedCluster: server.raft.HasJoinedCluster,
|
||||
AddVoter: server.raft.AddVoter,
|
||||
RemoveRaftServer: server.raft.RemoveServer,
|
||||
IsRaftLeader: server.raft.IsRaftLeader,
|
||||
ApplyMutate: server.raftApply,
|
||||
})
|
||||
server.raft.RaftInit(ctx)
|
||||
server.memberList.MemberListInit(ctx)
|
||||
}
|
||||
|
||||
server.StartTCP(ctx)
|
||||
}
|
||||
|
||||
func (server *Server) ShutDown(ctx context.Context) {
|
||||
if server.IsInCluster() {
|
||||
server.raft.RaftShutdown(ctx)
|
||||
server.memberList.MemberListShutdown(ctx)
|
||||
}
|
||||
}
|
||||
|
@@ -15,7 +15,6 @@ type Config struct {
|
||||
Key string `json:"key" yaml:"key"`
|
||||
Cert string `json:"cert" yaml:"cert"`
|
||||
Port uint16 `json:"port" yaml:"port"`
|
||||
HTTP bool `json:"http" yaml:"http"`
|
||||
PluginDir string `json:"plugins" yaml:"plugins"`
|
||||
ServerID string `json:"serverId" yaml:"serverId"`
|
||||
JoinAddr string `json:"joinAddr" yaml:"joinAddr"`
|
||||
@@ -36,7 +35,6 @@ func GetConfig() (Config, error) {
|
||||
key := flag.String("key", "", "The private key file path.")
|
||||
cert := flag.String("cert", "", "The signed certificate file path.")
|
||||
port := flag.Int("port", 7480, "Port to use. Default is 7480")
|
||||
http := flag.Bool("http", false, "Use HTTP protocol instead of raw TCP. Default is false")
|
||||
pluginDir := flag.String("pluginDir", "", "Directory where plugins are located.")
|
||||
serverId := flag.String("serverId", "1", "Server ID in raft cluster. Leave empty for client.")
|
||||
joinAddr := flag.String("joinAddr", "", "Address of cluster member in a cluster to you want to join.")
|
||||
@@ -75,7 +73,6 @@ It is a plain text value by default but you can provide a SHA256 hash by adding
|
||||
TLS: *tls,
|
||||
Key: *key,
|
||||
Cert: *cert,
|
||||
HTTP: *http,
|
||||
PluginDir: *pluginDir,
|
||||
Port: uint16(*port),
|
||||
ServerID: *serverId,
|
||||
|
Reference in New Issue
Block a user