mirror of
https://github.com/nalgeon/redka.git
synced 2025-10-29 02:22:34 +08:00
191 lines
4.5 KiB
Go
191 lines
4.5 KiB
Go
package rzset
|
|
|
|
import (
|
|
"database/sql"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/nalgeon/redka/internal/core"
|
|
"github.com/nalgeon/redka/internal/sqlx"
|
|
)
|
|
|
|
const (
|
|
sqlInter = `
|
|
select elem, sum(score) as score
|
|
from rzset
|
|
join rkey on key_id = rkey.id and (etime is null or etime > :now)
|
|
where key in (:keys)
|
|
group by elem
|
|
having count(distinct key_id) = :nkeys
|
|
order by sum(score), elem`
|
|
|
|
sqlInterStore = `
|
|
delete from rzset
|
|
where key_id = (
|
|
select id from rkey where key = :key
|
|
and (etime is null or etime > :now)
|
|
);
|
|
|
|
insert into rkey (key, type, version, mtime)
|
|
values (:key, :type, :version, :mtime)
|
|
on conflict (key) do update set
|
|
version = version+1,
|
|
type = excluded.type,
|
|
mtime = excluded.mtime;
|
|
|
|
insert into rzset (key_id, elem, score)
|
|
select
|
|
(select id from rkey where key = :key),
|
|
elem, sum(score) as score
|
|
from rzset
|
|
join rkey on key_id = rkey.id and (etime is null or etime > :now)
|
|
where key in (:keys)
|
|
group by elem
|
|
having count(distinct key_id) = :nkeys
|
|
order by sum(score), elem;`
|
|
)
|
|
|
|
// InterCmd intersects multiple sets.
|
|
type InterCmd struct {
|
|
db *DB
|
|
tx *Tx
|
|
dest string
|
|
keys []string
|
|
aggregate string
|
|
}
|
|
|
|
// Dest sets the key to store the result of the intersection.
|
|
func (c InterCmd) Dest(dest string) InterCmd {
|
|
c.dest = dest
|
|
return c
|
|
}
|
|
|
|
// Sum changes the aggregation function to take the sum of scores.
|
|
func (c InterCmd) Sum() InterCmd {
|
|
c.aggregate = sqlx.Sum
|
|
return c
|
|
}
|
|
|
|
// Min changes the aggregation function to take the minimum score.
|
|
func (c InterCmd) Min() InterCmd {
|
|
c.aggregate = sqlx.Min
|
|
return c
|
|
}
|
|
|
|
// Max changes the aggregation function to take the maximum score.
|
|
func (c InterCmd) Max() InterCmd {
|
|
c.aggregate = sqlx.Max
|
|
return c
|
|
}
|
|
|
|
// Run returns the intersection of multiple sets.
|
|
// The intersection consists of elements that exist in all given sets.
|
|
// The score of each element is the aggregate of its scores in the given sets.
|
|
// If any of the source keys do not exist or are not sets, returns an empty slice.
|
|
func (c InterCmd) Run() ([]SetItem, error) {
|
|
if c.db != nil {
|
|
return c.inter(c.db.SQL)
|
|
}
|
|
if c.tx != nil {
|
|
return c.inter(c.tx.tx)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
// Store intersects multiple sets and stores the result in a new set.
|
|
// Returns the number of elements in the resulting set.
|
|
// If the destination key already exists, it is fully overwritten
|
|
// (all old elements are removed and the new ones are inserted).
|
|
// If any of the source keys do not exist or are not sets, does nothing,
|
|
// except deleting the destination key if it exists.
|
|
func (c InterCmd) Store() (int, error) {
|
|
if c.db != nil {
|
|
var count int
|
|
err := c.db.Update(func(tx *Tx) error {
|
|
var err error
|
|
count, err = c.store(tx.tx)
|
|
return err
|
|
})
|
|
return count, err
|
|
}
|
|
if c.tx != nil {
|
|
return c.store(c.tx.tx)
|
|
}
|
|
return 0, nil
|
|
}
|
|
|
|
// inter returns the intersection of multiple sets.
|
|
func (c InterCmd) inter(tx sqlx.Tx) ([]SetItem, error) {
|
|
// Prepare query arguments.
|
|
query := sqlInter
|
|
if c.aggregate != sqlx.Sum {
|
|
query = strings.Replace(query, sqlx.Sum, c.aggregate, 2)
|
|
}
|
|
query, keyArgs := sqlx.ExpandIn(query, ":keys", c.keys)
|
|
args := slices.Concat(
|
|
[]any{time.Now().UnixMilli()}, // now
|
|
keyArgs, // keys
|
|
[]any{len(c.keys)}, // nkeys
|
|
)
|
|
|
|
// Execute the query.
|
|
var rows *sql.Rows
|
|
rows, err := tx.Query(query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
// Build the resulting element-score slice.
|
|
var items []SetItem
|
|
for rows.Next() {
|
|
it, err := scanItem(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, it)
|
|
}
|
|
if rows.Err() != nil {
|
|
return nil, rows.Err()
|
|
}
|
|
|
|
return items, nil
|
|
}
|
|
|
|
// store intersects multiple sets and stores the result in a new set.
|
|
func (c InterCmd) store(tx sqlx.Tx) (int, error) {
|
|
// Insert the destination key and get its ID.
|
|
now := time.Now().UnixMilli()
|
|
args := []any{
|
|
// delete from rzset
|
|
c.dest, // key
|
|
now, // now
|
|
// insert into rkey
|
|
c.dest, // key
|
|
core.TypeSortedSet, // type
|
|
core.InitialVersion, // version
|
|
now, // mtime
|
|
// insert into rzset
|
|
c.dest, // key
|
|
now, // now
|
|
// keys
|
|
// nkeys
|
|
}
|
|
query := sqlInterStore
|
|
if c.aggregate != sqlx.Sum {
|
|
query = strings.Replace(query, sqlx.Sum, c.aggregate, 2)
|
|
}
|
|
query, keyArgs := sqlx.ExpandIn(query, ":keys", c.keys)
|
|
args = slices.Concat(args, keyArgs, []any{len(c.keys)})
|
|
|
|
res, err := tx.Exec(query, args...)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// Return the number of elements in the resulting set.
|
|
n, _ := res.RowsAffected()
|
|
return int(n), nil
|
|
}
|