diff --git a/cluster/copy.go b/cluster/copy.go new file mode 100644 index 0000000..49a5673 --- /dev/null +++ b/cluster/copy.go @@ -0,0 +1,115 @@ +package cluster + +import ( + "github.com/hdt3213/godis/interface/redis" + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/protocol" + "strconv" + "strings" +) + +const copyToAnotherDBErr = "ERR Copying to another database is not allowed in cluster mode" +const noReplace = "NoReplace" +const useReplace = "UseReplace" + +// Copy copies the value stored at the source key to the destination key. +// the origin and the destination must within the same node. +func Copy(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { + if len(args) < 3 { + return protocol.MakeErrReply("ERR wrong number of arguments for 'copy' command") + } + srcKey := string(args[1]) + destKey := string(args[2]) + srcNode := cluster.peerPicker.PickNode(srcKey) + destNode := cluster.peerPicker.PickNode(destKey) + replaceFlag := noReplace + if len(args) > 3 { + for i := 3; i < len(args); i++ { + arg := strings.ToLower(string(args[i])) + if arg == "db" { + return protocol.MakeErrReply(copyToAnotherDBErr) + } else if arg == "replace" { + replaceFlag = useReplace + } else { + return protocol.MakeSyntaxErrReply() + } + } + } + + if srcNode == destNode { + return cluster.relay(srcNode, c, args) + } + groupMap := map[string][]string{ + srcNode: {srcKey}, + destNode: {destKey}, + } + + txID := cluster.idGenerator.NextID() + txIDStr := strconv.FormatInt(txID, 10) + // prepare Copy from + srcPrepareResp := cluster.relayPrepare(srcNode, c, makeArgs("Prepare", txIDStr, "CopyFrom", srcKey)) + if protocol.IsErrorReply(srcPrepareResp) { + // rollback src node + requestRollback(cluster, c, txID, map[string][]string{srcNode: {srcKey}}) + return srcPrepareResp + } + srcPrepareMBR, ok := srcPrepareResp.(*protocol.MultiBulkReply) + if !ok || len(srcPrepareMBR.Args) < 2 { + requestRollback(cluster, c, txID, map[string][]string{srcNode: {srcKey}}) + return protocol.MakeErrReply("ERR invalid prepare response") + } + // prepare Copy to + destPrepareResp := cluster.relayPrepare(destNode, c, utils.ToCmdLine3("Prepare", []byte(txIDStr), + []byte("CopyTo"), []byte(destKey), srcPrepareMBR.Args[0], srcPrepareMBR.Args[1], []byte(replaceFlag))) + if protocol.IsErrorReply(destPrepareResp) { + // rollback src node + requestRollback(cluster, c, txID, groupMap) + return destPrepareResp + } + if _, errReply := requestCommit(cluster, c, txID, groupMap); errReply != nil { + requestRollback(cluster, c, txID, groupMap) + return errReply + } + return protocol.MakeIntReply(1) +} + +// prepareCopyFrom is prepare-function for CopyFrom, see prepareFuncMap +func prepareCopyFrom(cluster *Cluster, conn redis.Connection, cmdLine CmdLine) redis.Reply { + if len(cmdLine) != 2 { + return protocol.MakeArgNumErrReply("CopyFrom") + } + key := string(cmdLine[1]) + existResp := cluster.db.ExecWithLock(conn, utils.ToCmdLine("Exists", key)) + if protocol.IsErrorReply(existResp) { + return existResp + } + existIntResp := existResp.(*protocol.IntReply) + if existIntResp.Code == 0 { + return protocol.MakeErrReply("ERR no such key") + } + return cluster.db.ExecWithLock(conn, utils.ToCmdLine2("DumpKey", key)) +} + +func prepareCopyTo(cluster *Cluster, conn redis.Connection, cmdLine CmdLine) redis.Reply { + if len(cmdLine) != 5 { + return protocol.MakeArgNumErrReply("CopyTo") + } + key := string(cmdLine[1]) + replaceFlag := string(cmdLine[4]) + existResp := cluster.db.ExecWithLock(conn, utils.ToCmdLine("Exists", key)) + if protocol.IsErrorReply(existResp) { + return existResp + } + existIntResp := existResp.(*protocol.IntReply) + if existIntResp.Code == 1 { + if replaceFlag == noReplace { + return protocol.MakeErrReply(keyExistsErr) + } + } + return protocol.MakeOkReply() +} + +func init() { + registerPrepareFunc("CopyFrom", prepareCopyFrom) + registerPrepareFunc("CopyTo", prepareCopyTo) +} diff --git a/cluster/copy_test.go b/cluster/copy_test.go new file mode 100644 index 0000000..11771fe --- /dev/null +++ b/cluster/copy_test.go @@ -0,0 +1,120 @@ +package cluster + +import ( + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/connection" + "github.com/hdt3213/godis/redis/protocol/asserts" + "testing" +) + +func TestCopy(t *testing.T) { + conn := new(connection.FakeConn) + testNodeA.db.Exec(conn, utils.ToCmdLine("FlushALL")) + + // cross node copy + srcKey := testNodeA.self + utils.RandString(10) + value := utils.RandString(10) + destKey := testNodeB.self + utils.RandString(10) + testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value)) + result := Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey)) + asserts.AssertIntReply(t, result, 1) + result = testNodeA.db.Exec(conn, utils.ToCmdLine("GET", srcKey)) + asserts.AssertBulkReply(t, result, value) + result = testNodeB.db.Exec(conn, utils.ToCmdLine("GET", destKey)) + asserts.AssertBulkReply(t, result, value) + // key exists + result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey)) + asserts.AssertErrReply(t, result, keyExistsErr) + // replace + value = utils.RandString(10) + testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value)) + result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey, "REPLACE")) + asserts.AssertIntReply(t, result, 1) + result = testNodeA.db.Exec(conn, utils.ToCmdLine("GET", srcKey)) + asserts.AssertBulkReply(t, result, value) + result = testNodeB.db.Exec(conn, utils.ToCmdLine("GET", destKey)) + asserts.AssertBulkReply(t, result, value) + // test copy expire time + testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value, "EX", "1000")) + result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey, "REPLACE")) + asserts.AssertIntReply(t, result, 1) + result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", srcKey)) + asserts.AssertIntReplyGreaterThan(t, result, 0) + result = testNodeB.db.Exec(conn, utils.ToCmdLine("TTL", destKey)) + asserts.AssertIntReplyGreaterThan(t, result, 0) + + // same node copy + srcKey = testNodeA.self + utils.RandString(10) + value = utils.RandString(10) + destKey = srcKey + utils.RandString(2) + testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value)) + result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey)) + asserts.AssertIntReply(t, result, 1) + result = testNodeA.db.Exec(conn, utils.ToCmdLine("GET", srcKey)) + asserts.AssertBulkReply(t, result, value) + result = testNodeA.db.Exec(conn, utils.ToCmdLine("GET", destKey)) + asserts.AssertBulkReply(t, result, value) + // key exists + result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey)) + asserts.AssertIntReply(t, result, 0) + // replace + value = utils.RandString(10) + testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value)) + result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey, "REPLACE")) + asserts.AssertIntReply(t, result, 1) + result = testNodeA.db.Exec(conn, utils.ToCmdLine("GET", srcKey)) + asserts.AssertBulkReply(t, result, value) + result = testNodeA.db.Exec(conn, utils.ToCmdLine("GET", destKey)) + asserts.AssertBulkReply(t, result, value) + // test copy expire time + testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value, "EX", "1000")) + result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey, "REPLACE")) + asserts.AssertIntReply(t, result, 1) + result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", srcKey)) + asserts.AssertIntReplyGreaterThan(t, result, 0) + result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", destKey)) + asserts.AssertIntReplyGreaterThan(t, result, 0) + + // test src prepare failed + *simulateATimout = true + srcKey = testNodeA.self + utils.RandString(10) + destKey = testNodeB.self + utils.RandString(10) // route to testNodeB, see mockPicker.PickNode + value = utils.RandString(10) + testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value, "ex", "1000")) + result = Rename(testNodeB, conn, utils.ToCmdLine("RENAME", srcKey, destKey)) + asserts.AssertErrReply(t, result, "ERR timeout") + result = testNodeA.db.Exec(conn, utils.ToCmdLine("EXISTS", srcKey)) + asserts.AssertIntReply(t, result, 1) + result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", srcKey)) + asserts.AssertIntReplyGreaterThan(t, result, 0) + result = testNodeB.db.Exec(conn, utils.ToCmdLine("EXISTS", destKey)) + asserts.AssertIntReply(t, result, 0) + *simulateATimout = false + + // test dest prepare failed + *simulateBTimout = true + srcKey = testNodeA.self + utils.RandString(10) + destKey = testNodeB.self + utils.RandString(10) // route to testNodeB, see mockPicker.PickNode + value = utils.RandString(10) + testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value, "ex", "1000")) + result = Rename(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey)) + asserts.AssertErrReply(t, result, "ERR timeout") + result = testNodeA.db.Exec(conn, utils.ToCmdLine("EXISTS", srcKey)) + asserts.AssertIntReply(t, result, 1) + result = testNodeA.db.Exec(conn, utils.ToCmdLine("TTL", srcKey)) + asserts.AssertIntReplyGreaterThan(t, result, 0) + result = testNodeB.db.Exec(conn, utils.ToCmdLine("EXISTS", destKey)) + asserts.AssertIntReply(t, result, 0) + *simulateBTimout = false + + // Copying to another database + srcKey = testNodeA.self + utils.RandString(10) + value = utils.RandString(10) + destKey = srcKey + utils.RandString(2) + testNodeA.db.Exec(conn, utils.ToCmdLine("SET", srcKey, value)) + result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey, destKey, "db", "1")) + asserts.AssertErrReply(t, result, copyToAnotherDBErr) + + result = Copy(testNodeA, conn, utils.ToCmdLine("COPY", srcKey)) + asserts.AssertErrReply(t, result, "ERR wrong number of arguments for 'copy' command") +} diff --git a/cluster/router.go b/cluster/router.go index 9aea963..677d02c 100644 --- a/cluster/router.go +++ b/cluster/router.go @@ -25,6 +25,7 @@ func makeRouter() map[string]CmdFunc { routerMap["type"] = defaultFunc routerMap["rename"] = Rename routerMap["renamenx"] = RenameNx + routerMap["copy"] = Copy routerMap["set"] = defaultFunc routerMap["setnx"] = defaultFunc diff --git a/cluster/utils.go b/cluster/utils.go index f9de730..fdeb669 100644 --- a/cluster/utils.go +++ b/cluster/utils.go @@ -42,7 +42,7 @@ func execSelect(c redis.Connection, args [][]byte) redis.Reply { if err != nil { return protocol.MakeErrReply("ERR invalid DB index") } - if dbIndex >= config.Properties.Databases { + if dbIndex >= config.Properties.Databases || dbIndex < 0 { return protocol.MakeErrReply("ERR DB index is out of range") } c.SelectDB(dbIndex) diff --git a/commands.md b/commands.md index 10b0a1a..adc4d5e 100644 --- a/commands.md +++ b/commands.md @@ -18,6 +18,7 @@ - flushall - keys - bgrewriteaof + - copy - String - set - setnx diff --git a/database/cluster_helper.go b/database/cluster_helper.go index d5204d9..c85eafb 100644 --- a/database/cluster_helper.go +++ b/database/cluster_helper.go @@ -88,11 +88,52 @@ func execRenameNxTo(db *DB, args [][]byte) redis.Reply { return execRenameTo(db, args) } +// execCopyFrom just reply "OK" message, used for cluster.Copy +func execCopyFrom(db *DB, args [][]byte) redis.Reply { + return protocol.MakeOkReply() +} + +// execCopyTo accepts result of execDumpKey and load the dumped key +// args format: key dumpCmd ttlCmd +// execCopyTo may be partially successful, do not use it without transaction +func execCopyTo(db *DB, args [][]byte) redis.Reply { + key := args[0] + dumpRawCmd, err := parser.ParseOne(args[1]) + if err != nil { + return protocol.MakeErrReply("illegal dump cmd: " + err.Error()) + } + dumpCmd, ok := dumpRawCmd.(*protocol.MultiBulkReply) + if !ok { + return protocol.MakeErrReply("dump cmd is not multi bulk reply") + } + dumpCmd.Args[1] = key // change key + ttlRawCmd, err := parser.ParseOne(args[2]) + if err != nil { + return protocol.MakeErrReply("illegal ttl cmd: " + err.Error()) + } + ttlCmd, ok := ttlRawCmd.(*protocol.MultiBulkReply) + if !ok { + return protocol.MakeErrReply("ttl cmd is not multi bulk reply") + } + ttlCmd.Args[1] = key + db.Remove(string(key)) + dumpResult := db.execWithLock(dumpCmd.Args) + if protocol.IsErrorReply(dumpResult) { + return dumpResult + } + ttlResult := db.execWithLock(ttlCmd.Args) + if protocol.IsErrorReply(ttlResult) { + return ttlResult + } + return protocol.MakeOkReply() +} + func init() { RegisterCommand("DumpKey", execDumpKey, writeAllKeys, undoDel, 2) RegisterCommand("ExistIn", execExistIn, readAllKeys, nil, -1) RegisterCommand("RenameFrom", execRenameFrom, readFirstKey, nil, 2) RegisterCommand("RenameTo", execRenameTo, writeFirstKey, rollbackFirstKey, 4) RegisterCommand("RenameNxTo", execRenameTo, writeFirstKey, rollbackFirstKey, 4) - + RegisterCommand("CopyFrom", execCopyFrom, readFirstKey, nil, 2) + RegisterCommand("CopyTo", execCopyTo, writeFirstKey, rollbackFirstKey, 5) } diff --git a/database/database.go b/database/database.go index b330b30..0ab2cc5 100644 --- a/database/database.go +++ b/database/database.go @@ -122,6 +122,11 @@ func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Rep return protocol.MakeArgNumErrReply("select") } return execSelect(c, mdb, cmdLine[1:]) + } else if cmdName == "copy" { + if len(cmdLine) < 3 { + return protocol.MakeArgNumErrReply("copy") + } + return execCopy(mdb, c, cmdLine[1:]) } // todo: support multi database transaction @@ -151,7 +156,7 @@ func execSelect(c redis.Connection, mdb *MultiDB, args [][]byte) redis.Reply { if err != nil { return protocol.MakeErrReply("ERR invalid DB index") } - if dbIndex >= len(mdb.dbSet) { + if dbIndex >= len(mdb.dbSet) || dbIndex < 0 { return protocol.MakeErrReply("ERR DB index is out of range") } c.SelectDB(dbIndex) diff --git a/database/keys.go b/database/keys.go index 344e2bd..c345f84 100644 --- a/database/keys.go +++ b/database/keys.go @@ -11,6 +11,7 @@ import ( "github.com/hdt3213/godis/lib/wildcard" "github.com/hdt3213/godis/redis/protocol" "strconv" + "strings" "time" ) @@ -308,6 +309,68 @@ func undoExpire(db *DB, args [][]byte) []CmdLine { } } +// execCopy usage: COPY source destination [DB destination-db] [REPLACE] +// This command copies the value stored at the source key to the destination key. +func execCopy(mdb *MultiDB, conn redis.Connection, args [][]byte) redis.Reply { + dbIndex := conn.GetDBIndex() + db := mdb.dbSet[dbIndex] // Current DB + replaceFlag := false + srcKey := string(args[0]) + destKey := string(args[1]) + + // Parse options + if len(args) > 2 { + for i := 2; i < len(args); i++ { + arg := strings.ToLower(string(args[i])) + if arg == "db" { + if i+1 >= len(args) { + return &protocol.SyntaxErrReply{} + } + idx, err := strconv.Atoi(string(args[i+1])) + if err != nil { + return &protocol.SyntaxErrReply{} + } + if idx >= len(mdb.dbSet) || idx < 0 { + return protocol.MakeErrReply("ERR DB index is out of range") + } + dbIndex = idx + i++ + } else if arg == "replace" { + replaceFlag = true + } else { + return &protocol.SyntaxErrReply{} + } + } + } + + if srcKey == destKey && dbIndex == conn.GetDBIndex() { + return protocol.MakeErrReply("ERR source and destination objects are the same") + } + + // source key does not exist + src, exists := db.GetEntity(srcKey) + if !exists { + return protocol.MakeIntReply(0) + } + + destDB := mdb.dbSet[dbIndex] + if _, exists = destDB.GetEntity(destKey); exists != false { + // If destKey exists and there is no "replace" option + if replaceFlag == false { + return protocol.MakeIntReply(0) + } + } + + destDB.PutEntity(destKey, src) + raw, exists := db.ttlMap.Get(srcKey) + if exists { + expire := raw.(time.Time) + destDB.Expire(destKey, expire) + } + mdb.aofHandler.AddAof(conn.GetDBIndex(), utils.ToCmdLine3("copy", args...)) + return protocol.MakeIntReply(1) +} + func init() { RegisterCommand("Del", execDel, writeAllKeys, undoDel, -2) RegisterCommand("Expire", execExpire, writeFirstKey, undoExpire, 3) diff --git a/database/keys_test.go b/database/keys_test.go index 9f2539d..37b83cd 100644 --- a/database/keys_test.go +++ b/database/keys_test.go @@ -3,6 +3,7 @@ package database import ( "fmt" "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/connection" "github.com/hdt3213/godis/redis/protocol" "github.com/hdt3213/godis/redis/protocol/asserts" "strconv" @@ -205,3 +206,48 @@ func TestKeys(t *testing.T) { result = testDB.Exec(nil, utils.ToCmdLine("keys", "?:*")) asserts.AssertMultiBulkReplySize(t, result, 2) } + +func TestCopy(t *testing.T) { + testDB.Flush() + testMDB := NewStandaloneServer() + srcKey := utils.RandString(10) + destKey := "from:" + srcKey + value := utils.RandString(10) + conn := new(connection.FakeConn) + + testMDB.Exec(conn, utils.ToCmdLine("set", srcKey, value)) + + // normal copy + result := testMDB.Exec(conn, utils.ToCmdLine("copy", srcKey, destKey)) + asserts.AssertIntReply(t, result, 1) + result = testMDB.Exec(conn, utils.ToCmdLine("get", destKey)) + asserts.AssertBulkReply(t, result, value) + + // copy srcKey(DB 0) to destKey(DB 1) + testMDB.Exec(conn, utils.ToCmdLine("copy", srcKey, destKey, "db", "1")) + testMDB.Exec(conn, utils.ToCmdLine("select", "1")) + result = testMDB.Exec(conn, utils.ToCmdLine("get", destKey)) + asserts.AssertBulkReply(t, result, value) + + // test destKey already exists + testMDB.Exec(conn, utils.ToCmdLine("select", "0")) + result = testMDB.Exec(conn, utils.ToCmdLine("copy", srcKey, destKey)) + asserts.AssertIntReply(t, result, 0) + + // copy srcKey(DB 0) to destKey(DB 0) with "Replace" + value = "new:" + value + testMDB.Exec(conn, utils.ToCmdLine("set", srcKey, value)) // reset srcKey + result = testMDB.Exec(conn, utils.ToCmdLine("copy", srcKey, destKey, "replace")) + asserts.AssertIntReply(t, result, 1) + result = testMDB.Exec(conn, utils.ToCmdLine("get", destKey)) + asserts.AssertBulkReply(t, result, value) + + // test copy expire time + testMDB.Exec(conn, utils.ToCmdLine("set", srcKey, value, "ex", "1000")) + result = testMDB.Exec(conn, utils.ToCmdLine("copy", srcKey, destKey, "replace")) + asserts.AssertIntReply(t, result, 1) + result = testMDB.Exec(conn, utils.ToCmdLine("ttl", srcKey)) + asserts.AssertIntReplyGreaterThan(t, result, 0) + result = testMDB.Exec(conn, utils.ToCmdLine("ttl", destKey)) + asserts.AssertIntReplyGreaterThan(t, result, 0) +}