diff --git a/README.md b/README.md index 00d4f48..9b13cd4 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ Key Features: - Replication (experimental) - Server-side Cluster which is transparent to client. You can connect to any node in the cluster to access all data in the cluster. + - Use the raft algorithm to maintain cluster metadata. (experimental) - `MSET`, `MSETNX`, `DEL`, `Rename`, `RenameNX` command is supported and atomically executed in cluster mode, allow over multi node - `MULTI` Commands Transaction is supported within slot in cluster mode - Concurrent Core, so you don't have to worry about your commands blocking the server too much. diff --git a/README_CN.md b/README_CN.md index 4997d7b..64a2053 100644 --- a/README_CN.md +++ b/README_CN.md @@ -19,6 +19,7 @@ Godis 是一个用 Go 语言实现的 Redis 服务器。本项目旨在为尝试 - 主从复制 (测试中) - Multi 命令开启的事务具有`原子性`和`隔离性`. 若在执行过程中遇到错误, godis 会回滚已执行的命令 - 内置集群模式. 集群对客户端是透明的, 您可以像使用单机版 redis 一样使用 godis 集群 + - 使用 raft 算法维护集群元数据(测试中) - `MSET`, `MSETNX`, `DEL`, `Rename`, `RenameNX` 命令在集群模式下原子性执行, 允许 key 在集群的不同节点上 - 在集群模式下支持在同一个 slot 内执行事务 - 并行引擎, 无需担心您的操作会阻塞整个服务器. diff --git a/cluster/cluster.go b/cluster/cluster.go index da02780..6581ce8 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -11,98 +11,109 @@ import ( "github.com/hdt3213/godis/config" database2 "github.com/hdt3213/godis/database" "github.com/hdt3213/godis/datastruct/dict" + "github.com/hdt3213/godis/datastruct/set" "github.com/hdt3213/godis/interface/database" "github.com/hdt3213/godis/interface/redis" - "github.com/hdt3213/godis/lib/consistenthash" "github.com/hdt3213/godis/lib/idgenerator" "github.com/hdt3213/godis/lib/logger" - "github.com/hdt3213/godis/lib/pool" - "github.com/hdt3213/godis/lib/utils" - "github.com/hdt3213/godis/redis/client" + "github.com/hdt3213/godis/redis/parser" "github.com/hdt3213/godis/redis/protocol" + "os" + "path" + "sync" ) -type PeerPicker interface { - AddNode(keys ...string) - PickNode(key string) string -} - // Cluster represents a node of godis cluster // it holds part of data and coordinates other nodes to finish transactions type Cluster struct { - self string + self string + addr string + db database.DBEngine + transactions *dict.SimpleDict // id -> Transaction + transactionMu sync.RWMutex + topology topology + slotMu sync.RWMutex + slots map[uint32]*hostSlot + idGenerator *idgenerator.IDGenerator - nodes []string - peerPicker PeerPicker - nodeConnections map[string]*pool.Pool + clientFactory clientFactory +} - db database.DBEngine - transactions *dict.SimpleDict // id -> Transaction +type peerClient interface { + Send(args [][]byte) redis.Reply +} - idGenerator *idgenerator.IDGenerator - // use a variable to allow injecting stub for testing - relayImpl func(cluster *Cluster, node string, c redis.Connection, cmdLine CmdLine) redis.Reply +type peerStream interface { + Stream() <-chan *parser.Payload + Close() error +} + +type clientFactory interface { + GetPeerClient(peerAddr string) (peerClient, error) + ReturnPeerClient(peerAddr string, peerClient peerClient) error + NewStream(peerAddr string, cmdLine CmdLine) (peerStream, error) + Close() error } const ( - replicas = 4 + slotStateHost = iota + slotStateImporting + slotStateMovingOut ) +// hostSlot stores status of host which hosted by current node +type hostSlot struct { + state uint32 + mu sync.RWMutex + // OldNodeID is the node which is moving out this slot + // only valid during slot is importing + oldNodeID string + // OldNodeID is the node which is importing this slot + // only valid during slot is moving out + newNodeID string + + /* importedKeys stores imported keys during migrating progress + * While this slot is migrating, if importedKeys does not have the given key, then current node will import key before execute commands + * + * In a migrating slot, the slot on the old node is immutable, we only delete a key in the new node. + * Therefore, we must distinguish between non-migrated key and deleted key. + * Even if a key has been deleted, it still exists in importedKeys, so we can distinguish between non-migrated and deleted. + */ + importedKeys *set.Set + // keys stores all keys in this slot + // Cluster.makeInsertCallback and Cluster.makeDeleteCallback will keep keys up to time + keys *set.Set +} + // if only one node involved in a transaction, just execute the command don't apply tcc procedure var allowFastTransaction = true // MakeCluster creates and starts a node of cluster func MakeCluster() *Cluster { cluster := &Cluster{ - self: config.Properties.Self, - db: database2.NewStandaloneServer(), - transactions: dict.MakeSimple(), - peerPicker: consistenthash.New(replicas, nil), - nodeConnections: make(map[string]*pool.Pool), - - idGenerator: idgenerator.MakeGenerator(config.Properties.Self), - relayImpl: defaultRelayImpl, + self: config.Properties.Self, + addr: config.Properties.AnnounceAddress(), + db: database2.NewStandaloneServer(), + transactions: dict.MakeSimple(), + idGenerator: idgenerator.MakeGenerator(config.Properties.Self), + clientFactory: newDefaultClientFactory(), } - - contains := make(map[string]struct{}) - nodes := make([]string, 0, len(config.Properties.Peers)+1) - for _, peer := range config.Properties.Peers { - if _, ok := contains[peer]; ok { - continue - } - contains[peer] = struct{}{} - nodes = append(nodes, peer) + topologyPersistFile := path.Join(config.Properties.Dir, config.Properties.ClusterConfigFile) + cluster.topology = newRaft(cluster, topologyPersistFile) + cluster.db.SetKeyInsertedCallback(cluster.makeInsertCallback()) + cluster.db.SetKeyDeletedCallback(cluster.makeDeleteCallback()) + cluster.slots = make(map[uint32]*hostSlot) + var err error + if topologyPersistFile != "" && fileExists(topologyPersistFile) { + err = cluster.LoadConfig() + } else if config.Properties.ClusterAsSeed { + err = cluster.startAsSeed(config.Properties.AnnounceAddress()) + } else { + err = cluster.Join(config.Properties.ClusterSeed) } - nodes = append(nodes, config.Properties.Self) - cluster.peerPicker.AddNode(nodes...) - connectionPoolConfig := pool.Config{ - MaxIdle: 1, - MaxActive: 16, + if err != nil { + panic(err) } - for _, p := range config.Properties.Peers { - peer := p - factory := func() (interface{}, error) { - c, err := client.MakeClient(peer) - if err != nil { - return nil, err - } - c.Start() - // all peers of cluster should use the same password - if config.Properties.RequirePass != "" { - c.Send(utils.ToCmdLine("AUTH", config.Properties.RequirePass)) - } - return c, nil - } - finalizer := func(x interface{}) { - cli, ok := x.(client.Client) - if !ok { - return - } - cli.Close() - } - cluster.nodeConnections[peer] = pool.New(factory, finalizer, connectionPoolConfig) - } - cluster.nodes = nodes return cluster } @@ -111,14 +122,11 @@ type CmdFunc func(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.R // Close stops current node of cluster func (cluster *Cluster) Close() { + _ = cluster.topology.Close() cluster.db.Close() - for _, pool := range cluster.nodeConnections { - pool.Close() - } + cluster.clientFactory.Close() } -var router = makeRouter() - func isAuthenticated(c redis.Connection) bool { if config.Properties.RequirePass == "" { return true @@ -163,10 +171,7 @@ func (cluster *Cluster) Exec(c redis.Connection, cmdLine [][]byte) (result redis } return execMulti(cluster, c, nil) } else if cmdName == "select" { - if len(cmdLine) != 2 { - return protocol.MakeArgNumErrReply(cmdName) - } - return execSelect(c, cmdLine) + return protocol.MakeErrReply("select not supported in cluster") } if c != nil && c.InMultiState() { return database2.EnqueueCmd(c, cmdLine) @@ -187,3 +192,38 @@ func (cluster *Cluster) AfterClientClose(c redis.Connection) { func (cluster *Cluster) LoadRDB(dec *core.Decoder) error { return cluster.db.LoadRDB(dec) } + +func (cluster *Cluster) makeInsertCallback() database.KeyEventCallback { + return func(dbIndex int, key string, entity *database.DataEntity) { + slotId := getSlot(key) + cluster.slotMu.RLock() + slot, ok := cluster.slots[slotId] + cluster.slotMu.RUnlock() + // As long as the command is executed, we should update slot.keys regardless of slot.state + if ok { + slot.mu.Lock() + defer slot.mu.Unlock() + slot.keys.Add(key) + } + } +} + +func (cluster *Cluster) makeDeleteCallback() database.KeyEventCallback { + return func(dbIndex int, key string, entity *database.DataEntity) { + slotId := getSlot(key) + cluster.slotMu.RLock() + slot, ok := cluster.slots[slotId] + cluster.slotMu.RUnlock() + // As long as the command is executed, we should update slot.keys regardless of slot.state + if ok { + slot.mu.Lock() + defer slot.mu.Unlock() + slot.keys.Remove(key) + } + } +} + +func fileExists(filename string) bool { + info, err := os.Stat(filename) + return err == nil && !info.IsDir() +} diff --git a/cluster/com.go b/cluster/com.go index d41d9cc..c667d81 100644 --- a/cluster/com.go +++ b/cluster/com.go @@ -1,69 +1,86 @@ package cluster import ( - "errors" "github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/lib/utils" - "github.com/hdt3213/godis/redis/client" + "github.com/hdt3213/godis/redis/connection" "github.com/hdt3213/godis/redis/protocol" - "strconv" ) -func (cluster *Cluster) getPeerClient(peer string) (*client.Client, error) { - pool, ok := cluster.nodeConnections[peer] - if !ok { - return nil, errors.New("connection pool not found") - } - raw, err := pool.Get() - if err != nil { - return nil, err - } - conn, ok := raw.(*client.Client) - if !ok { - return nil, errors.New("connection pool make wrong type") - } - return conn, nil -} - -func (cluster *Cluster) returnPeerClient(peer string, peerClient *client.Client) error { - pool, ok := cluster.nodeConnections[peer] - if !ok { - return errors.New("connection pool not found") - } - pool.Put(peerClient) - return nil -} - -var defaultRelayImpl = func(cluster *Cluster, node string, c redis.Connection, cmdLine CmdLine) redis.Reply { - if node == cluster.self { +// relay function relays command to peer or calls cluster.Exec +func (cluster *Cluster) relay(peerId string, c redis.Connection, cmdLine [][]byte) redis.Reply { + // use a variable to allow injecting stub for testing, see defaultRelayImpl + if peerId == cluster.self { // to self db - return cluster.db.Exec(c, cmdLine) + return cluster.Exec(c, cmdLine) } - peerClient, err := cluster.getPeerClient(node) + // peerId is peer.Addr + cli, err := cluster.clientFactory.GetPeerClient(peerId) if err != nil { return protocol.MakeErrReply(err.Error()) } defer func() { - _ = cluster.returnPeerClient(node, peerClient) + _ = cluster.clientFactory.ReturnPeerClient(peerId, cli) }() - peerClient.Send(utils.ToCmdLine("SELECT", strconv.Itoa(c.GetDBIndex()))) - return peerClient.Send(cmdLine) + return cli.Send(cmdLine) } -// relay function relays command to peer -// select db by c.GetDBIndex() -// cannot call Prepare, Commit, execRollback of self node -func (cluster *Cluster) relay(peer string, c redis.Connection, args [][]byte) redis.Reply { - // use a variable to allow injecting stub for testing - return cluster.relayImpl(cluster, peer, c, args) +// relayByKey function relays command to peer +// use routeKey to determine peer node +func (cluster *Cluster) relayByKey(routeKey string, c redis.Connection, args [][]byte) redis.Reply { + slotId := getSlot(routeKey) + peer := cluster.pickNode(slotId) + return cluster.relay(peer.ID, c, args) } // broadcast function broadcasts command to all node in cluster func (cluster *Cluster) broadcast(c redis.Connection, args [][]byte) map[string]redis.Reply { result := make(map[string]redis.Reply) - for _, node := range cluster.nodes { - reply := cluster.relay(node, c, args) - result[node] = reply + for _, node := range cluster.topology.GetNodes() { + reply := cluster.relay(node.ID, c, args) + result[node.Addr] = reply } return result } + +// ensureKey will migrate key to current node if the key is in a slot migrating to current node +// invoker should provide with locks of key +func (cluster *Cluster) ensureKey(key string) protocol.ErrorReply { + slotId := getSlot(key) + cluster.slotMu.RLock() + slot := cluster.slots[slotId] + cluster.slotMu.RUnlock() + if slot == nil { + return nil + } + if slot.state != slotStateImporting || slot.importedKeys.Has(key) { + return nil + } + resp := cluster.relay(slot.oldNodeID, connection.NewFakeConn(), utils.ToCmdLine("DumpKey_", key)) + if protocol.IsErrorReply(resp) { + return resp.(protocol.ErrorReply) + } + if protocol.IsEmptyMultiBulkReply(resp) { + slot.importedKeys.Add(key) + return nil + } + dumpResp := resp.(*protocol.MultiBulkReply) + if len(dumpResp.Args) != 2 { + return protocol.MakeErrReply("illegal dump key response") + } + // reuse copy to command ^_^ + resp = cluster.db.Exec(connection.NewFakeConn(), [][]byte{ + []byte("CopyTo"), []byte(key), dumpResp.Args[0], dumpResp.Args[1], + }) + if protocol.IsErrorReply(resp) { + return resp.(protocol.ErrorReply) + } + slot.importedKeys.Add(key) + return nil +} + +func (cluster *Cluster) ensureKeyWithoutLock(key string) protocol.ErrorReply { + cluster.db.RWLocks(0, []string{key}, nil) + defer cluster.db.RWUnLocks(0, []string{key}, nil) + return cluster.ensureKey(key) +} diff --git a/cluster/com_factory.go b/cluster/com_factory.go new file mode 100644 index 0000000..c9ab8ae --- /dev/null +++ b/cluster/com_factory.go @@ -0,0 +1,142 @@ +package cluster + +import ( + "errors" + "fmt" + "github.com/hdt3213/godis/config" + "github.com/hdt3213/godis/datastruct/dict" + "github.com/hdt3213/godis/interface/redis" + "github.com/hdt3213/godis/lib/logger" + "github.com/hdt3213/godis/lib/pool" + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/client" + "github.com/hdt3213/godis/redis/parser" + "github.com/hdt3213/godis/redis/protocol" + "net" +) + +type defaultClientFactory struct { + nodeConnections dict.Dict // map[string]*pool.Pool +} + +var connectionPoolConfig = pool.Config{ + MaxIdle: 1, + MaxActive: 16, +} + +// GetPeerClient gets a client with peer form pool +func (factory *defaultClientFactory) GetPeerClient(peerAddr string) (peerClient, error) { + var connectionPool *pool.Pool + raw, ok := factory.nodeConnections.Get(peerAddr) + if !ok { + creator := func() (interface{}, error) { + c, err := client.MakeClient(peerAddr) + if err != nil { + return nil, err + } + c.Start() + // all peers of cluster should use the same password + if config.Properties.RequirePass != "" { + authResp := c.Send(utils.ToCmdLine("AUTH", config.Properties.RequirePass)) + if !protocol.IsOKReply(authResp) { + return nil, fmt.Errorf("auth failed, resp: %s", string(authResp.ToBytes())) + } + } + return c, nil + } + finalizer := func(x interface{}) { + logger.Debug("destroy client") + cli, ok := x.(client.Client) + if !ok { + return + } + cli.Close() + } + connectionPool = pool.New(creator, finalizer, connectionPoolConfig) + factory.nodeConnections.Put(peerAddr, connectionPool) + } else { + connectionPool = raw.(*pool.Pool) + } + raw, err := connectionPool.Get() + if err != nil { + return nil, err + } + conn, ok := raw.(*client.Client) + if !ok { + return nil, errors.New("connection pool make wrong type") + } + return conn, nil +} + +// ReturnPeerClient returns client to pool +func (factory *defaultClientFactory) ReturnPeerClient(peer string, peerClient peerClient) error { + raw, ok := factory.nodeConnections.Get(peer) + if !ok { + return errors.New("connection pool not found") + } + raw.(*pool.Pool).Put(peerClient) + return nil +} + +type tcpStream struct { + conn net.Conn + ch <-chan *parser.Payload +} + +func (s *tcpStream) Stream() <-chan *parser.Payload { + return s.ch +} + +func (s *tcpStream) Close() error { + return s.conn.Close() +} + +func (factory *defaultClientFactory) NewStream(peerAddr string, cmdLine CmdLine) (peerStream, error) { + // todo: reuse connection + conn, err := net.Dial("tcp", peerAddr) + if err != nil { + return nil, fmt.Errorf("connect with %s failed: %v", peerAddr, err) + } + ch := parser.ParseStream(conn) + send2node := func(cmdLine CmdLine) redis.Reply { + req := protocol.MakeMultiBulkReply(cmdLine) + _, err := conn.Write(req.ToBytes()) + if err != nil { + return protocol.MakeErrReply(err.Error()) + } + resp := <-ch + if resp.Err != nil { + return protocol.MakeErrReply(resp.Err.Error()) + } + return resp.Data + } + if config.Properties.RequirePass != "" { + authResp := send2node(utils.ToCmdLine("AUTH", config.Properties.RequirePass)) + if !protocol.IsOKReply(authResp) { + return nil, fmt.Errorf("auth failed, resp: %s", string(authResp.ToBytes())) + } + } + req := protocol.MakeMultiBulkReply(cmdLine) + _, err = conn.Write(req.ToBytes()) + if err != nil { + return nil, protocol.MakeErrReply("send cmdLine failed: " + err.Error()) + } + return &tcpStream{ + conn: conn, + ch: ch, + }, nil +} + +func newDefaultClientFactory() *defaultClientFactory { + return &defaultClientFactory{ + nodeConnections: dict.MakeConcurrent(1), + } +} + +func (factory *defaultClientFactory) Close() error { + factory.nodeConnections.ForEach(func(key string, val interface{}) bool { + val.(*pool.Pool).Close() + return true + }) + return nil +} diff --git a/cluster/com_test.go b/cluster/com_test.go index 7978bc3..ecd9b76 100644 --- a/cluster/com_test.go +++ b/cluster/com_test.go @@ -9,13 +9,13 @@ import ( ) func TestExec(t *testing.T) { - testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"}) + testCluster := testCluster[0] conn := connection.NewFakeConn() for i := 0; i < 1000; i++ { key := RandString(4) value := RandString(4) - testCluster2.Exec(conn, toArgs("SET", key, value)) - ret := testCluster2.Exec(conn, toArgs("GET", key)) + testCluster.Exec(conn, toArgs("SET", key, value)) + ret := testCluster.Exec(conn, toArgs("GET", key)) asserts.AssertBulkReply(t, ret, value) } } @@ -27,27 +27,28 @@ func TestAuth(t *testing.T) { config.Properties.RequirePass = "" }() conn := connection.NewFakeConn() - ret := testNodeA.Exec(conn, toArgs("GET", "a")) + testCluster := testCluster[0] + ret := testCluster.Exec(conn, toArgs("GET", "a")) asserts.AssertErrReply(t, ret, "NOAUTH Authentication required") - ret = testNodeA.Exec(conn, toArgs("AUTH", passwd)) + ret = testCluster.Exec(conn, toArgs("AUTH", passwd)) asserts.AssertStatusReply(t, ret, "OK") - ret = testNodeA.Exec(conn, toArgs("GET", "a")) + ret = testCluster.Exec(conn, toArgs("GET", "a")) asserts.AssertNotError(t, ret) } func TestRelay(t *testing.T) { - testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"}) + testNodeA := testCluster[1] key := RandString(4) value := RandString(4) conn := connection.NewFakeConn() - ret := testCluster2.relay("127.0.0.1:6379", conn, toArgs("SET", key, value)) + ret := testNodeA.relay(addresses[1], conn, toArgs("SET", key, value)) asserts.AssertNotError(t, ret) - ret = testCluster2.relay("127.0.0.1:6379", conn, toArgs("GET", key)) + ret = testNodeA.relay(addresses[1], conn, toArgs("GET", key)) asserts.AssertBulkReply(t, ret, value) } func TestBroadcast(t *testing.T) { - testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"}) + testCluster2 := testCluster[0] key := RandString(4) value := RandString(4) rets := testCluster2.broadcast(connection.NewFakeConn(), toArgs("SET", key, value)) diff --git a/cluster/copy.go b/cluster/copy.go index 49a5673..f0cb960 100644 --- a/cluster/copy.go +++ b/cluster/copy.go @@ -20,8 +20,8 @@ func Copy(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { } srcKey := string(args[1]) destKey := string(args[2]) - srcNode := cluster.peerPicker.PickNode(srcKey) - destNode := cluster.peerPicker.PickNode(destKey) + srcNode := cluster.pickNodeAddrByKey(srcKey) + destNode := cluster.pickNodeAddrByKey(destKey) replaceFlag := noReplace if len(args) > 3 { for i := 3; i < len(args); i++ { @@ -37,6 +37,7 @@ func Copy(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { } if srcNode == destNode { + args[0] = []byte("Copy_") // Copy_ will go directly to cluster.DB avoiding infinite recursion return cluster.relay(srcNode, c, args) } groupMap := map[string][]string{ @@ -47,7 +48,7 @@ func Copy(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { txID := cluster.idGenerator.NextID() txIDStr := strconv.FormatInt(txID, 10) // prepare Copy from - srcPrepareResp := cluster.relayPrepare(srcNode, c, makeArgs("Prepare", txIDStr, "CopyFrom", srcKey)) + srcPrepareResp := cluster.relay(srcNode, c, makeArgs("Prepare", txIDStr, "CopyFrom", srcKey)) if protocol.IsErrorReply(srcPrepareResp) { // rollback src node requestRollback(cluster, c, txID, map[string][]string{srcNode: {srcKey}}) @@ -59,11 +60,14 @@ func Copy(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { return protocol.MakeErrReply("ERR invalid prepare response") } // prepare Copy to - destPrepareResp := cluster.relayPrepare(destNode, c, utils.ToCmdLine3("Prepare", []byte(txIDStr), + destPrepareResp := cluster.relay(destNode, c, utils.ToCmdLine3("Prepare", []byte(txIDStr), []byte("CopyTo"), []byte(destKey), srcPrepareMBR.Args[0], srcPrepareMBR.Args[1], []byte(replaceFlag))) - if protocol.IsErrorReply(destPrepareResp) { + if destErr, ok := destPrepareResp.(protocol.ErrorReply); ok { // rollback src node requestRollback(cluster, c, txID, groupMap) + if destErr.Error() == keyExistsErr { + return protocol.MakeIntReply(0) + } return destPrepareResp } if _, errReply := requestCommit(cluster, c, txID, groupMap); errReply != nil { diff --git a/cluster/copy_test.go b/cluster/copy_test.go index 11771fe..f0f343a 100644 --- a/cluster/copy_test.go +++ b/cluster/copy_test.go @@ -9,109 +9,117 @@ import ( func TestCopy(t *testing.T) { conn := new(connection.FakeConn) - testNodeA.db.Exec(conn, utils.ToCmdLine("FlushALL")) + testNodeA := testCluster[0] + testNodeB := testCluster[1] + testNodeA.Exec(conn, utils.ToCmdLine("FlushALL")) + testNodeB.Exec(conn, utils.ToCmdLine("FlushALL")) // cross node copy - srcKey := testNodeA.self + utils.RandString(10) + srcKey := "127.0.0.1:6399Bk2r3Sz0V5" // use fix key to ensure hashing to different node + destKey := "127.0.0.1:7379CcdC0QOopF" value := utils.RandString(10) - destKey := testNodeB.self + utils.RandString(10) - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value)) + testNodeA.Exec(conn, utils.ToCmdLine("SET", srcKey, value)) result := Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey)) asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("GET", srcKey)) + result = testNodeA.Exec(conn, utils.ToCmdLine("GET", srcKey)) asserts.AssertBulkReply(t, result, value) - result = testNodeB.db.Exec(conn, utils.ToCmdLine("GET", destKey)) + result = testNodeB.Exec(conn, utils.ToCmdLine("GET", destKey)) asserts.AssertBulkReply(t, result, value) // key exists result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey)) - asserts.AssertErrReply(t, result, keyExistsErr) + asserts.AssertIntReply(t, result, 0) // replace value = utils.RandString(10) - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value)) + testNodeA.Exec(conn, utils.ToCmdLine("SET", srcKey, value)) result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey, "REPLACE")) asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("GET", srcKey)) + result = testNodeA.Exec(conn, utils.ToCmdLine("GET", srcKey)) asserts.AssertBulkReply(t, result, value) - result = testNodeB.db.Exec(conn, utils.ToCmdLine("GET", destKey)) + result = testNodeB.Exec(conn, utils.ToCmdLine("GET", destKey)) asserts.AssertBulkReply(t, result, value) // test copy expire time - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value, "EX", "1000")) + testNodeA.Exec(conn, utils.ToCmdLine("SET", srcKey, value, "EX", "1000")) result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey, "REPLACE")) asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", srcKey)) + result = testNodeA.Exec(conn, utils.ToCmdLine("TTL", srcKey)) asserts.AssertIntReplyGreaterThan(t, result, 0) - result = testNodeB.db.Exec(conn, utils.ToCmdLine("TTL", destKey)) + result = testNodeB.Exec(conn, utils.ToCmdLine("TTL", destKey)) asserts.AssertIntReplyGreaterThan(t, result, 0) // same node copy - srcKey = testNodeA.self + utils.RandString(10) + srcKey = "{" + testNodeA.self + "}" + utils.RandString(10) + destKey = "{" + testNodeA.self + "}" + utils.RandString(9) value = utils.RandString(10) - destKey = srcKey + utils.RandString(2) - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value)) + testNodeA.Exec(conn, utils.ToCmdLine("SET", srcKey, value)) result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey)) asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("GET", srcKey)) + result = testNodeA.Exec(conn, utils.ToCmdLine("GET", srcKey)) asserts.AssertBulkReply(t, result, value) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("GET", destKey)) + result = testNodeA.Exec(conn, utils.ToCmdLine("GET", destKey)) asserts.AssertBulkReply(t, result, value) // key exists result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey)) asserts.AssertIntReply(t, result, 0) // replace value = utils.RandString(10) - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value)) + testNodeA.Exec(conn, utils.ToCmdLine("SET", srcKey, value)) result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey, "REPLACE")) asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("GET", srcKey)) + result = testNodeA.Exec(conn, utils.ToCmdLine("GET", srcKey)) asserts.AssertBulkReply(t, result, value) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("GET", destKey)) + result = testNodeA.Exec(conn, utils.ToCmdLine("GET", destKey)) asserts.AssertBulkReply(t, result, value) // test copy expire time - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value, "EX", "1000")) + testNodeA.Exec(conn, utils.ToCmdLine("SET", srcKey, value, "EX", "1000")) result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey, "REPLACE")) asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", srcKey)) + result = testNodeA.Exec(conn, utils.ToCmdLine("TTL", srcKey)) asserts.AssertIntReplyGreaterThan(t, result, 0) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", destKey)) + result = testNodeA.Exec(conn, utils.ToCmdLine("TTL", destKey)) asserts.AssertIntReplyGreaterThan(t, result, 0) +} + +func TestCopyTimeout(t *testing.T) { + conn := new(connection.FakeConn) + testNodeA := testCluster[0] + testNodeB := testCluster[1] + testNodeA.Exec(conn, utils.ToCmdLine("FlushALL")) + testNodeB.Exec(conn, utils.ToCmdLine("FlushALL")) // test src prepare failed - *simulateATimout = true - srcKey = testNodeA.self + utils.RandString(10) - destKey = testNodeB.self + utils.RandString(10) // route to testNodeB, see mockPicker.PickNode - value = utils.RandString(10) - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value, "ex", "1000")) - result = Rename(testNodeB, conn, utils.ToCmdLine("RENAME", srcKey, destKey)) + timeoutFlags[0] = true + srcKey := "127.0.0.1:6399Bk2r3Sz0V5" // use fix key to ensure hashing to different node + destKey := "127.0.0.1:7379CcdC0QOopF" + value := utils.RandString(10) + testNodeA.Exec(conn, utils.ToCmdLine("SET", srcKey, value, "ex", "1000")) + result := Rename(testNodeB, conn, utils.ToCmdLine("RENAME", srcKey, destKey)) asserts.AssertErrReply(t, result, "ERR timeout") - result = testNodeA.db.Exec(conn, utils.ToCmdLine("EXISTS", srcKey)) + result = testNodeA.Exec(conn, utils.ToCmdLine("EXISTS", srcKey)) asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", srcKey)) + result = testNodeA.Exec(conn, utils.ToCmdLine("TTL", srcKey)) asserts.AssertIntReplyGreaterThan(t, result, 0) - result = testNodeB.db.Exec(conn, utils.ToCmdLine("EXISTS", destKey)) + result = testNodeB.Exec(conn, utils.ToCmdLine("EXISTS", destKey)) asserts.AssertIntReply(t, result, 0) - *simulateATimout = false + timeoutFlags[0] = false // test dest prepare failed - *simulateBTimout = true - srcKey = testNodeA.self + utils.RandString(10) - destKey = testNodeB.self + utils.RandString(10) // route to testNodeB, see mockPicker.PickNode + timeoutFlags[1] = true value = utils.RandString(10) - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value, "ex", "1000")) + testNodeA.Exec(conn, utils.ToCmdLine("SET", srcKey, value, "ex", "1000")) result = Rename(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey)) asserts.AssertErrReply(t, result, "ERR timeout") - result = testNodeA.db.Exec(conn, utils.ToCmdLine("EXISTS", srcKey)) + result = testNodeA.Exec(conn, utils.ToCmdLine("EXISTS", srcKey)) asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", srcKey)) + result = testNodeA.Exec(conn, utils.ToCmdLine("TTL", srcKey)) asserts.AssertIntReplyGreaterThan(t, result, 0) - result = testNodeB.db.Exec(conn, utils.ToCmdLine("EXISTS", destKey)) + result = testNodeB.Exec(conn, utils.ToCmdLine("EXISTS", destKey)) asserts.AssertIntReply(t, result, 0) - *simulateBTimout = false - + timeoutFlags[1] = false // Copying to another database srcKey = testNodeA.self + utils.RandString(10) value = utils.RandString(10) destKey = srcKey + utils.RandString(2) - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value)) + testNodeA.Exec(conn, utils.ToCmdLine("SET", srcKey, value)) result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey, "db", "1")) asserts.AssertErrReply(t, result, copyToAnotherDBErr) diff --git a/cluster/del.go b/cluster/del.go index e200030..cbe6087 100644 --- a/cluster/del.go +++ b/cluster/del.go @@ -19,7 +19,7 @@ func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { groupMap := cluster.groupBy(keys) if len(groupMap) == 1 && allowFastTransaction { // do fast for peer, group := range groupMap { // only one peerKeys - return cluster.relay(peer, c, makeArgs("DEL", group...)) + return cluster.relay(peer, c, makeArgs("Del_", group...)) } } // prepare @@ -31,11 +31,7 @@ func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { peerArgs := []string{txIDStr, "DEL"} peerArgs = append(peerArgs, peerKeys...) var resp redis.Reply - if peer == cluster.self { - resp = execPrepare(cluster, c, makeArgs("Prepare", peerArgs...)) - } else { - resp = cluster.relay(peer, c, makeArgs("Prepare", peerArgs...)) - } + resp = cluster.relay(peer, c, makeArgs("Prepare", peerArgs...)) if protocol.IsErrorReply(resp) { errReply = resp rollback = true diff --git a/cluster/del_test.go b/cluster/del_test.go index cc3205f..836f0f9 100644 --- a/cluster/del_test.go +++ b/cluster/del_test.go @@ -9,6 +9,7 @@ import ( func TestDel(t *testing.T) { conn := connection.NewFakeConn() allowFastTransaction = false + testNodeA := testCluster[0] testNodeA.Exec(conn, toArgs("SET", "a", "a")) ret := Del(testNodeA, conn, toArgs("DEL", "a", "b", "c")) asserts.AssertNotError(t, ret) diff --git a/cluster/fixed_topo.go b/cluster/fixed_topo.go new file mode 100644 index 0000000..41c6f51 --- /dev/null +++ b/cluster/fixed_topo.go @@ -0,0 +1,58 @@ +package cluster + +import ( + "github.com/hdt3213/godis/redis/protocol" + "sync" +) + +// fixedTopology is a fixed cluster topology, used for test +type fixedTopology struct { + mu sync.RWMutex + nodeMap map[string]*Node + slots []*Slot + selfNodeID string +} + +func (fixed *fixedTopology) GetSelfNodeID() string { + return fixed.selfNodeID +} + +func (fixed *fixedTopology) GetNodes() []*Node { + fixed.mu.RLock() + defer fixed.mu.RUnlock() + result := make([]*Node, 0, len(fixed.nodeMap)) + for _, v := range fixed.nodeMap { + result = append(result, v) + } + return result +} + +func (fixed *fixedTopology) GetNode(nodeID string) *Node { + fixed.mu.RLock() + defer fixed.mu.RUnlock() + return fixed.nodeMap[nodeID] +} + +func (fixed *fixedTopology) GetSlots() []*Slot { + return fixed.slots +} + +func (fixed *fixedTopology) StartAsSeed(addr string) protocol.ErrorReply { + return nil +} + +func (fixed *fixedTopology) LoadConfigFile() protocol.ErrorReply { + return nil +} + +func (fixed *fixedTopology) Join(seed string) protocol.ErrorReply { + return protocol.MakeErrReply("fixed topology does not support join") +} + +func (fixed *fixedTopology) SetSlot(slotIDs []uint32, newNodeID string) protocol.ErrorReply { + return protocol.MakeErrReply("fixed topology does not support set slots") +} + +func (fixed *fixedTopology) Close() error { + return nil +} diff --git a/cluster/keys.go b/cluster/keys.go index 9ac78bd..b809941 100644 --- a/cluster/keys.go +++ b/cluster/keys.go @@ -6,8 +6,8 @@ import ( ) // FlushDB removes all data in current database -func FlushDB(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - replies := cluster.broadcast(c, args) +func FlushDB(cluster *Cluster, c redis.Connection, cmdLine [][]byte) redis.Reply { + replies := cluster.broadcast(c, modifyCmd(cmdLine, "FlushDB_")) var errReply protocol.ErrorReply for _, v := range replies { if protocol.IsErrorReply(v) { diff --git a/cluster/mset.go b/cluster/mset.go index 62d987b..43c1cea 100644 --- a/cluster/mset.go +++ b/cluster/mset.go @@ -22,15 +22,15 @@ func MGet(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply { resultMap := make(map[string][]byte) groupMap := cluster.groupBy(keys) - for peer, group := range groupMap { - resp := cluster.relay(peer, c, makeArgs("MGET", group...)) + for peer, groupKeys := range groupMap { + resp := cluster.relay(peer, c, makeArgs("MGet_", groupKeys...)) if protocol.IsErrorReply(resp) { errReply := resp.(protocol.ErrorReply) - return protocol.MakeErrReply(fmt.Sprintf("ERR during get %s occurs: %v", group[0], errReply.Error())) + return protocol.MakeErrReply(fmt.Sprintf("ERR during get %s occurs: %v", groupKeys[0], errReply.Error())) } arrReply, _ := resp.(*protocol.MultiBulkReply) for i, v := range arrReply.Args { - key := group[i] + key := groupKeys[i] resultMap[key] = v } } @@ -59,7 +59,7 @@ func MSet(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply { groupMap := cluster.groupBy(keys) if len(groupMap) == 1 && allowFastTransaction { // do fast for peer := range groupMap { - return cluster.relay(peer, c, cmdLine) + return cluster.relay(peer, c, modifyCmd(cmdLine, "MSet_")) } } @@ -73,12 +73,7 @@ func MSet(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply { for _, k := range group { peerArgs = append(peerArgs, k, valueMap[k]) } - var resp redis.Reply - if peer == cluster.self { - resp = execPrepare(cluster, c, makeArgs("Prepare", peerArgs...)) - } else { - resp = cluster.relay(peer, c, makeArgs("Prepare", peerArgs...)) - } + resp := cluster.relay(peer, c, makeArgs("Prepare", peerArgs...)) if protocol.IsErrorReply(resp) { errReply = resp rollback = true @@ -117,7 +112,7 @@ func MSetNX(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply { groupMap := cluster.groupBy(keys) if len(groupMap) == 1 && allowFastTransaction { // do fast for peer := range groupMap { - return cluster.relay(peer, c, cmdLine) + return cluster.relay(peer, c, modifyCmd(cmdLine, "MSetNX_")) } } @@ -133,7 +128,7 @@ func MSetNX(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply { for _, k := range group { nodeArgs = append(nodeArgs, k, valueMap[k]) } - resp := cluster.relayPrepare(node, c, makeArgs("Prepare", nodeArgs...)) + resp := cluster.relay(node, c, makeArgs("Prepare", nodeArgs...)) if protocol.IsErrorReply(resp) { re := resp.(protocol.ErrorReply) if re.Error() == keyExistsErr { diff --git a/cluster/mset_test.go b/cluster/mset_test.go index 9f81e4d..18127a0 100644 --- a/cluster/mset_test.go +++ b/cluster/mset_test.go @@ -9,6 +9,7 @@ import ( func TestMSet(t *testing.T) { conn := connection.NewFakeConn() allowFastTransaction = false + testNodeA := testCluster[0] ret := MSet(testNodeA, conn, toArgs("MSET", "a", "a", "b", "b")) asserts.AssertNotError(t, ret) ret = testNodeA.Exec(conn, toArgs("MGET", "a", "b")) @@ -18,6 +19,7 @@ func TestMSet(t *testing.T) { func TestMSetNx(t *testing.T) { conn := connection.NewFakeConn() allowFastTransaction = false + testNodeA := testCluster[0] FlushAll(testNodeA, conn, toArgs("FLUSHALL")) ret := MSetNX(testNodeA, conn, toArgs("MSETNX", "a", "a", "b", "b")) asserts.AssertNotError(t, ret) diff --git a/cluster/multi.go b/cluster/multi.go index ab5c1d4..9bfd740 100644 --- a/cluster/multi.go +++ b/cluster/multi.go @@ -4,12 +4,13 @@ import ( "github.com/hdt3213/godis/database" "github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/connection" "github.com/hdt3213/godis/redis/protocol" "strconv" ) -const relayMulti = "_multi" -const innerWatch = "_watch" +const relayMulti = "multi_" +const innerWatch = "watch_" var relayMultiBytes = []byte(relayMulti) @@ -50,6 +51,11 @@ func execMulti(cluster *Cluster, conn redis.Connection, cmdLine CmdLine) redis.R // out parser not support protocol.MultiRawReply, so we have to encode it if peer == cluster.self { + for _, key := range keys { + if errReply := cluster.ensureKey(key); errReply != nil { + return errReply + } + } return cluster.db.ExecMulti(conn, watching, cmdLines) } return execMultiOnOtherNode(cluster, conn, peer, watching, cmdLines) @@ -72,12 +78,7 @@ func execMultiOnOtherNode(cluster *Cluster, conn redis.Connection, peer string, relayCmdLine = append(relayCmdLine, encodeCmdLine([]CmdLine{watchingCmdLine})...) relayCmdLine = append(relayCmdLine, encodeCmdLine(cmdLines)...) var rawRelayResult redis.Reply - if peer == cluster.self { - // this branch just for testing - rawRelayResult = execRelayedMulti(cluster, conn, relayCmdLine) - } else { - rawRelayResult = cluster.relay(peer, conn, relayCmdLine) - } + rawRelayResult = cluster.relay(peer, connection.NewFakeConn(), relayCmdLine) if protocol.IsErrorReply(rawRelayResult) { return rawRelayResult } @@ -116,7 +117,7 @@ func execRelayedMulti(cluster *Cluster, conn redis.Connection, cmdLine CmdLine) txCmdLines = append(txCmdLines, mbr.Args) } watching := make(map[string]uint32) - watchCmdLine := txCmdLines[0] // format: _watch key1 ver1 key2 ver2... + watchCmdLine := txCmdLines[0] // format: watch_ key1 ver1 key2 ver2... for i := 2; i < len(watchCmdLine); i += 2 { key := string(watchCmdLine[i-1]) verStr := string(watchCmdLine[i]) @@ -146,8 +147,11 @@ func execWatch(cluster *Cluster, conn redis.Connection, args [][]byte) redis.Rep watching := conn.GetWatching() for _, bkey := range args { key := string(bkey) - peer := cluster.peerPicker.PickNode(key) - result := cluster.relay(peer, conn, utils.ToCmdLine("GetVer", key)) + err := cluster.ensureKey(key) + if err != nil { + return err + } + result := cluster.relayByKey(key, conn, utils.ToCmdLine("GetVer", key)) if protocol.IsErrorReply(result) { return result } diff --git a/cluster/multi_helper.go b/cluster/multi_helper.go index a1e89e6..2d1f6c2 100644 --- a/cluster/multi_helper.go +++ b/cluster/multi_helper.go @@ -35,6 +35,7 @@ func parseEncodedMultiRawReply(args [][]byte) (*protocol.MultiRawReply, error) { return protocol.MakeMultiRawReply(cmds), nil } +// todo: use multi raw reply instead of base64 func encodeMultiRawReply(src *protocol.MultiRawReply) *protocol.MultiBulkReply { args := make([][]byte, 0, len(src.Replies)) for _, rep := range src.Replies { diff --git a/cluster/multi_test.go b/cluster/multi_test.go index 4f9d8c2..8413196 100644 --- a/cluster/multi_test.go +++ b/cluster/multi_test.go @@ -9,14 +9,15 @@ import ( ) func TestMultiExecOnSelf(t *testing.T) { + testNodeA := testCluster[0] conn := new(connection.FakeConn) testNodeA.db.Exec(conn, utils.ToCmdLine("FLUSHALL")) result := testNodeA.Exec(conn, toArgs("MULTI")) asserts.AssertNotError(t, result) - key := utils.RandString(10) + key := "{abc}" + utils.RandString(10) value := utils.RandString(10) testNodeA.Exec(conn, utils.ToCmdLine("set", key, value)) - key2 := utils.RandString(10) + key2 := "{abc}" + utils.RandString(10) testNodeA.Exec(conn, utils.ToCmdLine("rpush", key2, value)) result = testNodeA.Exec(conn, utils.ToCmdLine("exec")) asserts.AssertNotError(t, result) @@ -27,8 +28,9 @@ func TestMultiExecOnSelf(t *testing.T) { } func TestEmptyMulti(t *testing.T) { + testNodeA := testCluster[0] conn := new(connection.FakeConn) - testNodeA.db.Exec(conn, utils.ToCmdLine("FLUSHALL")) + testNodeA.Exec(conn, utils.ToCmdLine("FLUSHALL")) result := testNodeA.Exec(conn, toArgs("MULTI")) asserts.AssertNotError(t, result) result = testNodeA.Exec(conn, utils.ToCmdLine("GET", "a")) @@ -40,8 +42,9 @@ func TestEmptyMulti(t *testing.T) { } func TestMultiExecOnOthers(t *testing.T) { + testNodeA := testCluster[0] conn := new(connection.FakeConn) - testNodeA.db.Exec(conn, utils.ToCmdLine("FLUSHALL")) + testNodeA.Exec(conn, utils.ToCmdLine("FLUSHALL")) result := testNodeA.Exec(conn, toArgs("MULTI")) asserts.AssertNotError(t, result) key := utils.RandString(10) @@ -59,54 +62,31 @@ func TestMultiExecOnOthers(t *testing.T) { } func TestWatch(t *testing.T) { - conn := new(connection.FakeConn) - testNodeA.db.Exec(conn, utils.ToCmdLine("FLUSHALL")) - key := utils.RandString(10) - value := utils.RandString(10) - testNodeA.Exec(conn, utils.ToCmdLine("watch", key)) - testNodeA.Exec(conn, utils.ToCmdLine("set", key, value)) - result := testNodeA.Exec(conn, toArgs("MULTI")) - asserts.AssertNotError(t, result) - key2 := utils.RandString(10) - value2 := utils.RandString(10) - testNodeA.Exec(conn, utils.ToCmdLine("set", key2, value2)) - result = testNodeA.Exec(conn, utils.ToCmdLine("exec")) - asserts.AssertNotError(t, result) - result = testNodeA.Exec(conn, utils.ToCmdLine("get", key2)) - asserts.AssertNullBulk(t, result) + testNodeA := testCluster[0] + for i := 0; i < 10; i++ { + conn := new(connection.FakeConn) + key := "{1}" + utils.RandString(10) + key2 := "{1}" + utils.RandString(10) // use hash tag to ensure same slot + value := utils.RandString(10) + testNodeA.Exec(conn, utils.ToCmdLine("FLUSHALL")) + testNodeA.Exec(conn, utils.ToCmdLine("watch", key)) + testNodeA.Exec(conn, utils.ToCmdLine("set", key, value)) + result := testNodeA.Exec(conn, toArgs("MULTI")) + asserts.AssertNotError(t, result) + value2 := utils.RandString(10) + testNodeA.Exec(conn, utils.ToCmdLine("set", key2, value2)) + result = testNodeA.Exec(conn, utils.ToCmdLine("exec")) + asserts.AssertNotError(t, result) + result = testNodeA.Exec(conn, utils.ToCmdLine("get", key2)) + asserts.AssertNullBulk(t, result) - testNodeA.Exec(conn, utils.ToCmdLine("watch", key)) - result = testNodeA.Exec(conn, toArgs("MULTI")) - asserts.AssertNotError(t, result) - testNodeA.Exec(conn, utils.ToCmdLine("set", key2, value2)) - result = testNodeA.Exec(conn, utils.ToCmdLine("exec")) - asserts.AssertNotError(t, result) - result = testNodeA.Exec(conn, utils.ToCmdLine("get", key2)) - asserts.AssertBulkReply(t, result, value2) -} - -func TestWatch2(t *testing.T) { - conn := new(connection.FakeConn) - testNodeA.db.Exec(conn, utils.ToCmdLine("FLUSHALL")) - key := utils.RandString(10) - value := utils.RandString(10) - testNodeA.Exec(conn, utils.ToCmdLine("watch", key)) - testNodeA.Exec(conn, utils.ToCmdLine("set", key, value)) - result := testNodeA.Exec(conn, toArgs("MULTI")) - asserts.AssertNotError(t, result) - key2 := utils.RandString(10) - value2 := utils.RandString(10) - testNodeA.Exec(conn, utils.ToCmdLine("set", key2, value2)) - cmdLines := conn.GetQueuedCmdLine() - execMultiOnOtherNode(testNodeA, conn, testNodeA.self, conn.GetWatching(), cmdLines) - result = testNodeA.Exec(conn, utils.ToCmdLine("get", key2)) - asserts.AssertNullBulk(t, result) - - testNodeA.Exec(conn, utils.ToCmdLine("watch", key)) - result = testNodeA.Exec(conn, toArgs("MULTI")) - asserts.AssertNotError(t, result) - testNodeA.Exec(conn, utils.ToCmdLine("set", key2, value2)) - execMultiOnOtherNode(testNodeA, conn, testNodeA.self, conn.GetWatching(), cmdLines) - result = testNodeA.Exec(conn, utils.ToCmdLine("get", key2)) - asserts.AssertBulkReply(t, result, value2) + testNodeA.Exec(conn, utils.ToCmdLine("watch", key)) + result = testNodeA.Exec(conn, toArgs("MULTI")) + asserts.AssertNotError(t, result) + testNodeA.Exec(conn, utils.ToCmdLine("set", key2, value2)) + result = testNodeA.Exec(conn, utils.ToCmdLine("exec")) + asserts.AssertNotError(t, result) + result = testNodeA.Exec(conn, utils.ToCmdLine("get", key2)) + asserts.AssertBulkReply(t, result, value2) + } } diff --git a/cluster/pubsub.go b/cluster/pubsub.go index a936cfa..ed86b2d 100644 --- a/cluster/pubsub.go +++ b/cluster/pubsub.go @@ -7,19 +7,13 @@ import ( ) const ( - relayPublish = "_publish" - publish = "publish" -) - -var ( - publishRelayCmd = []byte(relayPublish) - publishCmd = []byte(publish) + relayPublish = "publish_" ) // Publish broadcasts msg to all peers in cluster when receive publish command from client -func Publish(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { +func Publish(cluster *Cluster, c redis.Connection, cmdLine [][]byte) redis.Reply { var count int64 = 0 - results := cluster.broadcast(c, args) + results := cluster.broadcast(c, modifyCmd(cmdLine, relayPublish)) for _, val := range results { if errReply, ok := val.(protocol.ErrorReply); ok { logger.Error("publish occurs error: " + errReply.Error()) @@ -30,12 +24,6 @@ func Publish(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { return protocol.MakeIntReply(count) } -// onRelayedPublish receives publish command from peer, just publish to local subscribing clients, do not relay to peers -func onRelayedPublish(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - args[0] = publishCmd - return cluster.db.Exec(c, args) // let local db.hub handle publish -} - // Subscribe puts the given connection into the given channel func Subscribe(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { return cluster.db.Exec(c, args) // let local db.hub handle subscribe diff --git a/cluster/pubsub_test.go b/cluster/pubsub_test.go index 7c380fb..0894b1b 100644 --- a/cluster/pubsub_test.go +++ b/cluster/pubsub_test.go @@ -9,6 +9,7 @@ import ( ) func TestPublish(t *testing.T) { + testNodeA := testCluster[0] channel := utils.RandString(5) msg := utils.RandString(5) conn := connection.NewFakeConn() diff --git a/cluster/raft.go b/cluster/raft.go new file mode 100644 index 0000000..55dd496 --- /dev/null +++ b/cluster/raft.go @@ -0,0 +1,1075 @@ +package cluster + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/hdt3213/godis/config" + "github.com/hdt3213/godis/datastruct/lock" + "github.com/hdt3213/godis/interface/redis" + "github.com/hdt3213/godis/lib/logger" + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/connection" + "github.com/hdt3213/godis/redis/protocol" + "math/rand" + "os" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +const slotCount int = 16384 + +type raftState int + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +const ( + nodeFlagLeader uint32 = 1 << iota + nodeFlagCandidate + nodeFlagLearner +) + +const ( + follower raftState = iota + leader + candidate + learner +) + +var stateNames = map[raftState]string{ + follower: "follower", + leader: "leader", + candidate: "candidate", + learner: "learner", +} + +func (node *Node) setState(state raftState) { + node.Flags &= ^uint32(0x7) // clean + switch state { + case follower: + break + case leader: + node.Flags |= nodeFlagLeader + case candidate: + node.Flags |= nodeFlagCandidate + case learner: + node.Flags |= nodeFlagLearner + } +} + +func (node *Node) getState() raftState { + if node.Flags&nodeFlagLeader > 0 { + return leader + } + if node.Flags&nodeFlagCandidate > 0 { + return candidate + } + if node.Flags&nodeFlagLearner > 0 { + return learner + } + return follower +} + +type logEntry struct { + Term int + Index int + Event int + wg *sync.WaitGroup + // payload + SlotIDs []uint32 + NodeID string + Addr string +} + +func (e *logEntry) marshal() []byte { + bin, _ := json.Marshal(e) + return bin +} + +func (e *logEntry) unmarshal(bin []byte) error { + err := json.Unmarshal(bin, e) + if err != nil { + return fmt.Errorf("illegal message: %v", err) + } + return nil +} + +type Raft struct { + cluster *Cluster + mu sync.RWMutex + selfNodeID string + slots []*Slot + leaderId string + nodes map[string]*Node + log []*logEntry // log index begin from 0 + baseIndex int // baseIndex + 1 == log[0].Index, it can be considered as the previous log index + baseTerm int // baseTerm is the term of the previous log entry + state raftState + term int + votedFor string + voteCount int + committedIndex int // index of the last committed logEntry + proposedIndex int // index of the last proposed logEntry + heartbeatChan chan *heartbeat + persistFile string + electionAlarm time.Time + closeChan chan struct{} + closed bool + + // for leader + nodeIndexMap map[string]*nodeStatus + nodeLock *lock.Locks +} + +func newRaft(cluster *Cluster, persistFilename string) *Raft { + return &Raft{ + cluster: cluster, + persistFile: persistFilename, + closeChan: make(chan struct{}), + } +} + +type heartbeat struct { + sender string + term int + entries []*logEntry + commitTo int +} + +type nodeStatus struct { + receivedIndex int // received log index, not committed index +} + +func (raft *Raft) GetNodes() []*Node { + raft.mu.RLock() + defer raft.mu.RUnlock() + result := make([]*Node, 0, len(raft.nodes)) + for _, v := range raft.nodes { + result = append(result, v) + } + return result +} + +func (raft *Raft) GetNode(nodeID string) *Node { + raft.mu.RLock() + defer raft.mu.RUnlock() + return raft.nodes[nodeID] +} + +func (raft *Raft) getLogEntries(beg, end int) []*logEntry { + if beg <= raft.baseIndex || end > raft.baseIndex+len(raft.log)+1 { + return nil + } + i := beg - raft.baseIndex - 1 + j := end - raft.baseIndex - 1 + return raft.log[i:j] +} + +func (raft *Raft) getLogEntriesFrom(beg int) []*logEntry { + if beg <= raft.baseIndex { + return nil + } + i := beg - raft.baseIndex - 1 + return raft.log[i:] +} + +func (raft *Raft) getLogEntry(idx int) *logEntry { + if idx < raft.baseIndex || idx >= raft.baseIndex+len(raft.log) { + return nil + } + return raft.log[idx-raft.baseIndex] +} + +func (raft *Raft) initLog(baseTerm, baseIndex int, entries []*logEntry) { + raft.baseIndex = baseIndex + raft.baseTerm = baseTerm + raft.log = entries +} + +const ( + electionTimeoutMaxMs = 4000 + electionTimeoutMinMs = 2800 +) + +func randRange(from, to int) int { + return rand.Intn(to-from) + from +} + +// nextElectionAlarm generates normal election timeout, with randomness +func nextElectionAlarm() time.Time { + return time.Now().Add(time.Duration(randRange(electionTimeoutMinMs, electionTimeoutMaxMs)) * time.Millisecond) +} + +func compareLogIndex(term1, index1, term2, index2 int) int { + if term1 != term2 { + return term1 - term2 + } + return index1 - index2 +} + +func (cluster *Cluster) asRaft() *Raft { + return cluster.topology.(*Raft) +} + +// StartAsSeed starts cluster as seed node +func (raft *Raft) StartAsSeed(listenAddr string) protocol.ErrorReply { + selfNodeID := listenAddr + raft.mu.Lock() + defer raft.mu.Unlock() + raft.slots = make([]*Slot, slotCount) + // claim all slots + for i := range raft.slots { + raft.slots[i] = &Slot{ + ID: uint32(i), + NodeID: selfNodeID, + } + } + raft.selfNodeID = selfNodeID + raft.leaderId = selfNodeID + raft.nodes = make(map[string]*Node) + raft.nodes[selfNodeID] = &Node{ + ID: selfNodeID, + Addr: listenAddr, + Slots: raft.slots, + } + raft.nodes[selfNodeID].setState(leader) + raft.nodeIndexMap = map[string]*nodeStatus{ + selfNodeID: { + receivedIndex: raft.proposedIndex, + }, + } + raft.start(leader) + raft.cluster.self = selfNodeID + return nil +} + +func (raft *Raft) GetSlots() []*Slot { + return raft.slots +} + +// GetSelfNodeID returns node id of current node +func (raft *Raft) GetSelfNodeID() string { + return raft.selfNodeID +} + +const raftClosed = "ERR raft has closed" + +func (raft *Raft) start(state raftState) { + raft.state = state + raft.heartbeatChan = make(chan *heartbeat, 1) + raft.electionAlarm = nextElectionAlarm() + //raft.nodeIndexMap = make(map[string]*nodeStatus) + go func() { + for { + if raft.closed { + logger.Info("quit raft job") + return + } + switch raft.state { + case follower: + raft.followerJob() + case candidate: + raft.candidateJob() + case leader: + raft.leaderJob() + } + } + }() +} + +func (raft *Raft) Close() error { + raft.closed = true + close(raft.closeChan) + return raft.persist() +} + +func (raft *Raft) followerJob() { + electionTimeout := time.Until(raft.electionAlarm) + select { + case hb := <-raft.heartbeatChan: + raft.mu.Lock() + nodeId := hb.sender + raft.nodes[nodeId].lastHeard = time.Now() + // todo: drop duplicate entry + raft.log = append(raft.log, hb.entries...) + raft.proposedIndex += len(hb.entries) + raft.applyLogEntries(raft.getLogEntries(raft.committedIndex+1, hb.commitTo+1)) + raft.committedIndex = hb.commitTo + raft.electionAlarm = nextElectionAlarm() + raft.mu.Unlock() + case <-time.After(electionTimeout): + // change to candidate + logger.Info("raft leader timeout") + raft.mu.Lock() + raft.electionAlarm = nextElectionAlarm() + if raft.votedFor != "" { + // received request-vote and has voted during waiting timeout + raft.mu.Unlock() + logger.Infof("%s has voted for %s, give up being a candidate", raft.selfNodeID, raft.votedFor) + return + } + logger.Info("change to candidate") + raft.state = candidate + raft.mu.Unlock() + case <-raft.closeChan: + return + } +} + +func (raft *Raft) getLogProgressWithinLock() (int, int) { + var lastLogTerm, lastLogIndex int + if len(raft.log) > 0 { + lastLog := raft.log[len(raft.log)-1] + lastLogTerm = lastLog.Term + lastLogIndex = lastLog.Index + } else { + lastLogTerm = raft.baseTerm + lastLogIndex = raft.baseIndex + } + return lastLogTerm, lastLogIndex +} + +func (raft *Raft) candidateJob() { + raft.mu.Lock() + + raft.term++ + raft.votedFor = raft.selfNodeID + raft.voteCount++ + currentTerm := raft.term + lastLogTerm, lastLogIndex := raft.getLogProgressWithinLock() + req := &voteReq{ + nodeID: raft.selfNodeID, + lastLogTerm: lastLogTerm, + lastLogIndex: lastLogIndex, + term: raft.term, + } + raft.mu.Unlock() + args := append([][]byte{ + []byte("raft"), + []byte("request-vote"), + }, req.marshal()...) + conn := connection.NewFakeConn() + wg := sync.WaitGroup{} + elected := make(chan struct{}, len(raft.nodes)) // may receive many elected message during an election, only handle the first one + voteFinished := make(chan struct{}) + for nodeID := range raft.nodes { + if nodeID == raft.selfNodeID { + continue + } + nodeID := nodeID + wg.Add(1) + go func() { + defer wg.Done() + rawResp := raft.cluster.relay(nodeID, conn, args) + if err, ok := rawResp.(protocol.ErrorReply); ok { + logger.Info(fmt.Sprintf("cannot get vote response from %s, %v", nodeID, err)) + return + } + respBody, ok := rawResp.(*protocol.MultiBulkReply) + if !ok { + logger.Info(fmt.Sprintf("cannot get vote response from %s, not a multi bulk reply", nodeID)) + return + } + resp := &voteResp{} + err := resp.unmarshal(respBody.Args) + if err != nil { + logger.Info(fmt.Sprintf("cannot get vote response from %s, %v", nodeID, err)) + return + } + + raft.mu.Lock() + defer raft.mu.Unlock() + logger.Info("received vote response from " + nodeID) + // check-lock-check + if currentTerm != raft.term || raft.state != candidate { + // vote has finished during waiting lock + logger.Info("vote has finished during waiting lock, current term " + strconv.Itoa(raft.term) + " state " + strconv.Itoa(int(raft.state))) + return + } + if resp.term > raft.term { + logger.Infof(fmt.Sprintf("vote response from %s has newer term %d", nodeID, resp.term)) + raft.term = resp.term + raft.state = follower + raft.votedFor = "" + raft.leaderId = resp.voteFor + return + } + + if resp.voteFor == raft.selfNodeID { + logger.Infof(fmt.Sprintf("get vote from %s", nodeID)) + raft.voteCount++ + if raft.voteCount >= len(raft.nodes)/2+1 { + logger.Info("elected to be the leader") + raft.state = leader + elected <- struct{}{} // notify the main goroutine to stop waiting + return + } + } + }() + } + go func() { + wg.Wait() + voteFinished <- struct{}{} + }() + + // wait vote finished or elected + select { + case <-voteFinished: + raft.mu.Lock() + if raft.term == currentTerm && raft.state == candidate { + logger.Infof("%s failed to be elected, back to follower", raft.selfNodeID) + raft.state = follower + raft.votedFor = "" + raft.voteCount = 0 + } + raft.mu.Unlock() + case <-elected: + raft.votedFor = "" + raft.voteCount = 0 + logger.Info("win election, take leader of term " + strconv.Itoa(currentTerm)) + case <-raft.closeChan: + return + } +} + +// getNodeIndexMap ask offset of each node and init nodeIndexMap as new leader +// invoker provide lock +func (raft *Raft) getNodeIndexMap() { + // ask node index + nodeIndexMap := make(map[string]*nodeStatus) + for _, node := range raft.nodes { + status := raft.askNodeIndex(node) + if status != nil { + nodeIndexMap[node.ID] = status + } + } + logger.Info("got offsets of nodes") + raft.nodeIndexMap = nodeIndexMap +} + +// askNodeIndex ask another node for its log index +// return nil if failed +func (raft *Raft) askNodeIndex(node *Node) *nodeStatus { + if node.ID == raft.selfNodeID { + return &nodeStatus{ + receivedIndex: raft.proposedIndex, + } + } + logger.Debugf("ask %s for offset", node.ID) + c := connection.NewFakeConn() + reply := raft.cluster.relay(node.Addr, c, utils.ToCmdLine("raft", "get-offset")) + if protocol.IsErrorReply(reply) { + logger.Infof("ask node %s index failed: %v", node.ID, reply) + return nil + } + return &nodeStatus{ + receivedIndex: int(reply.(*protocol.IntReply).Code), + } +} + +func (raft *Raft) leaderJob() { + raft.mu.Lock() + if raft.nodeIndexMap == nil { + // getNodeIndexMap with lock, because leader cannot work without nodeIndexMap + raft.getNodeIndexMap() + } + if raft.nodeLock == nil { + raft.nodeLock = lock.Make(1024) + } + var recvedIndices []int + for _, status := range raft.nodeIndexMap { + recvedIndices = append(recvedIndices, status.receivedIndex) + } + sort.Slice(recvedIndices, func(i, j int) bool { + return recvedIndices[i] > recvedIndices[j] + }) + // more than half of the nodes received entries, can be committed + commitTo := 0 + if len(recvedIndices) > 0 { + commitTo = recvedIndices[len(recvedIndices)/2] + } + // new node (received index is 0) may cause commitTo less than raft.committedIndex + if commitTo > raft.committedIndex { + toCommit := raft.getLogEntries(raft.committedIndex+1, commitTo+1) // left inclusive, right exclusive + raft.applyLogEntries(toCommit) + raft.committedIndex = commitTo + for _, entry := range toCommit { + if entry.wg != nil { + entry.wg.Done() + } + } + } + // save receivedIndex in local variable in case changed by other goroutines + proposalIndex := raft.proposedIndex + snapshot := raft.makeSnapshot() // the snapshot is consistent with the committed log + for _, node := range raft.nodes { + if node.ID == raft.selfNodeID { + continue + } + node := node + status := raft.nodeIndexMap[node.ID] + go func() { + raft.nodeLock.Lock(node.ID) + defer raft.nodeLock.UnLock(node.ID) + var cmdLine [][]byte + if status == nil { + logger.Debugf("node %s offline", node.ID) + status = raft.askNodeIndex(node) + if status != nil { + // get status, node has back online + raft.mu.Lock() + raft.nodeIndexMap[node.ID] = status + raft.mu.Unlock() + } else { + // node still offline + return + } + } + if status.receivedIndex < raft.baseIndex { + // some entries are missed due to change of leader, send full snapshot + cmdLine = utils.ToCmdLine( + "raft", + "load-snapshot", + raft.selfNodeID, + ) + // see makeSnapshotForFollower + cmdLine = append(cmdLine, []byte(node.ID), []byte(strconv.Itoa(int(follower)))) + cmdLine = append(cmdLine, snapshot[2:]...) + } else { + // leader has all needed entries, send normal heartbeat + req := &heartbeatRequest{ + leaderId: raft.leaderId, + term: raft.term, + commitTo: commitTo, + } + // append new entries to heartbeat payload + if proposalIndex > status.receivedIndex { + req.prevLogTerm = raft.getLogEntry(status.receivedIndex).Term + req.prevLogIndex = status.receivedIndex + req.entries = raft.getLogEntriesFrom(status.receivedIndex + 1) + } + cmdLine = utils.ToCmdLine( + "raft", + "heartbeat", + ) + cmdLine = append(cmdLine, req.marshal()...) + } + + conn := connection.NewFakeConn() + resp := raft.cluster.relay(node.ID, conn, cmdLine) + switch respPayload := resp.(type) { + case *protocol.MultiBulkReply: + term, _ := strconv.Atoi(string(respPayload.Args[0])) + recvedIndex, _ := strconv.Atoi(string(respPayload.Args[1])) + if term > raft.term { + // todo: rejoin as follower + return + } + raft.mu.Lock() + raft.nodeIndexMap[node.ID].receivedIndex = recvedIndex + raft.mu.Unlock() + case protocol.ErrorReply: + if respPayload.Error() == prevLogMismatch { + cmdLine = utils.ToCmdLine( + "raft", + "load-snapshot", + raft.selfNodeID, + ) + cmdLine = append(cmdLine, []byte(node.ID), []byte(strconv.Itoa(int(follower)))) + cmdLine = append(cmdLine, snapshot[2:]...) + resp := raft.cluster.relay(node.ID, conn, cmdLine) + if err, ok := resp.(protocol.ErrorReply); ok { + logger.Errorf("heartbeat to %s failed: %v", node.ID, err) + return + } + } else if respPayload.Error() == nodeNotReady { + logger.Infof("%s is not ready yet", node.ID) + return + } else { + logger.Errorf("heartbeat to %s failed: %v", node.ID, respPayload.Error()) + return + } + + } + }() + } + raft.mu.Unlock() + time.Sleep(time.Millisecond * 1000) +} + +func init() { + registerCmd("raft", execRaft) +} + +func execRaft(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + raft := cluster.asRaft() + if raft.closed { + return protocol.MakeErrReply(raftClosed) + } + if len(args) < 2 { + return protocol.MakeArgNumErrReply("raft") + } + subCmd := strings.ToLower(string(args[1])) + switch subCmd { + case "request-vote": + // command line: raft request-vote nodeId index term + // Decide whether to vote when other nodes solicit votes + return execRaftRequestVote(cluster, c, args[2:]) + case "heartbeat": + // execRaftHeartbeat handles heartbeat from leader as follower or learner + // command line: raft heartbeat nodeID term number-of-log-log log log + return execRaftHeartbeat(cluster, c, args[2:]) + case "load-snapshot": + // execRaftLoadSnapshot load snapshot from leader + // command line: raft load-snapshot leaderId snapshot(see raft.makeSnapshot) + return execRaftLoadSnapshot(cluster, c, args[2:]) + case "propose": + // execRaftPropose handles event proposal as leader + // command line: raft propose + return execRaftPropose(cluster, c, args[2:]) + case "join": + // execRaftJoin handles requests from a new node to join raft group as leader + // command line: raft join
+ return execRaftJoin(cluster, c, args[2:]) + case "get-leader": + // execRaftGetLeader returns leader id and address + return execRaftGetLeader(cluster, c, args[2:]) + case "get-offset": + // execRaftGetOffset returns log offset of current leader + return execRaftGetOffset(cluster, c, args[2:]) + } + return protocol.MakeErrReply(" ERR unknown raft sub command '" + subCmd + "'") +} + +type voteReq struct { + nodeID string + term int + lastLogIndex int + lastLogTerm int +} + +func (req *voteReq) marshal() [][]byte { + lastLogIndexBin := []byte(strconv.Itoa(req.lastLogIndex)) + lastLogTermBin := []byte(strconv.Itoa(req.lastLogTerm)) + termBin := []byte(strconv.Itoa(req.term)) + return [][]byte{ + []byte(req.nodeID), + termBin, + lastLogIndexBin, + lastLogTermBin, + } +} + +func (req *voteReq) unmarshal(bin [][]byte) error { + req.nodeID = string(bin[0]) + term, err := strconv.Atoi(string(bin[1])) + if err != nil { + return fmt.Errorf("illegal term %s", string(bin[2])) + } + req.term = term + logIndex, err := strconv.Atoi(string(bin[2])) + if err != nil { + return fmt.Errorf("illegal index %s", string(bin[1])) + } + req.lastLogIndex = logIndex + logTerm, err := strconv.Atoi(string(bin[3])) + if err != nil { + return fmt.Errorf("illegal index %s", string(bin[1])) + } + req.lastLogTerm = logTerm + return nil +} + +type voteResp struct { + voteFor string + term int +} + +func (resp *voteResp) unmarshal(bin [][]byte) error { + if len(bin) != 2 { + return errors.New("illegal vote resp length") + } + resp.voteFor = string(bin[0]) + term, err := strconv.Atoi(string(bin[1])) + if err != nil { + return fmt.Errorf("illegal term: %s", string(bin[1])) + } + resp.term = term + return nil +} + +func (resp *voteResp) marshal() [][]byte { + return [][]byte{ + []byte(resp.voteFor), + []byte(strconv.Itoa(resp.term)), + } +} + +// execRaftRequestVote command line: raft request-vote nodeID index term +// Decide whether to vote when other nodes solicit votes +func execRaftRequestVote(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + if len(args) != 4 { + return protocol.MakeArgNumErrReply("raft request-vote") + } + req := &voteReq{} + err := req.unmarshal(args) + if err != nil { + return protocol.MakeErrReply(err.Error()) + } + raft := cluster.asRaft() + raft.mu.Lock() + defer raft.mu.Unlock() + logger.Info("recv request vote from " + req.nodeID + ", term: " + strconv.Itoa(req.term)) + resp := &voteResp{} + if req.term < raft.term { + resp.term = raft.term + resp.voteFor = raft.leaderId // tell candidate the new leader + logger.Info("deny request vote from " + req.nodeID + " for earlier term") + return protocol.MakeMultiBulkReply(resp.marshal()) + } + // todo: if req.term > raft.term step down as leader? + lastLogTerm, lastLogIndex := raft.getLogProgressWithinLock() + if compareLogIndex(req.lastLogTerm, req.lastLogIndex, lastLogTerm, lastLogIndex) < 0 { + resp.term = raft.term + resp.voteFor = raft.votedFor + logger.Info("deny request vote from " + req.nodeID + " for log progress") + logger.Info("request vote proposal index " + strconv.Itoa(req.lastLogIndex) + " self index " + strconv.Itoa(raft.proposedIndex)) + return protocol.MakeMultiBulkReply(resp.marshal()) + } + if raft.votedFor != "" && raft.votedFor != raft.selfNodeID { + resp.term = raft.term + resp.voteFor = raft.votedFor + logger.Info("deny request vote from " + req.nodeID + " for voted") + return protocol.MakeMultiBulkReply(resp.marshal()) + } + if raft.votedFor == raft.selfNodeID && + raft.voteCount == 1 { + // cancel vote for self to avoid live lock + raft.votedFor = "" + raft.voteCount = 0 + } + logger.Info("accept request vote from " + req.nodeID) + raft.votedFor = req.nodeID + raft.term = req.term + raft.electionAlarm = nextElectionAlarm() + resp.voteFor = req.nodeID + resp.term = raft.term + return protocol.MakeMultiBulkReply(resp.marshal()) +} + +type heartbeatRequest struct { + leaderId string + term int + commitTo int + prevLogTerm int + prevLogIndex int + entries []*logEntry +} + +func (req *heartbeatRequest) marshal() [][]byte { + cmdLine := utils.ToCmdLine( + req.leaderId, + strconv.Itoa(req.term), + strconv.Itoa(req.commitTo), + ) + if len(req.entries) > 0 { + cmdLine = append(cmdLine, + []byte(strconv.Itoa(req.prevLogTerm)), + []byte(strconv.Itoa(req.prevLogIndex)), + ) + for _, entry := range req.entries { + cmdLine = append(cmdLine, entry.marshal()) + } + } + return cmdLine +} + +func (req *heartbeatRequest) unmarshal(args [][]byte) protocol.ErrorReply { + if len(args) < 6 && len(args) != 3 { + return protocol.MakeArgNumErrReply("raft heartbeat") + } + req.leaderId = string(args[0]) + var err error + req.term, err = strconv.Atoi(string(args[1])) + if err != nil { + return protocol.MakeErrReply("illegal term: " + string(args[1])) + } + req.commitTo, err = strconv.Atoi(string(args[2])) + if err != nil { + return protocol.MakeErrReply("illegal commitTo: " + string(args[2])) + } + if len(args) > 3 { + req.prevLogTerm, err = strconv.Atoi(string(args[3])) + if err != nil { + return protocol.MakeErrReply("illegal commitTo: " + string(args[3])) + } + req.prevLogIndex, err = strconv.Atoi(string(args[4])) + if err != nil { + return protocol.MakeErrReply("illegal commitTo: " + string(args[4])) + } + for _, bin := range args[5:] { + entry := &logEntry{} + err = entry.unmarshal(bin) + if err != nil { + return protocol.MakeErrReply(err.Error()) + } + req.entries = append(req.entries, entry) + } + } + return nil +} + +const prevLogMismatch = "prev log mismatch" +const nodeNotReady = "not ready" + +// execRaftHeartbeat receives heartbeat from leader +// command line: raft heartbeat nodeID term commitTo prevTerm prevIndex [log entry] +// returns term and received index +func execRaftHeartbeat(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + raft := cluster.asRaft() + req := &heartbeatRequest{} + unmarshalErr := req.unmarshal(args) + if unmarshalErr != nil { + return unmarshalErr + } + if req.term < raft.term { + return protocol.MakeMultiBulkReply(utils.ToCmdLine( + strconv.Itoa(req.term), + strconv.Itoa(raft.proposedIndex), // new received index + )) + } else if req.term > raft.term { + logger.Info("accept new leader " + req.leaderId) + raft.mu.Lock() + // todo: if current node is not at follower state + raft.term = req.term + raft.votedFor = "" + raft.leaderId = req.leaderId + raft.mu.Unlock() + } + raft.mu.RLock() + // heartbeat may arrive earlier than follower ready + if raft.heartbeatChan == nil { + raft.mu.RUnlock() + return protocol.MakeErrReply(nodeNotReady) + } + if len(req.entries) > 0 && compareLogIndex(req.prevLogTerm, req.prevLogIndex, raft.baseTerm, raft.baseIndex) != 0 { + raft.mu.RUnlock() + return protocol.MakeErrReply(prevLogMismatch) + } + raft.mu.RUnlock() + + raft.heartbeatChan <- &heartbeat{ + sender: req.leaderId, + term: req.term, + entries: req.entries, + commitTo: req.commitTo, + } + return protocol.MakeMultiBulkReply(utils.ToCmdLine( + strconv.Itoa(req.term), + strconv.Itoa(raft.proposedIndex+len(req.entries)), // new received index + )) +} + +// execRaftLoadSnapshot load snapshot from leader +// command line: raft load-snapshot leaderId snapshot(see raft.makeSnapshot) +func execRaftLoadSnapshot(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + // leaderId snapshot + if len(args) < 5 { + return protocol.MakeArgNumErrReply("raft load snapshot") + } + raft := cluster.asRaft() + raft.mu.Lock() + defer raft.mu.Unlock() + if errReply := raft.loadSnapshot(args[1:]); errReply != nil { + return errReply + } + sender := string(args[0]) + raft.heartbeatChan <- &heartbeat{ + sender: sender, + term: raft.term, + entries: nil, + commitTo: raft.committedIndex, + } + return protocol.MakeMultiBulkReply(utils.ToCmdLine( + strconv.Itoa(raft.term), + strconv.Itoa(raft.proposedIndex), + )) +} + +var wgPool = sync.Pool{ + New: func() interface{} { + return &sync.WaitGroup{} + }, +} + +// execRaftGetLeader returns leader id and address +func execRaftGetLeader(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + raft := cluster.asRaft() + raft.mu.RLock() + leaderNode := raft.nodes[raft.leaderId] + raft.mu.RUnlock() + return protocol.MakeMultiBulkReply(utils.ToCmdLine( + leaderNode.ID, + leaderNode.Addr, + )) +} + +// execRaftGetOffset returns log offset of current leader +func execRaftGetOffset(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + raft := cluster.asRaft() + raft.mu.RLock() + proposalIndex := raft.proposedIndex + //committedIndex := raft.committedIndex + raft.mu.RUnlock() + return protocol.MakeIntReply(int64(proposalIndex)) +} + +// invoker should provide with raft.mu lock +func (raft *Raft) persist() error { + if raft.persistFile == "" { + return nil + } + tmpFile, err := os.CreateTemp(config.Properties.Dir, "tmp-cluster-conf-*.conf") + if err != nil { + return err + } + snapshot := raft.makeSnapshot() + buf := bytes.NewBuffer(nil) + for _, line := range snapshot { + buf.Write(line) + buf.WriteByte('\n') + } + _, err = tmpFile.Write(buf.Bytes()) + if err != nil { + return err + } + err = os.Rename(tmpFile.Name(), raft.persistFile) + if err != nil { + return err + } + return nil +} + +// execRaftPropose handles requests from other nodes (follower or learner) to propose a change +// command line: raft propose +func execRaftPropose(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + raft := cluster.asRaft() + if raft.state != leader { + leaderNode := raft.nodes[raft.leaderId] + return protocol.MakeErrReply("NOT LEADER " + leaderNode.ID + " " + leaderNode.Addr) + } + if len(args) != 1 { + return protocol.MakeArgNumErrReply("raft propose") + } + + e := &logEntry{} + err := e.unmarshal(args[0]) + if err != nil { + return protocol.MakeErrReply(err.Error()) + } + if errReply := raft.propose(e); errReply != nil { + return errReply + } + return protocol.MakeOkReply() +} + +func (raft *Raft) propose(e *logEntry) protocol.ErrorReply { + switch e.Event { + case eventNewNode: + raft.mu.Lock() + _, ok := raft.nodes[e.Addr] + raft.mu.Unlock() + if ok { + return protocol.MakeErrReply("node exists") + } + } + wg := wgPool.Get().(*sync.WaitGroup) + defer wgPool.Put(wg) + e.wg = wg + raft.mu.Lock() + raft.proposedIndex++ + raft.log = append(raft.log, e) + raft.nodeIndexMap[raft.selfNodeID].receivedIndex = raft.proposedIndex + e.Term = raft.term + e.Index = raft.proposedIndex + raft.mu.Unlock() + e.wg.Add(1) + e.wg.Wait() // wait for the raft group to reach a consensus + return nil +} + +func (raft *Raft) Join(seed string) protocol.ErrorReply { + cluster := raft.cluster + + /* STEP1: get leader from seed */ + seedCli, err := cluster.clientFactory.GetPeerClient(seed) + if err != nil { + return protocol.MakeErrReply("connect with seed failed: " + err.Error()) + } + defer cluster.clientFactory.ReturnPeerClient(seed, seedCli) + ret := seedCli.Send(utils.ToCmdLine("raft", "get-leader")) + if protocol.IsErrorReply(ret) { + return ret.(protocol.ErrorReply) + } + leaderInfo, ok := ret.(*protocol.MultiBulkReply) + if !ok || len(leaderInfo.Args) != 2 { + return protocol.MakeErrReply("ERR get-leader returns wrong reply") + } + leaderAddr := string(leaderInfo.Args[1]) + + /* STEP2: join raft group */ + leaderCli, err := cluster.clientFactory.GetPeerClient(leaderAddr) + if err != nil { + return protocol.MakeErrReply("connect with seed failed: " + err.Error()) + } + defer cluster.clientFactory.ReturnPeerClient(leaderAddr, leaderCli) + ret = leaderCli.Send(utils.ToCmdLine("raft", "join", cluster.addr)) + if protocol.IsErrorReply(ret) { + return ret.(protocol.ErrorReply) + } + snapshot, ok := ret.(*protocol.MultiBulkReply) + if !ok || len(snapshot.Args) < 4 { + return protocol.MakeErrReply("ERR gcluster join returns wrong reply") + } + raft.mu.Lock() + defer raft.mu.Unlock() + if errReply := raft.loadSnapshot(snapshot.Args); errReply != nil { + return errReply + } + cluster.self = raft.selfNodeID + raft.start(follower) + return nil +} + +func (raft *Raft) LoadConfigFile() protocol.ErrorReply { + f, err := os.Open(raft.persistFile) + if err == os.ErrNotExist { + return errConfigFileNotExist + } + defer func() { + if err := f.Close(); err != nil { + logger.Errorf("close cloud config file error: %v", err) + } + }() + scanner := bufio.NewScanner(f) + var snapshot [][]byte + for scanner.Scan() { + line := append([]byte{}, scanner.Bytes()...) // copy the line... + snapshot = append(snapshot, line) + } + raft.mu.Lock() + defer raft.mu.Unlock() + if errReply := raft.loadSnapshot(snapshot); errReply != nil { + return errReply + } + raft.cluster.self = raft.selfNodeID + raft.start(raft.state) + return nil +} diff --git a/cluster/raft_event.go b/cluster/raft_event.go new file mode 100644 index 0000000..ffa4a76 --- /dev/null +++ b/cluster/raft_event.go @@ -0,0 +1,134 @@ +package cluster + +// raft event handlers + +import ( + "errors" + "github.com/hdt3213/godis/interface/redis" + "github.com/hdt3213/godis/lib/logger" + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/connection" + "github.com/hdt3213/godis/redis/protocol" +) + +const ( + eventNewNode = iota + 1 + eventSetSlot +) + +// invoker should provide with raft.mu lock +func (raft *Raft) applyLogEntries(entries []*logEntry) { + for _, entry := range entries { + switch entry.Event { + case eventNewNode: + node := &Node{ + ID: entry.NodeID, + Addr: entry.Addr, + } + raft.nodes[node.ID] = node + if raft.state == leader { + raft.nodeIndexMap[entry.NodeID] = &nodeStatus{ + receivedIndex: entry.Index, // the new node should not receive its own join event + } + } + case eventSetSlot: + for _, slotID := range entry.SlotIDs { + slot := raft.slots[slotID] + oldNode := raft.nodes[slot.NodeID] + // remove from old oldNode + for i, s := range oldNode.Slots { + if s.ID == slot.ID { + copy(oldNode.Slots[i:], oldNode.Slots[i+1:]) + oldNode.Slots = oldNode.Slots[:len(oldNode.Slots)-1] + break + } + } + newNodeID := entry.NodeID + slot.NodeID = newNodeID + // fixme: 多个节点同时加入后 re balance 时 newNode 可能为 nil + newNode := raft.nodes[slot.NodeID] + newNode.Slots = append(newNode.Slots, slot) + } + } + } + if err := raft.persist(); err != nil { + logger.Errorf("persist raft error: %v", err) + } + +} + +// NewNode creates a new Node when a node request self node for joining cluster +func (raft *Raft) NewNode(addr string) (*Node, error) { + if _, ok := raft.nodes[addr]; ok { + return nil, errors.New("node existed") + } + node := &Node{ + ID: addr, + Addr: addr, + } + raft.nodes[node.ID] = node + proposal := &logEntry{ + Event: eventNewNode, + NodeID: node.ID, + Addr: node.Addr, + } + conn := connection.NewFakeConn() + resp := raft.cluster.relay(raft.leaderId, conn, + utils.ToCmdLine("raft", "propose", string(proposal.marshal()))) + if err, ok := resp.(protocol.ErrorReply); ok { + return nil, err + } + return node, nil +} + +// SetSlot propose +func (raft *Raft) SetSlot(slotIDs []uint32, newNodeID string) protocol.ErrorReply { + proposal := &logEntry{ + Event: eventSetSlot, + NodeID: newNodeID, + SlotIDs: slotIDs, + } + conn := connection.NewFakeConn() + resp := raft.cluster.relay(raft.leaderId, conn, + utils.ToCmdLine("raft", "propose", string(proposal.marshal()))) + if err, ok := resp.(protocol.ErrorReply); ok { + return err + } + return nil +} + +// execRaftJoin handles requests from a new node to join raft group, current node should be leader +// command line: raft join addr +func execRaftJoin(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + if len(args) != 1 { + return protocol.MakeArgNumErrReply("raft join") + } + raft := cluster.asRaft() + if raft.state != leader { + leaderNode := raft.nodes[raft.leaderId] + return protocol.MakeErrReply("NOT LEADER " + leaderNode.ID + " " + leaderNode.Addr) + } + addr := string(args[0]) + nodeID := addr + + raft.mu.RLock() + _, exist := raft.nodes[addr] + raft.mu.RUnlock() + // if node has joint cluster but terminated before persisting cluster config, + // it may try to join at next start. + // In this case, we only have to send a snapshot for it + if !exist { + proposal := &logEntry{ + Event: eventNewNode, + NodeID: nodeID, + Addr: addr, + } + if err := raft.propose(proposal); err != nil { + return err + } + } + raft.mu.RLock() + snapshot := raft.makeSnapshotForFollower(nodeID) + raft.mu.RUnlock() + return protocol.MakeMultiBulkReply(snapshot) +} diff --git a/cluster/raft_snapshot.go b/cluster/raft_snapshot.go new file mode 100644 index 0000000..b97145c --- /dev/null +++ b/cluster/raft_snapshot.go @@ -0,0 +1,204 @@ +package cluster + +import ( + "encoding/json" + "fmt" + "github.com/hdt3213/godis/redis/protocol" + "sort" + "strconv" + "strings" +) + +// marshalSlotIds serializes slot ids +// For example, 1, 2, 3, 5, 7, 8 -> 1-3, 5, 7-8 +func marshalSlotIds(slots []*Slot) []string { + sort.Slice(slots, func(i, j int) bool { + return slots[i].ID < slots[j].ID + }) + // find continuous scopes + var scopes [][]uint32 + buf := make([]uint32, 2) + var scope []uint32 + for i, slot := range slots { + if len(scope) == 0 { // outside scope + if i+1 < len(slots) && + slots[i+1].ID == slot.ID+1 { // if continuous, then start one + scope = buf + scope[0] = slot.ID + } else { // discrete number + scopes = append(scopes, []uint32{slot.ID}) + } + } else { // within a scope + if i == len(slots)-1 || slots[i+1].ID != slot.ID+1 { // reach end or not continuous, stop current scope + scope[1] = slot.ID + scopes = append(scopes, []uint32{scope[0], scope[1]}) + scope = nil + } + } + + } + + // marshal scopes + result := make([]string, 0, len(scopes)) + for _, scope := range scopes { + if len(scope) == 1 { + s := strconv.Itoa(int(scope[0])) + result = append(result, s) + } else { // assert len(scope) == 2 + beg := strconv.Itoa(int(scope[0])) + end := strconv.Itoa(int(scope[1])) + result = append(result, beg+"-"+end) + } + } + return result +} + +// unmarshalSlotIds deserializes lines generated by marshalSlotIds +func unmarshalSlotIds(args []string) ([]uint32, error) { + var result []uint32 + for i, line := range args { + if pivot := strings.IndexByte(line, '-'); pivot > 0 { + // line is a scope + beg, err := strconv.Atoi(line[:pivot]) + if err != nil { + return nil, fmt.Errorf("illegal at slot line %d", i+1) + } + end, err := strconv.Atoi(line[pivot+1:]) + if err != nil { + return nil, fmt.Errorf("illegal at slot line %d", i+1) + } + for j := beg; j <= end; j++ { + result = append(result, uint32(j)) + } + } else { + // line is a number + v, err := strconv.Atoi(line) + if err != nil { + return nil, fmt.Errorf("illegal at slot line %d", i) + } + result = append(result, uint32(v)) + } + } + return result, nil +} + +type nodePayload struct { + ID string `json:"id"` + Addr string `json:"addr"` + SlotDesc []string `json:"slotDesc"` + Flags uint32 `json:"flags"` +} + +func marshalNodes(nodes map[string]*Node) [][]byte { + var args [][]byte + for _, node := range nodes { + slotLines := marshalSlotIds(node.Slots) + payload := &nodePayload{ + ID: node.ID, + Addr: node.Addr, + SlotDesc: slotLines, + Flags: node.Flags, + } + bin, _ := json.Marshal(payload) + args = append(args, bin) + } + return args +} + +func unmarshalNodes(args [][]byte) (map[string]*Node, error) { + nodeMap := make(map[string]*Node) + for i, bin := range args { + payload := &nodePayload{} + err := json.Unmarshal(bin, payload) + if err != nil { + return nil, fmt.Errorf("unmarshal node failed at line %d: %v", i+1, err) + } + slotIds, err := unmarshalSlotIds(payload.SlotDesc) + if err != nil { + return nil, err + } + node := &Node{ + ID: payload.ID, + Addr: payload.Addr, + Flags: payload.Flags, + } + for _, slotId := range slotIds { + node.Slots = append(node.Slots, &Slot{ + ID: slotId, + NodeID: node.ID, + Flags: 0, + }) + } + nodeMap[node.ID] = node + } + return nodeMap, nil +} + +// genSnapshot +// invoker provide lock +func (raft *Raft) makeSnapshot() [][]byte { + topology := marshalNodes(raft.nodes) + snapshot := [][]byte{ + []byte(raft.selfNodeID), + []byte(strconv.Itoa(int(raft.state))), + []byte(raft.leaderId), + []byte(strconv.Itoa(raft.term)), + []byte(strconv.Itoa(raft.committedIndex)), + } + snapshot = append(snapshot, topology...) + return snapshot +} + +// makeSnapshotForFollower used by leader node to generate snapshot for follower +// invoker provide with lock +func (raft *Raft) makeSnapshotForFollower(followerId string) [][]byte { + snapshot := raft.makeSnapshot() + snapshot[0] = []byte(followerId) + snapshot[1] = []byte(strconv.Itoa(int(follower))) + return snapshot +} + +// invoker provide with lock +func (raft *Raft) loadSnapshot(snapshot [][]byte) protocol.ErrorReply { + // make sure raft.slots and node.Slots is the same object + selfNodeId := string(snapshot[0]) + state0, err := strconv.Atoi(string(snapshot[1])) + if err != nil { + return protocol.MakeErrReply("illegal state: " + string(snapshot[1])) + } + state := raftState(state0) + if _, ok := stateNames[state]; !ok { + return protocol.MakeErrReply("unknown state: " + strconv.Itoa(int(state))) + } + leaderId := string(snapshot[2]) + term, err := strconv.Atoi(string(snapshot[3])) + if err != nil { + return protocol.MakeErrReply("illegal term: " + string(snapshot[3])) + } + commitIndex, err := strconv.Atoi(string(snapshot[4])) + if err != nil { + return protocol.MakeErrReply("illegal commit index: " + string(snapshot[3])) + } + nodes, err := unmarshalNodes(snapshot[5:]) + if err != nil { + return protocol.MakeErrReply(err.Error()) + } + raft.selfNodeID = selfNodeId + raft.state = state + raft.leaderId = leaderId + raft.term = term + raft.committedIndex = commitIndex + raft.proposedIndex = commitIndex + raft.initLog(term, commitIndex, nil) + raft.slots = make([]*Slot, slotCount) + for _, node := range nodes { + for _, slot := range node.Slots { + raft.slots[int(slot.ID)] = slot + } + if node.getState() == leader { + raft.leaderId = node.ID + } + } + raft.nodes = nodes + return nil +} diff --git a/cluster/raft_snapshot_test.go b/cluster/raft_snapshot_test.go new file mode 100644 index 0000000..916b1b5 --- /dev/null +++ b/cluster/raft_snapshot_test.go @@ -0,0 +1 @@ +package cluster diff --git a/cluster/rename.go b/cluster/rename.go index 74fbec7..b663867 100644 --- a/cluster/rename.go +++ b/cluster/rename.go @@ -8,16 +8,16 @@ import ( ) // Rename renames a key, the origin and the destination must within the same node -func Rename(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) != 3 { +func Rename(cluster *Cluster, c redis.Connection, cmdLine [][]byte) redis.Reply { + if len(cmdLine) != 3 { return protocol.MakeErrReply("ERR wrong number of arguments for 'rename' command") } - srcKey := string(args[1]) - destKey := string(args[2]) - srcNode := cluster.peerPicker.PickNode(srcKey) - destNode := cluster.peerPicker.PickNode(destKey) + srcKey := string(cmdLine[1]) + destKey := string(cmdLine[2]) + srcNode := cluster.pickNodeAddrByKey(srcKey) + destNode := cluster.pickNodeAddrByKey(destKey) if srcNode == destNode { // do fast - return cluster.relay(srcNode, c, args) + return cluster.relay(srcNode, c, modifyCmd(cmdLine, "Rename_")) } groupMap := map[string][]string{ srcNode: {srcKey}, @@ -26,7 +26,7 @@ func Rename(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { txID := cluster.idGenerator.NextID() txIDStr := strconv.FormatInt(txID, 10) // prepare rename from - srcPrepareResp := cluster.relayPrepare(srcNode, c, makeArgs("Prepare", txIDStr, "RenameFrom", srcKey)) + srcPrepareResp := cluster.relay(srcNode, c, makeArgs("Prepare", txIDStr, "RenameFrom", srcKey)) if protocol.IsErrorReply(srcPrepareResp) { // rollback src node requestRollback(cluster, c, txID, map[string][]string{srcNode: {srcKey}}) @@ -38,7 +38,7 @@ func Rename(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { return protocol.MakeErrReply("ERR invalid prepare response") } // prepare rename to - destPrepareResp := cluster.relayPrepare(destNode, c, utils.ToCmdLine3("Prepare", []byte(txIDStr), + destPrepareResp := cluster.relay(destNode, c, utils.ToCmdLine3("Prepare", []byte(txIDStr), []byte("RenameTo"), []byte(destKey), srcPrepareMBR.Args[0], srcPrepareMBR.Args[1])) if protocol.IsErrorReply(destPrepareResp) { // rollback src node @@ -92,16 +92,16 @@ func init() { // RenameNx renames a key, only if the new key does not exist. // The origin and the destination must within the same node -func RenameNx(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) != 3 { +func RenameNx(cluster *Cluster, c redis.Connection, cmdLine [][]byte) redis.Reply { + if len(cmdLine) != 3 { return protocol.MakeErrReply("ERR wrong number of arguments for 'renamenx' command") } - srcKey := string(args[1]) - destKey := string(args[2]) - srcNode := cluster.peerPicker.PickNode(srcKey) - destNode := cluster.peerPicker.PickNode(destKey) + srcKey := string(cmdLine[1]) + destKey := string(cmdLine[2]) + srcNode := cluster.pickNodeAddrByKey(srcKey) + destNode := cluster.pickNodeAddrByKey(destKey) if srcNode == destNode { - return cluster.relay(srcNode, c, args) + return cluster.relay(srcNode, c, modifyCmd(cmdLine, "RenameNX_")) } groupMap := map[string][]string{ srcNode: {srcKey}, @@ -110,7 +110,7 @@ func RenameNx(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { txID := cluster.idGenerator.NextID() txIDStr := strconv.FormatInt(txID, 10) // prepare rename from - srcPrepareResp := cluster.relayPrepare(srcNode, c, makeArgs("Prepare", txIDStr, "RenameFrom", srcKey)) + srcPrepareResp := cluster.relay(srcNode, c, makeArgs("Prepare", txIDStr, "RenameFrom", srcKey)) if protocol.IsErrorReply(srcPrepareResp) { // rollback src node requestRollback(cluster, c, txID, map[string][]string{srcNode: {srcKey}}) @@ -122,7 +122,7 @@ func RenameNx(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { return protocol.MakeErrReply("ERR invalid prepare response") } // prepare rename to - destPrepareResp := cluster.relayPrepare(destNode, c, utils.ToCmdLine3("Prepare", []byte(txIDStr), + destPrepareResp := cluster.relay(destNode, c, utils.ToCmdLine3("Prepare", []byte(txIDStr), []byte("RenameNxTo"), []byte(destKey), srcPrepareMBR.Args[0], srcPrepareMBR.Args[1])) if protocol.IsErrorReply(destPrepareResp) { // rollback src node diff --git a/cluster/rename_test.go b/cluster/rename_test.go index ca8b070..966a939 100644 --- a/cluster/rename_test.go +++ b/cluster/rename_test.go @@ -8,152 +8,52 @@ import ( ) func TestRename(t *testing.T) { + testNodeA := testCluster[0] conn := new(connection.FakeConn) - testNodeA.db.Exec(conn, utils.ToCmdLine("FlushALL")) // cross node rename - key := testNodeA.self + utils.RandString(10) - value := utils.RandString(10) - newKey := testNodeB.self + utils.RandString(10) // route to testNodeB, see mockPicker.PickNode - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", key, value, "ex", "1000")) - result := Rename(testNodeA, conn, utils.ToCmdLine("RENAME", key, newKey)) - asserts.AssertStatusReply(t, result, "OK") - result = testNodeA.db.Exec(conn, utils.ToCmdLine("EXISTS", key)) - asserts.AssertIntReply(t, result, 0) - result = testNodeB.db.Exec(conn, utils.ToCmdLine("EXISTS", newKey)) - asserts.AssertIntReply(t, result, 1) - result = testNodeB.db.Exec(conn, utils.ToCmdLine("TTL", newKey)) - asserts.AssertIntReplyGreaterThan(t, result, 0) - - // same node rename - key = testNodeA.self + utils.RandString(10) - value = utils.RandString(10) - newKey = key + utils.RandString(2) - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", key, value, "ex", "1000")) - result = Rename(testNodeA, conn, utils.ToCmdLine("RENAME", key, newKey)) - asserts.AssertStatusReply(t, result, "OK") - result = testNodeA.db.Exec(conn, utils.ToCmdLine("EXISTS", key)) - asserts.AssertIntReply(t, result, 0) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("EXISTS", newKey)) - asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", newKey)) - asserts.AssertIntReplyGreaterThan(t, result, 0) - - // test src prepare failed - *simulateATimout = true - key = testNodeA.self + utils.RandString(10) - newKey = testNodeB.self + utils.RandString(10) // route to testNodeB, see mockPicker.PickNode - value = utils.RandString(10) - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", key, value, "ex", "1000")) - result = Rename(testNodeB, conn, utils.ToCmdLine("RENAME", key, newKey)) - asserts.AssertErrReply(t, result, "ERR timeout") - result = testNodeA.db.Exec(conn, utils.ToCmdLine("EXISTS", key)) - asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", key)) - asserts.AssertIntReplyGreaterThan(t, result, 0) - result = testNodeB.db.Exec(conn, utils.ToCmdLine("EXISTS", newKey)) - asserts.AssertIntReply(t, result, 0) - *simulateATimout = false - - // test dest prepare failed - *simulateBTimout = true - key = testNodeA.self + utils.RandString(10) - newKey = testNodeB.self + utils.RandString(10) // route to testNodeB, see mockPicker.PickNode - value = utils.RandString(10) - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", key, value, "ex", "1000")) - result = Rename(testNodeA, conn, utils.ToCmdLine("RENAME", key, newKey)) - asserts.AssertErrReply(t, result, "ERR timeout") - result = testNodeA.db.Exec(conn, utils.ToCmdLine("EXISTS", key)) - asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", key)) - asserts.AssertIntReplyGreaterThan(t, result, 0) - result = testNodeB.db.Exec(conn, utils.ToCmdLine("EXISTS", newKey)) - asserts.AssertIntReply(t, result, 0) - *simulateBTimout = false - - result = Rename(testNodeA, conn, utils.ToCmdLine("RENAME", key)) - asserts.AssertErrReply(t, result, "ERR wrong number of arguments for 'rename' command") + for i := 0; i < 10; i++ { + testNodeA.Exec(conn, utils.ToCmdLine("FlushALL")) + key := utils.RandString(10) + value := utils.RandString(10) + newKey := utils.RandString(10) + testNodeA.Exec(conn, utils.ToCmdLine("SET", key, value, "ex", "100000")) + result := testNodeA.Exec(conn, utils.ToCmdLine("RENAME", key, newKey)) + asserts.AssertStatusReply(t, result, "OK") + result = testNodeA.Exec(conn, utils.ToCmdLine("EXISTS", key)) + asserts.AssertIntReply(t, result, 0) + result = testNodeA.Exec(conn, utils.ToCmdLine("EXISTS", newKey)) + asserts.AssertIntReply(t, result, 1) + result = testNodeA.Exec(conn, utils.ToCmdLine("TTL", newKey)) + asserts.AssertIntReplyGreaterThan(t, result, 0) + } } func TestRenameNx(t *testing.T) { + testNodeA := testCluster[0] conn := new(connection.FakeConn) - testNodeA.db.Exec(conn, utils.ToCmdLine("FlushALL")) + // cross node rename - key := testNodeA.self + utils.RandString(10) - value := utils.RandString(10) - newKey := testNodeB.self + utils.RandString(10) // route to testNodeB, see mockPicker.PickNode - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", key, value, "ex", "1000")) - result := RenameNx(testNodeA, conn, utils.ToCmdLine("RENAMENX", key, newKey)) - asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("EXISTS", key)) - asserts.AssertIntReply(t, result, 0) - result = testNodeB.db.Exec(conn, utils.ToCmdLine("EXISTS", newKey)) - asserts.AssertIntReply(t, result, 1) - result = testNodeB.db.Exec(conn, utils.ToCmdLine("TTL", newKey)) - asserts.AssertIntReplyGreaterThan(t, result, 0) - - // cross node rename, dest key exist - key = testNodeA.self + utils.RandString(10) - value = utils.RandString(10) - newKey = testNodeB.self + utils.RandString(10) // route to testNodeB, see mockPicker.PickNode - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", key, value, "ex", "1000")) - testNodeB.db.Exec(conn, utils.ToCmdLine("SET", newKey, newKey)) - result = RenameNx(testNodeA, conn, utils.ToCmdLine("RENAMENX", key, newKey)) - asserts.AssertIntReply(t, result, 0) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("EXISTS", key)) - asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", key)) - asserts.AssertIntReplyGreaterThan(t, result, 0) - result = testNodeB.db.Exec(conn, utils.ToCmdLine("GET", newKey)) - asserts.AssertBulkReply(t, result, newKey) - - // same node rename - key = testNodeA.self + utils.RandString(10) - value = utils.RandString(10) - newKey = key + utils.RandString(2) - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", key, value, "ex", "1000")) - result = RenameNx(testNodeA, conn, utils.ToCmdLine("RENAMENX", key, newKey)) - asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("EXISTS", key)) - asserts.AssertIntReply(t, result, 0) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("EXISTS", newKey)) - asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", newKey)) - asserts.AssertIntReplyGreaterThan(t, result, 0) - - // test src prepare failed - *simulateATimout = true - key = testNodeA.self + utils.RandString(10) - newKey = testNodeB.self + utils.RandString(10) // route to testNodeB, see mockPicker.PickNode - value = utils.RandString(10) - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", key, value, "ex", "1000")) - result = RenameNx(testNodeB, conn, utils.ToCmdLine("RENAMENX", key, newKey)) - asserts.AssertErrReply(t, result, "ERR timeout") - result = testNodeA.db.Exec(conn, utils.ToCmdLine("EXISTS", key)) - asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", key)) - asserts.AssertIntReplyGreaterThan(t, result, 0) - result = testNodeB.db.Exec(conn, utils.ToCmdLine("EXISTS", newKey)) - asserts.AssertIntReply(t, result, 0) - *simulateATimout = false - - // test dest prepare failed - *simulateBTimout = true - key = testNodeA.self + utils.RandString(10) - newKey = testNodeB.self + utils.RandString(10) // route to testNodeB, see mockPicker.PickNode - value = utils.RandString(10) - testNodeA.db.Exec(conn, utils.ToCmdLine("SET", key, value, "ex", "1000")) - result = RenameNx(testNodeA, conn, utils.ToCmdLine("RENAMENX", key, newKey)) - asserts.AssertErrReply(t, result, "ERR timeout") - result = testNodeA.db.Exec(conn, utils.ToCmdLine("EXISTS", key)) - asserts.AssertIntReply(t, result, 1) - result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", key)) - asserts.AssertIntReplyGreaterThan(t, result, 0) - result = testNodeB.db.Exec(conn, utils.ToCmdLine("EXISTS", newKey)) - asserts.AssertIntReply(t, result, 0) - *simulateBTimout = false - - result = RenameNx(testNodeA, conn, utils.ToCmdLine("RENAMENX", key)) - asserts.AssertErrReply(t, result, "ERR wrong number of arguments for 'renamenx' command") + for i := 0; i < 10; i++ { + testNodeA.Exec(conn, utils.ToCmdLine("FlushALL")) + key := utils.RandString(10) + value := utils.RandString(10) + newKey := utils.RandString(10) + testNodeA.Exec(conn, utils.ToCmdLine("SET", key, value, "ex", "100000")) + result := testNodeA.Exec(conn, utils.ToCmdLine("RENAMENX", key, newKey)) + asserts.AssertIntReply(t, result, 1) + result = testNodeA.Exec(conn, utils.ToCmdLine("EXISTS", key)) + asserts.AssertIntReply(t, result, 0) + result = testNodeA.Exec(conn, utils.ToCmdLine("EXISTS", newKey)) + asserts.AssertIntReply(t, result, 1) + result = testNodeA.Exec(conn, utils.ToCmdLine("TTL", newKey)) + asserts.AssertIntReplyGreaterThan(t, result, 0) + value2 := value + "ccc" + testNodeA.Exec(conn, utils.ToCmdLine("SET", key, value2, "ex", "100000")) + result = testNodeA.Exec(conn, utils.ToCmdLine("RENAMENX", key, newKey)) + asserts.AssertIntReply(t, result, 0) + result = testNodeA.Exec(conn, utils.ToCmdLine("EXISTS", key)) + asserts.AssertIntReply(t, result, 1) + } } diff --git a/cluster/router.go b/cluster/router.go index d053b12..e8778be 100644 --- a/cluster/router.go +++ b/cluster/router.go @@ -1,135 +1,166 @@ package cluster -import "github.com/hdt3213/godis/interface/redis" +import ( + "github.com/hdt3213/godis/interface/redis" + "strings" +) // CmdLine is alias for [][]byte, represents a command line type CmdLine = [][]byte -func makeRouter() map[string]CmdFunc { - routerMap := make(map[string]CmdFunc) - routerMap["ping"] = ping - routerMap["info"] = info +var router = make(map[string]CmdFunc) - routerMap["prepare"] = execPrepare - routerMap["commit"] = execCommit - routerMap["rollback"] = execRollback - routerMap["del"] = Del +func registerCmd(name string, cmd CmdFunc) { + name = strings.ToLower(name) + router[name] = cmd +} - routerMap["expire"] = defaultFunc - routerMap["expireat"] = defaultFunc - routerMap["expiretime"] = defaultFunc - routerMap["pexpire"] = defaultFunc - routerMap["pexpireat"] = defaultFunc - routerMap["pexpiretime"] = defaultFunc - routerMap["ttl"] = defaultFunc - routerMap["pttl"] = defaultFunc - routerMap["persist"] = defaultFunc - routerMap["exists"] = defaultFunc - routerMap["type"] = defaultFunc - routerMap["rename"] = Rename - routerMap["renamenx"] = RenameNx - routerMap["copy"] = Copy - - routerMap["set"] = defaultFunc - routerMap["setnx"] = defaultFunc - routerMap["setex"] = defaultFunc - routerMap["psetex"] = defaultFunc - routerMap["mset"] = MSet - routerMap["mget"] = MGet - routerMap["msetnx"] = MSetNX - routerMap["get"] = defaultFunc - routerMap["getex"] = defaultFunc - routerMap["getset"] = defaultFunc - routerMap["getdel"] = defaultFunc - routerMap["incr"] = defaultFunc - routerMap["incrby"] = defaultFunc - routerMap["incrbyfloat"] = defaultFunc - routerMap["decr"] = defaultFunc - routerMap["decrby"] = defaultFunc - routerMap["randomkey"] = randomkey - - routerMap["lpush"] = defaultFunc - routerMap["lpushx"] = defaultFunc - routerMap["rpush"] = defaultFunc - routerMap["rpushx"] = defaultFunc - routerMap["lpop"] = defaultFunc - routerMap["rpop"] = defaultFunc - //routerMap["rpoplpush"] = RPopLPush - routerMap["lrem"] = defaultFunc - routerMap["llen"] = defaultFunc - routerMap["lindex"] = defaultFunc - routerMap["lset"] = defaultFunc - routerMap["lrange"] = defaultFunc - - routerMap["hset"] = defaultFunc - routerMap["hsetnx"] = defaultFunc - routerMap["hget"] = defaultFunc - routerMap["hexists"] = defaultFunc - routerMap["hdel"] = defaultFunc - routerMap["hlen"] = defaultFunc - routerMap["hstrlen"] = defaultFunc - routerMap["hmget"] = defaultFunc - routerMap["hmset"] = defaultFunc - routerMap["hkeys"] = defaultFunc - routerMap["hvals"] = defaultFunc - routerMap["hgetall"] = defaultFunc - routerMap["hincrby"] = defaultFunc - routerMap["hincrbyfloat"] = defaultFunc - routerMap["hrandfield"] = defaultFunc - - routerMap["sadd"] = defaultFunc - routerMap["sismember"] = defaultFunc - routerMap["srem"] = defaultFunc - routerMap["spop"] = defaultFunc - routerMap["scard"] = defaultFunc - routerMap["smembers"] = defaultFunc - routerMap["sinter"] = defaultFunc - routerMap["sinterstore"] = defaultFunc - routerMap["sunion"] = defaultFunc - routerMap["sunionstore"] = defaultFunc - routerMap["sdiff"] = defaultFunc - routerMap["sdiffstore"] = defaultFunc - routerMap["srandmember"] = defaultFunc - - routerMap["zadd"] = defaultFunc - routerMap["zscore"] = defaultFunc - routerMap["zincrby"] = defaultFunc - routerMap["zrank"] = defaultFunc - routerMap["zcount"] = defaultFunc - routerMap["zrevrank"] = defaultFunc - routerMap["zcard"] = defaultFunc - routerMap["zrange"] = defaultFunc - routerMap["zrevrange"] = defaultFunc - routerMap["zrangebyscore"] = defaultFunc - routerMap["zrevrangebyscore"] = defaultFunc - routerMap["zrem"] = defaultFunc - routerMap["zremrangebyscore"] = defaultFunc - routerMap["zremrangebyrank"] = defaultFunc - - routerMap["geoadd"] = defaultFunc - routerMap["geopos"] = defaultFunc - routerMap["geodist"] = defaultFunc - routerMap["geohash"] = defaultFunc - routerMap["georadius"] = defaultFunc - routerMap["georadiusbymember"] = defaultFunc - - routerMap["publish"] = Publish - routerMap[relayPublish] = onRelayedPublish - routerMap["subscribe"] = Subscribe - routerMap["unsubscribe"] = UnSubscribe - - routerMap["flushdb"] = FlushDB - routerMap["flushall"] = FlushAll - routerMap[relayMulti] = execRelayedMulti - routerMap["getver"] = defaultFunc - routerMap["watch"] = execWatch - - return routerMap +func registerDefaultCmd(name string) { + registerCmd(name, defaultFunc) } // relay command to responsible peer, and return its protocol to client func defaultFunc(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { key := string(args[1]) - peer := cluster.peerPicker.PickNode(key) - return cluster.relay(peer, c, args) + slotId := getSlot(key) + peer := cluster.pickNode(slotId) + if peer.ID == cluster.self { + err := cluster.ensureKeyWithoutLock(key) + if err != nil { + return err + } + // to self db + //return cluster.db.Exec(c, cmdLine) + return cluster.db.Exec(c, args) + } + return cluster.relay(peer.ID, c, args) +} + +func init() { + registerCmd("Ping", ping) + registerCmd("Prepare", execPrepare) + registerCmd("Commit", execCommit) + registerCmd("Rollback", execRollback) + registerCmd("Del", Del) + registerCmd("Rename", Rename) + registerCmd("RenameNx", RenameNx) + registerCmd("Copy", Copy) + registerCmd("MSet", MSet) + registerCmd("MGet", MGet) + registerCmd("MSetNx", MSetNX) + registerCmd("Publish", Publish) + registerCmd("Subscribe", Subscribe) + registerCmd("Unsubscribe", UnSubscribe) + registerCmd("FlushDB", FlushDB) + registerCmd("FlushAll", FlushAll) + registerCmd(relayMulti, execRelayedMulti) + registerCmd("Watch", execWatch) + registerCmd("FlushDB_", genPenetratingExecutor("FlushDB")) + registerCmd("Copy_", genPenetratingExecutor("Copy")) + registerCmd("Watch_", genPenetratingExecutor("Watch")) + registerCmd(relayPublish, genPenetratingExecutor("Publish")) + registerCmd("Del_", genPenetratingExecutor("Del")) + registerCmd("MSet_", genPenetratingExecutor("MSet")) + registerCmd("MSetNx_", genPenetratingExecutor("MSetNx")) + registerCmd("MGet_", genPenetratingExecutor("MGet")) + registerCmd("Rename_", genPenetratingExecutor("Rename")) + registerCmd("RenameNx_", genPenetratingExecutor("RenameNx")) + registerCmd("DumpKey_", genPenetratingExecutor("DumpKey")) + + defaultCmds := []string{ + "expire", + "expireAt", + "pExpire", + "pExpireAt", + "ttl", + "PTtl", + "persist", + "exists", + "type", + "set", + "setNx", + "setEx", + "pSetEx", + "get", + "getEx", + "getSet", + "getDel", + "incr", + "incrBy", + "incrByFloat", + "decr", + "decrBy", + "lPush", + "lPushX", + "rPush", + "rPushX", + "LPop", + "RPop", + "LRem", + "LLen", + "LIndex", + "LSet", + "LRange", + "HSet", + "HSetNx", + "HGet", + "HExists", + "HDel", + "HLen", + "HStrLen", + "HMGet", + "HMSet", + "HKeys", + "HVals", + "HGetAll", + "HIncrBy", + "HIncrByFloat", + "HRandField", + "SAdd", + "SIsMember", + "SRem", + "SPop", + "SCard", + "SMembers", + "SInter", + "SInterStore", + "SUnion", + "SUnionStore", + "SDiff", + "SDiffStore", + "SRandMember", + "ZAdd", + "ZScore", + "ZIncrBy", + "ZRank", + "ZCount", + "ZRevRank", + "ZCard", + "ZRange", + "ZRevRange", + "ZRangeByScore", + "ZRevRangeByScore", + "ZRem", + "ZRemRangeByScore", + "ZRemRangeByRank", + "GeoAdd", + "GeoPos", + "GeoDist", + "GeoHash", + "GeoRadius", + "GeoRadiusByMember", + "GetVer", + "DumpKey", + } + for _, name := range defaultCmds { + registerDefaultCmd(name) + } + +} + +// genPenetratingExecutor generates an executor that can reach directly to the database layer +func genPenetratingExecutor(realCmd string) CmdFunc { + return func(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply { + return cluster.db.Exec(c, modifyCmd(cmdLine, realCmd)) + } } diff --git a/cluster/tcc.go b/cluster/tcc.go index e8d69b0..b4ffe7c 100644 --- a/cluster/tcc.go +++ b/cluster/tcc.go @@ -90,6 +90,18 @@ func (tx *Transaction) prepare() error { // lock writeKeys tx.lockKeys() + for _, key := range tx.writeKeys { + err := tx.cluster.ensureKey(key) + if err != nil { + return err + } + } + for _, key := range tx.readKeys { + err := tx.cluster.ensureKey(key) + if err != nil { + return err + } + } // build undoLog tx.undoLog = tx.cluster.db.GetUndoLogs(tx.dbIndex, tx.cmdLine) tx.status = preparedStatus @@ -131,7 +143,9 @@ func execPrepare(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Re txID := string(cmdLine[1]) cmdName := strings.ToLower(string(cmdLine[2])) tx := NewTransaction(cluster, c, txID, cmdLine[2:]) + cluster.transactionMu.Lock() cluster.transactions.Put(txID, tx) + cluster.transactionMu.Unlock() err := tx.prepare() if err != nil { return protocol.MakeErrReply(err.Error()) @@ -149,7 +163,9 @@ func execRollback(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.R return protocol.MakeErrReply("ERR wrong number of arguments for 'rollback' command") } txID := string(cmdLine[1]) + cluster.transactionMu.RLock() raw, ok := cluster.transactions.Get(txID) + cluster.transactionMu.RUnlock() if !ok { return protocol.MakeIntReply(0) } @@ -163,7 +179,9 @@ func execRollback(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.R } // clean transaction timewheel.Delay(waitBeforeCleanTx, "", func() { + cluster.transactionMu.Lock() cluster.transactions.Remove(tx.id) + cluster.transactionMu.Unlock() }) return protocol.MakeIntReply(1) } @@ -174,7 +192,9 @@ func execCommit(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Rep return protocol.MakeErrReply("ERR wrong number of arguments for 'commit' command") } txID := string(cmdLine[1]) + cluster.transactionMu.RLock() raw, ok := cluster.transactions.Get(txID) + cluster.transactionMu.RUnlock() if !ok { return protocol.MakeIntReply(0) } @@ -196,7 +216,9 @@ func execCommit(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Rep // clean finished transaction // do not clean immediately, in case rollback timewheel.Delay(waitBeforeCleanTx, "", func() { + cluster.transactionMu.Lock() cluster.transactions.Remove(tx.id) + cluster.transactionMu.Unlock() }) return result } @@ -207,12 +229,7 @@ func requestCommit(cluster *Cluster, c redis.Connection, txID int64, groupMap ma txIDStr := strconv.FormatInt(txID, 10) respList := make([]redis.Reply, 0, len(groupMap)) for node := range groupMap { - var resp redis.Reply - if node == cluster.self { - resp = execCommit(cluster, c, makeArgs("commit", txIDStr)) - } else { - resp = cluster.relay(node, c, makeArgs("commit", txIDStr)) - } + resp := cluster.relay(node, c, makeArgs("commit", txIDStr)) if protocol.IsErrorReply(resp) { errReply = resp.(protocol.ErrorReply) break @@ -231,18 +248,6 @@ func requestCommit(cluster *Cluster, c redis.Connection, txID int64, groupMap ma func requestRollback(cluster *Cluster, c redis.Connection, txID int64, groupMap map[string][]string) { txIDStr := strconv.FormatInt(txID, 10) for node := range groupMap { - if node == cluster.self { - execRollback(cluster, c, makeArgs("rollback", txIDStr)) - } else { - cluster.relay(node, c, makeArgs("rollback", txIDStr)) - } - } -} - -func (cluster *Cluster) relayPrepare(node string, c redis.Connection, cmdLine CmdLine) redis.Reply { - if node == cluster.self { - return execPrepare(cluster, c, cmdLine) - } else { - return cluster.relay(node, c, cmdLine) + cluster.relay(node, c, makeArgs("rollback", txIDStr)) } } diff --git a/cluster/tcc_test.go b/cluster/tcc_test.go index f8c2ec8..11ee3a3 100644 --- a/cluster/tcc_test.go +++ b/cluster/tcc_test.go @@ -10,28 +10,31 @@ import ( func TestRollback(t *testing.T) { // rollback uncommitted transaction + testNodeA := testCluster[0] conn := new(connection.FakeConn) FlushAll(testNodeA, conn, toArgs("FLUSHALL")) txID := rand.Int63() txIDStr := strconv.FormatInt(txID, 10) - keys := []string{"a", "b"} - groupMap := testNodeA.groupBy(keys) + keys := []string{"a", "{a}1"} + groupMap := map[string][]string{ + testNodeA.self: keys, + } args := []string{txIDStr, "DEL"} args = append(args, keys...) - testNodeA.Exec(conn, toArgs("SET", "a", "a")) + testNodeA.db.Exec(conn, toArgs("SET", "a", "a")) ret := execPrepare(testNodeA, conn, makeArgs("Prepare", args...)) asserts.AssertNotError(t, ret) requestRollback(testNodeA, conn, txID, groupMap) - ret = testNodeA.Exec(conn, toArgs("GET", "a")) + ret = testNodeA.db.Exec(conn, toArgs("GET", "a")) asserts.AssertBulkReply(t, ret, "a") // rollback committed transaction FlushAll(testNodeA, conn, toArgs("FLUSHALL")) + testNodeA.db.Exec(conn, toArgs("SET", "a", "a")) txID = rand.Int63() txIDStr = strconv.FormatInt(txID, 10) args = []string{txIDStr, "DEL"} args = append(args, keys...) - testNodeA.Exec(conn, toArgs("SET", "a", "a")) ret = execPrepare(testNodeA, conn, makeArgs("Prepare", args...)) asserts.AssertNotError(t, ret) _, err := requestCommit(testNodeA, conn, txID, groupMap) @@ -39,9 +42,9 @@ func TestRollback(t *testing.T) { t.Errorf("del failed %v", err) return } - ret = testNodeA.Exec(conn, toArgs("GET", "a")) + ret = testNodeA.db.Exec(conn, toArgs("GET", "a")) // call db.Exec to skip key router asserts.AssertNullBulk(t, ret) requestRollback(testNodeA, conn, txID, groupMap) - ret = testNodeA.Exec(conn, toArgs("GET", "a")) + ret = testNodeA.db.Exec(conn, toArgs("GET", "a")) asserts.AssertBulkReply(t, ret, "a") } diff --git a/cluster/topo.go b/cluster/topo.go new file mode 100644 index 0000000..316715f --- /dev/null +++ b/cluster/topo.go @@ -0,0 +1,189 @@ +package cluster + +import ( + "fmt" + "github.com/hdt3213/godis/database" + "github.com/hdt3213/godis/lib/logger" + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/connection" + "github.com/hdt3213/godis/redis/protocol" + "strconv" + "time" +) + +func (cluster *Cluster) startAsSeed(listenAddr string) protocol.ErrorReply { + err := cluster.topology.StartAsSeed(listenAddr) + if err != nil { + return err + } + for i := 0; i < slotCount; i++ { + cluster.initSlot(uint32(i), slotStateHost) + } + return nil +} + +// Join send `gcluster join` to node in cluster to join +func (cluster *Cluster) Join(seed string) protocol.ErrorReply { + err := cluster.topology.Join(seed) + if err != nil { + return nil + } + /* STEP3: asynchronous migrating slots */ + go func() { + time.Sleep(time.Second) // let the cluster started + cluster.reBalance() + }() + return nil +} + +var errConfigFileNotExist = protocol.MakeErrReply("cluster config file not exist") + +// LoadConfig try to load cluster-config-file and re-join the cluster +func (cluster *Cluster) LoadConfig() protocol.ErrorReply { + err := cluster.topology.LoadConfigFile() + if err != nil { + return err + } + selfNodeId := cluster.topology.GetSelfNodeID() + selfNode := cluster.topology.GetNode(selfNodeId) + if selfNode == nil { + return protocol.MakeErrReply("ERR self node info not found") + } + for _, slot := range selfNode.Slots { + cluster.initSlot(slot.ID, slotStateHost) + } + return nil +} + +func (cluster *Cluster) reBalance() { + nodes := cluster.topology.GetNodes() + var slotIDs []uint32 + var slots []*Slot + reqDonateCmdLine := utils.ToCmdLine("gcluster", "request-donate", cluster.self) + for _, node := range nodes { + if node.ID == cluster.self { + continue + } + node := node + peerCli, err := cluster.clientFactory.GetPeerClient(node.Addr) + if err != nil { + logger.Errorf("get client of %s failed: %v", node.Addr, err) + continue + } + resp := peerCli.Send(reqDonateCmdLine) + payload, ok := resp.(*protocol.MultiBulkReply) + if !ok { + logger.Errorf("request donate to %s failed: %v", node.Addr, err) + continue + } + for _, bin := range payload.Args { + slotID64, err := strconv.ParseUint(string(bin), 10, 64) + if err != nil { + continue + } + slotID := uint32(slotID64) + slotIDs = append(slotIDs, slotID) + slots = append(slots, &Slot{ + ID: slotID, + NodeID: node.ID, + }) + // Raft cannot guarantee the simultaneity and order of submissions to the source and destination nodes + // In some cases the source node thinks the slot belongs to the destination node, and the destination node thinks the slot belongs to the source node + // To avoid it, the source node and the destination node must reach a consensus before propose to raft + cluster.setLocalSlotImporting(slotID, node.ID) + } + } + if len(slots) == 0 { + return + } + logger.Infof("received %d donated slots", len(slots)) + + // change route + err := cluster.topology.SetSlot(slotIDs, cluster.self) + if err != nil { + logger.Errorf("set slot route failed: %v", err) + return + } + slotChan := make(chan *Slot, len(slots)) + for _, slot := range slots { + slotChan <- slot + } + close(slotChan) + for i := 0; i < 4; i++ { + i := i + go func() { + for slot := range slotChan { + logger.Info("start import slot ", slot.ID) + err := cluster.importSlot(slot) + if err != nil { + logger.Error(fmt.Sprintf("import slot %d error: %v", slot.ID, err)) + // delete all imported keys in slot + cluster.cleanDroppedSlot(slot.ID) + // todo: recover route + return + } + logger.Infof("finish import slot: %d, about %d slots remains", slot.ID, len(slotChan)) + } + logger.Infof("import worker %d exited", i) + }() + } +} + +// importSlot do migrate slot into current node +// the pseudo `slot` parameter is used to store slotID and former host node +func (cluster *Cluster) importSlot(slot *Slot) error { + node := cluster.topology.GetNode(slot.NodeID) + + /* get migrate stream */ + migrateCmdLine := utils.ToCmdLine( + "gcluster", "migrate", strconv.Itoa(int(slot.ID))) + migrateStream, err := cluster.clientFactory.NewStream(node.Addr, migrateCmdLine) + if err != nil { + return err + } + defer migrateStream.Close() + + fakeConn := connection.NewFakeConn() +slotLoop: + for proto := range migrateStream.Stream() { + if proto.Err != nil { + return fmt.Errorf("set slot %d error: %v", slot.ID, err) + } + switch reply := proto.Data.(type) { + case *protocol.MultiBulkReply: + // todo: handle exec error + keys, _ := database.GetRelatedKeys(reply.Args) + // assert len(keys) == 1 + key := keys[0] + // key may be imported by Cluster.ensureKey or by former failed migrating try + if !cluster.isImportedKey(key) { + cluster.setImportedKey(key) + _ = cluster.db.Exec(fakeConn, reply.Args) + } + case *protocol.StatusReply: + if protocol.IsOKReply(reply) { + break slotLoop + } else { + // todo: return slot to former host node + msg := fmt.Sprintf("migrate slot %d error: %s", slot.ID, reply.Status) + logger.Errorf(msg) + return protocol.MakeErrReply(msg) + } + case protocol.ErrorReply: + // todo: return slot to former host node + msg := fmt.Sprintf("migrate slot %d error: %s", slot.ID, reply.Error()) + logger.Errorf(msg) + return protocol.MakeErrReply(msg) + } + } + cluster.finishSlotImport(slot.ID) + + // finish migration mode + peerCli, err := cluster.clientFactory.GetPeerClient(node.Addr) + if err != nil { + return err + } + defer cluster.clientFactory.ReturnPeerClient(node.Addr, peerCli) + peerCli.Send(utils.ToCmdLine("gcluster", "migrate-done", strconv.Itoa(int(slot.ID)))) + return nil +} diff --git a/cluster/topo_gcluster.go b/cluster/topo_gcluster.go new file mode 100644 index 0000000..8b9fbc1 --- /dev/null +++ b/cluster/topo_gcluster.go @@ -0,0 +1,146 @@ +package cluster + +import ( + "fmt" + "github.com/hdt3213/godis/aof" + "github.com/hdt3213/godis/interface/redis" + "github.com/hdt3213/godis/lib/logger" + "github.com/hdt3213/godis/redis/protocol" + "strconv" + "strings" +) + +func init() { + registerCmd("gcluster", execGCluster) +} + +func execGCluster(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + if len(args) < 2 { + return protocol.MakeArgNumErrReply("gcluster") + } + subCmd := strings.ToLower(string(args[1])) + switch subCmd { + case "set-slot": + // Command line: gcluster set-slot + // Other node request current node to migrate a slot to it. + // Current node will set the slot as migrating state. + // After this function return, all requests of target slot will be routed to target node + return execGClusterSetSlot(cluster, c, args[2:]) + case "migrate": + // Command line: gcluster migrate + // Current node will dump the given slot to the node sending this request + // The given slot must in migrating state + return execGClusterMigrate(cluster, c, args[2:]) + case "migrate-done": + // command line: gcluster migrate-done + // The new node hosting given slot tells current node that migration has finished, remains data can be deleted + return execGClusterMigrateDone(cluster, c, args[2:]) + case "request-donate": + // command line: gcluster donate + // picks some slots and gives them to the calling node for load balance + return execGClusterDonateSlot(cluster, c, args[2:]) + } + return protocol.MakeErrReply(" ERR unknown gcluster sub command '" + subCmd + "'") +} + +// execGClusterSetSlot set a current node hosted slot as migrating +// args is [slotID, newNodeId] +func execGClusterSetSlot(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + if len(args) != 2 { + return protocol.MakeArgNumErrReply("gcluster") + } + slotId0, err := strconv.Atoi(string(args[0])) + if err != nil || slotId0 >= slotCount { + return protocol.MakeErrReply("ERR value is not a valid slot id") + } + slotId := uint32(slotId0) + targetNodeID := string(args[1]) + targetNode := cluster.topology.GetNode(targetNodeID) + if targetNode == nil { + return protocol.MakeErrReply("ERR node not found") + } + cluster.setSlotMovingOut(slotId, targetNodeID) + logger.Info(fmt.Sprintf("set slot %d to node %s", slotId, targetNodeID)) + return protocol.MakeOkReply() +} + +// execGClusterDonateSlot picks some slots and gives them to the calling node for load balance +// args is [callingNodeId] +func execGClusterDonateSlot(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + targetNodeID := string(args[0]) + nodes := cluster.topology.GetNodes() // including the new node + avgSlot := slotCount / len(nodes) + cluster.slotMu.Lock() + defer cluster.slotMu.Unlock() + limit := len(cluster.slots) - avgSlot + if limit <= 0 { + return protocol.MakeEmptyMultiBulkReply() + } + result := make([][]byte, 0, limit) + // use the randomness of the for-each-in-map to randomly select slots + for slotID, slot := range cluster.slots { + if slot.state == slotStateHost { + slot.state = slotStateMovingOut + slot.newNodeID = targetNodeID + slotIDBin := []byte(strconv.FormatUint(uint64(slotID), 10)) + result = append(result, slotIDBin) + if len(result) == limit { + break + } + } + } + return protocol.MakeMultiBulkReply(result) +} + +// execGClusterMigrate Command line: gcluster migrate slotId +// Current node will dump data in the given slot to the node sending this request +// The given slot must in migrating state +func execGClusterMigrate(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + slotId0, err := strconv.Atoi(string(args[0])) + if err != nil || slotId0 >= slotCount { + return protocol.MakeErrReply("ERR value is not a valid slot id") + } + slotId := uint32(slotId0) + slot := cluster.getHostSlot(slotId) + if slot == nil || slot.state != slotStateMovingOut { + return protocol.MakeErrReply("ERR only dump migrating slot") + } + // migrating slot is immutable + logger.Info("start dump slot", slotId) + slot.keys.ForEach(func(key string) bool { + entity, ok := cluster.db.GetEntity(0, key) + if ok { + ret := aof.EntityToCmd(key, entity) + // todo: handle error and close connection + _, _ = c.Write(ret.ToBytes()) + expire := cluster.db.GetExpiration(0, key) + if expire != nil { + ret = aof.MakeExpireCmd(key, *expire) + _, _ = c.Write(ret.ToBytes()) + } + + } + return true + }) + logger.Info("finish dump slot ", slotId) + // send a ok reply to tell requesting node dump finished + return protocol.MakeOkReply() +} + +// execGClusterMigrateDone command line: gcluster migrate-done +func execGClusterMigrateDone(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + slotId0, err := strconv.Atoi(string(args[0])) + if err != nil || slotId0 >= slotCount { + return protocol.MakeErrReply("ERR value is not a valid slot id") + } + slotId := uint32(slotId0) + slot := cluster.getHostSlot(slotId) + if slot == nil || slot.state != slotStateMovingOut { + return protocol.MakeErrReply("ERR slot is not moving out") + } + cluster.cleanDroppedSlot(slotId) + cluster.slotMu.Lock() + delete(cluster.slots, slotId) + cluster.slotMu.Unlock() + return protocol.MakeOkReply() +} diff --git a/cluster/topo_interface.go b/cluster/topo_interface.go new file mode 100644 index 0000000..63103f2 --- /dev/null +++ b/cluster/topo_interface.go @@ -0,0 +1,58 @@ +package cluster + +import ( + "github.com/hdt3213/godis/redis/protocol" + "hash/crc32" + "strings" + "time" +) + +// Slot represents a hash slot, used in cluster internal messages +type Slot struct { + // ID is uint between 0 and 16383 + ID uint32 + // NodeID is id of the hosting node + // If the slot is migrating, NodeID is the id of the node importing this slot (target node) + NodeID string + // Flags stores more information of slot + Flags uint32 +} + +// getPartitionKey extract hashtag +func getPartitionKey(key string) string { + beg := strings.Index(key, "{") + if beg == -1 { + return key + } + end := strings.Index(key, "}") + if end == -1 || end == beg+1 { + return key + } + return key[beg+1 : end] +} + +func getSlot(key string) uint32 { + partitionKey := getPartitionKey(key) + return crc32.ChecksumIEEE([]byte(partitionKey)) % uint32(slotCount) +} + +// Node represents a node and its slots, used in cluster internal messages +type Node struct { + ID string + Addr string + Slots []*Slot // ascending order by slot id + Flags uint32 + lastHeard time.Time +} + +type topology interface { + GetSelfNodeID() string + GetNodes() []*Node // return a copy + GetNode(nodeID string) *Node + GetSlots() []*Slot + StartAsSeed(addr string) protocol.ErrorReply + SetSlot(slotIDs []uint32, newNodeID string) protocol.ErrorReply + LoadConfigFile() protocol.ErrorReply + Join(seed string) protocol.ErrorReply + Close() error +} diff --git a/cluster/topo_test.go b/cluster/topo_test.go new file mode 100644 index 0000000..d1ebfa0 --- /dev/null +++ b/cluster/topo_test.go @@ -0,0 +1,152 @@ +package cluster + +import ( + "github.com/hdt3213/godis/config" + database2 "github.com/hdt3213/godis/database" + "github.com/hdt3213/godis/datastruct/dict" + "github.com/hdt3213/godis/lib/idgenerator" + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/connection" + "github.com/hdt3213/godis/redis/protocol/asserts" + "path" + "strconv" + "testing" + "time" +) + +func makeTestRaft(addresses []string, timeoutFlags []bool, persistFilenames []string) ([]*Cluster, error) { + nodes := make([]*Cluster, len(addresses)) + factory := &testClientFactory{ + nodes: nodes, + timeoutFlags: timeoutFlags, + } + for i, addr := range addresses { + addr := addr + nodes[i] = &Cluster{ + self: addr, + addr: addr, + db: database2.NewStandaloneServer(), + transactions: dict.MakeSimple(), + idGenerator: idgenerator.MakeGenerator(config.Properties.Self), + clientFactory: factory, + slots: make(map[uint32]*hostSlot), + } + topologyPersistFile := persistFilenames[i] + nodes[i].topology = newRaft(nodes[i], topologyPersistFile) + } + + err := nodes[0].startAsSeed(addresses[0]) + if err != nil { + return nil, err + } + err = nodes[1].Join(addresses[0]) + if err != nil { + return nil, err + } + err = nodes[2].Join(addresses[0]) + if err != nil { + return nil, err + } + return nodes, nil +} + +func TestRaftStart(t *testing.T) { + addresses := []string{"127.0.0.1:6399", "127.0.0.1:7379", "127.0.0.1:7369"} + timeoutFlags := []bool{false, false, false} + persistFilenames := []string{"", "", ""} + nodes, err := makeTestRaft(addresses, timeoutFlags, persistFilenames) + if err != nil { + t.Error(err) + return + } + if nodes[0].asRaft().state != leader { + t.Error("expect leader") + return + } + if nodes[1].asRaft().state != follower { + t.Error("expect follower") + return + } + if nodes[2].asRaft().state != follower { + t.Error("expect follower") + return + } + size := 100 + conn := connection.NewFakeConn() + for i := 0; i < size; i++ { + str := strconv.Itoa(i) + result := nodes[0].Exec(conn, utils.ToCmdLine("SET", str, str)) + asserts.AssertNotError(t, result) + } + for i := 0; i < size; i++ { + str := strconv.Itoa(i) + result := nodes[0].Exec(conn, utils.ToCmdLine("Get", str)) + asserts.AssertBulkReply(t, result, str) + } + for _, node := range nodes { + _ = node.asRaft().Close() + } +} + +func TestRaftElection(t *testing.T) { + addresses := []string{"127.0.0.1:6399", "127.0.0.1:7379", "127.0.0.1:7369"} + timeoutFlags := []bool{false, false, false} + persistFilenames := []string{"", "", ""} + nodes, err := makeTestRaft(addresses, timeoutFlags, persistFilenames) + if err != nil { + t.Error(err) + return + } + nodes[0].asRaft().Close() + time.Sleep(3 * electionTimeoutMaxMs * time.Millisecond) // wait for leader timeout + //<-make(chan struct{}) // wait for leader timeout + for i := 0; i < 10; i++ { + leaderCount := 0 + for _, node := range nodes { + if node.asRaft().closed { + continue + } + switch node.asRaft().state { + case leader: + leaderCount++ + } + } + if leaderCount == 1 { + break + } else if leaderCount > 1 { + t.Errorf("get %d leaders, split brain", leaderCount) + break + } + time.Sleep(time.Second) + } +} + +func TestRaftPersist(t *testing.T) { + addresses := []string{"127.0.0.1:6399", "127.0.0.1:7379", "127.0.0.1:7369"} + timeoutFlags := []bool{false, false, false} + persistFilenames := []string{ + path.Join(config.Properties.Dir, "test6399.conf"), + path.Join(config.Properties.Dir, "test7379.conf"), + path.Join(config.Properties.Dir, "test7369.conf"), + } + nodes, err := makeTestRaft(addresses, timeoutFlags, persistFilenames) + if err != nil { + t.Error(err) + return + } + node1 := nodes[0].asRaft() + err = node1.persist() + if err != nil { + t.Error(err) + return + } + for _, node := range nodes { + _ = node.asRaft().Close() + } + + err = node1.LoadConfigFile() + if err != nil { + t.Error(err) + return + } +} diff --git a/cluster/topo_utils.go b/cluster/topo_utils.go new file mode 100644 index 0000000..0c1017d --- /dev/null +++ b/cluster/topo_utils.go @@ -0,0 +1,99 @@ +package cluster + +import ( + "github.com/hdt3213/godis/datastruct/set" + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/connection" +) + +func (cluster *Cluster) isImportedKey(key string) bool { + slotId := getSlot(key) + cluster.slotMu.RLock() + slot := cluster.slots[slotId] + cluster.slotMu.RUnlock() + return slot.importedKeys.Has(key) +} + +func (cluster *Cluster) setImportedKey(key string) { + slotId := getSlot(key) + cluster.slotMu.Lock() + slot := cluster.slots[slotId] + cluster.slotMu.Unlock() + slot.importedKeys.Add(key) +} + +// initSlot init a slot when start as seed or import slot from other node +func (cluster *Cluster) initSlot(slotId uint32, state uint32) { + cluster.slotMu.Lock() + defer cluster.slotMu.Unlock() + cluster.slots[slotId] = &hostSlot{ + importedKeys: set.Make(), + keys: set.Make(), + state: state, + } +} + +func (cluster *Cluster) getHostSlot(slotId uint32) *hostSlot { + cluster.slotMu.RLock() + defer cluster.slotMu.RUnlock() + return cluster.slots[slotId] +} + +func (cluster *Cluster) finishSlotImport(slotID uint32) { + cluster.slotMu.Lock() + defer cluster.slotMu.Unlock() + slot := cluster.slots[slotID] + slot.state = slotStateHost + slot.importedKeys = nil + slot.oldNodeID = "" +} + +func (cluster *Cluster) setLocalSlotImporting(slotID uint32, oldNodeID string) { + cluster.slotMu.Lock() + defer cluster.slotMu.Unlock() + slot := cluster.slots[slotID] + if slot == nil { + slot = &hostSlot{ + importedKeys: set.Make(), + keys: set.Make(), + } + cluster.slots[slotID] = slot + } + slot.state = slotStateImporting + slot.oldNodeID = oldNodeID +} + +func (cluster *Cluster) setSlotMovingOut(slotID uint32, newNodeID string) { + cluster.slotMu.Lock() + defer cluster.slotMu.Unlock() + slot := cluster.slots[slotID] + if slot == nil { + slot = &hostSlot{ + importedKeys: set.Make(), + keys: set.Make(), + } + cluster.slots[slotID] = slot + } + slot.state = slotStateMovingOut + slot.newNodeID = newNodeID +} + +// cleanDroppedSlot deletes keys when slot has moved out or failed to import +func (cluster *Cluster) cleanDroppedSlot(slotID uint32) { + cluster.slotMu.RLock() + if cluster.slots[slotID] == nil { + cluster.slotMu.RUnlock() + return + } + keys := cluster.slots[slotID].importedKeys + cluster.slotMu.RUnlock() + c := connection.NewFakeConn() + go func() { + if keys != nil { + keys.ForEach(func(key string) bool { + cluster.db.Exec(c, utils.ToCmdLine("DEL", key)) + return true + }) + } + }() +} diff --git a/cluster/utils.go b/cluster/utils.go index 6b850f8..84982c3 100644 --- a/cluster/utils.go +++ b/cluster/utils.go @@ -1,10 +1,7 @@ package cluster import ( - "github.com/hdt3213/godis/config" "github.com/hdt3213/godis/interface/redis" - "github.com/hdt3213/godis/redis/protocol" - "strconv" ) func ping(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply { @@ -34,7 +31,7 @@ func makeArgs(cmd string, args ...string) [][]byte { func (cluster *Cluster) groupBy(keys []string) map[string][]string { result := make(map[string][]string) for _, key := range keys { - peer := cluster.peerPicker.PickNode(key) + peer := cluster.pickNodeAddrByKey(key) group, ok := result[peer] if !ok { group = make([]string, 0) @@ -45,14 +42,34 @@ func (cluster *Cluster) groupBy(keys []string) map[string][]string { return result } -func execSelect(c redis.Connection, args [][]byte) redis.Reply { - dbIndex, err := strconv.Atoi(string(args[1])) - if err != nil { - return protocol.MakeErrReply("ERR invalid DB index") +// pickNode returns the node id hosting the given slot. +// If the slot is migrating, return the node which is importing the slot +func (cluster *Cluster) pickNode(slotID uint32) *Node { + // check cluster.slot to avoid errors caused by inconsistent status on follower nodes during raft commits + // see cluster.reBalance() + hSlot := cluster.getHostSlot(slotID) + if hSlot != nil { + switch hSlot.state { + case slotStateMovingOut: + return cluster.topology.GetNode(hSlot.newNodeID) + case slotStateImporting, slotStateHost: + return cluster.topology.GetNode(cluster.self) + } } - if dbIndex >= config.Properties.Databases || dbIndex < 0 { - return protocol.MakeErrReply("ERR DB index is out of range") - } - c.SelectDB(dbIndex) - return protocol.MakeOkReply() + + slot := cluster.topology.GetSlots()[int(slotID)] + node := cluster.topology.GetNode(slot.NodeID) + return node +} + +func (cluster *Cluster) pickNodeAddrByKey(key string) string { + slotId := getSlot(key) + return cluster.pickNode(slotId).Addr +} + +func modifyCmd(cmdLine CmdLine, newCmd string) CmdLine { + var cmdLine2 CmdLine + cmdLine2 = append(cmdLine2, cmdLine...) + cmdLine2[0] = []byte(newCmd) + return cmdLine2 } diff --git a/cluster/utils_test.go b/cluster/utils_test.go index 8aa701d..f222d7b 100644 --- a/cluster/utils_test.go +++ b/cluster/utils_test.go @@ -1,97 +1,145 @@ package cluster import ( + "errors" "github.com/hdt3213/godis/config" + database2 "github.com/hdt3213/godis/database" + "github.com/hdt3213/godis/datastruct/dict" "github.com/hdt3213/godis/interface/redis" + "github.com/hdt3213/godis/lib/idgenerator" + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/connection" + "github.com/hdt3213/godis/redis/parser" "github.com/hdt3213/godis/redis/protocol" "math/rand" - "strings" + "sync" ) -var testNodeA, testNodeB *Cluster -var simulateATimout, simulateBTimout *bool - -type mockPicker struct { - nodes []string +type testClientFactory struct { + nodes []*Cluster + timeoutFlags []bool } -func (picker *mockPicker) AddNode(keys ...string) { - picker.nodes = append(picker.nodes, keys...) +type testClient struct { + targetNode *Cluster + timeoutFlag *bool + conn redis.Connection } -func (picker *mockPicker) PickNode(key string) string { - for _, n := range picker.nodes { - if strings.Contains(key, n) { - return n - } +func (cli *testClient) Send(cmdLine [][]byte) redis.Reply { + if *cli.timeoutFlag { + return protocol.MakeErrReply("ERR timeout") } - return picker.nodes[0] + return cli.targetNode.Exec(cli.conn, cmdLine) } -func makeMockRelay(peer *Cluster) (*bool, func(cluster *Cluster, node string, c redis.Connection, cmdLine CmdLine) redis.Reply) { - simulateTimeout0 := false - simulateTimeout := &simulateTimeout0 - return simulateTimeout, func(cluster *Cluster, node string, c redis.Connection, cmdLine CmdLine) redis.Reply { - if len(cmdLine) == 0 { - return protocol.MakeErrReply("ERR command required") - } - if node == cluster.self { - // to self db - cmdName := strings.ToLower(string(cmdLine[0])) - if cmdName == "prepare" { - return execPrepare(cluster, c, cmdLine) - } else if cmdName == "commit" { - return execCommit(cluster, c, cmdLine) - } else if cmdName == "rollback" { - return execRollback(cluster, c, cmdLine) +func (factory *testClientFactory) GetPeerClient(peerAddr string) (peerClient, error) { + for i, n := range factory.nodes { + if n.self == peerAddr { + cli := &testClient{ + targetNode: n, + timeoutFlag: &factory.timeoutFlags[i], + conn: connection.NewFakeConn(), } - return cluster.db.Exec(c, cmdLine) + if config.Properties.RequirePass != "" { + cli.Send(utils.ToCmdLine("AUTH", config.Properties.RequirePass)) + } + return cli, nil } - if *simulateTimeout { - return protocol.MakeErrReply("ERR timeout") - } - cmdName := strings.ToLower(string(cmdLine[0])) - if cmdName == "prepare" { - return execPrepare(peer, c, cmdLine) - } else if cmdName == "commit" { - return execCommit(peer, c, cmdLine) - } else if cmdName == "rollback" { - return execRollback(peer, c, cmdLine) - } - return peer.db.Exec(c, cmdLine) } + return nil, errors.New("peer not found") } -func init() { - if config.Properties == nil { - config.Properties = &config.ServerProperties{} - } - addrA := "127.0.0.1:6399" - addrB := "127.0.0.1:7379" - config.Properties.Self = addrA - config.Properties.Peers = []string{addrB} - testNodeA = MakeCluster() - config.Properties.Self = addrB - config.Properties.Peers = []string{addrA} - testNodeB = MakeCluster() - - simulateBTimout, testNodeA.relayImpl = makeMockRelay(testNodeB) - testNodeA.peerPicker = &mockPicker{} - testNodeA.peerPicker.AddNode(addrA, addrB) - simulateATimout, testNodeB.relayImpl = makeMockRelay(testNodeA) - testNodeB.peerPicker = &mockPicker{} - testNodeB.peerPicker.AddNode(addrB, addrA) +type mockStream struct { + targetNode *Cluster + ch <-chan *parser.Payload } -func MakeTestCluster(peers []string) *Cluster { - if config.Properties == nil { - config.Properties = &config.ServerProperties{} - } - config.Properties.Self = "127.0.0.1:6399" - config.Properties.Peers = peers - return MakeCluster() +func (s *mockStream) Stream() <-chan *parser.Payload { + return s.ch } +func (s *mockStream) Close() error { + return nil +} + +func (factory *testClientFactory) NewStream(peerAddr string, cmdLine CmdLine) (peerStream, error) { + for _, n := range factory.nodes { + if n.self == peerAddr { + conn := connection.NewFakeConn() + if config.Properties.RequirePass != "" { + n.Exec(conn, utils.ToCmdLine("AUTH", config.Properties.RequirePass)) + } + result := n.Exec(conn, cmdLine) + conn.Write(result.ToBytes()) + ch := parser.ParseStream(conn) + return &mockStream{ + targetNode: n, + ch: ch, + }, nil + } + } + return nil, errors.New("node not found") +} + +func (factory *testClientFactory) ReturnPeerClient(peer string, peerClient peerClient) error { + return nil +} + +func (factory *testClientFactory) Close() error { + return nil +} + +// mockClusterNodes creates a fake cluster for test +// timeoutFlags should have the same length as addresses, set timeoutFlags[i] == true could simulate addresses[i] timeout +func mockClusterNodes(addresses []string, timeoutFlags []bool) []*Cluster { + nodes := make([]*Cluster, len(addresses)) + // build fixedTopology + slots := make([]*Slot, slotCount) + nodeMap := make(map[string]*Node) + for _, addr := range addresses { + nodeMap[addr] = &Node{ + ID: addr, + Addr: addr, + Slots: nil, + } + } + for i := range slots { + addr := addresses[i%len(addresses)] + slots[i] = &Slot{ + ID: uint32(i), + NodeID: addr, + Flags: 0, + } + nodeMap[addr].Slots = append(nodeMap[addr].Slots, slots[i]) + } + factory := &testClientFactory{ + nodes: nodes, + timeoutFlags: timeoutFlags, + } + for i, addr := range addresses { + topo := &fixedTopology{ + mu: sync.RWMutex{}, + nodeMap: nodeMap, + slots: slots, + selfNodeID: addr, + } + nodes[i] = &Cluster{ + self: addr, + db: database2.NewStandaloneServer(), + transactions: dict.MakeSimple(), + idGenerator: idgenerator.MakeGenerator(config.Properties.Self), + topology: topo, + clientFactory: factory, + } + } + return nodes +} + +var addresses = []string{"127.0.0.1:6399", "127.0.0.1:7379"} +var timeoutFlags = []bool{false, false} +var testCluster = mockClusterNodes(addresses, timeoutFlags) + func toArgs(cmd ...string) [][]byte { args := make([][]byte, len(cmd)) for i, s := range cmd { diff --git a/config/config.go b/config/config.go index b6dcc3e..6b4ec6e 100644 --- a/config/config.go +++ b/config/config.go @@ -27,6 +27,7 @@ type ServerProperties struct { Bind string `cfg:"bind"` Port int `cfg:"port"` Dir string `cfg:"dir"` + AnnounceHost string `cfg:"announce-host"` AppendOnly bool `cfg:"appendonly"` AppendFilename string `cfg:"appendfilename"` AppendFsync string `cfg:"appendfsync"` @@ -39,6 +40,10 @@ type ServerProperties struct { SlaveAnnouncePort int `cfg:"slave-announce-port"` SlaveAnnounceIP string `cfg:"slave-announce-ip"` ReplTimeout int `cfg:"repl-timeout"` + ClusterEnable bool `cfg:"cluster-enable"` + ClusterAsSeed bool `cfg:"cluster-as-seed"` + ClusterSeed string `cfg:"cluster-seed"` + ClusterConfigFile string `cfg:"cluster-config-file"` // for cluster mode configuration ClusterEnabled string `cfg:"cluster-enabled"` // Not used at present. @@ -53,6 +58,10 @@ type ServerInfo struct { StartUpTime time.Time } +func (p *ServerProperties) AnnounceAddress() string { + return p.AnnounceHost + ":" + strconv.Itoa(p.Port) +} + // Properties holds global config properties var Properties *ServerProperties var EachTimeServerInfo *ServerInfo diff --git a/database/database.go b/database/database.go index 859327a..9b48eab 100644 --- a/database/database.go +++ b/database/database.go @@ -30,6 +30,10 @@ type DB struct { // addaof is used to add command to aof addAof func(CmdLine) + + // callbacks + insertCallback database.KeyEventCallback + deleteCallback database.KeyEventCallback } // ExecFunc is interface for command executor @@ -159,7 +163,13 @@ func (db *DB) GetEntity(key string) (*database.DataEntity, bool) { // PutEntity a DataEntity into DB func (db *DB) PutEntity(key string, entity *database.DataEntity) int { - return db.data.PutWithLock(key, entity) + ret := db.data.PutWithLock(key, entity) + // db.insertCallback may be set as nil, during `if` and actually callback + // so introduce a local variable `cb` + if cb := db.insertCallback; ret > 0 && cb != nil { + cb(db.index, key, entity) + } + return ret } // PutIfExists edit an existing DataEntity @@ -169,15 +179,28 @@ func (db *DB) PutIfExists(key string, entity *database.DataEntity) int { // PutIfAbsent insert an DataEntity only if the key not exists func (db *DB) PutIfAbsent(key string, entity *database.DataEntity) int { - return db.data.PutIfAbsentWithLock(key, entity) + ret := db.data.PutIfAbsentWithLock(key, entity) + // db.insertCallback may be set as nil, during `if` and actually callback + // so introduce a local variable `cb` + if cb := db.insertCallback; ret > 0 && cb != nil { + cb(db.index, key, entity) + } + return ret } // Remove the given key from db func (db *DB) Remove(key string) { - db.data.RemoveWithLock(key) + raw, deleted := db.data.RemoveWithLock(key) db.ttlMap.Remove(key) taskKey := genExpireTask(key) timewheel.Cancel(taskKey) + if cb := db.deleteCallback; cb != nil { + var entity *database.DataEntity + if deleted > 0 { + entity = raw.(*database.DataEntity) + } + cb(db.index, key, entity) + } } // Removes the given keys from db diff --git a/database/hash.go b/database/hash.go index 6ca5256..811661f 100644 --- a/database/hash.go +++ b/database/hash.go @@ -149,7 +149,7 @@ func execHDel(db *DB, args [][]byte) redis.Reply { deleted := 0 for _, field := range fields { - result := dict.Remove(field) + _, result := dict.Remove(field) deleted += result } if dict.Len() == 0 { diff --git a/database/server.go b/database/server.go index b74b94b..1254463 100644 --- a/database/server.go +++ b/database/server.go @@ -34,6 +34,10 @@ type Server struct { role int32 slaveStatus *slaveStatus masterStatus *masterStatus + + // hooks + insertCallback database.KeyEventCallback + deleteCallback database.KeyEventCallback } func fileExists(filename string) bool { @@ -286,6 +290,20 @@ func (server *Server) ForEach(dbIndex int, cb func(key string, data *database.Da server.mustSelectDB(dbIndex).ForEach(cb) } +// GetEntity returns the data entity to the given key +func (server *Server) GetEntity(dbIndex int, key string) (*database.DataEntity, bool) { + return server.mustSelectDB(dbIndex).GetEntity(key) +} + +func (server *Server) GetExpiration(dbIndex int, key string) *time.Time { + raw, ok := server.mustSelectDB(dbIndex).ttlMap.Get(key) + if !ok { + return nil + } + expireTime, _ := raw.(time.Time) + return &expireTime +} + // ExecMulti executes multi commands transaction Atomically and Isolated func (server *Server) ExecMulti(conn redis.Connection, watching map[string]uint32, cmdLines []CmdLine) redis.Reply { selectedDB, errReply := server.selectDB(conn.GetDBIndex()) @@ -408,3 +426,20 @@ func (server *Server) GetAvgTTL(dbIndex, randomKeyCount int) int64 { } return ttlCount / int64(len(keys)) } + +func (server *Server) SetKeyInsertedCallback(cb database.KeyEventCallback) { + server.insertCallback = cb + for i := range server.dbSet { + db := server.mustSelectDB(i) + db.insertCallback = cb + } + +} + +func (server *Server) SetKeyDeletedCallback(cb database.KeyEventCallback) { + server.deleteCallback = cb + for i := range server.dbSet { + db := server.mustSelectDB(i) + db.deleteCallback = cb + } +} diff --git a/datastruct/dict/concurrent.go b/datastruct/dict/concurrent.go index eb6d5ce..2ece9cd 100644 --- a/datastruct/dict/concurrent.go +++ b/datastruct/dict/concurrent.go @@ -219,7 +219,7 @@ func (dict *ConcurrentDict) PutIfExistsWithLock(key string, val interface{}) (re } // Remove removes the key and return the number of deleted key-value -func (dict *ConcurrentDict) Remove(key string) (result int) { +func (dict *ConcurrentDict) Remove(key string) (val interface{}, result int) { if dict == nil { panic("dict is nil") } @@ -229,15 +229,15 @@ func (dict *ConcurrentDict) Remove(key string) (result int) { s.mutex.Lock() defer s.mutex.Unlock() - if _, ok := s.m[key]; ok { + if val, ok := s.m[key]; ok { delete(s.m, key) dict.decreaseCount() - return 1 + return val, 1 } - return 0 + return nil, 0 } -func (dict *ConcurrentDict) RemoveWithLock(key string) (result int) { +func (dict *ConcurrentDict) RemoveWithLock(key string) (val interface{}, result int) { if dict == nil { panic("dict is nil") } @@ -245,12 +245,12 @@ func (dict *ConcurrentDict) RemoveWithLock(key string) (result int) { index := dict.spread(hashCode) s := dict.getShard(index) - if _, ok := s.m[key]; ok { + if val, ok := s.m[key]; ok { delete(s.m, key) dict.decreaseCount() - return 1 + return val, 1 } - return 0 + return val, 0 } func (dict *ConcurrentDict) addCount() int32 { diff --git a/datastruct/dict/concurrent_test.go b/datastruct/dict/concurrent_test.go index ba046b2..0581ef5 100644 --- a/datastruct/dict/concurrent_test.go +++ b/datastruct/dict/concurrent_test.go @@ -258,7 +258,7 @@ func TestConcurrentRemove(t *testing.T) { t.Error("put test failed: expected true, actual: false") } - ret := d.Remove(key) + _, ret := d.Remove(key) if ret != 1 { t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key:" + key) } @@ -269,7 +269,7 @@ func TestConcurrentRemove(t *testing.T) { if ok { t.Error("remove test failed: expected true, actual false") } - ret = d.Remove(key) + _, ret = d.Remove(key) if ret != 0 { t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) } @@ -298,7 +298,7 @@ func TestConcurrentRemove(t *testing.T) { t.Error("put test failed: expected true, actual: false") } - ret := d.Remove(key) + _, ret := d.Remove(key) if ret != 1 { t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret)) } @@ -306,7 +306,7 @@ func TestConcurrentRemove(t *testing.T) { if ok { t.Error("remove test failed: expected true, actual false") } - ret = d.Remove(key) + _, ret = d.Remove(key) if ret != 0 { t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) } @@ -334,7 +334,7 @@ func TestConcurrentRemove(t *testing.T) { t.Error("put test failed: expected true, actual: false") } - ret := d.Remove(key) + _, ret := d.Remove(key) if ret != 1 { t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret)) } @@ -342,7 +342,7 @@ func TestConcurrentRemove(t *testing.T) { if ok { t.Error("remove test failed: expected true, actual false") } - ret = d.Remove(key) + _, ret = d.Remove(key) if ret != 0 { t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) } @@ -374,7 +374,7 @@ func TestConcurrentRemoveWithLock(t *testing.T) { t.Error("put test failed: expected true, actual: false") } - ret := d.RemoveWithLock(key) + _, ret := d.RemoveWithLock(key) if ret != 1 { t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key:" + key) } @@ -385,7 +385,7 @@ func TestConcurrentRemoveWithLock(t *testing.T) { if ok { t.Error("remove test failed: expected true, actual false") } - ret = d.RemoveWithLock(key) + _, ret = d.RemoveWithLock(key) if ret != 0 { t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) } @@ -414,7 +414,7 @@ func TestConcurrentRemoveWithLock(t *testing.T) { t.Error("put test failed: expected true, actual: false") } - ret := d.RemoveWithLock(key) + _, ret := d.RemoveWithLock(key) if ret != 1 { t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret)) } @@ -422,7 +422,7 @@ func TestConcurrentRemoveWithLock(t *testing.T) { if ok { t.Error("remove test failed: expected true, actual false") } - ret = d.RemoveWithLock(key) + _, ret = d.RemoveWithLock(key) if ret != 0 { t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) } @@ -450,7 +450,7 @@ func TestConcurrentRemoveWithLock(t *testing.T) { t.Error("put test failed: expected true, actual: false") } - ret := d.RemoveWithLock(key) + _, ret := d.RemoveWithLock(key) if ret != 1 { t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret)) } @@ -458,7 +458,7 @@ func TestConcurrentRemoveWithLock(t *testing.T) { if ok { t.Error("remove test failed: expected true, actual false") } - ret = d.RemoveWithLock(key) + _, ret = d.RemoveWithLock(key) if ret != 0 { t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) } diff --git a/datastruct/dict/dict.go b/datastruct/dict/dict.go index 0b98577..28ecd4c 100644 --- a/datastruct/dict/dict.go +++ b/datastruct/dict/dict.go @@ -10,7 +10,7 @@ type Dict interface { Put(key string, val interface{}) (result int) PutIfAbsent(key string, val interface{}) (result int) PutIfExists(key string, val interface{}) (result int) - Remove(key string) (result int) + Remove(key string) (val interface{}, result int) ForEach(consumer Consumer) Keys() []string RandomKeys(limit int) []string diff --git a/datastruct/dict/simple.go b/datastruct/dict/simple.go index 1ad9402..6183965 100644 --- a/datastruct/dict/simple.go +++ b/datastruct/dict/simple.go @@ -57,13 +57,13 @@ func (dict *SimpleDict) PutIfExists(key string, val interface{}) (result int) { } // Remove removes the key and return the number of deleted key-value -func (dict *SimpleDict) Remove(key string) (result int) { - _, existed := dict.m[key] +func (dict *SimpleDict) Remove(key string) (val interface{}, result int) { + val, existed := dict.m[key] delete(dict.m, key) if existed { - return 1 + return val, 1 } - return 0 + return nil, 0 } // Keys returns all keys in dict diff --git a/datastruct/set/set.go b/datastruct/set/set.go index 2b1648b..11236e6 100644 --- a/datastruct/set/set.go +++ b/datastruct/set/set.go @@ -27,7 +27,8 @@ func (set *Set) Add(val string) int { // Remove removes member from set func (set *Set) Remove(val string) int { - return set.dict.Remove(val) + _, ret := set.dict.Remove(val) + return ret } // Has returns true if the val exists in the set diff --git a/interface/database/db.go b/interface/database/db.go index 6de606b..9fc2d89 100644 --- a/interface/database/db.go +++ b/interface/database/db.go @@ -18,6 +18,10 @@ type DB interface { LoadRDB(dec *core.Decoder) error } +// KeyEventCallback will be called back on key event, such as key inserted or deleted +// may be called concurrently +type KeyEventCallback func(dbIndex int, key string, entity *DataEntity) + // DBEngine is the embedding storage engine exposing more methods for complex application type DBEngine interface { DB @@ -28,6 +32,10 @@ type DBEngine interface { RWLocks(dbIndex int, writeKeys []string, readKeys []string) RWUnLocks(dbIndex int, writeKeys []string, readKeys []string) GetDBSize(dbIndex int) (int, int) + GetEntity(dbIndex int, key string) (*DataEntity, bool) + GetExpiration(dbIndex int, key string) *time.Time + SetKeyInsertedCallback(cb KeyEventCallback) + SetKeyDeletedCallback(cb KeyEventCallback) } // DataEntity stores data bound to a key, including a string, list, hash, set and so on diff --git a/interface/redis/conn.go b/interface/redis/conn.go index 0f8f920..282a8da 100644 --- a/interface/redis/conn.go +++ b/interface/redis/conn.go @@ -4,6 +4,7 @@ package redis type Connection interface { Write([]byte) (int, error) Close() error + RemoteAddr() string SetPassword(string) GetPassword() string diff --git a/lib/logger/logger.go b/lib/logger/logger.go index f268df4..31dc3fc 100644 --- a/lib/logger/logger.go +++ b/lib/logger/logger.go @@ -57,7 +57,7 @@ func Setup(settings *Settings) { logFile, err = mustOpen(fileName, dir) if err != nil { - log.Fatalf("logging.Setup err: %s", err) + log.Fatalf("logging.Join err: %s", err) } mw := io.MultiWriter(os.Stdout, logFile) @@ -83,6 +83,13 @@ func Debug(v ...interface{}) { logger.Println(v...) } +func Debugf(format string, v ...interface{}) { + mu.Lock() + defer mu.Unlock() + setPrefix(DEBUG) + logger.Println(fmt.Sprintf(format, v...)) +} + // Info prints normal log func Info(v ...interface{}) { mu.Lock() @@ -91,6 +98,14 @@ func Info(v ...interface{}) { logger.Println(v...) } +// Infof prints normal log +func Infof(format string, v ...interface{}) { + mu.Lock() + defer mu.Unlock() + setPrefix(INFO) + logger.Println(fmt.Sprintf(format, v...)) +} + // Warn prints warning log func Warn(v ...interface{}) { mu.Lock() diff --git a/lib/utils/rand_string.go b/lib/utils/rand_string.go index 2aed53c..c1be83d 100644 --- a/lib/utils/rand_string.go +++ b/lib/utils/rand_string.go @@ -5,14 +5,14 @@ import ( "time" ) +var r = rand.New(rand.NewSource(time.Now().UnixNano())) var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") // RandString create a random string no longer than n func RandString(n int) string { - nR := rand.New(rand.NewSource(time.Now().UnixNano())) b := make([]rune, n) for i := range b { - b[i] = letters[nR.Intn(len(letters))] + b[i] = letters[r.Intn(len(letters))] } return string(b) } @@ -22,7 +22,19 @@ var hexLetters = []rune("0123456789abcdef") func RandHexString(n int) string { b := make([]rune, n) for i := range b { - b[i] = hexLetters[rand.Intn(len(hexLetters))] + b[i] = hexLetters[r.Intn(len(hexLetters))] } return string(b) } + +// RandIndex returns random indexes to random pick elements from slice +func RandIndex(size int) []int { + result := make([]int, size) + for i := range result { + result[i] = i + } + rand.Shuffle(size, func(i, j int) { + result[i], result[j] = result[j], result[i] + }) + return result +} diff --git a/node1.conf b/node1.conf index cb00c26..6040ce8 100644 --- a/node1.conf +++ b/node1.conf @@ -5,5 +5,7 @@ maxclients 128 appendonly no appendfilename appendonly.aof -peers localhost:7379 -self localhost:6399 +announce-host 127.0.0.1 +cluster-enable yes +cluster-as-seed yes +cluster-config-file 6399.conf diff --git a/node2.conf b/node2.conf index 9a5330a..88432f6 100644 --- a/node2.conf +++ b/node2.conf @@ -5,5 +5,7 @@ maxclients 128 appendonly no appendfilename appendonly.aof -peers localhost:6399 -self localhost:7379 +announce-host 127.0.0.1 +cluster-enable yes +cluster-seed 127.0.0.1:6399 +cluster-config-file 7379.conf \ No newline at end of file diff --git a/node3.conf b/node3.conf new file mode 100644 index 0000000..c4a55e6 --- /dev/null +++ b/node3.conf @@ -0,0 +1,11 @@ +bind 0.0.0.0 +port 7369 +maxclients 128 + +appendonly no +appendfilename appendonly.aof + +announce-host 127.0.0.1 +cluster-enable yes +cluster-seed 127.0.0.1:6399 +cluster-config-file 7369.conf \ No newline at end of file diff --git a/redis/client/client.go b/redis/client/client.go index 6263d52..2c526a0 100644 --- a/redis/client/client.go +++ b/redis/client/client.go @@ -136,23 +136,23 @@ func (client *Client) Send(args [][]byte) redis.Reply { if atomic.LoadInt32(&client.status) != running { return protocol.MakeErrReply("client closed") } - request := &request{ + req := &request{ args: args, heartbeat: false, waiting: &wait.Wait{}, } - request.waiting.Add(1) + req.waiting.Add(1) client.working.Add(1) defer client.working.Done() - client.pendingReqs <- request - timeout := request.waiting.WaitWithTimeout(maxWait) + client.pendingReqs <- req + timeout := req.waiting.WaitWithTimeout(maxWait) if timeout { return protocol.MakeErrReply("server time out") } - if request.err != nil { - return protocol.MakeErrReply("request failed") + if req.err != nil { + return protocol.MakeErrReply("request failed " + req.err.Error()) } - return request.reply + return req.reply } func (client *Client) doHeartbeat() { diff --git a/redis/connection/conn.go b/redis/connection/conn.go index 0292f42..8e1755d 100644 --- a/redis/connection/conn.go +++ b/redis/connection/conn.go @@ -50,8 +50,8 @@ var connPool = sync.Pool{ } // RemoteAddr returns the remote network address -func (c *Connection) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() +func (c *Connection) RemoteAddr() string { + return c.conn.RemoteAddr().String() } // Close disconnect with the client diff --git a/redis/connection/fake.go b/redis/connection/fake.go index 0b28c2f..8508199 100644 --- a/redis/connection/fake.go +++ b/redis/connection/fake.go @@ -104,3 +104,7 @@ func (c *FakeConn) Close() error { c.notify() return nil } + +func (c *FakeConn) RemoteAddr() string { + return "" +} diff --git a/redis/protocol/consts.go b/redis/protocol/consts.go index 73d46a4..6b45430 100644 --- a/redis/protocol/consts.go +++ b/redis/protocol/consts.go @@ -1,5 +1,10 @@ package protocol +import ( + "bytes" + "github.com/hdt3213/godis/interface/redis" +) + // PongReply is +PONG type PongReply struct{} @@ -57,6 +62,10 @@ func MakeEmptyMultiBulkReply() *EmptyMultiBulkReply { return &EmptyMultiBulkReply{} } +func IsEmptyMultiBulkReply(reply redis.Reply) bool { + return bytes.Equal(reply.ToBytes(), emptyMultiBulkBytes) +} + // NoReply respond nothing, for commands like subscribe type NoReply struct{} diff --git a/redis/server/server.go b/redis/server/server.go index b00646c..29dd10b 100644 --- a/redis/server/server.go +++ b/redis/server/server.go @@ -36,8 +36,7 @@ type Handler struct { // MakeHandler creates a Handler instance func MakeHandler() *Handler { var db database.DB - if config.Properties.Self != "" && - len(config.Properties.Peers) > 0 { + if config.Properties.ClusterEnable { db = cluster.MakeCluster() } else { db = database2.NewStandaloneServer() @@ -72,7 +71,7 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) { strings.Contains(payload.Err.Error(), "use of closed network connection") { // connection closed h.closeClient(client) - logger.Info("connection closed: " + client.RemoteAddr().String()) + logger.Info("connection closed: " + client.RemoteAddr()) return } // protocol err @@ -80,7 +79,7 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) { _, err := client.Write(errReply.ToBytes()) if err != nil { h.closeClient(client) - logger.Info("connection closed: " + client.RemoteAddr().String()) + logger.Info("connection closed: " + client.RemoteAddr()) return } continue