diff --git a/internal/rhash/tx.go b/internal/rhash/tx.go index 83e42cf..b09a99f 100644 --- a/internal/rhash/tx.go +++ b/internal/rhash/tx.go @@ -12,71 +12,72 @@ const ( sqlCount = ` select count(field) from rhash - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key and field in (:fields)` + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ? and field in (:fields)` sqlDelete = ` delete from rhash where key_id = ( - select id from rkey where key = :key - and (etime is null or etime > :now) + select id from rkey where key = ? + and (etime is null or etime > ?) ) and field in (:fields)` sqlFields = ` select field from rhash - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key` + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ?` sqlGet = ` select value from rhash - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key and field = :field` + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ? and field = ?` sqlGetMany = ` select field, value from rhash - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key and field in (:fields)` + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ? and field in (:fields)` sqlItems = ` select field, value from rhash - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key` + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ?` sqlLen = ` select count(field) from rhash - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key` + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ?` sqlScan = ` select rhash.rowid, field, value from rhash - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key and rhash.rowid > :cursor and field glob :pattern - limit :count` + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ? and rhash.rowid > ? and field glob ? + limit ?` - sqlSet = ` + sqlSet1 = ` insert into rkey (key, type, version, mtime) - values (:key, :type, :version, :mtime) + values (?, ?, ?, ?) on conflict (key) do update set version = version+1, type = excluded.type, - mtime = excluded.mtime; + mtime = excluded.mtime` + sqlSet2 = ` insert into rhash (key_id, field, value) - values ((select id from rkey where key = :key), :field, :value) + values ((select id from rkey where key = ?), ?, ?) on conflict (key_id, field) do update set value = excluded.value;` sqlValues = ` select value from rhash - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key` + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ?` ) const scanPageSize = 10 @@ -443,19 +444,17 @@ func (tx *Tx) count(key string, fields ...string) (int, error) { // set creates or updates the value of a field in a hash. func (tx *Tx) set(key string, field string, value any) error { args := []any{ - // insert into rkey key, // key core.TypeHash, // type core.InitialVersion, // version time.Now().UnixMilli(), // mtime - // insert into rhash - key, field, value, } - _, err := tx.tx.Exec(sqlSet, args...) + _, err := tx.tx.Exec(sqlSet1, args...) if err != nil { return err } - return nil + _, err = tx.tx.Exec(sqlSet2, key, field, value) + return err } // scanValue scans a hash field value the current row. diff --git a/internal/rkey/tx.go b/internal/rkey/tx.go index 700a0f4..3a7b41d 100644 --- a/internal/rkey/tx.go +++ b/internal/rkey/tx.go @@ -11,11 +11,11 @@ import ( const ( sqlCount = ` select count(id) from rkey - where key in (:keys) and (etime is null or etime > :now)` + where key in (:keys) and (etime is null or etime > ?)` sqlDelete = ` delete from rkey where key in (:keys) - and (etime is null or etime > :now)` + and (etime is null or etime > ?)` sqlDeleteAll = ` delete from rkey; @@ -24,59 +24,59 @@ const ( sqlDeleteAllExpired = ` delete from rkey - where etime <= :now` + where etime <= ?` sqlDeleteNExpired = ` delete from rkey where rowid in ( select rowid from rkey - where etime <= :now - limit :n + where etime <= ? + limit ? )` sqlExpire = ` - update rkey set etime = :at - where key = :key and (etime is null or etime > :now)` + update rkey set etime = ? + where key = ? and (etime is null or etime > ?)` sqlGet = ` select id, key, type, version, etime, mtime from rkey - where key = :key and (etime is null or etime > :now)` + where key = ? and (etime is null or etime > ?)` sqlKeys = ` select id, key, type, version, etime, mtime from rkey - where key glob :pattern and (etime is null or etime > :now)` + where key glob ? and (etime is null or etime > ?)` sqlPersist = ` update rkey set etime = null - where key = :key and (etime is null or etime > :now)` + where key = ? and (etime is null or etime > ?)` sqlRandom = ` select id, key, type, version, etime, mtime from rkey - where etime is null or etime > :now + where etime is null or etime > ? order by random() limit 1` sqlRename = ` update or replace rkey set id = old.id, - key = :new_key, + key = ?, type = old.type, version = old.version+1, etime = old.etime, - mtime = :now + mtime = ? from ( select id, key, type, version, etime, mtime from rkey - where key = :key and (etime is null or etime > :now) + where key = ? and (etime is null or etime > ?) ) as old - where rkey.key = :key and ( - rkey.etime is null or rkey.etime > :now + where rkey.key = ? and ( + rkey.etime is null or rkey.etime > ? )` sqlScan = ` select id, key, type, version, etime, mtime from rkey - where id > :cursor and key glob :pattern and (etime is null or etime > :now) - limit :count` + where id > ? and key glob ? and (etime is null or etime > ?) + limit ?` ) const scanPageSize = 10 diff --git a/internal/rlist/tx.go b/internal/rlist/tx.go index 28b9b78..e1ffabc 100644 --- a/internal/rlist/tx.go +++ b/internal/rlist/tx.go @@ -13,18 +13,18 @@ const ( sqlDelete = ` delete from rlist where key_id = ( - select id from rkey where key = :key - and (etime is null or etime > :now) - ) and elem = :elem` + select id from rkey where key = ? + and (etime is null or etime > ?) + ) and elem = ?` sqlDeleteBack = ` with ids as ( select rlist.rowid from rlist - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key and elem = :elem + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ? and elem = ? order by pos desc - limit :count + limit ? ) delete from rlist where rowid in (select rowid from ids)` @@ -33,10 +33,10 @@ const ( with ids as ( select rlist.rowid from rlist - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key and elem = :elem + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ? and elem = ? order by pos - limit :count + limit ? ) delete from rlist where rowid in (select rowid from ids)` @@ -45,21 +45,21 @@ const ( with elems as ( select elem, row_number() over (order by pos asc) as rownum from rlist - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ? ) select elem from elems - where rownum = :idx + 1` + where rownum = ? + 1` sqlInsertAfter = ` with keyid as ( select id from rkey - where key = :key and (etime is null or etime > :now) + where key = ? and (etime is null or etime > ?) ), elprev as ( select min(pos) as pos from rlist - where key_id = (select id from keyid) and elem = :pivot + where key_id = (select id from keyid) and elem = ? ), elnext as ( select min(pos) as pos from rlist @@ -74,7 +74,7 @@ const ( from elprev, elnext ) insert into rlist (key_id, pos, elem) - select (select id from keyid), (select pos from newpos), :elem + select (select id from keyid), (select pos from newpos), ? from rlist where key_id = (select id from keyid) limit 1 @@ -86,11 +86,11 @@ const ( sqlInsertBefore = ` with keyid as ( select id from rkey - where key = :key and (etime is null or etime > :now) + where key = ? and (etime is null or etime > ?) ), elnext as ( select min(pos) as pos from rlist - where key_id = (select id from keyid) and elem = :pivot + where key_id = (select id from keyid) and elem = ? ), elprev as ( select max(pos) as pos from rlist @@ -105,7 +105,7 @@ const ( from elprev, elnext ) insert into rlist (key_id, pos, elem) - select (select id from keyid), (select pos from newpos), :elem + select (select id from keyid), (select pos from newpos), ? from rlist where key_id = (select id from keyid) limit 1 @@ -117,13 +117,13 @@ const ( sqlLen = ` select count(*) from rlist - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key` + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ?` sqlPopBack = ` with keyid as ( select id from rkey - where key = :key and (etime is null or etime > :now) + where key = ? and (etime is null or etime > ?) ) delete from rlist where @@ -137,7 +137,7 @@ const ( sqlPopFront = ` with keyid as ( select id from rkey - where key = :key and (etime is null or etime > :now) + where key = ? and (etime is null or etime > ?) ) delete from rlist where @@ -150,28 +150,28 @@ const ( sqlPushBack = ` insert into rlist (key_id, pos, elem) - select :key_id, coalesce(max(pos)+1, 0), :elem + select ?, coalesce(max(pos)+1, 0), ? from rlist - where key_id = :key_id + where key_id = ? returning ( select count(*) from rlist - where key_id = :key_id + where key_id = ? )` sqlPushFront = ` insert into rlist (key_id, pos, elem) - select :key_id, coalesce(min(pos)-1, 0), :elem + select ?, coalesce(min(pos)-1, 0), ? from rlist - where key_id = :key_id + where key_id = ? returning ( select count(*) from rlist - where key_id = :key_id + where key_id = ? )` sqlRange = ` with keyid as ( select id from rkey - where key = :key and (etime is null or etime > :now) + where key = ? and (etime is null or etime > ?) ), counts as ( select count(*) as n_elem from rlist @@ -179,13 +179,13 @@ const ( ), bounds as ( select - case when :start < 0 - then (select n_elem from counts) + :start - else :start + case when ? < 0 + then (select n_elem from counts) + ? + else ? end as start, - case when :stop < 0 - then (select n_elem from counts) + :stop - else :stop + case when ? < 0 + then (select n_elem from counts) + ? + else ? end as stop ) select elem @@ -199,20 +199,20 @@ const ( sqlSet = ` with keyid as ( select id from rkey - where key = :key and (etime is null or etime > :now) + where key = ? and (etime is null or etime > ?) ), elems as ( select pos, row_number() over (order by pos asc) as rownum from rlist where key_id = (select id from keyid) ) - update rlist set elem = :elem + update rlist set elem = ? where key_id = (select id from keyid) - and pos = (select pos from elems where rownum = :idx + 1)` + and pos = (select pos from elems where rownum = ? + 1)` sqlSetKey = ` insert into rkey (key, type, version, mtime) - values (:key, :type, :version, :mtime) + values (?, ?, ?, ?) on conflict (key) do update set version = version+1, type = excluded.type, @@ -222,7 +222,7 @@ const ( sqlTrim = ` with keyid as ( select id from rkey - where key = :key and (etime is null or etime > :now) + where key = ? and (etime is null or etime > ?) ), counts as ( select count(*) as n_elem from rlist @@ -230,13 +230,13 @@ const ( ), bounds as ( select - case when :start < 0 - then (select n_elem from counts) + :start - else :start + case when ? < 0 + then (select n_elem from counts) + ? + else ? end as start, - case when :stop < 0 - then (select n_elem from counts) + :stop - else :stop + case when ? < 0 + then (select n_elem from counts) + ? + else ? end as stop ), remain as ( @@ -404,7 +404,11 @@ func (tx *Tx) Range(key string, start, stop int) ([]core.Value, error) { return nil, nil } - args := []any{key, time.Now().UnixMilli(), start, stop} + args := []any{ + key, time.Now().UnixMilli(), + start, start, start, + stop, stop, stop, + } rows, err := tx.tx.Query(sqlRange, args...) if err != nil { return nil, err @@ -465,7 +469,11 @@ func (tx *Tx) Set(key string, idx int, elem any) error { // // Does nothing if the key does not exist or is not a list. func (tx *Tx) Trim(key string, start, stop int) (int, error) { - args := []any{key, time.Now().UnixMilli(), start, stop} + args := []any{ + key, time.Now().UnixMilli(), + start, start, start, + stop, stop, stop, + } out, err := tx.tx.Exec(sqlTrim, args...) if err != nil { return 0, err @@ -506,7 +514,7 @@ func (tx *Tx) insert(key string, pivot, elem any, query string) (int, error) { return 0, core.ErrNotFound } if err != nil { - if err.Error() == "NOT NULL constraint failed: rlist.pos" { + if sqlx.ConstraintFailed(err, "NOT NULL", "rlist.pos") { return -1, core.ErrNotFound } return 0, err diff --git a/internal/rstring/db.go b/internal/rstring/db.go index cd6c7fc..855e574 100644 --- a/internal/rstring/db.go +++ b/internal/rstring/db.go @@ -70,15 +70,19 @@ func (d *DB) IncrFloat(key string, delta float64) (float64, error) { // Set sets the key value that will not expire. // Overwrites the value if the key already exists. func (d *DB) Set(key string, value any) error { - tx := NewTx(d.RW) - return tx.Set(key, value) + err := d.Update(func(tx *Tx) error { + return tx.Set(key, value) + }) + return err } // SetExpires sets the key value with an optional expiration time (if ttl > 0). // Overwrites the value and ttl if the key already exists. func (d *DB) SetExpires(key string, value any, ttl time.Duration) error { - tx := NewTx(d.RW) - return tx.SetExpires(key, value, ttl) + err := d.Update(func(tx *Tx) error { + return tx.SetExpires(key, value, ttl) + }) + return err } // SetMany sets the values of multiple keys. diff --git a/internal/rstring/tx.go b/internal/rstring/tx.go index 8508eb0..31161e5 100644 --- a/internal/rstring/tx.go +++ b/internal/rstring/tx.go @@ -12,40 +12,42 @@ const ( sqlGet = ` select value from rstring - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key` + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ?` sqlGetMany = ` select key, value from rstring - join rkey on key_id = rkey.id and (etime is null or etime > :now) + join rkey on key_id = rkey.id and (etime is null or etime > ?) where key in (:keys)` - sqlSet = ` + sqlSet1 = ` insert into rkey (key, type, version, etime, mtime) - values (:key, :type, :version, :etime, :mtime) + values (?, ?, ?, ?, ?) on conflict (key) do update set version = version+1, type = excluded.type, etime = excluded.etime, - mtime = excluded.mtime; + mtime = excluded.mtime` + sqlSet2 = ` insert into rstring (key_id, value) - values ((select id from rkey where key = :key), :value) + values ((select id from rkey where key = ?), ?) on conflict (key_id) do update - set value = excluded.value;` + set value = excluded.value` - sqlUpdate = ` + sqlUpdate1 = ` insert into rkey (key, type, version, etime, mtime) - values (:key, :type, :version, null, :mtime) + values (?, ?, ?, null, ?) on conflict (key) do update set version = version+1, type = excluded.type, -- not changing etime - mtime = excluded.mtime; + mtime = excluded.mtime` + sqlUpdate2 = ` insert into rstring (key_id, value) - values ((select id from rkey where key = :key), :value) + values ((select id from rkey where key = ?), ?) on conflict (key_id) do update set value = excluded.value` ) @@ -203,6 +205,10 @@ func (tx *Tx) SetWith(key string, value any) SetCmd { } func get(tx sqlx.Tx, key string) (core.Value, error) { + // args := []any{ + // sql.Named("now", time.Now().UnixMilli()), // now + // sql.Named("key", key), // key + // } args := []any{ time.Now().UnixMilli(), // now key, // key @@ -227,21 +233,18 @@ func set(tx sqlx.Tx, key string, value any, at time.Time) error { } args := []any{ - // insert into rkey key, // key core.TypeString, // type core.InitialVersion, // version etime, // etime time.Now().UnixMilli(), // mtime - // insert into rstring - key, // key - value, // value } - _, err := tx.Exec(sqlSet, args...) + _, err := tx.Exec(sqlSet1, args...) if err != nil { return err } - return nil + _, err = tx.Exec(sqlSet2, key, value) + return err } // update updates the value of the existing key without changing its @@ -249,18 +252,15 @@ func set(tx sqlx.Tx, key string, value any, at time.Time) error { // the specified value and no expiration time. func update(tx sqlx.Tx, key string, value any) error { args := []any{ - // insert into rkey key, // key core.TypeString, // type core.InitialVersion, // version time.Now().UnixMilli(), // mtime - // insert into rstring - key, // key - value, // value } - _, err := tx.Exec(sqlUpdate, args...) + _, err := tx.Exec(sqlUpdate1, args...) if err != nil { return err } - return nil + _, err = tx.Exec(sqlUpdate2, key, value) + return err } diff --git a/internal/rzset/db.go b/internal/rzset/db.go index 25a13ae..f034f16 100644 --- a/internal/rzset/db.go +++ b/internal/rzset/db.go @@ -127,14 +127,13 @@ func (d *DB) Incr(key string, elem any, delta float64) (float64, error) { // The score of each element is the sum of its scores in the given sets. // If any of the source keys do not exist or are not sets, returns an empty slice. func (d *DB) Inter(keys ...string) ([]SetItem, error) { - tx := NewTx(d.RO) - return tx.Inter(keys...) + cmd := InterCmd{db: d, keys: keys, aggregate: sqlx.Sum} + return cmd.Run() } // InterWith intersects multiple sets with additional options. func (d *DB) InterWith(keys ...string) InterCmd { - tx := NewTx(d.RW) - return tx.InterWith(keys...) + return InterCmd{db: d, keys: keys, aggregate: sqlx.Sum} } // Len returns the number of elements in a set. @@ -187,12 +186,11 @@ func (d *DB) Scanner(key, pattern string, pageSize int) *Scanner { // Ignores the keys that do not exist or are not sets. // If no keys exist, returns a nil slice. func (d *DB) Union(keys ...string) ([]SetItem, error) { - tx := NewTx(d.RO) - return tx.Union(keys...) + cmd := UnionCmd{db: d, keys: keys, aggregate: sqlx.Sum} + return cmd.Run() } // UnionWith unions multiple sets with additional options. func (d *DB) UnionWith(keys ...string) UnionCmd { - tx := NewTx(d.RW) - return tx.UnionWith(keys...) + return UnionCmd{db: d, keys: keys, aggregate: sqlx.Sum} } diff --git a/internal/rzset/delete.go b/internal/rzset/delete.go index c3abd20..fbcf674 100644 --- a/internal/rzset/delete.go +++ b/internal/rzset/delete.go @@ -12,11 +12,11 @@ const ( select rowid, elem, score from rzset where key_id = ( - select id from rkey where key = :key - and (etime is null or etime > :now) + select id from rkey where key = ? + and (etime is null or etime > ?) ) order by score, elem - limit :start, :count + limit ?, ? ) delete from rzset where rowid in (select rowid from ranked)` @@ -24,10 +24,10 @@ const ( sqlDeleteScore = ` delete from rzset where key_id = ( - select id from rkey where key = :key - and (etime is null or etime > :now) + select id from rkey where key = ? + and (etime is null or etime > ?) ) - and score between :start and :stop` + and score between ? and ?` ) // DeleteCmd removes elements from a set. diff --git a/internal/rzset/inter.go b/internal/rzset/inter.go index b3fac77..e479c2c 100644 --- a/internal/rzset/inter.go +++ b/internal/rzset/inter.go @@ -14,41 +14,43 @@ const ( sqlInter = ` select elem, sum(score) as score from rzset - join rkey on key_id = rkey.id and (etime is null or etime > :now) + join rkey on key_id = rkey.id and (etime is null or etime > ?) where key in (:keys) group by elem - having count(distinct key_id) = :nkeys + having count(distinct key_id) = ? order by sum(score), elem` - sqlInterStore = ` + sqlInterStore1 = ` delete from rzset where key_id = ( - select id from rkey where key = :key - and (etime is null or etime > :now) - ); + select id from rkey where key = ? + and (etime is null or etime > ?) + )` + sqlInterStore2 = ` insert into rkey (key, type, version, mtime) - values (:key, :type, :version, :mtime) + values (?, ?, ?, ?) on conflict (key) do update set version = version+1, type = excluded.type, - mtime = excluded.mtime; + mtime = excluded.mtime + returning id` + sqlInterStore3 = ` insert into rzset (key_id, elem, score) - select - (select id from rkey where key = :key), - elem, sum(score) as score + select ?, elem, sum(score) as score from rzset - join rkey on key_id = rkey.id and (etime is null or etime > :now) + join rkey on key_id = rkey.id and (etime is null or etime > ?) where key in (:keys) group by elem - having count(distinct key_id) = :nkeys + having count(distinct key_id) = ? order by sum(score), elem;` ) // InterCmd intersects multiple sets. type InterCmd struct { - tx sqlx.Tx + db *DB + tx *Tx dest string keys []string aggregate string @@ -83,6 +85,39 @@ func (c InterCmd) Max() InterCmd { // The score of each element is the aggregate of its scores in the given sets. // If any of the source keys do not exist or are not sets, returns an empty slice. func (c InterCmd) Run() ([]SetItem, error) { + if c.db != nil { + return c.run(c.db.RO) + } + if c.tx != nil { + return c.run(c.tx.tx) + } + return nil, nil +} + +// Store intersects multiple sets and stores the result in a new set. +// Returns the number of elements in the resulting set. +// If the destination key already exists, it is fully overwritten +// (all old elements are removed and the new ones are inserted). +// If any of the source keys do not exist or are not sets, does nothing, +// except deleting the destination key if it exists. +func (c InterCmd) Store() (int, error) { + if c.db != nil { + var count int + err := c.db.Update(func(tx *Tx) error { + var err error + count, err = c.store(tx.tx) + return err + }) + return count, err + } + if c.tx != nil { + return c.store(c.tx.tx) + } + return 0, nil +} + +// run returns the intersection of multiple sets. +func (c InterCmd) run(tx sqlx.Tx) ([]SetItem, error) { // Prepare query arguments. query := sqlInter if c.aggregate != sqlx.Sum { @@ -97,7 +132,7 @@ func (c InterCmd) Run() ([]SetItem, error) { // Execute the query. var rows *sql.Rows - rows, err := c.tx.Query(query, args...) + rows, err := tx.Query(query, args...) if err != nil { return nil, err } @@ -119,38 +154,38 @@ func (c InterCmd) Run() ([]SetItem, error) { return items, nil } -// Store intersects multiple sets and stores the result in a new set. -// Returns the number of elements in the resulting set. -// If the destination key already exists, it is fully overwritten -// (all old elements are removed and the new ones are inserted). -// If any of the source keys do not exist or are not sets, does nothing, -// except deleting the destination key if it exists. -func (c InterCmd) Store() (int, error) { - // Insert the destination key and get its ID. +// store intersects multiple sets and stores the result in a new set. +func (c InterCmd) store(tx sqlx.Tx) (int, error) { now := time.Now().UnixMilli() - args := []any{ - // delete from rzset - c.dest, // key - now, // now - // insert into rkey + + // Delete the destination key if it exists. + args := []any{c.dest, now} + _, err := tx.Exec(sqlInterStore1, args...) + if err != nil { + return 0, err + } + + // Create the destination key. + args = []any{ c.dest, // key core.TypeSortedSet, // type core.InitialVersion, // version now, // mtime - // insert into rzset - c.dest, // key - now, // now - // keys - // nkeys } - query := sqlInterStore + var destID int + err = tx.QueryRow(sqlInterStore2, args...).Scan(&destID) + if err != nil { + return 0, err + } + + // Intersect the source sets and store the result. + query := sqlInterStore3 if c.aggregate != sqlx.Sum { query = strings.Replace(query, sqlx.Sum, c.aggregate, 2) } query, keyArgs := sqlx.ExpandIn(query, ":keys", c.keys) - args = slices.Concat(args, keyArgs, []any{len(c.keys)}) - - res, err := c.tx.Exec(query, args...) + args = slices.Concat([]any{destID, now}, keyArgs, []any{len(c.keys)}) + res, err := tx.Exec(query, args...) if err != nil { return 0, err } diff --git a/internal/rzset/range.go b/internal/rzset/range.go index ffc0007..98290f1 100644 --- a/internal/rzset/range.go +++ b/internal/rzset/range.go @@ -12,21 +12,21 @@ const ( with ranked as ( select elem, score, (row_number() over w - 1) as rank from rzset - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ? window w as (partition by key_id order by score asc, elem asc) ) select elem, score from ranked - where rank between :start and :stop + where rank between ? and ? order by rank, elem asc` sqlRangeScore = ` select elem, score from rzset - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key - and score between :start and :stop + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ? + and score between ? and ? order by score asc, elem asc` ) @@ -173,13 +173,13 @@ func (c RangeCmd) rangeScore() ([]SetItem, error) { // Add offset and count if necessary. if c.offset > 0 && c.count > 0 { - query += " limit :offset, :count" + query += " limit ?, ?" args = append(args, c.offset, c.count) } else if c.count > 0 { - query += " limit :count" + query += " limit ?" args = append(args, c.count) } else if c.offset > 0 { - query += " limit :offset, -1" + query += " limit ?, -1" args = append(args, c.offset) } diff --git a/internal/rzset/tx.go b/internal/rzset/tx.go index 5cba6bf..fb9e46e 100644 --- a/internal/rzset/tx.go +++ b/internal/rzset/tx.go @@ -10,59 +10,60 @@ import ( ) const ( - sqlAdd = ` + sqlAdd1 = ` insert into rkey (key, type, version, mtime) - values (:key, :type, :version, :mtime) + values (?, ?, ?, ?) on conflict (key) do update set version = version+1, type = excluded.type, - mtime = excluded.mtime; + mtime = excluded.mtime` + sqlAdd2 = ` insert into rzset (key_id, elem, score) - values ((select id from rkey where key = :key), :elem, :score) + values ((select id from rkey where key = ?), ?, ?) on conflict (key_id, elem) do update - set score = excluded.score;` + set score = excluded.score` sqlCount = ` select count(elem) from rzset - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key and elem in (:elems)` + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ? and elem in (:elems)` sqlCountScore = ` select count(elem) from rzset - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key and score between :min and :max` + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ? and score between ? and ?` sqlDelete = ` delete from rzset where key_id = ( - select id from rkey where key = :key - and (etime is null or etime > :now) + select id from rkey where key = ? + and (etime is null or etime > ?) ) and elem in (:elems)` sqlGetRank = ` with ranked as ( select elem, score, (row_number() over w - 1) as rank from rzset - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ? window w as (partition by key_id order by score asc, elem asc) ) select rank, score from ranked - where elem = :elem` + where elem = ?` sqlGetScore = ` select score from rzset - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key and elem = :elem` + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ? and elem = ?` sqlIncr1 = ` insert into rkey (key, type, version, mtime) - values (:key, :type, :version, :mtime) + values (?, ?, ?, ?) on conflict (key) do update set version = version+1, type = excluded.type, @@ -70,7 +71,7 @@ const ( sqlIncr2 = ` insert into rzset (key_id, elem, score) - values ((select id from rkey where key = :key), :elem, :delta) + values ((select id from rkey where key = ?), ?, ?) on conflict (key_id, elem) do update set score = score + excluded.score returning score` @@ -78,15 +79,15 @@ const ( sqlLen = ` select count(elem) from rzset - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key` + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ?` sqlScan = ` select rzset.rowid, elem, score from rzset - join rkey on key_id = rkey.id and (etime is null or etime > :now) - where key = :key and rzset.rowid > :cursor and elem glob :pattern - limit :count` + join rkey on key_id = rkey.id and (etime is null or etime > ?) + where key = ? and rzset.rowid > ? and elem glob ? + limit ?` ) const scanPageSize = 10 @@ -259,12 +260,13 @@ func (tx *Tx) Incr(key string, elem any, delta float64) (float64, error) { // The score of each element is the sum of its scores in the given sets. // If any of the source keys do not exist or are not sets, returns an empty slice. func (tx *Tx) Inter(keys ...string) ([]SetItem, error) { - return tx.InterWith(keys...).Run() + cmd := InterCmd{tx: tx, keys: keys, aggregate: sqlx.Sum} + return cmd.Run() } // InterWith intersects multiple sets with additional options. func (tx *Tx) InterWith(keys ...string) InterCmd { - return InterCmd{tx: tx.tx, keys: keys, aggregate: sqlx.Sum} + return InterCmd{tx: tx, keys: keys, aggregate: sqlx.Sum} } // Len returns the number of elements in a set. @@ -345,12 +347,13 @@ func (tx *Tx) Scanner(key, pattern string, pageSize int) *Scanner { // Ignores the keys that do not exist or are not sets. // If no keys exist, returns a nil slice. func (tx *Tx) Union(keys ...string) ([]SetItem, error) { - return tx.UnionWith(keys...).Run() + cmd := UnionCmd{tx: tx, keys: keys, aggregate: sqlx.Sum} + return cmd.Run() } // UnionWith unions multiple sets with additional options. func (tx *Tx) UnionWith(keys ...string) UnionCmd { - return UnionCmd{tx: tx.tx, keys: keys, aggregate: sqlx.Sum} + return UnionCmd{tx: tx, keys: keys, aggregate: sqlx.Sum} } // add adds or updates the element in a set. @@ -360,20 +363,17 @@ func (tx *Tx) add(key string, elem any, score float64) error { } args := []any{ - // insert into rkey key, // key core.TypeSortedSet, // type core.InitialVersion, // version time.Now().UnixMilli(), // mtime - // insert into rzset - key, elem, score, } - - _, err := tx.tx.Exec(sqlAdd, args...) + _, err := tx.tx.Exec(sqlAdd1, args...) if err != nil { return err } - return nil + _, err = tx.tx.Exec(sqlAdd2, key, elem, score) + return err } // count returns the number of existing elements in a set. @@ -384,8 +384,8 @@ func (tx *Tx) count(key string, elems ...any) (int, error) { } } - query, fieldArgs := sqlx.ExpandIn(sqlCount, ":elems", elems) - args := append([]any{time.Now().UnixMilli(), key}, fieldArgs...) + query, elemArgs := sqlx.ExpandIn(sqlCount, ":elems", elems) + args := append([]any{time.Now().UnixMilli(), key}, elemArgs...) var count int err := tx.tx.QueryRow(query, args...).Scan(&count) return count, err diff --git a/internal/rzset/union.go b/internal/rzset/union.go index 9167ddd..c1f29cf 100644 --- a/internal/rzset/union.go +++ b/internal/rzset/union.go @@ -13,31 +13,32 @@ const ( sqlUnion = ` select elem, sum(score) as score from rzset - join rkey on key_id = rkey.id and (etime is null or etime > :now) + join rkey on key_id = rkey.id and (etime is null or etime > ?) where key in (:keys) group by elem order by sum(score), elem` - sqlUnionStore = ` + sqlUnionStore1 = ` delete from rzset where key_id = ( - select id from rkey where key = :key - and (etime is null or etime > :now) - ); + select id from rkey where key = ? + and (etime is null or etime > ?) + )` + sqlUnionStore2 = ` insert into rkey (key, type, version, mtime) - values (:key, :type, :version, :mtime) + values (?, ?, ?, ?) on conflict (key) do update set version = version+1, type = excluded.type, - mtime = excluded.mtime; + mtime = excluded.mtime + returning id` + sqlUnionStore3 = ` insert into rzset (key_id, elem, score) - select - (select id from rkey where key = :key), - elem, sum(score) as score + select ?, elem, sum(score) as score from rzset - join rkey on key_id = rkey.id and (etime is null or etime > :now) + join rkey on key_id = rkey.id and (etime is null or etime > ?) where key in (:keys) group by elem order by sum(score), elem;` @@ -45,7 +46,8 @@ const ( // UnionCmd unions multiple sets. type UnionCmd struct { - tx sqlx.Tx + db *DB + tx *Tx dest string keys []string aggregate string @@ -81,6 +83,40 @@ func (c UnionCmd) Max() UnionCmd { // Ignores the keys that do not exist or are not sets. // If no keys exist, returns a nil slice. func (c UnionCmd) Run() ([]SetItem, error) { + if c.db != nil { + return c.run(c.db.RO) + } + if c.tx != nil { + return c.run(c.tx.tx) + } + return nil, nil +} + +// Store unions multiple sets and stores the result in a new set. +// Returns the number of elements in the resulting set. +// If the destination key already exists, it is fully overwritten +// (all old elements are removed and the new ones are inserted). +// Ignores the source keys that do not exist or are not sets. +// If all of the source keys do not exist or are not sets, does nothing, +// except deleting the destination key if it exists. +func (c UnionCmd) Store() (int, error) { + if c.db != nil { + var count int + err := c.db.Update(func(tx *Tx) error { + var err error + count, err = c.store(tx.tx) + return err + }) + return count, err + } + if c.tx != nil { + return c.store(c.tx.tx) + } + return 0, nil +} + +// run returns the union of multiple sets. +func (c UnionCmd) run(tx sqlx.Tx) ([]SetItem, error) { // Prepare query arguments. now := time.Now().UnixMilli() query := sqlUnion @@ -92,7 +128,7 @@ func (c UnionCmd) Run() ([]SetItem, error) { // Execute the query. var rows *sql.Rows - rows, err := c.tx.Query(query, args...) + rows, err := tx.Query(query, args...) if err != nil { return nil, err } @@ -114,38 +150,38 @@ func (c UnionCmd) Run() ([]SetItem, error) { return items, nil } -// Store unions multiple sets and stores the result in a new set. -// Returns the number of elements in the resulting set. -// If the destination key already exists, it is fully overwritten -// (all old elements are removed and the new ones are inserted). -// Ignores the source keys that do not exist or are not sets. -// If all of the source keys do not exist or are not sets, does nothing, -// except deleting the destination key if it exists. -func (c UnionCmd) Store() (int, error) { - // Union the sets and store the result. +// store unions multiple sets and stores the result in a new set. +func (c UnionCmd) store(tx sqlx.Tx) (int, error) { now := time.Now().UnixMilli() - args := []any{ - // delete from rzset - c.dest, // key - now, // now - // insert into rkey + + // Delete the destination key if it exists. + args := []any{c.dest, now} + _, err := tx.Exec(sqlUnionStore1, args...) + if err != nil { + return 0, err + } + + // Create the destination key. + args = []any{ c.dest, // key core.TypeSortedSet, // type core.InitialVersion, // version now, // mtime - // insert into rzset - c.dest, // key - now, // now - // keys } - query := sqlUnionStore + var destID int + err = tx.QueryRow(sqlUnionStore2, args...).Scan(&destID) + if err != nil { + return 0, err + } + + // Union the source sets and store the result. + query := sqlUnionStore3 if c.aggregate != sqlx.Sum { query = strings.Replace(query, sqlx.Sum, c.aggregate, 2) } query, keyArgs := sqlx.ExpandIn(query, ":keys", c.keys) - args = append(args, keyArgs...) - - res, err := c.tx.Exec(query, args...) + args = append([]any{destID, now}, keyArgs...) + res, err := tx.Exec(query, args...) if err != nil { return 0, err } diff --git a/internal/sqlx/sql.go b/internal/sqlx/sql.go index f7e0e94..b4c6419 100644 --- a/internal/sqlx/sql.go +++ b/internal/sqlx/sql.go @@ -67,3 +67,10 @@ func Select[T any](db Tx, query string, args []any, return vals, err } + +// ConstraintFailed checks if the error is due to +// a constraint violation on a column. +func ConstraintFailed(err error, constraint, column string) bool { + msg := constraint + " constraint failed: " + column + return strings.Contains(err.Error(), msg) +} diff --git a/internal/testx/testx.go b/internal/testx/testx.go index 31d55d2..749bf90 100644 --- a/internal/testx/testx.go +++ b/internal/testx/testx.go @@ -2,22 +2,12 @@ package testx import ( - "database/sql" "reflect" "testing" _ "github.com/mattn/go-sqlite3" ) -func GetDB(tb testing.TB) *sql.DB { - tb.Helper() - db, err := sql.Open("sqlite3", ":memory:") - if err != nil { - tb.Fatal(err) - } - return db -} - func AssertEqual(tb testing.TB, got, want any) { tb.Helper() if !reflect.DeepEqual(got, want) {