Files
redka/internal/rzset/inter.go
2024-04-28 08:49:25 +05:00

191 lines
4.5 KiB
Go

package rzset
import (
"database/sql"
"slices"
"strings"
"time"
"github.com/nalgeon/redka/internal/core"
"github.com/nalgeon/redka/internal/sqlx"
)
const (
sqlInter = `
select elem, sum(score) as score
from rzset
join rkey on key_id = rkey.id and (etime is null or etime > :now)
where key in (:keys)
group by elem
having count(distinct key_id) = :nkeys
order by sum(score), elem`
sqlInterStore = `
delete from rzset
where key_id = (
select id from rkey where key = :key
and (etime is null or etime > :now)
);
insert into rkey (key, type, version, mtime)
values (:key, :type, :version, :mtime)
on conflict (key) do update set
version = version+1,
type = excluded.type,
mtime = excluded.mtime;
insert into rzset (key_id, elem, score)
select
(select id from rkey where key = :key),
elem, sum(score) as score
from rzset
join rkey on key_id = rkey.id and (etime is null or etime > :now)
where key in (:keys)
group by elem
having count(distinct key_id) = :nkeys
order by sum(score), elem;`
)
// InterCmd intersects multiple sets.
type InterCmd struct {
db *DB
tx *Tx
dest string
keys []string
aggregate string
}
// Dest sets the key to store the result of the intersection.
func (c InterCmd) Dest(dest string) InterCmd {
c.dest = dest
return c
}
// Sum changes the aggregation function to take the sum of scores.
func (c InterCmd) Sum() InterCmd {
c.aggregate = sqlx.Sum
return c
}
// Min changes the aggregation function to take the minimum score.
func (c InterCmd) Min() InterCmd {
c.aggregate = sqlx.Min
return c
}
// Max changes the aggregation function to take the maximum score.
func (c InterCmd) Max() InterCmd {
c.aggregate = sqlx.Max
return c
}
// Run returns the intersection of multiple sets.
// The intersection consists of elements that exist in all given sets.
// The score of each element is the aggregate of its scores in the given sets.
// If any of the source keys do not exist or are not sets, returns an empty slice.
func (c InterCmd) Run() ([]SetItem, error) {
if c.db != nil {
return c.inter(c.db.SQL)
}
if c.tx != nil {
return c.inter(c.tx.tx)
}
return nil, nil
}
// Store intersects multiple sets and stores the result in a new set.
// Returns the number of elements in the resulting set.
// If the destination key already exists, it is fully overwritten
// (all old elements are removed and the new ones are inserted).
// If any of the source keys do not exist or are not sets, does nothing,
// except deleting the destination key if it exists.
func (c InterCmd) Store() (int, error) {
if c.db != nil {
var count int
err := c.db.Update(func(tx *Tx) error {
var err error
count, err = c.store(tx.tx)
return err
})
return count, err
}
if c.tx != nil {
return c.store(c.tx.tx)
}
return 0, nil
}
// inter returns the intersection of multiple sets.
func (c InterCmd) inter(tx sqlx.Tx) ([]SetItem, error) {
// Prepare query arguments.
query := sqlInter
if c.aggregate != sqlx.Sum {
query = strings.Replace(query, sqlx.Sum, c.aggregate, 2)
}
query, keyArgs := sqlx.ExpandIn(query, ":keys", c.keys)
args := slices.Concat(
[]any{time.Now().UnixMilli()}, // now
keyArgs, // keys
[]any{len(c.keys)}, // nkeys
)
// Execute the query.
var rows *sql.Rows
rows, err := tx.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
// Build the resulting element-score slice.
var items []SetItem
for rows.Next() {
it, err := scanItem(rows)
if err != nil {
return nil, err
}
items = append(items, it)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return items, nil
}
// store intersects multiple sets and stores the result in a new set.
func (c InterCmd) store(tx sqlx.Tx) (int, error) {
// Insert the destination key and get its ID.
now := time.Now().UnixMilli()
args := []any{
// delete from rzset
c.dest, // key
now, // now
// insert into rkey
c.dest, // key
core.TypeSortedSet, // type
core.InitialVersion, // version
now, // mtime
// insert into rzset
c.dest, // key
now, // now
// keys
// nkeys
}
query := sqlInterStore
if c.aggregate != sqlx.Sum {
query = strings.Replace(query, sqlx.Sum, c.aggregate, 2)
}
query, keyArgs := sqlx.ExpandIn(query, ":keys", c.keys)
args = slices.Concat(args, keyArgs, []any{len(c.keys)})
res, err := tx.Exec(query, args...)
if err != nil {
return 0, err
}
// Return the number of elements in the resulting set.
n, _ := res.RowsAffected()
return int(n), nil
}