mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-09-26 20:11:15 +08:00
Implement KEYS command
This commit is contained in:
@@ -222,6 +222,7 @@ Benchmark script options:
|
||||
* [INCR](https://sugardb.io/docs/commands/generic/incr)
|
||||
* [INCRBY](https://sugardb.io/docs/commands/generic/incrby)
|
||||
* [INCRBYFLOAT](https://sugardb.io/docs/commands/generic/incrbyfloat)
|
||||
* [KEYS](https://sugardb.io/docs/commands/generic/keys)
|
||||
* [MGET](https://sugardb.io/docs/commands/generic/mget)
|
||||
* [MOVE](https://sugardb.io/docs/commands/generic/move)
|
||||
* [MSET](https://sugardb.io/docs/commands/generic/mset)
|
||||
|
File diff suppressed because it is too large
Load Diff
51
docs/docs/commands/generic/keys.mdx
Normal file
51
docs/docs/commands/generic/keys.mdx
Normal file
@@ -0,0 +1,51 @@
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# KEYS
|
||||
|
||||
### Syntax
|
||||
```
|
||||
KEYS
|
||||
```
|
||||
|
||||
### Module
|
||||
<span className="acl-category">generic</span>
|
||||
|
||||
### Categories
|
||||
<span className="acl-category">slow</span>
|
||||
<span className="acl-category">read</span>
|
||||
<span className="acl-category">keyspace</span>
|
||||
<span className="acl-category">dangerous</span>
|
||||
|
||||
### Description
|
||||
Returns an array of keys that match the provided glob pattern. This follows the same pattern matching rules as Redis.
|
||||
|
||||
### Examples
|
||||
|
||||
<Tabs
|
||||
defaultValue="go"
|
||||
values={[
|
||||
{ label: 'Go (Embedded)', value: 'go', },
|
||||
{ label: 'CLI', value: 'cli', },
|
||||
]}
|
||||
>
|
||||
<TabItem value="go">
|
||||
Return the keys matching the pattern:
|
||||
```go
|
||||
db, err := sugardb.NewSugarDB()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
key, err := db.Keys("a??")
|
||||
// keys can be eg. age, all
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="cli">
|
||||
Return the keys matching the pattern:
|
||||
```
|
||||
> KEYS a??
|
||||
age
|
||||
all
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
@@ -995,6 +995,38 @@ func handleMove(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
return []byte(fmt.Sprintf("+%v\r\n", 0)), nil
|
||||
}
|
||||
|
||||
func handleKeys(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
keys, err := keysKeyFunc(params.Command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pattern := keys.ReadKeys[0]
|
||||
storeKeys := params.GetKeys(params.Context)
|
||||
|
||||
var matchedKeys []string
|
||||
|
||||
// Special case for * pattern - return all keys
|
||||
if pattern == "*" {
|
||||
matchedKeys = storeKeys
|
||||
} else {
|
||||
// Find all matching keys using direct pattern matching
|
||||
for _, key := range storeKeys {
|
||||
if matchPattern(pattern, key) {
|
||||
matchedKeys = append(matchedKeys, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build byte response
|
||||
res := fmt.Sprintf("*%d\r\n", len(matchedKeys))
|
||||
for _, key := range matchedKeys {
|
||||
res += fmt.Sprintf("$%d\r\n%s\r\n", len(key), key)
|
||||
}
|
||||
|
||||
return []byte(res), nil
|
||||
}
|
||||
|
||||
func Commands() []internal.Command {
|
||||
return []internal.Command{
|
||||
{
|
||||
@@ -1414,5 +1446,15 @@ The REPLACE option removes the destination key before copying the value to it.`,
|
||||
KeyExtractionFunc: existsKeyFunc,
|
||||
HandlerFunc: handleExists,
|
||||
},
|
||||
{
|
||||
Command: "keys",
|
||||
Module: constants.GenericModule,
|
||||
Categories: []string{constants.KeyspaceCategory, constants.ReadCategory, constants.SlowCategory, constants.DangerousCategory},
|
||||
Description: "(KEYS pattern) Returns an array of keys that match the provided glob pattern.",
|
||||
Sync: false,
|
||||
Type: "BUILT_IN",
|
||||
KeyExtractionFunc: keysKeyFunc,
|
||||
HandlerFunc: handleKeys,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@@ -21,6 +21,8 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
"sort"
|
||||
"reflect"
|
||||
|
||||
"github.com/echovault/sugardb/internal"
|
||||
"github.com/echovault/sugardb/internal/clock"
|
||||
@@ -3892,6 +3894,155 @@ func Test_Generic(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Test_handleKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
mockServer, err := sugardb.NewSugarDB(
|
||||
sugardb.WithConfig(config.Config{
|
||||
BindAddr: "localhost",
|
||||
Port: uint16(port),
|
||||
DataDir: "",
|
||||
EvictionPolicy: constants.NoEviction,
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
mockServer.Start()
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
mockServer.ShutDown()
|
||||
})
|
||||
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
presetValues map[string]interface{}
|
||||
command []string
|
||||
expectedResponse []string
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "1. Return all keys with * pattern",
|
||||
presetValues: map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
"key3": "value3",
|
||||
},
|
||||
command: []string{"KEYS", "*"},
|
||||
expectedResponse: []string{"key1", "key2", "key3"},
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "2. Return empty slice when no keys exist",
|
||||
presetValues: nil,
|
||||
command: []string{"KEYS", "*"},
|
||||
expectedResponse: []string{},
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "3. Return error when command is invalid",
|
||||
presetValues: nil,
|
||||
command: []string{"KEYS"},
|
||||
expectedResponse: nil,
|
||||
expectedError: errors.New(constants.WrongArgsResponse),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Clear the database before each test
|
||||
flushCommand := []resp.Value{
|
||||
resp.StringValue("FLUSHALL"),
|
||||
}
|
||||
if err = client.WriteArray(flushCommand); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !strings.EqualFold(res.String(), "ok") {
|
||||
t.Errorf("expected FLUSHALL response to be \"ok\", got %s", res.String())
|
||||
}
|
||||
|
||||
// preset values
|
||||
if test.presetValues != nil {
|
||||
for k, v := range test.presetValues {
|
||||
command := []resp.Value{
|
||||
resp.StringValue("SET"),
|
||||
resp.StringValue(k),
|
||||
resp.StringValue(v.(string)),
|
||||
}
|
||||
|
||||
if err = client.WriteArray(command); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !strings.EqualFold(res.String(), "ok") {
|
||||
t.Errorf("expected preset response to be \"ok\", got %s", res.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
command := make([]resp.Value, len(test.command))
|
||||
for i, c := range test.command {
|
||||
command[i] = resp.StringValue(c)
|
||||
}
|
||||
|
||||
if err = client.WriteArray(command); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
res, _, err = client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if test.expectedError != nil {
|
||||
if !strings.Contains(res.Error().Error(), test.expectedError.Error()) {
|
||||
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Convert response array to string slice
|
||||
responseArray := res.Array()
|
||||
responseStrings := make([]string, len(responseArray))
|
||||
for i, item := range responseArray {
|
||||
responseStrings[i] = item.String()
|
||||
}
|
||||
|
||||
// Sort both slices for comparison since KEYS command doesn't guarantee order
|
||||
sort.Strings(responseStrings)
|
||||
sort.Strings(test.expectedResponse)
|
||||
|
||||
if !reflect.DeepEqual(responseStrings, test.expectedResponse) {
|
||||
t.Errorf("expected response %v, got %v", test.expectedResponse, responseStrings)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// Certain commands will need to be tested in a server with an eviction policy.
|
||||
|
@@ -321,3 +321,12 @@ func existsKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) {
|
||||
WriteKeys: make([]string, 0),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func keysKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) {
|
||||
if len(cmd) != 2 {
|
||||
return internal.KeyExtractionFuncResult{}, errors.New(constants.WrongArgsResponse)
|
||||
}
|
||||
return internal.KeyExtractionFuncResult{
|
||||
ReadKeys: cmd[1:2],
|
||||
}, nil
|
||||
}
|
||||
|
@@ -149,4 +149,98 @@ func getCopyCommandOptions(cmd []string, options CopyOptions) (CopyOptions, erro
|
||||
default:
|
||||
return CopyOptions{}, fmt.Errorf("unknown option %s for copy command", strings.ToUpper(cmd[0]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func matchPattern(pattern string, key string) bool {
|
||||
/*
|
||||
Implementation of Redis-style pattern matching
|
||||
https://redis.io/docs/latest/commands/keys/
|
||||
*/
|
||||
patternLen := len(pattern)
|
||||
keyLen := len(key) // length of the key to match
|
||||
patternPos := 0 // position in the pattern
|
||||
keyPos := 0 // position in the key
|
||||
|
||||
for patternPos < patternLen {
|
||||
switch pattern[patternPos] {
|
||||
case '\\': // Match characters verbatum after slash
|
||||
if patternPos+1 < patternLen {
|
||||
patternPos++
|
||||
if keyPos >= keyLen || pattern[patternPos] != key[keyPos] {
|
||||
return false
|
||||
}
|
||||
keyPos++
|
||||
}
|
||||
case '?': // Match any single character (skip key position)
|
||||
// key position is at the end, return false
|
||||
if keyPos >= keyLen {
|
||||
return false
|
||||
}
|
||||
keyPos++
|
||||
case '*': // Match any sequence of characters
|
||||
// If pattern is at the end, return true
|
||||
if patternPos+1 >= patternLen {
|
||||
return true
|
||||
}
|
||||
// Use recursion to match the rest of the pattern at each position
|
||||
for i := keyPos; i <= keyLen; i++ {
|
||||
if matchPattern(pattern[patternPos+1:], key[i:]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
case '[': // Match any character in the character class brackets []
|
||||
// key position is at the end, return false
|
||||
if keyPos >= keyLen {
|
||||
return false
|
||||
}
|
||||
patternPos++ // skip the [ character
|
||||
// check if character class is negated (^)
|
||||
negate := false
|
||||
if patternPos < patternLen && pattern[patternPos] == '^' {
|
||||
negate = true
|
||||
patternPos++
|
||||
}
|
||||
|
||||
// look through all characters in the character class
|
||||
matched := false
|
||||
for patternPos < patternLen && pattern[patternPos] != ']' {
|
||||
// if character is escaped, check the next character
|
||||
if pattern[patternPos] == '\\' && patternPos+1 < patternLen {
|
||||
patternPos++
|
||||
if pattern[patternPos] == key[keyPos] {
|
||||
matched = true
|
||||
}
|
||||
// if character is a range, check if the key position is within the range
|
||||
} else if patternPos+2 < patternLen && pattern[patternPos+1] == '-' {
|
||||
// Handle range
|
||||
if key[keyPos] >= pattern[patternPos] && key[keyPos] <= pattern[patternPos+2] {
|
||||
matched = true
|
||||
}
|
||||
patternPos += 2
|
||||
// if character is a match, set matched to true
|
||||
} else if pattern[patternPos] == key[keyPos] {
|
||||
matched = true
|
||||
}
|
||||
patternPos++
|
||||
}
|
||||
// if pattern position is at the end, return false
|
||||
if patternPos >= patternLen {
|
||||
return false
|
||||
}
|
||||
// negate check: if matched is true and negate is true, return false
|
||||
if matched == negate {
|
||||
return false
|
||||
}
|
||||
keyPos++
|
||||
default: // Match literal character (just like slash but on the current key position)
|
||||
if keyPos >= keyLen || pattern[patternPos] != key[keyPos] {
|
||||
return false
|
||||
}
|
||||
keyPos++
|
||||
}
|
||||
patternPos++
|
||||
}
|
||||
|
||||
return keyPos == keyLen
|
||||
}
|
||||
|
@@ -138,6 +138,8 @@ type HandlerFuncParams struct {
|
||||
Connection *net.Conn
|
||||
// KeysExist returns a map that specifies which keys exist in the keyspace.
|
||||
KeysExist func(ctx context.Context, keys []string) map[string]bool
|
||||
// GetKeys returns all the keys in the keyspace.
|
||||
GetKeys func(ctx context.Context) []string
|
||||
// GetExpiry returns the expiry time of a key.
|
||||
GetExpiry func(ctx context.Context, key string) time.Time
|
||||
// GetHashExpiry returns the expiry time of a field in a key whose value is a hash.
|
||||
|
@@ -822,3 +822,18 @@ func (server *SugarDB) Exists(keys ...string) (int, error) {
|
||||
}
|
||||
return internal.ParseIntegerResponse(b)
|
||||
}
|
||||
|
||||
// Keys returns all of the keys matching the glob pattern of the given key.
|
||||
// Parameters:
|
||||
//
|
||||
// `pattern` - string - pattern of key to match on
|
||||
//
|
||||
// Returns: A string slice of all the matching keys. If there are no keys matching the pattern, an empty slice is returned.
|
||||
func (server *SugarDB) Keys(pattern string) ([]string, error) {
|
||||
b, err := server.handleCommand(server.context, internal.EncodeCommand([]string{"KEYS", pattern}), nil, false, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return internal.ParseStringArrayResponse(b)
|
||||
}
|
@@ -2143,4 +2143,105 @@ func TestSugarDB_Generic(t *testing.T) {
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TestSugarDB_KEYS", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
presetValues map[string]interface{}
|
||||
pattern string
|
||||
want []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "1. Return all keys with * pattern",
|
||||
presetValues: map[string]interface{}{
|
||||
"keys_key1": "value1",
|
||||
"keys_key2": "value2",
|
||||
"keys_key3": "value3",
|
||||
},
|
||||
pattern: "*",
|
||||
want: []string{"keys_key1", "keys_key2", "keys_key3"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "2. Return keys matching specific pattern",
|
||||
presetValues: map[string]interface{}{
|
||||
"keys_key1": "value1",
|
||||
"keys_key2": "value2",
|
||||
"other_key": "value3",
|
||||
},
|
||||
pattern: "keys_*",
|
||||
want: []string{"keys_key1", "keys_key2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "3. Return keys matching single character pattern",
|
||||
presetValues: map[string]interface{}{
|
||||
"keys_key1": "value1",
|
||||
"keys_key2": "value2",
|
||||
"keys_kex3": "value3",
|
||||
},
|
||||
pattern: "keys_key?",
|
||||
want: []string{"keys_key1", "keys_key2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "4. Return empty slice when no keys match pattern",
|
||||
presetValues: map[string]interface{}{
|
||||
"keys_key1": "value1",
|
||||
"keys_key2": "value2",
|
||||
},
|
||||
pattern: "nonexistent_*",
|
||||
want: []string{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "5. Return empty slice when no keys exist",
|
||||
presetValues: nil,
|
||||
pattern: "*",
|
||||
want: []string{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "6. Return keys matching character class pattern",
|
||||
presetValues: map[string]interface{}{
|
||||
"keys_key1": "value1",
|
||||
"keys_key2": "value2",
|
||||
"keys_kex3": "value3",
|
||||
},
|
||||
pattern: "keys_key[12]",
|
||||
want: []string{"keys_key1", "keys_key2"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a new server instance for each test case
|
||||
server := createSugarDB()
|
||||
t.Cleanup(func() {
|
||||
server.ShutDown()
|
||||
})
|
||||
|
||||
if tt.presetValues != nil {
|
||||
for k, v := range tt.presetValues {
|
||||
err := presetValue(server, context.Background(), k, v)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
got, err := server.Keys(tt.pattern)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KEYS() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
slices.Sort(got)
|
||||
slices.Sort(tt.want)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("KEYS() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@@ -134,6 +134,22 @@ func (server *SugarDB) keysExist(ctx context.Context, keys []string) map[string]
|
||||
return exists
|
||||
}
|
||||
|
||||
func (server *SugarDB) getKeys(ctx context.Context) []string {
|
||||
server.storeLock.RLock()
|
||||
defer server.storeLock.RUnlock()
|
||||
|
||||
database := ctx.Value("Database").(int)
|
||||
|
||||
keys := make([]string, len(server.store[database]))
|
||||
i := 0
|
||||
for key := range server.store[database] {
|
||||
keys[i] = key
|
||||
i++
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
func (server *SugarDB) getExpiry(ctx context.Context, key string) time.Time {
|
||||
server.storeLock.RLock()
|
||||
defer server.storeLock.RUnlock()
|
||||
|
@@ -44,6 +44,7 @@ func (server *SugarDB) getHandlerFuncParams(ctx context.Context, cmd []string, c
|
||||
Command: cmd,
|
||||
Connection: conn,
|
||||
KeysExist: server.keysExist,
|
||||
GetKeys: server.getKeys,
|
||||
GetExpiry: server.getExpiry,
|
||||
GetHashExpiry: server.getHashExpiry,
|
||||
GetValues: server.getValues,
|
||||
|
Reference in New Issue
Block a user