From 23b70325f970252f626b759d4085d20db055b1d4 Mon Sep 17 00:00:00 2001 From: finley Date: Sun, 23 Mar 2025 20:44:13 +0800 Subject: [PATCH] support msetnx in cluster by tcc pre-check hook --- cluster/commands/mset.go | 77 +++++++++++++++++++++++++++++++++++ cluster/commands/mset_test.go | 19 +++++++++ cluster/core/tcc.go | 37 ++++++++++++++--- cluster/core/utils.go | 12 ++++++ 4 files changed, 139 insertions(+), 6 deletions(-) diff --git a/cluster/commands/mset.go b/cluster/commands/mset.go index a4f381b..49a7bf4 100644 --- a/cluster/commands/mset.go +++ b/cluster/commands/mset.go @@ -1,6 +1,8 @@ package commands import ( + "strings" + "github.com/hdt3213/godis/cluster/core" "github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/lib/utils" @@ -12,6 +14,8 @@ func init() { core.RegisterCmd("mset", execMSet) core.RegisterCmd("mget_", execMGetInLocal) core.RegisterCmd("mget", execMGet) + core.RegisterCmd("msetnx_", execMSetNxInLocal) + core.RegisterCmd("msetnx", execMSet) } // execMSetInLocal executes msets in local node @@ -32,6 +36,15 @@ func execMGetInLocal(cluster *core.Cluster, c redis.Connection, cmdLine CmdLine) return cluster.LocalExec(c, cmdLine) } +// execMSetInLocal executes msets in local node +func execMSetNxInLocal(cluster *core.Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply { + if len(cmdLine) < 3 { + return protocol.MakeArgNumErrReply("msetnx") + } + cmdLine[0] = []byte("msetnx") + return cluster.LocalExec(c, cmdLine) +} + func execMSet(cluster *core.Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply { if len(cmdLine) < 3 || len(cmdLine)%2 != 1 { return protocol.MakeArgNumErrReply("mset") @@ -126,3 +139,67 @@ func execMGet(cluster *core.Cluster, c redis.Connection, cmdLine CmdLine) redis. } return protocol.MakeMultiBulkReply(result) } + +const someKeysExistsErr = "Some Keys Exists" + +func init() { + core.RegisterPreCheck("msetnx", msetNxPrecheck) +} + +func msetNxPrecheck(cluster *core.Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply { + var keys []string + for i := 1; i < len(cmdLine); i+=2 { + keys = append(keys, string(cmdLine[i])) + } + exists := cluster.LocalExists(keys) + if len(exists) > 0 { + return protocol.MakeErrReply(someKeysExistsErr) + } + return protocol.MakeOkReply() +} + +func execMSetNx(cluster *core.Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply { + if len(cmdLine) < 3 || len(cmdLine)%2 != 1 { + return protocol.MakeArgNumErrReply("mset") + } + var keys []string + keyValues := make(map[string][]byte) + for i := 1; i < len(cmdLine); i += 2 { + key := string(cmdLine[i]) + value := cmdLine[i+1] + keyValues[key] = value + keys = append(keys, key) + } + routeMap := getRouteMap(cluster, keys) + if len(routeMap) == 1 { + // only one node, do it fast + for node := range routeMap { + cmdLine[0] = []byte("msetnx_") + return cluster.Relay(node, c, cmdLine) + } + } + + // tcc + cmdLineMap := make(map[string]CmdLine) + for node, keys := range routeMap { + nodeCmdLine := utils.ToCmdLine("msetnx") + for _, key := range keys { + val := keyValues[key] + nodeCmdLine = append(nodeCmdLine, []byte(key), val) + } + cmdLineMap[node] = nodeCmdLine + } + tx := &TccTx{ + rawCmdLine: cmdLine, + routeMap: routeMap, + cmdLines: cmdLineMap, + } + _, err := doTcc(cluster, c, tx) + if err != nil { + if strings.Contains(err.Error(), someKeysExistsErr) { + return protocol.MakeIntReply(0) + } + return err + } + return protocol.MakeIntReply(1) +} \ No newline at end of file diff --git a/cluster/commands/mset_test.go b/cluster/commands/mset_test.go index 8713584..c7d381c 100644 --- a/cluster/commands/mset_test.go +++ b/cluster/commands/mset_test.go @@ -20,4 +20,23 @@ func TestMset(t *testing.T) { asserts.AssertStatusReply(t, res, "OK") res2 := execMGet(node1, c, utils.ToCmdLine("mget", "1", "2")) asserts.AssertMultiBulkReply(t, res2, []string{"1", "2"}) +} + +func TestMsetNx(t *testing.T) { + id1 := "1" + id2 := "2" + nodes := core.MakeTestCluster([]string{id1, id2}) + node1 := nodes[id1] + c := connection.NewFakeConn() + // 1, 2 will be routed to node1 and node2, see MakeTestCluster + res := execMSetNx(node1, c, utils.ToCmdLine("mset", "1", "1", "2", "2")) + asserts.AssertIntReply(t, res, 1) + res2 := execMGet(node1, c, utils.ToCmdLine("mget", "1", "2")) + asserts.AssertMultiBulkReply(t, res2, []string{"1", "2"}) + + res = execMSetNx(node1, c, utils.ToCmdLine("mset", "3", "3", "2", "2")) + asserts.AssertIntReply(t, res, 0) + core.RegisterDefaultCmd("get") + res = node1.Exec(c, utils.ToCmdLine("get", "3")) + asserts.AssertNullBulk(t, res) } \ No newline at end of file diff --git a/cluster/core/tcc.go b/cluster/core/tcc.go index 4bf901e..e6f2cda 100644 --- a/cluster/core/tcc.go +++ b/cluster/core/tcc.go @@ -1,13 +1,19 @@ package core import ( + "strings" "sync" + "time" "github.com/hdt3213/godis/database" "github.com/hdt3213/godis/interface/redis" + "github.com/hdt3213/godis/lib/timewheel" "github.com/hdt3213/godis/redis/protocol" ) +// transaction info will be deleted after transactionTTL since commit +const transactionTTL = time.Minute + type TransactionManager struct { txs map[string]*TCC mu sync.RWMutex @@ -54,7 +60,14 @@ func execPrepare(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Re cluster.transactions.txs[txId] = tx cluster.transactions.mu.Unlock() - // todo: pre-execute check + // pre-execute check + validator := preChecks[string(realCmdLine[0])] + if validator != nil { + validateResult := validator(cluster, c, realCmdLine) + if protocol.IsErrorReply(validateResult) { + return validateResult + } + } // prepare lock and undo locks tx.writeKeys, tx.readKeys = database.GetRelatedKeys(realCmdLine) @@ -89,11 +102,12 @@ func execCommit(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Rep return resp } - // todo: delete transaction after deadline - // cluster.transactions.mu.Lock() - // delete(cluster.transactions.txs, txId) - // cluster.transactions.mu.Unlock() - + // delete transaction after deadline + timewheel.At(time.Now().Add(transactionTTL), txId, func() { + cluster.transactions.mu.Lock() + delete(cluster.transactions.txs, txId) + cluster.transactions.mu.Unlock() + }) return resp } @@ -127,3 +141,14 @@ func execRollback(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.R return protocol.MakeOkReply() } + +// PreCheckFunc do validation during tcc preparing period +type PreCheckFunc func(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply + +var preChecks = make(map[string]PreCheckFunc) + +// RegisterCmd add tcc preparing validator +func RegisterPreCheck(name string, fn PreCheckFunc) { + name = strings.ToLower(name) + preChecks[name] = fn +} \ No newline at end of file diff --git a/cluster/core/utils.go b/cluster/core/utils.go index b02eb30..f9ce454 100644 --- a/cluster/core/utils.go +++ b/cluster/core/utils.go @@ -79,3 +79,15 @@ func execRaftCommittedIndex(cluster *Cluster, c redis.Connection, cmdLine CmdLin } return protocol.MakeIntReply(int64(index)) } + +// LocalExists returns existed ones from `keys` in local node +func (cluster *Cluster) LocalExists(keys []string) []string { + var exists []string + for _, key := range keys { + _, ok := cluster.db.GetEntity(0, key) + if ok { + exists = append(exists, key) + } + } + return exists +}