Files
SugarDB/internal/modules/set/set.go
Kelvin Mwinuka 108bf97b4d Extend SugarDB Commands Using Lua Scripts (#155)
* Extend SugarDB by creating new commands using Lua - @kelvinmwinuka
2024-12-12 09:50:43 +08:00

217 lines
4.8 KiB
Go

// Copyright 2024 Kelvin Clement Mwinuka
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package set
import (
"math/rand"
"slices"
"unsafe"
"github.com/echovault/sugardb/internal"
"github.com/echovault/sugardb/internal/constants"
)
type Set struct {
members map[string]interface{}
length int
}
func (set *Set) GetMem() int64 {
var size int64
size += int64(unsafe.Sizeof(set))
// above only gives us the size of the pointer to the map, so we need to add its headers and contents
size += int64(unsafe.Sizeof(set.members))
for k, v := range set.members {
size += int64(unsafe.Sizeof(k))
size += int64(len(k))
size += int64(unsafe.Sizeof(v))
}
return size
}
// compile time interface check
var _ constants.CompositeType = (*Set)(nil)
func NewSet(elems []string) *Set {
set := &Set{
members: make(map[string]interface{}),
length: 0,
}
set.Add(elems)
return set
}
func (set *Set) Add(elems []string) int {
count := 0
for _, e := range elems {
if !set.Contains(e) {
set.members[e] = struct{}{}
count += 1
}
}
set.length += count
return count
}
func (set *Set) get(e string) interface{} {
return set.members[e]
}
func (set *Set) GetAll() []string {
var res []string
for e, _ := range set.members {
res = append(res, e)
}
return res
}
func (set *Set) Cardinality() int {
return set.length
}
func (set *Set) GetRandom(count int) []string {
keys := set.GetAll()
if count == 0 {
return []string{}
}
if internal.AbsInt(count) >= set.Cardinality() {
return keys
}
res := []string{}
var n int
if count < 0 {
// If count is negative, allow repeat elements
for i := 0; i < internal.AbsInt(count); i++ {
n = rand.Intn(len(keys))
res = append(res, keys[n])
}
} else {
// Count is positive, do not allow repeat elements
for i := 0; i < internal.AbsInt(count); {
n = rand.Intn(len(keys))
if !slices.Contains(res, keys[n]) {
res = append(res, keys[n])
keys = slices.DeleteFunc(keys, func(elem string) bool {
return elem == keys[n]
})
i++
}
}
}
return res
}
func (set *Set) Remove(elems []string) int {
count := 0
for _, e := range elems {
if set.get(e) != nil {
delete(set.members, e)
count += 1
}
}
set.length -= count
return count
}
func (set *Set) Pop(count int) []string {
keys := set.GetRandom(count)
set.Remove(keys)
return keys
}
func (set *Set) Contains(e string) bool {
return set.get(e) != nil
}
// Subtract received a list of sets and finds the difference between sets provided
func (set *Set) Subtract(others []*Set) *Set {
diff := NewSet(set.GetAll())
var remove []string
for _, s := range others {
for k, _ := range s.members {
if diff.Contains(k) {
remove = append(remove, k)
}
}
}
diff.Remove(remove)
return diff
}
func (set *Set) Move(destination *Set, e string) int {
if !set.Contains(e) {
return 0
}
set.Remove([]string{e})
destination.Add([]string{e})
return 1
}
// The Intersection accepts limit parameter of type int and a list of sets whose intersects are to be calculated.
// When limit is greater than 0, then the calculation will stop once the intersect cardinality reaches limit without
// calculating the full intersect.
func Intersection(limit int, sets ...*Set) (*Set, bool) {
// Use divide & conquer to get the set intersections
switch len(sets) {
case 1:
return sets[0], false
case 2:
intersection := NewSet([]string{})
var limitReached bool
for _, member := range sets[0].GetAll() {
if limit > 0 && intersection.Cardinality() >= limit {
limitReached = true
break
}
if sets[1].Contains(member) {
intersection.Add([]string{member})
}
}
return intersection, limitReached
default:
left, stop := Intersection(limit, sets[0:len(sets)/2]...)
if stop { // Check if limit is reached by left, if it is, return left
return left, stop
}
right, stop := Intersection(limit, sets[len(sets)/2:]...)
if stop { // Check if limit is reached by right, if it is, return right
return right, stop
}
return Intersection(limit, left, right)
}
}
// Union takes a slice of sets and generates a union
func Union(sets ...*Set) *Set {
switch len(sets) {
case 1:
return sets[0]
case 2:
union := sets[0]
union.Add(sets[1].GetAll())
return union
default:
left := Union(sets[0 : len(sets)/2]...)
right := Union(sets[len(sets)/2:]...)
return Union(left, right)
}
}