add some unit tests

This commit is contained in:
hdt3213
2021-03-31 17:11:46 +08:00
parent 4b01bbb52a
commit bf913a5aca
16 changed files with 2866 additions and 2198 deletions

View File

@@ -2,7 +2,6 @@ package cluster
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"github.com/HDT3213/godis/src/cluster/idgenerator" "github.com/HDT3213/godis/src/cluster/idgenerator"
"github.com/HDT3213/godis/src/config" "github.com/HDT3213/godis/src/config"
@@ -11,7 +10,6 @@ import (
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/lib/consistenthash" "github.com/HDT3213/godis/src/lib/consistenthash"
"github.com/HDT3213/godis/src/lib/logger" "github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/redis/client"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"github.com/jolestar/go-commons-pool/v2" "github.com/jolestar/go-commons-pool/v2"
"runtime/debug" "runtime/debug"
@@ -98,30 +96,6 @@ func (cluster *Cluster) AfterClientClose(c redis.Connection) {
} }
func (cluster *Cluster) getPeerClient(peer string) (*client.Client, error) {
connectionFactory, ok := cluster.peerConnection[peer]
if !ok {
return nil, errors.New("connection factory not found")
}
raw, err := connectionFactory.BorrowObject(context.Background())
if err != nil {
return nil, err
}
conn, ok := raw.(*client.Client)
if !ok {
return nil, errors.New("connection factory make wrong type")
}
return conn, nil
}
func (cluster *Cluster) returnPeerClient(peer string, peerClient *client.Client) error {
connectionFactory, ok := cluster.peerConnection[peer]
if !ok {
return errors.New("connection factory not found")
}
return connectionFactory.ReturnObject(context.Background(), peerClient)
}
func Ping(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { func Ping(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) == 1 { if len(args) == 1 {
return &reply.PongReply{} return &reply.PongReply{}
@@ -132,24 +106,6 @@ func Ping(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
} }
} }
// relay command to peer
// cannot call Prepare, Commit, Rollback of self node
func (cluster *Cluster) Relay(peer string, c redis.Connection, args [][]byte) redis.Reply {
if peer == cluster.self {
// to self db
return cluster.db.Exec(c, args)
} else {
peerClient, err := cluster.getPeerClient(peer)
if err != nil {
return reply.MakeErrReply(err.Error())
}
defer func() {
_ = cluster.returnPeerClient(peer, peerClient)
}()
return peerClient.Send(args)
}
}
/*----- utils -------*/ /*----- utils -------*/
func makeArgs(cmd string, args ...string) [][]byte { func makeArgs(cmd string, args ...string) [][]byte {

52
src/cluster/com.go Normal file
View File

@@ -0,0 +1,52 @@
// communicate with peers within cluster
package cluster
import (
"context"
"errors"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/client"
"github.com/HDT3213/godis/src/redis/reply"
)
func (cluster *Cluster) getPeerClient(peer string) (*client.Client, error) {
connectionFactory, ok := cluster.peerConnection[peer]
if !ok {
return nil, errors.New("connection factory not found")
}
raw, err := connectionFactory.BorrowObject(context.Background())
if err != nil {
return nil, err
}
conn, ok := raw.(*client.Client)
if !ok {
return nil, errors.New("connection factory make wrong type")
}
return conn, nil
}
func (cluster *Cluster) returnPeerClient(peer string, peerClient *client.Client) error {
connectionFactory, ok := cluster.peerConnection[peer]
if !ok {
return errors.New("connection factory not found")
}
return connectionFactory.ReturnObject(context.Background(), peerClient)
}
// relay command to peer
// cannot call Prepare, Commit, Rollback of self node
func (cluster *Cluster) Relay(peer string, c redis.Connection, args [][]byte) redis.Reply {
if peer == cluster.self {
// to self db
return cluster.db.Exec(c, args)
} else {
peerClient, err := cluster.getPeerClient(peer)
if err != nil {
return reply.MakeErrReply(err.Error())
}
defer func() {
_ = cluster.returnPeerClient(peer, peerClient)
}()
return peerClient.Send(args)
}
}

View File

@@ -23,7 +23,7 @@ func MakeFromVals(members ...string)*Set {
} }
func (set *Set) Add(val string) int { func (set *Set) Add(val string) int {
return set.dict.Put(val, true) return set.dict.Put(val, nil)
} }
func (set *Set) Remove(val string) int { func (set *Set) Remove(val string) int {

View File

@@ -120,7 +120,7 @@ func GeoDist(db *DB, args [][]byte) redis.Reply {
if len(args) == 4 { if len(args) == 4 {
unit = strings.ToLower(string(args[3])) unit = strings.ToLower(string(args[3]))
} }
dis := geohash.Distance(positions[0][1], positions[0][0], positions[1][1], positions[1][0]) dis := geohash.Distance(positions[0][0], positions[0][1], positions[1][0], positions[1][1])
switch unit { switch unit {
case "m": case "m":
disStr := strconv.FormatFloat(dis, 'f', -1, 64) disStr := strconv.FormatFloat(dis, 'f', -1, 64)
@@ -200,7 +200,7 @@ func GeoRadius(db *DB, args [][]byte) redis.Reply {
func GeoRadiusByMember(db *DB, args [][]byte) redis.Reply { func GeoRadiusByMember(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) < 4 { if len(args) < 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'georadiusbymember' command") return reply.MakeErrReply("ERR wrong number of arguments for 'georadiusbymember' command")
} }
@@ -224,13 +224,15 @@ func GeoRadiusByMember(db *DB, args [][]byte) redis.Reply {
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not a valid float") return reply.MakeErrReply("ERR value is not a valid float")
} }
unit := strings.ToLower(string(args[4])) if len(args) > 3 {
unit := strings.ToLower(string(args[3]))
if unit == "m" { if unit == "m" {
} else if unit == "km" { } else if unit == "km" {
radius *= 1000 radius *= 1000
} else { } else {
return reply.MakeErrReply("ERR unsupported unit provided. please use m, km") return reply.MakeErrReply("ERR unsupported unit provided. please use m, km")
} }
}
return geoRadius0(sortedSet, lat, lng, radius) return geoRadius0(sortedSet, lat, lng, radius)
} }

87
src/db/geo_test.go Normal file
View File

@@ -0,0 +1,87 @@
package db
import (
"fmt"
"github.com/HDT3213/godis/src/redis/reply"
"github.com/HDT3213/godis/src/redis/reply/asserts"
"strconv"
"testing"
)
func TestGeoHash(t *testing.T) {
FlushDB(testDB, toArgs())
key := RandString(10)
pos := RandString(10)
result := GeoAdd(testDB, toArgs(key, "13.361389", "38.115556", pos))
asserts.AssertIntReply(t, result, 1)
result = GeoHash(testDB, toArgs(key, pos))
asserts.AssertMultiBulkReply(t, result, []string{"sqc8b49rnys00"})
}
func TestGeoRadius(t *testing.T) {
FlushDB(testDB, toArgs())
key := RandString(10)
pos1 := RandString(10)
pos2 := RandString(10)
GeoAdd(testDB, toArgs(key,
"13.361389", "38.115556", pos1,
"15.087269", "37.502669", pos2,
))
result := GeoRadius(testDB, toArgs(key, "15", "37", "200", "km"))
asserts.AssertMultiBulkReplySize(t, result, 2)
}
func TestGeoRadiusByMember(t *testing.T) {
FlushDB(testDB, toArgs())
key := RandString(10)
pos1 := RandString(10)
pos2 := RandString(10)
pivot := RandString(10)
GeoAdd(testDB, toArgs(key,
"13.361389", "38.115556", pos1,
"17.087269", "38.502669", pos2,
"13.583333", "37.316667", pivot,
))
result := GeoRadiusByMember(testDB, toArgs(key, pivot, "100", "km"))
asserts.AssertMultiBulkReplySize(t, result, 2)
}
func TestGeoPos(t *testing.T) {
FlushDB(testDB, toArgs())
key := RandString(10)
pos1 := RandString(10)
pos2 := RandString(10)
GeoAdd(testDB, toArgs(key,
"13.361389", "38.115556", pos1,
))
result := GeoPos(testDB, toArgs(key, pos1, pos2))
expected := "*2\r\n*2\r\n$18\r\n13.361386698670685\r\n$17\r\n38.11555536696687\r\n*0\r\n"
if string(result.ToBytes()) != expected {
t.Error("test failed")
}
}
func TestGeoDist(t *testing.T) {
FlushDB(testDB, toArgs())
key := RandString(10)
pos1 := RandString(10)
pos2 := RandString(10)
GeoAdd(testDB, toArgs(key,
"13.361389", "38.115556", pos1,
"15.087269", "37.502669", pos2,
))
result := GeoDist(testDB, toArgs(key, pos1, pos2, "km"))
bulkReply, ok := result.(*reply.BulkReply)
if !ok {
t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes()))
return
}
dist, err := strconv.ParseFloat(string(bulkReply.Arg), 10)
if err != nil {
t.Error(err)
return
}
if dist < 166.274 || dist > 166.275 {
t.Errorf("expected 166.274, actual: %f", dist)
}
}

View File

@@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"github.com/HDT3213/godis/src/datastruct/utils" "github.com/HDT3213/godis/src/datastruct/utils"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"math/rand" "github.com/HDT3213/godis/src/redis/reply/asserts"
"strconv" "strconv"
"testing" "testing"
) )
@@ -14,10 +14,10 @@ func TestHSet(t *testing.T) {
size := 100 size := 100
// test hset // test hset
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := make(map[string][]byte, size) values := make(map[string][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
field := strconv.Itoa(i) field := strconv.Itoa(i)
values[field] = []byte(value) values[field] = []byte(value)
result := HSet(testDB, toArgs(key, field, value)) result := HSet(testDB, toArgs(key, field, value))
@@ -51,10 +51,10 @@ func TestHDel(t *testing.T) {
size := 100 size := 100
// set values // set values
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
fields := make([]string, size) fields := make([]string, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
field := strconv.Itoa(i) field := strconv.Itoa(i)
fields[i] = field fields[i] = field
HSet(testDB, toArgs(key, field, value)) HSet(testDB, toArgs(key, field, value))
@@ -79,13 +79,13 @@ func TestHMSet(t *testing.T) {
size := 100 size := 100
// test hset // test hset
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
fields := make([]string, size) fields := make([]string, size)
values := make([]string, size) values := make([]string, size)
setArgs := []string{key} setArgs := []string{key}
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
fields[i] = strconv.FormatInt(int64(rand.Int()), 10) fields[i] = RandString(10)
values[i] = strconv.FormatInt(int64(rand.Int()), 10) values[i] = RandString(10)
setArgs = append(setArgs, fields[i], values[i]) setArgs = append(setArgs, fields[i], values[i])
} }
result := HMSet(testDB, toArgs(setArgs...)) result := HMSet(testDB, toArgs(setArgs...))
@@ -106,14 +106,14 @@ func TestHMSet(t *testing.T) {
func TestHGetAll(t *testing.T) { func TestHGetAll(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
fields := make([]string, size) fields := make([]string, size)
valueSet := make(map[string]bool, size) valueSet := make(map[string]bool, size)
valueMap := make(map[string]string) valueMap := make(map[string]string)
all := make([]string, 0) all := make([]string, 0)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
fields[i] = strconv.FormatInt(int64(rand.Int()), 10) fields[i] = RandString(10)
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
all = append(all, fields[i], value) all = append(all, fields[i], value)
valueMap[fields[i]] = value valueMap[fields[i]] = value
valueSet[value] = true valueSet[value] = true
@@ -179,7 +179,7 @@ func TestHGetAll(t *testing.T) {
func TestHIncrBy(t *testing.T) { func TestHIncrBy(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
result := HIncrBy(testDB, toArgs(key, "a", "1")) result := HIncrBy(testDB, toArgs(key, "a", "1"))
if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "1" { if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "1" {
t.Error(fmt.Sprintf("expected %s, actually %s", "1", string(bulkResult.Arg))) t.Error(fmt.Sprintf("expected %s, actually %s", "1", string(bulkResult.Arg)))
@@ -198,3 +198,18 @@ func TestHIncrBy(t *testing.T) {
t.Error(fmt.Sprintf("expected %s, actually %s", "2.4", string(bulkResult.Arg))) t.Error(fmt.Sprintf("expected %s, actually %s", "2.4", string(bulkResult.Arg)))
} }
} }
func TestHSetNX(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
field := RandString(10)
value := RandString(10)
result := HSetNX(testDB, toArgs(key, field, value))
asserts.AssertIntReply(t, result, 1)
value2 := RandString(10)
result = HSetNX(testDB, toArgs(key, field, value2))
asserts.AssertIntReply(t, result, 0)
result = HGet(testDB, toArgs(key, field))
asserts.AssertBulkReply(t, result, value)
}

View File

@@ -76,7 +76,7 @@ func Type(db *DB, args [][]byte) redis.Reply {
return reply.MakeStatusReply("string") return reply.MakeStatusReply("string")
case *list.LinkedList: case *list.LinkedList:
return reply.MakeStatusReply("list") return reply.MakeStatusReply("list")
case *dict.Dict: case dict.Dict:
return reply.MakeStatusReply("hash") return reply.MakeStatusReply("hash")
case *set.Set: case *set.Set:
return reply.MakeStatusReply("set") return reply.MakeStatusReply("set")
@@ -101,10 +101,11 @@ func Rename(db *DB, args [][]byte) redis.Reply {
return reply.MakeErrReply("no such key") return reply.MakeErrReply("no such key")
} }
rawTTL, hasTTL := db.TTLMap.Get(src) rawTTL, hasTTL := db.TTLMap.Get(src)
db.Put(dest, entity)
db.Remove(src)
if hasTTL {
db.Persist(src) // clean src and dest with their ttl db.Persist(src) // clean src and dest with their ttl
db.Persist(dest) db.Persist(dest)
db.Put(dest, entity)
if hasTTL {
expireTime, _ := rawTTL.(time.Time) expireTime, _ := rawTTL.(time.Time)
db.Expire(dest, expireTime) db.Expire(dest, expireTime)
} }
@@ -135,6 +136,8 @@ func RenameNx(db *DB, args [][]byte) redis.Reply {
db.Removes(src, dest) // clean src and dest with their ttl db.Removes(src, dest) // clean src and dest with their ttl
db.Put(dest, entity) db.Put(dest, entity)
if hasTTL { if hasTTL {
db.Persist(src) // clean src and dest with their ttl
db.Persist(dest)
expireTime, _ := rawTTL.(time.Time) expireTime, _ := rawTTL.(time.Time)
db.Expire(dest, expireTime) db.Expire(dest, expireTime)
} }
@@ -161,7 +164,7 @@ func Expire(db *DB, args [][]byte) redis.Reply {
expireAt := time.Now().Add(ttl) expireAt := time.Now().Add(ttl)
db.Expire(key, expireAt) db.Expire(key, expireAt)
db.AddAof(makeExpireCmd(key, expireAt), ) db.AddAof(makeExpireCmd(key, expireAt))
return reply.MakeIntReply(1) return reply.MakeIntReply(1)
} }

197
src/db/keys_test.go Normal file
View File

@@ -0,0 +1,197 @@
package db
import (
"fmt"
"github.com/HDT3213/godis/src/redis/reply"
"github.com/HDT3213/godis/src/redis/reply/asserts"
"strconv"
"testing"
"time"
)
func TestExists(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
Set(testDB, toArgs(key, value))
result := Exists(testDB, toArgs(key))
asserts.AssertIntReply(t, result, 1)
key = RandString(10)
result = Exists(testDB, toArgs(key))
asserts.AssertIntReply(t, result, 0)
}
func TestType(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
Set(testDB, toArgs(key, value))
result := Type(testDB, toArgs(key))
asserts.AssertStatusReply(t, result, "string")
Del(testDB, toArgs(key))
result = Type(testDB, toArgs(key))
asserts.AssertStatusReply(t, result, "none")
RPush(testDB, toArgs(key, value))
result = Type(testDB, toArgs(key))
asserts.AssertStatusReply(t, result, "list")
Del(testDB, toArgs(key))
HSet(testDB, toArgs(key, key, value))
result = Type(testDB, toArgs(key))
asserts.AssertStatusReply(t, result, "hash")
Del(testDB, toArgs(key))
SAdd(testDB, toArgs(key, value))
result = Type(testDB, toArgs(key))
asserts.AssertStatusReply(t, result, "set")
Del(testDB, toArgs(key))
ZAdd(testDB, toArgs(key, "1", value))
result = Type(testDB, toArgs(key))
asserts.AssertStatusReply(t, result, "zset")
}
func TestRename(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
newKey := key + RandString(2)
Set(testDB, toArgs(key, value, "ex", "1000"))
result := Rename(testDB, toArgs(key, newKey))
if _, ok := result.(*reply.OkReply); !ok {
t.Error("expect ok")
return
}
result = Exists(testDB, toArgs(key))
asserts.AssertIntReply(t, result, 0)
result = Exists(testDB, toArgs(newKey))
asserts.AssertIntReply(t, result, 1)
// check ttl
result = TTL(testDB, toArgs(newKey))
intResult, ok := result.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes()))
return
}
if intResult.Code <= 0 {
t.Errorf("expected ttl more than 0, actual: %d", intResult.Code)
return
}
}
func TestRenameNx(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
newKey := key + RandString(2)
Set(testDB, toArgs(key, value, "ex", "1000"))
result := RenameNx(testDB, toArgs(key, newKey))
if _, ok := result.(*reply.OkReply); !ok {
t.Error("expect ok")
return
}
result = Exists(testDB, toArgs(key))
asserts.AssertIntReply(t, result, 0)
result = Exists(testDB, toArgs(newKey))
asserts.AssertIntReply(t, result, 1)
result = TTL(testDB, toArgs(newKey))
intResult, ok := result.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes()))
return
}
if intResult.Code <= 0 {
t.Errorf("expected ttl more than 0, actual: %d", intResult.Code)
return
}
}
func TestTTL(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
Set(testDB, toArgs(key, value))
result := Expire(testDB, toArgs(key, "1000"))
asserts.AssertIntReply(t, result, 1)
result = TTL(testDB, toArgs(key))
intResult, ok := result.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes()))
return
}
if intResult.Code <= 0 {
t.Errorf("expected ttl more than 0, actual: %d", intResult.Code)
return
}
result = Persist(testDB, toArgs(key))
asserts.AssertIntReply(t, result, 1)
result = TTL(testDB, toArgs(key))
asserts.AssertIntReply(t, result, -1)
result = PExpire(testDB, toArgs(key, "1000000"))
asserts.AssertIntReply(t, result, 1)
result = PTTL(testDB, toArgs(key))
intResult, ok = result.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes()))
return
}
if intResult.Code <= 0 {
t.Errorf("expected ttl more than 0, actual: %d", intResult.Code)
return
}
}
func TestExpireAt(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
Set(testDB, toArgs(key, value))
expireAt := time.Now().Add(time.Minute).Unix()
result := ExpireAt(testDB, toArgs(key, strconv.FormatInt(expireAt, 10)))
asserts.AssertIntReply(t, result, 1)
result = TTL(testDB, toArgs(key))
intResult, ok := result.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes()))
return
}
if intResult.Code <= 0 {
t.Errorf("expected ttl more than 0, actual: %d", intResult.Code)
return
}
expireAt = time.Now().Add(time.Minute).Unix()
result = PExpireAt(testDB, toArgs(key, strconv.FormatInt(expireAt*1000, 10)))
asserts.AssertIntReply(t, result, 1)
result = TTL(testDB, toArgs(key))
intResult, ok = result.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes()))
return
}
if intResult.Code <= 0 {
t.Errorf("expected ttl more than 0, actual: %d", intResult.Code)
return
}
}
func TestKeys(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
Set(testDB, toArgs(key, value))
Set(testDB, toArgs("a:"+key, value))
Set(testDB, toArgs("b:"+key, value))
result := Keys(testDB, toArgs("*"))
asserts.AssertMultiBulkReplySize(t, result, 3)
result = Keys(testDB, toArgs("a:*"))
asserts.AssertMultiBulkReplySize(t, result, 1)
result = Keys(testDB, toArgs("?:*"))
asserts.AssertMultiBulkReplySize(t, result, 2)
}

View File

@@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"github.com/HDT3213/godis/src/datastruct/utils" "github.com/HDT3213/godis/src/datastruct/utils"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"math/rand"
"strconv" "strconv"
"testing" "testing"
) )
@@ -14,10 +13,10 @@ func TestPush(t *testing.T) {
size := 100 size := 100
// rpush single // rpush single
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := make([][]byte, size) values := make([][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
values[i] = []byte(value) values[i] = []byte(value)
result := RPush(testDB, toArgs(key, value)) result := RPush(testDB, toArgs(key, value))
if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(i+1) { if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(i+1) {
@@ -32,11 +31,11 @@ func TestPush(t *testing.T) {
Del(testDB, toArgs(key)) Del(testDB, toArgs(key))
// rpush multi // rpush multi
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
values = make([][]byte, size+1) values = make([][]byte, size+1)
values[0] = []byte(key) values[0] = []byte(key)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
values[i+1] = []byte(value) values[i+1] = []byte(value)
} }
result := RPush(testDB, values) result := RPush(testDB, values)
@@ -51,10 +50,10 @@ func TestPush(t *testing.T) {
Del(testDB, toArgs(key)) Del(testDB, toArgs(key))
// left push single // left push single
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
values = make([][]byte, size) values = make([][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
values[size-i-1] = []byte(value) values[size-i-1] = []byte(value)
result = LPush(testDB, toArgs(key, value)) result = LPush(testDB, toArgs(key, value))
if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(i+1) { if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(i+1) {
@@ -69,12 +68,12 @@ func TestPush(t *testing.T) {
Del(testDB, toArgs(key)) Del(testDB, toArgs(key))
// left push multi // left push multi
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
values = make([][]byte, size+1) values = make([][]byte, size+1)
values[0] = []byte(key) values[0] = []byte(key)
expectedValues := make([][]byte, size) expectedValues := make([][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
values[i+1] = []byte(value) values[i+1] = []byte(value)
expectedValues[size-i-1] = []byte(value) expectedValues[size-i-1] = []byte(value)
} }
@@ -94,10 +93,10 @@ func TestLRange(t *testing.T) {
// prepare list // prepare list
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := make([][]byte, size) values := make([][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
RPush(testDB, toArgs(key, value)) RPush(testDB, toArgs(key, value))
values[i] = []byte(value) values[i] = []byte(value)
} }
@@ -147,10 +146,10 @@ func TestLIndex(t *testing.T) {
// prepare list // prepare list
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := make([][]byte, size) values := make([][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
RPush(testDB, toArgs(key, value)) RPush(testDB, toArgs(key, value))
values[i] = []byte(value) values[i] = []byte(value)
} }
@@ -180,7 +179,7 @@ func TestLIndex(t *testing.T) {
func TestLRem(t *testing.T) { func TestLRem(t *testing.T) {
// prepare list // prepare list
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := []string{key, "a", "b", "a", "a", "c", "a", "a"} values := []string{key, "a", "b", "a", "a", "c", "a", "a"}
RPush(testDB, toArgs(values...)) RPush(testDB, toArgs(values...))
@@ -214,7 +213,7 @@ func TestLRem(t *testing.T) {
func TestLSet(t *testing.T) { func TestLSet(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := []string{key, "a", "b", "c", "d", "e", "f"} values := []string{key, "a", "b", "c", "d", "e", "f"}
RPush(testDB, toArgs(values...)) RPush(testDB, toArgs(values...))
@@ -222,7 +221,7 @@ func TestLSet(t *testing.T) {
size := len(values) - 1 size := len(values) - 1
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
indexStr := strconv.Itoa(i) indexStr := strconv.Itoa(i)
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
result := LSet(testDB, toArgs(key, indexStr, value)) result := LSet(testDB, toArgs(key, indexStr, value))
if _, ok := result.(*reply.OkReply); !ok { if _, ok := result.(*reply.OkReply); !ok {
t.Error(fmt.Sprintf("expected OK, actually %s", string(result.ToBytes()))) t.Error(fmt.Sprintf("expected OK, actually %s", string(result.ToBytes())))
@@ -235,7 +234,7 @@ func TestLSet(t *testing.T) {
} }
// test negative index // test negative index
for i := 1; i <= size; i++ { for i := 1; i <= size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
result := LSet(testDB, toArgs(key, strconv.Itoa(-i), value)) result := LSet(testDB, toArgs(key, strconv.Itoa(-i), value))
if _, ok := result.(*reply.OkReply); !ok { if _, ok := result.(*reply.OkReply); !ok {
t.Error(fmt.Sprintf("expected OK, actually %s", string(result.ToBytes()))) t.Error(fmt.Sprintf("expected OK, actually %s", string(result.ToBytes())))
@@ -248,7 +247,7 @@ func TestLSet(t *testing.T) {
} }
// test illegal index // test illegal index
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
result := LSet(testDB, toArgs(key, strconv.Itoa(-len(values)-1), value)) result := LSet(testDB, toArgs(key, strconv.Itoa(-len(values)-1), value))
expected := reply.MakeErrReply("ERR index out of range") expected := reply.MakeErrReply("ERR index out of range")
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
@@ -267,7 +266,7 @@ func TestLSet(t *testing.T) {
func TestLPop(t *testing.T) { func TestLPop(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := []string{key, "a", "b", "c", "d", "e", "f"} values := []string{key, "a", "b", "c", "d", "e", "f"}
RPush(testDB, toArgs(values...)) RPush(testDB, toArgs(values...))
size := len(values) - 1 size := len(values) - 1
@@ -288,7 +287,7 @@ func TestLPop(t *testing.T) {
func TestRPop(t *testing.T) { func TestRPop(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := []string{key, "a", "b", "c", "d", "e", "f"} values := []string{key, "a", "b", "c", "d", "e", "f"}
RPush(testDB, toArgs(values...)) RPush(testDB, toArgs(values...))
size := len(values) - 1 size := len(values) - 1
@@ -309,8 +308,8 @@ func TestRPop(t *testing.T) {
func TestRPopLPush(t *testing.T) { func TestRPopLPush(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key1 := strconv.FormatInt(int64(rand.Int()), 10) key1 := RandString(10)
key2 := strconv.FormatInt(int64(rand.Int()), 10) key2 := RandString(10)
values := []string{key1, "a", "b", "c", "d", "e", "f"} values := []string{key1, "a", "b", "c", "d", "e", "f"}
RPush(testDB, toArgs(values...)) RPush(testDB, toArgs(values...))
size := len(values) - 1 size := len(values) - 1
@@ -335,7 +334,7 @@ func TestRPopLPush(t *testing.T) {
func TestRPushX(t *testing.T) { func TestRPushX(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
result := RPushX(testDB, toArgs(key, "1")) result := RPushX(testDB, toArgs(key, "1"))
expected := reply.MakeIntReply(int64(0)) expected := reply.MakeIntReply(int64(0))
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
@@ -344,7 +343,7 @@ func TestRPushX(t *testing.T) {
RPush(testDB, toArgs(key, "1")) RPush(testDB, toArgs(key, "1"))
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
result := RPushX(testDB, toArgs(key, value)) result := RPushX(testDB, toArgs(key, value))
expected := reply.MakeIntReply(int64(i + 2)) expected := reply.MakeIntReply(int64(i + 2))
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
@@ -360,7 +359,7 @@ func TestRPushX(t *testing.T) {
func TestLPushX(t *testing.T) { func TestLPushX(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
result := RPushX(testDB, toArgs(key, "1")) result := RPushX(testDB, toArgs(key, "1"))
expected := reply.MakeIntReply(int64(0)) expected := reply.MakeIntReply(int64(0))
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
@@ -369,7 +368,7 @@ func TestLPushX(t *testing.T) {
LPush(testDB, toArgs(key, "1")) LPush(testDB, toArgs(key, "1"))
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
result := LPushX(testDB, toArgs(key, value)) result := LPushX(testDB, toArgs(key, value))
expected := reply.MakeIntReply(int64(i + 2)) expected := reply.MakeIntReply(int64(i + 2))
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
@@ -383,4 +382,3 @@ func TestLPushX(t *testing.T) {
} }
} }

View File

@@ -7,7 +7,6 @@ import (
"strconv" "strconv"
) )
func (db *DB) getAsSet(key string) (*HashSet.Set, reply.ErrorReply) { func (db *DB) getAsSet(key string) (*HashSet.Set, reply.ErrorReply) {
entity, exists := db.Get(key) entity, exists := db.Get(key)
if !exists { if !exists {
@@ -151,7 +150,6 @@ func SMembers(db *DB, args [][]byte) redis.Reply {
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
arr := make([][]byte, set.Len()) arr := make([][]byte, set.Len())
i := 0 i := 0
set.ForEach(func(member string) bool { set.ForEach(func(member string) bool {
@@ -490,7 +488,7 @@ func SRandMember(db *DB, args [][]byte) redis.Reply {
count := int(count64) count := int(count64)
if count > 0 { if count > 0 {
members := set.RandomMembers(count) members := set.RandomDistinctMembers(count)
result := make([][]byte, len(members)) result := make([][]byte, len(members))
for i, v := range members { for i, v := range members {
@@ -498,7 +496,7 @@ func SRandMember(db *DB, args [][]byte) redis.Reply {
} }
return reply.MakeMultiBulkReply(result) return reply.MakeMultiBulkReply(result)
} else if count < 0 { } else if count < 0 {
members := set.RandomDistinctMembers(-count) members := set.RandomMembers(-count)
result := make([][]byte, len(members)) result := make([][]byte, len(members))
for i, v := range members { for i, v := range members {
result[i] = []byte(v) result[i] = []byte(v)

181
src/db/set_test.go Normal file
View File

@@ -0,0 +1,181 @@
package db
import (
"fmt"
"github.com/HDT3213/godis/src/redis/reply"
"github.com/HDT3213/godis/src/redis/reply/asserts"
"strconv"
"testing"
)
// basic add get and remove
func TestSAdd(t *testing.T) {
FlushAll(testDB, [][]byte{})
size := 100
// test sadd
key := RandString(10)
for i := 0; i < size; i++ {
member := strconv.Itoa(i)
result := SAdd(testDB, toArgs(key, member))
asserts.AssertIntReply(t, result, 1)
}
// test scard
result := SCard(testDB, toArgs(key))
asserts.AssertIntReply(t, result, size)
// test is member
for i := 0; i < size; i++ {
member := strconv.Itoa(i)
result := SIsMember(testDB, toArgs(key, member))
asserts.AssertIntReply(t, result, 1)
}
// test members
result = SMembers(testDB, toArgs(key))
multiBulk, ok := result.(*reply.MultiBulkReply)
if !ok {
t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes()))
return
}
if len(multiBulk.Args) != size {
t.Error(fmt.Sprintf("expected %d elements, actually %d", size, len(multiBulk.Args)))
return
}
}
func TestSRem(t *testing.T) {
FlushAll(testDB, [][]byte{})
size := 100
// mock data
key := RandString(10)
for i := 0; i < size; i++ {
member := strconv.Itoa(i)
SAdd(testDB, toArgs(key, member))
}
for i := 0; i < size; i++ {
member := strconv.Itoa(i)
SRem(testDB, toArgs(key, member))
result := SIsMember(testDB, toArgs(key, member))
asserts.AssertIntReply(t, result, 0)
}
}
func TestSInter(t *testing.T) {
FlushAll(testDB, [][]byte{})
size := 100
step := 10
keys := make([]string, 0)
start := 0
for i := 0; i < 4; i++ {
key := RandString(10)
keys = append(keys, key)
for j := start; j < size+start; j++ {
member := strconv.Itoa(j)
SAdd(testDB, toArgs(key, member))
}
start += step
}
result := SInter(testDB, toArgs(keys...))
asserts.AssertMultiBulkReplySize(t, result, 70)
destKey := RandString(10)
keysWithDest := []string{destKey}
keysWithDest = append(keysWithDest, keys...)
result = SInterStore(testDB, toArgs(keysWithDest...))
asserts.AssertIntReply(t, result, 70)
}
func TestSUnion(t *testing.T) {
FlushAll(testDB, [][]byte{})
size := 100
step := 10
keys := make([]string, 0)
start := 0
for i := 0; i < 4; i++ {
key := RandString(10)
keys = append(keys, key)
for j := start; j < size+start; j++ {
member := strconv.Itoa(j)
SAdd(testDB, toArgs(key, member))
}
start += step
}
result := SUnion(testDB, toArgs(keys...))
asserts.AssertMultiBulkReplySize(t, result, 130)
destKey := RandString(10)
keysWithDest := []string{destKey}
keysWithDest = append(keysWithDest, keys...)
result = SUnionStore(testDB, toArgs(keysWithDest...))
asserts.AssertIntReply(t, result, 130)
}
func TestSDiff(t *testing.T) {
FlushAll(testDB, [][]byte{})
size := 100
step := 20
keys := make([]string, 0)
start := 0
for i := 0; i < 3; i++ {
key := RandString(10)
keys = append(keys, key)
for j := start; j < size+start; j++ {
member := strconv.Itoa(j)
SAdd(testDB, toArgs(key, member))
}
start += step
}
result := SDiff(testDB, toArgs(keys...))
asserts.AssertMultiBulkReplySize(t, result, step)
destKey := RandString(10)
keysWithDest := []string{destKey}
keysWithDest = append(keysWithDest, keys...)
result = SDiffStore(testDB, toArgs(keysWithDest...))
asserts.AssertIntReply(t, result, step)
}
func TestSRandMember(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
for j := 0; j < 100; j++ {
member := strconv.Itoa(j)
SAdd(testDB, toArgs(key, member))
}
result := SRandMember(testDB, toArgs(key))
br, ok := result.(*reply.BulkReply)
if !ok && len(br.Arg) > 0 {
t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes()))
return
}
result = SRandMember(testDB, toArgs(key, "10"))
asserts.AssertMultiBulkReplySize(t, result, 10)
multiBulk, ok := result.(*reply.MultiBulkReply)
if !ok {
t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes()))
return
}
m := make(map[string]struct{})
for _, arg := range multiBulk.Args {
m[string(arg)] = struct{}{}
}
if len(m) != 10 {
t.Error(fmt.Sprintf("expected 10 members, actually %d", len(m)))
return
}
result = SRandMember(testDB, toArgs(key, "110"))
asserts.AssertMultiBulkReplySize(t, result, 100)
result = SRandMember(testDB, toArgs(key, "-10"))
asserts.AssertMultiBulkReplySize(t, result, 10)
result = SRandMember(testDB, toArgs(key, "-110"))
asserts.AssertMultiBulkReplySize(t, result, 110)
}

View File

@@ -12,12 +12,12 @@ func TestZAdd(t *testing.T) {
size := 100 size := 100
// add new members // add new members
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
members := make([]string, size) members := make([]string, size)
scores := make([]float64, size) scores := make([]float64, size)
setArgs := []string{key} setArgs := []string{key}
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
members[i] = strconv.FormatInt(int64(rand.Int()), 10) members[i] = RandString(10)
scores[i] = rand.Float64() scores[i] = rand.Float64()
setArgs = append(setArgs, strconv.FormatFloat(scores[i], 'f', -1, 64), members[i]) setArgs = append(setArgs, strconv.FormatFloat(scores[i], 'f', -1, 64), members[i])
} }
@@ -55,12 +55,12 @@ func TestZAdd(t *testing.T) {
func TestZRank(t *testing.T) { func TestZRank(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
members := make([]string, size) members := make([]string, size)
scores := make([]int, size) scores := make([]int, size)
setArgs := []string{key} setArgs := []string{key}
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
members[i] = strconv.FormatInt(int64(rand.Int()), 10) members[i] = RandString(10)
scores[i] = i scores[i] = i
setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i])
} }
@@ -80,12 +80,12 @@ func TestZRange(t *testing.T) {
// prepare // prepare
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
members := make([]string, size) members := make([]string, size)
scores := make([]int, size) scores := make([]int, size)
setArgs := []string{key} setArgs := []string{key}
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
members[i] = strconv.FormatInt(int64(rand.Int()), 10) members[i] = RandString(10)
scores[i] = i scores[i] = i
setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i])
} }
@@ -143,7 +143,7 @@ func TestZRangeByScore(t *testing.T) {
// prepare // prepare
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
members := make([]string, size) members := make([]string, size)
scores := make([]int, size) scores := make([]int, size)
setArgs := []string{key} setArgs := []string{key}
@@ -193,7 +193,7 @@ func TestZRangeByScore(t *testing.T) {
func TestZRem(t *testing.T) { func TestZRem(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
members := make([]string, size) members := make([]string, size)
scores := make([]int, size) scores := make([]int, size)
setArgs := []string{key} setArgs := []string{key}
@@ -214,7 +214,7 @@ func TestZRem(t *testing.T) {
// test ZRemRangeByRank // test ZRemRangeByRank
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size = 100 size = 100
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
members = make([]string, size) members = make([]string, size)
scores = make([]int, size) scores = make([]int, size)
setArgs = []string{key} setArgs = []string{key}
@@ -233,7 +233,7 @@ func TestZRem(t *testing.T) {
// test ZRemRangeByScore // test ZRemRangeByScore
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size = 100 size = 100
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
members = make([]string, size) members = make([]string, size)
scores = make([]int, size) scores = make([]int, size)
setArgs = []string{key} setArgs = []string{key}
@@ -254,7 +254,7 @@ func TestZCount(t *testing.T) {
// prepare // prepare
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
members := make([]string, size) members := make([]string, size)
scores := make([]int, size) scores := make([]int, size)
setArgs := []string{key} setArgs := []string{key}
@@ -288,7 +288,7 @@ func TestZCount(t *testing.T) {
func TestZIncrBy(t *testing.T) { func TestZIncrBy(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
ZAdd(testDB, toArgs(key, "10", "a")) ZAdd(testDB, toArgs(key, "10", "a"))
result := ZIncrBy(testDB, toArgs(key, "10", "a")) result := ZIncrBy(testDB, toArgs(key, "10", "a"))
asserts.AssertBulkReply(t, result, "20") asserts.AssertBulkReply(t, result, "20")

View File

@@ -159,9 +159,7 @@ func SetNX(db *DB, args [][]byte) redis.Reply {
Data: value, Data: value,
} }
result := db.PutIfAbsent(key, entity) result := db.PutIfAbsent(key, entity)
if result > 0 {
db.AddAof(makeAofCmd("setnx", args)) db.AddAof(makeAofCmd("setnx", args))
}
return reply.MakeIntReply(int64(result)) return reply.MakeIntReply(int64(result))
} }
@@ -188,43 +186,42 @@ func SetEX(db *DB, args [][]byte) redis.Reply {
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
result := db.Put(key, entity) db.Put(key, entity)
expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond) expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond)
db.Expire(key, expireTime) db.Expire(key, expireTime)
if result > 0 {
db.AddAof(makeAofCmd("setex", args)) db.AddAof(makeAofCmd("setex", args))
db.AddAof(makeExpireCmd(key, expireTime)) db.AddAof(makeExpireCmd(key, expireTime))
}
return &reply.OkReply{} return &reply.OkReply{}
} }
func PSetEX(db *DB, args [][]byte) redis.Reply { func PSetEX(db *DB, args [][]byte) redis.Reply {
if len(args) != 3 { if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'psetex' command") return reply.MakeErrReply("ERR wrong number of arguments for 'setex' command")
} }
key := string(args[0]) key := string(args[0])
value := args[1] value := args[2]
ttl, err := strconv.ParseInt(string(args[1]), 10, 64) ttlArg, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil { if err != nil {
return &reply.SyntaxErrReply{} return &reply.SyntaxErrReply{}
} }
if ttl <= 0 { if ttlArg <= 0 {
return reply.MakeErrReply("ERR invalid expire time in psetex") return reply.MakeErrReply("ERR invalid expire time in setex")
} }
entity := &DataEntity{ entity := &DataEntity{
Data: value, Data: value,
} }
result := db.PutIfExists(key, entity)
if result > 0 { db.Lock(key)
db.AddAof(makeAofCmd("psetex", args)) defer db.UnLock(key)
if ttl != unlimitedTTL {
expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond) db.Put(key, entity)
expireTime := time.Now().Add(time.Duration(ttlArg) * time.Millisecond)
db.Expire(key, expireTime) db.Expire(key, expireTime)
db.AddAof(makeAofCmd("setex", args))
db.AddAof(makeExpireCmd(key, expireTime)) db.AddAof(makeExpireCmd(key, expireTime))
}
}
return &reply.OkReply{} return &reply.OkReply{}
} }
@@ -326,7 +323,9 @@ func GetSet(db *DB, args [][]byte) redis.Reply {
db.Put(key, &DataEntity{Data: value}) db.Put(key, &DataEntity{Data: value})
db.Persist(key) // override ttl db.Persist(key) // override ttl
db.AddAof(makeAofCmd("getset", args)) db.AddAof(makeAofCmd("getset", args))
if old == nil {
return new(reply.NullBulkReply)
}
return reply.MakeBulkReply(old) return reply.MakeBulkReply(old)
} }

View File

@@ -1,9 +1,10 @@
package db package db
import ( import (
"fmt"
"github.com/HDT3213/godis/src/datastruct/utils" "github.com/HDT3213/godis/src/datastruct/utils"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"math/rand" "github.com/HDT3213/godis/src/redis/reply/asserts"
"strconv" "strconv"
"testing" "testing"
) )
@@ -12,8 +13,8 @@ var testDB = makeTestDB()
func TestSet(t *testing.T) { func TestSet(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
// normal set // normal set
Set(testDB, toArgs(key, value)) Set(testDB, toArgs(key, value))
@@ -30,8 +31,8 @@ func TestSet(t *testing.T) {
} }
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
value = strconv.FormatInt(int64(rand.Int()), 10) value = RandString(10)
Set(testDB, toArgs(key, value, "NX")) Set(testDB, toArgs(key, value, "NX"))
actual = Get(testDB, toArgs(key)) actual = Get(testDB, toArgs(key))
expected = reply.MakeBulkReply([]byte(value)) expected = reply.MakeBulkReply([]byte(value))
@@ -41,8 +42,8 @@ func TestSet(t *testing.T) {
// set xx // set xx
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
value = strconv.FormatInt(int64(rand.Int()), 10) value = RandString(10)
actual = Set(testDB, toArgs(key, value, "XX")) actual = Set(testDB, toArgs(key, value, "XX"))
if _, ok := actual.(*reply.NullBulkReply); !ok { if _, ok := actual.(*reply.NullBulkReply); !ok {
t.Error("expected true actually false ") t.Error("expected true actually false ")
@@ -51,17 +52,47 @@ func TestSet(t *testing.T) {
Set(testDB, toArgs(key, value)) Set(testDB, toArgs(key, value))
Set(testDB, toArgs(key, value, "XX")) Set(testDB, toArgs(key, value, "XX"))
actual = Get(testDB, toArgs(key)) actual = Get(testDB, toArgs(key))
expected = reply.MakeBulkReply([]byte(value)) asserts.AssertBulkReply(t, actual, value)
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) // set ex
Del(testDB, toArgs(key))
ttl := "1000"
Set(testDB, toArgs(key, value, "EX", ttl))
actual = Get(testDB, toArgs(key))
asserts.AssertBulkReply(t, actual, value)
actual = TTL(testDB, toArgs(key))
intResult, ok := actual.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes()))
return
}
if intResult.Code <= 0 || intResult.Code > 1000 {
t.Error(fmt.Sprintf("expected int between [0, 1000], actually %d", intResult.Code))
return
} }
// set px
Del(testDB, toArgs(key))
ttlPx := "1000000"
Set(testDB, toArgs(key, value, "PX", ttlPx))
actual = Get(testDB, toArgs(key))
asserts.AssertBulkReply(t, actual, value)
actual = TTL(testDB, toArgs(key))
intResult, ok = actual.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes()))
return
}
if intResult.Code <= 0 || intResult.Code > 1000 {
t.Error(fmt.Sprintf("expected int between [0, 1000], actually %d", intResult.Code))
return
}
} }
func TestSetNX(t *testing.T) { func TestSetNX(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
SetNX(testDB, toArgs(key, value)) SetNX(testDB, toArgs(key, value))
actual := Get(testDB, toArgs(key)) actual := Get(testDB, toArgs(key))
expected := reply.MakeBulkReply([]byte(value)) expected := reply.MakeBulkReply([]byte(value))
@@ -78,15 +109,43 @@ func TestSetNX(t *testing.T) {
func TestSetEX(t *testing.T) { func TestSetEX(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
ttl := "1000" ttl := "1000"
SetEX(testDB, toArgs(key, ttl, value)) SetEX(testDB, toArgs(key, ttl, value))
actual := Get(testDB, toArgs(key)) actual := Get(testDB, toArgs(key))
expected2 := reply.MakeBulkReply([]byte(value)) asserts.AssertBulkReply(t, actual, value)
if !utils.BytesEquals(actual.ToBytes(), expected2.ToBytes()) { actual = TTL(testDB, toArgs(key))
t.Error("expected: " + string(expected2.ToBytes()) + ", actual: " + string(actual.ToBytes())) intResult, ok := actual.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes()))
return
}
if intResult.Code <= 0 || intResult.Code > 1000 {
t.Error(fmt.Sprintf("expected int between [0, 1000], actually %d", intResult.Code))
return
}
}
func TestPSetEX(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
ttl := "1000000"
PSetEX(testDB, toArgs(key, ttl, value))
actual := Get(testDB, toArgs(key))
asserts.AssertBulkReply(t, actual, value)
actual = PTTL(testDB, toArgs(key))
intResult, ok := actual.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes()))
return
}
if intResult.Code <= 0 || intResult.Code > 1000000 {
t.Error(fmt.Sprintf("expected int between [0, 1000], actually %d", intResult.Code))
return
} }
} }
@@ -97,8 +156,8 @@ func TestMSet(t *testing.T) {
values := make([][]byte, size) values := make([][]byte, size)
args := make([]string, 0, size*2) args := make([]string, 0, size*2)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
keys[i] = strconv.FormatInt(int64(rand.Int()), 10) keys[i] = RandString(10)
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
values[i] = []byte(value) values[i] = []byte(value)
args = append(args, keys[i], value) args = append(args, keys[i], value)
} }
@@ -113,7 +172,7 @@ func TestMSet(t *testing.T) {
func TestIncr(t *testing.T) { func TestIncr(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 10 size := 10
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
Incr(testDB, toArgs(key)) Incr(testDB, toArgs(key))
actual := Get(testDB, toArgs(key)) actual := Get(testDB, toArgs(key))
@@ -132,7 +191,7 @@ func TestIncr(t *testing.T) {
} }
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
IncrBy(testDB, toArgs(key, "1")) IncrBy(testDB, toArgs(key, "1"))
actual := Get(testDB, toArgs(key)) actual := Get(testDB, toArgs(key))
@@ -141,12 +200,89 @@ func TestIncr(t *testing.T) {
t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes()))
} }
} }
Del(testDB, toArgs(key))
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
IncrByFloat(testDB, toArgs(key, "-1.0")) IncrByFloat(testDB, toArgs(key, "-1.0"))
actual := Get(testDB, toArgs(key)) actual := Get(testDB, toArgs(key))
expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(size-i-1), 10))) expected := -i - 1
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { bulk, ok := actual.(*reply.BulkReply)
t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) if !ok {
t.Error(fmt.Sprintf("expected bulk reply, actually %s", actual.ToBytes()))
return
}
val, err := strconv.ParseFloat(string(bulk.Arg), 10)
if err != nil {
t.Error(err)
return
}
if int(val) != expected {
t.Errorf("expect %d, actual: %d", expected, int(val))
return
} }
} }
} }
func TestDecr(t *testing.T) {
FlushAll(testDB, [][]byte{})
size := 10
key := RandString(10)
for i := 0; i < size; i++ {
Decr(testDB, toArgs(key))
actual := Get(testDB, toArgs(key))
asserts.AssertBulkReply(t, actual, strconv.Itoa(-i-1))
}
Del(testDB, toArgs(key))
for i := 0; i < size; i++ {
DecrBy(testDB, toArgs(key, "1"))
actual := Get(testDB, toArgs(key))
expected := -i - 1
bulk, ok := actual.(*reply.BulkReply)
if !ok {
t.Error(fmt.Sprintf("expected bulk reply, actually %s", actual.ToBytes()))
return
}
val, err := strconv.ParseFloat(string(bulk.Arg), 10)
if err != nil {
t.Error(err)
return
}
if int(val) != expected {
t.Errorf("expect %d, actual: %d", expected, int(val))
return
}
}
}
func TestGetSet(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
result := GetSet(testDB, toArgs(key, value))
_, ok := result.(*reply.NullBulkReply)
if !ok {
t.Errorf("expect null bulk reply, get: %s", string(result.ToBytes()))
return
}
value2 := RandString(10)
result = GetSet(testDB, toArgs(key, value2))
asserts.AssertBulkReply(t, result, value)
result = Get(testDB, toArgs(key))
asserts.AssertBulkReply(t, result, value2)
}
func TestMSetNX(t *testing.T) {
FlushAll(testDB, [][]byte{})
size := 10
args := make([]string, 0, size*2)
for i := 0; i < size; i++ {
str := RandString(10)
args = append(args, str, str)
}
result := MSetNX(testDB, toArgs(args...))
asserts.AssertIntReply(t, result, 1)
result = MSetNX(testDB, toArgs(args[0:4]...))
asserts.AssertIntReply(t, result, 0)
}

View File

@@ -3,6 +3,7 @@ package db
import ( import (
"github.com/HDT3213/godis/src/datastruct/dict" "github.com/HDT3213/godis/src/datastruct/dict"
"github.com/HDT3213/godis/src/datastruct/lock" "github.com/HDT3213/godis/src/datastruct/lock"
"math/rand"
"time" "time"
) )
@@ -22,3 +23,13 @@ func toArgs(cmd ...string) [][]byte {
} }
return args return args
} }
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
func RandString(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
}

View File

@@ -5,45 +5,78 @@ import (
"github.com/HDT3213/godis/src/datastruct/utils" "github.com/HDT3213/godis/src/datastruct/utils"
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"runtime"
"testing" "testing"
) )
func AssertIntReply(t *testing.T, actual redis.Reply, expected int) { func AssertIntReply(t *testing.T, actual redis.Reply, expected int) {
intResult, ok := actual.(*reply.IntReply) intResult, ok := actual.(*reply.IntReply)
if !ok { if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes())) t.Errorf("expected int reply, actually %s, %s", actual.ToBytes(), printStack())
return return
} }
if intResult.Code != int64(expected) { if intResult.Code != int64(expected) {
t.Error(fmt.Sprintf("expected %d, actually %d", expected, intResult.Code)) t.Errorf("expected %d, actually %d, %s", expected, intResult.Code, printStack())
} }
} }
func AssertBulkReply(t *testing.T, actual redis.Reply, expected string) { func AssertBulkReply(t *testing.T, actual redis.Reply, expected string) {
bulkReply, ok := actual.(*reply.BulkReply) bulkReply, ok := actual.(*reply.BulkReply)
if !ok { if !ok {
t.Error(fmt.Sprintf("expected bulk reply, actually %s", actual.ToBytes())) t.Errorf("expected bulk reply, actually %s, %s", actual.ToBytes(), printStack())
return return
} }
if !utils.BytesEquals(bulkReply.Arg, []byte(expected)) { if !utils.BytesEquals(bulkReply.Arg, []byte(expected)) {
t.Error(fmt.Sprintf("expected %s, actually %s", expected, actual.ToBytes())) t.Errorf("expected %s, actually %s, %s", expected, actual.ToBytes(), printStack())
}
}
func AssertStatusReply(t *testing.T, actual redis.Reply, expected string) {
statusReply, ok := actual.(*reply.StatusReply)
if !ok {
t.Errorf("expected bulk reply, actually %s, %s", actual.ToBytes(), printStack())
return
}
if statusReply.Status != expected {
t.Errorf("expected %s, actually %s, %s", expected, actual.ToBytes(), printStack())
} }
} }
func AssertMultiBulkReply(t *testing.T, actual redis.Reply, expected []string) { func AssertMultiBulkReply(t *testing.T, actual redis.Reply, expected []string) {
multiBulk, ok := actual.(*reply.MultiBulkReply) multiBulk, ok := actual.(*reply.MultiBulkReply)
if !ok { if !ok {
t.Error(fmt.Sprintf("expected bulk reply, actually %s", actual.ToBytes())) t.Errorf("expected bulk reply, actually %s, %s", actual.ToBytes(), printStack())
return return
} }
if len(multiBulk.Args) != len(expected) { if len(multiBulk.Args) != len(expected) {
t.Error(fmt.Sprintf("expected %d elements, actually %d", len(expected), len(multiBulk.Args))) t.Errorf("expected %d elements, actually %d, %s",
len(expected), len(multiBulk.Args), printStack())
return return
} }
for i, v := range multiBulk.Args { for i, v := range multiBulk.Args {
actual := string(v) str := string(v)
if actual != expected[i] { if str != expected[i] {
t.Error(fmt.Sprintf("expected %s, actually %s", expected[i], actual)) t.Errorf("expected %s, actually %s, %s", expected[i], actual, printStack())
} }
} }
} }
func AssertMultiBulkReplySize(t *testing.T, actual redis.Reply, expected int) {
multiBulk, ok := actual.(*reply.MultiBulkReply)
if !ok {
t.Errorf("expected bulk reply, actually %s, %s", actual.ToBytes(), printStack())
return
}
if len(multiBulk.Args) != expected {
t.Errorf("expected %d elements, actually %d, %s", expected, len(multiBulk.Args), printStack())
return
}
}
func printStack() string {
_, file, no, ok := runtime.Caller(2)
if ok {
return fmt.Sprintf("at %s#%d", file, no)
}
return ""
}