diff --git a/src/cluster/cluster.go b/src/cluster/cluster.go index 6d6752e..9ec54e8 100644 --- a/src/cluster/cluster.go +++ b/src/cluster/cluster.go @@ -129,56 +129,6 @@ func (cluster *Cluster) Relay(peer string, c redis.Connection, args [][]byte) re } } -// rollback local transaction -func Rollback(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'rollback' command") - } - txId := string(args[1]) - raw, ok := cluster.transactions.Get(txId) - if !ok { - return reply.MakeIntReply(0) - } - tx, _ := raw.(*Transaction) - err := tx.rollback() - if err != nil { - return reply.MakeErrReply(err.Error()) - } - return reply.MakeIntReply(1) -} - -func Commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'commit' command") - } - txId := string(args[1]) - raw, ok := cluster.transactions.Get(txId) - if !ok { - return reply.MakeIntReply(0) - } - tx, _ := raw.(*Transaction) - - // finish transaction - defer func() { - cluster.db.UnLocks(tx.keys...) - cluster.transactions.Remove(tx.id) - }() - - cmd := strings.ToLower(string(tx.args[0])) - var result redis.Reply - if cmd == "del" { - result = CommitDel(cluster, c, tx) - } - - if reply.IsErrorReply(result) { - // failed - err2 := tx.rollback() - return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result)) - } - - return &reply.OkReply{} -} - /*----- utils -------*/ func makeArgs(cmd string, args ...string) [][]byte { diff --git a/src/cluster/del.go b/src/cluster/del.go index 25c6838..4fc047e 100644 --- a/src/cluster/del.go +++ b/src/cluster/del.go @@ -23,51 +23,41 @@ func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { // prepare var errReply redis.Reply txId := cluster.idGenerator.NextId() + txIdStr := strconv.FormatInt(txId, 10) rollback := false for peer, group := range groupMap { - args := []string{strconv.FormatInt(txId, 10)} + args := []string{txIdStr} args = append(args, group...) - var ret redis.Reply + var resp redis.Reply if peer == cluster.self { - ret = PrepareDel(cluster, c, makeArgs("PrepareDel", args...)) + resp = PrepareDel(cluster, c, makeArgs("PrepareDel", args...)) } else { - ret = cluster.Relay(peer, c, makeArgs("PrepareDel", args...)) + resp = cluster.Relay(peer, c, makeArgs("PrepareDel", args...)) } - if reply.IsErrorReply(ret) { - errReply = ret + if reply.IsErrorReply(resp) { + errReply = resp rollback = true break } } + var respList []redis.Reply if rollback { // rollback - for peer := range groupMap { - cluster.Relay(peer, c, makeArgs("rollback", strconv.FormatInt(txId, 10))) - } + RequestRollback(cluster, c, txId, groupMap) } else { // commit - rollback = false - for peer := range groupMap { - var ret redis.Reply - if peer == cluster.self { - ret = Commit(cluster, c, makeArgs("commit", strconv.FormatInt(txId, 10))) - } else { - ret = cluster.Relay(peer, c, makeArgs("commit", strconv.FormatInt(txId, 10))) - } - if reply.IsErrorReply(ret) { - errReply = ret - rollback = true - break - } - } - if rollback { - for peer := range groupMap { - cluster.Relay(peer, c, makeArgs("rollback", strconv.FormatInt(txId, 10))) - } + respList, errReply = RequestCommit(cluster, c, txId, groupMap) + if errReply != nil { + rollback = true } } if !rollback { - return reply.MakeIntReply(int64(len(keys))) + var deleted int64 = 0 + for _, resp := range respList { + intResp := resp.(*reply.IntReply) + deleted += intResp.Code + } + return reply.MakeIntReply(int64(deleted)) } return errReply } @@ -105,5 +95,5 @@ func CommitDel(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Repl if deleted > 0 { cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args)) } - return &reply.OkReply{} + return reply.MakeIntReply(int64(deleted)) } diff --git a/src/cluster/mset.go b/src/cluster/mset.go index 294ffcc..736ba5c 100644 --- a/src/cluster/mset.go +++ b/src/cluster/mset.go @@ -2,9 +2,10 @@ package cluster import ( "fmt" + "github.com/HDT3213/godis/src/db" "github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/redis/reply" - "strings" + "strconv" ) func MGet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { @@ -37,6 +38,48 @@ func MGet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { return reply.MakeMultiBulkReply(result) } +// args: PrepareMSet id keys... +func PrepareMSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + if len(args) < 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'preparemset' command") + } + txId := string(args[1]) + size := (len(args) - 2) / 2 + keys := make([]string, size) + for i := 0; i < size; i++ { + keys[i] = string(args[2*i+2]) + } + + txArgs := [][]byte{ + []byte("MSet"), + } // actual args for cluster.db + txArgs = append(txArgs, args[2:]...) + tx := NewTransaction(cluster, c, txId, txArgs, keys) + cluster.transactions.Put(txId, tx) + err := tx.prepare() + if err != nil { + return reply.MakeErrReply(err.Error()) + } + return &reply.OkReply{} +} + +// invoker should provide lock +func CommitMSet(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply { + size := len(tx.args) / 2 + keys := make([]string, size) + values := make([][]byte, size) + for i := 0; i < size; i++ { + keys[i] = string(tx.args[2*i+1]) + values[i] = tx.args[2*i+2] + } + for i, key := range keys { + value := values[i] + cluster.db.Put(key, &db.DataEntity{Data: value}) + } + cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args)) + return &reply.OkReply{} +} + func MSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { argCount := len(args) - 1 if argCount%2 != 0 || argCount < 1 { @@ -47,29 +90,50 @@ func MSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { keys := make([]string, size) valueMap := make(map[string]string) for i := 0; i < size; i++ { - keys[i] = string(args[2*i]) - valueMap[keys[i]] = string(args[2*i+1]) + keys[i] = string(args[2*i+1]) + valueMap[keys[i]] = string(args[2*i+2]) } - failedKeys := make([]string, 0) groupMap := cluster.groupBy(keys) - for peer, groupKeys := range groupMap { - peerArgs := make([][]byte, 2*len(groupKeys)+1) - peerArgs[0] = []byte("MSET") - for i, k := range groupKeys { - peerArgs[2*i+1] = []byte(k) - value := valueMap[k] - peerArgs[2*i+2] = []byte(value) + if len(groupMap) == 1 { // do fast + for peer := range groupMap { + return cluster.Relay(peer, c, args) + } + } + + //prepare + var errReply redis.Reply + txId := cluster.idGenerator.NextId() + txIdStr := strconv.FormatInt(txId, 10) + rollback := false + for peer, group := range groupMap { + peerArgs := []string{txIdStr} + for _, k := range group { + peerArgs = append(peerArgs, k, valueMap[k]) + } + var resp redis.Reply + if peer == cluster.self { + resp = PrepareMSet(cluster, c, makeArgs("PrepareMSet", peerArgs...)) + } else { + resp = cluster.Relay(peer, c, makeArgs("PrepareMSet", peerArgs...)) } - resp := cluster.Relay(peer, c, peerArgs) if reply.IsErrorReply(resp) { - failedKeys = append(failedKeys, groupKeys...) + errReply = resp + rollback = true + break } } - if len(failedKeys) > 0 { - return reply.MakeErrReply("ERR part failure: " + strings.Join(failedKeys, ",")) + if rollback { + // rollback + RequestRollback(cluster, c, txId, groupMap) + } else { + _, errReply = RequestCommit(cluster, c, txId, groupMap) + rollback = errReply != nil } - return &reply.OkReply{} + if !rollback { + return &reply.OkReply{} + } + return errReply } diff --git a/src/cluster/router.go b/src/cluster/router.go index 1166af9..5b2f0ae 100644 --- a/src/cluster/router.go +++ b/src/cluster/router.go @@ -10,6 +10,7 @@ func MakeRouter() map[string]CmdFunc { routerMap["rollback"] = Rollback routerMap["del"] = Del routerMap["preparedel"] = PrepareDel + routerMap["preparemset"] = PrepareMSet routerMap["expire"] = defaultFunc routerMap["expireat"] = defaultFunc diff --git a/src/cluster/transaction.go b/src/cluster/transaction.go index fe0a4ef..20dcae8 100644 --- a/src/cluster/transaction.go +++ b/src/cluster/transaction.go @@ -2,9 +2,13 @@ package cluster import ( "context" + "fmt" "github.com/HDT3213/godis/src/db" "github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/lib/marshal/gob" + "github.com/HDT3213/godis/src/redis/reply" + "strconv" + "strings" "time" ) @@ -85,21 +89,100 @@ func (tx *Transaction) rollback() error { tx.cluster.db.Remove(key) } } - tx.cluster.db.UnLocks(tx.keys...) + if tx.status != CommitedStatus { + tx.cluster.db.UnLocks(tx.keys...) + } tx.status = RollbackedStatus return nil } -//func (tx *Transaction) commit(cmd CmdFunc) error { -// finished := make(chan int) -// go func() { -// cmd(tx.cluster, tx.conn, tx.args) -// finished <- 1 -// }() -// select { -// case <- tx.ctx.Done(): -// return tx.rollback() -// case <- finished: -// tx.cluster.db.UnLocks(tx.keys...) -// } -//} +// rollback local transaction +func Rollback(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'rollback' command") + } + txId := string(args[1]) + raw, ok := cluster.transactions.Get(txId) + if !ok { + return reply.MakeIntReply(0) + } + tx, _ := raw.(*Transaction) + err := tx.rollback() + if err != nil { + return reply.MakeErrReply(err.Error()) + } + return reply.MakeIntReply(1) +} + +// commit local transaction as a worker +func Commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'commit' command") + } + txId := string(args[1]) + raw, ok := cluster.transactions.Get(txId) + if !ok { + return reply.MakeIntReply(0) + } + tx, _ := raw.(*Transaction) + + // finish transaction + defer func() { + cluster.db.UnLocks(tx.keys...) + tx.status = CommitedStatus + //cluster.transactions.Remove(tx.id) // cannot remove, may rollback after commit + }() + + cmd := strings.ToLower(string(tx.args[0])) + var result redis.Reply + if cmd == "del" { + result = CommitDel(cluster, c, tx) + } else if cmd == "mset" { + result = CommitMSet(cluster, c, tx) + } + + if reply.IsErrorReply(result) { + // failed + err2 := tx.rollback() + return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result)) + } + + return result +} + +// request all node commit transaction as leader +func RequestCommit(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) ([]redis.Reply, reply.ErrorReply) { + var errReply reply.ErrorReply + txIdStr := strconv.FormatInt(txId, 10) + respList := make([]redis.Reply, 0, len(peers)) + for peer := range peers { + var resp redis.Reply + if peer == cluster.self { + resp = Commit(cluster, c, makeArgs("commit", txIdStr)) + } else { + resp = cluster.Relay(peer, c, makeArgs("commit", txIdStr)) + } + if reply.IsErrorReply(resp) { + errReply = resp.(reply.ErrorReply) + break + } + respList = append(respList, resp) + } + if errReply != nil { + RequestRollback(cluster, c, txId, peers) + return nil, errReply + } + return respList, nil +} + +// request all node rollback transaction as leader +func RequestRollback(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) { + txIdStr := strconv.FormatInt(txId, 10) + for peer := range peers { + if peer == cluster.self { + Rollback(cluster, c, makeArgs("rollback", txIdStr)) + } else { + cluster.Relay(peer, c, makeArgs("rollback", txIdStr)) + } + } +} \ No newline at end of file