mirror of
https://github.com/nalgeon/redka.git
synced 2025-12-24 12:38:00 +08:00
refactor: convert value to bytes before saving to db
This commit is contained in:
@@ -148,8 +148,42 @@ func (v Value) Exists() bool {
|
||||
// - byte slice
|
||||
func IsValueType(v any) bool {
|
||||
switch v.(type) {
|
||||
case string, int, float64, bool, []byte:
|
||||
case bool, float64, int, string, []byte:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ToBytesMany converts multiple values to byte slices.
|
||||
func ToBytesMany(values ...any) ([][]byte, error) {
|
||||
blobs := make([][]byte, len(values))
|
||||
for i, v := range values {
|
||||
b, err := ToBytes(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blobs[i] = b
|
||||
}
|
||||
return blobs, nil
|
||||
}
|
||||
|
||||
// ToBytes converts a value to a byte slice.
|
||||
func ToBytes(v any) ([]byte, error) {
|
||||
switch v := v.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
return []byte{'1'}, nil
|
||||
} else {
|
||||
return []byte{'0'}, nil
|
||||
}
|
||||
case float64:
|
||||
return []byte(strconv.FormatFloat(v, 'f', -1, 64)), nil
|
||||
case int:
|
||||
return []byte(strconv.Itoa(v)), nil
|
||||
case string:
|
||||
return []byte(v), nil
|
||||
case []byte:
|
||||
return v, nil
|
||||
}
|
||||
return nil, ErrValueType
|
||||
}
|
||||
|
||||
@@ -443,17 +443,21 @@ func (tx *Tx) count(key string, fields ...string) (int, error) {
|
||||
|
||||
// set creates or updates the value of a field in a hash.
|
||||
func (tx *Tx) set(key string, field string, value any) error {
|
||||
valueb, err := core.ToBytes(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
args := []any{
|
||||
key, // key
|
||||
core.TypeHash, // type
|
||||
core.InitialVersion, // version
|
||||
time.Now().UnixMilli(), // mtime
|
||||
}
|
||||
_, err := tx.tx.Exec(sqlSet1, args...)
|
||||
_, err = tx.tx.Exec(sqlSet1, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.tx.Exec(sqlSet2, key, field, value)
|
||||
_, err = tx.tx.Exec(sqlSet2, key, field, valueb)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -268,7 +268,11 @@ func NewTx(tx sqlx.Tx) *Tx {
|
||||
// Returns the number of elements deleted.
|
||||
// Does nothing if the key does not exist or is not a list.
|
||||
func (tx *Tx) Delete(key string, elem any) (int, error) {
|
||||
args := []any{key, time.Now().UnixMilli(), elem}
|
||||
elemb, err := core.ToBytes(elem)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
args := []any{key, time.Now().UnixMilli(), elemb}
|
||||
res, err := tx.tx.Exec(sqlDelete, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -436,8 +440,9 @@ func (tx *Tx) Range(key string, start, stop int) ([]core.Value, error) {
|
||||
// If the index is out of bounds, returns ErrNotFound.
|
||||
// If the key does not exist or is not a list, returns ErrNotFound.
|
||||
func (tx *Tx) Set(key string, idx int, elem any) error {
|
||||
if !core.IsValueType(elem) {
|
||||
return core.ErrValueType
|
||||
elemb, err := core.ToBytes(elem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var query = sqlSet
|
||||
@@ -447,7 +452,7 @@ func (tx *Tx) Set(key string, idx int, elem any) error {
|
||||
idx = -idx - 1
|
||||
}
|
||||
|
||||
args := []any{key, time.Now().UnixMilli(), elem, idx}
|
||||
args := []any{key, time.Now().UnixMilli(), elemb, idx}
|
||||
out, err := tx.tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -488,8 +493,12 @@ func (tx *Tx) delete(key string, elem any, count int, query string) (int, error)
|
||||
if count <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
elemb, err := core.ToBytes(elem)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
args := []any{time.Now().UnixMilli(), key, elem, count}
|
||||
args := []any{time.Now().UnixMilli(), key, elemb, count}
|
||||
res, err := tx.tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -503,13 +512,18 @@ func (tx *Tx) delete(key string, elem any, count int, query string) (int, error)
|
||||
|
||||
// insert inserts an element before or after a pivot in a list.
|
||||
func (tx *Tx) insert(key string, pivot, elem any, query string) (int, error) {
|
||||
if !core.IsValueType(elem) {
|
||||
return 0, core.ErrValueType
|
||||
pivotb, err := core.ToBytes(pivot)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
elemb, err := core.ToBytes(elem)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
args := []any{key, time.Now().UnixMilli(), pivot, elem}
|
||||
args := []any{key, time.Now().UnixMilli(), pivotb, elemb}
|
||||
var count int
|
||||
err := tx.tx.QueryRow(query, args...).Scan(&count)
|
||||
err = tx.tx.QueryRow(query, args...).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, core.ErrNotFound
|
||||
}
|
||||
@@ -538,8 +552,9 @@ func (tx *Tx) pop(key string, query string) (core.Value, error) {
|
||||
|
||||
// push inserts an element to the front or back of a list.
|
||||
func (tx *Tx) push(key string, elem any, query string) (int, error) {
|
||||
if !core.IsValueType(elem) {
|
||||
return 0, core.ErrValueType
|
||||
elemb, err := core.ToBytes(elem)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Set the key if it does not exist.
|
||||
@@ -550,14 +565,14 @@ func (tx *Tx) push(key string, elem any, query string) (int, error) {
|
||||
time.Now().UnixMilli(), // mtime
|
||||
}
|
||||
var keyID int
|
||||
err := tx.tx.QueryRow(sqlSetKey, args...).Scan(&keyID)
|
||||
err = tx.tx.QueryRow(sqlSetKey, args...).Scan(&keyID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Insert the element.
|
||||
var count int
|
||||
args = []any{keyID, elem, keyID, keyID}
|
||||
args = []any{keyID, elemb, keyID, keyID}
|
||||
err = tx.tx.QueryRow(query, args...).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
||||
@@ -166,9 +166,6 @@ func (tx *Tx) Set(key string, value any) error {
|
||||
// SetExpires sets the key value with an optional expiration time (if ttl > 0).
|
||||
// Overwrites the value and ttl if the key already exists.
|
||||
func (tx *Tx) SetExpires(key string, value any, ttl time.Duration) error {
|
||||
if !core.IsValueType(value) {
|
||||
return core.ErrValueType
|
||||
}
|
||||
var at time.Time
|
||||
if ttl > 0 {
|
||||
at = time.Now().Add(ttl)
|
||||
@@ -226,6 +223,11 @@ func get(tx sqlx.Tx, key string) (core.Value, error) {
|
||||
|
||||
// set sets the key value and (optionally) its expiration time.
|
||||
func set(tx sqlx.Tx, key string, value any, at time.Time) error {
|
||||
valueb, err := core.ToBytes(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var etime *int64
|
||||
if !at.IsZero() {
|
||||
etime = new(int64)
|
||||
@@ -239,11 +241,11 @@ func set(tx sqlx.Tx, key string, value any, at time.Time) error {
|
||||
etime, // etime
|
||||
time.Now().UnixMilli(), // mtime
|
||||
}
|
||||
_, err := tx.Exec(sqlSet1, args...)
|
||||
_, err = tx.Exec(sqlSet1, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(sqlSet2, key, value)
|
||||
_, err = tx.Exec(sqlSet2, key, valueb)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -251,16 +253,20 @@ func set(tx sqlx.Tx, key string, value any, at time.Time) error {
|
||||
// expiration time. If the key does not exist, creates a new key with
|
||||
// the specified value and no expiration time.
|
||||
func update(tx sqlx.Tx, key string, value any) error {
|
||||
valueb, err := core.ToBytes(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
args := []any{
|
||||
key, // key
|
||||
core.TypeString, // type
|
||||
core.InitialVersion, // version
|
||||
time.Now().UnixMilli(), // mtime
|
||||
}
|
||||
_, err := tx.Exec(sqlUpdate1, args...)
|
||||
_, err = tx.Exec(sqlUpdate1, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(sqlUpdate2, key, value)
|
||||
_, err = tx.Exec(sqlUpdate2, key, valueb)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -163,14 +163,13 @@ func (tx *Tx) Count(key string, min, max float64) (int, error) {
|
||||
// Does not delete the key if the set becomes empty.
|
||||
func (tx *Tx) Delete(key string, elems ...any) (int, error) {
|
||||
// Check the types of the elements.
|
||||
for elem := range elems {
|
||||
if !core.IsValueType(elem) {
|
||||
return 0, core.ErrValueType
|
||||
}
|
||||
elembs, err := core.ToBytesMany(elems...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Remove the elements.
|
||||
query, elemArgs := sqlx.ExpandIn(sqlDelete, ":elems", elems)
|
||||
query, elemArgs := sqlx.ExpandIn(sqlDelete, ":elems", elembs)
|
||||
args := append([]any{key, time.Now().UnixMilli()}, elemArgs...)
|
||||
res, err := tx.tx.Exec(query, args...)
|
||||
if err != nil {
|
||||
@@ -208,14 +207,15 @@ func (tx *Tx) GetRankRev(key string, elem any) (rank int, score float64, err err
|
||||
// If the element does not exist, returns ErrNotFound.
|
||||
// If the key does not exist or is not a set, returns ErrNotFound.
|
||||
func (tx *Tx) GetScore(key string, elem any) (float64, error) {
|
||||
if !core.IsValueType(elem) {
|
||||
return 0, core.ErrValueType
|
||||
elemb, err := core.ToBytes(elem)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var score float64
|
||||
args := []any{time.Now().UnixMilli(), key, elem}
|
||||
args := []any{time.Now().UnixMilli(), key, elemb}
|
||||
row := tx.tx.QueryRow(sqlGetScore, args...)
|
||||
err := row.Scan(&score)
|
||||
err = row.Scan(&score)
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, core.ErrNotFound
|
||||
}
|
||||
@@ -230,8 +230,9 @@ func (tx *Tx) GetScore(key string, elem any) (float64, error) {
|
||||
// If the element does not exist, adds it and sets the score to 0.0
|
||||
// before the increment. If the key does not exist, creates it.
|
||||
func (tx *Tx) Incr(key string, elem any, delta float64) (float64, error) {
|
||||
if !core.IsValueType(elem) {
|
||||
return 0, core.ErrValueType
|
||||
elemb, err := core.ToBytes(elem)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
args := []any{
|
||||
@@ -240,13 +241,13 @@ func (tx *Tx) Incr(key string, elem any, delta float64) (float64, error) {
|
||||
core.InitialVersion, // version
|
||||
time.Now().UnixMilli(), // mtime
|
||||
}
|
||||
_, err := tx.tx.Exec(sqlIncr1, args...)
|
||||
_, err = tx.tx.Exec(sqlIncr1, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var score float64
|
||||
args = []any{key, elem, delta}
|
||||
args = []any{key, elemb, delta}
|
||||
err = tx.tx.QueryRow(sqlIncr2, args...).Scan(&score)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -358,8 +359,9 @@ func (tx *Tx) UnionWith(keys ...string) UnionCmd {
|
||||
|
||||
// add adds or updates the element in a set.
|
||||
func (tx *Tx) add(key string, elem any, score float64) error {
|
||||
if !core.IsValueType(elem) {
|
||||
return core.ErrValueType
|
||||
elemb, err := core.ToBytes(elem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
args := []any{
|
||||
@@ -368,36 +370,35 @@ func (tx *Tx) add(key string, elem any, score float64) error {
|
||||
core.InitialVersion, // version
|
||||
time.Now().UnixMilli(), // mtime
|
||||
}
|
||||
_, err := tx.tx.Exec(sqlAdd1, args...)
|
||||
_, err = tx.tx.Exec(sqlAdd1, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.tx.Exec(sqlAdd2, key, elem, score)
|
||||
_, err = tx.tx.Exec(sqlAdd2, key, elemb, score)
|
||||
return err
|
||||
}
|
||||
|
||||
// count returns the number of existing elements in a set.
|
||||
func (tx *Tx) count(key string, elems ...any) (int, error) {
|
||||
for _, elem := range elems {
|
||||
if !core.IsValueType(elem) {
|
||||
return 0, core.ErrValueType
|
||||
}
|
||||
elembs, err := core.ToBytesMany(elems...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
query, elemArgs := sqlx.ExpandIn(sqlCount, ":elems", elems)
|
||||
query, elemArgs := sqlx.ExpandIn(sqlCount, ":elems", elembs)
|
||||
args := append([]any{time.Now().UnixMilli(), key}, elemArgs...)
|
||||
var count int
|
||||
err := tx.tx.QueryRow(query, args...).Scan(&count)
|
||||
err = tx.tx.QueryRow(query, args...).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// getRank returns the rank and score of an element in a set.
|
||||
func (tx *Tx) getRank(key string, elem any, sortDir string) (rank int, score float64, err error) {
|
||||
if !core.IsValueType(elem) {
|
||||
return 0, 0, core.ErrValueType
|
||||
elemb, err := core.ToBytes(elem)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
args := []any{time.Now().UnixMilli(), key, elem}
|
||||
args := []any{time.Now().UnixMilli(), key, elemb}
|
||||
query := sqlGetRank
|
||||
if sortDir != sqlx.Asc {
|
||||
query = strings.Replace(query, sqlx.Asc, sortDir, 2)
|
||||
|
||||
Reference in New Issue
Block a user