mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-13 19:55:14 +08:00
331 lines
6.9 KiB
Go
331 lines
6.9 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"plugin"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/hashicorp/memberlist"
|
|
"github.com/hashicorp/raft"
|
|
"github.com/kelvinmwinuka/memstore/utils"
|
|
)
|
|
|
|
type Plugin interface {
|
|
Name() string
|
|
Commands() []string
|
|
Description() string
|
|
HandleCommand(cmd []string, server interface{}, conn *bufio.Writer)
|
|
}
|
|
|
|
type Data struct {
|
|
mu sync.Mutex
|
|
data map[string]interface{}
|
|
}
|
|
|
|
type Server struct {
|
|
config Config
|
|
data Data
|
|
plugins []Plugin
|
|
raft *raft.Raft
|
|
memberList *memberlist.Memberlist
|
|
}
|
|
|
|
func (server *Server) Lock() {
|
|
server.data.mu.Lock()
|
|
}
|
|
|
|
func (server *Server) Unlock() {
|
|
server.data.mu.Unlock()
|
|
}
|
|
|
|
func (server *Server) GetData(key string) interface{} {
|
|
return server.data.data[key]
|
|
}
|
|
|
|
func (server *Server) SetData(key string, value interface{}) {
|
|
server.data.data[key] = value
|
|
}
|
|
|
|
func (server *Server) handleConnection(conn net.Conn) {
|
|
connRW := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
|
|
|
for {
|
|
message, err := utils.ReadMessage(connRW)
|
|
|
|
if err != nil && err == io.EOF {
|
|
// Connection closed
|
|
break
|
|
}
|
|
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
continue
|
|
}
|
|
|
|
if cmd, err := utils.Decode(message); err != nil {
|
|
// Return error to client
|
|
connRW.Write([]byte(fmt.Sprintf("-Error %s\r\n\n", err.Error())))
|
|
connRW.Flush()
|
|
continue
|
|
} else {
|
|
// Look for plugin that handles this command and trigger it
|
|
handled := false
|
|
|
|
for _, plugin := range server.plugins {
|
|
if utils.Contains[string](plugin.Commands(), strings.ToLower(cmd[0])) {
|
|
plugin.HandleCommand(cmd, server, connRW.Writer)
|
|
handled = true
|
|
}
|
|
}
|
|
|
|
if !handled {
|
|
connRW.Write([]byte(fmt.Sprintf("-Error %s command not supported\r\n\n", strings.ToUpper(cmd[0]))))
|
|
connRW.Flush()
|
|
}
|
|
}
|
|
}
|
|
|
|
conn.Close()
|
|
}
|
|
|
|
func (server *Server) StartTCP() {
|
|
conf := server.config
|
|
var listener net.Listener
|
|
|
|
if conf.TLS {
|
|
// TLS
|
|
fmt.Printf("Starting TLS server at Address %s, Port %d...\n", conf.Addr, conf.Port)
|
|
cer, err := tls.LoadX509KeyPair(conf.Cert, conf.Key)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if l, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", conf.Addr, conf.Port), &tls.Config{
|
|
Certificates: []tls.Certificate{cer},
|
|
}); err != nil {
|
|
panic(err)
|
|
} else {
|
|
listener = l
|
|
}
|
|
}
|
|
|
|
if !conf.TLS {
|
|
// TCP
|
|
fmt.Printf("Starting TCP server at Address %s, Port %d...\n", conf.Addr, conf.Port)
|
|
if l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", conf.Addr, conf.Port)); err != nil {
|
|
panic(err)
|
|
} else {
|
|
listener = l
|
|
}
|
|
}
|
|
|
|
// Listen to connection
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
fmt.Println("Could not establish connection")
|
|
continue
|
|
}
|
|
// Read loop for connection
|
|
go server.handleConnection(conn)
|
|
}
|
|
}
|
|
|
|
func (server *Server) StartHTTP() {
|
|
conf := server.config
|
|
|
|
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("Hello from memstore!"))
|
|
})
|
|
|
|
var err error
|
|
|
|
if conf.TLS {
|
|
fmt.Printf("Starting HTTPS server at Address %s, Port %d...\n", conf.Addr, conf.Port)
|
|
err = http.ListenAndServeTLS(fmt.Sprintf("%s:%d", conf.Addr, conf.Port), conf.Cert, conf.Key, nil)
|
|
} else {
|
|
fmt.Printf("Starting HTTP server at Address %s, Port %d...\n", conf.Addr, conf.Port)
|
|
err = http.ListenAndServe(fmt.Sprintf("%s:%d", conf.Addr, conf.Port), nil)
|
|
}
|
|
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func (server *Server) LoadPlugins() {
|
|
conf := server.config
|
|
|
|
// Load plugins
|
|
pluginDirs, err := os.ReadDir(conf.Plugins)
|
|
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
for _, file := range pluginDirs {
|
|
if file.IsDir() {
|
|
switch file.Name() {
|
|
case "commands":
|
|
files, err := os.ReadDir(path.Join(conf.Plugins, "commands"))
|
|
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
for _, file := range files {
|
|
if !strings.HasSuffix(file.Name(), ".so") {
|
|
// Skip files that are not .so
|
|
continue
|
|
}
|
|
p, err := plugin.Open(path.Join(conf.Plugins, "commands", file.Name()))
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
pluginSymbol, err := p.Lookup("Plugin")
|
|
if err != nil {
|
|
fmt.Printf("unexpected plugin symbol in plugin %s\n", file.Name())
|
|
continue
|
|
}
|
|
|
|
plugin, ok := pluginSymbol.(Plugin)
|
|
if !ok {
|
|
fmt.Printf("invalid plugin signature in plugin %s \n", file.Name())
|
|
continue
|
|
}
|
|
|
|
// Check if a plugin that handles the same command already exists
|
|
for _, loadedPlugin := range server.plugins {
|
|
containsMutual, elem := utils.ContainsMutual[string](loadedPlugin.Commands(), plugin.Commands())
|
|
if containsMutual {
|
|
fmt.Printf("plugin that handles %s command already exists. Please handle a different command.\n", elem)
|
|
}
|
|
}
|
|
|
|
server.plugins = append(server.plugins, plugin)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (server *Server) Start() {
|
|
server.data.data = make(map[string]interface{})
|
|
|
|
conf := server.config
|
|
|
|
server.LoadPlugins()
|
|
|
|
if conf.TLS && (len(conf.Key) <= 0 || len(conf.Cert) <= 0) {
|
|
fmt.Println("Must provide key and certificate file paths for TLS mode.")
|
|
return
|
|
}
|
|
|
|
if addr, err := getServerAddresses(); err != nil {
|
|
log.Fatal(err)
|
|
} else {
|
|
conf.Addr = addr
|
|
server.config.Addr = addr
|
|
}
|
|
|
|
raftConfig := raft.DefaultConfig()
|
|
raftConfig.LocalID = raft.ServerID(conf.ServerID)
|
|
|
|
raftLogStore := raft.NewInmemStore()
|
|
raftStableStore := raft.NewInmemStore()
|
|
raftSnapshotStore := raft.NewInmemSnapshotStore()
|
|
|
|
raftAddr := fmt.Sprintf("%s:%d", conf.Addr, conf.ClusterPort)
|
|
raftAdvertiseAddr, err := net.ResolveTCPAddr("tcp", raftAddr)
|
|
if err != nil {
|
|
log.Fatal("Could not resolve advertise address.")
|
|
}
|
|
|
|
raftTransport, err := raft.NewTCPTransport(
|
|
raftAddr,
|
|
raftAdvertiseAddr,
|
|
10,
|
|
500*time.Millisecond,
|
|
os.Stdout,
|
|
)
|
|
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
// Start raft server
|
|
raftServer, err := raft.NewRaft(
|
|
raftConfig,
|
|
&raft.MockFSM{},
|
|
raftLogStore,
|
|
raftStableStore,
|
|
raftSnapshotStore,
|
|
raftTransport,
|
|
)
|
|
|
|
if err != nil {
|
|
log.Fatalf("Could not start node with error; %s", err)
|
|
}
|
|
|
|
server.raft = raftServer
|
|
|
|
if conf.JoinAddr == "" {
|
|
// Bootstrap raft cluster
|
|
if err := server.raft.BootstrapCluster(raft.Configuration{
|
|
Servers: []raft.Server{
|
|
{
|
|
Suffrage: raft.Voter,
|
|
ID: raft.ServerID(conf.ServerID),
|
|
Address: raft.ServerAddress(raftAddr),
|
|
},
|
|
},
|
|
}).Error(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
if conf.HTTP {
|
|
server.StartHTTP()
|
|
} else {
|
|
server.StartTCP()
|
|
}
|
|
}
|
|
|
|
func getServerAddresses() (string, error) {
|
|
addrs, err := net.InterfaceAddrs()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
for _, address := range addrs {
|
|
// check the address type and if it is not a loopback the display it
|
|
if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
|
|
if ipnet.IP.To4() != nil {
|
|
return ipnet.IP.String(), nil
|
|
}
|
|
}
|
|
}
|
|
|
|
return "", errors.New("could not get IP Addresses")
|
|
}
|
|
|
|
func main() {
|
|
config := GetConfig()
|
|
|
|
server := &Server{
|
|
config: config,
|
|
}
|
|
server.Start()
|
|
}
|