diff --git a/database/list.go b/database/list.go index 416a459..4360a53 100644 --- a/database/list.go +++ b/database/list.go @@ -103,6 +103,27 @@ func execLPop(db *DB, args [][]byte) redis.Reply { return &protocol.NullBulkReply{} } + if len(args) == 2 { + count64, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return protocol.MakeErrReply("ERR value is not an integer or out of range") + } + count := int(count64) + if count > list.Len() { + count = list.Len() + } + vals := make([][]byte, count) + for i := 0; i < count; i++ { + val := list.Remove(0).([]byte) + vals[i] = val + } + if list.Len() == 0 { + db.Remove(key) + } + db.addAof(utils.ToCmdLine3("lpop", args...)) + return protocol.MakeMultiBulkReply(vals) + } + val, _ := list.Remove(0).([]byte) if list.Len() == 0 { db.Remove(key) @@ -122,6 +143,26 @@ func undoLPop(db *DB, args [][]byte) []CmdLine { if list == nil || list.Len() == 0 { return nil } + if len(args) == 2 { + count64, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return nil + } + count := int(count64) + if count > list.Len() { + count = list.Len() + } + vals := list.Range(0, count) + var cmds []CmdLine + for i := len(vals) - 1; i >= 0; i-- { + cmds = append(cmds, CmdLine{ + lPushCmd, + args[0], + vals[i].([]byte), + }) + } + return cmds + } element, _ := list.Get(0).([]byte) return []CmdLine{ { @@ -366,6 +407,27 @@ func execRPop(db *DB, args [][]byte) redis.Reply { return &protocol.NullBulkReply{} } + if len(args) == 2 { + count64, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return protocol.MakeErrReply("ERR value is not an integer or out of range") + } + count := int(count64) + if count > list.Len() { + count = list.Len() + } + vals := make([][]byte, count) + for i := 0; i < count; i++ { + val := list.RemoveLast().([]byte) + vals[i] = val + } + if list.Len() == 0 { + db.Remove(key) + } + db.addAof(utils.ToCmdLine3("rpop", args...)) + return protocol.MakeMultiBulkReply(vals) + } + val, _ := list.RemoveLast().([]byte) if list.Len() == 0 { db.Remove(key) @@ -385,6 +447,26 @@ func undoRPop(db *DB, args [][]byte) []CmdLine { if list == nil || list.Len() == 0 { return nil } + if len(args) == 2 { + count64, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return nil + } + count := int(count64) + if count > list.Len() { + count = list.Len() + } + vals := list.Range(list.Len()-count, list.Len()) + var cmds []CmdLine + for i := 0; i < count; i++ { + cmds = append(cmds, CmdLine{ + rPushCmd, + args[0], + vals[i].([]byte), + }) + } + return cmds + } element, _ := list.Get(list.Len() - 1).([]byte) return []CmdLine{ { @@ -614,9 +696,9 @@ func init() { attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM, redisFlagFast}, 1, 1, 1) registerCommand("RPushX", execRPushX, writeFirstKey, undoRPush, -3, flagWrite). attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM, redisFlagFast}, 1, 1, 1) - registerCommand("LPop", execLPop, writeFirstKey, undoLPop, 2, flagWrite). + registerCommand("LPop", execLPop, writeFirstKey, undoLPop, -2, flagWrite). attachCommandExtra([]string{redisFlagWrite, redisFlagFast}, 1, 1, 1) - registerCommand("RPop", execRPop, writeFirstKey, undoRPop, 2, flagWrite). + registerCommand("RPop", execRPop, writeFirstKey, undoRPop, -2, flagWrite). attachCommandExtra([]string{redisFlagWrite, redisFlagFast}, 1, 1, 1) registerCommand("RPopLPush", execRPopLPush, prepareRPopLPush, undoRPopLPush, 3, flagWrite). attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM}, 1, 1, 1) diff --git a/database/list_test.go b/database/list_test.go index b46ef82..14e8501 100644 --- a/database/list_test.go +++ b/database/list_test.go @@ -287,6 +287,21 @@ func TestLPop(t *testing.T) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Errorf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())) } + + // test count + testDB.Exec(nil, utils.ToCmdLine2("rpush", values...)) + result = testDB.Exec(nil, utils.ToCmdLine("lpop", key, strconv.Itoa(10))) + reply, ok := result.(*protocol.MultiBulkReply) + if !ok { + t.Errorf("parse error expected *protocol.MultiBulkReply, actually %s", result) + } + for i, arg := range reply.Args { + result := string(arg) + expected := values[i+1] + if result != expected { + t.Errorf("expected %s, actually %s", expected, result) + } + } } func TestRPop(t *testing.T) { @@ -308,6 +323,21 @@ func TestRPop(t *testing.T) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Errorf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())) } + + // test count + testDB.Exec(nil, utils.ToCmdLine2("rpush", values...)) + result = testDB.Exec(nil, utils.ToCmdLine("rpop", key, strconv.Itoa(10))) + reply, ok := result.(*protocol.MultiBulkReply) + if !ok { + t.Errorf("parse error expected *protocol.MultiBulkReply, actually %s", result) + } + for i, arg := range reply.Args { + result := string(arg) + expected := values[len(values)-i-1] + if result != expected { + t.Errorf("expected %s, actually %s", expected, result) + } + } } func TestRPopLPush(t *testing.T) { @@ -476,6 +506,14 @@ func TestUndoLPop(t *testing.T) { } result := testDB.Exec(nil, utils.ToCmdLine("llen", key)) asserts.AssertIntReply(t, result, 2) + cmdLine = utils.ToCmdLine("lpop", key, strconv.Itoa(2)) + undoCmdLines = undoRPop(testDB, cmdLine[1:]) + testDB.Exec(nil, cmdLine) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + result = testDB.Exec(nil, utils.ToCmdLine("llen", key)) + asserts.AssertIntReply(t, result, 2) } func TestUndoLSet(t *testing.T) { @@ -507,6 +545,14 @@ func TestUndoRPop(t *testing.T) { } result := testDB.Exec(nil, utils.ToCmdLine("llen", key)) asserts.AssertIntReply(t, result, 2) + cmdLine = utils.ToCmdLine("rpop", key, strconv.Itoa(2)) + undoCmdLines = undoRPop(testDB, cmdLine[1:]) + testDB.Exec(nil, cmdLine) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + result = testDB.Exec(nil, utils.ToCmdLine("llen", key)) + asserts.AssertIntReply(t, result, 2) } func TestUndoRPopLPush(t *testing.T) {