diff --git a/client/Makefile b/client/Makefile new file mode 100644 index 0000000..fcab375 --- /dev/null +++ b/client/Makefile @@ -0,0 +1,5 @@ +build: + go build -o ./bin/client ./*.go + +run: + ./bin/client --config config.json \ No newline at end of file diff --git a/client/main.go b/client/main.go index 5c79b7e..bb1b505 100644 --- a/client/main.go +++ b/client/main.go @@ -9,8 +9,6 @@ import ( "io" "net" "os" - - "github.com/kelvinmwinuka/memstore/utils" ) func main() { @@ -81,7 +79,7 @@ func main() { } // Serialize command and send to connection - encoded, err := utils.Encode(string(in)) + encoded, err := Encode(string(in)) if err != nil { fmt.Println(err) @@ -92,7 +90,7 @@ func main() { connRW.Flush() // Read response from server - message, err := utils.ReadMessage(connRW) + message, err := ReadMessage(connRW) if err != nil && err == io.EOF { fmt.Println(err) @@ -101,7 +99,7 @@ func main() { fmt.Println(err) } - decoded, err := utils.Decode(message) + decoded, err := Decode(message) if err != nil { fmt.Println(err) diff --git a/utils/utils.go b/client/utils.go similarity index 99% rename from utils/utils.go rename to client/utils.go index 90f4966..8c54598 100644 --- a/utils/utils.go +++ b/client/utils.go @@ -1,4 +1,4 @@ -package utils +package main import ( "bufio" diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..b529cee --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,18 @@ +version: '3.8' + +networks: + testnet: + driver: bridge + +services: + node1: + container_name: node1 + build: + context: . + dockerfile: ./server/Dockerfile + ports: + - "7480:7480" + - "7946:7946" + - "8000:8000" + networks: + - testnet \ No newline at end of file diff --git a/server/Dockerfile b/server/Dockerfile new file mode 100644 index 0000000..aa5433d --- /dev/null +++ b/server/Dockerfile @@ -0,0 +1,15 @@ +FROM golang:1.20.0 + +WORKDIR /app/memstore + +COPY ["./server", "./vendor", "./go.mod", "./go.sum", "./"] +COPY ["./openssl/server", "./ssl"] + +RUN go build -o bin/server ./*.go + +CMD ["./bin/server", "--config", "./config.json", "--tls"] + +EXPOSE 7480 +EXPOSE 8000 +EXPOSE 7946 + diff --git a/server/Makefile b/server/Makefile deleted file mode 100644 index 3313b5c..0000000 --- a/server/Makefile +++ /dev/null @@ -1,12 +0,0 @@ -build-plugins: - go build -buildmode=plugin -o bin/plugins/commands/command_ping.so plugins/commands/ping/ping.go - go build -buildmode=plugin -o bin/plugins/commands/command_setget.so plugins/commands/setget/setget.go - go build -buildmode=plugin -o bin/plugins/commands/command_list.so plugins/commands/list/list.go - -build-server: - go build -o bin/server ./*.go - -build: build-plugins build-server - -run: - ./bin/server diff --git a/server/plugins/commands/list/list.go b/server/command_list.go similarity index 72% rename from server/plugins/commands/list/list.go rename to server/command_list.go index 87fe3e7..8b949ff 100644 --- a/server/plugins/commands/list/list.go +++ b/server/command_list.go @@ -5,78 +5,63 @@ import ( "fmt" "math" "strings" - - "github.com/kelvinmwinuka/memstore/utils" ) -const ( - OK = "+OK\r\n\n" -) - -type Server interface { - Lock() - Unlock() - GetData(key string) interface{} - SetData(key string, value interface{}) -} - -type plugin struct { +type ListCommand struct { name string commands []string description string } -var Plugin plugin - -func (p *plugin) Name() string { +func (p *ListCommand) Name() string { return p.name } -func (p *plugin) Commands() []string { +func (p *ListCommand) Commands() []string { return p.commands } -func (p *plugin) Description() string { +func (p *ListCommand) Description() string { return p.description } -func (p *plugin) HandleCommand(cmd []string, server interface{}, conn *bufio.Writer) { +func (p *ListCommand) HandleCommand(cmd []string, server *Server, conn *bufio.Writer) { c := strings.ToLower(cmd[0]) switch { case c == "llen": - handleLLen(cmd, server.(Server), conn) + handleLLen(cmd, server, conn) case c == "lindex": - handleLIndex(cmd, server.(Server), conn) + handleLIndex(cmd, server, conn) case c == "lrange": - handleLRange(cmd, server.(Server), conn) + handleLRange(cmd, server, conn) case c == "lset": - handleLSet(cmd, server.(Server), conn) + handleLSet(cmd, server, conn) case c == "ltrim": - handleLTrim(cmd, server.(Server), conn) + handleLTrim(cmd, server, conn) case c == "lrem": - handleLRem(cmd, server.(Server), conn) + handleLRem(cmd, server, conn) case c == "lmove": - handleLMove(cmd, server.(Server), conn) + handleLMove(cmd, server, conn) - case utils.Contains[string]([]string{"lpush", "lpushx"}, c): - handleLPush(cmd, server.(Server), conn) + case Contains[string]([]string{"lpush", "lpushx"}, c): + handleLPush(cmd, server, conn) - case utils.Contains[string]([]string{"rpush", "rpushx"}, c): - handleRPush(cmd, server.(Server), conn) + case Contains[string]([]string{"rpush", "rpushx"}, c): + handleRPush(cmd, server, conn) - case utils.Contains[string]([]string{"lpop", "rpop"}, c): - handlePop(cmd, server.(Server), conn) + case Contains[string]([]string{"lpop", "rpop"}, c): + handlePop(cmd, server, conn) } } -func handleLLen(cmd []string, server Server, conn *bufio.Writer) { +func handleLLen(cmd []string, server *Server, conn *bufio.Writer) { if len(cmd) != 2 { conn.Write([]byte("-Error wrong number of args for LLEN command\r\n\n")) conn.Flush() @@ -99,14 +84,14 @@ func handleLLen(cmd []string, server Server, conn *bufio.Writer) { conn.Flush() } -func handleLIndex(cmd []string, server Server, conn *bufio.Writer) { +func handleLIndex(cmd []string, server *Server, conn *bufio.Writer) { if len(cmd) != 3 { conn.Write([]byte("-Error wrong number of args for LINDEX command\r\n\n")) conn.Flush() return } - index, ok := utils.AdaptType(cmd[2]).(int) + index, ok := AdaptType(cmd[2]).(int) if !ok { conn.Write([]byte("-Error index must be an integer\r\n\n")) @@ -137,15 +122,15 @@ func handleLIndex(cmd []string, server Server, conn *bufio.Writer) { conn.Flush() } -func handleLRange(cmd []string, server Server, conn *bufio.Writer) { +func handleLRange(cmd []string, server *Server, conn *bufio.Writer) { if len(cmd) != 4 { conn.Write([]byte("-Error wrong number of arguments for LRANGE command\r\n\n")) conn.Flush() return } - start, startOk := utils.AdaptType(cmd[2]).(int) - end, endOk := utils.AdaptType(cmd[3]).(int) + start, startOk := AdaptType(cmd[2]).(int) + end, endOk := AdaptType(cmd[3]).(int) if !startOk || !endOk { conn.Write([]byte("-Error both start and end indices must be integers\r\n\n")) @@ -223,7 +208,7 @@ func handleLRange(cmd []string, server Server, conn *bufio.Writer) { conn.Flush() } -func handleLSet(cmd []string, server Server, conn *bufio.Writer) { +func handleLSet(cmd []string, server *Server, conn *bufio.Writer) { if len(cmd) != 4 { conn.Write([]byte("-Error wrong number of arguments for LSET command\r\n\n")) conn.Flush() @@ -241,7 +226,7 @@ func handleLSet(cmd []string, server Server, conn *bufio.Writer) { return } - index, ok := utils.AdaptType(cmd[2]).(int) + index, ok := AdaptType(cmd[2]).(int) if !ok { server.Unlock() @@ -257,7 +242,7 @@ func handleLSet(cmd []string, server Server, conn *bufio.Writer) { return } - list[index] = utils.AdaptType(cmd[3]) + list[index] = AdaptType(cmd[3]) server.SetData(cmd[1], list) server.Unlock() @@ -265,15 +250,15 @@ func handleLSet(cmd []string, server Server, conn *bufio.Writer) { conn.Flush() } -func handleLTrim(cmd []string, server Server, conn *bufio.Writer) { +func handleLTrim(cmd []string, server *Server, conn *bufio.Writer) { if len(cmd) != 4 { conn.Write([]byte("-Error wrong number of args for command LTRIM \r\n\n")) conn.Flush() return } - start, startOk := utils.AdaptType(cmd[2]).(int) - end, endOk := utils.AdaptType(cmd[3]).(int) + start, startOk := AdaptType(cmd[2]).(int) + end, endOk := AdaptType(cmd[3]).(int) if !startOk || !endOk { conn.Write([]byte("-Error start and end indices must be integers\r\n\n")) @@ -319,7 +304,7 @@ func handleLTrim(cmd []string, server Server, conn *bufio.Writer) { conn.Flush() } -func handleLRem(cmd []string, server Server, conn *bufio.Writer) { +func handleLRem(cmd []string, server *Server, conn *bufio.Writer) { if len(cmd) != 4 { conn.Write([]byte("-Error wrong number of arguments for LREM command\r\n\n")) conn.Flush() @@ -327,7 +312,7 @@ func handleLRem(cmd []string, server Server, conn *bufio.Writer) { } value := cmd[3] - count, ok := utils.AdaptType(cmd[2]).(int) + count, ok := AdaptType(cmd[2]).(int) if !ok { conn.Write([]byte("-Error count must be an integer\r\n\n")) @@ -375,7 +360,7 @@ func handleLRem(cmd []string, server Server, conn *bufio.Writer) { } } - list = utils.Filter[interface{}](list, func(elem interface{}) bool { + list = Filter[interface{}](list, func(elem interface{}) bool { return elem != nil }) @@ -386,7 +371,7 @@ func handleLRem(cmd []string, server Server, conn *bufio.Writer) { conn.Flush() } -func handleLMove(cmd []string, server Server, conn *bufio.Writer) { +func handleLMove(cmd []string, server *Server, conn *bufio.Writer) { if len(cmd) != 5 { conn.Write([]byte("-Error wrong number of arguments for LMOVE command\r\n\n")) conn.Flush() @@ -396,7 +381,7 @@ func handleLMove(cmd []string, server Server, conn *bufio.Writer) { whereFrom := strings.ToLower(cmd[3]) whereTo := strings.ToLower(cmd[4]) - if !utils.Contains[string]([]string{"left", "right"}, whereFrom) || !utils.Contains[string]([]string{"left", "right"}, whereTo) { + if !Contains[string]([]string{"left", "right"}, whereFrom) || !Contains[string]([]string{"left", "right"}, whereTo) { conn.Write([]byte("-Error wherefrom and whereto arguments must be either LEFT or RIGHT\r\n\n")) conn.Flush() return @@ -436,7 +421,7 @@ func handleLMove(cmd []string, server Server, conn *bufio.Writer) { conn.Flush() } -func handleLPush(cmd []string, server Server, conn *bufio.Writer) { +func handleLPush(cmd []string, server *Server, conn *bufio.Writer) { if len(cmd) < 3 { conn.Write([]byte(fmt.Sprintf("-Error wrong number of arguments for %s command\r\n\n", strings.ToUpper(cmd[0])))) conn.Flush() @@ -448,7 +433,7 @@ func handleLPush(cmd []string, server Server, conn *bufio.Writer) { newElems := []interface{}{} for _, elem := range cmd[2:] { - newElems = append(newElems, utils.AdaptType(elem)) + newElems = append(newElems, AdaptType(elem)) } currentList := server.GetData(cmd[1]) @@ -484,7 +469,7 @@ func handleLPush(cmd []string, server Server, conn *bufio.Writer) { conn.Flush() } -func handleRPush(cmd []string, server Server, conn *bufio.Writer) { +func handleRPush(cmd []string, server *Server, conn *bufio.Writer) { if len(cmd) < 3 { conn.Write([]byte(fmt.Sprintf("-Error wrong number of arguments for %s command\r\n\n", strings.ToUpper(cmd[0])))) conn.Flush() @@ -496,7 +481,7 @@ func handleRPush(cmd []string, server Server, conn *bufio.Writer) { newElems := []interface{}{} for _, elem := range cmd[2:] { - newElems = append(newElems, utils.AdaptType(elem)) + newElems = append(newElems, AdaptType(elem)) } currentList := server.GetData(cmd[1]) @@ -531,7 +516,7 @@ func handleRPush(cmd []string, server Server, conn *bufio.Writer) { conn.Flush() } -func handlePop(cmd []string, server Server, conn *bufio.Writer) { +func handlePop(cmd []string, server *Server, conn *bufio.Writer) { if len(cmd) != 2 { conn.Write([]byte(fmt.Sprintf("-Error wrong number of args for %s command\r\n\n", strings.ToUpper(cmd[0])))) conn.Flush() @@ -562,22 +547,24 @@ func handlePop(cmd []string, server Server, conn *bufio.Writer) { conn.Flush() } -func init() { - Plugin.name = "ListCommand" - Plugin.commands = []string{ - "lpush", // (LPUSH key value1 [value2]) Prepends one or more values to the beginning of a list, creates the list if it does not exist. - "lpushx", // (LPUSHX key value) Prepends a value to the beginning of a list only if the list exists. - "lpop", // (LPOP key) Removes and returns the first element of a list. - "llen", // (LLEN key) Return the length of a list. - "lrange", // (LRANGE key start end) Return a range of elements between the given indices. - "lindex", // (LINDEX key index) Gets list element by index. - "lset", // (LSET key index value) Sets the value of an element in a list by its index. - "ltrim", // (LTRIM key start end) Trims a list to the specified range. - "lrem", // (LREM key count value) Remove elements from list. - "lmove", // (LMOVE source destination Move element from one list to the other specifying left/right for both lists. - "rpop", // (RPOP key) Removes and gets the last element in a list. - "rpush", // (RPUSH key value [value2]) Appends one or multiple elements to the end of a list. - "rpushx", // (RPUSHX key value) Appends an element to the end of a list, only if the list exists. +func NewListCommand() *ListCommand { + return &ListCommand{ + name: "ListCommand", + commands: []string{ + "lpush", // (LPUSH key value1 [value2]) Prepends one or more values to the beginning of a list, creates the list if it does not exist. + "lpushx", // (LPUSHX key value) Prepends a value to the beginning of a list only if the list exists. + "lpop", // (LPOP key) Removes and returns the first element of a list. + "llen", // (LLEN key) Return the length of a list. + "lrange", // (LRANGE key start end) Return a range of elements between the given indices. + "lindex", // (LINDEX key index) Gets list element by index. + "lset", // (LSET key index value) Sets the value of an element in a list by its index. + "ltrim", // (LTRIM key start end) Trims a list to the specified range. + "lrem", // (LREM key count value) Remove elements from list. + "lmove", // (LMOVE source destination Move element from one list to the other specifying left/right for both lists. + "rpop", // (RPOP key) Removes and gets the last element in a list. + "rpush", // (RPUSH key value [value2]) Appends one or multiple elements to the end of a list. + "rpushx", // (RPUSHX key value) Appends an element to the end of a list, only if the list exists. + }, + description: "Handle List commands", } - Plugin.description = "Handle List commands" } diff --git a/server/command_ping.go b/server/command_ping.go new file mode 100644 index 0000000..c8d5ffa --- /dev/null +++ b/server/command_ping.go @@ -0,0 +1,43 @@ +package main + +import "bufio" + +type PingCommand struct { + name string + commands []string + description string +} + +func (p *PingCommand) Name() string { + return p.name +} + +func (p *PingCommand) Commands() []string { + return p.commands +} + +func (p *PingCommand) Description() string { + return p.description +} + +func (p *PingCommand) HandleCommand(cmd []string, server *Server, conn *bufio.Writer) { + switch len(cmd) { + default: + conn.Write([]byte("-Error wrong number of arguments for PING command\r\n\n")) + conn.Flush() + case 1: + conn.Write([]byte("+PONG\r\n\n")) + conn.Flush() + case 2: + conn.Write([]byte("+" + cmd[1] + "\r\n\n")) + conn.Flush() + } +} + +func NewPingCommand() *PingCommand { + return &PingCommand{ + name: "PingCommand", + commands: []string{"ping"}, + description: "Handle PING command", + } +} diff --git a/server/plugins/commands/setget/setget.go b/server/command_set_get.go similarity index 59% rename from server/plugins/commands/setget/setget.go rename to server/command_set_get.go index 50c6dfa..bdcb6de 100644 --- a/server/plugins/commands/setget/setget.go +++ b/server/command_set_get.go @@ -4,49 +4,38 @@ import ( "bufio" "fmt" "strings" - - "github.com/kelvinmwinuka/memstore/utils" ) -type Server interface { - Lock() - Unlock() - GetData(key string) interface{} - SetData(key string, value interface{}) -} - -type plugin struct { +type SetGetCommand struct { name string commands []string description string } -var Plugin plugin - -func (p *plugin) Name() string { +func (p *SetGetCommand) Name() string { return p.name } -func (p *plugin) Commands() []string { +func (p *SetGetCommand) Commands() []string { return p.commands } -func (p *plugin) Description() string { +func (p *SetGetCommand) Description() string { return p.description } -func (p *plugin) HandleCommand(cmd []string, server interface{}, conn *bufio.Writer) { +func (p *SetGetCommand) HandleCommand(cmd []string, server *Server, conn *bufio.Writer) { switch strings.ToLower(cmd[0]) { case "get": - handleGet(cmd, server.(Server), conn) + handleGet(cmd, server, conn) case "set": - handleSet(cmd, server.(Server), conn) + handleSet(cmd, server, conn) case "mget": - handleMGet(cmd, server.(Server), conn) + handleMGet(cmd, server, conn) } } -func handleGet(cmd []string, s Server, conn *bufio.Writer) { +func handleGet(cmd []string, s *Server, conn *bufio.Writer) { if len(cmd) != 2 { conn.Write([]byte("-Error wrong number of args for GET command\r\n\n")) @@ -67,7 +56,7 @@ func handleGet(cmd []string, s Server, conn *bufio.Writer) { conn.Flush() } -func handleMGet(cmd []string, s Server, conn *bufio.Writer) { +func handleMGet(cmd []string, s *Server, conn *bufio.Writer) { if len(cmd) < 2 { conn.Write([]byte("-Error wrong number of args for MGET command\r\n\n")) conn.Flush() @@ -99,22 +88,24 @@ func handleMGet(cmd []string, s Server, conn *bufio.Writer) { conn.Flush() } -func handleSet(cmd []string, s Server, conn *bufio.Writer) { +func handleSet(cmd []string, s *Server, conn *bufio.Writer) { switch x := len(cmd); { default: conn.Write([]byte("-Error wrong number of args for SET command\r\n\n")) conn.Flush() case x == 3: s.Lock() - s.SetData(cmd[1], utils.AdaptType(cmd[2])) + s.SetData(cmd[1], AdaptType(cmd[2])) s.Unlock() conn.Write([]byte("+OK\r\n\n")) conn.Flush() } } -func init() { - Plugin.name = "GetCommand" - Plugin.commands = []string{"set", "get", "mget"} - Plugin.description = "Handle basic SET, GET and MGET commands" +func NewSetGetCommand() *SetGetCommand { + return &SetGetCommand{ + name: "GetCommand", + commands: []string{"set", "get", "mget"}, + description: "Handle basic SET, GET and MGET commands", + } } diff --git a/server/config.go b/server/config.go index 723a253..c6c25b2 100644 --- a/server/config.go +++ b/server/config.go @@ -10,16 +10,17 @@ import ( ) type Config struct { - TLS bool `json:"tls" yaml:"tls"` - Key string `json:"key" yaml:"key"` - Cert string `json:"cert" yaml:"cert"` - Port uint16 `json:"port" yaml:"port"` - HTTP bool `json:"http" yaml:"http"` - Plugins string `json:"plugins" yaml:"plugins"` - ClusterPort uint16 `json:"clusterPort" yaml:"clusterPort"` - ServerID string `json:"serverId" yaml:"serverId"` - JoinAddr string `json:"joinAddr" yaml:"joinAddr"` - Addr string + TLS bool `json:"tls" yaml:"tls"` + Key string `json:"key" yaml:"key"` + Cert string `json:"cert" yaml:"cert"` + Port uint16 `json:"port" yaml:"port"` + HTTP bool `json:"http" yaml:"http"` + Plugins string `json:"plugins" yaml:"plugins"` + ServerID string `json:"serverId" yaml:"serverId"` + JoinAddr string `json:"joinAddr" yaml:"joinAddr"` + BindAddr string `json:"bindAddr" yaml:"bindAddr"` + RaftBindPort uint16 `json:"raftPort" yaml:"raftPort"` + MemberListBindPort uint16 `json:"mlPort" yaml:"mlPort"` } func GetConfig() Config { @@ -29,9 +30,11 @@ func GetConfig() Config { port := flag.Int("port", 7480, "Port to use. Default is 7480") http := flag.Bool("http", false, "Use HTTP protocol instead of raw TCP. Default is false") plugins := flag.String("plugins", ".", "The path to the plugins folder.") - clusterPort := flag.Int("clusterPort", 7481, "Port to use for intra-cluster communication. Leave on the client.") serverId := flag.String("serverId", "1", "Server ID in raft cluster. Leave empty for client.") joinAddr := flag.String("joinAddr", "", "Address of cluster member in a cluster to you want to join.") + bindAddr := flag.String("bindAddr", "127.0.0.1", "Address to bind the server to.") + raftBindPort := flag.Int("clusterPort", 7481, "Port to use for intra-cluster communication. Leave on the client.") + mlBindPort := flag.Int("mlPort", 7946, "Port to use for memberlist communication.") config := flag.String( "config", "", @@ -40,10 +43,22 @@ func GetConfig() Config { flag.Parse() - var conf Config + conf := Config{ + TLS: *tls, + Key: *key, + Cert: *cert, + HTTP: *http, + Port: uint16(*port), + ServerID: *serverId, + Plugins: *plugins, + JoinAddr: *joinAddr, + BindAddr: *bindAddr, + RaftBindPort: uint16(*raftBindPort), + MemberListBindPort: uint16(*mlBindPort), + } if len(*config) > 0 { - // Load config from config file + // Override configurations from file if f, err := os.Open(*config); err != nil { panic(err) } else { @@ -60,18 +75,6 @@ func GetConfig() Config { } } - } else { - conf = Config{ - TLS: *tls, - Key: *key, - Cert: *cert, - HTTP: *http, - Port: uint16(*port), - ClusterPort: uint16(*clusterPort), - ServerID: *serverId, - Plugins: *plugins, - JoinAddr: *joinAddr, - } } return conf diff --git a/server/main.go b/server/main.go index 0c5c80e..b42b821 100644 --- a/server/main.go +++ b/server/main.go @@ -3,30 +3,17 @@ package main import ( "bufio" "crypto/tls" - "errors" "fmt" "io" - "log" "net" "net/http" - "os" - "path" - "plugin" "strings" "sync" "github.com/hashicorp/memberlist" "github.com/hashicorp/raft" - "github.com/kelvinmwinuka/memstore/utils" ) -type Plugin interface { - Name() string - Commands() []string - Description() string - HandleCommand(cmd []string, server interface{}, conn *bufio.Writer) -} - type Data struct { mu sync.Mutex data map[string]interface{} @@ -35,7 +22,7 @@ type Data struct { type Server struct { config Config data Data - plugins []Plugin + commands []Command raft *raft.Raft memberList *memberlist.Memberlist } @@ -60,7 +47,7 @@ func (server *Server) handleConnection(conn net.Conn) { connRW := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) for { - message, err := utils.ReadMessage(connRW) + message, err := ReadMessage(connRW) if err != nil && err == io.EOF { // Connection closed @@ -72,7 +59,7 @@ func (server *Server) handleConnection(conn net.Conn) { continue } - if cmd, err := utils.Decode(message); err != nil { + if cmd, err := Decode(message); err != nil { // Return error to client connRW.Write([]byte(fmt.Sprintf("-Error %s\r\n\n", err.Error()))) connRW.Flush() @@ -81,9 +68,9 @@ func (server *Server) handleConnection(conn net.Conn) { // Look for plugin that handles this command and trigger it handled := false - for _, plugin := range server.plugins { - if utils.Contains[string](plugin.Commands(), strings.ToLower(cmd[0])) { - plugin.HandleCommand(cmd, server, connRW.Writer) + for _, c := range server.commands { + if Contains[string](c.Commands(), strings.ToLower(cmd[0])) { + c.HandleCommand(cmd, server, connRW.Writer) handled = true } } @@ -104,13 +91,13 @@ func (server *Server) StartTCP() { if conf.TLS { // TLS - fmt.Printf("Starting TLS server at Address %s, Port %d...\n", conf.Addr, conf.Port) + fmt.Printf("Starting TLS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port) cer, err := tls.LoadX509KeyPair(conf.Cert, conf.Key) if err != nil { panic(err) } - if l, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", conf.Addr, conf.Port), &tls.Config{ + if l, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", conf.BindAddr, conf.Port), &tls.Config{ Certificates: []tls.Certificate{cer}, }); err != nil { panic(err) @@ -121,8 +108,8 @@ func (server *Server) StartTCP() { if !conf.TLS { // TCP - fmt.Printf("Starting TCP server at Address %s, Port %d...\n", conf.Addr, conf.Port) - if l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", conf.Addr, conf.Port)); err != nil { + fmt.Printf("Starting TCP server at Address %s, Port %d...\n", conf.BindAddr, conf.Port) + if l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", conf.BindAddr, conf.Port)); err != nil { panic(err) } else { listener = l @@ -151,11 +138,11 @@ func (server *Server) StartHTTP() { var err error if conf.TLS { - fmt.Printf("Starting HTTPS server at Address %s, Port %d...\n", conf.Addr, conf.Port) - err = http.ListenAndServeTLS(fmt.Sprintf("%s:%d", conf.Addr, conf.Port), conf.Cert, conf.Key, nil) + fmt.Printf("Starting HTTPS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port) + err = http.ListenAndServeTLS(fmt.Sprintf("%s:%d", conf.BindAddr, conf.Port), conf.Cert, conf.Key, nil) } else { - fmt.Printf("Starting HTTP server at Address %s, Port %d...\n", conf.Addr, conf.Port) - err = http.ListenAndServe(fmt.Sprintf("%s:%d", conf.Addr, conf.Port), nil) + fmt.Printf("Starting HTTP server at Address %s, Port %d...\n", conf.BindAddr, conf.Port) + err = http.ListenAndServe(fmt.Sprintf("%s:%d", conf.BindAddr, conf.Port), nil) } if err != nil { @@ -163,81 +150,17 @@ func (server *Server) StartHTTP() { } } -func (server *Server) LoadPlugins() { - conf := server.config - - // Load plugins - pluginDirs, err := os.ReadDir(conf.Plugins) - - if err != nil { - log.Fatal(err) - } - - for _, file := range pluginDirs { - if file.IsDir() { - switch file.Name() { - case "commands": - files, err := os.ReadDir(path.Join(conf.Plugins, "commands")) - - if err != nil { - log.Fatal(err) - } - - for _, file := range files { - if !strings.HasSuffix(file.Name(), ".so") { - // Skip files that are not .so - continue - } - p, err := plugin.Open(path.Join(conf.Plugins, "commands", file.Name())) - if err != nil { - log.Fatal(err) - } - - pluginSymbol, err := p.Lookup("Plugin") - if err != nil { - fmt.Printf("unexpected plugin symbol in plugin %s\n", file.Name()) - continue - } - - plugin, ok := pluginSymbol.(Plugin) - if !ok { - fmt.Printf("invalid plugin signature in plugin %s \n", file.Name()) - continue - } - - // Check if a plugin that handles the same command already exists - for _, loadedPlugin := range server.plugins { - containsMutual, elem := utils.ContainsMutual[string](loadedPlugin.Commands(), plugin.Commands()) - if containsMutual { - fmt.Printf("plugin that handles %s command already exists. Please handle a different command.\n", elem) - } - } - - server.plugins = append(server.plugins, plugin) - } - } - } - } -} - func (server *Server) Start() { server.data.data = make(map[string]interface{}) conf := server.config - server.LoadPlugins() - if conf.TLS && (len(conf.Key) <= 0 || len(conf.Cert) <= 0) { fmt.Println("Must provide key and certificate file paths for TLS mode.") return } - if addr, err := getServerAddresses(); err != nil { - log.Fatal(err) - } else { - conf.Addr = addr - server.config.Addr = addr - } + server.MemberListInit() if conf.HTTP { server.StartHTTP() @@ -246,28 +169,18 @@ func (server *Server) Start() { } } -func getServerAddresses() (string, error) { - addrs, err := net.InterfaceAddrs() - if err != nil { - log.Fatal(err) - } - for _, address := range addrs { - // check the address type and if it is not a loopback the display it - if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { - if ipnet.IP.To4() != nil { - return ipnet.IP.String(), nil - } - } - } - - return "", errors.New("could not get IP Addresses") -} - func main() { config := GetConfig() + fmt.Println(config) + server := &Server{ config: config, + commands: []Command{ + NewPingCommand(), + NewSetGetCommand(), + NewListCommand(), + }, } server.Start() } diff --git a/server/memberlist.go b/server/memberlist.go index d6e1761..c258062 100644 --- a/server/memberlist.go +++ b/server/memberlist.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "log" "github.com/hashicorp/memberlist" @@ -8,12 +9,33 @@ import ( func (server *Server) MemberListInit() { // Triggered before RaftInit - memberList, err := memberlist.Create(memberlist.DefaultLocalConfig()) + cfg := memberlist.DefaultLocalConfig() + cfg.BindAddr = server.config.BindAddr + cfg.BindPort = int(server.config.MemberListBindPort) + + list, err := memberlist.Create(cfg) + server.memberList = list + if err != nil { - log.Fatal("Could not start memberlist cluster.") + log.Fatal(err) } - server.memberList = memberList + if server.config.JoinAddr != "" { + n, err := server.memberList.Join([]string{server.config.JoinAddr}) + + if err != nil { + log.Fatal(err) + } + + fmt.Printf("Joined cluster. Contacted %d nodes.\n", n) + } + + // go func() { + // for { + // fmt.Println(server.memberList.NumMembers()) + // time.Sleep(2 * time.Second) + // } + // }() } func (server *Server) ShutdownMemberList() { diff --git a/utils/mock.go b/server/mock.go similarity index 98% rename from utils/mock.go rename to server/mock.go index 0fc1c0d..44db4d2 100644 --- a/utils/mock.go +++ b/server/mock.go @@ -1,4 +1,4 @@ -package utils +package main import ( "bytes" diff --git a/server/plugins/commands/ping/ping.go b/server/plugins/commands/ping/ping.go deleted file mode 100644 index 3888be0..0000000 --- a/server/plugins/commands/ping/ping.go +++ /dev/null @@ -1,50 +0,0 @@ -package main - -import "bufio" - -type Server interface { - Lock() - Unlock() - GetData(key string) interface{} - SetData(key string, value interface{}) -} - -type plugin struct { - name string - commands []string - description string -} - -var Plugin plugin - -func (p *plugin) Name() string { - return p.name -} - -func (p *plugin) Commands() []string { - return p.commands -} - -func (p *plugin) Description() string { - return p.description -} - -func (p *plugin) HandleCommand(cmd []string, server interface{}, conn *bufio.Writer) { - switch len(cmd) { - default: - conn.Write([]byte("-Error wrong number of arguments for PING command\r\n\n")) - conn.Flush() - case 1: - conn.Write([]byte("+PONG\r\n\n")) - conn.Flush() - case 2: - conn.Write([]byte("+" + cmd[1] + "\r\n\n")) - conn.Flush() - } -} - -func init() { - Plugin.name = "PingCommand" - Plugin.commands = []string{"ping"} - Plugin.description = "Handle PING command" -} diff --git a/server/plugins/commands/ping/ping_test.go b/server/plugins/commands/ping/ping_test.go deleted file mode 100644 index a6652a0..0000000 --- a/server/plugins/commands/ping/ping_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package main - -import ( - "bufio" - "strings" - "testing" - - "github.com/kelvinmwinuka/memstore/utils" -) - -func TestHandleCommand(t *testing.T) { - server := &utils.MockServer{} - - cw := &utils.CustomWriter{} - writer := bufio.NewWriter(cw) - - tests := []struct { - cmd []string - expected string - }{ - {[]string{"ping"}, "+PONG\r\n\n"}, - {[]string{"ping", "Ping Test"}, "+Ping Test\r\n\n"}, - {[]string{"ping", "Ping Test", "Error"}, "-Error wrong number of arguments for PING command\r\n\n"}, - } - - for _, tt := range tests { - cw.Buf.Reset() - Plugin.HandleCommand(tt.cmd, server, writer) - if tt.expected != cw.Buf.String() { - t.Errorf("Expected %s, Got %s", strings.TrimSpace(tt.expected), strings.TrimSpace(cw.Buf.String())) - } - } -} diff --git a/server/plugins/commands/setget/setget_test.go b/server/plugins/commands/setget/setget_test.go deleted file mode 100644 index bcc00d5..0000000 --- a/server/plugins/commands/setget/setget_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package main - -import ( - "bufio" - "strings" - "sync" - "testing" - - "github.com/kelvinmwinuka/memstore/utils" -) - -const ( - OK = "+OK\r\n\n" -) - -func TestHandleCommand(t *testing.T) { - server := utils.MockServer{ - Data: utils.MockData{ - Mu: sync.Mutex{}, - Data: make(map[string]interface{}), - }, - } - - cw := utils.CustomWriter{} - writer := bufio.NewWriter(&cw) - - tests := []struct { - cmd []string - expected string - }{ - // SET test cases - {[]string{"set", "key1", "value1"}, OK}, - {[]string{"set", "key2", "30"}, OK}, - {[]string{"set", "key3", "3.142"}, OK}, - {[]string{"set", "key4", "part1", "part2", "part3"}, "-Error wrong number of args for SET command\r\n\n"}, - {[]string{"set"}, "-Error wrong number of args for SET command\r\n\n"}, - - // GET test cases - {[]string{"get", "key1"}, "+value1\r\n\n"}, - {[]string{"get", "key2"}, "+30\r\n\n"}, - {[]string{"get", "key3"}, "+3.142\r\n\n"}, - {[]string{"get", "key4"}, "+nil\r\n\n"}, - {[]string{"get"}, "-Error wrong number of args for GET command\r\n\n"}, - {[]string{"get", "key1", "key2"}, "-Error wrong number of args for GET command\r\n\n"}, - - // MGET test cases - {[]string{"mget", "key1", "key2", "key3", "key4"}, "*4\r\n$6\r\nvalue1\r\n$2\r\n30\r\n$5\r\n3.142\r\n$3\r\nnil\r\n\n"}, - {[]string{"mget", "key5", "key6"}, "*2\r\n$3\r\nnil\r\n$3\r\nnil\r\n\n"}, - {[]string{"mget"}, "-Error wrong number of args for MGET command\r\n\n"}, - } - - for _, tt := range tests { - cw.Buf.Reset() - Plugin.HandleCommand(tt.cmd, &server, writer) - if tt.expected != cw.Buf.String() { - t.Errorf("Expected %s, Got %s", strings.TrimSpace(tt.expected), strings.TrimSpace(cw.Buf.String())) - } - } -} diff --git a/server/raft.go b/server/raft.go index 69cf9b2..a638862 100644 --- a/server/raft.go +++ b/server/raft.go @@ -23,7 +23,7 @@ func (server *Server) RaftInit() { raftStableStore := raft.NewInmemStore() raftSnapshotStore := raft.NewInmemSnapshotStore() - raftAddr := fmt.Sprintf("%s:%d", conf.Addr, conf.ClusterPort) + raftAddr := fmt.Sprintf("%s:%d", conf.BindAddr, conf.RaftBindPort) raftAdvertiseAddr, err := net.ResolveTCPAddr("tcp", raftAddr) if err != nil { log.Fatal("Could not resolve advertise address.") diff --git a/server/utils.go b/server/utils.go new file mode 100644 index 0000000..887d5db --- /dev/null +++ b/server/utils.go @@ -0,0 +1,166 @@ +package main + +import ( + "bufio" + "bytes" + "encoding/csv" + "errors" + "fmt" + "math" + "math/big" + "reflect" + "strings" + + "github.com/tidwall/resp" +) + +const ( + OK = "+OK\r\n\n" +) + +type Command interface { + Name() string + Commands() []string + Description() string + HandleCommand(cmd []string, server *Server, conn *bufio.Writer) +} + +func Contains[T comparable](arr []T, elem T) bool { + for _, v := range arr { + if v == elem { + return true + } + } + return false +} + +func ContainsMutual[T comparable](arr1 []T, arr2 []T) (bool, T) { + for _, a := range arr1 { + for _, b := range arr2 { + if a == b { + return true, a + } + } + } + return false, arr1[0] +} + +func IsInteger(n float64) bool { + return math.Mod(n, 1.0) == 0 +} + +func AdaptType(s string) interface{} { + // Adapt the type of the parameter to string, float64 or int + n, _, err := big.ParseFloat(s, 10, 256, big.RoundingMode(big.Exact)) + + if err != nil { + return s + } + + if n.IsInt() { + i, _ := n.Int64() + return i + } + + return n +} + +func IncrBy(num interface{}, by interface{}) (interface{}, error) { + if !Contains[string]([]string{"int", "float64"}, reflect.TypeOf(num).String()) { + return nil, errors.New("can only increment number") + } + if !Contains[string]([]string{"int", "float64"}, reflect.TypeOf(by).String()) { + return nil, errors.New("can only increment by number") + } + + n, _ := num.(float64) + b, _ := by.(float64) + res := n + b + + if IsInteger(res) { + return int(res), nil + } + + return res, nil +} + +func Filter[T comparable](arr []T, test func(elem T) bool) (res []T) { + for _, e := range arr { + if test(e) { + res = append(res, e) + } + } + return +} + +func tokenize(comm string) ([]string, error) { + r := csv.NewReader(strings.NewReader(comm)) + r.Comma = ' ' + return r.Read() +} + +func Encode(comm string) (string, error) { + tokens, err := tokenize(comm) + + if err != nil { + return "", errors.New("could not parse command") + } + + str := fmt.Sprintf("*%d\r\n", len(tokens)) + + for i, token := range tokens { + if i == 0 { + str += fmt.Sprintf("$%d\r\n%s\r\n", len(token), strings.ToUpper(token)) + } else { + str += fmt.Sprintf("$%d\r\n%s\r\n", len(token), token) + } + } + + str += "\n" + + return str, nil +} + +func Decode(raw string) ([]string, error) { + rd := resp.NewReader(bytes.NewBufferString(raw)) + res := []string{} + + v, _, err := rd.ReadValue() + + if err != nil { + return nil, err + } + + if Contains[string]([]string{"SimpleString", "Integer", "Error"}, v.Type().String()) { + return []string{v.String()}, nil + } + + if v.Type().String() == "Array" { + for _, elem := range v.Array() { + res = append(res, elem.String()) + } + } + + return res, nil +} + +func ReadMessage(r *bufio.ReadWriter) (message string, err error) { + var line [][]byte + + for { + b, _, err := r.ReadLine() + + if err != nil { + return "", err + } + + if bytes.Equal(b, []byte("")) { + // End of message + break + } + + line = append(line, b) + } + + return fmt.Sprintf("%s\r\n", string(bytes.Join(line, []byte("\r\n")))), nil +}