optimize set inter,union,diff

This commit is contained in:
finley
2023-05-08 23:07:19 +08:00
parent 6a7fb6a692
commit a266cc5aba
3 changed files with 121 additions and 218 deletions

View File

@@ -179,38 +179,10 @@ func execSMembers(db *DB, args [][]byte) redis.Reply {
return protocol.MakeMultiBulkReply(arr) return protocol.MakeMultiBulkReply(arr)
} }
// execSInter intersect multiple sets func set2reply(set *HashSet.Set) redis.Reply {
func execSInter(db *DB, args [][]byte) redis.Reply { arr := make([][]byte, set.Len())
keys := make([]string, len(args))
for i, arg := range args {
keys[i] = string(arg)
}
var result *HashSet.Set
for _, key := range keys {
set, errReply := db.getAsSet(key)
if errReply != nil {
return errReply
}
if set == nil {
return &protocol.EmptyMultiBulkReply{}
}
if result == nil {
// init
result = HashSet.Make(set.ToSlice()...)
} else {
result = result.Intersect(set)
if result.Len() == 0 {
// early termination
return &protocol.EmptyMultiBulkReply{}
}
}
}
arr := make([][]byte, result.Len())
i := 0 i := 0
result.ForEach(func(member string) bool { set.ForEach(func(member string) bool {
arr[i] = []byte(member) arr[i] = []byte(member)
i++ i++
return true return true
@@ -218,221 +190,125 @@ func execSInter(db *DB, args [][]byte) redis.Reply {
return protocol.MakeMultiBulkReply(arr) return protocol.MakeMultiBulkReply(arr)
} }
// execSInter intersect multiple sets
func execSInter(db *DB, args [][]byte) redis.Reply {
sets := make([]*HashSet.Set, 0, len(args))
for _, arg := range args {
key := string(arg)
set, errReply := db.getAsSet(key)
if errReply != nil {
return errReply
}
if set.Len() == 0 {
return &protocol.EmptyMultiBulkReply{}
}
sets = append(sets, set)
}
result := HashSet.Intersect(sets...)
return set2reply(result)
}
// execSInterStore intersects multiple sets and store the result in a key // execSInterStore intersects multiple sets and store the result in a key
func execSInterStore(db *DB, args [][]byte) redis.Reply { func execSInterStore(db *DB, args [][]byte) redis.Reply {
dest := string(args[0]) dest := string(args[0])
keys := make([]string, len(args)-1) sets := make([]*HashSet.Set, 0, len(args)-1)
keyArgs := args[1:] for i := 1; i < len(args); i++ {
for i, arg := range keyArgs { key := string(args[i])
keys[i] = string(arg)
}
var result *HashSet.Set
for _, key := range keys {
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { if set.Len() == 0 {
db.Remove(dest) // clean ttl and old value
return protocol.MakeIntReply(0) return protocol.MakeIntReply(0)
} }
sets = append(sets, set)
if result == nil {
// init
result = HashSet.Make(set.ToSlice()...)
} else {
result = result.Intersect(set)
if result.Len() == 0 {
// early termination
db.Remove(dest) // clean ttl and old value
return protocol.MakeIntReply(0)
}
}
} }
result := HashSet.Intersect(sets...)
set := HashSet.Make(result.ToSlice()...)
db.PutEntity(dest, &database.DataEntity{ db.PutEntity(dest, &database.DataEntity{
Data: set, Data: result,
}) })
db.addAof(utils.ToCmdLine3("sinterstore", args...)) db.addAof(utils.ToCmdLine3("sinterstore", args...))
return protocol.MakeIntReply(int64(set.Len())) return protocol.MakeIntReply(int64(result.Len()))
} }
// execSUnion adds multiple sets // execSUnion adds multiple sets
func execSUnion(db *DB, args [][]byte) redis.Reply { func execSUnion(db *DB, args [][]byte) redis.Reply {
keys := make([]string, len(args)) sets := make([]*HashSet.Set, 0, len(args))
for i, arg := range args { for _, arg := range args {
keys[i] = string(arg) key := string(arg)
}
var result *HashSet.Set
for _, key := range keys {
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { sets = append(sets, set)
continue
}
if result == nil {
// init
result = HashSet.Make(set.ToSlice()...)
} else {
result = result.Union(set)
}
} }
result := HashSet.Union(sets...)
if result == nil { return set2reply(result)
// all keys are empty set
return &protocol.EmptyMultiBulkReply{}
}
arr := make([][]byte, result.Len())
i := 0
result.ForEach(func(member string) bool {
arr[i] = []byte(member)
i++
return true
})
return protocol.MakeMultiBulkReply(arr)
} }
// execSUnionStore adds multiple sets and store the result in a key // execSUnionStore adds multiple sets and store the result in a key
func execSUnionStore(db *DB, args [][]byte) redis.Reply { func execSUnionStore(db *DB, args [][]byte) redis.Reply {
dest := string(args[0]) dest := string(args[0])
keys := make([]string, len(args)-1) sets := make([]*HashSet.Set, 0, len(args)-1)
keyArgs := args[1:] for i := 1; i < len(args); i++ {
for i, arg := range keyArgs { key := string(args[i])
keys[i] = string(arg)
}
var result *HashSet.Set
for _, key := range keys {
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { sets = append(sets, set)
continue
}
if result == nil {
// init
result = HashSet.Make(set.ToSlice()...)
} else {
result = result.Union(set)
}
} }
result := HashSet.Union(sets...)
db.Remove(dest) // clean ttl db.Remove(dest) // clean ttl
if result == nil { if result.Len() == 0 {
// all keys are empty set return protocol.MakeIntReply(0)
return &protocol.EmptyMultiBulkReply{}
} }
set := HashSet.Make(result.ToSlice()...)
db.PutEntity(dest, &database.DataEntity{ db.PutEntity(dest, &database.DataEntity{
Data: set, Data: result,
}) })
db.addAof(utils.ToCmdLine3("sunionstore", args...)) db.addAof(utils.ToCmdLine3("sunionstore", args...))
return protocol.MakeIntReply(int64(set.Len())) return protocol.MakeIntReply(int64(result.Len()))
} }
// execSDiff subtracts multiple sets // execSDiff subtracts multiple sets
func execSDiff(db *DB, args [][]byte) redis.Reply { func execSDiff(db *DB, args [][]byte) redis.Reply {
keys := make([]string, len(args)) sets := make([]*HashSet.Set, 0, len(args))
for i, arg := range args { for _, arg := range args {
keys[i] = string(arg) key := string(arg)
}
var result *HashSet.Set
for i, key := range keys {
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { sets = append(sets, set)
if i == 0 {
// early termination
return &protocol.EmptyMultiBulkReply{}
}
continue
}
if result == nil {
// init
result = HashSet.Make(set.ToSlice()...)
} else {
result = result.Diff(set)
if result.Len() == 0 {
// early termination
return &protocol.EmptyMultiBulkReply{}
}
}
} }
result := HashSet.Diff(sets...)
if result == nil { return set2reply(result)
// all keys are nil
return &protocol.EmptyMultiBulkReply{}
}
arr := make([][]byte, result.Len())
i := 0
result.ForEach(func(member string) bool {
arr[i] = []byte(member)
i++
return true
})
return protocol.MakeMultiBulkReply(arr)
} }
// execSDiffStore subtracts multiple sets and store the result in a key // execSDiffStore subtracts multiple sets and store the result in a key
func execSDiffStore(db *DB, args [][]byte) redis.Reply { func execSDiffStore(db *DB, args [][]byte) redis.Reply {
dest := string(args[0]) dest := string(args[0])
keys := make([]string, len(args)-1) sets := make([]*HashSet.Set, 0, len(args)-1)
keyArgs := args[1:] for i := 1; i < len(args); i++ {
for i, arg := range keyArgs { key := string(args[i])
keys[i] = string(arg)
}
var result *HashSet.Set
for i, key := range keys {
set, errReply := db.getAsSet(key) set, errReply := db.getAsSet(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if set == nil { sets = append(sets, set)
if i == 0 {
// early termination
db.Remove(dest)
return protocol.MakeIntReply(0)
}
continue
}
if result == nil {
// init
result = HashSet.Make(set.ToSlice()...)
} else {
result = result.Diff(set)
if result.Len() == 0 {
// early termination
db.Remove(dest)
return protocol.MakeIntReply(0)
}
}
} }
result := HashSet.Diff(sets...)
if result == nil { db.Remove(dest) // clean ttl
// all keys are nil if result.Len() == 0 {
db.Remove(dest) return protocol.MakeIntReply(0)
return &protocol.EmptyMultiBulkReply{}
} }
set := HashSet.Make(result.ToSlice()...)
db.PutEntity(dest, &database.DataEntity{ db.PutEntity(dest, &database.DataEntity{
Data: set, Data: result,
}) })
db.addAof(utils.ToCmdLine3("sdiffstore", args...)) db.addAof(utils.ToCmdLine3("sdiffstore", args...))
return protocol.MakeIntReply(int64(set.Len())) return protocol.MakeIntReply(int64(result.Len()))
} }
// execSRandMember gets random members from set // execSRandMember gets random members from set

View File

@@ -106,7 +106,7 @@ func TestSInter(t *testing.T) {
keys := make([]string, 0) keys := make([]string, 0)
start := 0 start := 0
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
key := utils.RandString(10) key := utils.RandString(10) + strconv.Itoa(i)
keys = append(keys, key) keys = append(keys, key)
for j := start; j < size+start; j++ { for j := start; j < size+start; j++ {
member := strconv.Itoa(j) member := strconv.Itoa(j)

View File

@@ -1,6 +1,8 @@
package set package set
import "github.com/hdt3213/godis/datastruct/dict" import (
"github.com/hdt3213/godis/datastruct/dict"
)
// Set is a set of elements based on hash table // Set is a set of elements based on hash table
type Set struct { type Set struct {
@@ -30,12 +32,18 @@ func (set *Set) Remove(val string) int {
// Has returns true if the val exists in the set // Has returns true if the val exists in the set
func (set *Set) Has(val string) bool { func (set *Set) Has(val string) bool {
if set == nil || set.dict == nil {
return false
}
_, exists := set.dict.Get(val) _, exists := set.dict.Get(val)
return exists return exists
} }
// Len returns number of members in the set // Len returns number of members in the set
func (set *Set) Len() int { func (set *Set) Len() int {
if set == nil || set.dict == nil {
return 0
}
return set.dict.Len() return set.dict.Len()
} }
@@ -58,62 +66,81 @@ func (set *Set) ToSlice() []string {
// ForEach visits each member in the set // ForEach visits each member in the set
func (set *Set) ForEach(consumer func(member string) bool) { func (set *Set) ForEach(consumer func(member string) bool) {
if set == nil || set.dict == nil {
return
}
set.dict.ForEach(func(key string, val interface{}) bool { set.dict.ForEach(func(key string, val interface{}) bool {
return consumer(key) return consumer(key)
}) })
} }
// Intersect intersects two sets // ShallowCopy copies all members to another set
func (set *Set) Intersect(another *Set) *Set { func (set *Set) ShallowCopy() *Set {
if set == nil {
panic("set is nil")
}
result := Make() result := Make()
another.ForEach(func(member string) bool { set.ForEach(func(member string) bool {
if set.Has(member) { result.Add(member)
result.Add(member)
}
return true return true
}) })
return result return result
} }
// Intersect intersects two sets
func Intersect(sets ...*Set) *Set {
result := Make()
if len(sets) == 0 {
return result
}
countMap := make(map[string]int)
for _, set := range sets {
set.ForEach(func(member string) bool {
countMap[member]++
return true
})
}
for k, v := range countMap {
if v == len(sets) {
result.Add(k)
}
}
return result
}
// Union adds two sets // Union adds two sets
func (set *Set) Union(another *Set) *Set { func Union(sets ...*Set) *Set {
if set == nil {
panic("set is nil")
}
result := Make() result := Make()
another.ForEach(func(member string) bool { for _, set := range sets {
result.Add(member) set.ForEach(func(member string) bool {
return true result.Add(member)
}) return true
set.ForEach(func(member string) bool { })
result.Add(member) }
return true
})
return result return result
} }
// Diff subtracts two sets // Diff subtracts two sets
func (set *Set) Diff(another *Set) *Set { func Diff(sets ...*Set) *Set {
if set == nil { if len(sets) == 0 {
panic("set is nil") return Make()
} }
result := sets[0].ShallowCopy()
result := Make() for i := 1; i < len(sets); i++ {
set.ForEach(func(member string) bool { sets[i].ForEach(func(member string) bool {
if !another.Has(member) { result.Remove(member)
result.Add(member) return true
})
if result.Len() == 0 {
break
} }
return true }
})
return result return result
} }
// RandomMembers randomly returns keys of the given number, may contain duplicated key // RandomMembers randomly returns keys of the given number, may contain duplicated key
func (set *Set) RandomMembers(limit int) []string { func (set *Set) RandomMembers(limit int) []string {
if set == nil || set.dict == nil {
return nil
}
return set.dict.RandomKeys(limit) return set.dict.RandomKeys(limit)
} }