Files
SugarDB/server/main.go

331 lines
6.9 KiB
Go

package main
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"path"
"plugin"
"strings"
"sync"
"time"
"github.com/hashicorp/memberlist"
"github.com/hashicorp/raft"
"github.com/kelvinmwinuka/memstore/utils"
)
type Plugin interface {
Name() string
Commands() []string
Description() string
HandleCommand(cmd []string, server interface{}, conn *bufio.Writer)
}
type Data struct {
mu sync.Mutex
data map[string]interface{}
}
type Server struct {
config Config
data Data
plugins []Plugin
raft *raft.Raft
memberList *memberlist.Memberlist
}
func (server *Server) Lock() {
server.data.mu.Lock()
}
func (server *Server) Unlock() {
server.data.mu.Unlock()
}
func (server *Server) GetData(key string) interface{} {
return server.data.data[key]
}
func (server *Server) SetData(key string, value interface{}) {
server.data.data[key] = value
}
func (server *Server) handleConnection(conn net.Conn) {
connRW := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
for {
message, err := utils.ReadMessage(connRW)
if err != nil && err == io.EOF {
// Connection closed
break
}
if err != nil {
fmt.Println(err)
continue
}
if cmd, err := utils.Decode(message); err != nil {
// Return error to client
connRW.Write([]byte(fmt.Sprintf("-Error %s\r\n\n", err.Error())))
connRW.Flush()
continue
} else {
// Look for plugin that handles this command and trigger it
handled := false
for _, plugin := range server.plugins {
if utils.Contains[string](plugin.Commands(), strings.ToLower(cmd[0])) {
plugin.HandleCommand(cmd, server, connRW.Writer)
handled = true
}
}
if !handled {
connRW.Write([]byte(fmt.Sprintf("-Error %s command not supported\r\n\n", strings.ToUpper(cmd[0]))))
connRW.Flush()
}
}
}
conn.Close()
}
func (server *Server) StartTCP() {
conf := server.config
var listener net.Listener
if conf.TLS {
// TLS
fmt.Printf("Starting TLS server at Address %s, Port %d...\n", conf.Addr, conf.Port)
cer, err := tls.LoadX509KeyPair(conf.Cert, conf.Key)
if err != nil {
panic(err)
}
if l, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", conf.Addr, conf.Port), &tls.Config{
Certificates: []tls.Certificate{cer},
}); err != nil {
panic(err)
} else {
listener = l
}
}
if !conf.TLS {
// TCP
fmt.Printf("Starting TCP server at Address %s, Port %d...\n", conf.Addr, conf.Port)
if l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", conf.Addr, conf.Port)); err != nil {
panic(err)
} else {
listener = l
}
}
// Listen to connection
for {
conn, err := listener.Accept()
if err != nil {
fmt.Println("Could not establish connection")
continue
}
// Read loop for connection
go server.handleConnection(conn)
}
}
func (server *Server) StartHTTP() {
conf := server.config
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello from memstore!"))
})
var err error
if conf.TLS {
fmt.Printf("Starting HTTPS server at Address %s, Port %d...\n", conf.Addr, conf.Port)
err = http.ListenAndServeTLS(fmt.Sprintf("%s:%d", conf.Addr, conf.Port), conf.Cert, conf.Key, nil)
} else {
fmt.Printf("Starting HTTP server at Address %s, Port %d...\n", conf.Addr, conf.Port)
err = http.ListenAndServe(fmt.Sprintf("%s:%d", conf.Addr, conf.Port), nil)
}
if err != nil {
panic(err)
}
}
func (server *Server) LoadPlugins() {
conf := server.config
// Load plugins
pluginDirs, err := os.ReadDir(conf.Plugins)
if err != nil {
log.Fatal(err)
}
for _, file := range pluginDirs {
if file.IsDir() {
switch file.Name() {
case "commands":
files, err := os.ReadDir(path.Join(conf.Plugins, "commands"))
if err != nil {
log.Fatal(err)
}
for _, file := range files {
if !strings.HasSuffix(file.Name(), ".so") {
// Skip files that are not .so
continue
}
p, err := plugin.Open(path.Join(conf.Plugins, "commands", file.Name()))
if err != nil {
log.Fatal(err)
}
pluginSymbol, err := p.Lookup("Plugin")
if err != nil {
fmt.Printf("unexpected plugin symbol in plugin %s\n", file.Name())
continue
}
plugin, ok := pluginSymbol.(Plugin)
if !ok {
fmt.Printf("invalid plugin signature in plugin %s \n", file.Name())
continue
}
// Check if a plugin that handles the same command already exists
for _, loadedPlugin := range server.plugins {
containsMutual, elem := utils.ContainsMutual[string](loadedPlugin.Commands(), plugin.Commands())
if containsMutual {
fmt.Printf("plugin that handles %s command already exists. Please handle a different command.\n", elem)
}
}
server.plugins = append(server.plugins, plugin)
}
}
}
}
}
func (server *Server) Start() {
server.data.data = make(map[string]interface{})
conf := server.config
server.LoadPlugins()
if conf.TLS && (len(conf.Key) <= 0 || len(conf.Cert) <= 0) {
fmt.Println("Must provide key and certificate file paths for TLS mode.")
return
}
if addr, err := getServerAddresses(); err != nil {
log.Fatal(err)
} else {
conf.Addr = addr
server.config.Addr = addr
}
raftConfig := raft.DefaultConfig()
raftConfig.LocalID = raft.ServerID(conf.ServerID)
raftLogStore := raft.NewInmemStore()
raftStableStore := raft.NewInmemStore()
raftSnapshotStore := raft.NewInmemSnapshotStore()
raftAddr := fmt.Sprintf("%s:%d", conf.Addr, conf.ClusterPort)
raftAdvertiseAddr, err := net.ResolveTCPAddr("tcp", raftAddr)
if err != nil {
log.Fatal("Could not resolve advertise address.")
}
raftTransport, err := raft.NewTCPTransport(
raftAddr,
raftAdvertiseAddr,
10,
500*time.Millisecond,
os.Stdout,
)
if err != nil {
log.Fatal(err)
}
// Start raft server
raftServer, err := raft.NewRaft(
raftConfig,
&raft.MockFSM{},
raftLogStore,
raftStableStore,
raftSnapshotStore,
raftTransport,
)
if err != nil {
log.Fatalf("Could not start node with error; %s", err)
}
server.raft = raftServer
if conf.JoinAddr == "" {
// Bootstrap raft cluster
if err := server.raft.BootstrapCluster(raft.Configuration{
Servers: []raft.Server{
{
Suffrage: raft.Voter,
ID: raft.ServerID(conf.ServerID),
Address: raft.ServerAddress(raftAddr),
},
},
}).Error(); err != nil {
log.Fatal(err)
}
}
if conf.HTTP {
server.StartHTTP()
} else {
server.StartTCP()
}
}
func getServerAddresses() (string, error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
log.Fatal(err)
}
for _, address := range addrs {
// check the address type and if it is not a loopback the display it
if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
if ipnet.IP.To4() != nil {
return ipnet.IP.String(), nil
}
}
}
return "", errors.New("could not get IP Addresses")
}
func main() {
config := GetConfig()
server := &Server{
config: config,
}
server.Start()
}