Files
SugarDB/server/main.go
Kelvin Clement Mwinuka f3d36b3c3d Added Error to decoded types.
Removed cmd_functions.go as it is replaced by plugin system.
Consolidated get and set plugins into one plugin setget that handles GET, SET and MGET commands.
Plugins can declare multiple commands that they handle.
2023-07-03 12:35:37 +08:00

232 lines
4.8 KiB
Go

package main
import (
"bufio"
"crypto/tls"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"path"
"plugin"
"strings"
"sync"
"github.com/kelvinmwinuka/memstore/serialization"
"github.com/kelvinmwinuka/memstore/utils"
)
type Plugin interface {
Name() string
Commands() []string
Description() string
HandleCommand(cmd []string, server interface{}, conn *bufio.Writer)
}
type Data struct {
mu sync.Mutex
data map[string]interface{}
}
type Server struct {
config utils.Config
data Data
plugins []Plugin
}
func (server *Server) GetData(key string) interface{} {
server.data.mu.Lock()
defer server.data.mu.Unlock()
return server.data.data[key]
}
func (server *Server) SetData(key string, value interface{}) {
server.data.mu.Lock()
defer server.data.mu.Unlock()
server.data.data[key] = value
}
func (server *Server) handleConnection(conn net.Conn) {
connRW := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
for {
message, err := utils.ReadMessage(connRW)
if err != nil && err == io.EOF {
fmt.Println(err)
break
}
if err != nil {
fmt.Println(err)
continue
}
if cmd, err := serialization.Decode(message); err != nil {
// Return error to client
serialization.Encode(connRW, fmt.Sprintf("Error %s", err.Error()))
continue
} else {
// Look for plugin that handles this command and trigger it
for _, plugin := range server.plugins {
fmt.Println(server)
if utils.Contains[string](plugin.Commands(), strings.ToLower(cmd[0])) {
plugin.HandleCommand(cmd, server, connRW.Writer)
}
}
}
}
conn.Close()
}
func (server *Server) StartTCP() {
conf := server.config
var listener net.Listener
if conf.TLS {
// TLS
fmt.Println("TCP/TLS mode enabled...")
cer, err := tls.LoadX509KeyPair(conf.Cert, conf.Key)
if err != nil {
panic(err)
}
if l, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", "localhost", conf.Port), &tls.Config{
Certificates: []tls.Certificate{cer},
}); err != nil {
panic(err)
} else {
listener = l
}
}
if !conf.TLS {
// TCP
fmt.Println("Starting server in TCP mode...")
if l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", "localhost", conf.Port)); err != nil {
panic(err)
} else {
listener = l
}
}
// Listen to connection
for {
conn, err := listener.Accept()
if err != nil {
fmt.Println("Could not establish connection")
continue
}
// Read loop for connection
go server.handleConnection(conn)
}
}
func (server *Server) StartHTTP() {
conf := server.config
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello from memstore!"))
})
var err error
if conf.TLS {
fmt.Println("Starting server in HTTPS mode...")
err = http.ListenAndServeTLS(fmt.Sprintf("%s:%d", "localhost", conf.Port), conf.Cert, conf.Key, nil)
} else {
fmt.Println("Starting server in HTTP mode...")
err = http.ListenAndServe(fmt.Sprintf("%s:%d", "localhost", conf.Port), nil)
}
if err != nil {
panic(err)
}
}
func (server *Server) LoadPlugins() {
conf := server.config
// Load plugins
pluginDirs, err := os.ReadDir(conf.Plugins)
if err != nil {
log.Fatal(err)
}
for _, file := range pluginDirs {
if file.IsDir() {
switch file.Name() {
case "commands":
files, err := os.ReadDir(path.Join(conf.Plugins, "commands"))
if err != nil {
log.Fatal(err)
}
for _, file := range files {
if !strings.HasSuffix(file.Name(), ".so") {
// Skip files that are not .so
continue
}
p, err := plugin.Open(path.Join(conf.Plugins, "commands", file.Name()))
if err != nil {
log.Fatal(err)
}
pluginSymbol, err := p.Lookup("Plugin")
if err != nil {
fmt.Printf("unexpected plugin symbol in plugin %s\n", file.Name())
continue
}
plugin, ok := pluginSymbol.(Plugin)
if !ok {
fmt.Printf("invalid plugin signature in plugin %s \n", file.Name())
continue
}
// Check if a plugin that handles the same command already exists
for _, loadedPlugin := range server.plugins {
containsMutual, elem := utils.ContainsMutual[string](loadedPlugin.Commands(), plugin.Commands())
if containsMutual {
fmt.Printf("plugin that handles %s command already exists. Please handle a different command.\n", elem)
}
}
server.plugins = append(server.plugins, plugin)
}
}
}
}
}
func (server *Server) Start() {
server.data.data = make(map[string]interface{})
server.config = utils.GetConfig()
conf := server.config
server.LoadPlugins()
if conf.TLS && (len(conf.Key) <= 0 || len(conf.Cert) <= 0) {
fmt.Println("Must provide key and certificate file paths for TLS mode.")
return
}
if conf.HTTP {
server.StartHTTP()
} else {
server.StartTCP()
}
}
func main() {
server := Server{}
server.Start()
}