diff --git a/client/main.go b/client/main.go index 6e8eef1..b1ffeee 100644 --- a/client/main.go +++ b/client/main.go @@ -14,6 +14,7 @@ import ( "path" "github.com/kelvinmwinuka/memstore/serialization" + "github.com/kelvinmwinuka/memstore/utils" "gopkg.in/yaml.v3" ) @@ -122,8 +123,7 @@ func main() { stdioRW.Flush() if in, err := stdioRW.ReadBytes(byte('\n')); err != nil { - stdioRW.Write([]byte(fmt.Sprintf("ERROR: %s\n", err))) - stdioRW.Flush() + fmt.Println(err) } else { in := bytes.TrimSpace(in) @@ -133,28 +133,30 @@ func main() { } if err := serialization.Encode(connRW, string(in)); err != nil { - stdioRW.Write([]byte(fmt.Sprintf("%s\n", err.Error()))) - stdioRW.Flush() + fmt.Println(err) } else { // Write encoded command to the connection connRW.Write([]byte("\n")) connRW.Flush() // Read response from server - if l, _, err := connRW.ReadLine(); err != nil { - if err == io.EOF { - // Break loop when we detect an EOF (Connection closed) - stdioRW.Write([]byte("Connection closed")) - stdioRW.Flush() - break - } else { - stdioRW.Write([]byte(err.Error())) - stdioRW.Flush() - } - } else { - stdioRW.Write(l) - stdioRW.Flush() + message, err := utils.ReadMessage(connRW) + + if err != nil && err == io.EOF { + fmt.Println(err) + break + } else if err != nil { + fmt.Println(err) } + + decoded, err := serialization.Decode(message) + + if err != nil { + fmt.Println(err) + continue + } + + fmt.Println(decoded) } } } diff --git a/serialization/decode.go b/serialization/decode.go index 13c0151..81ba5fe 100644 --- a/serialization/decode.go +++ b/serialization/decode.go @@ -9,6 +9,7 @@ import ( func Decode(raw string) ([]string, error) { rd := resp.NewReader(bytes.NewBufferString(raw)) + res := []string{} v, _, err := rd.ReadValue() @@ -16,8 +17,6 @@ func Decode(raw string) ([]string, error) { return nil, err } - res := []string{} - if v.Type().String() == "SimpleString" { return []string{v.String()}, nil } diff --git a/serialization/encode.go b/serialization/encode.go index 199d38e..a830c07 100644 --- a/serialization/encode.go +++ b/serialization/encode.go @@ -29,7 +29,6 @@ func encodeError(wr *resp.Writer, tokens []string) { } func encodeSimpleString(wr *resp.Writer, tokens []string) error { - fmt.Println(tokens, len(tokens)) switch len(tokens) { default: return fmt.Errorf(wrong_args_error, strings.ToUpper(tokens[0])) @@ -40,6 +39,23 @@ func encodeSimpleString(wr *resp.Writer, tokens []string) error { } } +func encodeArray(wr *resp.Writer, tokens []string) error { + switch l := len(tokens); { + default: + return fmt.Errorf(wrong_args_error, strings.ToUpper(tokens[0])) + case l > 0: + arr := []resp.Value{} + + for _, token := range tokens[1:] { + arr = append(arr, resp.AnyValue(token)) + } + + wr.WriteArray(arr) + + return nil + } +} + func encodePingPong(wr *resp.Writer, tokens []string) error { switch len(tokens) { default: @@ -171,6 +187,8 @@ func Encode(buf io.ReadWriter, comm string) error { err = encodeIncr(wr, tokens) case "simplestring": err = encodeSimpleString(wr, tokens) + case "array": + err = encodeArray(wr, tokens) case "Error": encodeError(wr, tokens) err = errors.New("failed to parse command") diff --git a/server/main.go b/server/main.go index a9ef9f8..dcfc0c3 100644 --- a/server/main.go +++ b/server/main.go @@ -2,7 +2,6 @@ package main import ( "bufio" - "bytes" "crypto/tls" "encoding/json" "flag" @@ -16,6 +15,7 @@ import ( "sync" "github.com/kelvinmwinuka/memstore/serialization" + "github.com/kelvinmwinuka/memstore/utils" "gopkg.in/yaml.v3" ) @@ -37,56 +37,54 @@ type Server struct { data Data } -func (server *Server) hanndleConnection(conn net.Conn) { +func (server *Server) handleConnection(conn net.Conn) { connRW := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) - var line [][]byte - for { - b, _, err := connRW.ReadLine() + message, err := utils.ReadMessage(connRW) if err != nil && err == io.EOF { fmt.Println(err) break } - line = append(line, b) + if err != nil { + fmt.Println(err) + continue + } - if bytes.Equal(b, []byte("")) { - // End of RESP message - // sw.Write(bytes.Join(line, []byte("\\r\\n"))) - // sw.Flush() + if cmd, err := serialization.Decode(message); err != nil { + // Return error to client + serialization.Encode(connRW, fmt.Sprintf("Error %s", err.Error())) + fmt.Println("Server: ", err) + continue + } else { + // Return encoded message to client - if cmd, err := serialization.Decode(string(bytes.Join(line, []byte("\r\n")))); err != nil { - // Return error to client - serialization.Encode(connRW, fmt.Sprintf("Error %s", err.Error())) - continue - } else { - // Return encoded message to client - - switch strings.ToLower(cmd[0]) { - default: - fmt.Println("The command is unknown") - case "ping": - if len(cmd) == 1 { - serialization.Encode(connRW, "SimpleString PONG") - connRW.Flush() - } - if len(cmd) == 2 { - fmt.Println(cmd) - serialization.Encode(connRW, fmt.Sprintf("SimpleString \"%s\"", cmd[1])) - connRW.Flush() - } - case "set": - fmt.Println("Set the value") - case "get": - fmt.Println("Get the value") - case "mget": - fmt.Println("Get the multiple values requested") + switch strings.ToLower(cmd[0]) { + default: + fmt.Println("The command is unknown") + case "ping": + if len(cmd) == 1 { + serialization.Encode(connRW, "SimpleString PONG") + connRW.Write([]byte("\n")) + connRW.Flush() } + if len(cmd) == 2 { + serialization.Encode(connRW, fmt.Sprintf("SimpleString \"%s\"", cmd[1])) + connRW.Write([]byte("\n")) + connRW.Flush() + } + case "set": + fmt.Println("Set the value") + case "get": + fmt.Println("Get the value") + case "mget": + fmt.Println("Get the multiple values requested") + serialization.Encode(connRW, "Array THIS IS THE ARRAY") + connRW.Write([]byte("\n")) + connRW.Flush() } - - line = [][]byte{} } } @@ -132,7 +130,7 @@ func (server *Server) StartTCP() { continue } // Read loop for connection - go server.hanndleConnection(conn) + go server.handleConnection(conn) } } diff --git a/utils/utils.go b/utils/utils.go index 1bc8998..cc0cd63 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -1,6 +1,11 @@ package utils -import "math" +import ( + "bufio" + "bytes" + "fmt" + "math" +) func Contains[T comparable](arr []T, elem T) bool { for _, v := range arr { @@ -14,3 +19,24 @@ func Contains[T comparable](arr []T, elem T) bool { func IsInteger(n float64) bool { return math.Mod(n, 1.0) == 0 } + +func ReadMessage(r *bufio.ReadWriter) (message string, err error) { + var line [][]byte + + for { + b, _, err := r.ReadLine() + + if err != nil { + return "", err + } + + if bytes.Equal(b, []byte("")) { + // End of message + break + } + + line = append(line, b) + } + + return fmt.Sprintf("%s\r\n", string(bytes.Join(line, []byte("\r\n")))), nil +}