diff --git a/cluster/router.go b/cluster/router.go index f65aabb..9aea963 100644 --- a/cluster/router.go +++ b/cluster/router.go @@ -34,7 +34,9 @@ func makeRouter() map[string]CmdFunc { routerMap["mget"] = MGet routerMap["msetnx"] = MSetNX routerMap["get"] = defaultFunc + routerMap["getex"] = defaultFunc routerMap["getset"] = defaultFunc + routerMap["getdel"] = defaultFunc routerMap["incr"] = defaultFunc routerMap["incrby"] = defaultFunc routerMap["incrbyfloat"] = defaultFunc diff --git a/commands.md b/commands.md index 9b246ad..10b0a1a 100644 --- a/commands.md +++ b/commands.md @@ -27,7 +27,9 @@ - mget - msetnx - get + - getex - getset + - getdel - incr - incrby - incrbyfloat diff --git a/database/string.go b/database/string.go index e18b645..69f340e 100644 --- a/database/string.go +++ b/database/string.go @@ -47,6 +47,79 @@ const ( const unlimitedTTL int64 = 0 +// execGetEX Get the value of key and optionally set its expiration +func execGetEX(db *DB, args [][]byte) redis.Reply { + key := string(args[0]) + bytes, err := db.getAsString(key) + ttl := unlimitedTTL + + if err != nil { + return err + } + if bytes == nil { + return &protocol.NullBulkReply{} + } + + for i := 1; i < len(args); i++ { + arg := strings.ToUpper(string(args[i])) + if arg == "EX" { // ttl in seconds + if ttl != unlimitedTTL { + // ttl has been set + return &protocol.SyntaxErrReply{} + } + if i+1 >= len(args) { + return &protocol.SyntaxErrReply{} + } + ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64) + if err != nil { + return &protocol.SyntaxErrReply{} + } + if ttlArg <= 0 { + return protocol.MakeErrReply("ERR invalid expire time in getex") + } + ttl = ttlArg * 1000 + i++ // skip next arg + } else if arg == "PX" { // ttl in milliseconds + if ttl != unlimitedTTL { + return &protocol.SyntaxErrReply{} + } + if i+1 >= len(args) { + return &protocol.SyntaxErrReply{} + } + ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64) + if err != nil { + return &protocol.SyntaxErrReply{} + } + if ttlArg <= 0 { + return protocol.MakeErrReply("ERR invalid expire time in getex") + } + ttl = ttlArg + i++ // skip next arg + } else if arg == "PERSIST" { + if ttl != unlimitedTTL { // PERSIST Cannot be used with EX | PX + return &protocol.SyntaxErrReply{} + } + if i+1 > len(args) { + return &protocol.SyntaxErrReply{} + } + db.Persist(key) + } + } + + if len(args) > 1 { + if ttl != unlimitedTTL { // EX | PX + expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond) + db.Expire(key, expireTime) + db.addAof(aof.MakeExpireCmd(key, expireTime).Args) + } else { // PERSIST + db.Persist(key) // override ttl + // we convert to persist command to write aof + db.addAof(utils.ToCmdLine3("persist", args[0])) + } + } + return protocol.MakeBulkReply(bytes) +} + // execSet sets string value and time to live to the given key func execSet(db *DB, args [][]byte) redis.Reply { key := string(args[0]) @@ -324,6 +397,22 @@ func execGetSet(db *DB, args [][]byte) redis.Reply { return protocol.MakeBulkReply(old) } +// execGetDel Get the value of key and delete the key. +func execGetDel(db *DB, args [][]byte) redis.Reply { + key := string(args[0]) + + old, err := db.getAsString(key) + if err != nil { + return err + } + if old == nil { + return new(protocol.NullBulkReply) + } + db.Remove(key) + db.addAof(utils.ToCmdLine3("getdel", args...)) + return protocol.MakeBulkReply(old) +} + // execIncr increments the integer value of a key by one func execIncr(db *DB, args [][]byte) redis.Reply { key := string(args[0]) @@ -743,7 +832,9 @@ func init() { RegisterCommand("MGet", execMGet, prepareMGet, nil, -2) RegisterCommand("MSetNX", execMSetNX, prepareMSet, undoMSet, -3) RegisterCommand("Get", execGet, readFirstKey, nil, 2) + RegisterCommand("GetEX", execGetEX, writeFirstKey, rollbackFirstKey, -2) RegisterCommand("GetSet", execGetSet, writeFirstKey, rollbackFirstKey, 3) + RegisterCommand("GetDel", execGetDel, writeFirstKey, rollbackFirstKey, 2) RegisterCommand("Incr", execIncr, writeFirstKey, rollbackFirstKey, 2) RegisterCommand("IncrBy", execIncrBy, writeFirstKey, rollbackFirstKey, 3) RegisterCommand("IncrByFloat", execIncrByFloat, writeFirstKey, rollbackFirstKey, 3) diff --git a/database/string_test.go b/database/string_test.go index 1a28bb3..d1a379d 100644 --- a/database/string_test.go +++ b/database/string_test.go @@ -283,6 +283,62 @@ func TestDecr(t *testing.T) { } } +func TestGetEX(t *testing.T) { + testDB.Flush() + key := utils.RandString(10) + value := utils.RandString(10) + ttl := "1000" + + testDB.Exec(nil, utils.ToCmdLine("SET", key, value)) + + // Normal Get + actual := testDB.Exec(nil, utils.ToCmdLine("GETEX", key)) + asserts.AssertBulkReply(t, actual, value) + + // Test GetEX Key EX Seconds + actual = testDB.Exec(nil, utils.ToCmdLine("GETEX", key, "EX", ttl)) + asserts.AssertBulkReply(t, actual, value) + actual = testDB.Exec(nil, utils.ToCmdLine("TTL", key)) + intResult, ok := actual.(*protocol.IntReply) + if !ok { + t.Error(fmt.Sprintf("expected int protocol, actually %s", actual.ToBytes())) + return + } + if intResult.Code <= 0 || intResult.Code > 1000 { + t.Error(fmt.Sprintf("expected int between [0, 1000], actually %d", intResult.Code)) + return + } + + // Test GetEX Key Persist + actual = testDB.Exec(nil, utils.ToCmdLine("GETEX", key, "PERSIST")) + asserts.AssertBulkReply(t, actual, value) + actual = testDB.Exec(nil, utils.ToCmdLine("TTL", key)) + intResult, ok = actual.(*protocol.IntReply) + if !ok { + t.Error(fmt.Sprintf("expected int protocol, actually %s", actual.ToBytes())) + return + } + if intResult.Code != -1 { + t.Error(fmt.Sprintf("expected int equals -1, actually %d", intResult.Code)) + return + } + + // Test GetEX Key NX Milliseconds + ttl = "1000000" + actual = testDB.Exec(nil, utils.ToCmdLine("GETEX", key, "PX", ttl)) + asserts.AssertBulkReply(t, actual, value) + actual = testDB.Exec(nil, utils.ToCmdLine("TTL", key)) + intResult, ok = actual.(*protocol.IntReply) + if !ok { + t.Error(fmt.Sprintf("expected int protocol, actually %s", actual.ToBytes())) + return + } + if intResult.Code <= 0 || intResult.Code > 1000000 { + t.Error(fmt.Sprintf("expected int between [0, 1000000], actually %d", intResult.Code)) + return + } +} + func TestGetSet(t *testing.T) { testDB.Flush() key := utils.RandString(10) @@ -300,6 +356,17 @@ func TestGetSet(t *testing.T) { asserts.AssertBulkReply(t, actual, value) actual = testDB.Exec(nil, utils.ToCmdLine("GET", key)) asserts.AssertBulkReply(t, actual, value2) + + // Test GetDel + actual = testDB.Exec(nil, utils.ToCmdLine("GETDEL", key)) + asserts.AssertBulkReply(t, actual, value2) + + actual = testDB.Exec(nil, utils.ToCmdLine("GETDEL", key)) + _, ok = actual.(*protocol.NullBulkReply) + if !ok { + t.Errorf("expect null bulk protocol, get: %s", string(actual.ToBytes())) + return + } } func TestMSetNX(t *testing.T) {