diff --git a/database/string.go b/database/string.go index 12aaf72..05be64f 100644 --- a/database/string.go +++ b/database/string.go @@ -2,11 +2,13 @@ package database import ( "github.com/hdt3213/godis/aof" + "github.com/hdt3213/godis/datastruct/bitmap" "github.com/hdt3213/godis/interface/database" "github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/redis/protocol" "github.com/shopspring/decimal" + "math/bits" "strconv" "strings" "time" @@ -536,46 +538,200 @@ func execSetRange(db *DB, args [][]byte) redis.Reply { func execGetRange(db *DB, args [][]byte) redis.Reply { key := string(args[0]) - startIdx, errNative := strconv.ParseInt(string(args[1]), 10, 64) - if errNative != nil { - return protocol.MakeErrReply(errNative.Error()) + startIdx, err2 := strconv.ParseInt(string(args[1]), 10, 64) + if err2 != nil { + return protocol.MakeErrReply("ERR value is not an integer or out of range") } - endIdx, errNative := strconv.ParseInt(string(args[2]), 10, 64) - if errNative != nil { - return protocol.MakeErrReply(errNative.Error()) + endIdx, err2 := strconv.ParseInt(string(args[2]), 10, 64) + if err2 != nil { + return protocol.MakeErrReply("ERR value is not an integer or out of range") } - bytes, err := db.getAsString(key) + bs, err := db.getAsString(key) if err != nil { return err } - - if bytes == nil { + if bs == nil { return protocol.MakeNullBulkReply() } - - bytesLen := int64(len(bytes)) - if startIdx < -1*bytesLen { - return &protocol.NullBulkReply{} - } else if startIdx < 0 { - startIdx = bytesLen + startIdx - } else if startIdx >= bytesLen { - return &protocol.NullBulkReply{} + bytesLen := int64(len(bs)) + beg, end := utils.ConvertRange(startIdx, endIdx, bytesLen) + if beg < 0 { + return protocol.MakeNullBulkReply() } - if endIdx < -1*bytesLen { - return &protocol.NullBulkReply{} - } else if endIdx < 0 { - endIdx = bytesLen + endIdx + 1 - } else if endIdx < bytesLen { - endIdx = endIdx + 1 + return protocol.MakeBulkReply(bs[beg:end]) +} + +func execSetBit(db *DB, args [][]byte) redis.Reply { + key := string(args[0]) + offset, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return protocol.MakeErrReply("ERR bit offset is not an integer or out of range") + } + valStr := string(args[2]) + var v byte + if valStr == "1" { + v = 1 + } else if valStr == "0" { + v = 0 } else { - endIdx = bytesLen + return protocol.MakeErrReply("ERR bit is not an integer or out of range\n") } - if startIdx > endIdx { - return protocol.MakeNullBulkReply() + bs, errReply := db.getAsString(key) + if errReply != nil { + return errReply } + bm := bitmap.FromBytes(bs) + former := bm.GetBit(offset) + bm.SetBit(offset, v) + db.PutEntity(key, &database.DataEntity{Data: bm.ToBytes()}) + return protocol.MakeIntReply(int64(former)) +} - return protocol.MakeBulkReply(bytes[startIdx:endIdx]) +func execGetBit(db *DB, args [][]byte) redis.Reply { + key := string(args[0]) + offset, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return protocol.MakeErrReply("ERR bit offset is not an integer or out of range") + } + bs, errReply := db.getAsString(key) + if errReply != nil { + return errReply + } + if bs == nil { + return protocol.MakeIntReply(0) + } + bm := bitmap.FromBytes(bs) + return protocol.MakeIntReply(int64(bm.GetBit(offset))) +} + +func execBitCount(db *DB, args [][]byte) redis.Reply { + key := string(args[0]) + bs, err := db.getAsString(key) + if err != nil { + return err + } + if bs == nil { + return protocol.MakeIntReply(0) + } + byteMode := true + if len(args) > 3 { + mode := strings.ToLower(string(args[3])) + if mode == "bit" { + byteMode = false + } else if mode == "byte" { + byteMode = true + } else { + return protocol.MakeErrReply("ERR syntax error") + } + } + var size int64 + bm := bitmap.FromBytes(bs) + if byteMode { + size = int64(len(*bm)) + } else { + size = int64(bm.BitSize()) + } + var beg, end int + if len(args) > 1 { + var err2 error + var startIdx, endIdx int64 + startIdx, err2 = strconv.ParseInt(string(args[1]), 10, 64) + if err2 != nil { + return protocol.MakeErrReply("ERR value is not an integer or out of range") + } + endIdx, err2 = strconv.ParseInt(string(args[2]), 10, 64) + if err2 != nil { + return protocol.MakeErrReply("ERR value is not an integer or out of range") + } + beg, end = utils.ConvertRange(startIdx, endIdx, size) + if beg < 0 { + return protocol.MakeIntReply(0) + } + } + var count int64 + if byteMode { + bm.ForEachByte(beg, end, func(offset int64, val byte) bool { + count += int64(bits.OnesCount8(val)) + return true + }) + } else { + bm.ForEachBit(int64(beg), int64(end), func(offset int64, val byte) bool { + if val > 0 { + count++ + } + return true + }) + } + return protocol.MakeIntReply(count) +} + +func execBitPos(db *DB, args [][]byte) redis.Reply { + key := string(args[0]) + bs, err := db.getAsString(key) + if err != nil { + return err + } + if bs == nil { + return protocol.MakeIntReply(-1) + } + valStr := string(args[1]) + var v byte + if valStr == "1" { + v = 1 + } else if valStr == "0" { + v = 0 + } else { + return protocol.MakeErrReply("ERR bit is not an integer or out of range") + } + byteMode := true + if len(args) > 4 { + mode := strings.ToLower(string(args[4])) + if mode == "bit" { + byteMode = false + } else if mode == "byte" { + byteMode = true + } else { + return protocol.MakeErrReply("ERR syntax error") + } + } + var size int64 + bm := bitmap.FromBytes(bs) + if byteMode { + size = int64(len(*bm)) + } else { + size = int64(bm.BitSize()) + } + var beg, end int + if len(args) > 2 { + var err2 error + var startIdx, endIdx int64 + startIdx, err2 = strconv.ParseInt(string(args[2]), 10, 64) + if err2 != nil { + return protocol.MakeErrReply("ERR value is not an integer or out of range") + } + endIdx, err2 = strconv.ParseInt(string(args[3]), 10, 64) + if err2 != nil { + return protocol.MakeErrReply("ERR value is not an integer or out of range") + } + beg, end = utils.ConvertRange(startIdx, endIdx, size) + if beg < 0 { + return protocol.MakeIntReply(0) + } + } + if byteMode { + beg *= 8 + end *= 8 + } + var offset = int64(-1) + bm.ForEachBit(int64(beg), int64(end), func(o int64, val byte) bool { + if val == v { + offset = o + return false + } + return true + }) + return protocol.MakeIntReply(offset) } func init() { @@ -597,4 +753,9 @@ func init() { RegisterCommand("Append", execAppend, writeFirstKey, rollbackFirstKey, 3) RegisterCommand("SetRange", execSetRange, writeFirstKey, rollbackFirstKey, 4) RegisterCommand("GetRange", execGetRange, readFirstKey, nil, 4) + RegisterCommand("SetBit", execSetBit, writeFirstKey, rollbackFirstKey, 4) + RegisterCommand("GetBit", execGetBit, readFirstKey, nil, 3) + RegisterCommand("BitCount", execBitCount, readFirstKey, nil, -2) + RegisterCommand("BitPos", execBitPos, readFirstKey, nil, -3) + } diff --git a/database/string_test.go b/database/string_test.go index 3618105..eb8c57c 100644 --- a/database/string_test.go +++ b/database/string_test.go @@ -323,12 +323,12 @@ func TestStrLen(t *testing.T) { testDB.Exec(nil, utils.ToCmdLine2("SET", key, key)) actual := testDB.Exec(nil, utils.ToCmdLine("StrLen", key)) - len, ok := actual.(*protocol.IntReply) + size, ok := actual.(*protocol.IntReply) if !ok { t.Errorf("expect int bulk protocol, get: %s", string(actual.ToBytes())) return } - asserts.AssertIntReply(t, len, 10) + asserts.AssertIntReply(t, size, 10) } func TestStrLen_KeyNotExist(t *testing.T) { @@ -648,3 +648,90 @@ func TestGetRange_StringExist_EndIdxIncorrectFormat(t *testing.T) { errorMsg := fmt.Sprintf("strconv.ParseInt: parsing \"%s\": invalid syntax", incorrectValue) asserts.AssertErrReply(t, val, errorMsg) } + +func TestSetBit(t *testing.T) { + testDB.Flush() + key := utils.RandString(10) + actual := testDB.Exec(nil, utils.ToCmdLine("SetBit", key, "15", "1")) + asserts.AssertIntReply(t, actual, 0) + actual = testDB.Exec(nil, utils.ToCmdLine("SetBit", key, "15", "0")) + asserts.AssertIntReply(t, actual, 1) + testDB.Exec(nil, utils.ToCmdLine("SetBit", key, "13", "1")) + actual = testDB.Exec(nil, utils.ToCmdLine("GetBit", key, "13")) + asserts.AssertIntReply(t, actual, 1) + actual = testDB.Exec(nil, utils.ToCmdLine("GetBit", key+"1", "13")) // test not exist key + asserts.AssertIntReply(t, actual, 0) + + actual = testDB.Exec(nil, utils.ToCmdLine("SetBit", key, "13", "a")) + asserts.AssertErrReply(t, actual, "ERR bit is not an integer or out of range") + actual = testDB.Exec(nil, utils.ToCmdLine("SetBit", key, "a", "1")) + asserts.AssertErrReply(t, actual, "ERR bit offset is not an integer or out of range") + actual = testDB.Exec(nil, utils.ToCmdLine("GetBit", key, "a")) + asserts.AssertErrReply(t, actual, "ERR bit offset is not an integer or out of range") + + key2 := utils.RandString(11) + testDB.Exec(nil, utils.ToCmdLine("rpush", key2, "1")) + actual = testDB.Exec(nil, utils.ToCmdLine("SetBit", key2, "15", "0")) + asserts.AssertErrReply(t, actual, "WRONGTYPE Operation against a key holding the wrong kind of value") + actual = testDB.Exec(nil, utils.ToCmdLine("GetBit", key2, "15")) + asserts.AssertErrReply(t, actual, "WRONGTYPE Operation against a key holding the wrong kind of value") +} + +func TestBitCount(t *testing.T) { + testDB.Flush() + key := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine("SetBit", key, "15", "1")) + testDB.Exec(nil, utils.ToCmdLine("SetBit", key, "13", "1")) + actual := testDB.Exec(nil, utils.ToCmdLine("BitCount", key)) + asserts.AssertIntReply(t, actual, 2) + actual = testDB.Exec(nil, utils.ToCmdLine("BitCount", key, "14", "15", "BIT")) + asserts.AssertIntReply(t, actual, 1) + actual = testDB.Exec(nil, utils.ToCmdLine("BitCount", key, "16", "20", "BIT")) + asserts.AssertIntReply(t, actual, 0) + actual = testDB.Exec(nil, utils.ToCmdLine("BitCount", key, "1", "1", "BYTE")) + asserts.AssertIntReply(t, actual, 2) + + key2 := utils.RandString(11) + testDB.Exec(nil, utils.ToCmdLine("rpush", key2, "1")) + actual = testDB.Exec(nil, utils.ToCmdLine("BitCount", key2)) + asserts.AssertErrReply(t, actual, "WRONGTYPE Operation against a key holding the wrong kind of value") + actual = testDB.Exec(nil, utils.ToCmdLine("BitCount", key+"a")) + asserts.AssertIntReply(t, actual, 0) + actual = testDB.Exec(nil, utils.ToCmdLine("BitCount", key, "14", "15", "B")) + asserts.AssertErrReply(t, actual, "ERR syntax error") + actual = testDB.Exec(nil, utils.ToCmdLine("BitCount", key, "14", "A")) + asserts.AssertErrReply(t, actual, "ERR value is not an integer or out of range") + actual = testDB.Exec(nil, utils.ToCmdLine("BitCount", key, "A", "-1")) + asserts.AssertErrReply(t, actual, "ERR value is not an integer or out of range") +} + +func TestBitPos(t *testing.T) { + testDB.Flush() + key := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine("SetBit", key, "15", "1")) + actual := testDB.Exec(nil, utils.ToCmdLine("BitPos", key, "0")) + asserts.AssertIntReply(t, actual, 0) + actual = testDB.Exec(nil, utils.ToCmdLine("BitPos", key, "1")) + asserts.AssertIntReply(t, actual, 15) + actual = testDB.Exec(nil, utils.ToCmdLine("BitPos", key, "1", "0", "-1", "BIT")) + asserts.AssertIntReply(t, actual, 15) + actual = testDB.Exec(nil, utils.ToCmdLine("BitPos", key, "1", "1", "1", "BYTE")) + asserts.AssertIntReply(t, actual, 15) + actual = testDB.Exec(nil, utils.ToCmdLine("BitPos", key, "0", "1", "1", "BYTE")) + asserts.AssertIntReply(t, actual, 8) + + key2 := utils.RandString(12) + testDB.Exec(nil, utils.ToCmdLine("rpush", key2, "1")) + actual = testDB.Exec(nil, utils.ToCmdLine("BitPos", key2, "1")) + asserts.AssertErrReply(t, actual, "WRONGTYPE Operation against a key holding the wrong kind of value") + actual = testDB.Exec(nil, utils.ToCmdLine("BitPos", key+"a", "1")) + asserts.AssertIntReply(t, actual, -1) + actual = testDB.Exec(nil, utils.ToCmdLine("BitPos", key, "1", "1", "15", "B")) + asserts.AssertErrReply(t, actual, "ERR syntax error") + actual = testDB.Exec(nil, utils.ToCmdLine("BitPos", key, "1", "14", "A")) + asserts.AssertErrReply(t, actual, "ERR value is not an integer or out of range") + actual = testDB.Exec(nil, utils.ToCmdLine("BitPos", key, "1", "a", "14")) + asserts.AssertErrReply(t, actual, "ERR value is not an integer or out of range") + actual = testDB.Exec(nil, utils.ToCmdLine("BitPos", key, "-1")) + asserts.AssertErrReply(t, actual, "ERR bit is not an integer or out of range") +} diff --git a/datastruct/bitmap/bitmap.go b/datastruct/bitmap/bitmap.go new file mode 100644 index 0000000..bce4eba --- /dev/null +++ b/datastruct/bitmap/bitmap.go @@ -0,0 +1,97 @@ +package bitmap + +type BitMap []byte + +func New() *BitMap { + b := BitMap(make([]byte, 0)) + return &b +} + +func toByteSize(bitSize int64) int64 { + if bitSize%8 == 0 { + return bitSize / 8 + } + return bitSize/8 + 1 +} + +func (b *BitMap) grow(bitSize int64) { + byteSize := toByteSize(bitSize) + gap := byteSize - int64(len(*b)) + if gap <= 0 { + return + } + *b = append(*b, make([]byte, gap)...) +} + +func (b *BitMap) BitSize() int { + return len(*b) * 8 +} + +func FromBytes(bytes []byte) *BitMap { + bm := BitMap(bytes) + return &bm +} + +func (b *BitMap) ToBytes() []byte { + return *b +} + +func (b *BitMap) SetBit(offset int64, val byte) { + byteIndex := offset / 8 + bitOffset := offset % 8 + mask := byte(1 << bitOffset) + b.grow(offset + 1) + if val > 0 { + // set bit + (*b)[byteIndex] |= mask + } else { + // clear bit + (*b)[byteIndex] &^= mask + } +} + +func (b *BitMap) GetBit(offset int64) byte { + byteIndex := offset / 8 + bitOffset := offset % 8 + if byteIndex >= int64(len(*b)) { + return 0 + } + return ((*b)[byteIndex] >> bitOffset) & 0x01 +} + +type Callback func(offset int64, val byte) bool + +func (b *BitMap) ForEachBit(begin int64, end int64, cb Callback) { + offset := begin + byteIndex := offset / 8 + bitOffset := offset % 8 + for byteIndex < int64(len(*b)) { + b := (*b)[byteIndex] + for bitOffset < 8 { + bit := byte(b >> bitOffset & 0x01) + if !cb(offset, bit) { + return + } + bitOffset++ + offset++ + } + byteIndex++ + bitOffset = 0 + if end > 0 && offset == end { + break + } + } +} + +func (b *BitMap) ForEachByte(begin int, end int, cb Callback) { + if end == 0 { + end = len(*b) + } else if end > len(*b) { + end = len(*b) + } + for i := begin; i < end; i++ { + if !cb(int64(i), (*b)[i]) { + return + } + } +} diff --git a/datastruct/bitmap/bitmap_test.go b/datastruct/bitmap/bitmap_test.go new file mode 100644 index 0000000..45111e8 --- /dev/null +++ b/datastruct/bitmap/bitmap_test.go @@ -0,0 +1,168 @@ +package bitmap + +import ( + "bytes" + "math/rand" + "testing" +) + +func TestSetBit(t *testing.T) { + // generate distinct rand offsets + size := 1000 + offsets := make([]int64, size) + for i := 0; i < size; i++ { + offsets[i] = int64(i) + } + rand.Shuffle(size, func(i, j int) { + offsets[i], offsets[j] = offsets[j], offsets[i] + }) + offsets = offsets[0 : size/5] + // set bit + offsetMap := make(map[int64]struct{}) + bm := New() + for _, offset := range offsets { + offsetMap[offset] = struct{}{} + bm.SetBit(offset, 1) + } + // get bit + for i := 0; i < bm.BitSize(); i++ { + offset := int64(i) + _, expect := offsetMap[offset] + actual := bm.GetBit(offset) > 0 + if expect != actual { + t.Errorf("wrong value at %d", offset) + } + } + + bm2 := New() + bm2.SetBit(15, 1) + if bm2.GetBit(15) != 1 { + t.Error("wrong value") + } + if bm2.GetBit(16) != 0 { + t.Error("wrong value") + } +} + +func TestFromBytes(t *testing.T) { + bs := []byte{0xff, 0xff} + bm := FromBytes(bs) + bm.SetBit(8, 0) + expect := []byte{0xff, 0xfe} + if !bytes.Equal(bs, expect) { + t.Error("wrong value") + } + ret := bm.ToBytes() + if !bytes.Equal(ret, expect) { + t.Error("wrong value") + } +} + +func TestForEachBit(t *testing.T) { + bm := New() + for i := 0; i < 1000; i++ { + if i%2 == 0 { + bm.SetBit(int64(i), 1) + } + } + expectOffset := int64(100) + count := 0 + bm.ForEachBit(100, 200, func(offset int64, val byte) bool { + if offset != expectOffset { + t.Error("wrong offset") + } + expectOffset++ + if offset%2 == 0 && val == 0 { + t.Error("expect 1") + } + count++ + return true + }) + if count != 100 { + t.Error("wrong count") + } + bm = New() + size := 1000 + offsets := make([]int64, size) + for i := 0; i < size; i++ { + offsets[i] = int64(i) + } + rand.Shuffle(size, func(i, j int) { + offsets[i], offsets[j] = offsets[j], offsets[i] + }) + offsets = offsets[0 : size/5] + offsetMap := make(map[int64]struct{}) + for _, offset := range offsets { + bm.SetBit(offset, 1) + offsetMap[offset] = struct{}{} + } + bm.ForEachBit(int64(size/20+1), 0, func(offset int64, val byte) bool { + _, expect := offsetMap[offset] + actual := bm.GetBit(offset) > 0 + if expect != actual { + t.Errorf("wrong value at %d", offset) + } + return true + }) + count = 0 + bm.ForEachBit(0, 0, func(offset int64, val byte) bool { + count++ + return false + }) + if count != 1 { + t.Error("break failed") + } +} + +func TestBitMap_ForEachByte(t *testing.T) { + bm := New() + for i := 0; i < 1000; i++ { + if i%16 == 0 { + bm.SetBit(int64(i), 1) + } + } + bm.ForEachByte(0, 0, func(offset int64, val byte) bool { + if offset%2 == 0 { + if val != 1 { + t.Error("wrong value") + } + } else { + if val != 0 { + t.Error("wrong value") + } + } + return true + }) + bm.ForEachByte(0, 2000, func(offset int64, val byte) bool { + if offset%2 == 0 { + if val != 1 { + t.Error("wrong value") + } + } else { + if val != 0 { + t.Error("wrong value") + } + } + return true + }) + bm.ForEachByte(0, 500, func(offset int64, val byte) bool { + if offset%2 == 0 { + if val != 1 { + t.Error("wrong value") + } + } else { + if val != 0 { + t.Error("wrong value") + } + } + return true + }) + count := 0 + bm.ForEachByte(0, 0, func(offset int64, val byte) bool { + count++ + return false + }) + if count != 1 { + t.Error("break failed") + } +} diff --git a/lib/utils/utils.go b/lib/utils/utils.go index c92fbe7..a912e3e 100644 --- a/lib/utils/utils.go +++ b/lib/utils/utils.go @@ -57,3 +57,30 @@ func BytesEquals(a []byte, b []byte) bool { } return true } + +// ConvertRange converts redis index to go slice index +// -1 => size-1 +// both inclusive [0, 10] => left inclusive right exclusive [0, 9) +// out of bound to max inbound [size, size+1] => [-1, -1] +func ConvertRange(start int64, end int64, size int64) (int, int) { + if start < -size { + return -1, -1 + } else if start < 0 { + start = size + start + } else if start >= size { + return -1, -1 + } + if end < -size { + return -1, -1 + } else if end < 0 { + end = size + end + 1 + } else if end < size { + end = end + 1 + } else { + end = size + } + if start > end { + return -1, -1 + } + return int(start), int(end) +}