mirror of
https://github.com/HDT3213/godis.git
synced 2025-10-06 09:17:10 +08:00
support msetnx in cluster by tcc pre-check hook
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/hdt3213/godis/cluster/core"
|
||||
"github.com/hdt3213/godis/interface/redis"
|
||||
"github.com/hdt3213/godis/lib/utils"
|
||||
@@ -12,6 +14,8 @@ func init() {
|
||||
core.RegisterCmd("mset", execMSet)
|
||||
core.RegisterCmd("mget_", execMGetInLocal)
|
||||
core.RegisterCmd("mget", execMGet)
|
||||
core.RegisterCmd("msetnx_", execMSetNxInLocal)
|
||||
core.RegisterCmd("msetnx", execMSet)
|
||||
}
|
||||
|
||||
// execMSetInLocal executes msets in local node
|
||||
@@ -32,6 +36,15 @@ func execMGetInLocal(cluster *core.Cluster, c redis.Connection, cmdLine CmdLine)
|
||||
return cluster.LocalExec(c, cmdLine)
|
||||
}
|
||||
|
||||
// execMSetInLocal executes msets in local node
|
||||
func execMSetNxInLocal(cluster *core.Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply {
|
||||
if len(cmdLine) < 3 {
|
||||
return protocol.MakeArgNumErrReply("msetnx")
|
||||
}
|
||||
cmdLine[0] = []byte("msetnx")
|
||||
return cluster.LocalExec(c, cmdLine)
|
||||
}
|
||||
|
||||
func execMSet(cluster *core.Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply {
|
||||
if len(cmdLine) < 3 || len(cmdLine)%2 != 1 {
|
||||
return protocol.MakeArgNumErrReply("mset")
|
||||
@@ -126,3 +139,67 @@ func execMGet(cluster *core.Cluster, c redis.Connection, cmdLine CmdLine) redis.
|
||||
}
|
||||
return protocol.MakeMultiBulkReply(result)
|
||||
}
|
||||
|
||||
const someKeysExistsErr = "Some Keys Exists"
|
||||
|
||||
func init() {
|
||||
core.RegisterPreCheck("msetnx", msetNxPrecheck)
|
||||
}
|
||||
|
||||
func msetNxPrecheck(cluster *core.Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply {
|
||||
var keys []string
|
||||
for i := 1; i < len(cmdLine); i+=2 {
|
||||
keys = append(keys, string(cmdLine[i]))
|
||||
}
|
||||
exists := cluster.LocalExists(keys)
|
||||
if len(exists) > 0 {
|
||||
return protocol.MakeErrReply(someKeysExistsErr)
|
||||
}
|
||||
return protocol.MakeOkReply()
|
||||
}
|
||||
|
||||
func execMSetNx(cluster *core.Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply {
|
||||
if len(cmdLine) < 3 || len(cmdLine)%2 != 1 {
|
||||
return protocol.MakeArgNumErrReply("mset")
|
||||
}
|
||||
var keys []string
|
||||
keyValues := make(map[string][]byte)
|
||||
for i := 1; i < len(cmdLine); i += 2 {
|
||||
key := string(cmdLine[i])
|
||||
value := cmdLine[i+1]
|
||||
keyValues[key] = value
|
||||
keys = append(keys, key)
|
||||
}
|
||||
routeMap := getRouteMap(cluster, keys)
|
||||
if len(routeMap) == 1 {
|
||||
// only one node, do it fast
|
||||
for node := range routeMap {
|
||||
cmdLine[0] = []byte("msetnx_")
|
||||
return cluster.Relay(node, c, cmdLine)
|
||||
}
|
||||
}
|
||||
|
||||
// tcc
|
||||
cmdLineMap := make(map[string]CmdLine)
|
||||
for node, keys := range routeMap {
|
||||
nodeCmdLine := utils.ToCmdLine("msetnx")
|
||||
for _, key := range keys {
|
||||
val := keyValues[key]
|
||||
nodeCmdLine = append(nodeCmdLine, []byte(key), val)
|
||||
}
|
||||
cmdLineMap[node] = nodeCmdLine
|
||||
}
|
||||
tx := &TccTx{
|
||||
rawCmdLine: cmdLine,
|
||||
routeMap: routeMap,
|
||||
cmdLines: cmdLineMap,
|
||||
}
|
||||
_, err := doTcc(cluster, c, tx)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), someKeysExistsErr) {
|
||||
return protocol.MakeIntReply(0)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return protocol.MakeIntReply(1)
|
||||
}
|
@@ -21,3 +21,22 @@ func TestMset(t *testing.T) {
|
||||
res2 := execMGet(node1, c, utils.ToCmdLine("mget", "1", "2"))
|
||||
asserts.AssertMultiBulkReply(t, res2, []string{"1", "2"})
|
||||
}
|
||||
|
||||
func TestMsetNx(t *testing.T) {
|
||||
id1 := "1"
|
||||
id2 := "2"
|
||||
nodes := core.MakeTestCluster([]string{id1, id2})
|
||||
node1 := nodes[id1]
|
||||
c := connection.NewFakeConn()
|
||||
// 1, 2 will be routed to node1 and node2, see MakeTestCluster
|
||||
res := execMSetNx(node1, c, utils.ToCmdLine("mset", "1", "1", "2", "2"))
|
||||
asserts.AssertIntReply(t, res, 1)
|
||||
res2 := execMGet(node1, c, utils.ToCmdLine("mget", "1", "2"))
|
||||
asserts.AssertMultiBulkReply(t, res2, []string{"1", "2"})
|
||||
|
||||
res = execMSetNx(node1, c, utils.ToCmdLine("mset", "3", "3", "2", "2"))
|
||||
asserts.AssertIntReply(t, res, 0)
|
||||
core.RegisterDefaultCmd("get")
|
||||
res = node1.Exec(c, utils.ToCmdLine("get", "3"))
|
||||
asserts.AssertNullBulk(t, res)
|
||||
}
|
@@ -1,13 +1,19 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hdt3213/godis/database"
|
||||
"github.com/hdt3213/godis/interface/redis"
|
||||
"github.com/hdt3213/godis/lib/timewheel"
|
||||
"github.com/hdt3213/godis/redis/protocol"
|
||||
)
|
||||
|
||||
// transaction info will be deleted after transactionTTL since commit
|
||||
const transactionTTL = time.Minute
|
||||
|
||||
type TransactionManager struct {
|
||||
txs map[string]*TCC
|
||||
mu sync.RWMutex
|
||||
@@ -54,7 +60,14 @@ func execPrepare(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Re
|
||||
cluster.transactions.txs[txId] = tx
|
||||
cluster.transactions.mu.Unlock()
|
||||
|
||||
// todo: pre-execute check
|
||||
// pre-execute check
|
||||
validator := preChecks[string(realCmdLine[0])]
|
||||
if validator != nil {
|
||||
validateResult := validator(cluster, c, realCmdLine)
|
||||
if protocol.IsErrorReply(validateResult) {
|
||||
return validateResult
|
||||
}
|
||||
}
|
||||
|
||||
// prepare lock and undo locks
|
||||
tx.writeKeys, tx.readKeys = database.GetRelatedKeys(realCmdLine)
|
||||
@@ -89,11 +102,12 @@ func execCommit(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Rep
|
||||
return resp
|
||||
}
|
||||
|
||||
// todo: delete transaction after deadline
|
||||
// cluster.transactions.mu.Lock()
|
||||
// delete(cluster.transactions.txs, txId)
|
||||
// cluster.transactions.mu.Unlock()
|
||||
|
||||
// delete transaction after deadline
|
||||
timewheel.At(time.Now().Add(transactionTTL), txId, func() {
|
||||
cluster.transactions.mu.Lock()
|
||||
delete(cluster.transactions.txs, txId)
|
||||
cluster.transactions.mu.Unlock()
|
||||
})
|
||||
return resp
|
||||
}
|
||||
|
||||
@@ -127,3 +141,14 @@ func execRollback(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.R
|
||||
|
||||
return protocol.MakeOkReply()
|
||||
}
|
||||
|
||||
// PreCheckFunc do validation during tcc preparing period
|
||||
type PreCheckFunc func(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply
|
||||
|
||||
var preChecks = make(map[string]PreCheckFunc)
|
||||
|
||||
// RegisterCmd add tcc preparing validator
|
||||
func RegisterPreCheck(name string, fn PreCheckFunc) {
|
||||
name = strings.ToLower(name)
|
||||
preChecks[name] = fn
|
||||
}
|
@@ -79,3 +79,15 @@ func execRaftCommittedIndex(cluster *Cluster, c redis.Connection, cmdLine CmdLin
|
||||
}
|
||||
return protocol.MakeIntReply(int64(index))
|
||||
}
|
||||
|
||||
// LocalExists returns existed ones from `keys` in local node
|
||||
func (cluster *Cluster) LocalExists(keys []string) []string {
|
||||
var exists []string
|
||||
for _, key := range keys {
|
||||
_, ok := cluster.db.GetEntity(0, key)
|
||||
if ok {
|
||||
exists = append(exists, key)
|
||||
}
|
||||
}
|
||||
return exists
|
||||
}
|
||||
|
Reference in New Issue
Block a user