mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-18 14:00:42 +08:00
217 lines
4.2 KiB
Go
217 lines
4.2 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"sync"
|
|
|
|
"github.com/kelvinmwinuka/memstore/serialization"
|
|
"github.com/tidwall/resp"
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
type Config struct {
|
|
TLS bool `json:"tls" yaml:"tls"`
|
|
Key string `json:"key" yaml:"key"`
|
|
Cert string `json:"cert" yaml:"cert"`
|
|
HTTP bool `json:"http" yaml:"http"`
|
|
Port uint16 `json:"port" yaml:"port"`
|
|
}
|
|
|
|
type Data struct {
|
|
mu sync.Mutex
|
|
data map[string]interface{}
|
|
}
|
|
|
|
type Server struct {
|
|
config Config
|
|
data Data
|
|
}
|
|
|
|
func (server *Server) hanndleConnection(conn net.Conn) {
|
|
connRW := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
|
respWriter := resp.NewWriter(connRW)
|
|
|
|
var line [][]byte
|
|
|
|
for {
|
|
b, _, err := connRW.ReadLine()
|
|
|
|
if err != nil && err == io.EOF {
|
|
fmt.Println(err)
|
|
break
|
|
}
|
|
|
|
line = append(line, b)
|
|
|
|
if bytes.Equal(b, []byte("")) {
|
|
// End of RESP message
|
|
// sw.Write(bytes.Join(line, []byte("\\r\\n")))
|
|
// sw.Flush()
|
|
|
|
if cmd, err := serialization.Decode(string(bytes.Join(line, []byte("\r\n")))); err != nil {
|
|
fmt.Println(err)
|
|
// Return error to client
|
|
continue
|
|
} else {
|
|
// Return encoded message to client
|
|
|
|
if len(cmd) == 1 && cmd[0] == "PING" {
|
|
serialization.EncodeSimpleString(respWriter, "PONG")
|
|
connRW.Flush()
|
|
}
|
|
|
|
if len(cmd) == 2 && cmd[0] == "PING" {
|
|
fmt.Println(cmd)
|
|
serialization.EncodeSimpleString(respWriter, cmd[1])
|
|
connRW.Flush()
|
|
}
|
|
}
|
|
|
|
line = [][]byte{}
|
|
}
|
|
}
|
|
|
|
conn.Close()
|
|
}
|
|
|
|
func (server *Server) StartTCP() {
|
|
conf := server.config
|
|
var listener net.Listener
|
|
|
|
if conf.TLS {
|
|
// TLS
|
|
fmt.Println("TCP/TLS mode enabled...")
|
|
cer, err := tls.LoadX509KeyPair(conf.Cert, conf.Key)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if l, err := tls.Listen("tcp", fmt.Sprintf("%s:%d", "localhost", conf.Port), &tls.Config{
|
|
Certificates: []tls.Certificate{cer},
|
|
}); err != nil {
|
|
panic(err)
|
|
} else {
|
|
listener = l
|
|
}
|
|
}
|
|
|
|
if !conf.TLS {
|
|
// TCP
|
|
fmt.Println("Starting server in TCP mode...")
|
|
if l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", "localhost", conf.Port)); err != nil {
|
|
panic(err)
|
|
} else {
|
|
listener = l
|
|
}
|
|
}
|
|
|
|
// Listen to connection
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
fmt.Println("Could not establish connection")
|
|
continue
|
|
}
|
|
// Read loop for connection
|
|
go server.hanndleConnection(conn)
|
|
}
|
|
}
|
|
|
|
func (server *Server) StartHTTP() {
|
|
conf := server.config
|
|
|
|
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("Hello from memstore!"))
|
|
})
|
|
|
|
var err error
|
|
|
|
if conf.TLS {
|
|
fmt.Println("Starting server in HTTPS mode...")
|
|
err = http.ListenAndServeTLS(fmt.Sprintf("%s:%d", "localhost", conf.Port), conf.Cert, conf.Key, nil)
|
|
} else {
|
|
fmt.Println("Starting server in HTTP mode...")
|
|
err = http.ListenAndServe(fmt.Sprintf("%s:%d", "localhost", conf.Port), nil)
|
|
}
|
|
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func (server *Server) Start() {
|
|
conf := server.config
|
|
|
|
if conf.TLS && (len(conf.Key) <= 0 || len(conf.Cert) <= 0) {
|
|
fmt.Println("Must provide key and certificate file paths for TLS mode.")
|
|
return
|
|
}
|
|
|
|
if conf.HTTP {
|
|
server.StartHTTP()
|
|
} else {
|
|
server.StartTCP()
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
tls := flag.Bool("tls", false, "Start the server in TLS mode. Default is false")
|
|
key := flag.String("key", "", "The private key file path.")
|
|
cert := flag.String("cert", "", "The signed certificate file path.")
|
|
http := flag.Bool("http", false, "Use HTTP protocol instead of raw TCP. Default is false")
|
|
port := flag.Int("port", 7480, "Port to use. Default is 7480")
|
|
config := flag.String(
|
|
"config",
|
|
"",
|
|
`File path to a JSON or YAML config file.The values in this config file will override the flag values.`,
|
|
)
|
|
|
|
flag.Parse()
|
|
|
|
var conf Config
|
|
|
|
if len(*config) > 0 {
|
|
// Load config from config file
|
|
if f, err := os.Open(*config); err != nil {
|
|
panic(err)
|
|
} else {
|
|
defer f.Close()
|
|
|
|
ext := path.Ext(f.Name())
|
|
|
|
if ext == ".json" {
|
|
json.NewDecoder(f).Decode(&conf)
|
|
}
|
|
|
|
if ext == ".yaml" || ext == ".yml" {
|
|
yaml.NewDecoder(f).Decode(&conf)
|
|
}
|
|
}
|
|
|
|
} else {
|
|
conf = Config{
|
|
TLS: *tls,
|
|
Key: *key,
|
|
Cert: *cert,
|
|
HTTP: *http,
|
|
Port: uint16(*port),
|
|
}
|
|
}
|
|
|
|
server := Server{
|
|
config: conf,
|
|
}
|
|
|
|
server.Start()
|
|
}
|