add some unit tests

This commit is contained in:
hdt3213
2021-03-31 17:11:46 +08:00
parent 4b01bbb52a
commit bf913a5aca
16 changed files with 2866 additions and 2198 deletions

View File

@@ -1,177 +1,133 @@
package cluster package cluster
import ( import (
"context" "context"
"errors" "fmt"
"fmt" "github.com/HDT3213/godis/src/cluster/idgenerator"
"github.com/HDT3213/godis/src/cluster/idgenerator" "github.com/HDT3213/godis/src/config"
"github.com/HDT3213/godis/src/config" "github.com/HDT3213/godis/src/datastruct/dict"
"github.com/HDT3213/godis/src/datastruct/dict" "github.com/HDT3213/godis/src/db"
"github.com/HDT3213/godis/src/db" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/lib/consistenthash"
"github.com/HDT3213/godis/src/lib/consistenthash" "github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/lib/logger" "github.com/HDT3213/godis/src/redis/reply"
"github.com/HDT3213/godis/src/redis/client" "github.com/jolestar/go-commons-pool/v2"
"github.com/HDT3213/godis/src/redis/reply" "runtime/debug"
"github.com/jolestar/go-commons-pool/v2" "strings"
"runtime/debug"
"strings"
) )
type Cluster struct { type Cluster struct {
self string self string
peerPicker *consistenthash.Map peerPicker *consistenthash.Map
peerConnection map[string]*pool.ObjectPool peerConnection map[string]*pool.ObjectPool
db *db.DB db *db.DB
transactions *dict.SimpleDict // id -> Transaction transactions *dict.SimpleDict // id -> Transaction
idGenerator *idgenerator.IdGenerator idGenerator *idgenerator.IdGenerator
} }
const ( const (
replicas = 4 replicas = 4
lockSize = 64 lockSize = 64
) )
func MakeCluster() *Cluster { func MakeCluster() *Cluster {
cluster := &Cluster{ cluster := &Cluster{
self: config.Properties.Self, self: config.Properties.Self,
db: db.MakeDB(), db: db.MakeDB(),
transactions: dict.MakeSimple(), transactions: dict.MakeSimple(),
peerPicker: consistenthash.New(replicas, nil), peerPicker: consistenthash.New(replicas, nil),
peerConnection: make(map[string]*pool.ObjectPool), peerConnection: make(map[string]*pool.ObjectPool),
idGenerator: idgenerator.MakeGenerator("godis", config.Properties.Self), idGenerator: idgenerator.MakeGenerator("godis", config.Properties.Self),
} }
if config.Properties.Peers != nil && len(config.Properties.Peers) > 0 && config.Properties.Self != "" { if config.Properties.Peers != nil && len(config.Properties.Peers) > 0 && config.Properties.Self != "" {
contains := make(map[string]bool) contains := make(map[string]bool)
peers := make([]string, 0, len(config.Properties.Peers)+1) peers := make([]string, 0, len(config.Properties.Peers)+1)
for _, peer := range config.Properties.Peers { for _, peer := range config.Properties.Peers {
if _, ok := contains[peer]; ok { if _, ok := contains[peer]; ok {
continue continue
} }
contains[peer] = true contains[peer] = true
peers = append(peers, peer) peers = append(peers, peer)
} }
peers = append(peers, config.Properties.Self) peers = append(peers, config.Properties.Self)
cluster.peerPicker.Add(peers...) cluster.peerPicker.Add(peers...)
ctx := context.Background() ctx := context.Background()
for _, peer := range peers { for _, peer := range peers {
cluster.peerConnection[peer] = pool.NewObjectPoolWithDefaultConfig(ctx, &ConnectionFactory{ cluster.peerConnection[peer] = pool.NewObjectPoolWithDefaultConfig(ctx, &ConnectionFactory{
Peer: peer, Peer: peer,
}) })
} }
} }
return cluster return cluster
} }
// args contains all // args contains all
type CmdFunc func(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply type CmdFunc func(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply
func (cluster *Cluster) Close() { func (cluster *Cluster) Close() {
cluster.db.Close() cluster.db.Close()
} }
var router = MakeRouter() var router = MakeRouter()
func (cluster *Cluster) Exec(c redis.Connection, args [][]byte) (result redis.Reply) { func (cluster *Cluster) Exec(c redis.Connection, args [][]byte) (result redis.Reply) {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack()))) logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack())))
result = &reply.UnknownErrReply{} result = &reply.UnknownErrReply{}
} }
}() }()
cmd := strings.ToLower(string(args[0])) cmd := strings.ToLower(string(args[0]))
cmdFunc, ok := router[cmd] cmdFunc, ok := router[cmd]
if !ok { if !ok {
return reply.MakeErrReply("ERR unknown command '" + cmd + "', or not supported in cluster mode") return reply.MakeErrReply("ERR unknown command '" + cmd + "', or not supported in cluster mode")
} }
result = cmdFunc(cluster, c, args) result = cmdFunc(cluster, c, args)
return return
} }
func (cluster *Cluster) AfterClientClose(c redis.Connection) { func (cluster *Cluster) AfterClientClose(c redis.Connection) {
} }
func (cluster *Cluster) getPeerClient(peer string) (*client.Client, error) {
connectionFactory, ok := cluster.peerConnection[peer]
if !ok {
return nil, errors.New("connection factory not found")
}
raw, err := connectionFactory.BorrowObject(context.Background())
if err != nil {
return nil, err
}
conn, ok := raw.(*client.Client)
if !ok {
return nil, errors.New("connection factory make wrong type")
}
return conn, nil
}
func (cluster *Cluster) returnPeerClient(peer string, peerClient *client.Client) error {
connectionFactory, ok := cluster.peerConnection[peer]
if !ok {
return errors.New("connection factory not found")
}
return connectionFactory.ReturnObject(context.Background(), peerClient)
}
func Ping(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { func Ping(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) == 1 { if len(args) == 1 {
return &reply.PongReply{} return &reply.PongReply{}
} else if len(args) == 2 { } else if len(args) == 2 {
return reply.MakeStatusReply("\"" + string(args[1]) + "\"") return reply.MakeStatusReply("\"" + string(args[1]) + "\"")
} else { } else {
return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command") return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command")
} }
}
// relay command to peer
// cannot call Prepare, Commit, Rollback of self node
func (cluster *Cluster) Relay(peer string, c redis.Connection, args [][]byte) redis.Reply {
if peer == cluster.self {
// to self db
return cluster.db.Exec(c, args)
} else {
peerClient, err := cluster.getPeerClient(peer)
if err != nil {
return reply.MakeErrReply(err.Error())
}
defer func() {
_ = cluster.returnPeerClient(peer, peerClient)
}()
return peerClient.Send(args)
}
} }
/*----- utils -------*/ /*----- utils -------*/
func makeArgs(cmd string, args ...string) [][]byte { func makeArgs(cmd string, args ...string) [][]byte {
result := make([][]byte, len(args)+1) result := make([][]byte, len(args)+1)
result[0] = []byte(cmd) result[0] = []byte(cmd)
for i, arg := range args { for i, arg := range args {
result[i+1] = []byte(arg) result[i+1] = []byte(arg)
} }
return result return result
} }
// return peer -> keys // return peer -> keys
func (cluster *Cluster) groupBy(keys []string) map[string][]string { func (cluster *Cluster) groupBy(keys []string) map[string][]string {
result := make(map[string][]string) result := make(map[string][]string)
for _, key := range keys { for _, key := range keys {
peer := cluster.peerPicker.Get(key) peer := cluster.peerPicker.Get(key)
group, ok := result[peer] group, ok := result[peer]
if !ok { if !ok {
group = make([]string, 0) group = make([]string, 0)
} }
group = append(group, key) group = append(group, key)
result[peer] = group result[peer] = group
} }
return result return result
} }

52
src/cluster/com.go Normal file
View File

@@ -0,0 +1,52 @@
// communicate with peers within cluster
package cluster
import (
"context"
"errors"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/client"
"github.com/HDT3213/godis/src/redis/reply"
)
func (cluster *Cluster) getPeerClient(peer string) (*client.Client, error) {
connectionFactory, ok := cluster.peerConnection[peer]
if !ok {
return nil, errors.New("connection factory not found")
}
raw, err := connectionFactory.BorrowObject(context.Background())
if err != nil {
return nil, err
}
conn, ok := raw.(*client.Client)
if !ok {
return nil, errors.New("connection factory make wrong type")
}
return conn, nil
}
func (cluster *Cluster) returnPeerClient(peer string, peerClient *client.Client) error {
connectionFactory, ok := cluster.peerConnection[peer]
if !ok {
return errors.New("connection factory not found")
}
return connectionFactory.ReturnObject(context.Background(), peerClient)
}
// relay command to peer
// cannot call Prepare, Commit, Rollback of self node
func (cluster *Cluster) Relay(peer string, c redis.Connection, args [][]byte) redis.Reply {
if peer == cluster.self {
// to self db
return cluster.db.Exec(c, args)
} else {
peerClient, err := cluster.getPeerClient(peer)
if err != nil {
return reply.MakeErrReply(err.Error())
}
defer func() {
_ = cluster.returnPeerClient(peer, peerClient)
}()
return peerClient.Send(args)
}
}

View File

@@ -3,114 +3,114 @@ package set
import "github.com/HDT3213/godis/src/datastruct/dict" import "github.com/HDT3213/godis/src/datastruct/dict"
type Set struct { type Set struct {
dict dict.Dict dict dict.Dict
} }
func Make() *Set { func Make() *Set {
return &Set{ return &Set{
dict: dict.MakeSimple(), dict: dict.MakeSimple(),
} }
} }
func MakeFromVals(members ...string)*Set { func MakeFromVals(members ...string) *Set {
set := &Set{ set := &Set{
dict: dict.MakeConcurrent(len(members)), dict: dict.MakeConcurrent(len(members)),
} }
for _, member := range members { for _, member := range members {
set.Add(member) set.Add(member)
} }
return set return set
} }
func (set *Set)Add(val string)int { func (set *Set) Add(val string) int {
return set.dict.Put(val, true) return set.dict.Put(val, nil)
} }
func (set *Set)Remove(val string)int { func (set *Set) Remove(val string) int {
return set.dict.Remove(val) return set.dict.Remove(val)
} }
func (set *Set)Has(val string)bool { func (set *Set) Has(val string) bool {
_, exists := set.dict.Get(val) _, exists := set.dict.Get(val)
return exists return exists
} }
func (set *Set)Len()int { func (set *Set) Len() int {
return set.dict.Len() return set.dict.Len()
} }
func (set *Set)ToSlice()[]string { func (set *Set) ToSlice() []string {
slice := make([]string, set.Len()) slice := make([]string, set.Len())
i := 0 i := 0
set.dict.ForEach(func(key string, val interface{})bool { set.dict.ForEach(func(key string, val interface{}) bool {
if i < len(slice) { if i < len(slice) {
slice[i] = key slice[i] = key
} else { } else {
// set extended during traversal // set extended during traversal
slice = append(slice, key) slice = append(slice, key)
} }
i++ i++
return true return true
}) })
return slice return slice
} }
func (set *Set)ForEach(consumer func(member string)bool) { func (set *Set) ForEach(consumer func(member string) bool) {
set.dict.ForEach(func(key string, val interface{})bool { set.dict.ForEach(func(key string, val interface{}) bool {
return consumer(key) return consumer(key)
}) })
} }
func (set *Set)Intersect(another *Set)*Set { func (set *Set) Intersect(another *Set) *Set {
if set == nil { if set == nil {
panic("set is nil") panic("set is nil")
} }
result := Make() result := Make()
another.ForEach(func(member string)bool { another.ForEach(func(member string) bool {
if set.Has(member) { if set.Has(member) {
result.Add(member) result.Add(member)
} }
return true return true
}) })
return result return result
} }
func (set *Set)Union(another *Set)*Set { func (set *Set) Union(another *Set) *Set {
if set == nil { if set == nil {
panic("set is nil") panic("set is nil")
} }
result := Make() result := Make()
another.ForEach(func(member string)bool { another.ForEach(func(member string) bool {
result.Add(member) result.Add(member)
return true return true
}) })
set.ForEach(func(member string)bool { set.ForEach(func(member string) bool {
result.Add(member) result.Add(member)
return true return true
}) })
return result return result
} }
func (set *Set)Diff(another *Set)*Set { func (set *Set) Diff(another *Set) *Set {
if set == nil { if set == nil {
panic("set is nil") panic("set is nil")
} }
result := Make() result := Make()
set.ForEach(func(member string)bool { set.ForEach(func(member string) bool {
if !another.Has(member) { if !another.Has(member) {
result.Add(member) result.Add(member)
} }
return true return true
}) })
return result return result
} }
func (set *Set)RandomMembers(limit int)[]string { func (set *Set) RandomMembers(limit int) []string {
return set.dict.RandomKeys(limit) return set.dict.RandomKeys(limit)
} }
func (set *Set)RandomDistinctMembers(limit int)[]string { func (set *Set) RandomDistinctMembers(limit int) []string {
return set.dict.RandomDistinctKeys(limit) return set.dict.RandomDistinctKeys(limit)
} }

View File

@@ -1,249 +1,251 @@
package db package db
import ( import (
"fmt" "fmt"
"github.com/HDT3213/godis/src/datastruct/sortedset" "github.com/HDT3213/godis/src/datastruct/sortedset"
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/lib/geohash" "github.com/HDT3213/godis/src/lib/geohash"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"strconv" "strconv"
"strings" "strings"
) )
func GeoAdd(db *DB, args [][]byte) redis.Reply { func GeoAdd(db *DB, args [][]byte) redis.Reply {
if len(args) < 4 || len(args)%3 != 1 { if len(args) < 4 || len(args)%3 != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'geoadd' command") return reply.MakeErrReply("ERR wrong number of arguments for 'geoadd' command")
} }
key := string(args[0]) key := string(args[0])
size := (len(args) - 1) / 3 size := (len(args) - 1) / 3
elements := make([]*sortedset.Element, size) elements := make([]*sortedset.Element, size)
for i := 0; i < size; i += 1 { for i := 0; i < size; i += 1 {
lngStr := string(args[3*i+1]) lngStr := string(args[3*i+1])
latStr := string(args[3*i+2]) latStr := string(args[3*i+2])
lng, err := strconv.ParseFloat(lngStr, 64) lng, err := strconv.ParseFloat(lngStr, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not a valid float") return reply.MakeErrReply("ERR value is not a valid float")
} }
lat, err := strconv.ParseFloat(latStr, 64) lat, err := strconv.ParseFloat(latStr, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not a valid float") return reply.MakeErrReply("ERR value is not a valid float")
} }
if lat < -90 || lat > 90 || lng < -180 || lng > 180 { if lat < -90 || lat > 90 || lng < -180 || lng > 180 {
return reply.MakeErrReply(fmt.Sprintf("ERR invalid longitude,latitude pair %s,%s", latStr, lngStr)) return reply.MakeErrReply(fmt.Sprintf("ERR invalid longitude,latitude pair %s,%s", latStr, lngStr))
} }
code := float64(geohash.Encode(lat, lng)) code := float64(geohash.Encode(lat, lng))
elements[i] = &sortedset.Element{ elements[i] = &sortedset.Element{
Member: string(args[3*i+3]), Member: string(args[3*i+3]),
Score: code, Score: code,
} }
} }
// lock // lock
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
// get or init entity // get or init entity
sortedSet, _, errReply := db.getOrInitSortedSet(key) sortedSet, _, errReply := db.getOrInitSortedSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
i := 0 i := 0
for _, e := range elements { for _, e := range elements {
if sortedSet.Add(e.Member, e.Score) { if sortedSet.Add(e.Member, e.Score) {
i++ i++
} }
} }
db.AddAof(makeAofCmd("geoadd", args)) db.AddAof(makeAofCmd("geoadd", args))
return reply.MakeIntReply(int64(i)) return reply.MakeIntReply(int64(i))
} }
func GeoPos(db *DB, args [][]byte) redis.Reply { func GeoPos(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) < 1 { if len(args) < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'geopos' command") return reply.MakeErrReply("ERR wrong number of arguments for 'geopos' command")
} }
key := string(args[0]) key := string(args[0])
sortedSet, errReply := db.getAsSortedSet(key) sortedSet, errReply := db.getAsSortedSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if sortedSet == nil { if sortedSet == nil {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
positions := make([][]byte, len(args)-1) positions := make([][]byte, len(args)-1)
for i := 0; i < len(args)-1; i++ { for i := 0; i < len(args)-1; i++ {
member := string(args[i+1]) member := string(args[i+1])
elem, exists := sortedSet.Get(member) elem, exists := sortedSet.Get(member)
if !exists { if !exists {
positions[i] = (&reply.EmptyMultiBulkReply{}).ToBytes() positions[i] = (&reply.EmptyMultiBulkReply{}).ToBytes()
continue continue
} }
lat, lng := geohash.Decode(uint64(elem.Score)) lat, lng := geohash.Decode(uint64(elem.Score))
lngStr := strconv.FormatFloat(lng, 'f', -1, 64) lngStr := strconv.FormatFloat(lng, 'f', -1, 64)
latStr := strconv.FormatFloat(lat, 'f', -1, 64) latStr := strconv.FormatFloat(lat, 'f', -1, 64)
positions[i] = reply.MakeMultiBulkReply([][]byte{ positions[i] = reply.MakeMultiBulkReply([][]byte{
[]byte(lngStr), []byte(latStr), []byte(lngStr), []byte(latStr),
}).ToBytes() }).ToBytes()
} }
return reply.MakeMultiRawReply(positions) return reply.MakeMultiRawReply(positions)
} }
func GeoDist(db *DB, args [][]byte) redis.Reply { func GeoDist(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) != 3 && len(args) != 4 { if len(args) != 3 && len(args) != 4 {
return reply.MakeErrReply("ERR wrong number of arguments for 'geodist' command") return reply.MakeErrReply("ERR wrong number of arguments for 'geodist' command")
} }
key := string(args[0]) key := string(args[0])
sortedSet, errReply := db.getAsSortedSet(key) sortedSet, errReply := db.getAsSortedSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if sortedSet == nil { if sortedSet == nil {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
positions := make([][]float64, 2) positions := make([][]float64, 2)
for i := 1; i < 3; i++ { for i := 1; i < 3; i++ {
member := string(args[i]) member := string(args[i])
elem, exists := sortedSet.Get(member) elem, exists := sortedSet.Get(member)
if !exists { if !exists {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
lat, lng := geohash.Decode(uint64(elem.Score)) lat, lng := geohash.Decode(uint64(elem.Score))
positions[i-1] = []float64{lat, lng} positions[i-1] = []float64{lat, lng}
} }
unit := "m" unit := "m"
if len(args) == 4 { if len(args) == 4 {
unit = strings.ToLower(string(args[3])) unit = strings.ToLower(string(args[3]))
} }
dis := geohash.Distance(positions[0][1], positions[0][0], positions[1][1], positions[1][0]) dis := geohash.Distance(positions[0][0], positions[0][1], positions[1][0], positions[1][1])
switch unit { switch unit {
case "m": case "m":
disStr := strconv.FormatFloat(dis, 'f', -1, 64) disStr := strconv.FormatFloat(dis, 'f', -1, 64)
return reply.MakeBulkReply([]byte(disStr)) return reply.MakeBulkReply([]byte(disStr))
case "km": case "km":
disStr := strconv.FormatFloat(dis/1000, 'f', -1, 64) disStr := strconv.FormatFloat(dis/1000, 'f', -1, 64)
return reply.MakeBulkReply([]byte(disStr)) return reply.MakeBulkReply([]byte(disStr))
} }
return reply.MakeErrReply("ERR unsupported unit provided. please use m, km") return reply.MakeErrReply("ERR unsupported unit provided. please use m, km")
} }
func GeoHash(db *DB, args [][]byte) redis.Reply { func GeoHash(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) < 1 { if len(args) < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'geohash' command") return reply.MakeErrReply("ERR wrong number of arguments for 'geohash' command")
} }
key := string(args[0]) key := string(args[0])
sortedSet, errReply := db.getAsSortedSet(key) sortedSet, errReply := db.getAsSortedSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if sortedSet == nil { if sortedSet == nil {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
strs := make([][]byte, len(args)-1) strs := make([][]byte, len(args)-1)
for i := 0; i < len(args)-1; i++ { for i := 0; i < len(args)-1; i++ {
member := string(args[i+1]) member := string(args[i+1])
elem, exists := sortedSet.Get(member) elem, exists := sortedSet.Get(member)
if !exists { if !exists {
strs[i] = (&reply.EmptyMultiBulkReply{}).ToBytes() strs[i] = (&reply.EmptyMultiBulkReply{}).ToBytes()
continue continue
} }
str := geohash.ToString(geohash.FromInt(uint64(elem.Score))) str := geohash.ToString(geohash.FromInt(uint64(elem.Score)))
strs[i] = []byte(str) strs[i] = []byte(str)
} }
return reply.MakeMultiBulkReply(strs) return reply.MakeMultiBulkReply(strs)
} }
func GeoRadius(db *DB, args [][]byte) redis.Reply { func GeoRadius(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) < 5 { if len(args) < 5 {
return reply.MakeErrReply("ERR wrong number of arguments for 'georadius' command") return reply.MakeErrReply("ERR wrong number of arguments for 'georadius' command")
} }
key := string(args[0]) key := string(args[0])
sortedSet, errReply := db.getAsSortedSet(key) sortedSet, errReply := db.getAsSortedSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if sortedSet == nil { if sortedSet == nil {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
lng, err := strconv.ParseFloat(string(args[1]), 64) lng, err := strconv.ParseFloat(string(args[1]), 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not a valid float") return reply.MakeErrReply("ERR value is not a valid float")
} }
lat, err := strconv.ParseFloat(string(args[2]), 64) lat, err := strconv.ParseFloat(string(args[2]), 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not a valid float") return reply.MakeErrReply("ERR value is not a valid float")
} }
radius, err := strconv.ParseFloat(string(args[3]), 64) radius, err := strconv.ParseFloat(string(args[3]), 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not a valid float") return reply.MakeErrReply("ERR value is not a valid float")
} }
unit := strings.ToLower(string(args[4])) unit := strings.ToLower(string(args[4]))
if unit == "m" { if unit == "m" {
} else if unit == "km" { } else if unit == "km" {
radius *= 1000 radius *= 1000
} else { } else {
return reply.MakeErrReply("ERR unsupported unit provided. please use m, km") return reply.MakeErrReply("ERR unsupported unit provided. please use m, km")
} }
return geoRadius0(sortedSet, lat, lng, radius) return geoRadius0(sortedSet, lat, lng, radius)
} }
func GeoRadiusByMember(db *DB, args [][]byte) redis.Reply { func GeoRadiusByMember(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) < 4 { if len(args) < 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'georadiusbymember' command") return reply.MakeErrReply("ERR wrong number of arguments for 'georadiusbymember' command")
} }
key := string(args[0]) key := string(args[0])
sortedSet, errReply := db.getAsSortedSet(key) sortedSet, errReply := db.getAsSortedSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if sortedSet == nil { if sortedSet == nil {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
member := string(args[1]) member := string(args[1])
elem, ok := sortedSet.Get(member) elem, ok := sortedSet.Get(member)
if !ok { if !ok {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
lat, lng := geohash.Decode(uint64(elem.Score)) lat, lng := geohash.Decode(uint64(elem.Score))
radius, err := strconv.ParseFloat(string(args[2]), 64) radius, err := strconv.ParseFloat(string(args[2]), 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not a valid float") return reply.MakeErrReply("ERR value is not a valid float")
} }
unit := strings.ToLower(string(args[4])) if len(args) > 3 {
if unit == "m" { unit := strings.ToLower(string(args[3]))
} else if unit == "km" { if unit == "m" {
radius *= 1000 } else if unit == "km" {
} else { radius *= 1000
return reply.MakeErrReply("ERR unsupported unit provided. please use m, km") } else {
} return reply.MakeErrReply("ERR unsupported unit provided. please use m, km")
return geoRadius0(sortedSet, lat, lng, radius) }
}
return geoRadius0(sortedSet, lat, lng, radius)
} }
func geoRadius0(sortedSet *sortedset.SortedSet, lat float64, lng float64, radius float64) redis.Reply { func geoRadius0(sortedSet *sortedset.SortedSet, lat float64, lng float64, radius float64) redis.Reply {
areas := geohash.GetNeighbours(lat, lng, radius) areas := geohash.GetNeighbours(lat, lng, radius)
members := make([][]byte, 0) members := make([][]byte, 0)
for _, area := range areas { for _, area := range areas {
lower := &sortedset.ScoreBorder{Value: float64(area[0])} lower := &sortedset.ScoreBorder{Value: float64(area[0])}
upper := &sortedset.ScoreBorder{Value: float64(area[1])} upper := &sortedset.ScoreBorder{Value: float64(area[1])}
elements := sortedSet.RangeByScore(lower, upper, 0, -1, true) elements := sortedSet.RangeByScore(lower, upper, 0, -1, true)
for _, elem := range elements { for _, elem := range elements {
members = append(members, []byte(elem.Member)) members = append(members, []byte(elem.Member))
} }
} }
return reply.MakeMultiBulkReply(members) return reply.MakeMultiBulkReply(members)
} }

87
src/db/geo_test.go Normal file
View File

@@ -0,0 +1,87 @@
package db
import (
"fmt"
"github.com/HDT3213/godis/src/redis/reply"
"github.com/HDT3213/godis/src/redis/reply/asserts"
"strconv"
"testing"
)
func TestGeoHash(t *testing.T) {
FlushDB(testDB, toArgs())
key := RandString(10)
pos := RandString(10)
result := GeoAdd(testDB, toArgs(key, "13.361389", "38.115556", pos))
asserts.AssertIntReply(t, result, 1)
result = GeoHash(testDB, toArgs(key, pos))
asserts.AssertMultiBulkReply(t, result, []string{"sqc8b49rnys00"})
}
func TestGeoRadius(t *testing.T) {
FlushDB(testDB, toArgs())
key := RandString(10)
pos1 := RandString(10)
pos2 := RandString(10)
GeoAdd(testDB, toArgs(key,
"13.361389", "38.115556", pos1,
"15.087269", "37.502669", pos2,
))
result := GeoRadius(testDB, toArgs(key, "15", "37", "200", "km"))
asserts.AssertMultiBulkReplySize(t, result, 2)
}
func TestGeoRadiusByMember(t *testing.T) {
FlushDB(testDB, toArgs())
key := RandString(10)
pos1 := RandString(10)
pos2 := RandString(10)
pivot := RandString(10)
GeoAdd(testDB, toArgs(key,
"13.361389", "38.115556", pos1,
"17.087269", "38.502669", pos2,
"13.583333", "37.316667", pivot,
))
result := GeoRadiusByMember(testDB, toArgs(key, pivot, "100", "km"))
asserts.AssertMultiBulkReplySize(t, result, 2)
}
func TestGeoPos(t *testing.T) {
FlushDB(testDB, toArgs())
key := RandString(10)
pos1 := RandString(10)
pos2 := RandString(10)
GeoAdd(testDB, toArgs(key,
"13.361389", "38.115556", pos1,
))
result := GeoPos(testDB, toArgs(key, pos1, pos2))
expected := "*2\r\n*2\r\n$18\r\n13.361386698670685\r\n$17\r\n38.11555536696687\r\n*0\r\n"
if string(result.ToBytes()) != expected {
t.Error("test failed")
}
}
func TestGeoDist(t *testing.T) {
FlushDB(testDB, toArgs())
key := RandString(10)
pos1 := RandString(10)
pos2 := RandString(10)
GeoAdd(testDB, toArgs(key,
"13.361389", "38.115556", pos1,
"15.087269", "37.502669", pos2,
))
result := GeoDist(testDB, toArgs(key, pos1, pos2, "km"))
bulkReply, ok := result.(*reply.BulkReply)
if !ok {
t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes()))
return
}
dist, err := strconv.ParseFloat(string(bulkReply.Arg), 10)
if err != nil {
t.Error(err)
return
}
if dist < 166.274 || dist > 166.275 {
t.Errorf("expected 166.274, actual: %f", dist)
}
}

View File

@@ -1,200 +1,215 @@
package db package db
import ( import (
"fmt" "fmt"
"github.com/HDT3213/godis/src/datastruct/utils" "github.com/HDT3213/godis/src/datastruct/utils"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"math/rand" "github.com/HDT3213/godis/src/redis/reply/asserts"
"strconv" "strconv"
"testing" "testing"
) )
func TestHSet(t *testing.T) { func TestHSet(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
// test hset // test hset
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := make(map[string][]byte, size) values := make(map[string][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
field := strconv.Itoa(i) field := strconv.Itoa(i)
values[field] = []byte(value) values[field] = []byte(value)
result := HSet(testDB, toArgs(key, field, value)) result := HSet(testDB, toArgs(key, field, value))
if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(1) { if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(1) {
t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code))
} }
} }
// test hget and hexists // test hget and hexists
for field, v := range values { for field, v := range values {
actual := HGet(testDB, toArgs(key, field)) actual := HGet(testDB, toArgs(key, field))
expected := reply.MakeBulkReply(v) expected := reply.MakeBulkReply(v)
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(actual.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(actual.ToBytes())))
} }
actual = HExists(testDB, toArgs(key, field)) actual = HExists(testDB, toArgs(key, field))
if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(1) { if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(1) {
t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code))
} }
} }
// test hlen // test hlen
actual := HLen(testDB, toArgs(key)) actual := HLen(testDB, toArgs(key))
if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(len(values)) { if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(len(values)) {
t.Error(fmt.Sprintf("expected %d, actually %d", len(values), intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", len(values), intResult.Code))
} }
} }
func TestHDel(t *testing.T) { func TestHDel(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
// set values // set values
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
fields := make([]string, size) fields := make([]string, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
field := strconv.Itoa(i) field := strconv.Itoa(i)
fields[i] = field fields[i] = field
HSet(testDB, toArgs(key, field, value)) HSet(testDB, toArgs(key, field, value))
} }
// test HDel // test HDel
args := []string{key} args := []string{key}
args = append(args, fields...) args = append(args, fields...)
actual := HDel(testDB, toArgs(args...)) actual := HDel(testDB, toArgs(args...))
if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(len(fields)) { if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(len(fields)) {
t.Error(fmt.Sprintf("expected %d, actually %d", len(fields), intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", len(fields), intResult.Code))
} }
actual = HLen(testDB, toArgs(key)) actual = HLen(testDB, toArgs(key))
if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(0) { if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(0) {
t.Error(fmt.Sprintf("expected %d, actually %d", 0, intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", 0, intResult.Code))
} }
} }
func TestHMSet(t *testing.T) { func TestHMSet(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
// test hset // test hset
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
fields := make([]string, size) fields := make([]string, size)
values := make([]string, size) values := make([]string, size)
setArgs := []string{key} setArgs := []string{key}
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
fields[i] = strconv.FormatInt(int64(rand.Int()), 10) fields[i] = RandString(10)
values[i] = strconv.FormatInt(int64(rand.Int()), 10) values[i] = RandString(10)
setArgs = append(setArgs, fields[i], values[i]) setArgs = append(setArgs, fields[i], values[i])
} }
result := HMSet(testDB, toArgs(setArgs...)) result := HMSet(testDB, toArgs(setArgs...))
if _, ok := result.(*reply.OkReply); !ok { if _, ok := result.(*reply.OkReply); !ok {
t.Error(fmt.Sprintf("expected ok, actually %s", string(result.ToBytes()))) t.Error(fmt.Sprintf("expected ok, actually %s", string(result.ToBytes())))
} }
// test HMGet // test HMGet
getArgs := []string{key} getArgs := []string{key}
getArgs = append(getArgs, fields...) getArgs = append(getArgs, fields...)
actual := HMGet(testDB, toArgs(getArgs...)) actual := HMGet(testDB, toArgs(getArgs...))
expected := reply.MakeMultiBulkReply(toArgs(values...)) expected := reply.MakeMultiBulkReply(toArgs(values...))
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(actual.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(actual.ToBytes())))
} }
} }
func TestHGetAll(t *testing.T) { func TestHGetAll(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
fields := make([]string, size) fields := make([]string, size)
valueSet := make(map[string]bool, size) valueSet := make(map[string]bool, size)
valueMap := make(map[string]string) valueMap := make(map[string]string)
all := make([]string, 0) all := make([]string, 0)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
fields[i] = strconv.FormatInt(int64(rand.Int()), 10) fields[i] = RandString(10)
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
all = append(all, fields[i], value) all = append(all, fields[i], value)
valueMap[fields[i]] = value valueMap[fields[i]] = value
valueSet[value] = true valueSet[value] = true
HSet(testDB, toArgs(key, fields[i], value)) HSet(testDB, toArgs(key, fields[i], value))
} }
// test HGetAll // test HGetAll
result := HGetAll(testDB, toArgs(key)) result := HGetAll(testDB, toArgs(key))
multiBulk, ok := result.(*reply.MultiBulkReply) multiBulk, ok := result.(*reply.MultiBulkReply)
if !ok { if !ok {
t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes()))) t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes())))
} }
if 2*len(fields) != len(multiBulk.Args) { if 2*len(fields) != len(multiBulk.Args) {
t.Error(fmt.Sprintf("expected %d items , actually %d ", 2*len(fields), len(multiBulk.Args))) t.Error(fmt.Sprintf("expected %d items , actually %d ", 2*len(fields), len(multiBulk.Args)))
} }
for i := range fields { for i := range fields {
field := string(multiBulk.Args[2*i]) field := string(multiBulk.Args[2*i])
actual := string(multiBulk.Args[2*i+1]) actual := string(multiBulk.Args[2*i+1])
expected, ok := valueMap[field] expected, ok := valueMap[field]
if !ok { if !ok {
t.Error(fmt.Sprintf("unexpected field %s", field)) t.Error(fmt.Sprintf("unexpected field %s", field))
continue continue
} }
if actual != expected { if actual != expected {
t.Error(fmt.Sprintf("expected %s, actually %s", expected, actual)) t.Error(fmt.Sprintf("expected %s, actually %s", expected, actual))
} }
} }
// test HKeys // test HKeys
result = HKeys(testDB, toArgs(key)) result = HKeys(testDB, toArgs(key))
multiBulk, ok = result.(*reply.MultiBulkReply) multiBulk, ok = result.(*reply.MultiBulkReply)
if !ok { if !ok {
t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes()))) t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes())))
} }
if len(fields) != len(multiBulk.Args) { if len(fields) != len(multiBulk.Args) {
t.Error(fmt.Sprintf("expected %d items , actually %d ", len(fields), len(multiBulk.Args))) t.Error(fmt.Sprintf("expected %d items , actually %d ", len(fields), len(multiBulk.Args)))
} }
for _, v := range multiBulk.Args { for _, v := range multiBulk.Args {
field := string(v) field := string(v)
if _, ok := valueMap[field]; !ok { if _, ok := valueMap[field]; !ok {
t.Error(fmt.Sprintf("unexpected field %s", field)) t.Error(fmt.Sprintf("unexpected field %s", field))
} }
} }
// test HVals // test HVals
result = HVals(testDB, toArgs(key)) result = HVals(testDB, toArgs(key))
multiBulk, ok = result.(*reply.MultiBulkReply) multiBulk, ok = result.(*reply.MultiBulkReply)
if !ok { if !ok {
t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes()))) t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes())))
} }
if len(fields) != len(multiBulk.Args) { if len(fields) != len(multiBulk.Args) {
t.Error(fmt.Sprintf("expected %d items , actually %d ", len(fields), len(multiBulk.Args))) t.Error(fmt.Sprintf("expected %d items , actually %d ", len(fields), len(multiBulk.Args)))
} }
for _, v := range multiBulk.Args { for _, v := range multiBulk.Args {
value := string(v) value := string(v)
_, ok := valueSet[value] _, ok := valueSet[value]
if !ok { if !ok {
t.Error(fmt.Sprintf("unexpected value %s", value)) t.Error(fmt.Sprintf("unexpected value %s", value))
} }
} }
} }
func TestHIncrBy(t *testing.T) { func TestHIncrBy(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
result := HIncrBy(testDB, toArgs(key, "a", "1")) result := HIncrBy(testDB, toArgs(key, "a", "1"))
if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "1" { if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "1" {
t.Error(fmt.Sprintf("expected %s, actually %s", "1", string(bulkResult.Arg))) t.Error(fmt.Sprintf("expected %s, actually %s", "1", string(bulkResult.Arg)))
} }
result = HIncrBy(testDB, toArgs(key, "a", "1")) result = HIncrBy(testDB, toArgs(key, "a", "1"))
if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "2" { if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "2" {
t.Error(fmt.Sprintf("expected %s, actually %s", "2", string(bulkResult.Arg))) t.Error(fmt.Sprintf("expected %s, actually %s", "2", string(bulkResult.Arg)))
} }
result = HIncrByFloat(testDB, toArgs(key, "b", "1.2"))
if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "1.2" {
t.Error(fmt.Sprintf("expected %s, actually %s", "1.2", string(bulkResult.Arg)))
}
result = HIncrByFloat(testDB, toArgs(key, "b", "1.2"))
if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "2.4" {
t.Error(fmt.Sprintf("expected %s, actually %s", "2.4", string(bulkResult.Arg)))
}
}
func TestHSetNX(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
field := RandString(10)
value := RandString(10)
result := HSetNX(testDB, toArgs(key, field, value))
asserts.AssertIntReply(t, result, 1)
value2 := RandString(10)
result = HSetNX(testDB, toArgs(key, field, value2))
asserts.AssertIntReply(t, result, 0)
result = HGet(testDB, toArgs(key, field))
asserts.AssertBulkReply(t, result, value)
result = HIncrByFloat(testDB, toArgs(key, "b", "1.2"))
if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "1.2" {
t.Error(fmt.Sprintf("expected %s, actually %s", "1.2", string(bulkResult.Arg)))
}
result = HIncrByFloat(testDB, toArgs(key, "b", "1.2"))
if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "2.4" {
t.Error(fmt.Sprintf("expected %s, actually %s", "2.4", string(bulkResult.Arg)))
}
} }

View File

@@ -76,7 +76,7 @@ func Type(db *DB, args [][]byte) redis.Reply {
return reply.MakeStatusReply("string") return reply.MakeStatusReply("string")
case *list.LinkedList: case *list.LinkedList:
return reply.MakeStatusReply("list") return reply.MakeStatusReply("list")
case *dict.Dict: case dict.Dict:
return reply.MakeStatusReply("hash") return reply.MakeStatusReply("hash")
case *set.Set: case *set.Set:
return reply.MakeStatusReply("set") return reply.MakeStatusReply("set")
@@ -101,10 +101,11 @@ func Rename(db *DB, args [][]byte) redis.Reply {
return reply.MakeErrReply("no such key") return reply.MakeErrReply("no such key")
} }
rawTTL, hasTTL := db.TTLMap.Get(src) rawTTL, hasTTL := db.TTLMap.Get(src)
db.Persist(src) // clean src and dest with their ttl
db.Persist(dest)
db.Put(dest, entity) db.Put(dest, entity)
db.Remove(src)
if hasTTL { if hasTTL {
db.Persist(src) // clean src and dest with their ttl
db.Persist(dest)
expireTime, _ := rawTTL.(time.Time) expireTime, _ := rawTTL.(time.Time)
db.Expire(dest, expireTime) db.Expire(dest, expireTime)
} }
@@ -135,6 +136,8 @@ func RenameNx(db *DB, args [][]byte) redis.Reply {
db.Removes(src, dest) // clean src and dest with their ttl db.Removes(src, dest) // clean src and dest with their ttl
db.Put(dest, entity) db.Put(dest, entity)
if hasTTL { if hasTTL {
db.Persist(src) // clean src and dest with their ttl
db.Persist(dest)
expireTime, _ := rawTTL.(time.Time) expireTime, _ := rawTTL.(time.Time)
db.Expire(dest, expireTime) db.Expire(dest, expireTime)
} }
@@ -161,7 +164,7 @@ func Expire(db *DB, args [][]byte) redis.Reply {
expireAt := time.Now().Add(ttl) expireAt := time.Now().Add(ttl)
db.Expire(key, expireAt) db.Expire(key, expireAt)
db.AddAof(makeExpireCmd(key, expireAt), ) db.AddAof(makeExpireCmd(key, expireAt))
return reply.MakeIntReply(1) return reply.MakeIntReply(1)
} }

197
src/db/keys_test.go Normal file
View File

@@ -0,0 +1,197 @@
package db
import (
"fmt"
"github.com/HDT3213/godis/src/redis/reply"
"github.com/HDT3213/godis/src/redis/reply/asserts"
"strconv"
"testing"
"time"
)
func TestExists(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
Set(testDB, toArgs(key, value))
result := Exists(testDB, toArgs(key))
asserts.AssertIntReply(t, result, 1)
key = RandString(10)
result = Exists(testDB, toArgs(key))
asserts.AssertIntReply(t, result, 0)
}
func TestType(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
Set(testDB, toArgs(key, value))
result := Type(testDB, toArgs(key))
asserts.AssertStatusReply(t, result, "string")
Del(testDB, toArgs(key))
result = Type(testDB, toArgs(key))
asserts.AssertStatusReply(t, result, "none")
RPush(testDB, toArgs(key, value))
result = Type(testDB, toArgs(key))
asserts.AssertStatusReply(t, result, "list")
Del(testDB, toArgs(key))
HSet(testDB, toArgs(key, key, value))
result = Type(testDB, toArgs(key))
asserts.AssertStatusReply(t, result, "hash")
Del(testDB, toArgs(key))
SAdd(testDB, toArgs(key, value))
result = Type(testDB, toArgs(key))
asserts.AssertStatusReply(t, result, "set")
Del(testDB, toArgs(key))
ZAdd(testDB, toArgs(key, "1", value))
result = Type(testDB, toArgs(key))
asserts.AssertStatusReply(t, result, "zset")
}
func TestRename(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
newKey := key + RandString(2)
Set(testDB, toArgs(key, value, "ex", "1000"))
result := Rename(testDB, toArgs(key, newKey))
if _, ok := result.(*reply.OkReply); !ok {
t.Error("expect ok")
return
}
result = Exists(testDB, toArgs(key))
asserts.AssertIntReply(t, result, 0)
result = Exists(testDB, toArgs(newKey))
asserts.AssertIntReply(t, result, 1)
// check ttl
result = TTL(testDB, toArgs(newKey))
intResult, ok := result.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes()))
return
}
if intResult.Code <= 0 {
t.Errorf("expected ttl more than 0, actual: %d", intResult.Code)
return
}
}
func TestRenameNx(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
newKey := key + RandString(2)
Set(testDB, toArgs(key, value, "ex", "1000"))
result := RenameNx(testDB, toArgs(key, newKey))
if _, ok := result.(*reply.OkReply); !ok {
t.Error("expect ok")
return
}
result = Exists(testDB, toArgs(key))
asserts.AssertIntReply(t, result, 0)
result = Exists(testDB, toArgs(newKey))
asserts.AssertIntReply(t, result, 1)
result = TTL(testDB, toArgs(newKey))
intResult, ok := result.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes()))
return
}
if intResult.Code <= 0 {
t.Errorf("expected ttl more than 0, actual: %d", intResult.Code)
return
}
}
func TestTTL(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
Set(testDB, toArgs(key, value))
result := Expire(testDB, toArgs(key, "1000"))
asserts.AssertIntReply(t, result, 1)
result = TTL(testDB, toArgs(key))
intResult, ok := result.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes()))
return
}
if intResult.Code <= 0 {
t.Errorf("expected ttl more than 0, actual: %d", intResult.Code)
return
}
result = Persist(testDB, toArgs(key))
asserts.AssertIntReply(t, result, 1)
result = TTL(testDB, toArgs(key))
asserts.AssertIntReply(t, result, -1)
result = PExpire(testDB, toArgs(key, "1000000"))
asserts.AssertIntReply(t, result, 1)
result = PTTL(testDB, toArgs(key))
intResult, ok = result.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes()))
return
}
if intResult.Code <= 0 {
t.Errorf("expected ttl more than 0, actual: %d", intResult.Code)
return
}
}
func TestExpireAt(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
Set(testDB, toArgs(key, value))
expireAt := time.Now().Add(time.Minute).Unix()
result := ExpireAt(testDB, toArgs(key, strconv.FormatInt(expireAt, 10)))
asserts.AssertIntReply(t, result, 1)
result = TTL(testDB, toArgs(key))
intResult, ok := result.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes()))
return
}
if intResult.Code <= 0 {
t.Errorf("expected ttl more than 0, actual: %d", intResult.Code)
return
}
expireAt = time.Now().Add(time.Minute).Unix()
result = PExpireAt(testDB, toArgs(key, strconv.FormatInt(expireAt*1000, 10)))
asserts.AssertIntReply(t, result, 1)
result = TTL(testDB, toArgs(key))
intResult, ok = result.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes()))
return
}
if intResult.Code <= 0 {
t.Errorf("expected ttl more than 0, actual: %d", intResult.Code)
return
}
}
func TestKeys(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
Set(testDB, toArgs(key, value))
Set(testDB, toArgs("a:"+key, value))
Set(testDB, toArgs("b:"+key, value))
result := Keys(testDB, toArgs("*"))
asserts.AssertMultiBulkReplySize(t, result, 3)
result = Keys(testDB, toArgs("a:*"))
asserts.AssertMultiBulkReplySize(t, result, 1)
result = Keys(testDB, toArgs("?:*"))
asserts.AssertMultiBulkReplySize(t, result, 2)
}

View File

@@ -1,386 +1,384 @@
package db package db
import ( import (
"fmt" "fmt"
"github.com/HDT3213/godis/src/datastruct/utils" "github.com/HDT3213/godis/src/datastruct/utils"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"math/rand" "strconv"
"strconv" "testing"
"testing"
) )
func TestPush(t *testing.T) { func TestPush(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
// rpush single // rpush single
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := make([][]byte, size) values := make([][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
values[i] = []byte(value) values[i] = []byte(value)
result := RPush(testDB, toArgs(key, value)) result := RPush(testDB, toArgs(key, value))
if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(i+1) { if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(i+1) {
t.Error(fmt.Sprintf("expected %d, actually %d", i+1, intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", i+1, intResult.Code))
} }
} }
actual := LRange(testDB, toArgs(key, "0", "-1")) actual := LRange(testDB, toArgs(key, "0", "-1"))
expected := reply.MakeMultiBulkReply(values) expected := reply.MakeMultiBulkReply(values)
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error("push error") t.Error("push error")
} }
Del(testDB, toArgs(key)) Del(testDB, toArgs(key))
// rpush multi // rpush multi
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
values = make([][]byte, size+1) values = make([][]byte, size+1)
values[0] = []byte(key) values[0] = []byte(key)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
values[i+1] = []byte(value) values[i+1] = []byte(value)
} }
result := RPush(testDB, values) result := RPush(testDB, values)
if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(size) { if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(size) {
t.Error(fmt.Sprintf("expected %d, actually %d", size, intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", size, intResult.Code))
} }
actual = LRange(testDB, toArgs(key, "0", "-1")) actual = LRange(testDB, toArgs(key, "0", "-1"))
expected = reply.MakeMultiBulkReply(values[1:]) expected = reply.MakeMultiBulkReply(values[1:])
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error("push error") t.Error("push error")
} }
Del(testDB, toArgs(key)) Del(testDB, toArgs(key))
// left push single // left push single
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
values = make([][]byte, size) values = make([][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
values[size-i-1] = []byte(value) values[size-i-1] = []byte(value)
result = LPush(testDB, toArgs(key, value)) result = LPush(testDB, toArgs(key, value))
if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(i+1) { if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(i+1) {
t.Error(fmt.Sprintf("expected %d, actually %d", i+1, intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", i+1, intResult.Code))
} }
} }
actual = LRange(testDB, toArgs(key, "0", "-1")) actual = LRange(testDB, toArgs(key, "0", "-1"))
expected = reply.MakeMultiBulkReply(values) expected = reply.MakeMultiBulkReply(values)
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error("push error") t.Error("push error")
} }
Del(testDB, toArgs(key)) Del(testDB, toArgs(key))
// left push multi // left push multi
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
values = make([][]byte, size+1) values = make([][]byte, size+1)
values[0] = []byte(key) values[0] = []byte(key)
expectedValues := make([][]byte, size) expectedValues := make([][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
values[i+1] = []byte(value) values[i+1] = []byte(value)
expectedValues[size-i-1] = []byte(value) expectedValues[size-i-1] = []byte(value)
} }
result = LPush(testDB, values) result = LPush(testDB, values)
if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(size) { if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(size) {
t.Error(fmt.Sprintf("expected %d, actually %d", size, intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", size, intResult.Code))
} }
actual = LRange(testDB, toArgs(key, "0", "-1")) actual = LRange(testDB, toArgs(key, "0", "-1"))
expected = reply.MakeMultiBulkReply(expectedValues) expected = reply.MakeMultiBulkReply(expectedValues)
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error("push error") t.Error("push error")
} }
Del(testDB, toArgs(key)) Del(testDB, toArgs(key))
} }
func TestLRange(t *testing.T) { func TestLRange(t *testing.T) {
// prepare list // prepare list
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := make([][]byte, size) values := make([][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
RPush(testDB, toArgs(key, value)) RPush(testDB, toArgs(key, value))
values[i] = []byte(value) values[i] = []byte(value)
} }
start := "0" start := "0"
end := "9" end := "9"
actual := LRange(testDB, toArgs(key, start, end)) actual := LRange(testDB, toArgs(key, start, end))
expected := reply.MakeMultiBulkReply(values[0:10]) expected := reply.MakeMultiBulkReply(values[0:10])
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) t.Error(fmt.Sprintf("range error [%s, %s]", start, end))
} }
start = "0" start = "0"
end = "200" end = "200"
actual = LRange(testDB, toArgs(key, start, end)) actual = LRange(testDB, toArgs(key, start, end))
expected = reply.MakeMultiBulkReply(values) expected = reply.MakeMultiBulkReply(values)
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) t.Error(fmt.Sprintf("range error [%s, %s]", start, end))
} }
start = "0" start = "0"
end = "-10" end = "-10"
actual = LRange(testDB, toArgs(key, start, end)) actual = LRange(testDB, toArgs(key, start, end))
expected = reply.MakeMultiBulkReply(values[0 : size-10+1]) expected = reply.MakeMultiBulkReply(values[0 : size-10+1])
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) t.Error(fmt.Sprintf("range error [%s, %s]", start, end))
} }
start = "0" start = "0"
end = "-200" end = "-200"
actual = LRange(testDB, toArgs(key, start, end)) actual = LRange(testDB, toArgs(key, start, end))
expected = reply.MakeMultiBulkReply(values[0:0]) expected = reply.MakeMultiBulkReply(values[0:0])
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) t.Error(fmt.Sprintf("range error [%s, %s]", start, end))
} }
start = "-10" start = "-10"
end = "-1" end = "-1"
actual = LRange(testDB, toArgs(key, start, end)) actual = LRange(testDB, toArgs(key, start, end))
expected = reply.MakeMultiBulkReply(values[90:]) expected = reply.MakeMultiBulkReply(values[90:])
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) t.Error(fmt.Sprintf("range error [%s, %s]", start, end))
} }
} }
func TestLIndex(t *testing.T) { func TestLIndex(t *testing.T) {
// prepare list // prepare list
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := make([][]byte, size) values := make([][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
RPush(testDB, toArgs(key, value)) RPush(testDB, toArgs(key, value))
values[i] = []byte(value) values[i] = []byte(value)
} }
result := LLen(testDB, toArgs(key)) result := LLen(testDB, toArgs(key))
if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(size) { if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(size) {
t.Error(fmt.Sprintf("expected %d, actually %d", size, intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", size, intResult.Code))
} }
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
result = LIndex(testDB, toArgs(key, strconv.Itoa(i))) result = LIndex(testDB, toArgs(key, strconv.Itoa(i)))
expected := reply.MakeBulkReply(values[i]) expected := reply.MakeBulkReply(values[i])
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
} }
for i := 1; i <= size; i++ { for i := 1; i <= size; i++ {
result = LIndex(testDB, toArgs(key, strconv.Itoa(-i))) result = LIndex(testDB, toArgs(key, strconv.Itoa(-i)))
expected := reply.MakeBulkReply(values[size-i]) expected := reply.MakeBulkReply(values[size-i])
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
} }
} }
func TestLRem(t *testing.T) { func TestLRem(t *testing.T) {
// prepare list // prepare list
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := []string{key, "a", "b", "a", "a", "c", "a", "a"} values := []string{key, "a", "b", "a", "a", "c", "a", "a"}
RPush(testDB, toArgs(values...)) RPush(testDB, toArgs(values...))
result := LRem(testDB, toArgs(key, "1", "a")) result := LRem(testDB, toArgs(key, "1", "a"))
if intResult, _ := result.(*reply.IntReply); intResult.Code != 1 { if intResult, _ := result.(*reply.IntReply); intResult.Code != 1 {
t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code))
} }
result = LLen(testDB, toArgs(key)) result = LLen(testDB, toArgs(key))
if intResult, _ := result.(*reply.IntReply); intResult.Code != 6 { if intResult, _ := result.(*reply.IntReply); intResult.Code != 6 {
t.Error(fmt.Sprintf("expected %d, actually %d", 6, intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", 6, intResult.Code))
} }
result = LRem(testDB, toArgs(key, "-2", "a")) result = LRem(testDB, toArgs(key, "-2", "a"))
if intResult, _ := result.(*reply.IntReply); intResult.Code != 2 { if intResult, _ := result.(*reply.IntReply); intResult.Code != 2 {
t.Error(fmt.Sprintf("expected %d, actually %d", 2, intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", 2, intResult.Code))
} }
result = LLen(testDB, toArgs(key)) result = LLen(testDB, toArgs(key))
if intResult, _ := result.(*reply.IntReply); intResult.Code != 4 { if intResult, _ := result.(*reply.IntReply); intResult.Code != 4 {
t.Error(fmt.Sprintf("expected %d, actually %d", 4, intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", 4, intResult.Code))
} }
result = LRem(testDB, toArgs(key, "0", "a")) result = LRem(testDB, toArgs(key, "0", "a"))
if intResult, _ := result.(*reply.IntReply); intResult.Code != 2 { if intResult, _ := result.(*reply.IntReply); intResult.Code != 2 {
t.Error(fmt.Sprintf("expected %d, actually %d", 2, intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", 2, intResult.Code))
} }
result = LLen(testDB, toArgs(key)) result = LLen(testDB, toArgs(key))
if intResult, _ := result.(*reply.IntReply); intResult.Code != 2 { if intResult, _ := result.(*reply.IntReply); intResult.Code != 2 {
t.Error(fmt.Sprintf("expected %d, actually %d", 2, intResult.Code)) t.Error(fmt.Sprintf("expected %d, actually %d", 2, intResult.Code))
} }
} }
func TestLSet(t *testing.T) { func TestLSet(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := []string{key, "a", "b", "c", "d", "e", "f"} values := []string{key, "a", "b", "c", "d", "e", "f"}
RPush(testDB, toArgs(values...)) RPush(testDB, toArgs(values...))
// test positive index // test positive index
size := len(values) - 1 size := len(values) - 1
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
indexStr := strconv.Itoa(i) indexStr := strconv.Itoa(i)
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
result := LSet(testDB, toArgs(key, indexStr, value)) result := LSet(testDB, toArgs(key, indexStr, value))
if _, ok := result.(*reply.OkReply); !ok { if _, ok := result.(*reply.OkReply); !ok {
t.Error(fmt.Sprintf("expected OK, actually %s", string(result.ToBytes()))) t.Error(fmt.Sprintf("expected OK, actually %s", string(result.ToBytes())))
} }
result = LIndex(testDB, toArgs(key, indexStr)) result = LIndex(testDB, toArgs(key, indexStr))
expected := reply.MakeBulkReply([]byte(value)) expected := reply.MakeBulkReply([]byte(value))
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
} }
// test negative index // test negative index
for i := 1; i <= size; i++ { for i := 1; i <= size; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
result := LSet(testDB, toArgs(key, strconv.Itoa(-i), value)) result := LSet(testDB, toArgs(key, strconv.Itoa(-i), value))
if _, ok := result.(*reply.OkReply); !ok { if _, ok := result.(*reply.OkReply); !ok {
t.Error(fmt.Sprintf("expected OK, actually %s", string(result.ToBytes()))) t.Error(fmt.Sprintf("expected OK, actually %s", string(result.ToBytes())))
} }
result = LIndex(testDB, toArgs(key, strconv.Itoa(len(values)-i-1))) result = LIndex(testDB, toArgs(key, strconv.Itoa(len(values)-i-1)))
expected := reply.MakeBulkReply([]byte(value)) expected := reply.MakeBulkReply([]byte(value))
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
} }
// test illegal index // test illegal index
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
result := LSet(testDB, toArgs(key, strconv.Itoa(-len(values)-1), value)) result := LSet(testDB, toArgs(key, strconv.Itoa(-len(values)-1), value))
expected := reply.MakeErrReply("ERR index out of range") expected := reply.MakeErrReply("ERR index out of range")
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
result = LSet(testDB, toArgs(key, strconv.Itoa(len(values)), value)) result = LSet(testDB, toArgs(key, strconv.Itoa(len(values)), value))
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
result = LSet(testDB, toArgs(key, "a", value)) result = LSet(testDB, toArgs(key, "a", value))
expected = reply.MakeErrReply("ERR value is not an integer or out of range") expected = reply.MakeErrReply("ERR value is not an integer or out of range")
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
} }
func TestLPop(t *testing.T) { func TestLPop(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := []string{key, "a", "b", "c", "d", "e", "f"} values := []string{key, "a", "b", "c", "d", "e", "f"}
RPush(testDB, toArgs(values...)) RPush(testDB, toArgs(values...))
size := len(values) - 1 size := len(values) - 1
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
result := LPop(testDB, toArgs(key)) result := LPop(testDB, toArgs(key))
expected := reply.MakeBulkReply([]byte(values[i+1])) expected := reply.MakeBulkReply([]byte(values[i+1]))
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
} }
result := RPop(testDB, toArgs(key)) result := RPop(testDB, toArgs(key))
expected := &reply.NullBulkReply{} expected := &reply.NullBulkReply{}
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
} }
func TestRPop(t *testing.T) { func TestRPop(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
values := []string{key, "a", "b", "c", "d", "e", "f"} values := []string{key, "a", "b", "c", "d", "e", "f"}
RPush(testDB, toArgs(values...)) RPush(testDB, toArgs(values...))
size := len(values) - 1 size := len(values) - 1
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
result := RPop(testDB, toArgs(key)) result := RPop(testDB, toArgs(key))
expected := reply.MakeBulkReply([]byte(values[len(values)-i-1])) expected := reply.MakeBulkReply([]byte(values[len(values)-i-1]))
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
} }
result := RPop(testDB, toArgs(key)) result := RPop(testDB, toArgs(key))
expected := &reply.NullBulkReply{} expected := &reply.NullBulkReply{}
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
} }
func TestRPopLPush(t *testing.T) { func TestRPopLPush(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key1 := strconv.FormatInt(int64(rand.Int()), 10) key1 := RandString(10)
key2 := strconv.FormatInt(int64(rand.Int()), 10) key2 := RandString(10)
values := []string{key1, "a", "b", "c", "d", "e", "f"} values := []string{key1, "a", "b", "c", "d", "e", "f"}
RPush(testDB, toArgs(values...)) RPush(testDB, toArgs(values...))
size := len(values) - 1 size := len(values) - 1
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
result := RPopLPush(testDB, toArgs(key1, key2)) result := RPopLPush(testDB, toArgs(key1, key2))
expected := reply.MakeBulkReply([]byte(values[len(values)-i-1])) expected := reply.MakeBulkReply([]byte(values[len(values)-i-1]))
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
result = LIndex(testDB, toArgs(key2, "0")) result = LIndex(testDB, toArgs(key2, "0"))
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
} }
result := RPop(testDB, toArgs(key1)) result := RPop(testDB, toArgs(key1))
expected := &reply.NullBulkReply{} expected := &reply.NullBulkReply{}
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
} }
func TestRPushX(t *testing.T) { func TestRPushX(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
result := RPushX(testDB, toArgs(key, "1")) result := RPushX(testDB, toArgs(key, "1"))
expected := reply.MakeIntReply(int64(0)) expected := reply.MakeIntReply(int64(0))
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
RPush(testDB, toArgs(key, "1")) RPush(testDB, toArgs(key, "1"))
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
result := RPushX(testDB, toArgs(key, value)) result := RPushX(testDB, toArgs(key, value))
expected := reply.MakeIntReply(int64(i + 2)) expected := reply.MakeIntReply(int64(i + 2))
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
result = LIndex(testDB, toArgs(key, "-1")) result = LIndex(testDB, toArgs(key, "-1"))
expected2 := reply.MakeBulkReply([]byte(value)) expected2 := reply.MakeBulkReply([]byte(value))
if !utils.BytesEquals(result.ToBytes(), expected2.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected2.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected2.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected2.ToBytes()), string(result.ToBytes())))
} }
} }
} }
func TestLPushX(t *testing.T) { func TestLPushX(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
result := RPushX(testDB, toArgs(key, "1")) result := RPushX(testDB, toArgs(key, "1"))
expected := reply.MakeIntReply(int64(0)) expected := reply.MakeIntReply(int64(0))
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
LPush(testDB, toArgs(key, "1")) LPush(testDB, toArgs(key, "1"))
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
result := LPushX(testDB, toArgs(key, value)) result := LPushX(testDB, toArgs(key, value))
expected := reply.MakeIntReply(int64(i + 2)) expected := reply.MakeIntReply(int64(i + 2))
if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes())))
} }
result = LIndex(testDB, toArgs(key, "0")) result = LIndex(testDB, toArgs(key, "0"))
expected2 := reply.MakeBulkReply([]byte(value)) expected2 := reply.MakeBulkReply([]byte(value))
if !utils.BytesEquals(result.ToBytes(), expected2.ToBytes()) { if !utils.BytesEquals(result.ToBytes(), expected2.ToBytes()) {
t.Error(fmt.Sprintf("expected %s, actually %s", string(expected2.ToBytes()), string(result.ToBytes()))) t.Error(fmt.Sprintf("expected %s, actually %s", string(expected2.ToBytes()), string(result.ToBytes())))
} }
} }
} }

View File

@@ -1,511 +1,509 @@
package db package db
import ( import (
HashSet "github.com/HDT3213/godis/src/datastruct/set" HashSet "github.com/HDT3213/godis/src/datastruct/set"
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"strconv" "strconv"
) )
func (db *DB) getAsSet(key string) (*HashSet.Set, reply.ErrorReply) {
func (db *DB)getAsSet(key string)(*HashSet.Set, reply.ErrorReply) { entity, exists := db.Get(key)
entity, exists := db.Get(key) if !exists {
if !exists { return nil, nil
return nil, nil }
} set, ok := entity.Data.(*HashSet.Set)
set, ok := entity.Data.(*HashSet.Set) if !ok {
if !ok { return nil, &reply.WrongTypeErrReply{}
return nil, &reply.WrongTypeErrReply{} }
} return set, nil
return set, nil
} }
func (db *DB) getOrInitSet(key string)(set *HashSet.Set, inited bool, errReply reply.ErrorReply) { func (db *DB) getOrInitSet(key string) (set *HashSet.Set, inited bool, errReply reply.ErrorReply) {
set, errReply = db.getAsSet(key) set, errReply = db.getAsSet(key)
if errReply != nil { if errReply != nil {
return nil, false, errReply return nil, false, errReply
} }
inited = false inited = false
if set == nil { if set == nil {
set = HashSet.Make() set = HashSet.Make()
db.Put(key, &DataEntity{ db.Put(key, &DataEntity{
Data: set, Data: set,
}) })
inited = true inited = true
} }
return set, inited, nil return set, inited, nil
} }
func SAdd(db *DB, args [][]byte) redis.Reply { func SAdd(db *DB, args [][]byte) redis.Reply {
if len(args) < 2 { if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sadd' command") return reply.MakeErrReply("ERR wrong number of arguments for 'sadd' command")
} }
key := string(args[0]) key := string(args[0])
members := args[1:] members := args[1:]
// lock // lock
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
// get or init entity // get or init entity
set, _, errReply := db.getOrInitSet(key) set, _, errReply := db.getOrInitSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
counter := 0 counter := 0
for _, member := range members { for _, member := range members {
counter += set.Add(string(member)) counter += set.Add(string(member))
} }
db.AddAof(makeAofCmd("sadd", args)) db.AddAof(makeAofCmd("sadd", args))
return reply.MakeIntReply(int64(counter)) return reply.MakeIntReply(int64(counter))
} }
func SIsMember(db *DB, args [][]byte) redis.Reply { func SIsMember(db *DB, args [][]byte) redis.Reply {
if len(args) != 2 { if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sismember' command") return reply.MakeErrReply("ERR wrong number of arguments for 'sismember' command")
} }
key := string(args[0]) key := string(args[0])
member := string(args[1]) member := string(args[1])
// get set // get set
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { if set == nil {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
has := set.Has(member) has := set.Has(member)
if has { if has {
return reply.MakeIntReply(1) return reply.MakeIntReply(1)
} else { } else {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
} }
func SRem(db *DB, args [][]byte) redis.Reply { func SRem(db *DB, args [][]byte) redis.Reply {
if len(args) < 2 { if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'srem' command") return reply.MakeErrReply("ERR wrong number of arguments for 'srem' command")
} }
key := string(args[0]) key := string(args[0])
members := args[1:] members := args[1:]
// lock // lock
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { if set == nil {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
counter := 0 counter := 0
for _, member := range members { for _, member := range members {
counter += set.Remove(string(member)) counter += set.Remove(string(member))
} }
if set.Len() == 0 { if set.Len() == 0 {
db.Remove(key) db.Remove(key)
} }
if counter > 0 { if counter > 0 {
db.AddAof(makeAofCmd("srem", args)) db.AddAof(makeAofCmd("srem", args))
} }
return reply.MakeIntReply(int64(counter)) return reply.MakeIntReply(int64(counter))
} }
func SCard(db *DB, args [][]byte) redis.Reply { func SCard(db *DB, args [][]byte) redis.Reply {
if len(args) != 1 { if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'scard' command") return reply.MakeErrReply("ERR wrong number of arguments for 'scard' command")
} }
key := string(args[0]) key := string(args[0])
// get or init entity // get or init entity
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { if set == nil {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
return reply.MakeIntReply(int64(set.Len())) return reply.MakeIntReply(int64(set.Len()))
} }
func SMembers(db *DB, args [][]byte) redis.Reply { func SMembers(db *DB, args [][]byte) redis.Reply {
if len(args) != 1 { if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'smembers' command") return reply.MakeErrReply("ERR wrong number of arguments for 'smembers' command")
} }
key := string(args[0]) key := string(args[0])
// lock // lock
db.Locker.RLock(key) db.Locker.RLock(key)
defer db.Locker.RUnLock(key) defer db.Locker.RUnLock(key)
// get or init entity // get or init entity
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { if set == nil {
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
arr := make([][]byte, set.Len())
arr := make([][]byte, set.Len()) i := 0
i := 0 set.ForEach(func(member string) bool {
set.ForEach(func (member string)bool { arr[i] = []byte(member)
arr[i] = []byte(member) i++
i++ return true
return true })
}) return reply.MakeMultiBulkReply(arr)
return reply.MakeMultiBulkReply(arr)
} }
func SInter(db *DB, args [][]byte) redis.Reply { func SInter(db *DB, args [][]byte) redis.Reply {
if len(args) < 1 { if len(args) < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sinter' command") return reply.MakeErrReply("ERR wrong number of arguments for 'sinter' command")
} }
keys := make([]string, len(args)) keys := make([]string, len(args))
for i, arg := range args { for i, arg := range args {
keys[i] = string(arg) keys[i] = string(arg)
} }
// lock // lock
db.Locker.RLocks(keys...) db.Locker.RLocks(keys...)
defer db.Locker.RUnLocks(keys...) defer db.Locker.RUnLocks(keys...)
var result *HashSet.Set var result *HashSet.Set
for _, key := range keys { for _, key := range keys {
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { if set == nil {
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
if result == nil { if result == nil {
// init // init
result = HashSet.MakeFromVals(set.ToSlice()...) result = HashSet.MakeFromVals(set.ToSlice()...)
} else { } else {
result = result.Intersect(set) result = result.Intersect(set)
if result.Len() == 0 { if result.Len() == 0 {
// early termination // early termination
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
} }
} }
arr := make([][]byte, result.Len()) arr := make([][]byte, result.Len())
i := 0 i := 0
result.ForEach(func (member string)bool { result.ForEach(func(member string) bool {
arr[i] = []byte(member) arr[i] = []byte(member)
i++ i++
return true return true
}) })
return reply.MakeMultiBulkReply(arr) return reply.MakeMultiBulkReply(arr)
} }
func SInterStore(db *DB, args [][]byte) redis.Reply { func SInterStore(db *DB, args [][]byte) redis.Reply {
if len(args) < 2 { if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sinterstore' command") return reply.MakeErrReply("ERR wrong number of arguments for 'sinterstore' command")
} }
dest := string(args[0]) dest := string(args[0])
keys := make([]string, len(args) - 1) keys := make([]string, len(args)-1)
keyArgs := args[1:] keyArgs := args[1:]
for i, arg := range keyArgs { for i, arg := range keyArgs {
keys[i] = string(arg) keys[i] = string(arg)
} }
// lock // lock
db.RLocks(keys...) db.RLocks(keys...)
defer db.RUnLocks(keys...) defer db.RUnLocks(keys...)
db.Lock(dest) db.Lock(dest)
defer db.UnLock(dest) defer db.UnLock(dest)
var result *HashSet.Set var result *HashSet.Set
for _, key := range keys { for _, key := range keys {
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { if set == nil {
db.Remove(dest) // clean ttl and old value db.Remove(dest) // clean ttl and old value
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
if result == nil { if result == nil {
// init // init
result = HashSet.MakeFromVals(set.ToSlice()...) result = HashSet.MakeFromVals(set.ToSlice()...)
} else { } else {
result = result.Intersect(set) result = result.Intersect(set)
if result.Len() == 0 { if result.Len() == 0 {
// early termination // early termination
db.Remove(dest) // clean ttl and old value db.Remove(dest) // clean ttl and old value
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
} }
} }
set := HashSet.MakeFromVals(result.ToSlice()...) set := HashSet.MakeFromVals(result.ToSlice()...)
db.Put(dest, &DataEntity{ db.Put(dest, &DataEntity{
Data: set, Data: set,
}) })
db.AddAof(makeAofCmd("sinterstore", args)) db.AddAof(makeAofCmd("sinterstore", args))
return reply.MakeIntReply(int64(set.Len())) return reply.MakeIntReply(int64(set.Len()))
} }
func SUnion(db *DB, args [][]byte) redis.Reply { func SUnion(db *DB, args [][]byte) redis.Reply {
if len(args) < 1 { if len(args) < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sunion' command") return reply.MakeErrReply("ERR wrong number of arguments for 'sunion' command")
} }
keys := make([]string, len(args)) keys := make([]string, len(args))
for i, arg := range args { for i, arg := range args {
keys[i] = string(arg) keys[i] = string(arg)
} }
// lock // lock
db.RLocks(keys...) db.RLocks(keys...)
defer db.RUnLocks(keys...) defer db.RUnLocks(keys...)
var result *HashSet.Set var result *HashSet.Set
for _, key := range keys { for _, key := range keys {
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { if set == nil {
continue continue
} }
if result == nil { if result == nil {
// init // init
result = HashSet.MakeFromVals(set.ToSlice()...) result = HashSet.MakeFromVals(set.ToSlice()...)
} else { } else {
result = result.Union(set) result = result.Union(set)
} }
} }
if result == nil { if result == nil {
// all keys are empty set // all keys are empty set
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
arr := make([][]byte, result.Len()) arr := make([][]byte, result.Len())
i := 0 i := 0
result.ForEach(func (member string)bool { result.ForEach(func(member string) bool {
arr[i] = []byte(member) arr[i] = []byte(member)
i++ i++
return true return true
}) })
return reply.MakeMultiBulkReply(arr) return reply.MakeMultiBulkReply(arr)
} }
func SUnionStore(db *DB, args [][]byte) redis.Reply { func SUnionStore(db *DB, args [][]byte) redis.Reply {
if len(args) < 2 { if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sunionstore' command") return reply.MakeErrReply("ERR wrong number of arguments for 'sunionstore' command")
} }
dest := string(args[0]) dest := string(args[0])
keys := make([]string, len(args) - 1) keys := make([]string, len(args)-1)
keyArgs := args[1:] keyArgs := args[1:]
for i, arg := range keyArgs { for i, arg := range keyArgs {
keys[i] = string(arg) keys[i] = string(arg)
} }
// lock // lock
db.RLocks(keys...) db.RLocks(keys...)
defer db.RUnLocks(keys...) defer db.RUnLocks(keys...)
db.Lock(dest) db.Lock(dest)
defer db.UnLock(dest) defer db.UnLock(dest)
var result *HashSet.Set var result *HashSet.Set
for _, key := range keys { for _, key := range keys {
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { if set == nil {
continue continue
} }
if result == nil { if result == nil {
// init // init
result = HashSet.MakeFromVals(set.ToSlice()...) result = HashSet.MakeFromVals(set.ToSlice()...)
} else { } else {
result = result.Union(set) result = result.Union(set)
} }
} }
db.Remove(dest) // clean ttl db.Remove(dest) // clean ttl
if result == nil { if result == nil {
// all keys are empty set // all keys are empty set
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
set := HashSet.MakeFromVals(result.ToSlice()...) set := HashSet.MakeFromVals(result.ToSlice()...)
db.Put(dest, &DataEntity{ db.Put(dest, &DataEntity{
Data: set, Data: set,
}) })
db.AddAof(makeAofCmd("sunionstore", args)) db.AddAof(makeAofCmd("sunionstore", args))
return reply.MakeIntReply(int64(set.Len())) return reply.MakeIntReply(int64(set.Len()))
} }
func SDiff(db *DB, args [][]byte) redis.Reply { func SDiff(db *DB, args [][]byte) redis.Reply {
if len(args) < 1 { if len(args) < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sdiff' command") return reply.MakeErrReply("ERR wrong number of arguments for 'sdiff' command")
} }
keys := make([]string, len(args)) keys := make([]string, len(args))
for i, arg := range args { for i, arg := range args {
keys[i] = string(arg) keys[i] = string(arg)
} }
// lock // lock
db.RLocks(keys...) db.RLocks(keys...)
defer db.RUnLocks(keys...) defer db.RUnLocks(keys...)
var result *HashSet.Set var result *HashSet.Set
for i, key := range keys { for i, key := range keys {
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { if set == nil {
if i == 0 { if i == 0 {
// early termination // early termination
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} else { } else {
continue continue
} }
} }
if result == nil { if result == nil {
// init // init
result = HashSet.MakeFromVals(set.ToSlice()...) result = HashSet.MakeFromVals(set.ToSlice()...)
} else { } else {
result = result.Diff(set) result = result.Diff(set)
if result.Len() == 0 { if result.Len() == 0 {
// early termination // early termination
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
} }
} }
if result == nil { if result == nil {
// all keys are nil // all keys are nil
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
arr := make([][]byte, result.Len()) arr := make([][]byte, result.Len())
i := 0 i := 0
result.ForEach(func (member string)bool { result.ForEach(func(member string) bool {
arr[i] = []byte(member) arr[i] = []byte(member)
i++ i++
return true return true
}) })
return reply.MakeMultiBulkReply(arr) return reply.MakeMultiBulkReply(arr)
} }
func SDiffStore(db *DB, args [][]byte) redis.Reply { func SDiffStore(db *DB, args [][]byte) redis.Reply {
if len(args) < 2 { if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sdiffstore' command") return reply.MakeErrReply("ERR wrong number of arguments for 'sdiffstore' command")
} }
dest := string(args[0]) dest := string(args[0])
keys := make([]string, len(args) - 1) keys := make([]string, len(args)-1)
keyArgs := args[1:] keyArgs := args[1:]
for i, arg := range keyArgs { for i, arg := range keyArgs {
keys[i] = string(arg) keys[i] = string(arg)
} }
// lock // lock
db.RLocks(keys...) db.RLocks(keys...)
defer db.RUnLocks(keys...) defer db.RUnLocks(keys...)
db.Lock(dest) db.Lock(dest)
defer db.Locker.UnLock(dest) defer db.Locker.UnLock(dest)
var result *HashSet.Set var result *HashSet.Set
for i, key := range keys { for i, key := range keys {
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { if set == nil {
if i == 0 { if i == 0 {
// early termination // early termination
db.Remove(dest) db.Remove(dest)
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} else { } else {
continue continue
} }
} }
if result == nil { if result == nil {
// init // init
result = HashSet.MakeFromVals(set.ToSlice()...) result = HashSet.MakeFromVals(set.ToSlice()...)
} else { } else {
result = result.Diff(set) result = result.Diff(set)
if result.Len() == 0 { if result.Len() == 0 {
// early termination // early termination
db.Remove(dest) db.Remove(dest)
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
} }
} }
if result == nil { if result == nil {
// all keys are nil // all keys are nil
db.Remove(dest) db.Remove(dest)
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
set := HashSet.MakeFromVals(result.ToSlice()...) set := HashSet.MakeFromVals(result.ToSlice()...)
db.Put(dest, &DataEntity{ db.Put(dest, &DataEntity{
Data: set, Data: set,
}) })
db.AddAof(makeAofCmd("sdiffstore", args)) db.AddAof(makeAofCmd("sdiffstore", args))
return reply.MakeIntReply(int64(set.Len())) return reply.MakeIntReply(int64(set.Len()))
} }
func SRandMember(db *DB, args [][]byte) redis.Reply { func SRandMember(db *DB, args [][]byte) redis.Reply {
if len(args) != 1 && len(args) != 2 { if len(args) != 1 && len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'srandmember' command") return reply.MakeErrReply("ERR wrong number of arguments for 'srandmember' command")
} }
key := string(args[0]) key := string(args[0])
// lock // lock
db.RLock(key) db.RLock(key)
defer db.RUnLock(key) defer db.RUnLock(key)
// get or init entity // get or init entity
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { if set == nil {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
if len(args) == 1 { if len(args) == 1 {
members := set.RandomMembers(1) members := set.RandomMembers(1)
return reply.MakeBulkReply([]byte(members[0])) return reply.MakeBulkReply([]byte(members[0]))
} else { } else {
count64, err := strconv.ParseInt(string(args[1]), 10, 64) count64, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
count := int(count64) count := int(count64)
if count > 0 { if count > 0 {
members := set.RandomMembers(count) members := set.RandomDistinctMembers(count)
result := make([][]byte, len(members)) result := make([][]byte, len(members))
for i, v := range members { for i, v := range members {
result[i] = []byte(v) result[i] = []byte(v)
} }
return reply.MakeMultiBulkReply(result) return reply.MakeMultiBulkReply(result)
} else if count < 0 { } else if count < 0 {
members := set.RandomDistinctMembers(-count) members := set.RandomMembers(-count)
result := make([][]byte, len(members)) result := make([][]byte, len(members))
for i, v := range members { for i, v := range members {
result[i] = []byte(v) result[i] = []byte(v)
} }
return reply.MakeMultiBulkReply(result) return reply.MakeMultiBulkReply(result)
} else { } else {
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
} }
} }

181
src/db/set_test.go Normal file
View File

@@ -0,0 +1,181 @@
package db
import (
"fmt"
"github.com/HDT3213/godis/src/redis/reply"
"github.com/HDT3213/godis/src/redis/reply/asserts"
"strconv"
"testing"
)
// basic add get and remove
func TestSAdd(t *testing.T) {
FlushAll(testDB, [][]byte{})
size := 100
// test sadd
key := RandString(10)
for i := 0; i < size; i++ {
member := strconv.Itoa(i)
result := SAdd(testDB, toArgs(key, member))
asserts.AssertIntReply(t, result, 1)
}
// test scard
result := SCard(testDB, toArgs(key))
asserts.AssertIntReply(t, result, size)
// test is member
for i := 0; i < size; i++ {
member := strconv.Itoa(i)
result := SIsMember(testDB, toArgs(key, member))
asserts.AssertIntReply(t, result, 1)
}
// test members
result = SMembers(testDB, toArgs(key))
multiBulk, ok := result.(*reply.MultiBulkReply)
if !ok {
t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes()))
return
}
if len(multiBulk.Args) != size {
t.Error(fmt.Sprintf("expected %d elements, actually %d", size, len(multiBulk.Args)))
return
}
}
func TestSRem(t *testing.T) {
FlushAll(testDB, [][]byte{})
size := 100
// mock data
key := RandString(10)
for i := 0; i < size; i++ {
member := strconv.Itoa(i)
SAdd(testDB, toArgs(key, member))
}
for i := 0; i < size; i++ {
member := strconv.Itoa(i)
SRem(testDB, toArgs(key, member))
result := SIsMember(testDB, toArgs(key, member))
asserts.AssertIntReply(t, result, 0)
}
}
func TestSInter(t *testing.T) {
FlushAll(testDB, [][]byte{})
size := 100
step := 10
keys := make([]string, 0)
start := 0
for i := 0; i < 4; i++ {
key := RandString(10)
keys = append(keys, key)
for j := start; j < size+start; j++ {
member := strconv.Itoa(j)
SAdd(testDB, toArgs(key, member))
}
start += step
}
result := SInter(testDB, toArgs(keys...))
asserts.AssertMultiBulkReplySize(t, result, 70)
destKey := RandString(10)
keysWithDest := []string{destKey}
keysWithDest = append(keysWithDest, keys...)
result = SInterStore(testDB, toArgs(keysWithDest...))
asserts.AssertIntReply(t, result, 70)
}
func TestSUnion(t *testing.T) {
FlushAll(testDB, [][]byte{})
size := 100
step := 10
keys := make([]string, 0)
start := 0
for i := 0; i < 4; i++ {
key := RandString(10)
keys = append(keys, key)
for j := start; j < size+start; j++ {
member := strconv.Itoa(j)
SAdd(testDB, toArgs(key, member))
}
start += step
}
result := SUnion(testDB, toArgs(keys...))
asserts.AssertMultiBulkReplySize(t, result, 130)
destKey := RandString(10)
keysWithDest := []string{destKey}
keysWithDest = append(keysWithDest, keys...)
result = SUnionStore(testDB, toArgs(keysWithDest...))
asserts.AssertIntReply(t, result, 130)
}
func TestSDiff(t *testing.T) {
FlushAll(testDB, [][]byte{})
size := 100
step := 20
keys := make([]string, 0)
start := 0
for i := 0; i < 3; i++ {
key := RandString(10)
keys = append(keys, key)
for j := start; j < size+start; j++ {
member := strconv.Itoa(j)
SAdd(testDB, toArgs(key, member))
}
start += step
}
result := SDiff(testDB, toArgs(keys...))
asserts.AssertMultiBulkReplySize(t, result, step)
destKey := RandString(10)
keysWithDest := []string{destKey}
keysWithDest = append(keysWithDest, keys...)
result = SDiffStore(testDB, toArgs(keysWithDest...))
asserts.AssertIntReply(t, result, step)
}
func TestSRandMember(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
for j := 0; j < 100; j++ {
member := strconv.Itoa(j)
SAdd(testDB, toArgs(key, member))
}
result := SRandMember(testDB, toArgs(key))
br, ok := result.(*reply.BulkReply)
if !ok && len(br.Arg) > 0 {
t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes()))
return
}
result = SRandMember(testDB, toArgs(key, "10"))
asserts.AssertMultiBulkReplySize(t, result, 10)
multiBulk, ok := result.(*reply.MultiBulkReply)
if !ok {
t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes()))
return
}
m := make(map[string]struct{})
for _, arg := range multiBulk.Args {
m[string(arg)] = struct{}{}
}
if len(m) != 10 {
t.Error(fmt.Sprintf("expected 10 members, actually %d", len(m)))
return
}
result = SRandMember(testDB, toArgs(key, "110"))
asserts.AssertMultiBulkReplySize(t, result, 100)
result = SRandMember(testDB, toArgs(key, "-10"))
asserts.AssertMultiBulkReplySize(t, result, 10)
result = SRandMember(testDB, toArgs(key, "-110"))
asserts.AssertMultiBulkReplySize(t, result, 110)
}

View File

@@ -1,298 +1,298 @@
package db package db
import ( import (
"github.com/HDT3213/godis/src/redis/reply/asserts" "github.com/HDT3213/godis/src/redis/reply/asserts"
"math/rand" "math/rand"
"strconv" "strconv"
"testing" "testing"
) )
func TestZAdd(t *testing.T) { func TestZAdd(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
// add new members // add new members
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
members := make([]string, size) members := make([]string, size)
scores := make([]float64, size) scores := make([]float64, size)
setArgs := []string{key} setArgs := []string{key}
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
members[i] = strconv.FormatInt(int64(rand.Int()), 10) members[i] = RandString(10)
scores[i] = rand.Float64() scores[i] = rand.Float64()
setArgs = append(setArgs, strconv.FormatFloat(scores[i], 'f', -1, 64), members[i]) setArgs = append(setArgs, strconv.FormatFloat(scores[i], 'f', -1, 64), members[i])
} }
result := ZAdd(testDB, toArgs(setArgs...)) result := ZAdd(testDB, toArgs(setArgs...))
asserts.AssertIntReply(t, result, size) asserts.AssertIntReply(t, result, size)
// test zscore and zrank // test zscore and zrank
for i, member := range members { for i, member := range members {
result := ZScore(testDB, toArgs(key, member)) result := ZScore(testDB, toArgs(key, member))
score := strconv.FormatFloat(scores[i], 'f', -1, 64) score := strconv.FormatFloat(scores[i], 'f', -1, 64)
asserts.AssertBulkReply(t, result, score) asserts.AssertBulkReply(t, result, score)
} }
// test zcard // test zcard
result = ZCard(testDB, toArgs(key)) result = ZCard(testDB, toArgs(key))
asserts.AssertIntReply(t, result, size) asserts.AssertIntReply(t, result, size)
// update members // update members
setArgs = []string{key} setArgs = []string{key}
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
scores[i] = rand.Float64() + 100 scores[i] = rand.Float64() + 100
setArgs = append(setArgs, strconv.FormatFloat(scores[i], 'f', -1, 64), members[i]) setArgs = append(setArgs, strconv.FormatFloat(scores[i], 'f', -1, 64), members[i])
} }
result = ZAdd(testDB, toArgs(setArgs...)) result = ZAdd(testDB, toArgs(setArgs...))
asserts.AssertIntReply(t, result, 0) // return number of new members asserts.AssertIntReply(t, result, 0) // return number of new members
// test updated score // test updated score
for i, member := range members { for i, member := range members {
result := ZScore(testDB, toArgs(key, member)) result := ZScore(testDB, toArgs(key, member))
score := strconv.FormatFloat(scores[i], 'f', -1, 64) score := strconv.FormatFloat(scores[i], 'f', -1, 64)
asserts.AssertBulkReply(t, result, score) asserts.AssertBulkReply(t, result, score)
} }
} }
func TestZRank(t *testing.T) { func TestZRank(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
members := make([]string, size) members := make([]string, size)
scores := make([]int, size) scores := make([]int, size)
setArgs := []string{key} setArgs := []string{key}
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
members[i] = strconv.FormatInt(int64(rand.Int()), 10) members[i] = RandString(10)
scores[i] = i scores[i] = i
setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i])
} }
ZAdd(testDB, toArgs(setArgs...)) ZAdd(testDB, toArgs(setArgs...))
// test zrank // test zrank
for i, member := range members { for i, member := range members {
result := ZRank(testDB, toArgs(key, member)) result := ZRank(testDB, toArgs(key, member))
asserts.AssertIntReply(t, result, i) asserts.AssertIntReply(t, result, i)
result = ZRevRank(testDB, toArgs(key, member)) result = ZRevRank(testDB, toArgs(key, member))
asserts.AssertIntReply(t, result, size-i-1) asserts.AssertIntReply(t, result, size-i-1)
} }
} }
func TestZRange(t *testing.T) { func TestZRange(t *testing.T) {
// prepare // prepare
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
members := make([]string, size) members := make([]string, size)
scores := make([]int, size) scores := make([]int, size)
setArgs := []string{key} setArgs := []string{key}
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
members[i] = strconv.FormatInt(int64(rand.Int()), 10) members[i] = RandString(10)
scores[i] = i scores[i] = i
setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i])
} }
result := ZAdd(testDB, toArgs(setArgs...)) result := ZAdd(testDB, toArgs(setArgs...))
reverseMembers := make([]string, size) reverseMembers := make([]string, size)
for i, v := range members { for i, v := range members {
reverseMembers[size-i-1] = v reverseMembers[size-i-1] = v
} }
start := "0" start := "0"
end := "9" end := "9"
result = ZRange(testDB, toArgs(key, start, end)) result = ZRange(testDB, toArgs(key, start, end))
asserts.AssertMultiBulkReply(t, result, members[0:10]) asserts.AssertMultiBulkReply(t, result, members[0:10])
result = ZRevRange(testDB, toArgs(key, start, end)) result = ZRevRange(testDB, toArgs(key, start, end))
asserts.AssertMultiBulkReply(t, result, reverseMembers[0:10]) asserts.AssertMultiBulkReply(t, result, reverseMembers[0:10])
start = "0" start = "0"
end = "200" end = "200"
result = ZRange(testDB, toArgs(key, start, end)) result = ZRange(testDB, toArgs(key, start, end))
asserts.AssertMultiBulkReply(t, result, members) asserts.AssertMultiBulkReply(t, result, members)
result = ZRevRange(testDB, toArgs(key, start, end)) result = ZRevRange(testDB, toArgs(key, start, end))
asserts.AssertMultiBulkReply(t, result, reverseMembers) asserts.AssertMultiBulkReply(t, result, reverseMembers)
start = "0" start = "0"
end = "-10" end = "-10"
result = ZRange(testDB, toArgs(key, start, end)) result = ZRange(testDB, toArgs(key, start, end))
asserts.AssertMultiBulkReply(t, result, members[0:size-10+1]) asserts.AssertMultiBulkReply(t, result, members[0:size-10+1])
result = ZRevRange(testDB, toArgs(key, start, end)) result = ZRevRange(testDB, toArgs(key, start, end))
asserts.AssertMultiBulkReply(t, result, reverseMembers[0:size-10+1]) asserts.AssertMultiBulkReply(t, result, reverseMembers[0:size-10+1])
start = "0" start = "0"
end = "-200" end = "-200"
result = ZRange(testDB, toArgs(key, start, end)) result = ZRange(testDB, toArgs(key, start, end))
asserts.AssertMultiBulkReply(t, result, members[0:0]) asserts.AssertMultiBulkReply(t, result, members[0:0])
result = ZRevRange(testDB, toArgs(key, start, end)) result = ZRevRange(testDB, toArgs(key, start, end))
asserts.AssertMultiBulkReply(t, result, reverseMembers[0:0]) asserts.AssertMultiBulkReply(t, result, reverseMembers[0:0])
start = "-10" start = "-10"
end = "-1" end = "-1"
result = ZRange(testDB, toArgs(key, start, end)) result = ZRange(testDB, toArgs(key, start, end))
asserts.AssertMultiBulkReply(t, result, members[90:]) asserts.AssertMultiBulkReply(t, result, members[90:])
result = ZRevRange(testDB, toArgs(key, start, end)) result = ZRevRange(testDB, toArgs(key, start, end))
asserts.AssertMultiBulkReply(t, result, reverseMembers[90:]) asserts.AssertMultiBulkReply(t, result, reverseMembers[90:])
} }
func reverse(src []string) []string { func reverse(src []string) []string {
result := make([]string, len(src)) result := make([]string, len(src))
for i, v := range src { for i, v := range src {
result[len(src)-i-1] = v result[len(src)-i-1] = v
} }
return result return result
} }
func TestZRangeByScore(t *testing.T) { func TestZRangeByScore(t *testing.T) {
// prepare // prepare
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
members := make([]string, size) members := make([]string, size)
scores := make([]int, size) scores := make([]int, size)
setArgs := []string{key} setArgs := []string{key}
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
members[i] = strconv.FormatInt(int64(i), 10) members[i] = strconv.FormatInt(int64(i), 10)
scores[i] = i scores[i] = i
setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i])
} }
result := ZAdd(testDB, toArgs(setArgs...)) result := ZAdd(testDB, toArgs(setArgs...))
min := "20" min := "20"
max := "30" max := "30"
result = ZRangeByScore(testDB, toArgs(key, min, max)) result = ZRangeByScore(testDB, toArgs(key, min, max))
asserts.AssertMultiBulkReply(t, result, members[20:31]) asserts.AssertMultiBulkReply(t, result, members[20:31])
result = ZRevRangeByScore(testDB, toArgs(key, max, min)) result = ZRevRangeByScore(testDB, toArgs(key, max, min))
asserts.AssertMultiBulkReply(t, result, reverse(members[20:31])) asserts.AssertMultiBulkReply(t, result, reverse(members[20:31]))
min = "-10" min = "-10"
max = "10" max = "10"
result = ZRangeByScore(testDB, toArgs(key, min, max)) result = ZRangeByScore(testDB, toArgs(key, min, max))
asserts.AssertMultiBulkReply(t, result, members[0:11]) asserts.AssertMultiBulkReply(t, result, members[0:11])
result = ZRevRangeByScore(testDB, toArgs(key, max, min)) result = ZRevRangeByScore(testDB, toArgs(key, max, min))
asserts.AssertMultiBulkReply(t, result, reverse(members[0:11])) asserts.AssertMultiBulkReply(t, result, reverse(members[0:11]))
min = "90" min = "90"
max = "110" max = "110"
result = ZRangeByScore(testDB, toArgs(key, min, max)) result = ZRangeByScore(testDB, toArgs(key, min, max))
asserts.AssertMultiBulkReply(t, result, members[90:]) asserts.AssertMultiBulkReply(t, result, members[90:])
result = ZRevRangeByScore(testDB, toArgs(key, max, min)) result = ZRevRangeByScore(testDB, toArgs(key, max, min))
asserts.AssertMultiBulkReply(t, result, reverse(members[90:])) asserts.AssertMultiBulkReply(t, result, reverse(members[90:]))
min = "(20" min = "(20"
max = "(30" max = "(30"
result = ZRangeByScore(testDB, toArgs(key, min, max)) result = ZRangeByScore(testDB, toArgs(key, min, max))
asserts.AssertMultiBulkReply(t, result, members[21:30]) asserts.AssertMultiBulkReply(t, result, members[21:30])
result = ZRevRangeByScore(testDB, toArgs(key, max, min)) result = ZRevRangeByScore(testDB, toArgs(key, max, min))
asserts.AssertMultiBulkReply(t, result, reverse(members[21:30])) asserts.AssertMultiBulkReply(t, result, reverse(members[21:30]))
min = "20" min = "20"
max = "40" max = "40"
result = ZRangeByScore(testDB, toArgs(key, min, max, "LIMIT", "5", "5")) result = ZRangeByScore(testDB, toArgs(key, min, max, "LIMIT", "5", "5"))
asserts.AssertMultiBulkReply(t, result, members[25:30]) asserts.AssertMultiBulkReply(t, result, members[25:30])
result = ZRevRangeByScore(testDB, toArgs(key, max, min, "LIMIT", "5", "5")) result = ZRevRangeByScore(testDB, toArgs(key, max, min, "LIMIT", "5", "5"))
asserts.AssertMultiBulkReply(t, result, reverse(members[31:36])) asserts.AssertMultiBulkReply(t, result, reverse(members[31:36]))
} }
func TestZRem(t *testing.T) { func TestZRem(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
members := make([]string, size) members := make([]string, size)
scores := make([]int, size) scores := make([]int, size)
setArgs := []string{key} setArgs := []string{key}
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
members[i] = strconv.FormatInt(int64(i), 10) members[i] = strconv.FormatInt(int64(i), 10)
scores[i] = i scores[i] = i
setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i])
} }
ZAdd(testDB, toArgs(setArgs...)) ZAdd(testDB, toArgs(setArgs...))
args := []string{key} args := []string{key}
args = append(args, members[0:10]...) args = append(args, members[0:10]...)
result := ZRem(testDB, toArgs(args...)) result := ZRem(testDB, toArgs(args...))
asserts.AssertIntReply(t, result, 10) asserts.AssertIntReply(t, result, 10)
result = ZCard(testDB, toArgs(key)) result = ZCard(testDB, toArgs(key))
asserts.AssertIntReply(t, result, size-10) asserts.AssertIntReply(t, result, size-10)
// test ZRemRangeByRank // test ZRemRangeByRank
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size = 100 size = 100
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
members = make([]string, size) members = make([]string, size)
scores = make([]int, size) scores = make([]int, size)
setArgs = []string{key} setArgs = []string{key}
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
members[i] = strconv.FormatInt(int64(i), 10) members[i] = strconv.FormatInt(int64(i), 10)
scores[i] = i scores[i] = i
setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i])
} }
ZAdd(testDB, toArgs(setArgs...)) ZAdd(testDB, toArgs(setArgs...))
result = ZRemRangeByRank(testDB, toArgs(key, "0", "9")) result = ZRemRangeByRank(testDB, toArgs(key, "0", "9"))
asserts.AssertIntReply(t, result, 10) asserts.AssertIntReply(t, result, 10)
result = ZCard(testDB, toArgs(key)) result = ZCard(testDB, toArgs(key))
asserts.AssertIntReply(t, result, size-10) asserts.AssertIntReply(t, result, size-10)
// test ZRemRangeByScore // test ZRemRangeByScore
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size = 100 size = 100
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
members = make([]string, size) members = make([]string, size)
scores = make([]int, size) scores = make([]int, size)
setArgs = []string{key} setArgs = []string{key}
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
members[i] = strconv.FormatInt(int64(i), 10) members[i] = strconv.FormatInt(int64(i), 10)
scores[i] = i scores[i] = i
setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i])
} }
ZAdd(testDB, toArgs(setArgs...)) ZAdd(testDB, toArgs(setArgs...))
result = ZRemRangeByScore(testDB, toArgs(key, "0", "9")) result = ZRemRangeByScore(testDB, toArgs(key, "0", "9"))
asserts.AssertIntReply(t, result, 10) asserts.AssertIntReply(t, result, 10)
result = ZCard(testDB, toArgs(key)) result = ZCard(testDB, toArgs(key))
asserts.AssertIntReply(t, result, size-10) asserts.AssertIntReply(t, result, size-10)
} }
func TestZCount(t *testing.T) { func TestZCount(t *testing.T) {
// prepare // prepare
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 100 size := 100
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
members := make([]string, size) members := make([]string, size)
scores := make([]int, size) scores := make([]int, size)
setArgs := []string{key} setArgs := []string{key}
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
members[i] = strconv.FormatInt(int64(i), 10) members[i] = strconv.FormatInt(int64(i), 10)
scores[i] = i scores[i] = i
setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i])
} }
result := ZAdd(testDB, toArgs(setArgs...)) result := ZAdd(testDB, toArgs(setArgs...))
min := "20" min := "20"
max := "30" max := "30"
result = ZCount(testDB, toArgs(key, min, max)) result = ZCount(testDB, toArgs(key, min, max))
asserts.AssertIntReply(t, result, 11) asserts.AssertIntReply(t, result, 11)
min = "-10" min = "-10"
max = "10" max = "10"
result = ZCount(testDB, toArgs(key, min, max)) result = ZCount(testDB, toArgs(key, min, max))
asserts.AssertIntReply(t, result, 11) asserts.AssertIntReply(t, result, 11)
min = "90" min = "90"
max = "110" max = "110"
result = ZCount(testDB, toArgs(key, min, max)) result = ZCount(testDB, toArgs(key, min, max))
asserts.AssertIntReply(t, result, 10) asserts.AssertIntReply(t, result, 10)
min = "(20" min = "(20"
max = "(30" max = "(30"
result = ZCount(testDB, toArgs(key, min, max)) result = ZCount(testDB, toArgs(key, min, max))
asserts.AssertIntReply(t, result, 9) asserts.AssertIntReply(t, result, 9)
} }
func TestZIncrBy(t *testing.T) { func TestZIncrBy(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
ZAdd(testDB, toArgs(key, "10", "a")) ZAdd(testDB, toArgs(key, "10", "a"))
result := ZIncrBy(testDB, toArgs(key, "10", "a")) result := ZIncrBy(testDB, toArgs(key, "10", "a"))
asserts.AssertBulkReply(t, result, "20") asserts.AssertBulkReply(t, result, "20")
result = ZScore(testDB, toArgs(key, "a")) result = ZScore(testDB, toArgs(key, "a"))
asserts.AssertBulkReply(t, result, "20") asserts.AssertBulkReply(t, result, "20")
} }

View File

@@ -1,509 +1,508 @@
package db package db
import ( import (
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"github.com/shopspring/decimal" "github.com/shopspring/decimal"
"strconv" "strconv"
"strings" "strings"
"time" "time"
) )
func (db *DB) getAsString(key string) ([]byte, reply.ErrorReply) { func (db *DB) getAsString(key string) ([]byte, reply.ErrorReply) {
entity, ok := db.Get(key) entity, ok := db.Get(key)
if !ok { if !ok {
return nil, nil return nil, nil
} }
bytes, ok := entity.Data.([]byte) bytes, ok := entity.Data.([]byte)
if !ok { if !ok {
return nil, &reply.WrongTypeErrReply{} return nil, &reply.WrongTypeErrReply{}
} }
return bytes, nil return bytes, nil
} }
func Get(db *DB, args [][]byte) redis.Reply { func Get(db *DB, args [][]byte) redis.Reply {
if len(args) != 1 { if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'get' command") return reply.MakeErrReply("ERR wrong number of arguments for 'get' command")
} }
key := string(args[0]) key := string(args[0])
bytes, err := db.getAsString(key) bytes, err := db.getAsString(key)
if err != nil { if err != nil {
return err return err
} }
if bytes == nil { if bytes == nil {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
return reply.MakeBulkReply(bytes) return reply.MakeBulkReply(bytes)
} }
const ( const (
upsertPolicy = iota // default upsertPolicy = iota // default
insertPolicy // set nx insertPolicy // set nx
updatePolicy // set ex updatePolicy // set ex
) )
const unlimitedTTL int64 = 0 const unlimitedTTL int64 = 0
// SET key value [EX seconds] [PX milliseconds] [NX|XX] // SET key value [EX seconds] [PX milliseconds] [NX|XX]
func Set(db *DB, args [][]byte) redis.Reply { func Set(db *DB, args [][]byte) redis.Reply {
if len(args) < 2 { if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'set' command") return reply.MakeErrReply("ERR wrong number of arguments for 'set' command")
} }
key := string(args[0]) key := string(args[0])
value := args[1] value := args[1]
policy := upsertPolicy policy := upsertPolicy
ttl := unlimitedTTL ttl := unlimitedTTL
// parse options // parse options
if len(args) > 2 { if len(args) > 2 {
for i := 2; i < len(args); i++ { for i := 2; i < len(args); i++ {
arg := strings.ToUpper(string(args[i])) arg := strings.ToUpper(string(args[i]))
if arg == "NX" { // insert if arg == "NX" { // insert
if policy == updatePolicy { if policy == updatePolicy {
return &reply.SyntaxErrReply{} return &reply.SyntaxErrReply{}
} }
policy = insertPolicy policy = insertPolicy
} else if arg == "XX" { // update policy } else if arg == "XX" { // update policy
if policy == insertPolicy { if policy == insertPolicy {
return &reply.SyntaxErrReply{} return &reply.SyntaxErrReply{}
} }
policy = updatePolicy policy = updatePolicy
} else if arg == "EX" { // ttl in seconds } else if arg == "EX" { // ttl in seconds
if ttl != unlimitedTTL { if ttl != unlimitedTTL {
// ttl has been set // ttl has been set
return &reply.SyntaxErrReply{} return &reply.SyntaxErrReply{}
} }
if i+1 >= len(args) { if i+1 >= len(args) {
return &reply.SyntaxErrReply{} return &reply.SyntaxErrReply{}
} }
ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64) ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64)
if err != nil { if err != nil {
return &reply.SyntaxErrReply{} return &reply.SyntaxErrReply{}
} }
if ttlArg <= 0 { if ttlArg <= 0 {
return reply.MakeErrReply("ERR invalid expire time in set") return reply.MakeErrReply("ERR invalid expire time in set")
} }
ttl = ttlArg * 1000 ttl = ttlArg * 1000
i++ // skip next arg i++ // skip next arg
} else if arg == "PX" { // ttl in milliseconds } else if arg == "PX" { // ttl in milliseconds
if ttl != unlimitedTTL { if ttl != unlimitedTTL {
return &reply.SyntaxErrReply{} return &reply.SyntaxErrReply{}
} }
if i+1 >= len(args) { if i+1 >= len(args) {
return &reply.SyntaxErrReply{} return &reply.SyntaxErrReply{}
} }
ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64) ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64)
if err != nil { if err != nil {
return &reply.SyntaxErrReply{} return &reply.SyntaxErrReply{}
} }
if ttlArg <= 0 { if ttlArg <= 0 {
return reply.MakeErrReply("ERR invalid expire time in set") return reply.MakeErrReply("ERR invalid expire time in set")
} }
ttl = ttlArg ttl = ttlArg
i++ // skip next arg i++ // skip next arg
} else { } else {
return &reply.SyntaxErrReply{} return &reply.SyntaxErrReply{}
} }
} }
} }
entity := &DataEntity{ entity := &DataEntity{
Data: value, Data: value,
} }
db.Persist(key) // clean ttl db.Persist(key) // clean ttl
var result int var result int
switch policy { switch policy {
case upsertPolicy: case upsertPolicy:
result = db.Put(key, entity) result = db.Put(key, entity)
case insertPolicy: case insertPolicy:
result = db.PutIfAbsent(key, entity) result = db.PutIfAbsent(key, entity)
case updatePolicy: case updatePolicy:
result = db.PutIfExists(key, entity) result = db.PutIfExists(key, entity)
} }
/* /*
* 如果设置了ttl 则以最新的ttl为准 * 如果设置了ttl 则以最新的ttl为准
* 如果没有设置ttl 是新增key的情况不设置ttl。 * 如果没有设置ttl 是新增key的情况不设置ttl。
* 如果没有设置ttl 且已存在key的 不修改ttl 但需要增加aof * 如果没有设置ttl 且已存在key的 不修改ttl 但需要增加aof
*/ */
if ttl != unlimitedTTL { if ttl != unlimitedTTL {
expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond) expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond)
db.Expire(key, expireTime) db.Expire(key, expireTime)
db.AddAof(reply.MakeMultiBulkReply([][]byte{ db.AddAof(reply.MakeMultiBulkReply([][]byte{
[]byte("SET"), []byte("SET"),
args[0], args[0],
args[1], args[1],
})) }))
db.AddAof(makeExpireCmd(key, expireTime)) db.AddAof(makeExpireCmd(key, expireTime))
} else if result > 0{ } else if result > 0 {
db.Persist(key) // override ttl db.Persist(key) // override ttl
db.AddAof(makeAofCmd("set", args)) db.AddAof(makeAofCmd("set", args))
}else{ } else {
db.AddAof(makeAofCmd("set", args)) db.AddAof(makeAofCmd("set", args))
} }
if policy == upsertPolicy || result > 0 { if policy == upsertPolicy || result > 0 {
return &reply.OkReply{} return &reply.OkReply{}
} else { } else {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
} }
func SetNX(db *DB, args [][]byte) redis.Reply { func SetNX(db *DB, args [][]byte) redis.Reply {
if len(args) != 2 { if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'setnx' command") return reply.MakeErrReply("ERR wrong number of arguments for 'setnx' command")
} }
key := string(args[0]) key := string(args[0])
value := args[1] value := args[1]
entity := &DataEntity{ entity := &DataEntity{
Data: value, Data: value,
} }
result := db.PutIfAbsent(key, entity) result := db.PutIfAbsent(key, entity)
if result > 0 { db.AddAof(makeAofCmd("setnx", args))
db.AddAof(makeAofCmd("setnx", args)) return reply.MakeIntReply(int64(result))
}
return reply.MakeIntReply(int64(result))
} }
func SetEX(db *DB, args [][]byte) redis.Reply { func SetEX(db *DB, args [][]byte) redis.Reply {
if len(args) != 3 { if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'setex' command") return reply.MakeErrReply("ERR wrong number of arguments for 'setex' command")
} }
key := string(args[0]) key := string(args[0])
value := args[2] value := args[2]
ttlArg, err := strconv.ParseInt(string(args[1]), 10, 64) ttlArg, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil { if err != nil {
return &reply.SyntaxErrReply{} return &reply.SyntaxErrReply{}
} }
if ttlArg <= 0 { if ttlArg <= 0 {
return reply.MakeErrReply("ERR invalid expire time in setex") return reply.MakeErrReply("ERR invalid expire time in setex")
} }
ttl := ttlArg * 1000 ttl := ttlArg * 1000
entity := &DataEntity{ entity := &DataEntity{
Data: value, Data: value,
} }
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
result := db.Put(key, entity) db.Put(key, entity)
expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond) expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond)
db.Expire(key, expireTime) db.Expire(key, expireTime)
if result > 0 { db.AddAof(makeAofCmd("setex", args))
db.AddAof(makeAofCmd("setex", args)) db.AddAof(makeExpireCmd(key, expireTime))
db.AddAof(makeExpireCmd(key, expireTime)) return &reply.OkReply{}
}
return &reply.OkReply{}
} }
func PSetEX(db *DB, args [][]byte) redis.Reply { func PSetEX(db *DB, args [][]byte) redis.Reply {
if len(args) != 3 { if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'psetex' command") return reply.MakeErrReply("ERR wrong number of arguments for 'setex' command")
} }
key := string(args[0]) key := string(args[0])
value := args[1] value := args[2]
ttl, err := strconv.ParseInt(string(args[1]), 10, 64) ttlArg, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil { if err != nil {
return &reply.SyntaxErrReply{} return &reply.SyntaxErrReply{}
} }
if ttl <= 0 { if ttlArg <= 0 {
return reply.MakeErrReply("ERR invalid expire time in psetex") return reply.MakeErrReply("ERR invalid expire time in setex")
} }
entity := &DataEntity{ entity := &DataEntity{
Data: value, Data: value,
} }
result := db.PutIfExists(key, entity)
if result > 0 { db.Lock(key)
db.AddAof(makeAofCmd("psetex", args)) defer db.UnLock(key)
if ttl != unlimitedTTL {
expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond) db.Put(key, entity)
db.Expire(key, expireTime) expireTime := time.Now().Add(time.Duration(ttlArg) * time.Millisecond)
db.AddAof(makeExpireCmd(key, expireTime)) db.Expire(key, expireTime)
} db.AddAof(makeAofCmd("setex", args))
} db.AddAof(makeExpireCmd(key, expireTime))
return &reply.OkReply{}
return &reply.OkReply{}
} }
func MSet(db *DB, args [][]byte) redis.Reply { func MSet(db *DB, args [][]byte) redis.Reply {
if len(args)%2 != 0 || len(args) == 0 { if len(args)%2 != 0 || len(args) == 0 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command") return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
} }
size := len(args) / 2 size := len(args) / 2
keys := make([]string, size) keys := make([]string, size)
values := make([][]byte, size) values := make([][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
keys[i] = string(args[2*i]) keys[i] = string(args[2*i])
values[i] = args[2*i+1] values[i] = args[2*i+1]
} }
db.Locks(keys...) db.Locks(keys...)
defer db.UnLocks(keys...) defer db.UnLocks(keys...)
for i, key := range keys { for i, key := range keys {
value := values[i] value := values[i]
db.Put(key, &DataEntity{Data: value}) db.Put(key, &DataEntity{Data: value})
} }
db.AddAof(makeAofCmd("mset", args)) db.AddAof(makeAofCmd("mset", args))
return &reply.OkReply{} return &reply.OkReply{}
} }
func MGet(db *DB, args [][]byte) redis.Reply { func MGet(db *DB, args [][]byte) redis.Reply {
if len(args) == 0 { if len(args) == 0 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command") return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command")
} }
keys := make([]string, len(args)) keys := make([]string, len(args))
for i, v := range args { for i, v := range args {
keys[i] = string(v) keys[i] = string(v)
} }
result := make([][]byte, len(args)) result := make([][]byte, len(args))
for i, key := range keys { for i, key := range keys {
bytes, err := db.getAsString(key) bytes, err := db.getAsString(key)
if err != nil { if err != nil {
_, isWrongType := err.(*reply.WrongTypeErrReply) _, isWrongType := err.(*reply.WrongTypeErrReply)
if isWrongType { if isWrongType {
result[i] = nil result[i] = nil
continue continue
} else { } else {
return err return err
} }
} }
result[i] = bytes // nil or []byte result[i] = bytes // nil or []byte
} }
return reply.MakeMultiBulkReply(result) return reply.MakeMultiBulkReply(result)
} }
func MSetNX(db *DB, args [][]byte) redis.Reply { func MSetNX(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args)%2 != 0 || len(args) == 0 { if len(args)%2 != 0 || len(args) == 0 {
return reply.MakeErrReply("ERR wrong number of arguments for 'msetnx' command") return reply.MakeErrReply("ERR wrong number of arguments for 'msetnx' command")
} }
size := len(args) / 2 size := len(args) / 2
values := make([][]byte, size) values := make([][]byte, size)
keys := make([]string, size) keys := make([]string, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
keys[i] = string(args[2*i]) keys[i] = string(args[2*i])
values[i] = args[2*i+1] values[i] = args[2*i+1]
} }
// lock keys // lock keys
db.Locks(keys...) db.Locks(keys...)
defer db.UnLocks(keys...) defer db.UnLocks(keys...)
for _, key := range keys { for _, key := range keys {
_, exists := db.Get(key) _, exists := db.Get(key)
if exists { if exists {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
} }
for i, key := range keys { for i, key := range keys {
value := values[i] value := values[i]
db.Put(key, &DataEntity{Data: value}) db.Put(key, &DataEntity{Data: value})
} }
db.AddAof(makeAofCmd("msetnx", args)) db.AddAof(makeAofCmd("msetnx", args))
return reply.MakeIntReply(1) return reply.MakeIntReply(1)
} }
func GetSet(db *DB, args [][]byte) redis.Reply { func GetSet(db *DB, args [][]byte) redis.Reply {
if len(args) != 2 { if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'getset' command") return reply.MakeErrReply("ERR wrong number of arguments for 'getset' command")
} }
key := string(args[0]) key := string(args[0])
value := args[1] value := args[1]
old, err := db.getAsString(key) old, err := db.getAsString(key)
if err != nil { if err != nil {
return err return err
} }
db.Put(key, &DataEntity{Data: value}) db.Put(key, &DataEntity{Data: value})
db.Persist(key) // override ttl db.Persist(key) // override ttl
db.AddAof(makeAofCmd("getset", args)) db.AddAof(makeAofCmd("getset", args))
if old == nil {
return reply.MakeBulkReply(old) return new(reply.NullBulkReply)
}
return reply.MakeBulkReply(old)
} }
func Incr(db *DB, args [][]byte) redis.Reply { func Incr(db *DB, args [][]byte) redis.Reply {
if len(args) != 1 { if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'incr' command") return reply.MakeErrReply("ERR wrong number of arguments for 'incr' command")
} }
key := string(args[0]) key := string(args[0])
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
bytes, err := db.getAsString(key) bytes, err := db.getAsString(key)
if err != nil { if err != nil {
return err return err
} }
if bytes != nil { if bytes != nil {
val, err := strconv.ParseInt(string(bytes), 10, 64) val, err := strconv.ParseInt(string(bytes), 10, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
db.Put(key, &DataEntity{ db.Put(key, &DataEntity{
Data: []byte(strconv.FormatInt(val+1, 10)), Data: []byte(strconv.FormatInt(val+1, 10)),
}) })
db.AddAof(makeAofCmd("incr", args)) db.AddAof(makeAofCmd("incr", args))
return reply.MakeIntReply(val + 1) return reply.MakeIntReply(val + 1)
} else { } else {
db.Put(key, &DataEntity{ db.Put(key, &DataEntity{
Data: []byte("1"), Data: []byte("1"),
}) })
db.AddAof(makeAofCmd("incr", args)) db.AddAof(makeAofCmd("incr", args))
return reply.MakeIntReply(1) return reply.MakeIntReply(1)
} }
} }
func IncrBy(db *DB, args [][]byte) redis.Reply { func IncrBy(db *DB, args [][]byte) redis.Reply {
if len(args) != 2 { if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'incrby' command") return reply.MakeErrReply("ERR wrong number of arguments for 'incrby' command")
} }
key := string(args[0]) key := string(args[0])
rawDelta := string(args[1]) rawDelta := string(args[1])
delta, err := strconv.ParseInt(rawDelta, 10, 64) delta, err := strconv.ParseInt(rawDelta, 10, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
bytes, errReply := db.getAsString(key) bytes, errReply := db.getAsString(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if bytes != nil { if bytes != nil {
val, err := strconv.ParseInt(string(bytes), 10, 64) val, err := strconv.ParseInt(string(bytes), 10, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
db.Put(key, &DataEntity{ db.Put(key, &DataEntity{
Data: []byte(strconv.FormatInt(val+delta, 10)), Data: []byte(strconv.FormatInt(val+delta, 10)),
}) })
db.AddAof(makeAofCmd("incrby", args)) db.AddAof(makeAofCmd("incrby", args))
return reply.MakeIntReply(val + delta) return reply.MakeIntReply(val + delta)
} else { } else {
db.Put(key, &DataEntity{ db.Put(key, &DataEntity{
Data: args[1], Data: args[1],
}) })
db.AddAof(makeAofCmd("incrby", args)) db.AddAof(makeAofCmd("incrby", args))
return reply.MakeIntReply(delta) return reply.MakeIntReply(delta)
} }
} }
func IncrByFloat(db *DB, args [][]byte) redis.Reply { func IncrByFloat(db *DB, args [][]byte) redis.Reply {
if len(args) != 2 { if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'incrbyfloat' command") return reply.MakeErrReply("ERR wrong number of arguments for 'incrbyfloat' command")
} }
key := string(args[0]) key := string(args[0])
rawDelta := string(args[1]) rawDelta := string(args[1])
delta, err := decimal.NewFromString(rawDelta) delta, err := decimal.NewFromString(rawDelta)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not a valid float") return reply.MakeErrReply("ERR value is not a valid float")
} }
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
bytes, errReply := db.getAsString(key) bytes, errReply := db.getAsString(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if bytes != nil { if bytes != nil {
val, err := decimal.NewFromString(string(bytes)) val, err := decimal.NewFromString(string(bytes))
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not a valid float") return reply.MakeErrReply("ERR value is not a valid float")
} }
resultBytes := []byte(val.Add(delta).String()) resultBytes := []byte(val.Add(delta).String())
db.Put(key, &DataEntity{ db.Put(key, &DataEntity{
Data: resultBytes, Data: resultBytes,
}) })
db.AddAof(makeAofCmd("incrbyfloat", args)) db.AddAof(makeAofCmd("incrbyfloat", args))
return reply.MakeBulkReply(resultBytes) return reply.MakeBulkReply(resultBytes)
} else { } else {
db.Put(key, &DataEntity{ db.Put(key, &DataEntity{
Data: args[1], Data: args[1],
}) })
db.AddAof(makeAofCmd("incrbyfloat", args)) db.AddAof(makeAofCmd("incrbyfloat", args))
return reply.MakeBulkReply(args[1]) return reply.MakeBulkReply(args[1])
} }
} }
func Decr(db *DB, args [][]byte) redis.Reply { func Decr(db *DB, args [][]byte) redis.Reply {
if len(args) != 1 { if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'decr' command") return reply.MakeErrReply("ERR wrong number of arguments for 'decr' command")
} }
key := string(args[0]) key := string(args[0])
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
bytes, errReply := db.getAsString(key) bytes, errReply := db.getAsString(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if bytes != nil { if bytes != nil {
val, err := strconv.ParseInt(string(bytes), 10, 64) val, err := strconv.ParseInt(string(bytes), 10, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
db.Put(key, &DataEntity{ db.Put(key, &DataEntity{
Data: []byte(strconv.FormatInt(val-1, 10)), Data: []byte(strconv.FormatInt(val-1, 10)),
}) })
db.AddAof(makeAofCmd("decr", args)) db.AddAof(makeAofCmd("decr", args))
return reply.MakeIntReply(val - 1) return reply.MakeIntReply(val - 1)
} else { } else {
entity := &DataEntity{ entity := &DataEntity{
Data: []byte("-1"), Data: []byte("-1"),
} }
db.Put(key, entity) db.Put(key, entity)
db.AddAof(makeAofCmd("decr", args)) db.AddAof(makeAofCmd("decr", args))
return reply.MakeIntReply(-1) return reply.MakeIntReply(-1)
} }
} }
func DecrBy(db *DB, args [][]byte) redis.Reply { func DecrBy(db *DB, args [][]byte) redis.Reply {
if len(args) != 2 { if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'decrby' command") return reply.MakeErrReply("ERR wrong number of arguments for 'decrby' command")
} }
key := string(args[0]) key := string(args[0])
rawDelta := string(args[1]) rawDelta := string(args[1])
delta, err := strconv.ParseInt(rawDelta, 10, 64) delta, err := strconv.ParseInt(rawDelta, 10, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
bytes, errReply := db.getAsString(key) bytes, errReply := db.getAsString(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if bytes != nil { if bytes != nil {
val, err := strconv.ParseInt(string(bytes), 10, 64) val, err := strconv.ParseInt(string(bytes), 10, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
db.Put(key, &DataEntity{ db.Put(key, &DataEntity{
Data: []byte(strconv.FormatInt(val-delta, 10)), Data: []byte(strconv.FormatInt(val-delta, 10)),
}) })
db.AddAof(makeAofCmd("decrby", args)) db.AddAof(makeAofCmd("decrby", args))
return reply.MakeIntReply(val - delta) return reply.MakeIntReply(val - delta)
} else { } else {
valueStr := strconv.FormatInt(-delta, 10) valueStr := strconv.FormatInt(-delta, 10)
db.Put(key, &DataEntity{ db.Put(key, &DataEntity{
Data: []byte(valueStr), Data: []byte(valueStr),
}) })
db.AddAof(makeAofCmd("decrby", args)) db.AddAof(makeAofCmd("decrby", args))
return reply.MakeIntReply(-delta) return reply.MakeIntReply(-delta)
} }
} }

View File

@@ -1,152 +1,288 @@
package db package db
import ( import (
"github.com/HDT3213/godis/src/datastruct/utils" "fmt"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/datastruct/utils"
"math/rand" "github.com/HDT3213/godis/src/redis/reply"
"strconv" "github.com/HDT3213/godis/src/redis/reply/asserts"
"testing" "strconv"
"testing"
) )
var testDB = makeTestDB() var testDB = makeTestDB()
func TestSet(t *testing.T) { func TestSet(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
// normal set // normal set
Set(testDB, toArgs(key, value)) Set(testDB, toArgs(key, value))
actual := Get(testDB, toArgs(key)) actual := Get(testDB, toArgs(key))
expected := reply.MakeBulkReply([]byte(value)) expected := reply.MakeBulkReply([]byte(value))
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes()))
} }
// set nx // set nx
actual = Set(testDB, toArgs(key, value, "NX")) actual = Set(testDB, toArgs(key, value, "NX"))
if _, ok := actual.(*reply.NullBulkReply); !ok { if _, ok := actual.(*reply.NullBulkReply); !ok {
t.Error("expected true actual false") t.Error("expected true actual false")
} }
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
value = strconv.FormatInt(int64(rand.Int()), 10) value = RandString(10)
Set(testDB, toArgs(key, value, "NX")) Set(testDB, toArgs(key, value, "NX"))
actual = Get(testDB, toArgs(key)) actual = Get(testDB, toArgs(key))
expected = reply.MakeBulkReply([]byte(value)) expected = reply.MakeBulkReply([]byte(value))
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes()))
} }
// set xx // set xx
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
value = strconv.FormatInt(int64(rand.Int()), 10) value = RandString(10)
actual = Set(testDB, toArgs(key, value, "XX")) actual = Set(testDB, toArgs(key, value, "XX"))
if _, ok := actual.(*reply.NullBulkReply); !ok { if _, ok := actual.(*reply.NullBulkReply); !ok {
t.Error("expected true actually false ") t.Error("expected true actually false ")
} }
Set(testDB, toArgs(key, value)) Set(testDB, toArgs(key, value))
Set(testDB, toArgs(key, value, "XX")) Set(testDB, toArgs(key, value, "XX"))
actual = Get(testDB, toArgs(key)) actual = Get(testDB, toArgs(key))
expected = reply.MakeBulkReply([]byte(value)) asserts.AssertBulkReply(t, actual, value)
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes()))
}
// set ex
Del(testDB, toArgs(key))
ttl := "1000"
Set(testDB, toArgs(key, value, "EX", ttl))
actual = Get(testDB, toArgs(key))
asserts.AssertBulkReply(t, actual, value)
actual = TTL(testDB, toArgs(key))
intResult, ok := actual.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes()))
return
}
if intResult.Code <= 0 || intResult.Code > 1000 {
t.Error(fmt.Sprintf("expected int between [0, 1000], actually %d", intResult.Code))
return
}
// set px
Del(testDB, toArgs(key))
ttlPx := "1000000"
Set(testDB, toArgs(key, value, "PX", ttlPx))
actual = Get(testDB, toArgs(key))
asserts.AssertBulkReply(t, actual, value)
actual = TTL(testDB, toArgs(key))
intResult, ok = actual.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes()))
return
}
if intResult.Code <= 0 || intResult.Code > 1000 {
t.Error(fmt.Sprintf("expected int between [0, 1000], actually %d", intResult.Code))
return
}
} }
func TestSetNX(t *testing.T) { func TestSetNX(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
SetNX(testDB, toArgs(key, value)) SetNX(testDB, toArgs(key, value))
actual := Get(testDB, toArgs(key)) actual := Get(testDB, toArgs(key))
expected := reply.MakeBulkReply([]byte(value)) expected := reply.MakeBulkReply([]byte(value))
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes()))
} }
actual = SetNX(testDB, toArgs(key, value)) actual = SetNX(testDB, toArgs(key, value))
expected2 := reply.MakeIntReply(int64(0)) expected2 := reply.MakeIntReply(int64(0))
if !utils.BytesEquals(actual.ToBytes(), expected2.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected2.ToBytes()) {
t.Error("expected: " + string(expected2.ToBytes()) + ", actual: " + string(actual.ToBytes())) t.Error("expected: " + string(expected2.ToBytes()) + ", actual: " + string(actual.ToBytes()))
} }
} }
func TestSetEX(t *testing.T) { func TestSetEX(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
ttl := "1000" ttl := "1000"
SetEX(testDB, toArgs(key, ttl, value)) SetEX(testDB, toArgs(key, ttl, value))
actual := Get(testDB, toArgs(key)) actual := Get(testDB, toArgs(key))
expected2 := reply.MakeBulkReply([]byte(value)) asserts.AssertBulkReply(t, actual, value)
if !utils.BytesEquals(actual.ToBytes(), expected2.ToBytes()) { actual = TTL(testDB, toArgs(key))
t.Error("expected: " + string(expected2.ToBytes()) + ", actual: " + string(actual.ToBytes())) intResult, ok := actual.(*reply.IntReply)
} if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes()))
return
}
if intResult.Code <= 0 || intResult.Code > 1000 {
t.Error(fmt.Sprintf("expected int between [0, 1000], actually %d", intResult.Code))
return
}
}
func TestPSetEX(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
ttl := "1000000"
PSetEX(testDB, toArgs(key, ttl, value))
actual := Get(testDB, toArgs(key))
asserts.AssertBulkReply(t, actual, value)
actual = PTTL(testDB, toArgs(key))
intResult, ok := actual.(*reply.IntReply)
if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes()))
return
}
if intResult.Code <= 0 || intResult.Code > 1000000 {
t.Error(fmt.Sprintf("expected int between [0, 1000], actually %d", intResult.Code))
return
}
} }
func TestMSet(t *testing.T) { func TestMSet(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 10 size := 10
keys := make([]string, size) keys := make([]string, size)
values := make([][]byte, size) values := make([][]byte, size)
args := make([]string, 0, size*2) args := make([]string, 0, size*2)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
keys[i] = strconv.FormatInt(int64(rand.Int()), 10) keys[i] = RandString(10)
value := strconv.FormatInt(int64(rand.Int()), 10) value := RandString(10)
values[i] = []byte(value) values[i] = []byte(value)
args = append(args, keys[i], value) args = append(args, keys[i], value)
} }
MSet(testDB, toArgs(args...)) MSet(testDB, toArgs(args...))
actual := MGet(testDB, toArgs(keys...)) actual := MGet(testDB, toArgs(keys...))
expected := reply.MakeMultiBulkReply(values) expected := reply.MakeMultiBulkReply(values)
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes()))
} }
} }
func TestIncr(t *testing.T) { func TestIncr(t *testing.T) {
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
size := 10 size := 10
key := strconv.FormatInt(int64(rand.Int()), 10) key := RandString(10)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
Incr(testDB, toArgs(key)) Incr(testDB, toArgs(key))
actual := Get(testDB, toArgs(key)) actual := Get(testDB, toArgs(key))
expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(i+1), 10))) expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(i+1), 10)))
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes()))
} }
} }
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
IncrBy(testDB, toArgs(key, "-1")) IncrBy(testDB, toArgs(key, "-1"))
actual := Get(testDB, toArgs(key)) actual := Get(testDB, toArgs(key))
expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(size-i-1), 10))) expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(size-i-1), 10)))
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes()))
} }
} }
FlushAll(testDB, [][]byte{}) FlushAll(testDB, [][]byte{})
key = strconv.FormatInt(int64(rand.Int()), 10) key = RandString(10)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
IncrBy(testDB, toArgs(key, "1")) IncrBy(testDB, toArgs(key, "1"))
actual := Get(testDB, toArgs(key)) actual := Get(testDB, toArgs(key))
expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(i+1), 10))) expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(i+1), 10)))
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) {
t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes()))
} }
} }
for i := 0; i < size; i++ { Del(testDB, toArgs(key))
IncrByFloat(testDB, toArgs(key, "-1.0")) for i := 0; i < size; i++ {
actual := Get(testDB, toArgs(key)) IncrByFloat(testDB, toArgs(key, "-1.0"))
expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(size-i-1), 10))) actual := Get(testDB, toArgs(key))
if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { expected := -i - 1
t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) bulk, ok := actual.(*reply.BulkReply)
} if !ok {
} t.Error(fmt.Sprintf("expected bulk reply, actually %s", actual.ToBytes()))
return
}
val, err := strconv.ParseFloat(string(bulk.Arg), 10)
if err != nil {
t.Error(err)
return
}
if int(val) != expected {
t.Errorf("expect %d, actual: %d", expected, int(val))
return
}
}
}
func TestDecr(t *testing.T) {
FlushAll(testDB, [][]byte{})
size := 10
key := RandString(10)
for i := 0; i < size; i++ {
Decr(testDB, toArgs(key))
actual := Get(testDB, toArgs(key))
asserts.AssertBulkReply(t, actual, strconv.Itoa(-i-1))
}
Del(testDB, toArgs(key))
for i := 0; i < size; i++ {
DecrBy(testDB, toArgs(key, "1"))
actual := Get(testDB, toArgs(key))
expected := -i - 1
bulk, ok := actual.(*reply.BulkReply)
if !ok {
t.Error(fmt.Sprintf("expected bulk reply, actually %s", actual.ToBytes()))
return
}
val, err := strconv.ParseFloat(string(bulk.Arg), 10)
if err != nil {
t.Error(err)
return
}
if int(val) != expected {
t.Errorf("expect %d, actual: %d", expected, int(val))
return
}
}
}
func TestGetSet(t *testing.T) {
FlushAll(testDB, [][]byte{})
key := RandString(10)
value := RandString(10)
result := GetSet(testDB, toArgs(key, value))
_, ok := result.(*reply.NullBulkReply)
if !ok {
t.Errorf("expect null bulk reply, get: %s", string(result.ToBytes()))
return
}
value2 := RandString(10)
result = GetSet(testDB, toArgs(key, value2))
asserts.AssertBulkReply(t, result, value)
result = Get(testDB, toArgs(key))
asserts.AssertBulkReply(t, result, value2)
}
func TestMSetNX(t *testing.T) {
FlushAll(testDB, [][]byte{})
size := 10
args := make([]string, 0, size*2)
for i := 0; i < size; i++ {
str := RandString(10)
args = append(args, str, str)
}
result := MSetNX(testDB, toArgs(args...))
asserts.AssertIntReply(t, result, 1)
result = MSetNX(testDB, toArgs(args[0:4]...))
asserts.AssertIntReply(t, result, 0)
} }

View File

@@ -1,24 +1,35 @@
package db package db
import ( import (
"github.com/HDT3213/godis/src/datastruct/dict" "github.com/HDT3213/godis/src/datastruct/dict"
"github.com/HDT3213/godis/src/datastruct/lock" "github.com/HDT3213/godis/src/datastruct/lock"
"time" "math/rand"
"time"
) )
func makeTestDB() *DB { func makeTestDB() *DB {
return &DB{ return &DB{
Data: dict.MakeConcurrent(1), Data: dict.MakeConcurrent(1),
TTLMap: dict.MakeConcurrent(ttlDictSize), TTLMap: dict.MakeConcurrent(ttlDictSize),
Locker: lock.Make(lockerSize), Locker: lock.Make(lockerSize),
interval: 5 * time.Second, interval: 5 * time.Second,
} }
} }
func toArgs(cmd ...string) [][]byte { func toArgs(cmd ...string) [][]byte {
args := make([][]byte, len(cmd)) args := make([][]byte, len(cmd))
for i, s := range cmd { for i, s := range cmd {
args[i] = []byte(s) args[i] = []byte(s)
} }
return args return args
}
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
func RandString(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
} }

View File

@@ -1,49 +1,82 @@
package asserts package asserts
import ( import (
"fmt" "fmt"
"github.com/HDT3213/godis/src/datastruct/utils" "github.com/HDT3213/godis/src/datastruct/utils"
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"testing" "runtime"
"testing"
) )
func AssertIntReply(t *testing.T, actual redis.Reply, expected int) { func AssertIntReply(t *testing.T, actual redis.Reply, expected int) {
intResult, ok := actual.(*reply.IntReply) intResult, ok := actual.(*reply.IntReply)
if !ok { if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes())) t.Errorf("expected int reply, actually %s, %s", actual.ToBytes(), printStack())
return return
} }
if intResult.Code != int64(expected) { if intResult.Code != int64(expected) {
t.Error(fmt.Sprintf("expected %d, actually %d", expected, intResult.Code)) t.Errorf("expected %d, actually %d, %s", expected, intResult.Code, printStack())
} }
} }
func AssertBulkReply(t *testing.T, actual redis.Reply, expected string) { func AssertBulkReply(t *testing.T, actual redis.Reply, expected string) {
bulkReply, ok := actual.(*reply.BulkReply) bulkReply, ok := actual.(*reply.BulkReply)
if !ok { if !ok {
t.Error(fmt.Sprintf("expected bulk reply, actually %s", actual.ToBytes())) t.Errorf("expected bulk reply, actually %s, %s", actual.ToBytes(), printStack())
return return
} }
if !utils.BytesEquals(bulkReply.Arg, []byte(expected)) { if !utils.BytesEquals(bulkReply.Arg, []byte(expected)) {
t.Error(fmt.Sprintf("expected %s, actually %s", expected, actual.ToBytes())) t.Errorf("expected %s, actually %s, %s", expected, actual.ToBytes(), printStack())
} }
}
func AssertStatusReply(t *testing.T, actual redis.Reply, expected string) {
statusReply, ok := actual.(*reply.StatusReply)
if !ok {
t.Errorf("expected bulk reply, actually %s, %s", actual.ToBytes(), printStack())
return
}
if statusReply.Status != expected {
t.Errorf("expected %s, actually %s, %s", expected, actual.ToBytes(), printStack())
}
} }
func AssertMultiBulkReply(t *testing.T, actual redis.Reply, expected []string) { func AssertMultiBulkReply(t *testing.T, actual redis.Reply, expected []string) {
multiBulk, ok := actual.(*reply.MultiBulkReply) multiBulk, ok := actual.(*reply.MultiBulkReply)
if !ok { if !ok {
t.Error(fmt.Sprintf("expected bulk reply, actually %s", actual.ToBytes())) t.Errorf("expected bulk reply, actually %s, %s", actual.ToBytes(), printStack())
return return
} }
if len(multiBulk.Args) != len(expected) { if len(multiBulk.Args) != len(expected) {
t.Error(fmt.Sprintf("expected %d elements, actually %d", len(expected), len(multiBulk.Args))) t.Errorf("expected %d elements, actually %d, %s",
return len(expected), len(multiBulk.Args), printStack())
} return
for i, v := range multiBulk.Args { }
actual := string(v) for i, v := range multiBulk.Args {
if actual != expected[i] { str := string(v)
t.Error(fmt.Sprintf("expected %s, actually %s", expected[i], actual)) if str != expected[i] {
} t.Errorf("expected %s, actually %s, %s", expected[i], actual, printStack())
} }
}
}
func AssertMultiBulkReplySize(t *testing.T, actual redis.Reply, expected int) {
multiBulk, ok := actual.(*reply.MultiBulkReply)
if !ok {
t.Errorf("expected bulk reply, actually %s, %s", actual.ToBytes(), printStack())
return
}
if len(multiBulk.Args) != expected {
t.Errorf("expected %d elements, actually %d, %s", expected, len(multiBulk.Args), printStack())
return
}
}
func printStack() string {
_, file, no, ok := runtime.Caller(2)
if ok {
return fmt.Sprintf("at %s#%d", file, no)
}
return ""
} }