diff --git a/internal/core/core.go b/internal/core/core.go index 3433f49..d391e81 100644 --- a/internal/core/core.go +++ b/internal/core/core.go @@ -148,8 +148,42 @@ func (v Value) Exists() bool { // - byte slice func IsValueType(v any) bool { switch v.(type) { - case string, int, float64, bool, []byte: + case bool, float64, int, string, []byte: return true } return false } + +// ToBytesMany converts multiple values to byte slices. +func ToBytesMany(values ...any) ([][]byte, error) { + blobs := make([][]byte, len(values)) + for i, v := range values { + b, err := ToBytes(v) + if err != nil { + return nil, err + } + blobs[i] = b + } + return blobs, nil +} + +// ToBytes converts a value to a byte slice. +func ToBytes(v any) ([]byte, error) { + switch v := v.(type) { + case bool: + if v { + return []byte{'1'}, nil + } else { + return []byte{'0'}, nil + } + case float64: + return []byte(strconv.FormatFloat(v, 'f', -1, 64)), nil + case int: + return []byte(strconv.Itoa(v)), nil + case string: + return []byte(v), nil + case []byte: + return v, nil + } + return nil, ErrValueType +} diff --git a/internal/rhash/tx.go b/internal/rhash/tx.go index f2eba0d..380282c 100644 --- a/internal/rhash/tx.go +++ b/internal/rhash/tx.go @@ -443,17 +443,21 @@ func (tx *Tx) count(key string, fields ...string) (int, error) { // set creates or updates the value of a field in a hash. func (tx *Tx) set(key string, field string, value any) error { + valueb, err := core.ToBytes(value) + if err != nil { + return err + } args := []any{ key, // key core.TypeHash, // type core.InitialVersion, // version time.Now().UnixMilli(), // mtime } - _, err := tx.tx.Exec(sqlSet1, args...) + _, err = tx.tx.Exec(sqlSet1, args...) if err != nil { return err } - _, err = tx.tx.Exec(sqlSet2, key, field, value) + _, err = tx.tx.Exec(sqlSet2, key, field, valueb) return err } diff --git a/internal/rlist/tx.go b/internal/rlist/tx.go index 9c7e8e5..c7e2687 100644 --- a/internal/rlist/tx.go +++ b/internal/rlist/tx.go @@ -268,7 +268,11 @@ func NewTx(tx sqlx.Tx) *Tx { // Returns the number of elements deleted. // Does nothing if the key does not exist or is not a list. func (tx *Tx) Delete(key string, elem any) (int, error) { - args := []any{key, time.Now().UnixMilli(), elem} + elemb, err := core.ToBytes(elem) + if err != nil { + return 0, err + } + args := []any{key, time.Now().UnixMilli(), elemb} res, err := tx.tx.Exec(sqlDelete, args...) if err != nil { return 0, err @@ -436,8 +440,9 @@ func (tx *Tx) Range(key string, start, stop int) ([]core.Value, error) { // If the index is out of bounds, returns ErrNotFound. // If the key does not exist or is not a list, returns ErrNotFound. func (tx *Tx) Set(key string, idx int, elem any) error { - if !core.IsValueType(elem) { - return core.ErrValueType + elemb, err := core.ToBytes(elem) + if err != nil { + return err } var query = sqlSet @@ -447,7 +452,7 @@ func (tx *Tx) Set(key string, idx int, elem any) error { idx = -idx - 1 } - args := []any{key, time.Now().UnixMilli(), elem, idx} + args := []any{key, time.Now().UnixMilli(), elemb, idx} out, err := tx.tx.Exec(query, args...) if err != nil { return err @@ -488,8 +493,12 @@ func (tx *Tx) delete(key string, elem any, count int, query string) (int, error) if count <= 0 { return 0, nil } + elemb, err := core.ToBytes(elem) + if err != nil { + return 0, err + } - args := []any{time.Now().UnixMilli(), key, elem, count} + args := []any{time.Now().UnixMilli(), key, elemb, count} res, err := tx.tx.Exec(query, args...) if err != nil { return 0, err @@ -503,13 +512,18 @@ func (tx *Tx) delete(key string, elem any, count int, query string) (int, error) // insert inserts an element before or after a pivot in a list. func (tx *Tx) insert(key string, pivot, elem any, query string) (int, error) { - if !core.IsValueType(elem) { - return 0, core.ErrValueType + pivotb, err := core.ToBytes(pivot) + if err != nil { + return 0, err + } + elemb, err := core.ToBytes(elem) + if err != nil { + return 0, err } - args := []any{key, time.Now().UnixMilli(), pivot, elem} + args := []any{key, time.Now().UnixMilli(), pivotb, elemb} var count int - err := tx.tx.QueryRow(query, args...).Scan(&count) + err = tx.tx.QueryRow(query, args...).Scan(&count) if err == sql.ErrNoRows { return 0, core.ErrNotFound } @@ -538,8 +552,9 @@ func (tx *Tx) pop(key string, query string) (core.Value, error) { // push inserts an element to the front or back of a list. func (tx *Tx) push(key string, elem any, query string) (int, error) { - if !core.IsValueType(elem) { - return 0, core.ErrValueType + elemb, err := core.ToBytes(elem) + if err != nil { + return 0, err } // Set the key if it does not exist. @@ -550,14 +565,14 @@ func (tx *Tx) push(key string, elem any, query string) (int, error) { time.Now().UnixMilli(), // mtime } var keyID int - err := tx.tx.QueryRow(sqlSetKey, args...).Scan(&keyID) + err = tx.tx.QueryRow(sqlSetKey, args...).Scan(&keyID) if err != nil { return 0, err } // Insert the element. var count int - args = []any{keyID, elem, keyID, keyID} + args = []any{keyID, elemb, keyID, keyID} err = tx.tx.QueryRow(query, args...).Scan(&count) if err != nil { return 0, err diff --git a/internal/rstring/tx.go b/internal/rstring/tx.go index 9052567..6c8eca3 100644 --- a/internal/rstring/tx.go +++ b/internal/rstring/tx.go @@ -166,9 +166,6 @@ func (tx *Tx) Set(key string, value any) error { // SetExpires sets the key value with an optional expiration time (if ttl > 0). // Overwrites the value and ttl if the key already exists. func (tx *Tx) SetExpires(key string, value any, ttl time.Duration) error { - if !core.IsValueType(value) { - return core.ErrValueType - } var at time.Time if ttl > 0 { at = time.Now().Add(ttl) @@ -226,6 +223,11 @@ func get(tx sqlx.Tx, key string) (core.Value, error) { // set sets the key value and (optionally) its expiration time. func set(tx sqlx.Tx, key string, value any, at time.Time) error { + valueb, err := core.ToBytes(value) + if err != nil { + return err + } + var etime *int64 if !at.IsZero() { etime = new(int64) @@ -239,11 +241,11 @@ func set(tx sqlx.Tx, key string, value any, at time.Time) error { etime, // etime time.Now().UnixMilli(), // mtime } - _, err := tx.Exec(sqlSet1, args...) + _, err = tx.Exec(sqlSet1, args...) if err != nil { return err } - _, err = tx.Exec(sqlSet2, key, value) + _, err = tx.Exec(sqlSet2, key, valueb) return err } @@ -251,16 +253,20 @@ func set(tx sqlx.Tx, key string, value any, at time.Time) error { // expiration time. If the key does not exist, creates a new key with // the specified value and no expiration time. func update(tx sqlx.Tx, key string, value any) error { + valueb, err := core.ToBytes(value) + if err != nil { + return err + } args := []any{ key, // key core.TypeString, // type core.InitialVersion, // version time.Now().UnixMilli(), // mtime } - _, err := tx.Exec(sqlUpdate1, args...) + _, err = tx.Exec(sqlUpdate1, args...) if err != nil { return err } - _, err = tx.Exec(sqlUpdate2, key, value) + _, err = tx.Exec(sqlUpdate2, key, valueb) return err } diff --git a/internal/rzset/tx.go b/internal/rzset/tx.go index 2dc9f1d..87a17d9 100644 --- a/internal/rzset/tx.go +++ b/internal/rzset/tx.go @@ -163,14 +163,13 @@ func (tx *Tx) Count(key string, min, max float64) (int, error) { // Does not delete the key if the set becomes empty. func (tx *Tx) Delete(key string, elems ...any) (int, error) { // Check the types of the elements. - for elem := range elems { - if !core.IsValueType(elem) { - return 0, core.ErrValueType - } + elembs, err := core.ToBytesMany(elems...) + if err != nil { + return 0, err } // Remove the elements. - query, elemArgs := sqlx.ExpandIn(sqlDelete, ":elems", elems) + query, elemArgs := sqlx.ExpandIn(sqlDelete, ":elems", elembs) args := append([]any{key, time.Now().UnixMilli()}, elemArgs...) res, err := tx.tx.Exec(query, args...) if err != nil { @@ -208,14 +207,15 @@ func (tx *Tx) GetRankRev(key string, elem any) (rank int, score float64, err err // If the element does not exist, returns ErrNotFound. // If the key does not exist or is not a set, returns ErrNotFound. func (tx *Tx) GetScore(key string, elem any) (float64, error) { - if !core.IsValueType(elem) { - return 0, core.ErrValueType + elemb, err := core.ToBytes(elem) + if err != nil { + return 0, err } var score float64 - args := []any{time.Now().UnixMilli(), key, elem} + args := []any{time.Now().UnixMilli(), key, elemb} row := tx.tx.QueryRow(sqlGetScore, args...) - err := row.Scan(&score) + err = row.Scan(&score) if err == sql.ErrNoRows { return 0, core.ErrNotFound } @@ -230,8 +230,9 @@ func (tx *Tx) GetScore(key string, elem any) (float64, error) { // If the element does not exist, adds it and sets the score to 0.0 // before the increment. If the key does not exist, creates it. func (tx *Tx) Incr(key string, elem any, delta float64) (float64, error) { - if !core.IsValueType(elem) { - return 0, core.ErrValueType + elemb, err := core.ToBytes(elem) + if err != nil { + return 0, err } args := []any{ @@ -240,13 +241,13 @@ func (tx *Tx) Incr(key string, elem any, delta float64) (float64, error) { core.InitialVersion, // version time.Now().UnixMilli(), // mtime } - _, err := tx.tx.Exec(sqlIncr1, args...) + _, err = tx.tx.Exec(sqlIncr1, args...) if err != nil { return 0, err } var score float64 - args = []any{key, elem, delta} + args = []any{key, elemb, delta} err = tx.tx.QueryRow(sqlIncr2, args...).Scan(&score) if err != nil { return 0, err @@ -358,8 +359,9 @@ func (tx *Tx) UnionWith(keys ...string) UnionCmd { // add adds or updates the element in a set. func (tx *Tx) add(key string, elem any, score float64) error { - if !core.IsValueType(elem) { - return core.ErrValueType + elemb, err := core.ToBytes(elem) + if err != nil { + return err } args := []any{ @@ -368,36 +370,35 @@ func (tx *Tx) add(key string, elem any, score float64) error { core.InitialVersion, // version time.Now().UnixMilli(), // mtime } - _, err := tx.tx.Exec(sqlAdd1, args...) + _, err = tx.tx.Exec(sqlAdd1, args...) if err != nil { return err } - _, err = tx.tx.Exec(sqlAdd2, key, elem, score) + _, err = tx.tx.Exec(sqlAdd2, key, elemb, score) return err } // count returns the number of existing elements in a set. func (tx *Tx) count(key string, elems ...any) (int, error) { - for _, elem := range elems { - if !core.IsValueType(elem) { - return 0, core.ErrValueType - } + elembs, err := core.ToBytesMany(elems...) + if err != nil { + return 0, err } - - query, elemArgs := sqlx.ExpandIn(sqlCount, ":elems", elems) + query, elemArgs := sqlx.ExpandIn(sqlCount, ":elems", elembs) args := append([]any{time.Now().UnixMilli(), key}, elemArgs...) var count int - err := tx.tx.QueryRow(query, args...).Scan(&count) + err = tx.tx.QueryRow(query, args...).Scan(&count) return count, err } // getRank returns the rank and score of an element in a set. func (tx *Tx) getRank(key string, elem any, sortDir string) (rank int, score float64, err error) { - if !core.IsValueType(elem) { - return 0, 0, core.ErrValueType + elemb, err := core.ToBytes(elem) + if err != nil { + return 0, 0, err } - args := []any{time.Now().UnixMilli(), key, elem} + args := []any{time.Now().UnixMilli(), key, elemb} query := sqlGetRank if sortDir != sqlx.Asc { query = strings.Replace(query, sqlx.Asc, sortDir, 2)