Files
redis-go/cluster/transaction.go
2021-05-02 14:54:42 +08:00

226 lines
5.3 KiB
Go

package cluster
import (
"errors"
"fmt"
"github.com/hdt3213/godis/db"
"github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/timewheel"
"github.com/hdt3213/godis/redis/reply"
"strconv"
"strings"
"sync"
"time"
)
type Transaction struct {
id string // transaction id
args [][]byte // cmd args
cluster *Cluster
conn redis.Connection
keys []string // related keys
lockedKeys bool
undoLog map[string][][]byte // store data for undoLog
status int8
mu *sync.Mutex
}
const (
maxLockTime = 3 * time.Second
waitBeforeCleanTx = 2 * maxLockTime
CreatedStatus = 0
PreparedStatus = 1
CommittedStatus = 2
RolledBackStatus = 3
)
func genTaskKey(txId string) string {
return "tx:" + txId
}
func NewTransaction(cluster *Cluster, c redis.Connection, id string, args [][]byte, keys []string) *Transaction {
return &Transaction{
id: id,
args: args,
cluster: cluster,
conn: c,
keys: keys,
status: CreatedStatus,
mu: new(sync.Mutex),
}
}
// Reentrant
// invoker should hold tx.mu
func (tx *Transaction) lockKeys() {
if !tx.lockedKeys {
tx.cluster.db.Locks(tx.keys...)
tx.lockedKeys = true
}
}
func (tx *Transaction) unLockKeys() {
if tx.lockedKeys {
tx.cluster.db.UnLocks(tx.keys...)
tx.lockedKeys = false
}
}
// t should contains Keys and Id field
func (tx *Transaction) prepare() error {
tx.mu.Lock()
defer tx.mu.Unlock()
// lock keys
tx.lockKeys()
// build undoLog
tx.undoLog = make(map[string][][]byte)
for _, key := range tx.keys {
entity, ok := tx.cluster.db.Get(key)
if ok {
blob := db.EntityToCmd(key, entity)
tx.undoLog[key] = blob.Args
} else {
tx.undoLog[key] = nil // entity was nil, should be removed while rollback
}
}
tx.status = PreparedStatus
taskKey := genTaskKey(tx.id)
timewheel.Delay(maxLockTime, taskKey, func() {
if tx.status == PreparedStatus { // rollback transaction uncommitted until expire
logger.Info("abort transaction: " + tx.id)
_ = tx.rollback()
}
})
return nil
}
func (tx *Transaction) rollback() error {
curStatus := tx.status
tx.mu.Lock()
defer tx.mu.Unlock()
if tx.status != curStatus { // ensure status not changed by other goroutine
return errors.New(fmt.Sprintf("tx %s status changed", tx.id))
}
if tx.status == RolledBackStatus { // no need to rollback a rolled-back transaction
return nil
}
tx.lockKeys()
for key, blob := range tx.undoLog {
if len(blob) > 0 {
tx.cluster.db.Remove(key)
tx.cluster.db.Exec(nil, blob)
} else {
tx.cluster.db.Remove(key)
}
}
tx.unLockKeys()
tx.status = RolledBackStatus
return nil
}
// rollback local transaction
func Rollback(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rollback' command")
}
txId := string(args[1])
raw, ok := cluster.transactions.Get(txId)
if !ok {
return reply.MakeIntReply(0)
}
tx, _ := raw.(*Transaction)
err := tx.rollback()
if err != nil {
return reply.MakeErrReply(err.Error())
}
// clean transaction
timewheel.Delay(waitBeforeCleanTx, "", func() {
cluster.transactions.Remove(tx.id)
})
return reply.MakeIntReply(1)
}
// commit local transaction as a worker
func Commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'commit' command")
}
txId := string(args[1])
raw, ok := cluster.transactions.Get(txId)
if !ok {
return reply.MakeIntReply(0)
}
tx, _ := raw.(*Transaction)
tx.mu.Lock()
defer tx.mu.Unlock()
cmd := strings.ToLower(string(tx.args[0]))
var result redis.Reply
if cmd == "del" {
result = CommitDel(cluster, c, tx)
} else if cmd == "mset" {
result = CommitMSet(cluster, c, tx)
}
if reply.IsErrorReply(result) {
// failed
err2 := tx.rollback()
return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result))
} else {
// after committed
tx.unLockKeys()
tx.status = CommittedStatus
// clean transaction
// do not clean immediately, in case rollback
timewheel.Delay(waitBeforeCleanTx, "", func() {
cluster.transactions.Remove(tx.id)
})
}
return result
}
// request all node commit transaction as leader
func RequestCommit(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) ([]redis.Reply, reply.ErrorReply) {
var errReply reply.ErrorReply
txIdStr := strconv.FormatInt(txId, 10)
respList := make([]redis.Reply, 0, len(peers))
for peer := range peers {
var resp redis.Reply
if peer == cluster.self {
resp = Commit(cluster, c, makeArgs("commit", txIdStr))
} else {
resp = cluster.Relay(peer, c, makeArgs("commit", txIdStr))
}
if reply.IsErrorReply(resp) {
errReply = resp.(reply.ErrorReply)
break
}
respList = append(respList, resp)
}
if errReply != nil {
RequestRollback(cluster, c, txId, peers)
return nil, errReply
}
return respList, nil
}
// request all node rollback transaction as leader
func RequestRollback(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) {
txIdStr := strconv.FormatInt(txId, 10)
for peer := range peers {
if peer == cluster.self {
Rollback(cluster, c, makeArgs("rollback", txIdStr))
} else {
cluster.Relay(peer, c, makeArgs("rollback", txIdStr))
}
}
}