Added array serialization to encode function. Moved message reading to utils packages as it's used by both server and client

This commit is contained in:
Kelvin Clement Mwinuka
2023-07-01 21:59:41 +08:00
parent eb188deb7b
commit ef140c8427
5 changed files with 103 additions and 60 deletions

View File

@@ -14,6 +14,7 @@ import (
"path" "path"
"github.com/kelvinmwinuka/memstore/serialization" "github.com/kelvinmwinuka/memstore/serialization"
"github.com/kelvinmwinuka/memstore/utils"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@@ -122,8 +123,7 @@ func main() {
stdioRW.Flush() stdioRW.Flush()
if in, err := stdioRW.ReadBytes(byte('\n')); err != nil { if in, err := stdioRW.ReadBytes(byte('\n')); err != nil {
stdioRW.Write([]byte(fmt.Sprintf("ERROR: %s\n", err))) fmt.Println(err)
stdioRW.Flush()
} else { } else {
in := bytes.TrimSpace(in) in := bytes.TrimSpace(in)
@@ -133,28 +133,30 @@ func main() {
} }
if err := serialization.Encode(connRW, string(in)); err != nil { if err := serialization.Encode(connRW, string(in)); err != nil {
stdioRW.Write([]byte(fmt.Sprintf("%s\n", err.Error()))) fmt.Println(err)
stdioRW.Flush()
} else { } else {
// Write encoded command to the connection // Write encoded command to the connection
connRW.Write([]byte("\n")) connRW.Write([]byte("\n"))
connRW.Flush() connRW.Flush()
// Read response from server // Read response from server
if l, _, err := connRW.ReadLine(); err != nil { message, err := utils.ReadMessage(connRW)
if err == io.EOF {
// Break loop when we detect an EOF (Connection closed) if err != nil && err == io.EOF {
stdioRW.Write([]byte("Connection closed")) fmt.Println(err)
stdioRW.Flush() break
break } else if err != nil {
} else { fmt.Println(err)
stdioRW.Write([]byte(err.Error()))
stdioRW.Flush()
}
} else {
stdioRW.Write(l)
stdioRW.Flush()
} }
decoded, err := serialization.Decode(message)
if err != nil {
fmt.Println(err)
continue
}
fmt.Println(decoded)
} }
} }
} }

View File

@@ -9,6 +9,7 @@ import (
func Decode(raw string) ([]string, error) { func Decode(raw string) ([]string, error) {
rd := resp.NewReader(bytes.NewBufferString(raw)) rd := resp.NewReader(bytes.NewBufferString(raw))
res := []string{}
v, _, err := rd.ReadValue() v, _, err := rd.ReadValue()
@@ -16,8 +17,6 @@ func Decode(raw string) ([]string, error) {
return nil, err return nil, err
} }
res := []string{}
if v.Type().String() == "SimpleString" { if v.Type().String() == "SimpleString" {
return []string{v.String()}, nil return []string{v.String()}, nil
} }

View File

@@ -29,7 +29,6 @@ func encodeError(wr *resp.Writer, tokens []string) {
} }
func encodeSimpleString(wr *resp.Writer, tokens []string) error { func encodeSimpleString(wr *resp.Writer, tokens []string) error {
fmt.Println(tokens, len(tokens))
switch len(tokens) { switch len(tokens) {
default: default:
return fmt.Errorf(wrong_args_error, strings.ToUpper(tokens[0])) return fmt.Errorf(wrong_args_error, strings.ToUpper(tokens[0]))
@@ -40,6 +39,23 @@ func encodeSimpleString(wr *resp.Writer, tokens []string) error {
} }
} }
func encodeArray(wr *resp.Writer, tokens []string) error {
switch l := len(tokens); {
default:
return fmt.Errorf(wrong_args_error, strings.ToUpper(tokens[0]))
case l > 0:
arr := []resp.Value{}
for _, token := range tokens[1:] {
arr = append(arr, resp.AnyValue(token))
}
wr.WriteArray(arr)
return nil
}
}
func encodePingPong(wr *resp.Writer, tokens []string) error { func encodePingPong(wr *resp.Writer, tokens []string) error {
switch len(tokens) { switch len(tokens) {
default: default:
@@ -171,6 +187,8 @@ func Encode(buf io.ReadWriter, comm string) error {
err = encodeIncr(wr, tokens) err = encodeIncr(wr, tokens)
case "simplestring": case "simplestring":
err = encodeSimpleString(wr, tokens) err = encodeSimpleString(wr, tokens)
case "array":
err = encodeArray(wr, tokens)
case "Error": case "Error":
encodeError(wr, tokens) encodeError(wr, tokens)
err = errors.New("failed to parse command") err = errors.New("failed to parse command")

View File

@@ -2,7 +2,6 @@ package main
import ( import (
"bufio" "bufio"
"bytes"
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"flag" "flag"
@@ -16,6 +15,7 @@ import (
"sync" "sync"
"github.com/kelvinmwinuka/memstore/serialization" "github.com/kelvinmwinuka/memstore/serialization"
"github.com/kelvinmwinuka/memstore/utils"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@@ -37,56 +37,54 @@ type Server struct {
data Data data Data
} }
func (server *Server) hanndleConnection(conn net.Conn) { func (server *Server) handleConnection(conn net.Conn) {
connRW := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) connRW := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
var line [][]byte
for { for {
b, _, err := connRW.ReadLine() message, err := utils.ReadMessage(connRW)
if err != nil && err == io.EOF { if err != nil && err == io.EOF {
fmt.Println(err) fmt.Println(err)
break break
} }
line = append(line, b) if err != nil {
fmt.Println(err)
continue
}
if bytes.Equal(b, []byte("")) { if cmd, err := serialization.Decode(message); err != nil {
// End of RESP message // Return error to client
// sw.Write(bytes.Join(line, []byte("\\r\\n"))) serialization.Encode(connRW, fmt.Sprintf("Error %s", err.Error()))
// sw.Flush() fmt.Println("Server: ", err)
continue
} else {
// Return encoded message to client
if cmd, err := serialization.Decode(string(bytes.Join(line, []byte("\r\n")))); err != nil { switch strings.ToLower(cmd[0]) {
// Return error to client default:
serialization.Encode(connRW, fmt.Sprintf("Error %s", err.Error())) fmt.Println("The command is unknown")
continue case "ping":
} else { if len(cmd) == 1 {
// Return encoded message to client serialization.Encode(connRW, "SimpleString PONG")
connRW.Write([]byte("\n"))
switch strings.ToLower(cmd[0]) { connRW.Flush()
default:
fmt.Println("The command is unknown")
case "ping":
if len(cmd) == 1 {
serialization.Encode(connRW, "SimpleString PONG")
connRW.Flush()
}
if len(cmd) == 2 {
fmt.Println(cmd)
serialization.Encode(connRW, fmt.Sprintf("SimpleString \"%s\"", cmd[1]))
connRW.Flush()
}
case "set":
fmt.Println("Set the value")
case "get":
fmt.Println("Get the value")
case "mget":
fmt.Println("Get the multiple values requested")
} }
if len(cmd) == 2 {
serialization.Encode(connRW, fmt.Sprintf("SimpleString \"%s\"", cmd[1]))
connRW.Write([]byte("\n"))
connRW.Flush()
}
case "set":
fmt.Println("Set the value")
case "get":
fmt.Println("Get the value")
case "mget":
fmt.Println("Get the multiple values requested")
serialization.Encode(connRW, "Array THIS IS THE ARRAY")
connRW.Write([]byte("\n"))
connRW.Flush()
} }
line = [][]byte{}
} }
} }
@@ -132,7 +130,7 @@ func (server *Server) StartTCP() {
continue continue
} }
// Read loop for connection // Read loop for connection
go server.hanndleConnection(conn) go server.handleConnection(conn)
} }
} }

View File

@@ -1,6 +1,11 @@
package utils package utils
import "math" import (
"bufio"
"bytes"
"fmt"
"math"
)
func Contains[T comparable](arr []T, elem T) bool { func Contains[T comparable](arr []T, elem T) bool {
for _, v := range arr { for _, v := range arr {
@@ -14,3 +19,24 @@ func Contains[T comparable](arr []T, elem T) bool {
func IsInteger(n float64) bool { func IsInteger(n float64) bool {
return math.Mod(n, 1.0) == 0 return math.Mod(n, 1.0) == 0
} }
func ReadMessage(r *bufio.ReadWriter) (message string, err error) {
var line [][]byte
for {
b, _, err := r.ReadLine()
if err != nil {
return "", err
}
if bytes.Equal(b, []byte("")) {
// End of message
break
}
line = append(line, b)
}
return fmt.Sprintf("%s\r\n", string(bytes.Join(line, []byte("\r\n")))), nil
}