use 2pc to execute del command in cluster mode

This commit is contained in:
hdt3213
2020-09-19 20:32:21 +08:00
parent a806f8e64f
commit 684760696d
20 changed files with 979 additions and 292 deletions

307
src/cluster/client.go Normal file
View File

@@ -0,0 +1,307 @@
package cluster
import (
"bufio"
"context"
"errors"
"github.com/HDT3213/godis/src/cluster/idgenerator"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/lib/sync/wait"
"github.com/HDT3213/godis/src/redis/reply"
"io"
"net"
"strconv"
"sync"
"time"
)
const (
timeout = 2 * time.Second
CRLF = "\r\n"
)
type InternalClient struct {
idGen *idgenerator.IdGenerator
conn net.Conn
sendingReqs chan *AsyncRequest
ticker *time.Ticker
addr string
waitingMap *sync.Map // key -> request
ctx context.Context
cancelFunc context.CancelFunc
writing *sync.WaitGroup
}
type AsyncRequest struct {
id int64
args [][]byte
reply redis.Reply
heartbeat bool
waiting *wait.Wait
}
type AsyncMultiBulkReply struct {
Args [][]byte
}
func MakeAsyncMultiBulkReply(args [][]byte) *AsyncMultiBulkReply {
return &AsyncMultiBulkReply{
Args: args,
}
}
func (r *AsyncMultiBulkReply) ToBytes() []byte {
argLen := len(r.Args)
res := "@" + strconv.Itoa(argLen) + CRLF
for _, arg := range r.Args {
if arg == nil {
res += "$-1" + CRLF
} else {
res += "$" + strconv.Itoa(len(arg)) + CRLF + string(arg) + CRLF
}
}
return []byte(res)
}
func MakeInternalClient(addr string, idGen *idgenerator.IdGenerator) (*InternalClient, error) {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
return &InternalClient{
addr: addr,
conn: conn,
sendingReqs: make(chan *AsyncRequest, 256),
waitingMap: &sync.Map{},
ctx: ctx,
cancelFunc: cancel,
writing: &sync.WaitGroup{},
idGen: idGen,
}, nil
}
func (client *InternalClient) Start() {
client.ticker = time.NewTicker(10 * time.Second)
go client.handleWrite()
go func() {
err := client.handleRead()
logger.Warn(err)
}()
go client.heartbeat()
}
func (client *InternalClient) Close() {
// send stop signal
client.cancelFunc()
// wait stop process
client.writing.Wait()
// clean
_ = client.conn.Close()
close(client.sendingReqs)
}
func (client *InternalClient) handleConnectionError(err error) error {
err1 := client.conn.Close()
if err1 != nil {
if opErr, ok := err1.(*net.OpError); ok {
if opErr.Err.Error() != "use of closed network connection" {
return err1
}
} else {
return err1
}
}
conn, err1 := net.Dial("tcp", client.addr)
if err1 != nil {
logger.Error(err1)
return err1
}
client.conn = conn
go func() {
_ = client.handleRead()
}()
return nil
}
func (client *InternalClient) heartbeat() {
loop:
for {
select {
case <-client.ticker.C:
client.sendingReqs <- &AsyncRequest{
args: [][]byte{[]byte("PING")},
heartbeat: true,
}
case <-client.ctx.Done():
break loop
}
}
}
func (client *InternalClient) handleWrite() {
client.writing.Add(1)
loop:
for {
select {
case req := <-client.sendingReqs:
client.doRequest(req)
case <-client.ctx.Done():
break loop
}
}
client.writing.Done()
}
func (client *InternalClient) Send(args [][]byte) redis.Reply {
request := &AsyncRequest{
id: client.idGen.NextId(),
args: args,
heartbeat: false,
waiting: &wait.Wait{},
}
request.waiting.Add(1)
client.sendingReqs <- request
client.waitingMap.Store(request.id, request)
timeUp := request.waiting.WaitWithTimeout(timeout)
if timeUp {
client.waitingMap.Delete(request.id)
return nil
} else {
return request.reply
}
}
func (client *InternalClient) doRequest(req *AsyncRequest) {
bytes := reply.MakeMultiBulkReply(req.args).ToBytes()
_, err := client.conn.Write(bytes)
i := 0
for err != nil && i < 3 {
err = client.handleConnectionError(err)
if err == nil {
_, err = client.conn.Write(bytes)
}
i++
}
}
func (client *InternalClient) finishRequest(reply *AsyncMultiBulkReply) {
if reply == nil || reply.Args == nil || len(reply.Args) == 0 {
return
}
reqId, err := strconv.ParseInt(string(reply.Args[0]), 10, 64)
if err != nil {
logger.Warn(err)
return
}
raw, ok := client.waitingMap.Load(reqId)
if !ok {
return
}
request := raw.(*AsyncRequest)
request.reply = reply
if request.waiting != nil {
request.waiting.Done()
}
}
func (client *InternalClient) handleRead() error {
reader := bufio.NewReader(client.conn)
downloading := false
expectedArgsCount := 0
receivedCount := 0
var args [][]byte
var fixedLen int64 = 0
var err error
var msg []byte
for {
// read line
if fixedLen == 0 { // read normal line
msg, err = reader.ReadBytes('\n')
if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
logger.Info("connection close")
} else {
logger.Warn(err)
}
return errors.New("connection closed")
}
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
return errors.New("protocol error")
}
} else { // read bulk line (binary safe)
msg = make([]byte, fixedLen+2)
_, err = io.ReadFull(reader, msg)
if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
return errors.New("connection closed")
} else {
return err
}
}
if len(msg) == 0 ||
msg[len(msg)-2] != '\r' ||
msg[len(msg)-1] != '\n' {
return errors.New("protocol error")
}
fixedLen = 0
}
// parse line
if !downloading {
// receive new response
if msg[0] == '@' { // customized multi bulk response
// bulk multi msg
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
if err != nil {
return errors.New("protocol error: " + err.Error())
}
if expectedLine == 0 {
client.finishRequest(nil)
} else if expectedLine > 0 {
downloading = true
expectedArgsCount = int(expectedLine)
receivedCount = 0
args = make([][]byte, expectedLine)
} else {
return errors.New("protocol error")
}
}
} else {
// receive following part of a request
line := msg[0 : len(msg)-2]
if line[0] == '$' {
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
if err != nil {
return err
}
if fixedLen <= 0 { // null bulk in multi bulks
args[receivedCount] = []byte{}
receivedCount++
fixedLen = 0
}
} else {
args[receivedCount] = line
receivedCount++
}
// if sending finished
if receivedCount == expectedArgsCount {
downloading = false // finish downloading progress
client.finishRequest(&AsyncMultiBulkReply{Args: args})
// finish reply
expectedArgsCount = 0
receivedCount = 0
args = nil
}
}
}
}

View File

@@ -2,7 +2,9 @@ package cluster
import (
"fmt"
"github.com/HDT3213/godis/src/cluster/idgenerator"
"github.com/HDT3213/godis/src/config"
"github.com/HDT3213/godis/src/datastruct/dict"
"github.com/HDT3213/godis/src/db"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/lib/consistenthash"
@@ -16,26 +18,34 @@ import (
type Cluster struct {
self string
db *db.DB
peerPicker *consistenthash.Map
peers map[string]*client.Client
db *db.DB
transactions *dict.SimpleDict // id -> Transaction
idGenerator *idgenerator.IdGenerator
}
const (
replicas = 4
lockSize = 64
)
func MakeCluster() *Cluster {
cluster := &Cluster{
self: config.Properties.Self,
db: db.MakeDB(),
peerPicker: consistenthash.New(replicas, nil),
peers: make(map[string]*client.Client),
db: db.MakeDB(),
transactions: dict.MakeSimple(),
peerPicker: consistenthash.New(replicas, nil),
peers: make(map[string]*client.Client),
idGenerator: idgenerator.MakeGenerator("godis", config.Properties.Self),
}
if config.Properties.Peers != nil && len(config.Properties.Peers) > 0 && config.Properties.Self != "" {
contains := make(map[string]bool)
peers := make([]string, len(config.Properties.Peers)+1)[:]
peers := make([]string, 0, len(config.Properties.Peers)+1)
for _, peer := range config.Properties.Peers {
if _, ok := contains[peer]; ok {
continue
@@ -78,3 +88,119 @@ func (cluster *Cluster) Exec(c redis.Connection, args [][]byte) (result redis.Re
func (cluster *Cluster) AfterClientClose(c redis.Connection) {
}
func (cluster *Cluster) getPeerClient(peer string) (*client.Client, error) {
peerClient, ok := cluster.peers[peer]
// lazy init
if !ok {
var err error
peerClient, err = client.MakeClient(peer)
if err != nil {
return nil, err
}
peerClient.Start()
cluster.peers[peer] = peerClient
}
return peerClient, nil
}
func Ping(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) == 1 {
return &reply.PongReply{}
} else if len(args) == 2 {
return reply.MakeStatusReply("\"" + string(args[1]) + "\"")
} else {
return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command")
}
}
// relay command to peer
// cannot call Prepare, Commit, Rollback of self node
func (cluster *Cluster) Relay(peer string, c redis.Connection, args [][]byte) redis.Reply {
if peer == cluster.self {
// to self db
return cluster.db.Exec(c, args)
} else {
peerClient, err := cluster.getPeerClient(peer)
if err != nil {
return reply.MakeErrReply(err.Error())
}
return peerClient.Send(args)
}
}
// rollback local transaction
func Rollback(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rollback' command")
}
txId := string(args[1])
raw, ok := cluster.transactions.Get(txId)
if !ok {
return reply.MakeIntReply(0)
}
tx, _ := raw.(*Transaction)
err := tx.rollback()
if err != nil {
return reply.MakeErrReply(err.Error())
}
return reply.MakeIntReply(1)
}
func Commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'commit' command")
}
txId := string(args[1])
raw, ok := cluster.transactions.Get(txId)
if !ok {
return reply.MakeIntReply(0)
}
tx, _ := raw.(*Transaction)
// finish transaction
defer func() {
cluster.db.UnLocks(tx.keys...)
cluster.transactions.Remove(tx.id)
}()
cmd := strings.ToLower(string(tx.args[0]))
var result redis.Reply
if cmd == "del" {
result = CommitDel(cluster, c, tx)
}
if reply.IsErrorReply(result) {
// failed
err2 := tx.rollback()
return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result))
}
return &reply.OkReply{}
}
/*----- utils -------*/
func makeArgs(cmd string, args ...string) [][]byte {
result := make([][]byte, len(args)+1)
result[0] = []byte(cmd)
for i, arg := range args {
result[i+1] = []byte(arg)
}
return result
}
// return peer -> keys
func (cluster *Cluster) groupBy(keys []string) map[string][]string {
result := make(map[string][]string)
for _, key := range keys {
peer := cluster.peerPicker.Get(key)
group, ok := result[peer]
if !ok {
group = make([]string, 0)
}
group = append(group, key)
result[peer] = group
}
return result
}

109
src/cluster/del.go Normal file
View File

@@ -0,0 +1,109 @@
package cluster
import (
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"strconv"
)
func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'del' command")
}
keys := make([]string, len(args)-1)
for i := 1; i < len(args); i++ {
keys[i-1] = string(args[i])
}
groupMap := cluster.groupBy(keys)
if len(groupMap) == 1 { // do fast
for peer, group := range groupMap { // only one group
return cluster.Relay(peer, c, makeArgs("DEL", group...))
}
}
// prepare
var errReply redis.Reply
txId := cluster.idGenerator.NextId()
rollback := false
for peer, group := range groupMap {
args := []string{strconv.FormatInt(txId, 10)}
args = append(args, group...)
var ret redis.Reply
if peer == cluster.self {
ret = PrepareDel(cluster, c, makeArgs("PrepareDel", args...))
} else {
ret = cluster.Relay(peer, c, makeArgs("PrepareDel", args...))
}
if reply.IsErrorReply(ret) {
errReply = ret
rollback = true
break
}
}
if rollback {
// rollback
for peer := range groupMap {
cluster.Relay(peer, c, makeArgs("rollback", strconv.FormatInt(txId, 10)))
}
} else {
// commit
rollback = false
for peer := range groupMap {
var ret redis.Reply
if peer == cluster.self {
ret = Commit(cluster, c, makeArgs("commit", strconv.FormatInt(txId, 10)))
} else {
ret = cluster.Relay(peer, c, makeArgs("commit", strconv.FormatInt(txId, 10)))
}
if reply.IsErrorReply(ret) {
errReply = ret
rollback = true
break
}
}
if rollback {
for peer := range groupMap {
cluster.Relay(peer, c, makeArgs("rollback", strconv.FormatInt(txId, 10)))
}
}
}
if !rollback {
return reply.MakeIntReply(int64(len(keys)))
}
return errReply
}
// args: PrepareDel id keys...
func PrepareDel(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) < 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'preparedel' command")
}
txId := string(args[1])
keys := make([]string, 0, len(args)-2)
for i := 2; i < len(args); i++ {
arg := args[i]
keys = append(keys, string(arg))
}
txArgs := makeArgs("DEL", keys...) // actual args for cluster.db
tx := NewTransaction(cluster, c, txId, txArgs, keys)
cluster.transactions.Put(txId, tx)
err := tx.prepare()
if err != nil {
return reply.MakeErrReply(err.Error())
}
return &reply.OkReply{}
}
// invoker should provide lock
func CommitDel(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply {
keys := make([]string, len(tx.args))
for i, v := range tx.args {
keys[i] = string(v)
}
keys = keys[1:]
deleted := cluster.db.Removes(keys...)
if deleted > 0 {
cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args))
}
return &reply.OkReply{}
}

View File

@@ -0,0 +1,85 @@
package idgenerator
import (
"hash/fnv"
"log"
"sync"
"time"
)
const (
workerIdBits int64 = 5
datacenterIdBits int64 = 5
sequenceBits int64 = 12
maxWorkerId int64 = -1 ^ (-1 << uint64(workerIdBits))
maxDatacenterId int64 = -1 ^ (-1 << uint64(datacenterIdBits))
maxSequence int64 = -1 ^ (-1 << uint64(sequenceBits))
timeLeft uint8 = 22
dataLeft uint8 = 17
workLeft uint8 = 12
twepoch int64 = 1525705533000
)
type IdGenerator struct {
mu *sync.Mutex
lastStamp int64
workerId int64
dataCenterId int64
sequence int64
}
func MakeGenerator(cluster string, node string) *IdGenerator {
fnv64 := fnv.New64()
_, _ = fnv64.Write([]byte(cluster))
dataCenterId := int64(fnv64.Sum64())
fnv64.Reset()
_, _ = fnv64.Write([]byte(node))
workerId := int64(fnv64.Sum64())
return &IdGenerator{
mu: &sync.Mutex{},
lastStamp: -1,
dataCenterId: dataCenterId,
workerId: workerId,
sequence: 1,
}
}
func (w *IdGenerator) getCurrentTime() int64 {
return time.Now().UnixNano() / 1e6
}
func (w *IdGenerator) NextId() int64 {
w.mu.Lock()
defer w.mu.Unlock()
timestamp := w.getCurrentTime()
if timestamp < w.lastStamp {
log.Fatal("can not generate id")
}
if w.lastStamp == timestamp {
w.sequence = (w.sequence + 1) & maxSequence
if w.sequence == 0 {
for timestamp <= w.lastStamp {
timestamp = w.getCurrentTime()
}
}
} else {
w.sequence = 0
}
w.lastStamp = timestamp
return ((timestamp - twepoch) << timeLeft) | (w.dataCenterId << dataLeft) | (w.workerId << workLeft) | w.sequence
}
func (w *IdGenerator) tilNextMillis() int64 {
timestamp := w.getCurrentTime()
if timestamp <= w.lastStamp {
timestamp = w.getCurrentTime()
}
return timestamp
}

View File

@@ -1,221 +0,0 @@
package cluster
import (
"fmt"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/client"
"github.com/HDT3213/godis/src/redis/reply"
"strings"
)
func makeArgs(cmd string, args ...string) [][]byte {
result := make([][]byte, len(args)+1)
result[0] = []byte(cmd)
for i, arg := range args {
result[i+1] = []byte(arg)
}
return result
}
func (cluster *Cluster) getPeerClient(peer string) (*client.Client, error) {
peerClient, ok := cluster.peers[peer]
// lazy init
if !ok {
var err error
peerClient, err = client.MakeClient(peer)
if err != nil {
return nil, err
}
peerClient.Start()
cluster.peers[peer] = peerClient
}
return peerClient, nil
}
// return peer -> keys
func (cluster *Cluster) groupBy(keys []string) map[string][]string {
result := make(map[string][]string)
for _, key := range keys {
peer := cluster.peerPicker.Get(key)
group, ok := result[peer]
if !ok {
group = make([]string, 0)
}
group = append(group, key)
result[peer] = group
}
return result
}
// relay command to peer
func (cluster *Cluster) Relay(peer string, c redis.Connection, args [][]byte) redis.Reply {
if peer == cluster.self {
// to self db
return cluster.db.Exec(c, args)
} else {
peerClient, err := cluster.getPeerClient(peer)
if err != nil {
return reply.MakeErrReply(err.Error())
}
return peerClient.Send(args)
}
}
func defaultFunc(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
key := string(args[1])
peer := cluster.peerPicker.Get(key)
return cluster.Relay(peer, c, args)
}
func Ping(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) == 1 {
return &reply.PongReply{}
} else if len(args) == 2 {
return reply.MakeStatusReply("\"" + string(args[1]) + "\"")
} else {
return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command")
}
}
// TODO: support multiplex slots
func Rename(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rename' command")
}
src := string(args[1])
dest := string(args[2])
srcPeer := cluster.peerPicker.Get(src)
destPeer := cluster.peerPicker.Get(dest)
if srcPeer != destPeer {
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
}
return cluster.Relay(srcPeer, c, args)
}
func RenameNx(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'renamenx' command")
}
src := string(args[1])
dest := string(args[2])
srcPeer := cluster.peerPicker.Get(src)
destPeer := cluster.peerPicker.Get(dest)
if srcPeer != destPeer {
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
}
return cluster.Relay(srcPeer, c, args)
}
func MSetNX(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
argCount := len(args) - 1
if argCount%2 != 0 || argCount < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
}
var peer string
size := argCount / 2
for i := 0; i < size; i++ {
key := string(args[2*i])
currentPeer := cluster.peerPicker.Get(key)
if peer == "" {
peer = currentPeer
} else {
if peer != currentPeer {
return reply.MakeErrReply("ERR msetnx must within one slot in cluster mode")
}
}
}
return cluster.Relay(peer, c, args)
}
// TODO: avoid part failure
func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'del' command")
}
keys := make([]string, len(args)-1)
for i := 1; i < len(args); i++ {
keys[i-1] = string(args[i])
}
failedKeys := make([]string, 0)
groupMap := cluster.groupBy(keys)
for peer, group := range groupMap {
resp := cluster.Relay(peer, c, makeArgs("DEL", group...))
if reply.IsErrorReply(resp) {
failedKeys = append(failedKeys, group...)
}
}
if len(failedKeys) > 0 {
return reply.MakeErrReply("ERR part failure: " + strings.Join(failedKeys, ","))
}
return reply.MakeIntReply(int64(len(keys)))
}
func MGet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'del' command")
}
keys := make([]string, len(args)-1)
for i := 1; i < len(args); i++ {
keys[i-1] = string(args[i])
}
resultMap := make(map[string][]byte)
groupMap := cluster.groupBy(keys)
for peer, group := range groupMap {
resp := cluster.Relay(peer, c, makeArgs("MGET", group...))
if reply.IsErrorReply(resp) {
errReply := resp.(reply.ErrorReply)
return reply.MakeErrReply(fmt.Sprintf("ERR during get %s occurs: %v", group[0], errReply.Error()))
}
arrReply, _ := resp.(*reply.MultiBulkReply)
for i, v := range arrReply.Args {
key := group[i]
resultMap[key] = v
}
}
result := make([][]byte, len(keys))
for i, k := range keys {
result[i] = resultMap[k]
}
return reply.MakeMultiBulkReply(result)
}
func MSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
argCount := len(args) - 1
if argCount%2 != 0 || argCount < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
}
size := argCount / 2
keys := make([]string, size)
valueMap := make(map[string]string)
for i := 0; i < size; i++ {
keys[i] = string(args[2*i])
valueMap[keys[i]] = string(args[2*i+1])
}
failedKeys := make([]string, 0)
groupMap := cluster.groupBy(keys)
for peer, groupKeys := range groupMap {
peerArgs := make([][]byte, 2*len(groupKeys)+1)
peerArgs[0] = []byte("MSET")
for i, k := range groupKeys {
peerArgs[2*i+1] = []byte(k)
value := valueMap[k]
peerArgs[2*i+2] = []byte(value)
}
resp := cluster.Relay(peer, c, peerArgs)
if reply.IsErrorReply(resp) {
failedKeys = append(failedKeys, groupKeys...)
}
}
if len(failedKeys) > 0 {
return reply.MakeErrReply("ERR part failure: " + strings.Join(failedKeys, ","))
}
return &reply.OkReply{}
}

95
src/cluster/mset.go Normal file
View File

@@ -0,0 +1,95 @@
package cluster
import (
"fmt"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"strings"
)
func MGet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'del' command")
}
keys := make([]string, len(args)-1)
for i := 1; i < len(args); i++ {
keys[i-1] = string(args[i])
}
resultMap := make(map[string][]byte)
groupMap := cluster.groupBy(keys)
for peer, group := range groupMap {
resp := cluster.Relay(peer, c, makeArgs("MGET", group...))
if reply.IsErrorReply(resp) {
errReply := resp.(reply.ErrorReply)
return reply.MakeErrReply(fmt.Sprintf("ERR during get %s occurs: %v", group[0], errReply.Error()))
}
arrReply, _ := resp.(*reply.MultiBulkReply)
for i, v := range arrReply.Args {
key := group[i]
resultMap[key] = v
}
}
result := make([][]byte, len(keys))
for i, k := range keys {
result[i] = resultMap[k]
}
return reply.MakeMultiBulkReply(result)
}
func MSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
argCount := len(args) - 1
if argCount%2 != 0 || argCount < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
}
size := argCount / 2
keys := make([]string, size)
valueMap := make(map[string]string)
for i := 0; i < size; i++ {
keys[i] = string(args[2*i])
valueMap[keys[i]] = string(args[2*i+1])
}
failedKeys := make([]string, 0)
groupMap := cluster.groupBy(keys)
for peer, groupKeys := range groupMap {
peerArgs := make([][]byte, 2*len(groupKeys)+1)
peerArgs[0] = []byte("MSET")
for i, k := range groupKeys {
peerArgs[2*i+1] = []byte(k)
value := valueMap[k]
peerArgs[2*i+2] = []byte(value)
}
resp := cluster.Relay(peer, c, peerArgs)
if reply.IsErrorReply(resp) {
failedKeys = append(failedKeys, groupKeys...)
}
}
if len(failedKeys) > 0 {
return reply.MakeErrReply("ERR part failure: " + strings.Join(failedKeys, ","))
}
return &reply.OkReply{}
}
func MSetNX(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
argCount := len(args) - 1
if argCount%2 != 0 || argCount < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
}
var peer string
size := argCount / 2
for i := 0; i < size; i++ {
key := string(args[2*i])
currentPeer := cluster.peerPicker.Get(key)
if peer == "" {
peer = currentPeer
} else {
if peer != currentPeer {
return reply.MakeErrReply("ERR msetnx must within one slot in cluster mode")
}
}
}
return cluster.Relay(peer, c, args)
}

40
src/cluster/rename.go Normal file
View File

@@ -0,0 +1,40 @@
package cluster
import (
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
)
// TODO: support multiplex slots
func Rename(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rename' command")
}
src := string(args[1])
dest := string(args[2])
srcPeer := cluster.peerPicker.Get(src)
destPeer := cluster.peerPicker.Get(dest)
if srcPeer != destPeer {
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
}
return cluster.Relay(srcPeer, c, args)
}
func RenameNx(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'renamenx' command")
}
src := string(args[1])
dest := string(args[2])
srcPeer := cluster.peerPicker.Get(src)
destPeer := cluster.peerPicker.Get(dest)
if srcPeer != destPeer {
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
}
return cluster.Relay(srcPeer, c, args)
}

View File

@@ -1,10 +1,16 @@
package cluster
import "github.com/HDT3213/godis/src/interface/redis"
func MakeRouter() map[string]CmdFunc {
routerMap := make(map[string]CmdFunc)
routerMap["ping"] = Ping
routerMap["commit"] = Commit
routerMap["rollback"] = Rollback
routerMap["del"] = Del
routerMap["preparedel"] = PrepareDel
routerMap["expire"] = defaultFunc
routerMap["expireat"] = defaultFunc
routerMap["pexpire"] = defaultFunc
@@ -93,3 +99,9 @@ func MakeRouter() map[string]CmdFunc {
return routerMap
}
func defaultFunc(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
key := string(args[1])
peer := cluster.peerPicker.Get(key)
return cluster.Relay(peer, c, args)
}

105
src/cluster/transaction.go Normal file
View File

@@ -0,0 +1,105 @@
package cluster
import (
"context"
"github.com/HDT3213/godis/src/db"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/lib/marshal/gob"
"time"
)
type Transaction struct {
id string // transaction id
args [][]byte // cmd args
cluster *Cluster
conn redis.Connection
keys []string // related keys
undoLog map[string][]byte // store data for undoLog
lockUntil time.Time
ctx context.Context
cancel context.CancelFunc
status int8
}
const (
maxLockTime = 3 * time.Second
CreatedStatus = 0
PreparedStatus = 1
CommitedStatus = 2
RollbackedStatus = 3
)
func NewTransaction(cluster *Cluster, c redis.Connection, id string, args [][]byte, keys []string) *Transaction {
return &Transaction{
id: id,
args: args,
cluster: cluster,
conn: c,
keys: keys,
status: CreatedStatus,
}
}
// t should contains Keys field
func (tx *Transaction) prepare() error {
// lock keys
tx.cluster.db.Locks(tx.keys...)
// use context to manage
//tx.lockUntil = time.Now().Add(maxLockTime)
//ctx, cancel := context.WithDeadline(context.Background(), tx.lockUntil)
//tx.ctx = ctx
//tx.cancel = cancel
// build undoLog
tx.undoLog = make(map[string][]byte)
for _, key := range tx.keys {
entity, ok := tx.cluster.db.Get(key)
if ok {
blob, err := gob.Marshal(entity)
if err != nil {
return err
}
tx.undoLog[key] = blob
} else {
tx.undoLog[key] = []byte{} // entity was nil, should be removed while rollback
}
}
tx.status = PreparedStatus
return nil
}
func (tx *Transaction) rollback() error {
for key, blob := range tx.undoLog {
if len(blob) > 0 {
entity := &db.DataEntity{}
err := gob.UnMarshal(blob, entity)
if err != nil {
return err
}
tx.cluster.db.Put(key, entity)
} else {
tx.cluster.db.Remove(key)
}
}
tx.cluster.db.UnLocks(tx.keys...)
tx.status = RollbackedStatus
return nil
}
//func (tx *Transaction) commit(cmd CmdFunc) error {
// finished := make(chan int)
// go func() {
// cmd(tx.cluster, tx.conn, tx.args)
// finished <- 1
// }()
// select {
// case <- tx.ctx.Done():
// return tx.rollback()
// case <- finished:
// tx.cluster.db.UnLocks(tx.keys...)
// }
//}