package main import ( "bufio" "crypto/tls" "fmt" "io" "log" "net" "net/http" "os" "path" "plugin" "strings" "sync" "github.com/kelvinmwinuka/memstore/serialization" "github.com/kelvinmwinuka/memstore/utils" ) type Plugin interface { Name() string Commands() []string Description() string HandleCommand(cmd []string, server interface{}, conn *bufio.Writer) } type Data struct { mu sync.Mutex data map[string]interface{} } type Server struct { config utils.Config data Data plugins []Plugin } func (server *Server) GetData(key string) interface{} { server.data.mu.Lock() defer server.data.mu.Unlock() return server.data.data[key] } func (server *Server) SetData(key string, value interface{}) { server.data.mu.Lock() defer server.data.mu.Unlock() server.data.data[key] = value } func (server *Server) handleConnection(conn net.Conn) { connRW := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) for { message, err := utils.ReadMessage(connRW) if err != nil && err == io.EOF { fmt.Println(err) break } if err != nil { fmt.Println(err) continue } if cmd, err := serialization.Decode(message); err != nil { // Return error to client serialization.Encode(connRW, fmt.Sprintf("Error %s", err.Error())) continue } else { // Look for plugin that handles this command and trigger it for _, plugin := range server.plugins { fmt.Println(server) if utils.Contains[string](plugin.Commands(), strings.ToLower(cmd[0])) { plugin.HandleCommand(cmd, server, connRW.Writer) } } } } conn.Close() } func (server *Server) StartTCP() { conf := server.config var listener net.Listener if conf.TLS { // TLS fmt.Println("TCP/TLS mode enabled...") cer, err := tls.LoadX509KeyPair(conf.Cert, conf.Key) if err != nil { panic(err) } if l, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", "localhost", conf.Port), &tls.Config{ Certificates: []tls.Certificate{cer}, }); err != nil { panic(err) } else { listener = l } } if !conf.TLS { // TCP fmt.Println("Starting server in TCP mode...") if l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", "localhost", conf.Port)); err != nil { panic(err) } else { listener = l } } // Listen to connection for { conn, err := listener.Accept() if err != nil { fmt.Println("Could not establish connection") continue } // Read loop for connection go server.handleConnection(conn) } } func (server *Server) StartHTTP() { conf := server.config http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello from memstore!")) }) var err error if conf.TLS { fmt.Println("Starting server in HTTPS mode...") err = http.ListenAndServeTLS(fmt.Sprintf("%s:%d", "localhost", conf.Port), conf.Cert, conf.Key, nil) } else { fmt.Println("Starting server in HTTP mode...") err = http.ListenAndServe(fmt.Sprintf("%s:%d", "localhost", conf.Port), nil) } if err != nil { panic(err) } } func (server *Server) LoadPlugins() { conf := server.config // Load plugins pluginDirs, err := os.ReadDir(conf.Plugins) if err != nil { log.Fatal(err) } for _, file := range pluginDirs { if file.IsDir() { switch file.Name() { case "commands": files, err := os.ReadDir(path.Join(conf.Plugins, "commands")) if err != nil { log.Fatal(err) } for _, file := range files { if !strings.HasSuffix(file.Name(), ".so") { // Skip files that are not .so continue } p, err := plugin.Open(path.Join(conf.Plugins, "commands", file.Name())) if err != nil { log.Fatal(err) } pluginSymbol, err := p.Lookup("Plugin") if err != nil { fmt.Printf("unexpected plugin symbol in plugin %s\n", file.Name()) continue } plugin, ok := pluginSymbol.(Plugin) if !ok { fmt.Printf("invalid plugin signature in plugin %s \n", file.Name()) continue } // Check if a plugin that handles the same command already exists for _, loadedPlugin := range server.plugins { containsMutual, elem := utils.ContainsMutual[string](loadedPlugin.Commands(), plugin.Commands()) if containsMutual { fmt.Printf("plugin that handles %s command already exists. Please handle a different command.\n", elem) } } server.plugins = append(server.plugins, plugin) } } } } } func (server *Server) Start() { server.data.data = make(map[string]interface{}) server.config = utils.GetConfig() conf := server.config server.LoadPlugins() if conf.TLS && (len(conf.Key) <= 0 || len(conf.Cert) <= 0) { fmt.Println("Must provide key and certificate file paths for TLS mode.") return } if conf.HTTP { server.StartHTTP() } else { server.StartTCP() } } func main() { server := Server{} server.Start() }