diff --git a/database/set.go b/database/set.go index 41282fb..b283a0d 100644 --- a/database/set.go +++ b/database/set.go @@ -179,38 +179,10 @@ func execSMembers(db *DB, args [][]byte) redis.Reply { return protocol.MakeMultiBulkReply(arr) } -// execSInter intersect multiple sets -func execSInter(db *DB, args [][]byte) redis.Reply { - keys := make([]string, len(args)) - for i, arg := range args { - keys[i] = string(arg) - } - - var result *HashSet.Set - for _, key := range keys { - set, errReply := db.getAsSet(key) - if errReply != nil { - return errReply - } - if set == nil { - return &protocol.EmptyMultiBulkReply{} - } - - if result == nil { - // init - result = HashSet.Make(set.ToSlice()...) - } else { - result = result.Intersect(set) - if result.Len() == 0 { - // early termination - return &protocol.EmptyMultiBulkReply{} - } - } - } - - arr := make([][]byte, result.Len()) +func set2reply(set *HashSet.Set) redis.Reply { + arr := make([][]byte, set.Len()) i := 0 - result.ForEach(func(member string) bool { + set.ForEach(func(member string) bool { arr[i] = []byte(member) i++ return true @@ -218,221 +190,125 @@ func execSInter(db *DB, args [][]byte) redis.Reply { return protocol.MakeMultiBulkReply(arr) } +// execSInter intersect multiple sets +func execSInter(db *DB, args [][]byte) redis.Reply { + sets := make([]*HashSet.Set, 0, len(args)) + for _, arg := range args { + key := string(arg) + set, errReply := db.getAsSet(key) + if errReply != nil { + return errReply + } + if set.Len() == 0 { + return &protocol.EmptyMultiBulkReply{} + } + sets = append(sets, set) + } + result := HashSet.Intersect(sets...) + return set2reply(result) +} + // execSInterStore intersects multiple sets and store the result in a key func execSInterStore(db *DB, args [][]byte) redis.Reply { dest := string(args[0]) - keys := make([]string, len(args)-1) - keyArgs := args[1:] - for i, arg := range keyArgs { - keys[i] = string(arg) - } - - var result *HashSet.Set - for _, key := range keys { + sets := make([]*HashSet.Set, 0, len(args)-1) + for i := 1; i < len(args); i++ { + key := string(args[i]) set, errReply := db.getAsSet(key) if errReply != nil { return errReply } - if set == nil { - db.Remove(dest) // clean ttl and old value + if set.Len() == 0 { return protocol.MakeIntReply(0) } - - if result == nil { - // init - result = HashSet.Make(set.ToSlice()...) - } else { - result = result.Intersect(set) - if result.Len() == 0 { - // early termination - db.Remove(dest) // clean ttl and old value - return protocol.MakeIntReply(0) - } - } + sets = append(sets, set) } + result := HashSet.Intersect(sets...) - set := HashSet.Make(result.ToSlice()...) db.PutEntity(dest, &database.DataEntity{ - Data: set, + Data: result, }) db.addAof(utils.ToCmdLine3("sinterstore", args...)) - return protocol.MakeIntReply(int64(set.Len())) + return protocol.MakeIntReply(int64(result.Len())) } // execSUnion adds multiple sets func execSUnion(db *DB, args [][]byte) redis.Reply { - keys := make([]string, len(args)) - for i, arg := range args { - keys[i] = string(arg) - } - - var result *HashSet.Set - for _, key := range keys { + sets := make([]*HashSet.Set, 0, len(args)) + for _, arg := range args { + key := string(arg) set, errReply := db.getAsSet(key) if errReply != nil { return errReply } - if set == nil { - continue - } - - if result == nil { - // init - result = HashSet.Make(set.ToSlice()...) - } else { - result = result.Union(set) - } + sets = append(sets, set) } - - if result == nil { - // all keys are empty set - return &protocol.EmptyMultiBulkReply{} - } - arr := make([][]byte, result.Len()) - i := 0 - result.ForEach(func(member string) bool { - arr[i] = []byte(member) - i++ - return true - }) - return protocol.MakeMultiBulkReply(arr) + result := HashSet.Union(sets...) + return set2reply(result) } // execSUnionStore adds multiple sets and store the result in a key func execSUnionStore(db *DB, args [][]byte) redis.Reply { dest := string(args[0]) - keys := make([]string, len(args)-1) - keyArgs := args[1:] - for i, arg := range keyArgs { - keys[i] = string(arg) - } - - var result *HashSet.Set - for _, key := range keys { + sets := make([]*HashSet.Set, 0, len(args)-1) + for i := 1; i < len(args); i++ { + key := string(args[i]) set, errReply := db.getAsSet(key) if errReply != nil { return errReply } - if set == nil { - continue - } - if result == nil { - // init - result = HashSet.Make(set.ToSlice()...) - } else { - result = result.Union(set) - } + sets = append(sets, set) } - + result := HashSet.Union(sets...) db.Remove(dest) // clean ttl - if result == nil { - // all keys are empty set - return &protocol.EmptyMultiBulkReply{} + if result.Len() == 0 { + return protocol.MakeIntReply(0) } - set := HashSet.Make(result.ToSlice()...) db.PutEntity(dest, &database.DataEntity{ - Data: set, + Data: result, }) - db.addAof(utils.ToCmdLine3("sunionstore", args...)) - return protocol.MakeIntReply(int64(set.Len())) + return protocol.MakeIntReply(int64(result.Len())) } // execSDiff subtracts multiple sets func execSDiff(db *DB, args [][]byte) redis.Reply { - keys := make([]string, len(args)) - for i, arg := range args { - keys[i] = string(arg) - } - - var result *HashSet.Set - for i, key := range keys { + sets := make([]*HashSet.Set, 0, len(args)) + for _, arg := range args { + key := string(arg) set, errReply := db.getAsSet(key) if errReply != nil { return errReply } - if set == nil { - if i == 0 { - // early termination - return &protocol.EmptyMultiBulkReply{} - } - continue - } - if result == nil { - // init - result = HashSet.Make(set.ToSlice()...) - } else { - result = result.Diff(set) - if result.Len() == 0 { - // early termination - return &protocol.EmptyMultiBulkReply{} - } - } + sets = append(sets, set) } - - if result == nil { - // all keys are nil - return &protocol.EmptyMultiBulkReply{} - } - arr := make([][]byte, result.Len()) - i := 0 - result.ForEach(func(member string) bool { - arr[i] = []byte(member) - i++ - return true - }) - return protocol.MakeMultiBulkReply(arr) + result := HashSet.Diff(sets...) + return set2reply(result) } // execSDiffStore subtracts multiple sets and store the result in a key func execSDiffStore(db *DB, args [][]byte) redis.Reply { dest := string(args[0]) - keys := make([]string, len(args)-1) - keyArgs := args[1:] - for i, arg := range keyArgs { - keys[i] = string(arg) - } - - var result *HashSet.Set - for i, key := range keys { + sets := make([]*HashSet.Set, 0, len(args)-1) + for i := 1; i < len(args); i++ { + key := string(args[i]) set, errReply := db.getAsSet(key) if errReply != nil { return errReply } - if set == nil { - if i == 0 { - // early termination - db.Remove(dest) - return protocol.MakeIntReply(0) - } - continue - } - if result == nil { - // init - result = HashSet.Make(set.ToSlice()...) - } else { - result = result.Diff(set) - if result.Len() == 0 { - // early termination - db.Remove(dest) - return protocol.MakeIntReply(0) - } - } + sets = append(sets, set) } - - if result == nil { - // all keys are nil - db.Remove(dest) - return &protocol.EmptyMultiBulkReply{} + result := HashSet.Diff(sets...) + db.Remove(dest) // clean ttl + if result.Len() == 0 { + return protocol.MakeIntReply(0) } - set := HashSet.Make(result.ToSlice()...) db.PutEntity(dest, &database.DataEntity{ - Data: set, + Data: result, }) - db.addAof(utils.ToCmdLine3("sdiffstore", args...)) - return protocol.MakeIntReply(int64(set.Len())) + return protocol.MakeIntReply(int64(result.Len())) } // execSRandMember gets random members from set diff --git a/database/set_test.go b/database/set_test.go index a4a3561..638e0ac 100644 --- a/database/set_test.go +++ b/database/set_test.go @@ -106,7 +106,7 @@ func TestSInter(t *testing.T) { keys := make([]string, 0) start := 0 for i := 0; i < 4; i++ { - key := utils.RandString(10) + key := utils.RandString(10) + strconv.Itoa(i) keys = append(keys, key) for j := start; j < size+start; j++ { member := strconv.Itoa(j) diff --git a/datastruct/set/set.go b/datastruct/set/set.go index 50f29f1..2b1648b 100644 --- a/datastruct/set/set.go +++ b/datastruct/set/set.go @@ -1,6 +1,8 @@ package set -import "github.com/hdt3213/godis/datastruct/dict" +import ( + "github.com/hdt3213/godis/datastruct/dict" +) // Set is a set of elements based on hash table type Set struct { @@ -30,12 +32,18 @@ func (set *Set) Remove(val string) int { // Has returns true if the val exists in the set func (set *Set) Has(val string) bool { + if set == nil || set.dict == nil { + return false + } _, exists := set.dict.Get(val) return exists } // Len returns number of members in the set func (set *Set) Len() int { + if set == nil || set.dict == nil { + return 0 + } return set.dict.Len() } @@ -58,62 +66,81 @@ func (set *Set) ToSlice() []string { // ForEach visits each member in the set func (set *Set) ForEach(consumer func(member string) bool) { + if set == nil || set.dict == nil { + return + } set.dict.ForEach(func(key string, val interface{}) bool { return consumer(key) }) } -// Intersect intersects two sets -func (set *Set) Intersect(another *Set) *Set { - if set == nil { - panic("set is nil") - } - +// ShallowCopy copies all members to another set +func (set *Set) ShallowCopy() *Set { result := Make() - another.ForEach(func(member string) bool { - if set.Has(member) { - result.Add(member) - } + set.ForEach(func(member string) bool { + result.Add(member) return true }) return result } +// Intersect intersects two sets +func Intersect(sets ...*Set) *Set { + result := Make() + if len(sets) == 0 { + return result + } + + countMap := make(map[string]int) + for _, set := range sets { + set.ForEach(func(member string) bool { + countMap[member]++ + return true + }) + } + for k, v := range countMap { + if v == len(sets) { + result.Add(k) + } + } + return result +} + // Union adds two sets -func (set *Set) Union(another *Set) *Set { - if set == nil { - panic("set is nil") - } +func Union(sets ...*Set) *Set { result := Make() - another.ForEach(func(member string) bool { - result.Add(member) - return true - }) - set.ForEach(func(member string) bool { - result.Add(member) - return true - }) + for _, set := range sets { + set.ForEach(func(member string) bool { + result.Add(member) + return true + }) + } return result } // Diff subtracts two sets -func (set *Set) Diff(another *Set) *Set { - if set == nil { - panic("set is nil") +func Diff(sets ...*Set) *Set { + if len(sets) == 0 { + return Make() } - - result := Make() - set.ForEach(func(member string) bool { - if !another.Has(member) { - result.Add(member) + result := sets[0].ShallowCopy() + for i := 1; i < len(sets); i++ { + sets[i].ForEach(func(member string) bool { + result.Remove(member) + return true + }) + if result.Len() == 0 { + break } - return true - }) + } return result } // RandomMembers randomly returns keys of the given number, may contain duplicated key func (set *Set) RandomMembers(limit int) []string { + if set == nil || set.dict == nil { + return nil + } return set.dict.RandomKeys(limit) }