From 992dd8cc2c6df7ca159203054fc970a39eaacc38 Mon Sep 17 00:00:00 2001 From: Anton Date: Sun, 21 Apr 2024 17:53:43 +0500 Subject: [PATCH] refactor: rstring - ErrNotFound if key does not exist --- internal/command/get.go | 12 +- internal/rkey/tx.go | 15 ++ internal/rstring/db.go | 34 ++-- internal/rstring/db_test.go | 357 +++++++++++++++++++++++------------- internal/rstring/tx.go | 91 +++++---- 5 files changed, 322 insertions(+), 187 deletions(-) diff --git a/internal/command/get.go b/internal/command/get.go index 2975aba..83f3343 100644 --- a/internal/command/get.go +++ b/internal/command/get.go @@ -1,5 +1,7 @@ package command +import "github.com/nalgeon/redka/internal/core" + // Get returns the string value of a key. // GET key // https://redis.io/commands/get @@ -19,14 +21,14 @@ func parseGet(b baseCmd) (*Get, error) { func (cmd *Get) Run(w Writer, red Redka) (any, error) { val, err := red.Str().Get(cmd.key) - if err != nil { - w.WriteError(cmd.Error(err)) - return nil, err - } - if !val.Exists() { + if err == core.ErrNotFound { w.WriteNull() return val, nil } + if err != nil { + w.WriteError(cmd.Error(err)) + return val, err + } w.WriteBulk(val) return val, nil } diff --git a/internal/rkey/tx.go b/internal/rkey/tx.go index dcc6e02..0756d47 100644 --- a/internal/rkey/tx.go +++ b/internal/rkey/tx.go @@ -18,6 +18,10 @@ const sqlCount = ` select count(id) from rkey where key in (:keys) and (etime is null or etime > :now)` +const sqlCountType = ` +select count(id) from rkey +where key in (:keys) and (etime is null or etime > :now) and type = :type` + const sqlKeys = ` select id, key, type, version, etime, mtime from rkey where key glob :pattern and (etime is null or etime > :now)` @@ -424,6 +428,17 @@ func Count(tx sqlx.Tx, keys ...string) (int, error) { return count, err } +// CountType returns the number of existing keys +// of a specific type among specified keys. +func CountType(tx sqlx.Tx, typ core.TypeID, keys ...string) (int, error) { + now := time.Now().UnixMilli() + query, keyArgs := sqlx.ExpandIn(sqlCountType, ":keys", keys) + args := slices.Concat(keyArgs, []any{sql.Named("now", now), sql.Named("type", typ)}) + var count int + err := tx.QueryRow(query, args...).Scan(&count) + return count, err +} + // Delete deletes keys and their values (regardless of the type). func Delete(tx sqlx.Tx, keys ...string) (int, error) { now := time.Now().UnixMilli() diff --git a/internal/rstring/db.go b/internal/rstring/db.go index 5cf4c42..40ba032 100644 --- a/internal/rstring/db.go +++ b/internal/rstring/db.go @@ -25,14 +25,15 @@ func New(db *sql.DB) *DB { } // Get returns the value of the key. -// Returns nil if the key does not exist. +// If the key does not exist or is not a string, returns ErrNotFound. func (d *DB) Get(key string) (core.Value, error) { tx := NewTx(d.SQL) return tx.Get(key) } // GetMany returns a map of values for given keys. -// Returns nil for keys that do not exist. +// Ignores keys that do not exist or not strings, +// and does not return them in the map. func (d *DB) GetMany(keys ...string) (map[string]core.Value, error) { tx := NewTx(d.SQL) return tx.GetMany(keys...) @@ -40,8 +41,9 @@ func (d *DB) GetMany(keys ...string) (map[string]core.Value, error) { // GetSet returns the previous value of a key after setting it to a new value. // Optionally sets the expiration time (if ttl > 0). -// Overwrites the value and ttl if the key already exists. -// Returns nil if the key did not exist. +// If the key already exists, overwrites the value and ttl. +// If the key exists but is not a string, returns ErrKeyType. +// If the key does not exist, returns nil as the previous value. func (d *DB) GetSet(key string, value any, ttl time.Duration) (core.Value, error) { var val core.Value err := d.Update(func(tx *Tx) error { @@ -52,10 +54,11 @@ func (d *DB) GetSet(key string, value any, ttl time.Duration) (core.Value, error return val, err } -// Incr increments the key value by the specified amount. -// If the key does not exist, sets it to 0 before the increment. +// Incr increments the integer key value by the specified amount. // Returns the value after the increment. -// Returns an error if the key value is not an integer. +// If the key does not exist, sets it to 0 before the increment. +// If the key value is not an integer, returns ErrValueType. +// If the key exists but is not a string, returns ErrKeyType. func (d *DB) Incr(key string, delta int) (int, error) { var val int err := d.Update(func(tx *Tx) error { @@ -66,10 +69,11 @@ func (d *DB) Incr(key string, delta int) (int, error) { return val, err } -// IncrFloat increments the key value by the specified amount. -// If the key does not exist, sets it to 0 before the increment. +// IncrFloat increments the float key value by the specified amount. // Returns the value after the increment. -// Returns an error if the key value is not a float. +// If the key does not exist, sets it to 0 before the increment. +// If the key value is not an float, returns ErrValueType. +// If the key exists but is not a string, returns ErrKeyType. func (d *DB) IncrFloat(key string, delta float64) (float64, error) { var val float64 err := d.Update(func(tx *Tx) error { @@ -82,6 +86,7 @@ func (d *DB) IncrFloat(key string, delta float64) (float64, error) { // Set sets the key value that will not expire. // Overwrites the value if the key already exists. +// If the key exists but is not a string, returns ErrKeyType. func (d *DB) Set(key string, value any) error { err := d.Update(func(tx *Tx) error { return tx.Set(key, value) @@ -92,6 +97,7 @@ func (d *DB) Set(key string, value any) error { // SetExists sets the key value if the key exists. // Optionally sets the expiration time (if ttl > 0). // Returns true if the key was set, false if the key does not exist. +// If the key exists but is not a string, returns ErrKeyType. func (d *DB) SetExists(key string, value any, ttl time.Duration) (bool, error) { var ok bool err := d.Update(func(tx *Tx) error { @@ -104,6 +110,7 @@ func (d *DB) SetExists(key string, value any, ttl time.Duration) (bool, error) { // SetExpires sets the key value with an optional expiration time (if ttl > 0). // Overwrites the value and ttl if the key already exists. +// If the key exists but is not a string, returns ErrKeyType. func (d *DB) SetExpires(key string, value any, ttl time.Duration) error { err := d.Update(func(tx *Tx) error { return tx.SetExpires(key, value, ttl) @@ -115,6 +122,7 @@ func (d *DB) SetExpires(key string, value any, ttl time.Duration) error { // Overwrites values for keys that already exist and // creates new keys/values for keys that do not exist. // Removes the TTL for existing keys. +// If any of the keys exists but is not a string, returns ErrKeyType. func (d *DB) SetMany(items map[string]any) error { err := d.Update(func(tx *Tx) error { return tx.SetMany(items) @@ -123,8 +131,9 @@ func (d *DB) SetMany(items map[string]any) error { } // SetManyNX sets the values of multiple keys, but only if none -// of them yet exist. Returns true if the keys were set, false if any -// of them already exist. +// of them yet exist. Returns true if the keys were set, +// false if any of them already exist. +// If any of the keys exists but is not a string, returns ErrKeyType. func (d *DB) SetManyNX(items map[string]any) (bool, error) { var ok bool err := d.Update(func(tx *Tx) error { @@ -138,6 +147,7 @@ func (d *DB) SetManyNX(items map[string]any) (bool, error) { // SetNotExists sets the key value if the key does not exist. // Optionally sets the expiration time (if ttl > 0). // Returns true if the key was set, false if the key already exists. +// If the key exists but is not a string, returns ErrKeyType. func (d *DB) SetNotExists(key string, value any, ttl time.Duration) (bool, error) { var ok bool err := d.Update(func(tx *Tx) error { diff --git a/internal/rstring/db_test.go b/internal/rstring/db_test.go index d47a85e..b155f4d 100644 --- a/internal/rstring/db_test.go +++ b/internal/rstring/db_test.go @@ -11,26 +11,32 @@ import ( ) func TestGet(t *testing.T) { - red, db := getDB(t) - defer red.Close() + t.Run("key found", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() + _ = db.Set("name", "alice") - _ = db.Set("name", "alice") + val, err := db.Get("name") + testx.AssertNoErr(t, err) + testx.AssertEqual(t, val, core.Value("alice")) + }) + t.Run("key not found", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() - tests := []struct { - name string - key string - want any - }{ - {"key found", "name", core.Value("alice")}, - {"key not found", "key1", core.Value(nil)}, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - val, err := db.Get(test.key) - testx.AssertNoErr(t, err) - testx.AssertEqual(t, val, test.want) - }) - } + val, err := db.Get("name") + testx.AssertErr(t, err, core.ErrNotFound) + testx.AssertEqual(t, val, core.Value(nil)) + }) + t.Run("key type mismatch", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() + _, _ = red.Hash().Set("person", "name", "alice") + + val, err := db.Get("person") + testx.AssertErr(t, err, core.ErrNotFound) + testx.AssertEqual(t, val, core.Value(nil)) + }) } func TestGetMany(t *testing.T) { @@ -39,6 +45,8 @@ func TestGetMany(t *testing.T) { _ = db.Set("name", "alice") _ = db.Set("age", 25) + _, _ = red.Hash().Set("hash1", "f1", "v1") + _, _ = red.Hash().Set("hash2", "f2", "v2") tests := []struct { name string @@ -52,13 +60,14 @@ func TestGetMany(t *testing.T) { }, {"some found", []string{"name", "key1"}, map[string]core.Value{ - "name": core.Value("alice"), "key1": core.Value(nil), + "name": core.Value("alice"), }, }, {"none found", []string{"key1", "key2"}, - map[string]core.Value{ - "key1": core.Value(nil), "key2": core.Value(nil), - }, + map[string]core.Value{}, + }, + {"key type mismatch", []string{"hash1", "hash2"}, + map[string]core.Value{}, }, } for _, test := range tests { @@ -131,34 +140,56 @@ func TestGetSet(t *testing.T) { } func TestIncr(t *testing.T) { - red, db := getDB(t) - defer red.Close() + t.Run("create", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() - tests := []struct { - name string - key string - value int - want int - }{ - {"create", "age", 10, 10}, - {"increment", "age", 15, 25}, - {"decrement", "age", -5, 20}, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - val, err := db.Incr(test.key, test.value) - testx.AssertNoErr(t, err) - testx.AssertEqual(t, val, test.want) - }) - } + val, err := db.Incr("age", 25) + testx.AssertNoErr(t, err) + testx.AssertEqual(t, val, 25) + + age, _ := db.Get("age") + testx.AssertEqual(t, age.MustInt(), 25) + }) + t.Run("increment", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() + _ = db.Set("age", "25") + + val, err := db.Incr("age", 10) + testx.AssertNoErr(t, err) + testx.AssertEqual(t, val, 35) + + age, _ := db.Get("age") + testx.AssertEqual(t, age.MustInt(), 35) + }) + + t.Run("decrement", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() + _ = db.Set("age", "25") + + val, err := db.Incr("age", -10) + testx.AssertNoErr(t, err) + testx.AssertEqual(t, val, 15) + + age, _ := db.Get("age") + testx.AssertEqual(t, age.MustInt(), 15) + }) t.Run("invalid int", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() _ = db.Set("name", "alice") + val, err := db.Incr("name", 1) testx.AssertErr(t, err, core.ErrValueType) testx.AssertEqual(t, val, 0) }) t.Run("key type mismatch", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() _, _ = red.Hash().Set("person", "age", 25) + val, err := db.Incr("person", 10) testx.AssertErr(t, err, core.ErrKeyType) testx.AssertEqual(t, val, 0) @@ -166,34 +197,56 @@ func TestIncr(t *testing.T) { } func TestIncrFloat(t *testing.T) { - red, db := getDB(t) - defer red.Close() + t.Run("create", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() - tests := []struct { - name string - key string - value float64 - want float64 - }{ - {"create", "pi", 3.14, 3.14}, - {"increment", "pi", 1.86, 5}, - {"decrement", "pi", -1.5, 3.5}, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - val, err := db.IncrFloat(test.key, test.value) - testx.AssertNoErr(t, err) - testx.AssertEqual(t, val, test.want) - }) - } + val, err := db.IncrFloat("pi", 3.14) + testx.AssertNoErr(t, err) + testx.AssertEqual(t, val, 3.14) + + pi, _ := db.Get("pi") + testx.AssertEqual(t, pi.MustFloat(), 3.14) + }) + t.Run("increment", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() + _ = db.Set("pi", "3.14") + + val, err := db.IncrFloat("pi", 1.86) + testx.AssertNoErr(t, err) + testx.AssertEqual(t, val, 5.0) + + pi, _ := db.Get("pi") + testx.AssertEqual(t, pi.MustFloat(), 5.0) + }) + + t.Run("decrement", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() + _ = db.Set("pi", "3.14") + + val, err := db.IncrFloat("pi", -1.14) + testx.AssertNoErr(t, err) + testx.AssertEqual(t, val, 2.0) + + pi, _ := db.Get("pi") + testx.AssertEqual(t, pi.MustFloat(), 2.0) + }) t.Run("invalid float", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() _ = db.Set("name", "alice") + val, err := db.IncrFloat("name", 1.5) testx.AssertErr(t, err, core.ErrValueType) testx.AssertEqual(t, val, 0.0) }) t.Run("key type mismatch", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() _, _ = red.Hash().Set("person", "age", 25.5) + val, err := db.IncrFloat("person", 10.5) testx.AssertErr(t, err, core.ErrKeyType) testx.AssertEqual(t, val, 0.0) @@ -201,25 +254,25 @@ func TestIncrFloat(t *testing.T) { } func TestSet(t *testing.T) { - red, db := getDB(t) - defer red.Close() + t.Run("set", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() - tests := []struct { - name string - key string - value any - want any - }{ - {"string", "name", "alice", core.Value("alice")}, - {"empty string", "empty", "", core.Value("")}, - {"int", "age", 25, core.Value("25")}, - {"float", "pi", 3.14, core.Value("3.14")}, - {"bool true", "ok", true, core.Value("1")}, - {"bool false", "ok", false, core.Value("0")}, - {"bytes", "bytes", []byte("hello"), core.Value("hello")}, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { + tests := []struct { + name string + key string + value any + want any + }{ + {"string", "name", "alice", core.Value("alice")}, + {"empty string", "empty", "", core.Value("")}, + {"int", "age", 25, core.Value("25")}, + {"float", "pi", 3.14, core.Value("3.14")}, + {"bool true", "ok", true, core.Value("1")}, + {"bool false", "ok", false, core.Value("0")}, + {"bytes", "bytes", []byte("hello"), core.Value("hello")}, + } + for _, test := range tests { err := db.Set(test.key, test.value) testx.AssertNoErr(t, err) @@ -228,70 +281,97 @@ func TestSet(t *testing.T) { key, _ := red.Key().Get(test.key) testx.AssertEqual(t, key.ETime, (*int64)(nil)) - }) - } + } + }) t.Run("struct", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() + err := db.Set("struct", struct{ Name string }{"alice"}) testx.AssertErr(t, err, core.ErrValueType) }) t.Run("nil", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() + err := db.Set("nil", nil) testx.AssertErr(t, err, core.ErrValueType) }) t.Run("update", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() _ = db.Set("name", "alice") + err := db.Set("name", "bob") testx.AssertNoErr(t, err) val, _ := db.Get("name") testx.AssertEqual(t, val, core.Value("bob")) }) t.Run("change value type", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() _ = db.Set("name", "alice") + err := db.Set("name", true) testx.AssertNoErr(t, err) val, _ := db.Get("name") testx.AssertEqual(t, val, core.Value("1")) }) t.Run("not changed", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() _ = db.Set("name", "alice") + err := db.Set("name", "alice") testx.AssertNoErr(t, err) val, _ := db.Get("name") testx.AssertEqual(t, val, core.Value("alice")) }) t.Run("key type mismatch", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() _, _ = red.Hash().Set("person", "name", "alice") + err := db.Set("person", "name") testx.AssertErr(t, err, core.ErrKeyType) + + _, err = db.Get("person") + testx.AssertErr(t, err, core.ErrNotFound) }) } func TestSetExists(t *testing.T) { - red, db := getDB(t) - defer red.Close() + t.Run("key exists", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() + _ = db.Set("name", "alice") - _ = db.Set("name", "alice") + ok, err := db.SetExists("name", "bob", 0) + testx.AssertNoErr(t, err) + testx.AssertEqual(t, ok, true) - tests := []struct { - name string - key string - value any - want bool - }{ - {"new key", "age", 25, false}, - {"existing key", "name", "bob", true}, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ok, err := db.SetExists(test.key, test.value, 0) - testx.AssertNoErr(t, err) - testx.AssertEqual(t, ok, test.want) + name, _ := db.Get("name") + testx.AssertEqual(t, name, core.Value("bob")) - key, _ := red.Key().Get(test.key) - testx.AssertEqual(t, key.ETime, (*int64)(nil)) - }) - } + key, _ := red.Key().Get("name") + testx.AssertEqual(t, key.ETime, (*int64)(nil)) + }) + t.Run("key not found", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() + + ok, err := db.SetExists("name", "alice", 0) + testx.AssertNoErr(t, err) + testx.AssertEqual(t, ok, false) + + _, err = db.Get("name") + testx.AssertErr(t, err, core.ErrNotFound) + }) t.Run("with ttl", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() + _ = db.Set("name", "alice") + now := time.Now() ttl := time.Second ok, err := db.SetExists("name", "cindy", ttl) @@ -304,10 +384,16 @@ func TestSetExists(t *testing.T) { testx.AssertEqual(t, got, want) }) t.Run("key type mismatch", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() _, _ = red.Hash().Set("person", "name", "alice") + ok, err := db.SetExists("person", "name", 0) testx.AssertErr(t, err, core.ErrKeyType) testx.AssertEqual(t, ok, false) + + _, err = db.Get("person") + testx.AssertErr(t, err, core.ErrNotFound) }) } @@ -348,6 +434,9 @@ func TestSetExpires(t *testing.T) { _, _ = red.Hash().Set("person", "name", "alice") err := db.SetExpires("person", "name", time.Second) testx.AssertErr(t, err, core.ErrKeyType) + + _, err = db.Get("person") + testx.AssertErr(t, err, core.ErrNotFound) }) } @@ -396,11 +485,17 @@ func TestSetMany(t *testing.T) { red, db := getDB(t) defer red.Close() _, _ = red.Hash().Set("person", "name", "alice") + err := db.SetMany(map[string]any{ "name": "alice", "person": "alice", }) testx.AssertErr(t, err, core.ErrKeyType) + + _, err = db.Get("name") + testx.AssertErr(t, err, core.ErrNotFound) + _, err = db.Get("person") + testx.AssertErr(t, err, core.ErrNotFound) }) } @@ -452,41 +547,49 @@ func TestSetManyNX(t *testing.T) { red, db := getDB(t) defer red.Close() _, _ = red.Hash().Set("person", "name", "alice") + ok, err := db.SetManyNX(map[string]any{ "name": "alice", "person": "alice", }) - testx.AssertNoErr(t, err) + testx.AssertErr(t, err, core.ErrKeyType) testx.AssertEqual(t, ok, false) + + _, err = db.Get("name") + testx.AssertErr(t, err, core.ErrNotFound) + _, err = db.Get("person") + testx.AssertErr(t, err, core.ErrNotFound) }) } func TestSetNotExists(t *testing.T) { - red, db := getDB(t) - defer red.Close() + t.Run("key exists", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() + _ = db.Set("name", "alice") - _ = db.Set("name", "alice") + ok, err := db.SetNotExists("name", "bob", 0) + testx.AssertNoErr(t, err) + testx.AssertEqual(t, ok, false) - tests := []struct { - name string - key string - value any - want bool - }{ - {"new key", "age", 25, true}, - {"existing key", "name", "bob", false}, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ok, err := db.SetNotExists(test.key, test.value, 0) - testx.AssertNoErr(t, err) - testx.AssertEqual(t, ok, test.want) + name, _ := db.Get("name") + testx.AssertEqual(t, name, core.Value("alice")) + }) + t.Run("key not found", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() - key, _ := red.Key().Get(test.key) - testx.AssertEqual(t, key.ETime, (*int64)(nil)) - }) - } + ok, err := db.SetNotExists("name", "alice", 0) + testx.AssertNoErr(t, err) + testx.AssertEqual(t, ok, true) + + name, _ := db.Get("name") + testx.AssertEqual(t, name, core.Value("alice")) + }) t.Run("with ttl", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() + now := time.Now() ttl := time.Second ok, err := db.SetNotExists("city", "paris", ttl) @@ -499,10 +602,16 @@ func TestSetNotExists(t *testing.T) { testx.AssertEqual(t, got, want) }) t.Run("key type mismatch", func(t *testing.T) { + red, db := getDB(t) + defer red.Close() _, _ = red.Hash().Set("person", "name", "alice") + ok, err := db.SetNotExists("person", "name", 0) - testx.AssertNoErr(t, err) + testx.AssertErr(t, err, core.ErrKeyType) testx.AssertEqual(t, ok, false) + + _, err = db.Get("person") + testx.AssertErr(t, err, core.ErrNotFound) }) } diff --git a/internal/rstring/tx.go b/internal/rstring/tx.go index aea6733..ce5617d 100644 --- a/internal/rstring/tx.go +++ b/internal/rstring/tx.go @@ -12,7 +12,7 @@ import ( const ( sqlGet = ` - select key, value + select value from rstring join rkey on key_id = rkey.id and (etime is null or etime > :now) where key = :key` @@ -66,26 +66,27 @@ func NewTx(tx sqlx.Tx) *Tx { } // Get returns the value of the key. -// Returns nil if the key does not exist. +// If the key does not exist or is not a string, returns ErrNotFound. func (tx *Tx) Get(key string) (core.Value, error) { args := []any{ sql.Named("key", key), sql.Named("now", time.Now().UnixMilli()), } - row := tx.tx.QueryRow(sqlGet, args...) - _, val, err := scanValue(row) - return val, err + var val []byte + err := tx.tx.QueryRow(sqlGet, args...).Scan(&val) + if err == sql.ErrNoRows { + return core.Value(nil), core.ErrNotFound + } + if err != nil { + return core.Value(nil), err + } + return core.Value(val), nil } // GetMany returns a map of values for given keys. -// Returns nil for keys that do not exist. +// Ignores keys that do not exist or not strings, +// and does not return them in the map. func (tx *Tx) GetMany(keys ...string) (map[string]core.Value, error) { - // Build a map of requested keys. - items := make(map[string]core.Value, len(keys)) - for _, key := range keys { - items[key] = nil - } - // Get the values of the requested keys. now := time.Now().UnixMilli() query, keyArgs := sqlx.ExpandIn(sqlGetMany, ":keys", keys) @@ -98,14 +99,16 @@ func (tx *Tx) GetMany(keys ...string) (map[string]core.Value, error) { } defer rows.Close() - // Fill the map with the values for existing keys - // (the rest of the keys will remain nil). + // Fill the map with the values for existing keys. + items := map[string]core.Value{} for rows.Next() { - key, val, err := scanValue(rows) + var key string + var val []byte + err = rows.Scan(&key, &val) if err != nil { return nil, err } - items[key] = val + items[key] = core.Value(val) } if rows.Err() != nil { return nil, rows.Err() @@ -116,15 +119,16 @@ func (tx *Tx) GetMany(keys ...string) (map[string]core.Value, error) { // GetSet returns the previous value of a key after setting it to a new value. // Optionally sets the expiration time (if ttl > 0). -// Overwrites the value and ttl if the key already exists. -// Returns nil if the key did not exist. +// If the key already exists, overwrites the value and ttl. +// If the key exists but is not a string, returns ErrKeyType. +// If the key does not exist, returns nil as the previous value. func (tx *Tx) GetSet(key string, value any, ttl time.Duration) (core.Value, error) { if !core.IsValueType(value) { return nil, core.ErrValueType } prev, err := tx.Get(key) - if err != nil { + if err != nil && err != core.ErrNotFound { return nil, err } @@ -132,14 +136,15 @@ func (tx *Tx) GetSet(key string, value any, ttl time.Duration) (core.Value, erro return prev, err } -// Incr increments the key value by the specified amount. -// If the key does not exist, sets it to 0 before the increment. +// Incr increments the integer key value by the specified amount. // Returns the value after the increment. -// Returns an error if the key value is not an integer. +// If the key does not exist, sets it to 0 before the increment. +// If the key value is not an integer, returns ErrValueType. +// If the key exists but is not a string, returns ErrKeyType. func (tx *Tx) Incr(key string, delta int) (int, error) { // get the current value val, err := tx.Get(key) - if err != nil { + if err != nil && err != core.ErrNotFound { return 0, err } @@ -159,14 +164,15 @@ func (tx *Tx) Incr(key string, delta int) (int, error) { return newVal, nil } -// IncrFloat increments the key value by the specified amount. -// If the key does not exist, sets it to 0 before the increment. +// IncrFloat increments the float key value by the specified amount. // Returns the value after the increment. -// Returns an error if the key value is not a float. +// If the key does not exist, sets it to 0 before the increment. +// If the key value is not an float, returns ErrValueType. +// If the key exists but is not a string, returns ErrKeyType. func (tx *Tx) IncrFloat(key string, delta float64) (float64, error) { // get the current value val, err := tx.Get(key) - if err != nil { + if err != nil && err != core.ErrNotFound { return 0, err } @@ -188,6 +194,7 @@ func (tx *Tx) IncrFloat(key string, delta float64) (float64, error) { // Set sets the key value that will not expire. // Overwrites the value if the key already exists. +// If the key exists but is not a string, returns ErrKeyType. func (tx *Tx) Set(key string, value any) error { return tx.SetExpires(key, value, 0) } @@ -195,6 +202,7 @@ func (tx *Tx) Set(key string, value any) error { // SetExists sets the key value if the key exists. // Optionally sets the expiration time (if ttl > 0). // Returns true if the key was set, false if the key does not exist. +// If the key exists but is not a string, returns ErrKeyType. func (tx *Tx) SetExists(key string, value any, ttl time.Duration) (bool, error) { if !core.IsValueType(value) { return false, core.ErrValueType @@ -214,6 +222,7 @@ func (tx *Tx) SetExists(key string, value any, ttl time.Duration) (bool, error) // SetExpires sets the key value with an optional expiration time (if ttl > 0). // Overwrites the value and ttl if the key already exists. +// If the key exists but is not a string, returns ErrKeyType. func (tx *Tx) SetExpires(key string, value any, ttl time.Duration) error { if !core.IsValueType(value) { return core.ErrValueType @@ -226,6 +235,7 @@ func (tx *Tx) SetExpires(key string, value any, ttl time.Duration) error { // Overwrites values for keys that already exist and // creates new keys/values for keys that do not exist. // Removes the TTL for existing keys. +// If any of the keys exists but is not a string, returns ErrKeyType. func (tx *Tx) SetMany(items map[string]any) error { for _, val := range items { if !core.IsValueType(val) { @@ -244,8 +254,9 @@ func (tx *Tx) SetMany(items map[string]any) error { } // SetManyNX sets the values of multiple keys, but only if none -// of them yet exist. Returns true if the keys were set, false if any -// of them already exist. +// of them yet exist. Returns true if the keys were set, +// false if any of them already exist. +// If any of the keys exists but is not a string, returns ErrKeyType. func (tx *Tx) SetManyNX(items map[string]any) (bool, error) { for _, val := range items { if !core.IsValueType(val) { @@ -260,7 +271,7 @@ func (tx *Tx) SetManyNX(items map[string]any) (bool, error) { } // check if any of the keys exist - count, err := rkey.Count(tx.tx, keys...) + count, err := rkey.CountType(tx.tx, core.TypeString, keys...) if err != nil { return false, err } @@ -284,16 +295,17 @@ func (tx *Tx) SetManyNX(items map[string]any) (bool, error) { // SetNotExists sets the key value if the key does not exist. // Optionally sets the expiration time (if ttl > 0). // Returns true if the key was set, false if the key already exists. +// If the key exists but is not a string, returns ErrKeyType. func (tx *Tx) SetNotExists(key string, value any, ttl time.Duration) (bool, error) { if !core.IsValueType(value) { return false, core.ErrValueType } - k, err := rkey.Get(tx.tx, key) - if err != nil { + val, err := tx.Get(key) + if err != nil && err != core.ErrNotFound { return false, err } - if k.Exists() { + if val.Exists() { return false, nil } @@ -347,16 +359,3 @@ func (tx *Tx) update(key string, value any) error { _, err = tx.tx.Exec(sqlUpdate2, args...) return err } - -// scanValue scans a key value from the row (rows). -func scanValue(scanner sqlx.RowScanner) (key string, val core.Value, err error) { - var value []byte - err = scanner.Scan(&key, &value) - if err == sql.ErrNoRows { - return "", nil, nil - } - if err != nil { - return "", nil, err - } - return key, core.Value(value), nil -}