From bf913a5aca4d0a22fa93b58acf21a0fb4b6dcdc0 Mon Sep 17 00:00:00 2001 From: hdt3213 Date: Wed, 31 Mar 2021 17:11:46 +0800 Subject: [PATCH] add some unit tests --- src/cluster/cluster.go | 220 +++----- src/cluster/com.go | 52 ++ src/datastruct/set/set.go | 164 +++--- src/db/geo.go | 424 +++++++-------- src/db/geo_test.go | 87 +++ src/db/hash_test.go | 351 ++++++------ src/db/keys.go | 11 +- src/db/keys_test.go | 197 +++++++ src/db/list_test.go | 656 ++++++++++++----------- src/db/set.go | 860 +++++++++++++++--------------- src/db/set_test.go | 181 +++++++ src/db/sortedset_test.go | 486 ++++++++--------- src/db/string.go | 849 +++++++++++++++-------------- src/db/string_test.go | 382 ++++++++----- src/db/util_test.go | 39 +- src/redis/reply/asserts/assert.go | 105 ++-- 16 files changed, 2866 insertions(+), 2198 deletions(-) create mode 100644 src/cluster/com.go create mode 100644 src/db/geo_test.go create mode 100644 src/db/keys_test.go create mode 100644 src/db/set_test.go diff --git a/src/cluster/cluster.go b/src/cluster/cluster.go index e9264c9..896908a 100644 --- a/src/cluster/cluster.go +++ b/src/cluster/cluster.go @@ -1,177 +1,133 @@ package cluster import ( - "context" - "errors" - "fmt" - "github.com/HDT3213/godis/src/cluster/idgenerator" - "github.com/HDT3213/godis/src/config" - "github.com/HDT3213/godis/src/datastruct/dict" - "github.com/HDT3213/godis/src/db" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/lib/consistenthash" - "github.com/HDT3213/godis/src/lib/logger" - "github.com/HDT3213/godis/src/redis/client" - "github.com/HDT3213/godis/src/redis/reply" - "github.com/jolestar/go-commons-pool/v2" - "runtime/debug" - "strings" + "context" + "fmt" + "github.com/HDT3213/godis/src/cluster/idgenerator" + "github.com/HDT3213/godis/src/config" + "github.com/HDT3213/godis/src/datastruct/dict" + "github.com/HDT3213/godis/src/db" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/lib/consistenthash" + "github.com/HDT3213/godis/src/lib/logger" + "github.com/HDT3213/godis/src/redis/reply" + "github.com/jolestar/go-commons-pool/v2" + "runtime/debug" + "strings" ) type Cluster struct { - self string + self string - peerPicker *consistenthash.Map - peerConnection map[string]*pool.ObjectPool + peerPicker *consistenthash.Map + peerConnection map[string]*pool.ObjectPool - db *db.DB - transactions *dict.SimpleDict // id -> Transaction + db *db.DB + transactions *dict.SimpleDict // id -> Transaction - idGenerator *idgenerator.IdGenerator + idGenerator *idgenerator.IdGenerator } const ( - replicas = 4 - lockSize = 64 + replicas = 4 + lockSize = 64 ) func MakeCluster() *Cluster { - cluster := &Cluster{ - self: config.Properties.Self, + cluster := &Cluster{ + self: config.Properties.Self, - db: db.MakeDB(), - transactions: dict.MakeSimple(), - peerPicker: consistenthash.New(replicas, nil), - peerConnection: make(map[string]*pool.ObjectPool), + db: db.MakeDB(), + transactions: dict.MakeSimple(), + peerPicker: consistenthash.New(replicas, nil), + peerConnection: make(map[string]*pool.ObjectPool), - idGenerator: idgenerator.MakeGenerator("godis", config.Properties.Self), - } - if config.Properties.Peers != nil && len(config.Properties.Peers) > 0 && config.Properties.Self != "" { - contains := make(map[string]bool) - peers := make([]string, 0, len(config.Properties.Peers)+1) - for _, peer := range config.Properties.Peers { - if _, ok := contains[peer]; ok { - continue - } - contains[peer] = true - peers = append(peers, peer) - } - peers = append(peers, config.Properties.Self) - cluster.peerPicker.Add(peers...) - ctx := context.Background() - for _, peer := range peers { - cluster.peerConnection[peer] = pool.NewObjectPoolWithDefaultConfig(ctx, &ConnectionFactory{ - Peer: peer, - }) - } - } - return cluster + idGenerator: idgenerator.MakeGenerator("godis", config.Properties.Self), + } + if config.Properties.Peers != nil && len(config.Properties.Peers) > 0 && config.Properties.Self != "" { + contains := make(map[string]bool) + peers := make([]string, 0, len(config.Properties.Peers)+1) + for _, peer := range config.Properties.Peers { + if _, ok := contains[peer]; ok { + continue + } + contains[peer] = true + peers = append(peers, peer) + } + peers = append(peers, config.Properties.Self) + cluster.peerPicker.Add(peers...) + ctx := context.Background() + for _, peer := range peers { + cluster.peerConnection[peer] = pool.NewObjectPoolWithDefaultConfig(ctx, &ConnectionFactory{ + Peer: peer, + }) + } + } + return cluster } // args contains all type CmdFunc func(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply func (cluster *Cluster) Close() { - cluster.db.Close() + cluster.db.Close() } var router = MakeRouter() func (cluster *Cluster) Exec(c redis.Connection, args [][]byte) (result redis.Reply) { - defer func() { - if err := recover(); err != nil { - logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack()))) - result = &reply.UnknownErrReply{} - } - }() + defer func() { + if err := recover(); err != nil { + logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack()))) + result = &reply.UnknownErrReply{} + } + }() - cmd := strings.ToLower(string(args[0])) - cmdFunc, ok := router[cmd] - if !ok { - return reply.MakeErrReply("ERR unknown command '" + cmd + "', or not supported in cluster mode") - } - result = cmdFunc(cluster, c, args) - return + cmd := strings.ToLower(string(args[0])) + cmdFunc, ok := router[cmd] + if !ok { + return reply.MakeErrReply("ERR unknown command '" + cmd + "', or not supported in cluster mode") + } + result = cmdFunc(cluster, c, args) + return } func (cluster *Cluster) AfterClientClose(c redis.Connection) { } -func (cluster *Cluster) getPeerClient(peer string) (*client.Client, error) { - connectionFactory, ok := cluster.peerConnection[peer] - if !ok { - return nil, errors.New("connection factory not found") - } - raw, err := connectionFactory.BorrowObject(context.Background()) - if err != nil { - return nil, err - } - conn, ok := raw.(*client.Client) - if !ok { - return nil, errors.New("connection factory make wrong type") - } - return conn, nil -} - -func (cluster *Cluster) returnPeerClient(peer string, peerClient *client.Client) error { - connectionFactory, ok := cluster.peerConnection[peer] - if !ok { - return errors.New("connection factory not found") - } - return connectionFactory.ReturnObject(context.Background(), peerClient) -} - func Ping(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) == 1 { - return &reply.PongReply{} - } else if len(args) == 2 { - return reply.MakeStatusReply("\"" + string(args[1]) + "\"") - } else { - return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command") - } -} - -// relay command to peer -// cannot call Prepare, Commit, Rollback of self node -func (cluster *Cluster) Relay(peer string, c redis.Connection, args [][]byte) redis.Reply { - if peer == cluster.self { - // to self db - return cluster.db.Exec(c, args) - } else { - peerClient, err := cluster.getPeerClient(peer) - if err != nil { - return reply.MakeErrReply(err.Error()) - } - defer func() { - _ = cluster.returnPeerClient(peer, peerClient) - }() - return peerClient.Send(args) - } + if len(args) == 1 { + return &reply.PongReply{} + } else if len(args) == 2 { + return reply.MakeStatusReply("\"" + string(args[1]) + "\"") + } else { + return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command") + } } /*----- utils -------*/ func makeArgs(cmd string, args ...string) [][]byte { - result := make([][]byte, len(args)+1) - result[0] = []byte(cmd) - for i, arg := range args { - result[i+1] = []byte(arg) - } - return result + result := make([][]byte, len(args)+1) + result[0] = []byte(cmd) + for i, arg := range args { + result[i+1] = []byte(arg) + } + return result } // return peer -> keys func (cluster *Cluster) groupBy(keys []string) map[string][]string { - result := make(map[string][]string) - for _, key := range keys { - peer := cluster.peerPicker.Get(key) - group, ok := result[peer] - if !ok { - group = make([]string, 0) - } - group = append(group, key) - result[peer] = group - } - return result -} \ No newline at end of file + result := make(map[string][]string) + for _, key := range keys { + peer := cluster.peerPicker.Get(key) + group, ok := result[peer] + if !ok { + group = make([]string, 0) + } + group = append(group, key) + result[peer] = group + } + return result +} diff --git a/src/cluster/com.go b/src/cluster/com.go new file mode 100644 index 0000000..ba46199 --- /dev/null +++ b/src/cluster/com.go @@ -0,0 +1,52 @@ +// communicate with peers within cluster +package cluster + +import ( + "context" + "errors" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/client" + "github.com/HDT3213/godis/src/redis/reply" +) + +func (cluster *Cluster) getPeerClient(peer string) (*client.Client, error) { + connectionFactory, ok := cluster.peerConnection[peer] + if !ok { + return nil, errors.New("connection factory not found") + } + raw, err := connectionFactory.BorrowObject(context.Background()) + if err != nil { + return nil, err + } + conn, ok := raw.(*client.Client) + if !ok { + return nil, errors.New("connection factory make wrong type") + } + return conn, nil +} + +func (cluster *Cluster) returnPeerClient(peer string, peerClient *client.Client) error { + connectionFactory, ok := cluster.peerConnection[peer] + if !ok { + return errors.New("connection factory not found") + } + return connectionFactory.ReturnObject(context.Background(), peerClient) +} + +// relay command to peer +// cannot call Prepare, Commit, Rollback of self node +func (cluster *Cluster) Relay(peer string, c redis.Connection, args [][]byte) redis.Reply { + if peer == cluster.self { + // to self db + return cluster.db.Exec(c, args) + } else { + peerClient, err := cluster.getPeerClient(peer) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + defer func() { + _ = cluster.returnPeerClient(peer, peerClient) + }() + return peerClient.Send(args) + } +} diff --git a/src/datastruct/set/set.go b/src/datastruct/set/set.go index 36daf91..c52d1c7 100644 --- a/src/datastruct/set/set.go +++ b/src/datastruct/set/set.go @@ -3,114 +3,114 @@ package set import "github.com/HDT3213/godis/src/datastruct/dict" type Set struct { - dict dict.Dict + dict dict.Dict } func Make() *Set { - return &Set{ - dict: dict.MakeSimple(), - } + return &Set{ + dict: dict.MakeSimple(), + } } -func MakeFromVals(members ...string)*Set { - set := &Set{ - dict: dict.MakeConcurrent(len(members)), - } - for _, member := range members { - set.Add(member) - } - return set +func MakeFromVals(members ...string) *Set { + set := &Set{ + dict: dict.MakeConcurrent(len(members)), + } + for _, member := range members { + set.Add(member) + } + return set } -func (set *Set)Add(val string)int { - return set.dict.Put(val, true) +func (set *Set) Add(val string) int { + return set.dict.Put(val, nil) } -func (set *Set)Remove(val string)int { - return set.dict.Remove(val) +func (set *Set) Remove(val string) int { + return set.dict.Remove(val) } -func (set *Set)Has(val string)bool { - _, exists := set.dict.Get(val) - return exists +func (set *Set) Has(val string) bool { + _, exists := set.dict.Get(val) + return exists } -func (set *Set)Len()int { - return set.dict.Len() +func (set *Set) Len() int { + return set.dict.Len() } -func (set *Set)ToSlice()[]string { - slice := make([]string, set.Len()) - i := 0 - set.dict.ForEach(func(key string, val interface{})bool { - if i < len(slice) { - slice[i] = key - } else { - // set extended during traversal - slice = append(slice, key) - } - i++ - return true - }) - return slice +func (set *Set) ToSlice() []string { + slice := make([]string, set.Len()) + i := 0 + set.dict.ForEach(func(key string, val interface{}) bool { + if i < len(slice) { + slice[i] = key + } else { + // set extended during traversal + slice = append(slice, key) + } + i++ + return true + }) + return slice } -func (set *Set)ForEach(consumer func(member string)bool) { - set.dict.ForEach(func(key string, val interface{})bool { - return consumer(key) - }) +func (set *Set) ForEach(consumer func(member string) bool) { + set.dict.ForEach(func(key string, val interface{}) bool { + return consumer(key) + }) } -func (set *Set)Intersect(another *Set)*Set { - if set == nil { - panic("set is nil") - } +func (set *Set) Intersect(another *Set) *Set { + if set == nil { + panic("set is nil") + } - result := Make() - another.ForEach(func(member string)bool { - if set.Has(member) { - result.Add(member) - } - return true - }) - return result + result := Make() + another.ForEach(func(member string) bool { + if set.Has(member) { + result.Add(member) + } + return true + }) + return result } -func (set *Set)Union(another *Set)*Set { - if set == nil { - panic("set is nil") - } - result := Make() - another.ForEach(func(member string)bool { - result.Add(member) - return true - }) - set.ForEach(func(member string)bool { - result.Add(member) - return true - }) - return result +func (set *Set) Union(another *Set) *Set { + if set == nil { + panic("set is nil") + } + result := Make() + another.ForEach(func(member string) bool { + result.Add(member) + return true + }) + set.ForEach(func(member string) bool { + result.Add(member) + return true + }) + return result } -func (set *Set)Diff(another *Set)*Set { - if set == nil { - panic("set is nil") - } +func (set *Set) Diff(another *Set) *Set { + if set == nil { + panic("set is nil") + } - result := Make() - set.ForEach(func(member string)bool { - if !another.Has(member) { - result.Add(member) - } - return true - }) - return result + result := Make() + set.ForEach(func(member string) bool { + if !another.Has(member) { + result.Add(member) + } + return true + }) + return result } -func (set *Set)RandomMembers(limit int)[]string { - return set.dict.RandomKeys(limit) +func (set *Set) RandomMembers(limit int) []string { + return set.dict.RandomKeys(limit) } -func (set *Set)RandomDistinctMembers(limit int)[]string { - return set.dict.RandomDistinctKeys(limit) -} \ No newline at end of file +func (set *Set) RandomDistinctMembers(limit int) []string { + return set.dict.RandomDistinctKeys(limit) +} diff --git a/src/db/geo.go b/src/db/geo.go index 3bd1196..05bd0c6 100644 --- a/src/db/geo.go +++ b/src/db/geo.go @@ -1,249 +1,251 @@ package db import ( - "fmt" - "github.com/HDT3213/godis/src/datastruct/sortedset" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/lib/geohash" - "github.com/HDT3213/godis/src/redis/reply" - "strconv" - "strings" + "fmt" + "github.com/HDT3213/godis/src/datastruct/sortedset" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/lib/geohash" + "github.com/HDT3213/godis/src/redis/reply" + "strconv" + "strings" ) func GeoAdd(db *DB, args [][]byte) redis.Reply { - if len(args) < 4 || len(args)%3 != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'geoadd' command") - } - key := string(args[0]) - size := (len(args) - 1) / 3 - elements := make([]*sortedset.Element, size) - for i := 0; i < size; i += 1 { - lngStr := string(args[3*i+1]) - latStr := string(args[3*i+2]) - lng, err := strconv.ParseFloat(lngStr, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not a valid float") - } - lat, err := strconv.ParseFloat(latStr, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not a valid float") - } - if lat < -90 || lat > 90 || lng < -180 || lng > 180 { - return reply.MakeErrReply(fmt.Sprintf("ERR invalid longitude,latitude pair %s,%s", latStr, lngStr)) - } - code := float64(geohash.Encode(lat, lng)) - elements[i] = &sortedset.Element{ - Member: string(args[3*i+3]), - Score: code, - } - } + if len(args) < 4 || len(args)%3 != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'geoadd' command") + } + key := string(args[0]) + size := (len(args) - 1) / 3 + elements := make([]*sortedset.Element, size) + for i := 0; i < size; i += 1 { + lngStr := string(args[3*i+1]) + latStr := string(args[3*i+2]) + lng, err := strconv.ParseFloat(lngStr, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not a valid float") + } + lat, err := strconv.ParseFloat(latStr, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not a valid float") + } + if lat < -90 || lat > 90 || lng < -180 || lng > 180 { + return reply.MakeErrReply(fmt.Sprintf("ERR invalid longitude,latitude pair %s,%s", latStr, lngStr)) + } + code := float64(geohash.Encode(lat, lng)) + elements[i] = &sortedset.Element{ + Member: string(args[3*i+3]), + Score: code, + } + } - // lock - db.Lock(key) - defer db.UnLock(key) + // lock + db.Lock(key) + defer db.UnLock(key) - // get or init entity - sortedSet, _, errReply := db.getOrInitSortedSet(key) - if errReply != nil { - return errReply - } + // get or init entity + sortedSet, _, errReply := db.getOrInitSortedSet(key) + if errReply != nil { + return errReply + } - i := 0 - for _, e := range elements { - if sortedSet.Add(e.Member, e.Score) { - i++ - } - } + i := 0 + for _, e := range elements { + if sortedSet.Add(e.Member, e.Score) { + i++ + } + } - db.AddAof(makeAofCmd("geoadd", args)) + db.AddAof(makeAofCmd("geoadd", args)) - return reply.MakeIntReply(int64(i)) + return reply.MakeIntReply(int64(i)) } func GeoPos(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) < 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'geopos' command") - } - key := string(args[0]) - sortedSet, errReply := db.getAsSortedSet(key) - if errReply != nil { - return errReply - } - if sortedSet == nil { - return &reply.NullBulkReply{} - } + // parse args + if len(args) < 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'geopos' command") + } + key := string(args[0]) + sortedSet, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if sortedSet == nil { + return &reply.NullBulkReply{} + } - positions := make([][]byte, len(args)-1) - for i := 0; i < len(args)-1; i++ { - member := string(args[i+1]) - elem, exists := sortedSet.Get(member) - if !exists { - positions[i] = (&reply.EmptyMultiBulkReply{}).ToBytes() - continue - } - lat, lng := geohash.Decode(uint64(elem.Score)) - lngStr := strconv.FormatFloat(lng, 'f', -1, 64) - latStr := strconv.FormatFloat(lat, 'f', -1, 64) - positions[i] = reply.MakeMultiBulkReply([][]byte{ - []byte(lngStr), []byte(latStr), - }).ToBytes() - } - return reply.MakeMultiRawReply(positions) + positions := make([][]byte, len(args)-1) + for i := 0; i < len(args)-1; i++ { + member := string(args[i+1]) + elem, exists := sortedSet.Get(member) + if !exists { + positions[i] = (&reply.EmptyMultiBulkReply{}).ToBytes() + continue + } + lat, lng := geohash.Decode(uint64(elem.Score)) + lngStr := strconv.FormatFloat(lng, 'f', -1, 64) + latStr := strconv.FormatFloat(lat, 'f', -1, 64) + positions[i] = reply.MakeMultiBulkReply([][]byte{ + []byte(lngStr), []byte(latStr), + }).ToBytes() + } + return reply.MakeMultiRawReply(positions) } func GeoDist(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 3 && len(args) != 4 { - return reply.MakeErrReply("ERR wrong number of arguments for 'geodist' command") - } - key := string(args[0]) - sortedSet, errReply := db.getAsSortedSet(key) - if errReply != nil { - return errReply - } - if sortedSet == nil { - return &reply.NullBulkReply{} - } + // parse args + if len(args) != 3 && len(args) != 4 { + return reply.MakeErrReply("ERR wrong number of arguments for 'geodist' command") + } + key := string(args[0]) + sortedSet, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if sortedSet == nil { + return &reply.NullBulkReply{} + } - positions := make([][]float64, 2) - for i := 1; i < 3; i++ { - member := string(args[i]) - elem, exists := sortedSet.Get(member) - if !exists { - return &reply.NullBulkReply{} - } - lat, lng := geohash.Decode(uint64(elem.Score)) - positions[i-1] = []float64{lat, lng} - } - unit := "m" - if len(args) == 4 { - unit = strings.ToLower(string(args[3])) - } - dis := geohash.Distance(positions[0][1], positions[0][0], positions[1][1], positions[1][0]) - switch unit { - case "m": - disStr := strconv.FormatFloat(dis, 'f', -1, 64) - return reply.MakeBulkReply([]byte(disStr)) - case "km": - disStr := strconv.FormatFloat(dis/1000, 'f', -1, 64) - return reply.MakeBulkReply([]byte(disStr)) - } - return reply.MakeErrReply("ERR unsupported unit provided. please use m, km") + positions := make([][]float64, 2) + for i := 1; i < 3; i++ { + member := string(args[i]) + elem, exists := sortedSet.Get(member) + if !exists { + return &reply.NullBulkReply{} + } + lat, lng := geohash.Decode(uint64(elem.Score)) + positions[i-1] = []float64{lat, lng} + } + unit := "m" + if len(args) == 4 { + unit = strings.ToLower(string(args[3])) + } + dis := geohash.Distance(positions[0][0], positions[0][1], positions[1][0], positions[1][1]) + switch unit { + case "m": + disStr := strconv.FormatFloat(dis, 'f', -1, 64) + return reply.MakeBulkReply([]byte(disStr)) + case "km": + disStr := strconv.FormatFloat(dis/1000, 'f', -1, 64) + return reply.MakeBulkReply([]byte(disStr)) + } + return reply.MakeErrReply("ERR unsupported unit provided. please use m, km") } func GeoHash(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) < 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'geohash' command") - } + // parse args + if len(args) < 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'geohash' command") + } - key := string(args[0]) - sortedSet, errReply := db.getAsSortedSet(key) - if errReply != nil { - return errReply - } - if sortedSet == nil { - return &reply.NullBulkReply{} - } + key := string(args[0]) + sortedSet, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if sortedSet == nil { + return &reply.NullBulkReply{} + } - strs := make([][]byte, len(args)-1) - for i := 0; i < len(args)-1; i++ { - member := string(args[i+1]) - elem, exists := sortedSet.Get(member) - if !exists { - strs[i] = (&reply.EmptyMultiBulkReply{}).ToBytes() - continue - } - str := geohash.ToString(geohash.FromInt(uint64(elem.Score))) - strs[i] = []byte(str) - } - return reply.MakeMultiBulkReply(strs) + strs := make([][]byte, len(args)-1) + for i := 0; i < len(args)-1; i++ { + member := string(args[i+1]) + elem, exists := sortedSet.Get(member) + if !exists { + strs[i] = (&reply.EmptyMultiBulkReply{}).ToBytes() + continue + } + str := geohash.ToString(geohash.FromInt(uint64(elem.Score))) + strs[i] = []byte(str) + } + return reply.MakeMultiBulkReply(strs) } func GeoRadius(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) < 5 { - return reply.MakeErrReply("ERR wrong number of arguments for 'georadius' command") - } + // parse args + if len(args) < 5 { + return reply.MakeErrReply("ERR wrong number of arguments for 'georadius' command") + } - key := string(args[0]) - sortedSet, errReply := db.getAsSortedSet(key) - if errReply != nil { - return errReply - } - if sortedSet == nil { - return &reply.NullBulkReply{} - } + key := string(args[0]) + sortedSet, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if sortedSet == nil { + return &reply.NullBulkReply{} + } - lng, err := strconv.ParseFloat(string(args[1]), 64) - if err != nil { - return reply.MakeErrReply("ERR value is not a valid float") - } - lat, err := strconv.ParseFloat(string(args[2]), 64) - if err != nil { - return reply.MakeErrReply("ERR value is not a valid float") - } - radius, err := strconv.ParseFloat(string(args[3]), 64) - if err != nil { - return reply.MakeErrReply("ERR value is not a valid float") - } - unit := strings.ToLower(string(args[4])) - if unit == "m" { - } else if unit == "km" { - radius *= 1000 - } else { - return reply.MakeErrReply("ERR unsupported unit provided. please use m, km") - } - return geoRadius0(sortedSet, lat, lng, radius) + lng, err := strconv.ParseFloat(string(args[1]), 64) + if err != nil { + return reply.MakeErrReply("ERR value is not a valid float") + } + lat, err := strconv.ParseFloat(string(args[2]), 64) + if err != nil { + return reply.MakeErrReply("ERR value is not a valid float") + } + radius, err := strconv.ParseFloat(string(args[3]), 64) + if err != nil { + return reply.MakeErrReply("ERR value is not a valid float") + } + unit := strings.ToLower(string(args[4])) + if unit == "m" { + } else if unit == "km" { + radius *= 1000 + } else { + return reply.MakeErrReply("ERR unsupported unit provided. please use m, km") + } + return geoRadius0(sortedSet, lat, lng, radius) } func GeoRadiusByMember(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) < 4 { - return reply.MakeErrReply("ERR wrong number of arguments for 'georadiusbymember' command") - } + // parse args + if len(args) < 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'georadiusbymember' command") + } - key := string(args[0]) - sortedSet, errReply := db.getAsSortedSet(key) - if errReply != nil { - return errReply - } - if sortedSet == nil { - return &reply.NullBulkReply{} - } + key := string(args[0]) + sortedSet, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if sortedSet == nil { + return &reply.NullBulkReply{} + } - member := string(args[1]) - elem, ok := sortedSet.Get(member) - if !ok { - return &reply.NullBulkReply{} - } - lat, lng := geohash.Decode(uint64(elem.Score)) + member := string(args[1]) + elem, ok := sortedSet.Get(member) + if !ok { + return &reply.NullBulkReply{} + } + lat, lng := geohash.Decode(uint64(elem.Score)) - radius, err := strconv.ParseFloat(string(args[2]), 64) - if err != nil { - return reply.MakeErrReply("ERR value is not a valid float") - } - unit := strings.ToLower(string(args[4])) - if unit == "m" { - } else if unit == "km" { - radius *= 1000 - } else { - return reply.MakeErrReply("ERR unsupported unit provided. please use m, km") - } - return geoRadius0(sortedSet, lat, lng, radius) + radius, err := strconv.ParseFloat(string(args[2]), 64) + if err != nil { + return reply.MakeErrReply("ERR value is not a valid float") + } + if len(args) > 3 { + unit := strings.ToLower(string(args[3])) + if unit == "m" { + } else if unit == "km" { + radius *= 1000 + } else { + return reply.MakeErrReply("ERR unsupported unit provided. please use m, km") + } + } + return geoRadius0(sortedSet, lat, lng, radius) } func geoRadius0(sortedSet *sortedset.SortedSet, lat float64, lng float64, radius float64) redis.Reply { - areas := geohash.GetNeighbours(lat, lng, radius) - members := make([][]byte, 0) - for _, area := range areas { - lower := &sortedset.ScoreBorder{Value: float64(area[0])} - upper := &sortedset.ScoreBorder{Value: float64(area[1])} - elements := sortedSet.RangeByScore(lower, upper, 0, -1, true) - for _, elem := range elements { - members = append(members, []byte(elem.Member)) - } - } - return reply.MakeMultiBulkReply(members) -} \ No newline at end of file + areas := geohash.GetNeighbours(lat, lng, radius) + members := make([][]byte, 0) + for _, area := range areas { + lower := &sortedset.ScoreBorder{Value: float64(area[0])} + upper := &sortedset.ScoreBorder{Value: float64(area[1])} + elements := sortedSet.RangeByScore(lower, upper, 0, -1, true) + for _, elem := range elements { + members = append(members, []byte(elem.Member)) + } + } + return reply.MakeMultiBulkReply(members) +} diff --git a/src/db/geo_test.go b/src/db/geo_test.go new file mode 100644 index 0000000..f49f9b5 --- /dev/null +++ b/src/db/geo_test.go @@ -0,0 +1,87 @@ +package db + +import ( + "fmt" + "github.com/HDT3213/godis/src/redis/reply" + "github.com/HDT3213/godis/src/redis/reply/asserts" + "strconv" + "testing" +) + +func TestGeoHash(t *testing.T) { + FlushDB(testDB, toArgs()) + key := RandString(10) + pos := RandString(10) + result := GeoAdd(testDB, toArgs(key, "13.361389", "38.115556", pos)) + asserts.AssertIntReply(t, result, 1) + result = GeoHash(testDB, toArgs(key, pos)) + asserts.AssertMultiBulkReply(t, result, []string{"sqc8b49rnys00"}) +} + +func TestGeoRadius(t *testing.T) { + FlushDB(testDB, toArgs()) + key := RandString(10) + pos1 := RandString(10) + pos2 := RandString(10) + GeoAdd(testDB, toArgs(key, + "13.361389", "38.115556", pos1, + "15.087269", "37.502669", pos2, + )) + result := GeoRadius(testDB, toArgs(key, "15", "37", "200", "km")) + asserts.AssertMultiBulkReplySize(t, result, 2) +} + +func TestGeoRadiusByMember(t *testing.T) { + FlushDB(testDB, toArgs()) + key := RandString(10) + pos1 := RandString(10) + pos2 := RandString(10) + pivot := RandString(10) + GeoAdd(testDB, toArgs(key, + "13.361389", "38.115556", pos1, + "17.087269", "38.502669", pos2, + "13.583333", "37.316667", pivot, + )) + result := GeoRadiusByMember(testDB, toArgs(key, pivot, "100", "km")) + asserts.AssertMultiBulkReplySize(t, result, 2) +} + +func TestGeoPos(t *testing.T) { + FlushDB(testDB, toArgs()) + key := RandString(10) + pos1 := RandString(10) + pos2 := RandString(10) + GeoAdd(testDB, toArgs(key, + "13.361389", "38.115556", pos1, + )) + result := GeoPos(testDB, toArgs(key, pos1, pos2)) + expected := "*2\r\n*2\r\n$18\r\n13.361386698670685\r\n$17\r\n38.11555536696687\r\n*0\r\n" + if string(result.ToBytes()) != expected { + t.Error("test failed") + } +} + +func TestGeoDist(t *testing.T) { + FlushDB(testDB, toArgs()) + key := RandString(10) + pos1 := RandString(10) + pos2 := RandString(10) + GeoAdd(testDB, toArgs(key, + "13.361389", "38.115556", pos1, + "15.087269", "37.502669", pos2, + )) + result := GeoDist(testDB, toArgs(key, pos1, pos2, "km")) + bulkReply, ok := result.(*reply.BulkReply) + if !ok { + t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes())) + return + } + dist, err := strconv.ParseFloat(string(bulkReply.Arg), 10) + if err != nil { + t.Error(err) + return + } + if dist < 166.274 || dist > 166.275 { + t.Errorf("expected 166.274, actual: %f", dist) + } +} diff --git a/src/db/hash_test.go b/src/db/hash_test.go index 64fc71c..479181b 100644 --- a/src/db/hash_test.go +++ b/src/db/hash_test.go @@ -1,200 +1,215 @@ package db import ( - "fmt" - "github.com/HDT3213/godis/src/datastruct/utils" - "github.com/HDT3213/godis/src/redis/reply" - "math/rand" - "strconv" - "testing" + "fmt" + "github.com/HDT3213/godis/src/datastruct/utils" + "github.com/HDT3213/godis/src/redis/reply" + "github.com/HDT3213/godis/src/redis/reply/asserts" + "strconv" + "testing" ) func TestHSet(t *testing.T) { - FlushAll(testDB, [][]byte{}) - size := 100 + FlushAll(testDB, [][]byte{}) + size := 100 - // test hset - key := strconv.FormatInt(int64(rand.Int()), 10) - values := make(map[string][]byte, size) - for i := 0; i < size; i++ { - value := strconv.FormatInt(int64(rand.Int()), 10) - field := strconv.Itoa(i) - values[field] = []byte(value) - result := HSet(testDB, toArgs(key, field, value)) - if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(1) { - t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code)) - } - } + // test hset + key := RandString(10) + values := make(map[string][]byte, size) + for i := 0; i < size; i++ { + value := RandString(10) + field := strconv.Itoa(i) + values[field] = []byte(value) + result := HSet(testDB, toArgs(key, field, value)) + if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(1) { + t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code)) + } + } - // test hget and hexists - for field, v := range values { - actual := HGet(testDB, toArgs(key, field)) - expected := reply.MakeBulkReply(v) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(actual.ToBytes()))) - } - actual = HExists(testDB, toArgs(key, field)) - if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(1) { - t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code)) - } - } + // test hget and hexists + for field, v := range values { + actual := HGet(testDB, toArgs(key, field)) + expected := reply.MakeBulkReply(v) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(actual.ToBytes()))) + } + actual = HExists(testDB, toArgs(key, field)) + if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(1) { + t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code)) + } + } - // test hlen - actual := HLen(testDB, toArgs(key)) - if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(len(values)) { - t.Error(fmt.Sprintf("expected %d, actually %d", len(values), intResult.Code)) - } + // test hlen + actual := HLen(testDB, toArgs(key)) + if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(len(values)) { + t.Error(fmt.Sprintf("expected %d, actually %d", len(values), intResult.Code)) + } } func TestHDel(t *testing.T) { - FlushAll(testDB, [][]byte{}) - size := 100 + FlushAll(testDB, [][]byte{}) + size := 100 - // set values - key := strconv.FormatInt(int64(rand.Int()), 10) - fields := make([]string, size) - for i := 0; i < size; i++ { - value := strconv.FormatInt(int64(rand.Int()), 10) - field := strconv.Itoa(i) - fields[i] = field - HSet(testDB, toArgs(key, field, value)) - } + // set values + key := RandString(10) + fields := make([]string, size) + for i := 0; i < size; i++ { + value := RandString(10) + field := strconv.Itoa(i) + fields[i] = field + HSet(testDB, toArgs(key, field, value)) + } - // test HDel - args := []string{key} - args = append(args, fields...) - actual := HDel(testDB, toArgs(args...)) - if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(len(fields)) { - t.Error(fmt.Sprintf("expected %d, actually %d", len(fields), intResult.Code)) - } + // test HDel + args := []string{key} + args = append(args, fields...) + actual := HDel(testDB, toArgs(args...)) + if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(len(fields)) { + t.Error(fmt.Sprintf("expected %d, actually %d", len(fields), intResult.Code)) + } - actual = HLen(testDB, toArgs(key)) - if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(0) { - t.Error(fmt.Sprintf("expected %d, actually %d", 0, intResult.Code)) - } + actual = HLen(testDB, toArgs(key)) + if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(0) { + t.Error(fmt.Sprintf("expected %d, actually %d", 0, intResult.Code)) + } } func TestHMSet(t *testing.T) { - FlushAll(testDB, [][]byte{}) - size := 100 + FlushAll(testDB, [][]byte{}) + size := 100 - // test hset - key := strconv.FormatInt(int64(rand.Int()), 10) - fields := make([]string, size) - values := make([]string, size) - setArgs := []string{key} - for i := 0; i < size; i++ { - fields[i] = strconv.FormatInt(int64(rand.Int()), 10) - values[i] = strconv.FormatInt(int64(rand.Int()), 10) - setArgs = append(setArgs, fields[i], values[i]) - } - result := HMSet(testDB, toArgs(setArgs...)) - if _, ok := result.(*reply.OkReply); !ok { - t.Error(fmt.Sprintf("expected ok, actually %s", string(result.ToBytes()))) - } + // test hset + key := RandString(10) + fields := make([]string, size) + values := make([]string, size) + setArgs := []string{key} + for i := 0; i < size; i++ { + fields[i] = RandString(10) + values[i] = RandString(10) + setArgs = append(setArgs, fields[i], values[i]) + } + result := HMSet(testDB, toArgs(setArgs...)) + if _, ok := result.(*reply.OkReply); !ok { + t.Error(fmt.Sprintf("expected ok, actually %s", string(result.ToBytes()))) + } - // test HMGet - getArgs := []string{key} - getArgs = append(getArgs, fields...) - actual := HMGet(testDB, toArgs(getArgs...)) - expected := reply.MakeMultiBulkReply(toArgs(values...)) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(actual.ToBytes()))) - } + // test HMGet + getArgs := []string{key} + getArgs = append(getArgs, fields...) + actual := HMGet(testDB, toArgs(getArgs...)) + expected := reply.MakeMultiBulkReply(toArgs(values...)) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(actual.ToBytes()))) + } } func TestHGetAll(t *testing.T) { - FlushAll(testDB, [][]byte{}) - size := 100 - key := strconv.FormatInt(int64(rand.Int()), 10) - fields := make([]string, size) - valueSet := make(map[string]bool, size) - valueMap := make(map[string]string) - all := make([]string, 0) - for i := 0; i < size; i++ { - fields[i] = strconv.FormatInt(int64(rand.Int()), 10) - value := strconv.FormatInt(int64(rand.Int()), 10) - all = append(all, fields[i], value) - valueMap[fields[i]] = value - valueSet[value] = true - HSet(testDB, toArgs(key, fields[i], value)) - } + FlushAll(testDB, [][]byte{}) + size := 100 + key := RandString(10) + fields := make([]string, size) + valueSet := make(map[string]bool, size) + valueMap := make(map[string]string) + all := make([]string, 0) + for i := 0; i < size; i++ { + fields[i] = RandString(10) + value := RandString(10) + all = append(all, fields[i], value) + valueMap[fields[i]] = value + valueSet[value] = true + HSet(testDB, toArgs(key, fields[i], value)) + } - // test HGetAll - result := HGetAll(testDB, toArgs(key)) - multiBulk, ok := result.(*reply.MultiBulkReply) - if !ok { - t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes()))) - } - if 2*len(fields) != len(multiBulk.Args) { - t.Error(fmt.Sprintf("expected %d items , actually %d ", 2*len(fields), len(multiBulk.Args))) - } - for i := range fields { - field := string(multiBulk.Args[2*i]) - actual := string(multiBulk.Args[2*i+1]) - expected, ok := valueMap[field] - if !ok { - t.Error(fmt.Sprintf("unexpected field %s", field)) - continue - } - if actual != expected { - t.Error(fmt.Sprintf("expected %s, actually %s", expected, actual)) - } - } + // test HGetAll + result := HGetAll(testDB, toArgs(key)) + multiBulk, ok := result.(*reply.MultiBulkReply) + if !ok { + t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes()))) + } + if 2*len(fields) != len(multiBulk.Args) { + t.Error(fmt.Sprintf("expected %d items , actually %d ", 2*len(fields), len(multiBulk.Args))) + } + for i := range fields { + field := string(multiBulk.Args[2*i]) + actual := string(multiBulk.Args[2*i+1]) + expected, ok := valueMap[field] + if !ok { + t.Error(fmt.Sprintf("unexpected field %s", field)) + continue + } + if actual != expected { + t.Error(fmt.Sprintf("expected %s, actually %s", expected, actual)) + } + } - // test HKeys - result = HKeys(testDB, toArgs(key)) - multiBulk, ok = result.(*reply.MultiBulkReply) - if !ok { - t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes()))) - } - if len(fields) != len(multiBulk.Args) { - t.Error(fmt.Sprintf("expected %d items , actually %d ", len(fields), len(multiBulk.Args))) - } - for _, v := range multiBulk.Args { - field := string(v) - if _, ok := valueMap[field]; !ok { - t.Error(fmt.Sprintf("unexpected field %s", field)) - } - } + // test HKeys + result = HKeys(testDB, toArgs(key)) + multiBulk, ok = result.(*reply.MultiBulkReply) + if !ok { + t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes()))) + } + if len(fields) != len(multiBulk.Args) { + t.Error(fmt.Sprintf("expected %d items , actually %d ", len(fields), len(multiBulk.Args))) + } + for _, v := range multiBulk.Args { + field := string(v) + if _, ok := valueMap[field]; !ok { + t.Error(fmt.Sprintf("unexpected field %s", field)) + } + } - // test HVals - result = HVals(testDB, toArgs(key)) - multiBulk, ok = result.(*reply.MultiBulkReply) - if !ok { - t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes()))) - } - if len(fields) != len(multiBulk.Args) { - t.Error(fmt.Sprintf("expected %d items , actually %d ", len(fields), len(multiBulk.Args))) - } - for _, v := range multiBulk.Args { - value := string(v) - _, ok := valueSet[value] - if !ok { - t.Error(fmt.Sprintf("unexpected value %s", value)) - } - } + // test HVals + result = HVals(testDB, toArgs(key)) + multiBulk, ok = result.(*reply.MultiBulkReply) + if !ok { + t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes()))) + } + if len(fields) != len(multiBulk.Args) { + t.Error(fmt.Sprintf("expected %d items , actually %d ", len(fields), len(multiBulk.Args))) + } + for _, v := range multiBulk.Args { + value := string(v) + _, ok := valueSet[value] + if !ok { + t.Error(fmt.Sprintf("unexpected value %s", value)) + } + } } func TestHIncrBy(t *testing.T) { - FlushAll(testDB, [][]byte{}) + FlushAll(testDB, [][]byte{}) - key := strconv.FormatInt(int64(rand.Int()), 10) - result := HIncrBy(testDB, toArgs(key, "a", "1")) - if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "1" { - t.Error(fmt.Sprintf("expected %s, actually %s", "1", string(bulkResult.Arg))) - } - result = HIncrBy(testDB, toArgs(key, "a", "1")) - if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "2" { - t.Error(fmt.Sprintf("expected %s, actually %s", "2", string(bulkResult.Arg))) - } + key := RandString(10) + result := HIncrBy(testDB, toArgs(key, "a", "1")) + if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "1" { + t.Error(fmt.Sprintf("expected %s, actually %s", "1", string(bulkResult.Arg))) + } + result = HIncrBy(testDB, toArgs(key, "a", "1")) + if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "2" { + t.Error(fmt.Sprintf("expected %s, actually %s", "2", string(bulkResult.Arg))) + } + + result = HIncrByFloat(testDB, toArgs(key, "b", "1.2")) + if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "1.2" { + t.Error(fmt.Sprintf("expected %s, actually %s", "1.2", string(bulkResult.Arg))) + } + result = HIncrByFloat(testDB, toArgs(key, "b", "1.2")) + if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "2.4" { + t.Error(fmt.Sprintf("expected %s, actually %s", "2.4", string(bulkResult.Arg))) + } +} + +func TestHSetNX(t *testing.T) { + FlushAll(testDB, [][]byte{}) + key := RandString(10) + field := RandString(10) + value := RandString(10) + result := HSetNX(testDB, toArgs(key, field, value)) + asserts.AssertIntReply(t, result, 1) + value2 := RandString(10) + result = HSetNX(testDB, toArgs(key, field, value2)) + asserts.AssertIntReply(t, result, 0) + result = HGet(testDB, toArgs(key, field)) + asserts.AssertBulkReply(t, result, value) - result = HIncrByFloat(testDB, toArgs(key, "b", "1.2")) - if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "1.2" { - t.Error(fmt.Sprintf("expected %s, actually %s", "1.2", string(bulkResult.Arg))) - } - result = HIncrByFloat(testDB, toArgs(key, "b", "1.2")) - if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "2.4" { - t.Error(fmt.Sprintf("expected %s, actually %s", "2.4", string(bulkResult.Arg))) - } } diff --git a/src/db/keys.go b/src/db/keys.go index ff40d27..535d56d 100644 --- a/src/db/keys.go +++ b/src/db/keys.go @@ -76,7 +76,7 @@ func Type(db *DB, args [][]byte) redis.Reply { return reply.MakeStatusReply("string") case *list.LinkedList: return reply.MakeStatusReply("list") - case *dict.Dict: + case dict.Dict: return reply.MakeStatusReply("hash") case *set.Set: return reply.MakeStatusReply("set") @@ -101,10 +101,11 @@ func Rename(db *DB, args [][]byte) redis.Reply { return reply.MakeErrReply("no such key") } rawTTL, hasTTL := db.TTLMap.Get(src) - db.Persist(src) // clean src and dest with their ttl - db.Persist(dest) db.Put(dest, entity) + db.Remove(src) if hasTTL { + db.Persist(src) // clean src and dest with their ttl + db.Persist(dest) expireTime, _ := rawTTL.(time.Time) db.Expire(dest, expireTime) } @@ -135,6 +136,8 @@ func RenameNx(db *DB, args [][]byte) redis.Reply { db.Removes(src, dest) // clean src and dest with their ttl db.Put(dest, entity) if hasTTL { + db.Persist(src) // clean src and dest with their ttl + db.Persist(dest) expireTime, _ := rawTTL.(time.Time) db.Expire(dest, expireTime) } @@ -161,7 +164,7 @@ func Expire(db *DB, args [][]byte) redis.Reply { expireAt := time.Now().Add(ttl) db.Expire(key, expireAt) - db.AddAof(makeExpireCmd(key, expireAt), ) + db.AddAof(makeExpireCmd(key, expireAt)) return reply.MakeIntReply(1) } diff --git a/src/db/keys_test.go b/src/db/keys_test.go new file mode 100644 index 0000000..2747a59 --- /dev/null +++ b/src/db/keys_test.go @@ -0,0 +1,197 @@ +package db + +import ( + "fmt" + "github.com/HDT3213/godis/src/redis/reply" + "github.com/HDT3213/godis/src/redis/reply/asserts" + "strconv" + "testing" + "time" +) + +func TestExists(t *testing.T) { + FlushAll(testDB, [][]byte{}) + key := RandString(10) + value := RandString(10) + Set(testDB, toArgs(key, value)) + result := Exists(testDB, toArgs(key)) + asserts.AssertIntReply(t, result, 1) + key = RandString(10) + result = Exists(testDB, toArgs(key)) + asserts.AssertIntReply(t, result, 0) +} + +func TestType(t *testing.T) { + FlushAll(testDB, [][]byte{}) + key := RandString(10) + value := RandString(10) + Set(testDB, toArgs(key, value)) + result := Type(testDB, toArgs(key)) + asserts.AssertStatusReply(t, result, "string") + + Del(testDB, toArgs(key)) + result = Type(testDB, toArgs(key)) + asserts.AssertStatusReply(t, result, "none") + RPush(testDB, toArgs(key, value)) + result = Type(testDB, toArgs(key)) + asserts.AssertStatusReply(t, result, "list") + + Del(testDB, toArgs(key)) + HSet(testDB, toArgs(key, key, value)) + result = Type(testDB, toArgs(key)) + asserts.AssertStatusReply(t, result, "hash") + + Del(testDB, toArgs(key)) + SAdd(testDB, toArgs(key, value)) + result = Type(testDB, toArgs(key)) + asserts.AssertStatusReply(t, result, "set") + + Del(testDB, toArgs(key)) + ZAdd(testDB, toArgs(key, "1", value)) + result = Type(testDB, toArgs(key)) + asserts.AssertStatusReply(t, result, "zset") +} + +func TestRename(t *testing.T) { + FlushAll(testDB, [][]byte{}) + key := RandString(10) + value := RandString(10) + newKey := key + RandString(2) + Set(testDB, toArgs(key, value, "ex", "1000")) + result := Rename(testDB, toArgs(key, newKey)) + if _, ok := result.(*reply.OkReply); !ok { + t.Error("expect ok") + return + } + result = Exists(testDB, toArgs(key)) + asserts.AssertIntReply(t, result, 0) + result = Exists(testDB, toArgs(newKey)) + asserts.AssertIntReply(t, result, 1) + // check ttl + result = TTL(testDB, toArgs(newKey)) + intResult, ok := result.(*reply.IntReply) + if !ok { + t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) + return + } + if intResult.Code <= 0 { + t.Errorf("expected ttl more than 0, actual: %d", intResult.Code) + return + } +} + +func TestRenameNx(t *testing.T) { + FlushAll(testDB, [][]byte{}) + key := RandString(10) + value := RandString(10) + newKey := key + RandString(2) + Set(testDB, toArgs(key, value, "ex", "1000")) + result := RenameNx(testDB, toArgs(key, newKey)) + if _, ok := result.(*reply.OkReply); !ok { + t.Error("expect ok") + return + } + result = Exists(testDB, toArgs(key)) + asserts.AssertIntReply(t, result, 0) + result = Exists(testDB, toArgs(newKey)) + asserts.AssertIntReply(t, result, 1) + result = TTL(testDB, toArgs(newKey)) + intResult, ok := result.(*reply.IntReply) + if !ok { + t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) + return + } + if intResult.Code <= 0 { + t.Errorf("expected ttl more than 0, actual: %d", intResult.Code) + return + } +} + +func TestTTL(t *testing.T) { + FlushAll(testDB, [][]byte{}) + key := RandString(10) + value := RandString(10) + Set(testDB, toArgs(key, value)) + + result := Expire(testDB, toArgs(key, "1000")) + asserts.AssertIntReply(t, result, 1) + result = TTL(testDB, toArgs(key)) + intResult, ok := result.(*reply.IntReply) + if !ok { + t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) + return + } + if intResult.Code <= 0 { + t.Errorf("expected ttl more than 0, actual: %d", intResult.Code) + return + } + + result = Persist(testDB, toArgs(key)) + asserts.AssertIntReply(t, result, 1) + result = TTL(testDB, toArgs(key)) + asserts.AssertIntReply(t, result, -1) + + result = PExpire(testDB, toArgs(key, "1000000")) + asserts.AssertIntReply(t, result, 1) + result = PTTL(testDB, toArgs(key)) + intResult, ok = result.(*reply.IntReply) + if !ok { + t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) + return + } + if intResult.Code <= 0 { + t.Errorf("expected ttl more than 0, actual: %d", intResult.Code) + return + } +} + +func TestExpireAt(t *testing.T) { + FlushAll(testDB, [][]byte{}) + key := RandString(10) + value := RandString(10) + Set(testDB, toArgs(key, value)) + + expireAt := time.Now().Add(time.Minute).Unix() + result := ExpireAt(testDB, toArgs(key, strconv.FormatInt(expireAt, 10))) + asserts.AssertIntReply(t, result, 1) + result = TTL(testDB, toArgs(key)) + intResult, ok := result.(*reply.IntReply) + if !ok { + t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) + return + } + if intResult.Code <= 0 { + t.Errorf("expected ttl more than 0, actual: %d", intResult.Code) + return + } + + expireAt = time.Now().Add(time.Minute).Unix() + result = PExpireAt(testDB, toArgs(key, strconv.FormatInt(expireAt*1000, 10))) + asserts.AssertIntReply(t, result, 1) + result = TTL(testDB, toArgs(key)) + intResult, ok = result.(*reply.IntReply) + if !ok { + t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) + return + } + if intResult.Code <= 0 { + t.Errorf("expected ttl more than 0, actual: %d", intResult.Code) + return + } +} + +func TestKeys(t *testing.T) { + FlushAll(testDB, [][]byte{}) + key := RandString(10) + value := RandString(10) + Set(testDB, toArgs(key, value)) + Set(testDB, toArgs("a:"+key, value)) + Set(testDB, toArgs("b:"+key, value)) + + result := Keys(testDB, toArgs("*")) + asserts.AssertMultiBulkReplySize(t, result, 3) + result = Keys(testDB, toArgs("a:*")) + asserts.AssertMultiBulkReplySize(t, result, 1) + result = Keys(testDB, toArgs("?:*")) + asserts.AssertMultiBulkReplySize(t, result, 2) +} diff --git a/src/db/list_test.go b/src/db/list_test.go index 2846916..9ca0562 100644 --- a/src/db/list_test.go +++ b/src/db/list_test.go @@ -1,386 +1,384 @@ package db import ( - "fmt" - "github.com/HDT3213/godis/src/datastruct/utils" - "github.com/HDT3213/godis/src/redis/reply" - "math/rand" - "strconv" - "testing" + "fmt" + "github.com/HDT3213/godis/src/datastruct/utils" + "github.com/HDT3213/godis/src/redis/reply" + "strconv" + "testing" ) func TestPush(t *testing.T) { - FlushAll(testDB, [][]byte{}) - size := 100 + FlushAll(testDB, [][]byte{}) + size := 100 - // rpush single - key := strconv.FormatInt(int64(rand.Int()), 10) - values := make([][]byte, size) - for i := 0; i < size; i++ { - value := strconv.FormatInt(int64(rand.Int()), 10) - values[i] = []byte(value) - result := RPush(testDB, toArgs(key, value)) - if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(i+1) { - t.Error(fmt.Sprintf("expected %d, actually %d", i+1, intResult.Code)) - } - } - actual := LRange(testDB, toArgs(key, "0", "-1")) - expected := reply.MakeMultiBulkReply(values) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error("push error") - } - Del(testDB, toArgs(key)) + // rpush single + key := RandString(10) + values := make([][]byte, size) + for i := 0; i < size; i++ { + value := RandString(10) + values[i] = []byte(value) + result := RPush(testDB, toArgs(key, value)) + if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(i+1) { + t.Error(fmt.Sprintf("expected %d, actually %d", i+1, intResult.Code)) + } + } + actual := LRange(testDB, toArgs(key, "0", "-1")) + expected := reply.MakeMultiBulkReply(values) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error("push error") + } + Del(testDB, toArgs(key)) - // rpush multi - key = strconv.FormatInt(int64(rand.Int()), 10) - values = make([][]byte, size+1) - values[0] = []byte(key) - for i := 0; i < size; i++ { - value := strconv.FormatInt(int64(rand.Int()), 10) - values[i+1] = []byte(value) - } - result := RPush(testDB, values) - if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(size) { - t.Error(fmt.Sprintf("expected %d, actually %d", size, intResult.Code)) - } - actual = LRange(testDB, toArgs(key, "0", "-1")) - expected = reply.MakeMultiBulkReply(values[1:]) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error("push error") - } - Del(testDB, toArgs(key)) + // rpush multi + key = RandString(10) + values = make([][]byte, size+1) + values[0] = []byte(key) + for i := 0; i < size; i++ { + value := RandString(10) + values[i+1] = []byte(value) + } + result := RPush(testDB, values) + if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(size) { + t.Error(fmt.Sprintf("expected %d, actually %d", size, intResult.Code)) + } + actual = LRange(testDB, toArgs(key, "0", "-1")) + expected = reply.MakeMultiBulkReply(values[1:]) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error("push error") + } + Del(testDB, toArgs(key)) - // left push single - key = strconv.FormatInt(int64(rand.Int()), 10) - values = make([][]byte, size) - for i := 0; i < size; i++ { - value := strconv.FormatInt(int64(rand.Int()), 10) - values[size-i-1] = []byte(value) - result = LPush(testDB, toArgs(key, value)) - if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(i+1) { - t.Error(fmt.Sprintf("expected %d, actually %d", i+1, intResult.Code)) - } - } - actual = LRange(testDB, toArgs(key, "0", "-1")) - expected = reply.MakeMultiBulkReply(values) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error("push error") - } - Del(testDB, toArgs(key)) + // left push single + key = RandString(10) + values = make([][]byte, size) + for i := 0; i < size; i++ { + value := RandString(10) + values[size-i-1] = []byte(value) + result = LPush(testDB, toArgs(key, value)) + if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(i+1) { + t.Error(fmt.Sprintf("expected %d, actually %d", i+1, intResult.Code)) + } + } + actual = LRange(testDB, toArgs(key, "0", "-1")) + expected = reply.MakeMultiBulkReply(values) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error("push error") + } + Del(testDB, toArgs(key)) - // left push multi - key = strconv.FormatInt(int64(rand.Int()), 10) - values = make([][]byte, size+1) - values[0] = []byte(key) - expectedValues := make([][]byte, size) - for i := 0; i < size; i++ { - value := strconv.FormatInt(int64(rand.Int()), 10) - values[i+1] = []byte(value) - expectedValues[size-i-1] = []byte(value) - } - result = LPush(testDB, values) - if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(size) { - t.Error(fmt.Sprintf("expected %d, actually %d", size, intResult.Code)) - } - actual = LRange(testDB, toArgs(key, "0", "-1")) - expected = reply.MakeMultiBulkReply(expectedValues) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error("push error") - } - Del(testDB, toArgs(key)) + // left push multi + key = RandString(10) + values = make([][]byte, size+1) + values[0] = []byte(key) + expectedValues := make([][]byte, size) + for i := 0; i < size; i++ { + value := RandString(10) + values[i+1] = []byte(value) + expectedValues[size-i-1] = []byte(value) + } + result = LPush(testDB, values) + if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(size) { + t.Error(fmt.Sprintf("expected %d, actually %d", size, intResult.Code)) + } + actual = LRange(testDB, toArgs(key, "0", "-1")) + expected = reply.MakeMultiBulkReply(expectedValues) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error("push error") + } + Del(testDB, toArgs(key)) } func TestLRange(t *testing.T) { - // prepare list - FlushAll(testDB, [][]byte{}) - size := 100 - key := strconv.FormatInt(int64(rand.Int()), 10) - values := make([][]byte, size) - for i := 0; i < size; i++ { - value := strconv.FormatInt(int64(rand.Int()), 10) - RPush(testDB, toArgs(key, value)) - values[i] = []byte(value) - } + // prepare list + FlushAll(testDB, [][]byte{}) + size := 100 + key := RandString(10) + values := make([][]byte, size) + for i := 0; i < size; i++ { + value := RandString(10) + RPush(testDB, toArgs(key, value)) + values[i] = []byte(value) + } - start := "0" - end := "9" - actual := LRange(testDB, toArgs(key, start, end)) - expected := reply.MakeMultiBulkReply(values[0:10]) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) - } + start := "0" + end := "9" + actual := LRange(testDB, toArgs(key, start, end)) + expected := reply.MakeMultiBulkReply(values[0:10]) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) + } - start = "0" - end = "200" - actual = LRange(testDB, toArgs(key, start, end)) - expected = reply.MakeMultiBulkReply(values) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) - } + start = "0" + end = "200" + actual = LRange(testDB, toArgs(key, start, end)) + expected = reply.MakeMultiBulkReply(values) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) + } - start = "0" - end = "-10" - actual = LRange(testDB, toArgs(key, start, end)) - expected = reply.MakeMultiBulkReply(values[0 : size-10+1]) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) - } + start = "0" + end = "-10" + actual = LRange(testDB, toArgs(key, start, end)) + expected = reply.MakeMultiBulkReply(values[0 : size-10+1]) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) + } - start = "0" - end = "-200" - actual = LRange(testDB, toArgs(key, start, end)) - expected = reply.MakeMultiBulkReply(values[0:0]) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) - } + start = "0" + end = "-200" + actual = LRange(testDB, toArgs(key, start, end)) + expected = reply.MakeMultiBulkReply(values[0:0]) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) + } - start = "-10" - end = "-1" - actual = LRange(testDB, toArgs(key, start, end)) - expected = reply.MakeMultiBulkReply(values[90:]) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) - } + start = "-10" + end = "-1" + actual = LRange(testDB, toArgs(key, start, end)) + expected = reply.MakeMultiBulkReply(values[90:]) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) + } } func TestLIndex(t *testing.T) { - // prepare list - FlushAll(testDB, [][]byte{}) - size := 100 - key := strconv.FormatInt(int64(rand.Int()), 10) - values := make([][]byte, size) - for i := 0; i < size; i++ { - value := strconv.FormatInt(int64(rand.Int()), 10) - RPush(testDB, toArgs(key, value)) - values[i] = []byte(value) - } + // prepare list + FlushAll(testDB, [][]byte{}) + size := 100 + key := RandString(10) + values := make([][]byte, size) + for i := 0; i < size; i++ { + value := RandString(10) + RPush(testDB, toArgs(key, value)) + values[i] = []byte(value) + } - result := LLen(testDB, toArgs(key)) - if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(size) { - t.Error(fmt.Sprintf("expected %d, actually %d", size, intResult.Code)) - } + result := LLen(testDB, toArgs(key)) + if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(size) { + t.Error(fmt.Sprintf("expected %d, actually %d", size, intResult.Code)) + } - for i := 0; i < size; i++ { - result = LIndex(testDB, toArgs(key, strconv.Itoa(i))) - expected := reply.MakeBulkReply(values[i]) - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } - } + for i := 0; i < size; i++ { + result = LIndex(testDB, toArgs(key, strconv.Itoa(i))) + expected := reply.MakeBulkReply(values[i]) + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } + } - for i := 1; i <= size; i++ { - result = LIndex(testDB, toArgs(key, strconv.Itoa(-i))) - expected := reply.MakeBulkReply(values[size-i]) - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } - } + for i := 1; i <= size; i++ { + result = LIndex(testDB, toArgs(key, strconv.Itoa(-i))) + expected := reply.MakeBulkReply(values[size-i]) + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } + } } func TestLRem(t *testing.T) { - // prepare list - FlushAll(testDB, [][]byte{}) - key := strconv.FormatInt(int64(rand.Int()), 10) - values := []string{key, "a", "b", "a", "a", "c", "a", "a"} - RPush(testDB, toArgs(values...)) + // prepare list + FlushAll(testDB, [][]byte{}) + key := RandString(10) + values := []string{key, "a", "b", "a", "a", "c", "a", "a"} + RPush(testDB, toArgs(values...)) - result := LRem(testDB, toArgs(key, "1", "a")) - if intResult, _ := result.(*reply.IntReply); intResult.Code != 1 { - t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code)) - } - result = LLen(testDB, toArgs(key)) - if intResult, _ := result.(*reply.IntReply); intResult.Code != 6 { - t.Error(fmt.Sprintf("expected %d, actually %d", 6, intResult.Code)) - } + result := LRem(testDB, toArgs(key, "1", "a")) + if intResult, _ := result.(*reply.IntReply); intResult.Code != 1 { + t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code)) + } + result = LLen(testDB, toArgs(key)) + if intResult, _ := result.(*reply.IntReply); intResult.Code != 6 { + t.Error(fmt.Sprintf("expected %d, actually %d", 6, intResult.Code)) + } - result = LRem(testDB, toArgs(key, "-2", "a")) - if intResult, _ := result.(*reply.IntReply); intResult.Code != 2 { - t.Error(fmt.Sprintf("expected %d, actually %d", 2, intResult.Code)) - } - result = LLen(testDB, toArgs(key)) - if intResult, _ := result.(*reply.IntReply); intResult.Code != 4 { - t.Error(fmt.Sprintf("expected %d, actually %d", 4, intResult.Code)) - } + result = LRem(testDB, toArgs(key, "-2", "a")) + if intResult, _ := result.(*reply.IntReply); intResult.Code != 2 { + t.Error(fmt.Sprintf("expected %d, actually %d", 2, intResult.Code)) + } + result = LLen(testDB, toArgs(key)) + if intResult, _ := result.(*reply.IntReply); intResult.Code != 4 { + t.Error(fmt.Sprintf("expected %d, actually %d", 4, intResult.Code)) + } - result = LRem(testDB, toArgs(key, "0", "a")) - if intResult, _ := result.(*reply.IntReply); intResult.Code != 2 { - t.Error(fmt.Sprintf("expected %d, actually %d", 2, intResult.Code)) - } - result = LLen(testDB, toArgs(key)) - if intResult, _ := result.(*reply.IntReply); intResult.Code != 2 { - t.Error(fmt.Sprintf("expected %d, actually %d", 2, intResult.Code)) - } + result = LRem(testDB, toArgs(key, "0", "a")) + if intResult, _ := result.(*reply.IntReply); intResult.Code != 2 { + t.Error(fmt.Sprintf("expected %d, actually %d", 2, intResult.Code)) + } + result = LLen(testDB, toArgs(key)) + if intResult, _ := result.(*reply.IntReply); intResult.Code != 2 { + t.Error(fmt.Sprintf("expected %d, actually %d", 2, intResult.Code)) + } } func TestLSet(t *testing.T) { - FlushAll(testDB, [][]byte{}) - key := strconv.FormatInt(int64(rand.Int()), 10) - values := []string{key, "a", "b", "c", "d", "e", "f"} - RPush(testDB, toArgs(values...)) + FlushAll(testDB, [][]byte{}) + key := RandString(10) + values := []string{key, "a", "b", "c", "d", "e", "f"} + RPush(testDB, toArgs(values...)) - // test positive index - size := len(values) - 1 - for i := 0; i < size; i++ { - indexStr := strconv.Itoa(i) - value := strconv.FormatInt(int64(rand.Int()), 10) - result := LSet(testDB, toArgs(key, indexStr, value)) - if _, ok := result.(*reply.OkReply); !ok { - t.Error(fmt.Sprintf("expected OK, actually %s", string(result.ToBytes()))) - } - result = LIndex(testDB, toArgs(key, indexStr)) - expected := reply.MakeBulkReply([]byte(value)) - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } - } - // test negative index - for i := 1; i <= size; i++ { - value := strconv.FormatInt(int64(rand.Int()), 10) - result := LSet(testDB, toArgs(key, strconv.Itoa(-i), value)) - if _, ok := result.(*reply.OkReply); !ok { - t.Error(fmt.Sprintf("expected OK, actually %s", string(result.ToBytes()))) - } - result = LIndex(testDB, toArgs(key, strconv.Itoa(len(values)-i-1))) - expected := reply.MakeBulkReply([]byte(value)) - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } - } + // test positive index + size := len(values) - 1 + for i := 0; i < size; i++ { + indexStr := strconv.Itoa(i) + value := RandString(10) + result := LSet(testDB, toArgs(key, indexStr, value)) + if _, ok := result.(*reply.OkReply); !ok { + t.Error(fmt.Sprintf("expected OK, actually %s", string(result.ToBytes()))) + } + result = LIndex(testDB, toArgs(key, indexStr)) + expected := reply.MakeBulkReply([]byte(value)) + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } + } + // test negative index + for i := 1; i <= size; i++ { + value := RandString(10) + result := LSet(testDB, toArgs(key, strconv.Itoa(-i), value)) + if _, ok := result.(*reply.OkReply); !ok { + t.Error(fmt.Sprintf("expected OK, actually %s", string(result.ToBytes()))) + } + result = LIndex(testDB, toArgs(key, strconv.Itoa(len(values)-i-1))) + expected := reply.MakeBulkReply([]byte(value)) + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } + } - // test illegal index - value := strconv.FormatInt(int64(rand.Int()), 10) - result := LSet(testDB, toArgs(key, strconv.Itoa(-len(values)-1), value)) - expected := reply.MakeErrReply("ERR index out of range") - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } - result = LSet(testDB, toArgs(key, strconv.Itoa(len(values)), value)) - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } - result = LSet(testDB, toArgs(key, "a", value)) - expected = reply.MakeErrReply("ERR value is not an integer or out of range") - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } + // test illegal index + value := RandString(10) + result := LSet(testDB, toArgs(key, strconv.Itoa(-len(values)-1), value)) + expected := reply.MakeErrReply("ERR index out of range") + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } + result = LSet(testDB, toArgs(key, strconv.Itoa(len(values)), value)) + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } + result = LSet(testDB, toArgs(key, "a", value)) + expected = reply.MakeErrReply("ERR value is not an integer or out of range") + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } } func TestLPop(t *testing.T) { - FlushAll(testDB, [][]byte{}) - key := strconv.FormatInt(int64(rand.Int()), 10) - values := []string{key, "a", "b", "c", "d", "e", "f"} - RPush(testDB, toArgs(values...)) - size := len(values) - 1 + FlushAll(testDB, [][]byte{}) + key := RandString(10) + values := []string{key, "a", "b", "c", "d", "e", "f"} + RPush(testDB, toArgs(values...)) + size := len(values) - 1 - for i := 0; i < size; i++ { - result := LPop(testDB, toArgs(key)) - expected := reply.MakeBulkReply([]byte(values[i+1])) - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } - } - result := RPop(testDB, toArgs(key)) - expected := &reply.NullBulkReply{} - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } + for i := 0; i < size; i++ { + result := LPop(testDB, toArgs(key)) + expected := reply.MakeBulkReply([]byte(values[i+1])) + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } + } + result := RPop(testDB, toArgs(key)) + expected := &reply.NullBulkReply{} + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } } func TestRPop(t *testing.T) { - FlushAll(testDB, [][]byte{}) - key := strconv.FormatInt(int64(rand.Int()), 10) - values := []string{key, "a", "b", "c", "d", "e", "f"} - RPush(testDB, toArgs(values...)) - size := len(values) - 1 + FlushAll(testDB, [][]byte{}) + key := RandString(10) + values := []string{key, "a", "b", "c", "d", "e", "f"} + RPush(testDB, toArgs(values...)) + size := len(values) - 1 - for i := 0; i < size; i++ { - result := RPop(testDB, toArgs(key)) - expected := reply.MakeBulkReply([]byte(values[len(values)-i-1])) - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } - } - result := RPop(testDB, toArgs(key)) - expected := &reply.NullBulkReply{} - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } + for i := 0; i < size; i++ { + result := RPop(testDB, toArgs(key)) + expected := reply.MakeBulkReply([]byte(values[len(values)-i-1])) + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } + } + result := RPop(testDB, toArgs(key)) + expected := &reply.NullBulkReply{} + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } } func TestRPopLPush(t *testing.T) { - FlushAll(testDB, [][]byte{}) - key1 := strconv.FormatInt(int64(rand.Int()), 10) - key2 := strconv.FormatInt(int64(rand.Int()), 10) - values := []string{key1, "a", "b", "c", "d", "e", "f"} - RPush(testDB, toArgs(values...)) - size := len(values) - 1 + FlushAll(testDB, [][]byte{}) + key1 := RandString(10) + key2 := RandString(10) + values := []string{key1, "a", "b", "c", "d", "e", "f"} + RPush(testDB, toArgs(values...)) + size := len(values) - 1 - for i := 0; i < size; i++ { - result := RPopLPush(testDB, toArgs(key1, key2)) - expected := reply.MakeBulkReply([]byte(values[len(values)-i-1])) - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } - result = LIndex(testDB, toArgs(key2, "0")) - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } - } - result := RPop(testDB, toArgs(key1)) - expected := &reply.NullBulkReply{} - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } + for i := 0; i < size; i++ { + result := RPopLPush(testDB, toArgs(key1, key2)) + expected := reply.MakeBulkReply([]byte(values[len(values)-i-1])) + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } + result = LIndex(testDB, toArgs(key2, "0")) + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } + } + result := RPop(testDB, toArgs(key1)) + expected := &reply.NullBulkReply{} + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } } func TestRPushX(t *testing.T) { - FlushAll(testDB, [][]byte{}) - key := strconv.FormatInt(int64(rand.Int()), 10) - result := RPushX(testDB, toArgs(key, "1")) - expected := reply.MakeIntReply(int64(0)) - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } + FlushAll(testDB, [][]byte{}) + key := RandString(10) + result := RPushX(testDB, toArgs(key, "1")) + expected := reply.MakeIntReply(int64(0)) + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } - RPush(testDB, toArgs(key, "1")) - for i := 0; i < 10; i++ { - value := strconv.FormatInt(int64(rand.Int()), 10) - result := RPushX(testDB, toArgs(key, value)) - expected := reply.MakeIntReply(int64(i + 2)) - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } - result = LIndex(testDB, toArgs(key, "-1")) - expected2 := reply.MakeBulkReply([]byte(value)) - if !utils.BytesEquals(result.ToBytes(), expected2.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected2.ToBytes()), string(result.ToBytes()))) - } - } + RPush(testDB, toArgs(key, "1")) + for i := 0; i < 10; i++ { + value := RandString(10) + result := RPushX(testDB, toArgs(key, value)) + expected := reply.MakeIntReply(int64(i + 2)) + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } + result = LIndex(testDB, toArgs(key, "-1")) + expected2 := reply.MakeBulkReply([]byte(value)) + if !utils.BytesEquals(result.ToBytes(), expected2.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected2.ToBytes()), string(result.ToBytes()))) + } + } } func TestLPushX(t *testing.T) { - FlushAll(testDB, [][]byte{}) - key := strconv.FormatInt(int64(rand.Int()), 10) - result := RPushX(testDB, toArgs(key, "1")) - expected := reply.MakeIntReply(int64(0)) - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } + FlushAll(testDB, [][]byte{}) + key := RandString(10) + result := RPushX(testDB, toArgs(key, "1")) + expected := reply.MakeIntReply(int64(0)) + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } - LPush(testDB, toArgs(key, "1")) - for i := 0; i < 10; i++ { - value := strconv.FormatInt(int64(rand.Int()), 10) - result := LPushX(testDB, toArgs(key, value)) - expected := reply.MakeIntReply(int64(i + 2)) - if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) - } - result = LIndex(testDB, toArgs(key, "0")) - expected2 := reply.MakeBulkReply([]byte(value)) - if !utils.BytesEquals(result.ToBytes(), expected2.ToBytes()) { - t.Error(fmt.Sprintf("expected %s, actually %s", string(expected2.ToBytes()), string(result.ToBytes()))) - } - } + LPush(testDB, toArgs(key, "1")) + for i := 0; i < 10; i++ { + value := RandString(10) + result := LPushX(testDB, toArgs(key, value)) + expected := reply.MakeIntReply(int64(i + 2)) + if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) + } + result = LIndex(testDB, toArgs(key, "0")) + expected2 := reply.MakeBulkReply([]byte(value)) + if !utils.BytesEquals(result.ToBytes(), expected2.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected2.ToBytes()), string(result.ToBytes()))) + } + } } - diff --git a/src/db/set.go b/src/db/set.go index 80e9587..d45a33d 100644 --- a/src/db/set.go +++ b/src/db/set.go @@ -1,511 +1,509 @@ package db import ( - HashSet "github.com/HDT3213/godis/src/datastruct/set" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" - "strconv" + HashSet "github.com/HDT3213/godis/src/datastruct/set" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/reply" + "strconv" ) - -func (db *DB)getAsSet(key string)(*HashSet.Set, reply.ErrorReply) { - entity, exists := db.Get(key) - if !exists { - return nil, nil - } - set, ok := entity.Data.(*HashSet.Set) - if !ok { - return nil, &reply.WrongTypeErrReply{} - } - return set, nil +func (db *DB) getAsSet(key string) (*HashSet.Set, reply.ErrorReply) { + entity, exists := db.Get(key) + if !exists { + return nil, nil + } + set, ok := entity.Data.(*HashSet.Set) + if !ok { + return nil, &reply.WrongTypeErrReply{} + } + return set, nil } -func (db *DB) getOrInitSet(key string)(set *HashSet.Set, inited bool, errReply reply.ErrorReply) { - set, errReply = db.getAsSet(key) - if errReply != nil { - return nil, false, errReply - } - inited = false - if set == nil { - set = HashSet.Make() - db.Put(key, &DataEntity{ - Data: set, - }) - inited = true - } - return set, inited, nil +func (db *DB) getOrInitSet(key string) (set *HashSet.Set, inited bool, errReply reply.ErrorReply) { + set, errReply = db.getAsSet(key) + if errReply != nil { + return nil, false, errReply + } + inited = false + if set == nil { + set = HashSet.Make() + db.Put(key, &DataEntity{ + Data: set, + }) + inited = true + } + return set, inited, nil } func SAdd(db *DB, args [][]byte) redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'sadd' command") - } - key := string(args[0]) - members := args[1:] + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'sadd' command") + } + key := string(args[0]) + members := args[1:] - // lock - db.Lock(key) - defer db.UnLock(key) + // lock + db.Lock(key) + defer db.UnLock(key) - // get or init entity - set, _, errReply := db.getOrInitSet(key) - if errReply != nil { - return errReply - } - counter := 0 - for _, member := range members { - counter += set.Add(string(member)) - } - db.AddAof(makeAofCmd("sadd", args)) - return reply.MakeIntReply(int64(counter)) + // get or init entity + set, _, errReply := db.getOrInitSet(key) + if errReply != nil { + return errReply + } + counter := 0 + for _, member := range members { + counter += set.Add(string(member)) + } + db.AddAof(makeAofCmd("sadd", args)) + return reply.MakeIntReply(int64(counter)) } func SIsMember(db *DB, args [][]byte) redis.Reply { - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'sismember' command") - } - key := string(args[0]) - member := string(args[1]) + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'sismember' command") + } + key := string(args[0]) + member := string(args[1]) - // get set - set, errReply := db.getAsSet(key) - if errReply != nil { - return errReply - } - if set == nil { - return reply.MakeIntReply(0) - } + // get set + set, errReply := db.getAsSet(key) + if errReply != nil { + return errReply + } + if set == nil { + return reply.MakeIntReply(0) + } - has := set.Has(member) - if has { - return reply.MakeIntReply(1) - } else { - return reply.MakeIntReply(0) - } + has := set.Has(member) + if has { + return reply.MakeIntReply(1) + } else { + return reply.MakeIntReply(0) + } } func SRem(db *DB, args [][]byte) redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'srem' command") - } - key := string(args[0]) - members := args[1:] + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'srem' command") + } + key := string(args[0]) + members := args[1:] - // lock - db.Lock(key) - defer db.UnLock(key) + // lock + db.Lock(key) + defer db.UnLock(key) - set, errReply := db.getAsSet(key) - if errReply != nil { - return errReply - } - if set == nil { - return reply.MakeIntReply(0) - } - counter := 0 - for _, member := range members { - counter += set.Remove(string(member)) - } - if set.Len() == 0 { - db.Remove(key) - } - if counter > 0 { - db.AddAof(makeAofCmd("srem", args)) - } - return reply.MakeIntReply(int64(counter)) + set, errReply := db.getAsSet(key) + if errReply != nil { + return errReply + } + if set == nil { + return reply.MakeIntReply(0) + } + counter := 0 + for _, member := range members { + counter += set.Remove(string(member)) + } + if set.Len() == 0 { + db.Remove(key) + } + if counter > 0 { + db.AddAof(makeAofCmd("srem", args)) + } + return reply.MakeIntReply(int64(counter)) } func SCard(db *DB, args [][]byte) redis.Reply { - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'scard' command") - } - key := string(args[0]) + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'scard' command") + } + key := string(args[0]) - // get or init entity - set, errReply := db.getAsSet(key) - if errReply != nil { - return errReply - } - if set == nil { - return reply.MakeIntReply(0) - } - return reply.MakeIntReply(int64(set.Len())) + // get or init entity + set, errReply := db.getAsSet(key) + if errReply != nil { + return errReply + } + if set == nil { + return reply.MakeIntReply(0) + } + return reply.MakeIntReply(int64(set.Len())) } func SMembers(db *DB, args [][]byte) redis.Reply { - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'smembers' command") - } - key := string(args[0]) + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'smembers' command") + } + key := string(args[0]) - // lock - db.Locker.RLock(key) - defer db.Locker.RUnLock(key) + // lock + db.Locker.RLock(key) + defer db.Locker.RUnLock(key) - // get or init entity - set, errReply := db.getAsSet(key) - if errReply != nil { - return errReply - } - if set == nil { - return &reply.EmptyMultiBulkReply{} - } + // get or init entity + set, errReply := db.getAsSet(key) + if errReply != nil { + return errReply + } + if set == nil { + return &reply.EmptyMultiBulkReply{} + } - - arr := make([][]byte, set.Len()) - i := 0 - set.ForEach(func (member string)bool { - arr[i] = []byte(member) - i++ - return true - }) - return reply.MakeMultiBulkReply(arr) + arr := make([][]byte, set.Len()) + i := 0 + set.ForEach(func(member string) bool { + arr[i] = []byte(member) + i++ + return true + }) + return reply.MakeMultiBulkReply(arr) } func SInter(db *DB, args [][]byte) redis.Reply { - if len(args) < 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'sinter' command") - } - keys := make([]string, len(args)) - for i, arg := range args { - keys[i] = string(arg) - } + if len(args) < 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'sinter' command") + } + keys := make([]string, len(args)) + for i, arg := range args { + keys[i] = string(arg) + } - // lock - db.Locker.RLocks(keys...) - defer db.Locker.RUnLocks(keys...) + // lock + db.Locker.RLocks(keys...) + defer db.Locker.RUnLocks(keys...) - var result *HashSet.Set - for _, key := range keys { - set, errReply := db.getAsSet(key) - if errReply != nil { - return errReply - } - if set == nil { - return &reply.EmptyMultiBulkReply{} - } + var result *HashSet.Set + for _, key := range keys { + set, errReply := db.getAsSet(key) + if errReply != nil { + return errReply + } + if set == nil { + return &reply.EmptyMultiBulkReply{} + } - if result == nil { - // init - result = HashSet.MakeFromVals(set.ToSlice()...) - } else { - result = result.Intersect(set) - if result.Len() == 0 { - // early termination - return &reply.EmptyMultiBulkReply{} - } - } - } + if result == nil { + // init + result = HashSet.MakeFromVals(set.ToSlice()...) + } else { + result = result.Intersect(set) + if result.Len() == 0 { + // early termination + return &reply.EmptyMultiBulkReply{} + } + } + } - arr := make([][]byte, result.Len()) - i := 0 - result.ForEach(func (member string)bool { - arr[i] = []byte(member) - i++ - return true - }) - return reply.MakeMultiBulkReply(arr) + arr := make([][]byte, result.Len()) + i := 0 + result.ForEach(func(member string) bool { + arr[i] = []byte(member) + i++ + return true + }) + return reply.MakeMultiBulkReply(arr) } func SInterStore(db *DB, args [][]byte) redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'sinterstore' command") - } - dest := string(args[0]) - keys := make([]string, len(args) - 1) - keyArgs := args[1:] - for i, arg := range keyArgs { - keys[i] = string(arg) - } + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'sinterstore' command") + } + dest := string(args[0]) + keys := make([]string, len(args)-1) + keyArgs := args[1:] + for i, arg := range keyArgs { + keys[i] = string(arg) + } - // lock - db.RLocks(keys...) - defer db.RUnLocks(keys...) - db.Lock(dest) - defer db.UnLock(dest) + // lock + db.RLocks(keys...) + defer db.RUnLocks(keys...) + db.Lock(dest) + defer db.UnLock(dest) - var result *HashSet.Set - for _, key := range keys { - set, errReply := db.getAsSet(key) - if errReply != nil { - return errReply - } - if set == nil { - db.Remove(dest) // clean ttl and old value - return &reply.EmptyMultiBulkReply{} - } + var result *HashSet.Set + for _, key := range keys { + set, errReply := db.getAsSet(key) + if errReply != nil { + return errReply + } + if set == nil { + db.Remove(dest) // clean ttl and old value + return &reply.EmptyMultiBulkReply{} + } - if result == nil { - // init - result = HashSet.MakeFromVals(set.ToSlice()...) - } else { - result = result.Intersect(set) - if result.Len() == 0 { - // early termination - db.Remove(dest) // clean ttl and old value - return reply.MakeIntReply(0) - } - } - } + if result == nil { + // init + result = HashSet.MakeFromVals(set.ToSlice()...) + } else { + result = result.Intersect(set) + if result.Len() == 0 { + // early termination + db.Remove(dest) // clean ttl and old value + return reply.MakeIntReply(0) + } + } + } - set := HashSet.MakeFromVals(result.ToSlice()...) - db.Put(dest, &DataEntity{ - Data: set, - }) - db.AddAof(makeAofCmd("sinterstore", args)) - return reply.MakeIntReply(int64(set.Len())) + set := HashSet.MakeFromVals(result.ToSlice()...) + db.Put(dest, &DataEntity{ + Data: set, + }) + db.AddAof(makeAofCmd("sinterstore", args)) + return reply.MakeIntReply(int64(set.Len())) } func SUnion(db *DB, args [][]byte) redis.Reply { - if len(args) < 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'sunion' command") - } - keys := make([]string, len(args)) - for i, arg := range args { - keys[i] = string(arg) - } + if len(args) < 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'sunion' command") + } + keys := make([]string, len(args)) + for i, arg := range args { + keys[i] = string(arg) + } - // lock - db.RLocks(keys...) - defer db.RUnLocks(keys...) + // lock + db.RLocks(keys...) + defer db.RUnLocks(keys...) - var result *HashSet.Set - for _, key := range keys { - set, errReply := db.getAsSet(key) - if errReply != nil { - return errReply - } - if set == nil { - continue - } + var result *HashSet.Set + for _, key := range keys { + set, errReply := db.getAsSet(key) + if errReply != nil { + return errReply + } + if set == nil { + continue + } - if result == nil { - // init - result = HashSet.MakeFromVals(set.ToSlice()...) - } else { - result = result.Union(set) - } - } + if result == nil { + // init + result = HashSet.MakeFromVals(set.ToSlice()...) + } else { + result = result.Union(set) + } + } - if result == nil { - // all keys are empty set - return &reply.EmptyMultiBulkReply{} - } - arr := make([][]byte, result.Len()) - i := 0 - result.ForEach(func (member string)bool { - arr[i] = []byte(member) - i++ - return true - }) - return reply.MakeMultiBulkReply(arr) + if result == nil { + // all keys are empty set + return &reply.EmptyMultiBulkReply{} + } + arr := make([][]byte, result.Len()) + i := 0 + result.ForEach(func(member string) bool { + arr[i] = []byte(member) + i++ + return true + }) + return reply.MakeMultiBulkReply(arr) } func SUnionStore(db *DB, args [][]byte) redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'sunionstore' command") - } - dest := string(args[0]) - keys := make([]string, len(args) - 1) - keyArgs := args[1:] - for i, arg := range keyArgs { - keys[i] = string(arg) - } + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'sunionstore' command") + } + dest := string(args[0]) + keys := make([]string, len(args)-1) + keyArgs := args[1:] + for i, arg := range keyArgs { + keys[i] = string(arg) + } - // lock - db.RLocks(keys...) - defer db.RUnLocks(keys...) - db.Lock(dest) - defer db.UnLock(dest) + // lock + db.RLocks(keys...) + defer db.RUnLocks(keys...) + db.Lock(dest) + defer db.UnLock(dest) - var result *HashSet.Set - for _, key := range keys { - set, errReply := db.getAsSet(key) - if errReply != nil { - return errReply - } - if set == nil { - continue - } - if result == nil { - // init - result = HashSet.MakeFromVals(set.ToSlice()...) - } else { - result = result.Union(set) - } - } + var result *HashSet.Set + for _, key := range keys { + set, errReply := db.getAsSet(key) + if errReply != nil { + return errReply + } + if set == nil { + continue + } + if result == nil { + // init + result = HashSet.MakeFromVals(set.ToSlice()...) + } else { + result = result.Union(set) + } + } - db.Remove(dest) // clean ttl - if result == nil { - // all keys are empty set - return &reply.EmptyMultiBulkReply{} - } + db.Remove(dest) // clean ttl + if result == nil { + // all keys are empty set + return &reply.EmptyMultiBulkReply{} + } - set := HashSet.MakeFromVals(result.ToSlice()...) - db.Put(dest, &DataEntity{ - Data: set, - }) + set := HashSet.MakeFromVals(result.ToSlice()...) + db.Put(dest, &DataEntity{ + Data: set, + }) - db.AddAof(makeAofCmd("sunionstore", args)) - return reply.MakeIntReply(int64(set.Len())) + db.AddAof(makeAofCmd("sunionstore", args)) + return reply.MakeIntReply(int64(set.Len())) } func SDiff(db *DB, args [][]byte) redis.Reply { - if len(args) < 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'sdiff' command") - } - keys := make([]string, len(args)) - for i, arg := range args { - keys[i] = string(arg) - } + if len(args) < 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'sdiff' command") + } + keys := make([]string, len(args)) + for i, arg := range args { + keys[i] = string(arg) + } - // lock - db.RLocks(keys...) - defer db.RUnLocks(keys...) + // lock + db.RLocks(keys...) + defer db.RUnLocks(keys...) - var result *HashSet.Set - for i, key := range keys { - set, errReply := db.getAsSet(key) - if errReply != nil { - return errReply - } - if set == nil { - if i == 0 { - // early termination - return &reply.EmptyMultiBulkReply{} - } else { - continue - } - } - if result == nil { - // init - result = HashSet.MakeFromVals(set.ToSlice()...) - } else { - result = result.Diff(set) - if result.Len() == 0 { - // early termination - return &reply.EmptyMultiBulkReply{} - } - } - } + var result *HashSet.Set + for i, key := range keys { + set, errReply := db.getAsSet(key) + if errReply != nil { + return errReply + } + if set == nil { + if i == 0 { + // early termination + return &reply.EmptyMultiBulkReply{} + } else { + continue + } + } + if result == nil { + // init + result = HashSet.MakeFromVals(set.ToSlice()...) + } else { + result = result.Diff(set) + if result.Len() == 0 { + // early termination + return &reply.EmptyMultiBulkReply{} + } + } + } - if result == nil { - // all keys are nil - return &reply.EmptyMultiBulkReply{} - } - arr := make([][]byte, result.Len()) - i := 0 - result.ForEach(func (member string)bool { - arr[i] = []byte(member) - i++ - return true - }) - return reply.MakeMultiBulkReply(arr) + if result == nil { + // all keys are nil + return &reply.EmptyMultiBulkReply{} + } + arr := make([][]byte, result.Len()) + i := 0 + result.ForEach(func(member string) bool { + arr[i] = []byte(member) + i++ + return true + }) + return reply.MakeMultiBulkReply(arr) } func SDiffStore(db *DB, args [][]byte) redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'sdiffstore' command") - } - dest := string(args[0]) - keys := make([]string, len(args) - 1) - keyArgs := args[1:] - for i, arg := range keyArgs { - keys[i] = string(arg) - } + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'sdiffstore' command") + } + dest := string(args[0]) + keys := make([]string, len(args)-1) + keyArgs := args[1:] + for i, arg := range keyArgs { + keys[i] = string(arg) + } - // lock - db.RLocks(keys...) - defer db.RUnLocks(keys...) - db.Lock(dest) - defer db.Locker.UnLock(dest) + // lock + db.RLocks(keys...) + defer db.RUnLocks(keys...) + db.Lock(dest) + defer db.Locker.UnLock(dest) - var result *HashSet.Set - for i, key := range keys { - set, errReply := db.getAsSet(key) - if errReply != nil { - return errReply - } - if set == nil { - if i == 0 { - // early termination - db.Remove(dest) - return &reply.EmptyMultiBulkReply{} - } else { - continue - } - } - if result == nil { - // init - result = HashSet.MakeFromVals(set.ToSlice()...) - } else { - result = result.Diff(set) - if result.Len() == 0 { - // early termination - db.Remove(dest) - return &reply.EmptyMultiBulkReply{} - } - } - } + var result *HashSet.Set + for i, key := range keys { + set, errReply := db.getAsSet(key) + if errReply != nil { + return errReply + } + if set == nil { + if i == 0 { + // early termination + db.Remove(dest) + return &reply.EmptyMultiBulkReply{} + } else { + continue + } + } + if result == nil { + // init + result = HashSet.MakeFromVals(set.ToSlice()...) + } else { + result = result.Diff(set) + if result.Len() == 0 { + // early termination + db.Remove(dest) + return &reply.EmptyMultiBulkReply{} + } + } + } - if result == nil { - // all keys are nil - db.Remove(dest) - return &reply.EmptyMultiBulkReply{} - } - set := HashSet.MakeFromVals(result.ToSlice()...) - db.Put(dest, &DataEntity{ - Data: set, - }) + if result == nil { + // all keys are nil + db.Remove(dest) + return &reply.EmptyMultiBulkReply{} + } + set := HashSet.MakeFromVals(result.ToSlice()...) + db.Put(dest, &DataEntity{ + Data: set, + }) - db.AddAof(makeAofCmd("sdiffstore", args)) - return reply.MakeIntReply(int64(set.Len())) + db.AddAof(makeAofCmd("sdiffstore", args)) + return reply.MakeIntReply(int64(set.Len())) } func SRandMember(db *DB, args [][]byte) redis.Reply { - if len(args) != 1 && len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'srandmember' command") - } - key := string(args[0]) - // lock - db.RLock(key) - defer db.RUnLock(key) + if len(args) != 1 && len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'srandmember' command") + } + key := string(args[0]) + // lock + db.RLock(key) + defer db.RUnLock(key) - // get or init entity - set, errReply := db.getAsSet(key) - if errReply != nil { - return errReply - } - if set == nil { - return &reply.NullBulkReply{} - } - if len(args) == 1 { - members := set.RandomMembers(1) - return reply.MakeBulkReply([]byte(members[0])) - } else { - count64, err := strconv.ParseInt(string(args[1]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - count := int(count64) + // get or init entity + set, errReply := db.getAsSet(key) + if errReply != nil { + return errReply + } + if set == nil { + return &reply.NullBulkReply{} + } + if len(args) == 1 { + members := set.RandomMembers(1) + return reply.MakeBulkReply([]byte(members[0])) + } else { + count64, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + count := int(count64) - if count > 0 { - members := set.RandomMembers(count) + if count > 0 { + members := set.RandomDistinctMembers(count) - result := make([][]byte, len(members)) - for i, v := range members { - result[i] = []byte(v) - } - return reply.MakeMultiBulkReply(result) - } else if count < 0 { - members := set.RandomDistinctMembers(-count) - result := make([][]byte, len(members)) - for i, v := range members { - result[i] = []byte(v) - } - return reply.MakeMultiBulkReply(result) - } else { - return &reply.EmptyMultiBulkReply{} - } - } -} \ No newline at end of file + result := make([][]byte, len(members)) + for i, v := range members { + result[i] = []byte(v) + } + return reply.MakeMultiBulkReply(result) + } else if count < 0 { + members := set.RandomMembers(-count) + result := make([][]byte, len(members)) + for i, v := range members { + result[i] = []byte(v) + } + return reply.MakeMultiBulkReply(result) + } else { + return &reply.EmptyMultiBulkReply{} + } + } +} diff --git a/src/db/set_test.go b/src/db/set_test.go new file mode 100644 index 0000000..233787a --- /dev/null +++ b/src/db/set_test.go @@ -0,0 +1,181 @@ +package db + +import ( + "fmt" + "github.com/HDT3213/godis/src/redis/reply" + "github.com/HDT3213/godis/src/redis/reply/asserts" + "strconv" + "testing" +) + +// basic add get and remove +func TestSAdd(t *testing.T) { + FlushAll(testDB, [][]byte{}) + size := 100 + + // test sadd + key := RandString(10) + for i := 0; i < size; i++ { + member := strconv.Itoa(i) + result := SAdd(testDB, toArgs(key, member)) + asserts.AssertIntReply(t, result, 1) + } + // test scard + result := SCard(testDB, toArgs(key)) + asserts.AssertIntReply(t, result, size) + + // test is member + for i := 0; i < size; i++ { + member := strconv.Itoa(i) + result := SIsMember(testDB, toArgs(key, member)) + asserts.AssertIntReply(t, result, 1) + } + + // test members + result = SMembers(testDB, toArgs(key)) + multiBulk, ok := result.(*reply.MultiBulkReply) + if !ok { + t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes())) + return + } + if len(multiBulk.Args) != size { + t.Error(fmt.Sprintf("expected %d elements, actually %d", size, len(multiBulk.Args))) + return + } +} + +func TestSRem(t *testing.T) { + FlushAll(testDB, [][]byte{}) + size := 100 + + // mock data + key := RandString(10) + for i := 0; i < size; i++ { + member := strconv.Itoa(i) + SAdd(testDB, toArgs(key, member)) + } + for i := 0; i < size; i++ { + member := strconv.Itoa(i) + SRem(testDB, toArgs(key, member)) + result := SIsMember(testDB, toArgs(key, member)) + asserts.AssertIntReply(t, result, 0) + } +} + +func TestSInter(t *testing.T) { + FlushAll(testDB, [][]byte{}) + size := 100 + step := 10 + + keys := make([]string, 0) + start := 0 + for i := 0; i < 4; i++ { + key := RandString(10) + keys = append(keys, key) + for j := start; j < size+start; j++ { + member := strconv.Itoa(j) + SAdd(testDB, toArgs(key, member)) + } + start += step + } + result := SInter(testDB, toArgs(keys...)) + asserts.AssertMultiBulkReplySize(t, result, 70) + + destKey := RandString(10) + keysWithDest := []string{destKey} + keysWithDest = append(keysWithDest, keys...) + result = SInterStore(testDB, toArgs(keysWithDest...)) + asserts.AssertIntReply(t, result, 70) +} + +func TestSUnion(t *testing.T) { + FlushAll(testDB, [][]byte{}) + size := 100 + step := 10 + + keys := make([]string, 0) + start := 0 + for i := 0; i < 4; i++ { + key := RandString(10) + keys = append(keys, key) + for j := start; j < size+start; j++ { + member := strconv.Itoa(j) + SAdd(testDB, toArgs(key, member)) + } + start += step + } + result := SUnion(testDB, toArgs(keys...)) + asserts.AssertMultiBulkReplySize(t, result, 130) + + destKey := RandString(10) + keysWithDest := []string{destKey} + keysWithDest = append(keysWithDest, keys...) + result = SUnionStore(testDB, toArgs(keysWithDest...)) + asserts.AssertIntReply(t, result, 130) +} + +func TestSDiff(t *testing.T) { + FlushAll(testDB, [][]byte{}) + size := 100 + step := 20 + + keys := make([]string, 0) + start := 0 + for i := 0; i < 3; i++ { + key := RandString(10) + keys = append(keys, key) + for j := start; j < size+start; j++ { + member := strconv.Itoa(j) + SAdd(testDB, toArgs(key, member)) + } + start += step + } + result := SDiff(testDB, toArgs(keys...)) + asserts.AssertMultiBulkReplySize(t, result, step) + + destKey := RandString(10) + keysWithDest := []string{destKey} + keysWithDest = append(keysWithDest, keys...) + result = SDiffStore(testDB, toArgs(keysWithDest...)) + asserts.AssertIntReply(t, result, step) +} + +func TestSRandMember(t *testing.T) { + FlushAll(testDB, [][]byte{}) + key := RandString(10) + for j := 0; j < 100; j++ { + member := strconv.Itoa(j) + SAdd(testDB, toArgs(key, member)) + } + result := SRandMember(testDB, toArgs(key)) + br, ok := result.(*reply.BulkReply) + if !ok && len(br.Arg) > 0 { + t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes())) + return + } + + result = SRandMember(testDB, toArgs(key, "10")) + asserts.AssertMultiBulkReplySize(t, result, 10) + multiBulk, ok := result.(*reply.MultiBulkReply) + if !ok { + t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes())) + return + } + m := make(map[string]struct{}) + for _, arg := range multiBulk.Args { + m[string(arg)] = struct{}{} + } + if len(m) != 10 { + t.Error(fmt.Sprintf("expected 10 members, actually %d", len(m))) + return + } + + result = SRandMember(testDB, toArgs(key, "110")) + asserts.AssertMultiBulkReplySize(t, result, 100) + + result = SRandMember(testDB, toArgs(key, "-10")) + asserts.AssertMultiBulkReplySize(t, result, 10) + + result = SRandMember(testDB, toArgs(key, "-110")) + asserts.AssertMultiBulkReplySize(t, result, 110) +} diff --git a/src/db/sortedset_test.go b/src/db/sortedset_test.go index f27adcf..8bc266d 100644 --- a/src/db/sortedset_test.go +++ b/src/db/sortedset_test.go @@ -1,298 +1,298 @@ package db import ( - "github.com/HDT3213/godis/src/redis/reply/asserts" - "math/rand" - "strconv" - "testing" + "github.com/HDT3213/godis/src/redis/reply/asserts" + "math/rand" + "strconv" + "testing" ) func TestZAdd(t *testing.T) { - FlushAll(testDB, [][]byte{}) - size := 100 + FlushAll(testDB, [][]byte{}) + size := 100 - // add new members - key := strconv.FormatInt(int64(rand.Int()), 10) - members := make([]string, size) - scores := make([]float64, size) - setArgs := []string{key} - for i := 0; i < size; i++ { - members[i] = strconv.FormatInt(int64(rand.Int()), 10) - scores[i] = rand.Float64() - setArgs = append(setArgs, strconv.FormatFloat(scores[i], 'f', -1, 64), members[i]) - } - result := ZAdd(testDB, toArgs(setArgs...)) - asserts.AssertIntReply(t, result, size) + // add new members + key := RandString(10) + members := make([]string, size) + scores := make([]float64, size) + setArgs := []string{key} + for i := 0; i < size; i++ { + members[i] = RandString(10) + scores[i] = rand.Float64() + setArgs = append(setArgs, strconv.FormatFloat(scores[i], 'f', -1, 64), members[i]) + } + result := ZAdd(testDB, toArgs(setArgs...)) + asserts.AssertIntReply(t, result, size) - // test zscore and zrank - for i, member := range members { - result := ZScore(testDB, toArgs(key, member)) - score := strconv.FormatFloat(scores[i], 'f', -1, 64) - asserts.AssertBulkReply(t, result, score) - } + // test zscore and zrank + for i, member := range members { + result := ZScore(testDB, toArgs(key, member)) + score := strconv.FormatFloat(scores[i], 'f', -1, 64) + asserts.AssertBulkReply(t, result, score) + } - // test zcard - result = ZCard(testDB, toArgs(key)) - asserts.AssertIntReply(t, result, size) + // test zcard + result = ZCard(testDB, toArgs(key)) + asserts.AssertIntReply(t, result, size) - // update members - setArgs = []string{key} - for i := 0; i < size; i++ { - scores[i] = rand.Float64() + 100 - setArgs = append(setArgs, strconv.FormatFloat(scores[i], 'f', -1, 64), members[i]) - } - result = ZAdd(testDB, toArgs(setArgs...)) - asserts.AssertIntReply(t, result, 0) // return number of new members + // update members + setArgs = []string{key} + for i := 0; i < size; i++ { + scores[i] = rand.Float64() + 100 + setArgs = append(setArgs, strconv.FormatFloat(scores[i], 'f', -1, 64), members[i]) + } + result = ZAdd(testDB, toArgs(setArgs...)) + asserts.AssertIntReply(t, result, 0) // return number of new members - // test updated score - for i, member := range members { - result := ZScore(testDB, toArgs(key, member)) - score := strconv.FormatFloat(scores[i], 'f', -1, 64) - asserts.AssertBulkReply(t, result, score) - } + // test updated score + for i, member := range members { + result := ZScore(testDB, toArgs(key, member)) + score := strconv.FormatFloat(scores[i], 'f', -1, 64) + asserts.AssertBulkReply(t, result, score) + } } func TestZRank(t *testing.T) { - FlushAll(testDB, [][]byte{}) - size := 100 - key := strconv.FormatInt(int64(rand.Int()), 10) - members := make([]string, size) - scores := make([]int, size) - setArgs := []string{key} - for i := 0; i < size; i++ { - members[i] = strconv.FormatInt(int64(rand.Int()), 10) - scores[i] = i - setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) - } - ZAdd(testDB, toArgs(setArgs...)) + FlushAll(testDB, [][]byte{}) + size := 100 + key := RandString(10) + members := make([]string, size) + scores := make([]int, size) + setArgs := []string{key} + for i := 0; i < size; i++ { + members[i] = RandString(10) + scores[i] = i + setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) + } + ZAdd(testDB, toArgs(setArgs...)) - // test zrank - for i, member := range members { - result := ZRank(testDB, toArgs(key, member)) - asserts.AssertIntReply(t, result, i) + // test zrank + for i, member := range members { + result := ZRank(testDB, toArgs(key, member)) + asserts.AssertIntReply(t, result, i) - result = ZRevRank(testDB, toArgs(key, member)) - asserts.AssertIntReply(t, result, size-i-1) - } + result = ZRevRank(testDB, toArgs(key, member)) + asserts.AssertIntReply(t, result, size-i-1) + } } func TestZRange(t *testing.T) { - // prepare - FlushAll(testDB, [][]byte{}) - size := 100 - key := strconv.FormatInt(int64(rand.Int()), 10) - members := make([]string, size) - scores := make([]int, size) - setArgs := []string{key} - for i := 0; i < size; i++ { - members[i] = strconv.FormatInt(int64(rand.Int()), 10) - scores[i] = i - setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) - } - result := ZAdd(testDB, toArgs(setArgs...)) - reverseMembers := make([]string, size) - for i, v := range members { - reverseMembers[size-i-1] = v - } + // prepare + FlushAll(testDB, [][]byte{}) + size := 100 + key := RandString(10) + members := make([]string, size) + scores := make([]int, size) + setArgs := []string{key} + for i := 0; i < size; i++ { + members[i] = RandString(10) + scores[i] = i + setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) + } + result := ZAdd(testDB, toArgs(setArgs...)) + reverseMembers := make([]string, size) + for i, v := range members { + reverseMembers[size-i-1] = v + } - start := "0" - end := "9" - result = ZRange(testDB, toArgs(key, start, end)) - asserts.AssertMultiBulkReply(t, result, members[0:10]) - result = ZRevRange(testDB, toArgs(key, start, end)) - asserts.AssertMultiBulkReply(t, result, reverseMembers[0:10]) + start := "0" + end := "9" + result = ZRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, members[0:10]) + result = ZRevRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, reverseMembers[0:10]) - start = "0" - end = "200" - result = ZRange(testDB, toArgs(key, start, end)) - asserts.AssertMultiBulkReply(t, result, members) - result = ZRevRange(testDB, toArgs(key, start, end)) - asserts.AssertMultiBulkReply(t, result, reverseMembers) + start = "0" + end = "200" + result = ZRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, members) + result = ZRevRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, reverseMembers) - start = "0" - end = "-10" - result = ZRange(testDB, toArgs(key, start, end)) - asserts.AssertMultiBulkReply(t, result, members[0:size-10+1]) - result = ZRevRange(testDB, toArgs(key, start, end)) - asserts.AssertMultiBulkReply(t, result, reverseMembers[0:size-10+1]) + start = "0" + end = "-10" + result = ZRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, members[0:size-10+1]) + result = ZRevRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, reverseMembers[0:size-10+1]) - start = "0" - end = "-200" - result = ZRange(testDB, toArgs(key, start, end)) - asserts.AssertMultiBulkReply(t, result, members[0:0]) - result = ZRevRange(testDB, toArgs(key, start, end)) - asserts.AssertMultiBulkReply(t, result, reverseMembers[0:0]) + start = "0" + end = "-200" + result = ZRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, members[0:0]) + result = ZRevRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, reverseMembers[0:0]) - start = "-10" - end = "-1" - result = ZRange(testDB, toArgs(key, start, end)) - asserts.AssertMultiBulkReply(t, result, members[90:]) - result = ZRevRange(testDB, toArgs(key, start, end)) - asserts.AssertMultiBulkReply(t, result, reverseMembers[90:]) + start = "-10" + end = "-1" + result = ZRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, members[90:]) + result = ZRevRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, reverseMembers[90:]) } func reverse(src []string) []string { - result := make([]string, len(src)) - for i, v := range src { - result[len(src)-i-1] = v - } - return result + result := make([]string, len(src)) + for i, v := range src { + result[len(src)-i-1] = v + } + return result } func TestZRangeByScore(t *testing.T) { - // prepare - FlushAll(testDB, [][]byte{}) - size := 100 - key := strconv.FormatInt(int64(rand.Int()), 10) - members := make([]string, size) - scores := make([]int, size) - setArgs := []string{key} - for i := 0; i < size; i++ { - members[i] = strconv.FormatInt(int64(i), 10) - scores[i] = i - setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) - } - result := ZAdd(testDB, toArgs(setArgs...)) + // prepare + FlushAll(testDB, [][]byte{}) + size := 100 + key := RandString(10) + members := make([]string, size) + scores := make([]int, size) + setArgs := []string{key} + for i := 0; i < size; i++ { + members[i] = strconv.FormatInt(int64(i), 10) + scores[i] = i + setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) + } + result := ZAdd(testDB, toArgs(setArgs...)) - min := "20" - max := "30" - result = ZRangeByScore(testDB, toArgs(key, min, max)) - asserts.AssertMultiBulkReply(t, result, members[20:31]) - result = ZRevRangeByScore(testDB, toArgs(key, max, min)) - asserts.AssertMultiBulkReply(t, result, reverse(members[20:31])) + min := "20" + max := "30" + result = ZRangeByScore(testDB, toArgs(key, min, max)) + asserts.AssertMultiBulkReply(t, result, members[20:31]) + result = ZRevRangeByScore(testDB, toArgs(key, max, min)) + asserts.AssertMultiBulkReply(t, result, reverse(members[20:31])) - min = "-10" - max = "10" - result = ZRangeByScore(testDB, toArgs(key, min, max)) - asserts.AssertMultiBulkReply(t, result, members[0:11]) - result = ZRevRangeByScore(testDB, toArgs(key, max, min)) - asserts.AssertMultiBulkReply(t, result, reverse(members[0:11])) + min = "-10" + max = "10" + result = ZRangeByScore(testDB, toArgs(key, min, max)) + asserts.AssertMultiBulkReply(t, result, members[0:11]) + result = ZRevRangeByScore(testDB, toArgs(key, max, min)) + asserts.AssertMultiBulkReply(t, result, reverse(members[0:11])) - min = "90" - max = "110" - result = ZRangeByScore(testDB, toArgs(key, min, max)) - asserts.AssertMultiBulkReply(t, result, members[90:]) - result = ZRevRangeByScore(testDB, toArgs(key, max, min)) - asserts.AssertMultiBulkReply(t, result, reverse(members[90:])) + min = "90" + max = "110" + result = ZRangeByScore(testDB, toArgs(key, min, max)) + asserts.AssertMultiBulkReply(t, result, members[90:]) + result = ZRevRangeByScore(testDB, toArgs(key, max, min)) + asserts.AssertMultiBulkReply(t, result, reverse(members[90:])) - min = "(20" - max = "(30" - result = ZRangeByScore(testDB, toArgs(key, min, max)) - asserts.AssertMultiBulkReply(t, result, members[21:30]) - result = ZRevRangeByScore(testDB, toArgs(key, max, min)) - asserts.AssertMultiBulkReply(t, result, reverse(members[21:30])) + min = "(20" + max = "(30" + result = ZRangeByScore(testDB, toArgs(key, min, max)) + asserts.AssertMultiBulkReply(t, result, members[21:30]) + result = ZRevRangeByScore(testDB, toArgs(key, max, min)) + asserts.AssertMultiBulkReply(t, result, reverse(members[21:30])) - min = "20" - max = "40" - result = ZRangeByScore(testDB, toArgs(key, min, max, "LIMIT", "5", "5")) - asserts.AssertMultiBulkReply(t, result, members[25:30]) - result = ZRevRangeByScore(testDB, toArgs(key, max, min, "LIMIT", "5", "5")) - asserts.AssertMultiBulkReply(t, result, reverse(members[31:36])) + min = "20" + max = "40" + result = ZRangeByScore(testDB, toArgs(key, min, max, "LIMIT", "5", "5")) + asserts.AssertMultiBulkReply(t, result, members[25:30]) + result = ZRevRangeByScore(testDB, toArgs(key, max, min, "LIMIT", "5", "5")) + asserts.AssertMultiBulkReply(t, result, reverse(members[31:36])) } func TestZRem(t *testing.T) { - FlushAll(testDB, [][]byte{}) - size := 100 - key := strconv.FormatInt(int64(rand.Int()), 10) - members := make([]string, size) - scores := make([]int, size) - setArgs := []string{key} - for i := 0; i < size; i++ { - members[i] = strconv.FormatInt(int64(i), 10) - scores[i] = i - setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) - } - ZAdd(testDB, toArgs(setArgs...)) + FlushAll(testDB, [][]byte{}) + size := 100 + key := RandString(10) + members := make([]string, size) + scores := make([]int, size) + setArgs := []string{key} + for i := 0; i < size; i++ { + members[i] = strconv.FormatInt(int64(i), 10) + scores[i] = i + setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) + } + ZAdd(testDB, toArgs(setArgs...)) - args := []string{key} - args = append(args, members[0:10]...) - result := ZRem(testDB, toArgs(args...)) - asserts.AssertIntReply(t, result, 10) - result = ZCard(testDB, toArgs(key)) - asserts.AssertIntReply(t, result, size-10) + args := []string{key} + args = append(args, members[0:10]...) + result := ZRem(testDB, toArgs(args...)) + asserts.AssertIntReply(t, result, 10) + result = ZCard(testDB, toArgs(key)) + asserts.AssertIntReply(t, result, size-10) - // test ZRemRangeByRank - FlushAll(testDB, [][]byte{}) - size = 100 - key = strconv.FormatInt(int64(rand.Int()), 10) - members = make([]string, size) - scores = make([]int, size) - setArgs = []string{key} - for i := 0; i < size; i++ { - members[i] = strconv.FormatInt(int64(i), 10) - scores[i] = i - setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) - } - ZAdd(testDB, toArgs(setArgs...)) + // test ZRemRangeByRank + FlushAll(testDB, [][]byte{}) + size = 100 + key = RandString(10) + members = make([]string, size) + scores = make([]int, size) + setArgs = []string{key} + for i := 0; i < size; i++ { + members[i] = strconv.FormatInt(int64(i), 10) + scores[i] = i + setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) + } + ZAdd(testDB, toArgs(setArgs...)) - result = ZRemRangeByRank(testDB, toArgs(key, "0", "9")) - asserts.AssertIntReply(t, result, 10) - result = ZCard(testDB, toArgs(key)) - asserts.AssertIntReply(t, result, size-10) + result = ZRemRangeByRank(testDB, toArgs(key, "0", "9")) + asserts.AssertIntReply(t, result, 10) + result = ZCard(testDB, toArgs(key)) + asserts.AssertIntReply(t, result, size-10) - // test ZRemRangeByScore - FlushAll(testDB, [][]byte{}) - size = 100 - key = strconv.FormatInt(int64(rand.Int()), 10) - members = make([]string, size) - scores = make([]int, size) - setArgs = []string{key} - for i := 0; i < size; i++ { - members[i] = strconv.FormatInt(int64(i), 10) - scores[i] = i - setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) - } - ZAdd(testDB, toArgs(setArgs...)) + // test ZRemRangeByScore + FlushAll(testDB, [][]byte{}) + size = 100 + key = RandString(10) + members = make([]string, size) + scores = make([]int, size) + setArgs = []string{key} + for i := 0; i < size; i++ { + members[i] = strconv.FormatInt(int64(i), 10) + scores[i] = i + setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) + } + ZAdd(testDB, toArgs(setArgs...)) - result = ZRemRangeByScore(testDB, toArgs(key, "0", "9")) - asserts.AssertIntReply(t, result, 10) - result = ZCard(testDB, toArgs(key)) - asserts.AssertIntReply(t, result, size-10) + result = ZRemRangeByScore(testDB, toArgs(key, "0", "9")) + asserts.AssertIntReply(t, result, 10) + result = ZCard(testDB, toArgs(key)) + asserts.AssertIntReply(t, result, size-10) } func TestZCount(t *testing.T) { - // prepare - FlushAll(testDB, [][]byte{}) - size := 100 - key := strconv.FormatInt(int64(rand.Int()), 10) - members := make([]string, size) - scores := make([]int, size) - setArgs := []string{key} - for i := 0; i < size; i++ { - members[i] = strconv.FormatInt(int64(i), 10) - scores[i] = i - setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) - } - result := ZAdd(testDB, toArgs(setArgs...)) + // prepare + FlushAll(testDB, [][]byte{}) + size := 100 + key := RandString(10) + members := make([]string, size) + scores := make([]int, size) + setArgs := []string{key} + for i := 0; i < size; i++ { + members[i] = strconv.FormatInt(int64(i), 10) + scores[i] = i + setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) + } + result := ZAdd(testDB, toArgs(setArgs...)) - min := "20" - max := "30" - result = ZCount(testDB, toArgs(key, min, max)) - asserts.AssertIntReply(t, result, 11) + min := "20" + max := "30" + result = ZCount(testDB, toArgs(key, min, max)) + asserts.AssertIntReply(t, result, 11) - min = "-10" - max = "10" - result = ZCount(testDB, toArgs(key, min, max)) - asserts.AssertIntReply(t, result, 11) + min = "-10" + max = "10" + result = ZCount(testDB, toArgs(key, min, max)) + asserts.AssertIntReply(t, result, 11) - min = "90" - max = "110" - result = ZCount(testDB, toArgs(key, min, max)) - asserts.AssertIntReply(t, result, 10) + min = "90" + max = "110" + result = ZCount(testDB, toArgs(key, min, max)) + asserts.AssertIntReply(t, result, 10) - min = "(20" - max = "(30" - result = ZCount(testDB, toArgs(key, min, max)) - asserts.AssertIntReply(t, result, 9) + min = "(20" + max = "(30" + result = ZCount(testDB, toArgs(key, min, max)) + asserts.AssertIntReply(t, result, 9) } func TestZIncrBy(t *testing.T) { - FlushAll(testDB, [][]byte{}) - key := strconv.FormatInt(int64(rand.Int()), 10) - ZAdd(testDB, toArgs(key, "10", "a")) - result := ZIncrBy(testDB, toArgs(key, "10", "a")) - asserts.AssertBulkReply(t, result, "20") + FlushAll(testDB, [][]byte{}) + key := RandString(10) + ZAdd(testDB, toArgs(key, "10", "a")) + result := ZIncrBy(testDB, toArgs(key, "10", "a")) + asserts.AssertBulkReply(t, result, "20") - result = ZScore(testDB, toArgs(key, "a")) - asserts.AssertBulkReply(t, result, "20") + result = ZScore(testDB, toArgs(key, "a")) + asserts.AssertBulkReply(t, result, "20") } diff --git a/src/db/string.go b/src/db/string.go index 28a5727..f7d8ed5 100644 --- a/src/db/string.go +++ b/src/db/string.go @@ -1,509 +1,508 @@ package db import ( - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" - "github.com/shopspring/decimal" - "strconv" - "strings" - "time" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/reply" + "github.com/shopspring/decimal" + "strconv" + "strings" + "time" ) func (db *DB) getAsString(key string) ([]byte, reply.ErrorReply) { - entity, ok := db.Get(key) - if !ok { - return nil, nil - } - bytes, ok := entity.Data.([]byte) - if !ok { - return nil, &reply.WrongTypeErrReply{} - } - return bytes, nil + entity, ok := db.Get(key) + if !ok { + return nil, nil + } + bytes, ok := entity.Data.([]byte) + if !ok { + return nil, &reply.WrongTypeErrReply{} + } + return bytes, nil } func Get(db *DB, args [][]byte) redis.Reply { - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'get' command") - } - key := string(args[0]) - bytes, err := db.getAsString(key) - if err != nil { - return err - } - if bytes == nil { - return &reply.NullBulkReply{} - } - return reply.MakeBulkReply(bytes) + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'get' command") + } + key := string(args[0]) + bytes, err := db.getAsString(key) + if err != nil { + return err + } + if bytes == nil { + return &reply.NullBulkReply{} + } + return reply.MakeBulkReply(bytes) } const ( - upsertPolicy = iota // default - insertPolicy // set nx - updatePolicy // set ex + upsertPolicy = iota // default + insertPolicy // set nx + updatePolicy // set ex ) const unlimitedTTL int64 = 0 // SET key value [EX seconds] [PX milliseconds] [NX|XX] func Set(db *DB, args [][]byte) redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'set' command") - } - key := string(args[0]) - value := args[1] - policy := upsertPolicy - ttl := unlimitedTTL + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'set' command") + } + key := string(args[0]) + value := args[1] + policy := upsertPolicy + ttl := unlimitedTTL - // parse options - if len(args) > 2 { - for i := 2; i < len(args); i++ { - arg := strings.ToUpper(string(args[i])) - if arg == "NX" { // insert - if policy == updatePolicy { - return &reply.SyntaxErrReply{} - } - policy = insertPolicy - } else if arg == "XX" { // update policy - if policy == insertPolicy { - return &reply.SyntaxErrReply{} - } - policy = updatePolicy - } else if arg == "EX" { // ttl in seconds - if ttl != unlimitedTTL { - // ttl has been set - return &reply.SyntaxErrReply{} - } - if i+1 >= len(args) { - return &reply.SyntaxErrReply{} - } - ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64) - if err != nil { - return &reply.SyntaxErrReply{} - } - if ttlArg <= 0 { - return reply.MakeErrReply("ERR invalid expire time in set") - } - ttl = ttlArg * 1000 - i++ // skip next arg - } else if arg == "PX" { // ttl in milliseconds - if ttl != unlimitedTTL { - return &reply.SyntaxErrReply{} - } - if i+1 >= len(args) { - return &reply.SyntaxErrReply{} - } - ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64) - if err != nil { - return &reply.SyntaxErrReply{} - } - if ttlArg <= 0 { - return reply.MakeErrReply("ERR invalid expire time in set") - } - ttl = ttlArg - i++ // skip next arg - } else { - return &reply.SyntaxErrReply{} - } - } - } + // parse options + if len(args) > 2 { + for i := 2; i < len(args); i++ { + arg := strings.ToUpper(string(args[i])) + if arg == "NX" { // insert + if policy == updatePolicy { + return &reply.SyntaxErrReply{} + } + policy = insertPolicy + } else if arg == "XX" { // update policy + if policy == insertPolicy { + return &reply.SyntaxErrReply{} + } + policy = updatePolicy + } else if arg == "EX" { // ttl in seconds + if ttl != unlimitedTTL { + // ttl has been set + return &reply.SyntaxErrReply{} + } + if i+1 >= len(args) { + return &reply.SyntaxErrReply{} + } + ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64) + if err != nil { + return &reply.SyntaxErrReply{} + } + if ttlArg <= 0 { + return reply.MakeErrReply("ERR invalid expire time in set") + } + ttl = ttlArg * 1000 + i++ // skip next arg + } else if arg == "PX" { // ttl in milliseconds + if ttl != unlimitedTTL { + return &reply.SyntaxErrReply{} + } + if i+1 >= len(args) { + return &reply.SyntaxErrReply{} + } + ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64) + if err != nil { + return &reply.SyntaxErrReply{} + } + if ttlArg <= 0 { + return reply.MakeErrReply("ERR invalid expire time in set") + } + ttl = ttlArg + i++ // skip next arg + } else { + return &reply.SyntaxErrReply{} + } + } + } - entity := &DataEntity{ - Data: value, - } + entity := &DataEntity{ + Data: value, + } - db.Persist(key) // clean ttl - var result int - switch policy { - case upsertPolicy: - result = db.Put(key, entity) - case insertPolicy: - result = db.PutIfAbsent(key, entity) - case updatePolicy: - result = db.PutIfExists(key, entity) - } - /* - * 如果设置了ttl 则以最新的ttl为准 - * 如果没有设置ttl 是新增key的情况,不设置ttl。 - * 如果没有设置ttl 且已存在key的 不修改ttl 但需要增加aof - */ - if ttl != unlimitedTTL { - expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond) - db.Expire(key, expireTime) - db.AddAof(reply.MakeMultiBulkReply([][]byte{ - []byte("SET"), - args[0], - args[1], - })) - db.AddAof(makeExpireCmd(key, expireTime)) - } else if result > 0{ - db.Persist(key) // override ttl - db.AddAof(makeAofCmd("set", args)) - }else{ - db.AddAof(makeAofCmd("set", args)) - } - - if policy == upsertPolicy || result > 0 { - return &reply.OkReply{} - } else { - return &reply.NullBulkReply{} - } + db.Persist(key) // clean ttl + var result int + switch policy { + case upsertPolicy: + result = db.Put(key, entity) + case insertPolicy: + result = db.PutIfAbsent(key, entity) + case updatePolicy: + result = db.PutIfExists(key, entity) + } + /* + * 如果设置了ttl 则以最新的ttl为准 + * 如果没有设置ttl 是新增key的情况,不设置ttl。 + * 如果没有设置ttl 且已存在key的 不修改ttl 但需要增加aof + */ + if ttl != unlimitedTTL { + expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond) + db.Expire(key, expireTime) + db.AddAof(reply.MakeMultiBulkReply([][]byte{ + []byte("SET"), + args[0], + args[1], + })) + db.AddAof(makeExpireCmd(key, expireTime)) + } else if result > 0 { + db.Persist(key) // override ttl + db.AddAof(makeAofCmd("set", args)) + } else { + db.AddAof(makeAofCmd("set", args)) + } + + if policy == upsertPolicy || result > 0 { + return &reply.OkReply{} + } else { + return &reply.NullBulkReply{} + } } func SetNX(db *DB, args [][]byte) redis.Reply { - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'setnx' command") - } - key := string(args[0]) - value := args[1] - entity := &DataEntity{ - Data: value, - } - result := db.PutIfAbsent(key, entity) - if result > 0 { - db.AddAof(makeAofCmd("setnx", args)) - } - return reply.MakeIntReply(int64(result)) + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'setnx' command") + } + key := string(args[0]) + value := args[1] + entity := &DataEntity{ + Data: value, + } + result := db.PutIfAbsent(key, entity) + db.AddAof(makeAofCmd("setnx", args)) + return reply.MakeIntReply(int64(result)) } func SetEX(db *DB, args [][]byte) redis.Reply { - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'setex' command") - } - key := string(args[0]) - value := args[2] + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'setex' command") + } + key := string(args[0]) + value := args[2] - ttlArg, err := strconv.ParseInt(string(args[1]), 10, 64) - if err != nil { - return &reply.SyntaxErrReply{} - } - if ttlArg <= 0 { - return reply.MakeErrReply("ERR invalid expire time in setex") - } - ttl := ttlArg * 1000 + ttlArg, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return &reply.SyntaxErrReply{} + } + if ttlArg <= 0 { + return reply.MakeErrReply("ERR invalid expire time in setex") + } + ttl := ttlArg * 1000 - entity := &DataEntity{ - Data: value, - } + entity := &DataEntity{ + Data: value, + } - db.Lock(key) - defer db.UnLock(key) + db.Lock(key) + defer db.UnLock(key) - result := db.Put(key, entity) - expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond) - db.Expire(key, expireTime) - if result > 0 { - db.AddAof(makeAofCmd("setex", args)) - db.AddAof(makeExpireCmd(key, expireTime)) - } - return &reply.OkReply{} + db.Put(key, entity) + expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond) + db.Expire(key, expireTime) + db.AddAof(makeAofCmd("setex", args)) + db.AddAof(makeExpireCmd(key, expireTime)) + return &reply.OkReply{} } func PSetEX(db *DB, args [][]byte) redis.Reply { - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'psetex' command") - } - key := string(args[0]) - value := args[1] + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'setex' command") + } + key := string(args[0]) + value := args[2] - ttl, err := strconv.ParseInt(string(args[1]), 10, 64) - if err != nil { - return &reply.SyntaxErrReply{} - } - if ttl <= 0 { - return reply.MakeErrReply("ERR invalid expire time in psetex") - } + ttlArg, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return &reply.SyntaxErrReply{} + } + if ttlArg <= 0 { + return reply.MakeErrReply("ERR invalid expire time in setex") + } - entity := &DataEntity{ - Data: value, - } - result := db.PutIfExists(key, entity) - if result > 0 { - db.AddAof(makeAofCmd("psetex", args)) - if ttl != unlimitedTTL { - expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond) - db.Expire(key, expireTime) - db.AddAof(makeExpireCmd(key, expireTime)) - } - } - return &reply.OkReply{} + entity := &DataEntity{ + Data: value, + } + + db.Lock(key) + defer db.UnLock(key) + + db.Put(key, entity) + expireTime := time.Now().Add(time.Duration(ttlArg) * time.Millisecond) + db.Expire(key, expireTime) + db.AddAof(makeAofCmd("setex", args)) + db.AddAof(makeExpireCmd(key, expireTime)) + + return &reply.OkReply{} } func MSet(db *DB, args [][]byte) redis.Reply { - if len(args)%2 != 0 || len(args) == 0 { - return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command") - } + if len(args)%2 != 0 || len(args) == 0 { + return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command") + } - size := len(args) / 2 - keys := make([]string, size) - values := make([][]byte, size) - for i := 0; i < size; i++ { - keys[i] = string(args[2*i]) - values[i] = args[2*i+1] - } + size := len(args) / 2 + keys := make([]string, size) + values := make([][]byte, size) + for i := 0; i < size; i++ { + keys[i] = string(args[2*i]) + values[i] = args[2*i+1] + } - db.Locks(keys...) - defer db.UnLocks(keys...) + db.Locks(keys...) + defer db.UnLocks(keys...) - for i, key := range keys { - value := values[i] - db.Put(key, &DataEntity{Data: value}) - } - db.AddAof(makeAofCmd("mset", args)) - return &reply.OkReply{} + for i, key := range keys { + value := values[i] + db.Put(key, &DataEntity{Data: value}) + } + db.AddAof(makeAofCmd("mset", args)) + return &reply.OkReply{} } func MGet(db *DB, args [][]byte) redis.Reply { - if len(args) == 0 { - return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command") - } - keys := make([]string, len(args)) - for i, v := range args { - keys[i] = string(v) - } + if len(args) == 0 { + return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command") + } + keys := make([]string, len(args)) + for i, v := range args { + keys[i] = string(v) + } - result := make([][]byte, len(args)) - for i, key := range keys { - bytes, err := db.getAsString(key) - if err != nil { - _, isWrongType := err.(*reply.WrongTypeErrReply) - if isWrongType { - result[i] = nil - continue - } else { - return err - } - } - result[i] = bytes // nil or []byte - } + result := make([][]byte, len(args)) + for i, key := range keys { + bytes, err := db.getAsString(key) + if err != nil { + _, isWrongType := err.(*reply.WrongTypeErrReply) + if isWrongType { + result[i] = nil + continue + } else { + return err + } + } + result[i] = bytes // nil or []byte + } - return reply.MakeMultiBulkReply(result) + return reply.MakeMultiBulkReply(result) } func MSetNX(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args)%2 != 0 || len(args) == 0 { - return reply.MakeErrReply("ERR wrong number of arguments for 'msetnx' command") - } - size := len(args) / 2 - values := make([][]byte, size) - keys := make([]string, size) - for i := 0; i < size; i++ { - keys[i] = string(args[2*i]) - values[i] = args[2*i+1] - } + // parse args + if len(args)%2 != 0 || len(args) == 0 { + return reply.MakeErrReply("ERR wrong number of arguments for 'msetnx' command") + } + size := len(args) / 2 + values := make([][]byte, size) + keys := make([]string, size) + for i := 0; i < size; i++ { + keys[i] = string(args[2*i]) + values[i] = args[2*i+1] + } - // lock keys - db.Locks(keys...) - defer db.UnLocks(keys...) + // lock keys + db.Locks(keys...) + defer db.UnLocks(keys...) - for _, key := range keys { - _, exists := db.Get(key) - if exists { - return reply.MakeIntReply(0) - } - } + for _, key := range keys { + _, exists := db.Get(key) + if exists { + return reply.MakeIntReply(0) + } + } - for i, key := range keys { - value := values[i] - db.Put(key, &DataEntity{Data: value}) - } - db.AddAof(makeAofCmd("msetnx", args)) - return reply.MakeIntReply(1) + for i, key := range keys { + value := values[i] + db.Put(key, &DataEntity{Data: value}) + } + db.AddAof(makeAofCmd("msetnx", args)) + return reply.MakeIntReply(1) } func GetSet(db *DB, args [][]byte) redis.Reply { - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'getset' command") - } - key := string(args[0]) - value := args[1] + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'getset' command") + } + key := string(args[0]) + value := args[1] - old, err := db.getAsString(key) - if err != nil { - return err - } + old, err := db.getAsString(key) + if err != nil { + return err + } - db.Put(key, &DataEntity{Data: value}) - db.Persist(key) // override ttl - db.AddAof(makeAofCmd("getset", args)) - - return reply.MakeBulkReply(old) + db.Put(key, &DataEntity{Data: value}) + db.Persist(key) // override ttl + db.AddAof(makeAofCmd("getset", args)) + if old == nil { + return new(reply.NullBulkReply) + } + return reply.MakeBulkReply(old) } func Incr(db *DB, args [][]byte) redis.Reply { - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'incr' command") - } - key := string(args[0]) + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'incr' command") + } + key := string(args[0]) - db.Lock(key) - defer db.UnLock(key) + db.Lock(key) + defer db.UnLock(key) - bytes, err := db.getAsString(key) - if err != nil { - return err - } - if bytes != nil { - val, err := strconv.ParseInt(string(bytes), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - db.Put(key, &DataEntity{ - Data: []byte(strconv.FormatInt(val+1, 10)), - }) - db.AddAof(makeAofCmd("incr", args)) - return reply.MakeIntReply(val + 1) - } else { - db.Put(key, &DataEntity{ - Data: []byte("1"), - }) - db.AddAof(makeAofCmd("incr", args)) - return reply.MakeIntReply(1) - } + bytes, err := db.getAsString(key) + if err != nil { + return err + } + if bytes != nil { + val, err := strconv.ParseInt(string(bytes), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + db.Put(key, &DataEntity{ + Data: []byte(strconv.FormatInt(val+1, 10)), + }) + db.AddAof(makeAofCmd("incr", args)) + return reply.MakeIntReply(val + 1) + } else { + db.Put(key, &DataEntity{ + Data: []byte("1"), + }) + db.AddAof(makeAofCmd("incr", args)) + return reply.MakeIntReply(1) + } } func IncrBy(db *DB, args [][]byte) redis.Reply { - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'incrby' command") - } - key := string(args[0]) - rawDelta := string(args[1]) - delta, err := strconv.ParseInt(rawDelta, 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'incrby' command") + } + key := string(args[0]) + rawDelta := string(args[1]) + delta, err := strconv.ParseInt(rawDelta, 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } - db.Lock(key) - defer db.UnLock(key) + db.Lock(key) + defer db.UnLock(key) - bytes, errReply := db.getAsString(key) - if errReply != nil { - return errReply - } - if bytes != nil { - val, err := strconv.ParseInt(string(bytes), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - db.Put(key, &DataEntity{ - Data: []byte(strconv.FormatInt(val+delta, 10)), - }) - db.AddAof(makeAofCmd("incrby", args)) - return reply.MakeIntReply(val + delta) - } else { - db.Put(key, &DataEntity{ - Data: args[1], - }) - db.AddAof(makeAofCmd("incrby", args)) - return reply.MakeIntReply(delta) - } + bytes, errReply := db.getAsString(key) + if errReply != nil { + return errReply + } + if bytes != nil { + val, err := strconv.ParseInt(string(bytes), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + db.Put(key, &DataEntity{ + Data: []byte(strconv.FormatInt(val+delta, 10)), + }) + db.AddAof(makeAofCmd("incrby", args)) + return reply.MakeIntReply(val + delta) + } else { + db.Put(key, &DataEntity{ + Data: args[1], + }) + db.AddAof(makeAofCmd("incrby", args)) + return reply.MakeIntReply(delta) + } } func IncrByFloat(db *DB, args [][]byte) redis.Reply { - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'incrbyfloat' command") - } - key := string(args[0]) - rawDelta := string(args[1]) - delta, err := decimal.NewFromString(rawDelta) - if err != nil { - return reply.MakeErrReply("ERR value is not a valid float") - } + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'incrbyfloat' command") + } + key := string(args[0]) + rawDelta := string(args[1]) + delta, err := decimal.NewFromString(rawDelta) + if err != nil { + return reply.MakeErrReply("ERR value is not a valid float") + } - db.Lock(key) - defer db.UnLock(key) + db.Lock(key) + defer db.UnLock(key) - bytes, errReply := db.getAsString(key) - if errReply != nil { - return errReply - } - if bytes != nil { - val, err := decimal.NewFromString(string(bytes)) - if err != nil { - return reply.MakeErrReply("ERR value is not a valid float") - } - resultBytes := []byte(val.Add(delta).String()) - db.Put(key, &DataEntity{ - Data: resultBytes, - }) - db.AddAof(makeAofCmd("incrbyfloat", args)) - return reply.MakeBulkReply(resultBytes) - } else { - db.Put(key, &DataEntity{ - Data: args[1], - }) - db.AddAof(makeAofCmd("incrbyfloat", args)) - return reply.MakeBulkReply(args[1]) - } + bytes, errReply := db.getAsString(key) + if errReply != nil { + return errReply + } + if bytes != nil { + val, err := decimal.NewFromString(string(bytes)) + if err != nil { + return reply.MakeErrReply("ERR value is not a valid float") + } + resultBytes := []byte(val.Add(delta).String()) + db.Put(key, &DataEntity{ + Data: resultBytes, + }) + db.AddAof(makeAofCmd("incrbyfloat", args)) + return reply.MakeBulkReply(resultBytes) + } else { + db.Put(key, &DataEntity{ + Data: args[1], + }) + db.AddAof(makeAofCmd("incrbyfloat", args)) + return reply.MakeBulkReply(args[1]) + } } func Decr(db *DB, args [][]byte) redis.Reply { - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'decr' command") - } - key := string(args[0]) + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'decr' command") + } + key := string(args[0]) - db.Lock(key) - defer db.UnLock(key) + db.Lock(key) + defer db.UnLock(key) - bytes, errReply := db.getAsString(key) - if errReply != nil { - return errReply - } - if bytes != nil { - val, err := strconv.ParseInt(string(bytes), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - db.Put(key, &DataEntity{ - Data: []byte(strconv.FormatInt(val-1, 10)), - }) - db.AddAof(makeAofCmd("decr", args)) - return reply.MakeIntReply(val - 1) - } else { - entity := &DataEntity{ - Data: []byte("-1"), - } - db.Put(key, entity) - db.AddAof(makeAofCmd("decr", args)) - return reply.MakeIntReply(-1) - } + bytes, errReply := db.getAsString(key) + if errReply != nil { + return errReply + } + if bytes != nil { + val, err := strconv.ParseInt(string(bytes), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + db.Put(key, &DataEntity{ + Data: []byte(strconv.FormatInt(val-1, 10)), + }) + db.AddAof(makeAofCmd("decr", args)) + return reply.MakeIntReply(val - 1) + } else { + entity := &DataEntity{ + Data: []byte("-1"), + } + db.Put(key, entity) + db.AddAof(makeAofCmd("decr", args)) + return reply.MakeIntReply(-1) + } } func DecrBy(db *DB, args [][]byte) redis.Reply { - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'decrby' command") - } - key := string(args[0]) - rawDelta := string(args[1]) - delta, err := strconv.ParseInt(rawDelta, 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'decrby' command") + } + key := string(args[0]) + rawDelta := string(args[1]) + delta, err := strconv.ParseInt(rawDelta, 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } - db.Lock(key) - defer db.UnLock(key) + db.Lock(key) + defer db.UnLock(key) - bytes, errReply := db.getAsString(key) - if errReply != nil { - return errReply - } - if bytes != nil { - val, err := strconv.ParseInt(string(bytes), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - db.Put(key, &DataEntity{ - Data: []byte(strconv.FormatInt(val-delta, 10)), - }) - db.AddAof(makeAofCmd("decrby", args)) - return reply.MakeIntReply(val - delta) - } else { - valueStr := strconv.FormatInt(-delta, 10) - db.Put(key, &DataEntity{ - Data: []byte(valueStr), - }) - db.AddAof(makeAofCmd("decrby", args)) - return reply.MakeIntReply(-delta) - } + bytes, errReply := db.getAsString(key) + if errReply != nil { + return errReply + } + if bytes != nil { + val, err := strconv.ParseInt(string(bytes), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + db.Put(key, &DataEntity{ + Data: []byte(strconv.FormatInt(val-delta, 10)), + }) + db.AddAof(makeAofCmd("decrby", args)) + return reply.MakeIntReply(val - delta) + } else { + valueStr := strconv.FormatInt(-delta, 10) + db.Put(key, &DataEntity{ + Data: []byte(valueStr), + }) + db.AddAof(makeAofCmd("decrby", args)) + return reply.MakeIntReply(-delta) + } } diff --git a/src/db/string_test.go b/src/db/string_test.go index 792fb9c..ed6e1cb 100644 --- a/src/db/string_test.go +++ b/src/db/string_test.go @@ -1,152 +1,288 @@ package db import ( - "github.com/HDT3213/godis/src/datastruct/utils" - "github.com/HDT3213/godis/src/redis/reply" - "math/rand" - "strconv" - "testing" + "fmt" + "github.com/HDT3213/godis/src/datastruct/utils" + "github.com/HDT3213/godis/src/redis/reply" + "github.com/HDT3213/godis/src/redis/reply/asserts" + "strconv" + "testing" ) var testDB = makeTestDB() func TestSet(t *testing.T) { - FlushAll(testDB, [][]byte{}) - key := strconv.FormatInt(int64(rand.Int()), 10) - value := strconv.FormatInt(int64(rand.Int()), 10) + FlushAll(testDB, [][]byte{}) + key := RandString(10) + value := RandString(10) - // normal set - Set(testDB, toArgs(key, value)) - actual := Get(testDB, toArgs(key)) - expected := reply.MakeBulkReply([]byte(value)) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) - } + // normal set + Set(testDB, toArgs(key, value)) + actual := Get(testDB, toArgs(key)) + expected := reply.MakeBulkReply([]byte(value)) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) + } - // set nx - actual = Set(testDB, toArgs(key, value, "NX")) - if _, ok := actual.(*reply.NullBulkReply); !ok { - t.Error("expected true actual false") - } + // set nx + actual = Set(testDB, toArgs(key, value, "NX")) + if _, ok := actual.(*reply.NullBulkReply); !ok { + t.Error("expected true actual false") + } - FlushAll(testDB, [][]byte{}) - key = strconv.FormatInt(int64(rand.Int()), 10) - value = strconv.FormatInt(int64(rand.Int()), 10) - Set(testDB, toArgs(key, value, "NX")) - actual = Get(testDB, toArgs(key)) - expected = reply.MakeBulkReply([]byte(value)) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) - } + FlushAll(testDB, [][]byte{}) + key = RandString(10) + value = RandString(10) + Set(testDB, toArgs(key, value, "NX")) + actual = Get(testDB, toArgs(key)) + expected = reply.MakeBulkReply([]byte(value)) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) + } - // set xx - FlushAll(testDB, [][]byte{}) - key = strconv.FormatInt(int64(rand.Int()), 10) - value = strconv.FormatInt(int64(rand.Int()), 10) - actual = Set(testDB, toArgs(key, value, "XX")) - if _, ok := actual.(*reply.NullBulkReply); !ok { - t.Error("expected true actually false ") - } + // set xx + FlushAll(testDB, [][]byte{}) + key = RandString(10) + value = RandString(10) + actual = Set(testDB, toArgs(key, value, "XX")) + if _, ok := actual.(*reply.NullBulkReply); !ok { + t.Error("expected true actually false ") + } - Set(testDB, toArgs(key, value)) - Set(testDB, toArgs(key, value, "XX")) - actual = Get(testDB, toArgs(key)) - expected = reply.MakeBulkReply([]byte(value)) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) - } + Set(testDB, toArgs(key, value)) + Set(testDB, toArgs(key, value, "XX")) + actual = Get(testDB, toArgs(key)) + asserts.AssertBulkReply(t, actual, value) + // set ex + Del(testDB, toArgs(key)) + ttl := "1000" + Set(testDB, toArgs(key, value, "EX", ttl)) + actual = Get(testDB, toArgs(key)) + asserts.AssertBulkReply(t, actual, value) + actual = TTL(testDB, toArgs(key)) + intResult, ok := actual.(*reply.IntReply) + if !ok { + t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes())) + return + } + if intResult.Code <= 0 || intResult.Code > 1000 { + t.Error(fmt.Sprintf("expected int between [0, 1000], actually %d", intResult.Code)) + return + } + + // set px + Del(testDB, toArgs(key)) + ttlPx := "1000000" + Set(testDB, toArgs(key, value, "PX", ttlPx)) + actual = Get(testDB, toArgs(key)) + asserts.AssertBulkReply(t, actual, value) + actual = TTL(testDB, toArgs(key)) + intResult, ok = actual.(*reply.IntReply) + if !ok { + t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes())) + return + } + if intResult.Code <= 0 || intResult.Code > 1000 { + t.Error(fmt.Sprintf("expected int between [0, 1000], actually %d", intResult.Code)) + return + } } func TestSetNX(t *testing.T) { - FlushAll(testDB, [][]byte{}) - key := strconv.FormatInt(int64(rand.Int()), 10) - value := strconv.FormatInt(int64(rand.Int()), 10) - SetNX(testDB, toArgs(key, value)) - actual := Get(testDB, toArgs(key)) - expected := reply.MakeBulkReply([]byte(value)) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) - } + FlushAll(testDB, [][]byte{}) + key := RandString(10) + value := RandString(10) + SetNX(testDB, toArgs(key, value)) + actual := Get(testDB, toArgs(key)) + expected := reply.MakeBulkReply([]byte(value)) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) + } - actual = SetNX(testDB, toArgs(key, value)) - expected2 := reply.MakeIntReply(int64(0)) - if !utils.BytesEquals(actual.ToBytes(), expected2.ToBytes()) { - t.Error("expected: " + string(expected2.ToBytes()) + ", actual: " + string(actual.ToBytes())) - } + actual = SetNX(testDB, toArgs(key, value)) + expected2 := reply.MakeIntReply(int64(0)) + if !utils.BytesEquals(actual.ToBytes(), expected2.ToBytes()) { + t.Error("expected: " + string(expected2.ToBytes()) + ", actual: " + string(actual.ToBytes())) + } } func TestSetEX(t *testing.T) { - FlushAll(testDB, [][]byte{}) - key := strconv.FormatInt(int64(rand.Int()), 10) - value := strconv.FormatInt(int64(rand.Int()), 10) - ttl := "1000" + FlushAll(testDB, [][]byte{}) + key := RandString(10) + value := RandString(10) + ttl := "1000" - SetEX(testDB, toArgs(key, ttl, value)) - actual := Get(testDB, toArgs(key)) - expected2 := reply.MakeBulkReply([]byte(value)) - if !utils.BytesEquals(actual.ToBytes(), expected2.ToBytes()) { - t.Error("expected: " + string(expected2.ToBytes()) + ", actual: " + string(actual.ToBytes())) - } + SetEX(testDB, toArgs(key, ttl, value)) + actual := Get(testDB, toArgs(key)) + asserts.AssertBulkReply(t, actual, value) + actual = TTL(testDB, toArgs(key)) + intResult, ok := actual.(*reply.IntReply) + if !ok { + t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes())) + return + } + if intResult.Code <= 0 || intResult.Code > 1000 { + t.Error(fmt.Sprintf("expected int between [0, 1000], actually %d", intResult.Code)) + return + } +} + +func TestPSetEX(t *testing.T) { + FlushAll(testDB, [][]byte{}) + key := RandString(10) + value := RandString(10) + ttl := "1000000" + + PSetEX(testDB, toArgs(key, ttl, value)) + actual := Get(testDB, toArgs(key)) + asserts.AssertBulkReply(t, actual, value) + actual = PTTL(testDB, toArgs(key)) + intResult, ok := actual.(*reply.IntReply) + if !ok { + t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes())) + return + } + if intResult.Code <= 0 || intResult.Code > 1000000 { + t.Error(fmt.Sprintf("expected int between [0, 1000], actually %d", intResult.Code)) + return + } } func TestMSet(t *testing.T) { - FlushAll(testDB, [][]byte{}) - size := 10 - keys := make([]string, size) - values := make([][]byte, size) - args := make([]string, 0, size*2) - for i := 0; i < size; i++ { - keys[i] = strconv.FormatInt(int64(rand.Int()), 10) - value := strconv.FormatInt(int64(rand.Int()), 10) - values[i] = []byte(value) - args = append(args, keys[i], value) - } - MSet(testDB, toArgs(args...)) - actual := MGet(testDB, toArgs(keys...)) - expected := reply.MakeMultiBulkReply(values) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) - } + FlushAll(testDB, [][]byte{}) + size := 10 + keys := make([]string, size) + values := make([][]byte, size) + args := make([]string, 0, size*2) + for i := 0; i < size; i++ { + keys[i] = RandString(10) + value := RandString(10) + values[i] = []byte(value) + args = append(args, keys[i], value) + } + MSet(testDB, toArgs(args...)) + actual := MGet(testDB, toArgs(keys...)) + expected := reply.MakeMultiBulkReply(values) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) + } } func TestIncr(t *testing.T) { - FlushAll(testDB, [][]byte{}) - size := 10 - key := strconv.FormatInt(int64(rand.Int()), 10) - for i := 0; i < size; i++ { - Incr(testDB, toArgs(key)) - actual := Get(testDB, toArgs(key)) - expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(i+1), 10))) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) - } - } - for i := 0; i < size; i++ { - IncrBy(testDB, toArgs(key, "-1")) - actual := Get(testDB, toArgs(key)) - expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(size-i-1), 10))) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) - } - } + FlushAll(testDB, [][]byte{}) + size := 10 + key := RandString(10) + for i := 0; i < size; i++ { + Incr(testDB, toArgs(key)) + actual := Get(testDB, toArgs(key)) + expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(i+1), 10))) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) + } + } + for i := 0; i < size; i++ { + IncrBy(testDB, toArgs(key, "-1")) + actual := Get(testDB, toArgs(key)) + expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(size-i-1), 10))) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) + } + } - FlushAll(testDB, [][]byte{}) - key = strconv.FormatInt(int64(rand.Int()), 10) - for i := 0; i < size; i++ { - IncrBy(testDB, toArgs(key, "1")) - actual := Get(testDB, toArgs(key)) - expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(i+1), 10))) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) - } - } - for i := 0; i < size; i++ { - IncrByFloat(testDB, toArgs(key, "-1.0")) - actual := Get(testDB, toArgs(key)) - expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(size-i-1), 10))) - if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { - t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) - } - } -} \ No newline at end of file + FlushAll(testDB, [][]byte{}) + key = RandString(10) + for i := 0; i < size; i++ { + IncrBy(testDB, toArgs(key, "1")) + actual := Get(testDB, toArgs(key)) + expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(i+1), 10))) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) + } + } + Del(testDB, toArgs(key)) + for i := 0; i < size; i++ { + IncrByFloat(testDB, toArgs(key, "-1.0")) + actual := Get(testDB, toArgs(key)) + expected := -i - 1 + bulk, ok := actual.(*reply.BulkReply) + if !ok { + t.Error(fmt.Sprintf("expected bulk reply, actually %s", actual.ToBytes())) + return + } + val, err := strconv.ParseFloat(string(bulk.Arg), 10) + if err != nil { + t.Error(err) + return + } + if int(val) != expected { + t.Errorf("expect %d, actual: %d", expected, int(val)) + return + } + } +} + +func TestDecr(t *testing.T) { + FlushAll(testDB, [][]byte{}) + size := 10 + key := RandString(10) + for i := 0; i < size; i++ { + Decr(testDB, toArgs(key)) + actual := Get(testDB, toArgs(key)) + asserts.AssertBulkReply(t, actual, strconv.Itoa(-i-1)) + } + Del(testDB, toArgs(key)) + for i := 0; i < size; i++ { + DecrBy(testDB, toArgs(key, "1")) + actual := Get(testDB, toArgs(key)) + expected := -i - 1 + bulk, ok := actual.(*reply.BulkReply) + if !ok { + t.Error(fmt.Sprintf("expected bulk reply, actually %s", actual.ToBytes())) + return + } + val, err := strconv.ParseFloat(string(bulk.Arg), 10) + if err != nil { + t.Error(err) + return + } + if int(val) != expected { + t.Errorf("expect %d, actual: %d", expected, int(val)) + return + } + } +} + +func TestGetSet(t *testing.T) { + FlushAll(testDB, [][]byte{}) + key := RandString(10) + value := RandString(10) + + result := GetSet(testDB, toArgs(key, value)) + _, ok := result.(*reply.NullBulkReply) + if !ok { + t.Errorf("expect null bulk reply, get: %s", string(result.ToBytes())) + return + } + + value2 := RandString(10) + result = GetSet(testDB, toArgs(key, value2)) + asserts.AssertBulkReply(t, result, value) + result = Get(testDB, toArgs(key)) + asserts.AssertBulkReply(t, result, value2) +} + +func TestMSetNX(t *testing.T) { + FlushAll(testDB, [][]byte{}) + size := 10 + args := make([]string, 0, size*2) + for i := 0; i < size; i++ { + str := RandString(10) + args = append(args, str, str) + } + result := MSetNX(testDB, toArgs(args...)) + asserts.AssertIntReply(t, result, 1) + + result = MSetNX(testDB, toArgs(args[0:4]...)) + asserts.AssertIntReply(t, result, 0) +} diff --git a/src/db/util_test.go b/src/db/util_test.go index c4b6246..15ce516 100644 --- a/src/db/util_test.go +++ b/src/db/util_test.go @@ -1,24 +1,35 @@ package db import ( - "github.com/HDT3213/godis/src/datastruct/dict" - "github.com/HDT3213/godis/src/datastruct/lock" - "time" + "github.com/HDT3213/godis/src/datastruct/dict" + "github.com/HDT3213/godis/src/datastruct/lock" + "math/rand" + "time" ) func makeTestDB() *DB { - return &DB{ - Data: dict.MakeConcurrent(1), - TTLMap: dict.MakeConcurrent(ttlDictSize), - Locker: lock.Make(lockerSize), - interval: 5 * time.Second, - } + return &DB{ + Data: dict.MakeConcurrent(1), + TTLMap: dict.MakeConcurrent(ttlDictSize), + Locker: lock.Make(lockerSize), + interval: 5 * time.Second, + } } func toArgs(cmd ...string) [][]byte { - args := make([][]byte, len(cmd)) - for i, s := range cmd { - args[i] = []byte(s) - } - return args + args := make([][]byte, len(cmd)) + for i, s := range cmd { + args[i] = []byte(s) + } + return args +} + +var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + +func RandString(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) } diff --git a/src/redis/reply/asserts/assert.go b/src/redis/reply/asserts/assert.go index 7cda3ef..2f5991f 100644 --- a/src/redis/reply/asserts/assert.go +++ b/src/redis/reply/asserts/assert.go @@ -1,49 +1,82 @@ package asserts import ( - "fmt" - "github.com/HDT3213/godis/src/datastruct/utils" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" - "testing" + "fmt" + "github.com/HDT3213/godis/src/datastruct/utils" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/reply" + "runtime" + "testing" ) func AssertIntReply(t *testing.T, actual redis.Reply, expected int) { - intResult, ok := actual.(*reply.IntReply) - if !ok { - t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes())) - return - } - if intResult.Code != int64(expected) { - t.Error(fmt.Sprintf("expected %d, actually %d", expected, intResult.Code)) - } + intResult, ok := actual.(*reply.IntReply) + if !ok { + t.Errorf("expected int reply, actually %s, %s", actual.ToBytes(), printStack()) + return + } + if intResult.Code != int64(expected) { + t.Errorf("expected %d, actually %d, %s", expected, intResult.Code, printStack()) + } } func AssertBulkReply(t *testing.T, actual redis.Reply, expected string) { - bulkReply, ok := actual.(*reply.BulkReply) - if !ok { - t.Error(fmt.Sprintf("expected bulk reply, actually %s", actual.ToBytes())) - return - } - if !utils.BytesEquals(bulkReply.Arg, []byte(expected)) { - t.Error(fmt.Sprintf("expected %s, actually %s", expected, actual.ToBytes())) - } + bulkReply, ok := actual.(*reply.BulkReply) + if !ok { + t.Errorf("expected bulk reply, actually %s, %s", actual.ToBytes(), printStack()) + return + } + if !utils.BytesEquals(bulkReply.Arg, []byte(expected)) { + t.Errorf("expected %s, actually %s, %s", expected, actual.ToBytes(), printStack()) + } +} + +func AssertStatusReply(t *testing.T, actual redis.Reply, expected string) { + statusReply, ok := actual.(*reply.StatusReply) + if !ok { + t.Errorf("expected bulk reply, actually %s, %s", actual.ToBytes(), printStack()) + return + } + if statusReply.Status != expected { + t.Errorf("expected %s, actually %s, %s", expected, actual.ToBytes(), printStack()) + } } func AssertMultiBulkReply(t *testing.T, actual redis.Reply, expected []string) { - multiBulk, ok := actual.(*reply.MultiBulkReply) - if !ok { - t.Error(fmt.Sprintf("expected bulk reply, actually %s", actual.ToBytes())) - return - } - if len(multiBulk.Args) != len(expected) { - t.Error(fmt.Sprintf("expected %d elements, actually %d", len(expected), len(multiBulk.Args))) - return - } - for i, v := range multiBulk.Args { - actual := string(v) - if actual != expected[i] { - t.Error(fmt.Sprintf("expected %s, actually %s", expected[i], actual)) - } - } + multiBulk, ok := actual.(*reply.MultiBulkReply) + if !ok { + t.Errorf("expected bulk reply, actually %s, %s", actual.ToBytes(), printStack()) + return + } + if len(multiBulk.Args) != len(expected) { + t.Errorf("expected %d elements, actually %d, %s", + len(expected), len(multiBulk.Args), printStack()) + return + } + for i, v := range multiBulk.Args { + str := string(v) + if str != expected[i] { + t.Errorf("expected %s, actually %s, %s", expected[i], actual, printStack()) + } + } +} + +func AssertMultiBulkReplySize(t *testing.T, actual redis.Reply, expected int) { + multiBulk, ok := actual.(*reply.MultiBulkReply) + if !ok { + t.Errorf("expected bulk reply, actually %s, %s", actual.ToBytes(), printStack()) + return + } + if len(multiBulk.Args) != expected { + t.Errorf("expected %d elements, actually %d, %s", expected, len(multiBulk.Args), printStack()) + return + } +} + +func printStack() string { + _, file, no, ok := runtime.Caller(2) + if ok { + return fmt.Sprintf("at %s#%d", file, no) + } + return "" }