diff --git a/README.md b/README.md index 737680f..c2b2d40 100644 --- a/README.md +++ b/README.md @@ -126,7 +126,7 @@ If you want to read my code in this repository, here is a simple guidance. I suggest focusing on the following directories: -- server: the tcp server +- tcp: the tcp server - redis: the redis protocol parser - datastruct: the implements of data structures - dict: a concurrent hash map diff --git a/redis.conf b/redis.conf index 8a84b89..d2bc3ad 100644 --- a/redis.conf +++ b/redis.conf @@ -4,3 +4,6 @@ maxclients 128 appendonly yes appendfilename appendonly.aof + +peers localhost:6399 +self localhost:6399 diff --git a/src/cluster/cluster.go b/src/cluster/cluster.go new file mode 100644 index 0000000..2b63246 --- /dev/null +++ b/src/cluster/cluster.go @@ -0,0 +1,92 @@ +package cluster + +import ( + "fmt" + "github.com/HDT3213/godis/src/config" + "github.com/HDT3213/godis/src/db" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/lib/consistenthash" + "github.com/HDT3213/godis/src/lib/logger" + "github.com/HDT3213/godis/src/redis/client" + "github.com/HDT3213/godis/src/redis/reply" + "runtime/debug" + "strings" +) + +type Cluster struct { + self string + + db *db.DB + peerPicker *consistenthash.Map + peers map[string]*client.Client +} + +const ( + replicas = 4 +) + +func MakeCluster() *Cluster { + cluster := &Cluster{ + self: config.Properties.Self, + + db: db.MakeDB(), + peerPicker: consistenthash.New(replicas, nil), + peers: make(map[string]*client.Client), + } + if config.Properties.Peers != nil && len(config.Properties.Peers) > 0 { + cluster.peerPicker.Add(config.Properties.Peers...) + } + return cluster +} + +// args contains all +type CmdFunc func(cluster *Cluster, c redis.Client, args [][]byte) redis.Reply + +func (cluster *Cluster) Close() { + cluster.db.Close() +} + +var router = MakeRouter() + +func (cluster *Cluster) Exec(c redis.Client, args [][]byte) (result redis.Reply) { + defer func() { + if err := recover(); err != nil { + logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack()))) + result = &reply.UnknownErrReply{} + } + }() + + cmd := strings.ToLower(string(args[0])) + cmdFunc, ok := router[cmd] + if !ok { + return reply.MakeErrReply("ERR unknown command '" + cmd + "'") + } + result = cmdFunc(cluster, c, args) + return +} + +// relay command to peer +func (cluster *Cluster) Relay(key string, c redis.Client, args [][]byte) redis.Reply { + peer := cluster.peerPicker.Get(key) + if peer == cluster.self { + // to self db + return cluster.db.Exec(c, args) + } else { + peerClient, ok := cluster.peers[peer] + // lazy init + if !ok { + var err error + peerClient, err = client.MakeClient(peer) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + peerClient.Start() + cluster.peers[peer] = peerClient + } + return peerClient.Send(args) + } +} + +func (cluster *Cluster) AfterClientClose(c redis.Client) { + +} diff --git a/src/cluster/router.go b/src/cluster/router.go new file mode 100644 index 0000000..52d4c89 --- /dev/null +++ b/src/cluster/router.go @@ -0,0 +1,102 @@ +package cluster + +import "github.com/HDT3213/godis/src/interface/redis" + +func defaultFunc(cluster *Cluster, c redis.Client, args [][]byte) redis.Reply { + key := string(args[1]) + return cluster.Relay(key, c, args) +} + +func MakeRouter() map[string]CmdFunc { + routerMap := make(map[string]CmdFunc) + //routerMap["ping"] = defaultFunc + + //routerMap["del"] = Del + routerMap["expire"] = defaultFunc + routerMap["expireat"] = defaultFunc + routerMap["pexpire"] = defaultFunc + routerMap["pexpireat"] = defaultFunc + routerMap["ttl"] = defaultFunc + routerMap["pttl"] = defaultFunc + routerMap["persist"] = defaultFunc + routerMap["exists"] = defaultFunc + routerMap["type"] = defaultFunc + //routerMap["rename"] = Rename + //routerMap["renamenx"] = RenameNx + + routerMap["set"] = defaultFunc + routerMap["setnx"] = defaultFunc + routerMap["setex"] = defaultFunc + routerMap["psetex"] = defaultFunc + //routerMap["mset"] = MSet + //routerMap["mget"] = MGet + //routerMap["msetnx"] = MSetNX + routerMap["get"] = defaultFunc + routerMap["getset"] = defaultFunc + routerMap["incr"] = defaultFunc + routerMap["incrby"] = defaultFunc + routerMap["incrbyfloat"] = defaultFunc + routerMap["decr"] = defaultFunc + routerMap["decrby"] = defaultFunc + + routerMap["lpush"] = defaultFunc + routerMap["lpushx"] = defaultFunc + routerMap["rpush"] = defaultFunc + routerMap["rpushx"] = defaultFunc + routerMap["lpop"] = defaultFunc + routerMap["rpop"] = defaultFunc + //routerMap["rpoplpush"] = RPopLPush + routerMap["lrem"] = defaultFunc + routerMap["llen"] = defaultFunc + routerMap["lindex"] = defaultFunc + routerMap["lset"] = defaultFunc + routerMap["lrange"] = defaultFunc + + routerMap["hset"] = defaultFunc + routerMap["hsetnx"] = defaultFunc + routerMap["hget"] = defaultFunc + routerMap["hexists"] = defaultFunc + routerMap["hdel"] = defaultFunc + routerMap["hlen"] = defaultFunc + routerMap["hmget"] = defaultFunc + routerMap["hmset"] = defaultFunc + routerMap["hkeys"] = defaultFunc + routerMap["hvals"] = defaultFunc + routerMap["hgetall"] = defaultFunc + routerMap["hincrby"] = defaultFunc + routerMap["hincrbyfloat"] = defaultFunc + + routerMap["sadd"] = defaultFunc + routerMap["sismember"] = defaultFunc + routerMap["srem"] = defaultFunc + routerMap["scard"] = defaultFunc + routerMap["smembers"] = defaultFunc + routerMap["sinter"] = defaultFunc + routerMap["sinterstore"] = defaultFunc + routerMap["sunion"] = defaultFunc + routerMap["sunionstore"] = defaultFunc + routerMap["sdiff"] = defaultFunc + routerMap["sdiffstore"] = defaultFunc + routerMap["srandmember"] = defaultFunc + + routerMap["zadd"] = defaultFunc + routerMap["zscore"] = defaultFunc + routerMap["zincrby"] = defaultFunc + routerMap["zrank"] = defaultFunc + routerMap["zcount"] = defaultFunc + routerMap["zrevrank"] = defaultFunc + routerMap["zcard"] = defaultFunc + routerMap["zrange"] = defaultFunc + routerMap["zrevrange"] = defaultFunc + routerMap["zrangebyscore"] = defaultFunc + routerMap["zrevrangebyscore"] = defaultFunc + routerMap["zrem"] = defaultFunc + routerMap["zremrangebyscore"] = defaultFunc + routerMap["zremrangebyrank"] = defaultFunc + + //routerMap["flushdb"] = FlushDB + //routerMap["flushall"] = FlushAll + //routerMap["keys"] = Keys + + return routerMap +} diff --git a/src/cluster/string.go b/src/cluster/string.go new file mode 100644 index 0000000..916b1b5 --- /dev/null +++ b/src/cluster/string.go @@ -0,0 +1 @@ +package cluster diff --git a/src/cmd/main.go b/src/cmd/main.go index 91cf2b3..534d6be 100644 --- a/src/cmd/main.go +++ b/src/cmd/main.go @@ -4,9 +4,8 @@ import ( "fmt" "github.com/HDT3213/godis/src/config" "github.com/HDT3213/godis/src/lib/logger" - "github.com/HDT3213/godis/src/redis/handler" - "github.com/HDT3213/godis/src/server" - "time" + RedisServer "github.com/HDT3213/godis/src/redis/server" + "github.com/HDT3213/godis/src/tcp" ) func main() { @@ -18,9 +17,7 @@ func main() { TimeFormat: "2006-01-02", }) - server.ListenAndServe(&server.Config{ + tcp.ListenAndServe(&tcp.Config{ Address: fmt.Sprintf("%s:%d", config.Properties.Bind, config.Properties.Port), - MaxConnect: uint32(config.Properties.MaxClients), - Timeout: 2 * time.Second, - }, handler.MakeHandler()) + }, RedisServer.MakeHandler()) } diff --git a/src/config/config.go b/src/config/config.go index 3c13308..0cf1ccb 100644 --- a/src/config/config.go +++ b/src/config/config.go @@ -11,11 +11,13 @@ import ( ) type PropertyHolder struct { - Bind string `cfg:"bind"` - Port int `cfg:"port"` - AppendOnly bool `cfg:"appendOnly"` - AppendFilename string `cfg:"appendFilename"` - MaxClients int `cfg:"maxclients"` + Bind string `cfg:"bind"` + Port int `cfg:"port"` + AppendOnly bool `cfg:"appendOnly"` + AppendFilename string `cfg:"appendFilename"` + MaxClients int `cfg:"maxclients"` + Peers []string `cfg:"peers"` + Self string `cfg:"self"` } var Properties *PropertyHolder @@ -30,13 +32,7 @@ func init() { } func LoadConfig(configFilename string) *PropertyHolder { - // open config file - config := &PropertyHolder{ - Bind: "127.0.0.1", - Port: 6379, - AppendOnly: true, - AppendFilename: "appendonly.aof", - } + config := Properties file, err := os.Open(configFilename) if err != nil { log.Print(err) @@ -55,7 +51,7 @@ func LoadConfig(configFilename string) *PropertyHolder { pivot := strings.IndexAny(line, " ") if pivot > 0 && pivot < len(line)-1 { // separator found key := line[0:pivot] - value := line[pivot+1:] + value := strings.Trim(line[pivot+1:], " ") rawMap[strings.ToLower(key)] = value } } @@ -88,6 +84,11 @@ func LoadConfig(configFilename string) *PropertyHolder { case reflect.Bool: boolValue := "yes" == value fieldVal.SetBool(boolValue) + case reflect.Slice: + if field.Type.Elem().Kind() == reflect.String { + slice := strings.Split(value, ",") + fieldVal.Set(reflect.ValueOf(slice)) + } } } } diff --git a/src/lib/consistenthash/consistenthash.go b/src/lib/consistenthash/consistenthash.go new file mode 100644 index 0000000..090546b --- /dev/null +++ b/src/lib/consistenthash/consistenthash.go @@ -0,0 +1,62 @@ +package consistenthash + +import ( + "hash/crc32" + "sort" + "strconv" +) + +type HashFunc func(data []byte) uint32 + +type Map struct { + hashFunc HashFunc + replicas int + keys []int // sorted + hashMap map[int]string +} + +func New(replicas int, fn HashFunc) *Map { + m := &Map{ + replicas: replicas, + hashFunc: fn, + hashMap: make(map[int]string), + } + if m.hashFunc == nil { + m.hashFunc = crc32.ChecksumIEEE + } + return m +} + +func (m *Map) IsEmpty() bool { + return len(m.keys) == 0 +} + +func (m *Map) Add(keys ...string) { + for _, key := range keys { + for i := 0; i < m.replicas; i++ { + hash := int(m.hashFunc([]byte(strconv.Itoa(i) + key))) + m.keys = append(m.keys, hash) + m.hashMap[hash] = key + } + } + sort.Ints(m.keys) +} + +// Get gets the closest item in the hash to the provided key. +func (m *Map) Get(key string) string { + if m.IsEmpty() { + return "" + } + + hash := int(m.hashFunc([]byte(key))) + + // Binary search for appropriate replica. + idx := sort.Search(len(m.keys), func(i int) bool { return m.keys[i] >= hash }) + + // Means we have cycled back to the first replica. + if idx == len(m.keys) { + idx = 0 + } + + return m.hashMap[m.keys[idx]] +} diff --git a/src/redis/client/client.go b/src/redis/client/client.go new file mode 100644 index 0000000..467db98 --- /dev/null +++ b/src/redis/client/client.go @@ -0,0 +1,309 @@ +package client + +import ( + "bufio" + "context" + "errors" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/lib/logger" + "github.com/HDT3213/godis/src/redis/reply" + "io" + "net" + "strconv" + "strings" + "sync" + "time" +) + +type Client struct { + conn net.Conn + sendingReqs chan *Request + waitingReqs chan *Request + ticker *time.Ticker + addr string + + ctx context.Context + cancelFunc context.CancelFunc + writing *sync.WaitGroup +} + +type Request struct { + args [][]byte + reply redis.Reply + heartbeat bool + waiting *sync.WaitGroup +} + +const ( + chanSize = 256 +) + +func MakeClient(addr string) (*Client, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(context.Background()) + return &Client{ + addr: addr, + conn: conn, + sendingReqs: make(chan *Request, chanSize), + waitingReqs: make(chan *Request, chanSize), + ctx: ctx, + cancelFunc: cancel, + writing: &sync.WaitGroup{}, + }, nil +} + +func (client *Client) Start() { + client.ticker = time.NewTicker(10 * time.Second) + go client.handleWrite() + go func() { + err := client.handleRead() + logger.Warn(err) + }() + go client.heartbeat() +} + +func (client *Client) Close() { + // send stop signal + client.cancelFunc() + + // wait stop process + client.writing.Wait() + + // clean + _ = client.conn.Close() + close(client.sendingReqs) + close(client.waitingReqs) +} + +func (client *Client) handleConnectionError(err error) error { + err1 := client.conn.Close() + if err1 != nil { + if opErr, ok := err1.(*net.OpError); ok { + if opErr.Err.Error() != "use of closed network connection" { + return err1 + } + } else { + return err1 + } + } + conn, err1 := net.Dial("tcp", client.addr) + if err1 != nil { + logger.Error(err1) + return err1 + } + client.conn = conn + go func() { + _ = client.handleRead() + }() + return nil +} + +func (client *Client) heartbeat() { +loop: + for { + select { + case <-client.ticker.C: + client.sendingReqs <- &Request{ + args: [][]byte{[]byte("PING")}, + heartbeat: true, + } + case <-client.ctx.Done(): + break loop + } + } +} + +func (client *Client) handleWrite() { + client.writing.Add(1) +loop: + for { + select { + case req := <-client.sendingReqs: + client.doRequest(req) + case <-client.ctx.Done(): + break loop + } + } + client.writing.Done() +} + +func (client *Client) Send(args [][]byte) redis.Reply { + request := &Request{ + args: args, + heartbeat: false, + waiting: &sync.WaitGroup{}, + } + request.waiting.Add(1) + client.sendingReqs <- request + request.waiting.Wait() + return request.reply +} + +func (client *Client) doRequest(req *Request) { + bytes := reply.MakeMultiBulkReply(req.args).ToBytes() + _, err := client.conn.Write(bytes) + i := 0 + for err != nil && i < 3 { + err = client.handleConnectionError(err) + if err == nil { + _, err = client.conn.Write(bytes) + } + i++ + } + if err == nil { + client.waitingReqs <- req + } +} + +func (client *Client) finishRequest(reply redis.Reply) { + request := <-client.waitingReqs + request.reply = reply + if request.waiting != nil { + request.waiting.Done() + } +} + +func (client *Client) handleRead() error { + reader := bufio.NewReader(client.conn) + downloading := false + expectedArgsCount := 0 + receivedCount := 0 + msgType := byte(0) // first char of msg + var args [][]byte + var fixedLen int64 = 0 + var err error + var msg []byte + for { + // read line + if fixedLen == 0 { // read normal line + msg, err = reader.ReadBytes('\n') + if err != nil { + if err == io.EOF || err == io.ErrUnexpectedEOF { + logger.Info("connection close") + } else { + logger.Warn(err) + } + + return errors.New("connection closed") + } + if len(msg) == 0 || msg[len(msg)-2] != '\r' { + return errors.New("protocol error") + } + } else { // read bulk line (binary safe) + msg = make([]byte, fixedLen+2) + _, err = io.ReadFull(reader, msg) + if err != nil { + if err == io.EOF || err == io.ErrUnexpectedEOF { + return errors.New("connection closed") + } else { + return err + } + } + if len(msg) == 0 || + msg[len(msg)-2] != '\r' || + msg[len(msg)-1] != '\n' { + return errors.New("protocol error") + } + fixedLen = 0 + } + + // parse line + if !downloading { + // receive new response + if msg[0] == '*' { // multi bulk response + // bulk multi msg + expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32) + if err != nil { + return errors.New("protocol error: " + err.Error()) + } + if expectedLine == 0 { + client.finishRequest(&reply.EmptyMultiBulkReply{}) + } else if expectedLine > 0 { + msgType = msg[0] + downloading = true + expectedArgsCount = int(expectedLine) + receivedCount = 0 + args = make([][]byte, expectedLine) + } else { + return errors.New("protocol error") + } + } else if msg[0] == '$' { // bulk response + fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64) + if err != nil { + return err + } + if fixedLen == -1 { // null bulk + client.finishRequest(&reply.NullBulkReply{}) + fixedLen = 0 + } else if fixedLen > 0 { + msgType = msg[0] + downloading = true + expectedArgsCount = 1 + receivedCount = 0 + args = make([][]byte, 1) + } else { + return errors.New("protocol error") + } + } else { // single line response + str := strings.TrimSuffix(string(msg), "\n") + str = strings.TrimSuffix(str, "\r") + var result redis.Reply + switch msg[0] { + case '+': + result = reply.MakeStatusReply(str[1:]) + case '-': + result = reply.MakeErrReply(str[1:]) + case ':': + val, err := strconv.ParseInt(str[1:], 10, 64) + if err != nil { + return errors.New("protocol error") + } + result = reply.MakeIntReply(val) + } + client.finishRequest(result) + } + } else { + // receive following part of a request + line := msg[0 : len(msg)-2] + if line[0] == '$' { + fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64) + if err != nil { + return err + } + if fixedLen <= 0 { // null bulk in multi bulks + args[receivedCount] = []byte{} + receivedCount++ + fixedLen = 0 + } + } else { + args[receivedCount] = line + receivedCount++ + } + + // if sending finished + if receivedCount == expectedArgsCount { + downloading = false // finish downloading progress + + request := <-client.waitingReqs + if msgType == '*' { + request.reply = reply.MakeMultiBulkReply(args) + } else if msgType == '$' { + request.reply = reply.MakeBulkReply(args[0]) + } + + if request.waiting != nil { + request.waiting.Done() + } + + // finish reply + expectedArgsCount = 0 + receivedCount = 0 + args = nil + msgType = byte(0) + } + } + } +} diff --git a/src/redis/client/client_test.go b/src/redis/client/client_test.go new file mode 100644 index 0000000..b1eeee4 --- /dev/null +++ b/src/redis/client/client_test.go @@ -0,0 +1,104 @@ +package client + +import ( + "github.com/HDT3213/godis/src/lib/logger" + "github.com/HDT3213/godis/src/redis/reply" + "testing" +) + +func TestClient(t *testing.T) { + logger.Setup(&logger.Settings{ + Path: "logs", + Name: "godis", + Ext: ".log", + TimeFormat: "2006-01-02", + }) + client, err := MakeClient("localhost:6379") + if err != nil { + t.Error(err) + } + client.Start() + + result := client.Send([][]byte{ + []byte("PING"), + }) + if statusRet, ok := result.(*reply.StatusReply); ok { + if statusRet.Status != "PONG" { + t.Error("`ping` failed, result: " + statusRet.Status) + } + } + + result = client.Send([][]byte{ + []byte("SET"), + []byte("a"), + []byte("a"), + }) + if statusRet, ok := result.(*reply.StatusReply); ok { + if statusRet.Status != "OK" { + t.Error("`set` failed, result: " + statusRet.Status) + } + } + + result = client.Send([][]byte{ + []byte("GET"), + []byte("a"), + }) + if bulkRet, ok := result.(*reply.BulkReply); ok { + if string(bulkRet.Arg) != "a" { + t.Error("`get` failed, result: " + string(bulkRet.Arg)) + } + } + + result = client.Send([][]byte{ + []byte("DEL"), + []byte("a"), + }) + if intRet, ok := result.(*reply.IntReply); ok { + if intRet.Code != 1 { + t.Error("`del` failed, result: " + string(intRet.Code)) + } + } + + result = client.Send([][]byte{ + []byte("GET"), + []byte("a"), + }) + if _, ok := result.(*reply.NullBulkReply); !ok { + t.Error("`get` failed, result: " + string(result.ToBytes())) + } + + result = client.Send([][]byte{ + []byte("DEL"), + []byte("arr"), + }) + + result = client.Send([][]byte{ + []byte("RPUSH"), + []byte("arr"), + []byte("1"), + []byte("2"), + []byte("c"), + }) + if intRet, ok := result.(*reply.IntReply); ok { + if intRet.Code != 3 { + t.Error("`rpush` failed, result: " + string(intRet.Code)) + } + } + + result = client.Send([][]byte{ + []byte("LRANGE"), + []byte("arr"), + []byte("0"), + []byte("-1"), + }) + if multiBulkRet, ok := result.(*reply.MultiBulkReply); ok { + if len(multiBulkRet.Args) != 3 || + string(multiBulkRet.Args[0]) != "1" || + string(multiBulkRet.Args[1]) != "2" || + string(multiBulkRet.Args[2]) != "c" { + t.Error("`lrange` failed, result: " + string(multiBulkRet.ToBytes())) + } + } + + client.Close() +} diff --git a/src/redis/handler/client.go b/src/redis/server/client.go similarity index 99% rename from src/redis/handler/client.go rename to src/redis/server/client.go index 5dce57b..8be64a5 100644 --- a/src/redis/handler/client.go +++ b/src/redis/server/client.go @@ -1,4 +1,4 @@ -package handler +package server import ( "github.com/HDT3213/godis/src/lib/sync/atomic" diff --git a/src/redis/handler/handler.go b/src/redis/server/handler.go similarity index 95% rename from src/redis/handler/handler.go rename to src/redis/server/handler.go index 38ec00d..25fc644 100644 --- a/src/redis/handler/handler.go +++ b/src/redis/server/handler.go @@ -1,4 +1,4 @@ -package handler +package server /* * A tcp.Handler implements redis protocol @@ -7,6 +7,8 @@ package handler import ( "bufio" "context" + "github.com/HDT3213/godis/src/cluster" + "github.com/HDT3213/godis/src/config" DBImpl "github.com/HDT3213/godis/src/db" "github.com/HDT3213/godis/src/interface/db" "github.com/HDT3213/godis/src/lib/logger" @@ -30,8 +32,15 @@ type Handler struct { } func MakeHandler() *Handler { + var db db.DB + if config.Properties.Peers != nil && + len(config.Properties.Peers) > 0 { + db = cluster.MakeCluster() + } else { + db = DBImpl.MakeDB() + } return &Handler{ - db: DBImpl.MakeDB(), + db: db, } } diff --git a/src/server/echo.go b/src/tcp/echo.go similarity index 99% rename from src/server/echo.go rename to src/tcp/echo.go index 93f041a..586dcb5 100644 --- a/src/server/echo.go +++ b/src/tcp/echo.go @@ -1,4 +1,4 @@ -package server +package tcp /** * A echo server to test whether the server is functioning normally diff --git a/src/server/server.go b/src/tcp/server.go similarity index 99% rename from src/server/server.go rename to src/tcp/server.go index b4b89dc..0b6f7b8 100644 --- a/src/server/server.go +++ b/src/tcp/server.go @@ -1,4 +1,4 @@ -package server +package tcp /** * A tcp server