Implement KEYS command

This commit is contained in:
Nicole Streltsov
2025-04-13 20:05:48 -04:00
committed by GitHub
parent 38f9ca4ed7
commit 469a5e5233
12 changed files with 797 additions and 4810 deletions

View File

@@ -222,6 +222,7 @@ Benchmark script options:
* [INCR](https://sugardb.io/docs/commands/generic/incr)
* [INCRBY](https://sugardb.io/docs/commands/generic/incrby)
* [INCRBYFLOAT](https://sugardb.io/docs/commands/generic/incrbyfloat)
* [KEYS](https://sugardb.io/docs/commands/generic/keys)
* [MGET](https://sugardb.io/docs/commands/generic/mget)
* [MOVE](https://sugardb.io/docs/commands/generic/move)
* [MSET](https://sugardb.io/docs/commands/generic/mset)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,51 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# KEYS
### Syntax
```
KEYS
```
### Module
<span className="acl-category">generic</span>
### Categories
<span className="acl-category">slow</span>
<span className="acl-category">read</span>
<span className="acl-category">keyspace</span>
<span className="acl-category">dangerous</span>
### Description
Returns an array of keys that match the provided glob pattern. This follows the same pattern matching rules as Redis.
### Examples
<Tabs
defaultValue="go"
values={[
{ label: 'Go (Embedded)', value: 'go', },
{ label: 'CLI', value: 'cli', },
]}
>
<TabItem value="go">
Return the keys matching the pattern:
```go
db, err := sugardb.NewSugarDB()
if err != nil {
log.Fatal(err)
}
key, err := db.Keys("a??")
// keys can be eg. age, all
```
</TabItem>
<TabItem value="cli">
Return the keys matching the pattern:
```
> KEYS a??
age
all
```
</TabItem>
</Tabs>

View File

@@ -995,6 +995,38 @@ func handleMove(params internal.HandlerFuncParams) ([]byte, error) {
return []byte(fmt.Sprintf("+%v\r\n", 0)), nil
}
func handleKeys(params internal.HandlerFuncParams) ([]byte, error) {
keys, err := keysKeyFunc(params.Command)
if err != nil {
return nil, err
}
pattern := keys.ReadKeys[0]
storeKeys := params.GetKeys(params.Context)
var matchedKeys []string
// Special case for * pattern - return all keys
if pattern == "*" {
matchedKeys = storeKeys
} else {
// Find all matching keys using direct pattern matching
for _, key := range storeKeys {
if matchPattern(pattern, key) {
matchedKeys = append(matchedKeys, key)
}
}
}
// Build byte response
res := fmt.Sprintf("*%d\r\n", len(matchedKeys))
for _, key := range matchedKeys {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(key), key)
}
return []byte(res), nil
}
func Commands() []internal.Command {
return []internal.Command{
{
@@ -1414,5 +1446,15 @@ The REPLACE option removes the destination key before copying the value to it.`,
KeyExtractionFunc: existsKeyFunc,
HandlerFunc: handleExists,
},
{
Command: "keys",
Module: constants.GenericModule,
Categories: []string{constants.KeyspaceCategory, constants.ReadCategory, constants.SlowCategory, constants.DangerousCategory},
Description: "(KEYS pattern) Returns an array of keys that match the provided glob pattern.",
Sync: false,
Type: "BUILT_IN",
KeyExtractionFunc: keysKeyFunc,
HandlerFunc: handleKeys,
},
}
}

View File

@@ -21,6 +21,8 @@ import (
"strings"
"testing"
"time"
"sort"
"reflect"
"github.com/echovault/sugardb/internal"
"github.com/echovault/sugardb/internal/clock"
@@ -3892,6 +3894,155 @@ func Test_Generic(t *testing.T) {
}
})
t.Run("Test_handleKeys", func(t *testing.T) {
t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := sugardb.NewSugarDB(
sugardb.WithConfig(config.Config{
BindAddr: "localhost",
Port: uint16(port),
DataDir: "",
EvictionPolicy: constants.NoEviction,
}),
)
if err != nil {
t.Error(err)
return
}
go func() {
mockServer.Start()
}()
t.Cleanup(func() {
mockServer.ShutDown()
})
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
defer func() {
_ = conn.Close()
}()
client := resp.NewConn(conn)
tests := []struct {
name string
presetValues map[string]interface{}
command []string
expectedResponse []string
expectedError error
}{
{
name: "1. Return all keys with * pattern",
presetValues: map[string]interface{}{
"key1": "value1",
"key2": "value2",
"key3": "value3",
},
command: []string{"KEYS", "*"},
expectedResponse: []string{"key1", "key2", "key3"},
expectedError: nil,
},
{
name: "2. Return empty slice when no keys exist",
presetValues: nil,
command: []string{"KEYS", "*"},
expectedResponse: []string{},
expectedError: nil,
},
{
name: "3. Return error when command is invalid",
presetValues: nil,
command: []string{"KEYS"},
expectedResponse: nil,
expectedError: errors.New(constants.WrongArgsResponse),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Clear the database before each test
flushCommand := []resp.Value{
resp.StringValue("FLUSHALL"),
}
if err = client.WriteArray(flushCommand); err != nil {
t.Error(err)
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
}
if !strings.EqualFold(res.String(), "ok") {
t.Errorf("expected FLUSHALL response to be \"ok\", got %s", res.String())
}
// preset values
if test.presetValues != nil {
for k, v := range test.presetValues {
command := []resp.Value{
resp.StringValue("SET"),
resp.StringValue(k),
resp.StringValue(v.(string)),
}
if err = client.WriteArray(command); err != nil {
t.Error(err)
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
}
if !strings.EqualFold(res.String(), "ok") {
t.Errorf("expected preset response to be \"ok\", got %s", res.String())
}
}
}
command := make([]resp.Value, len(test.command))
for i, c := range test.command {
command[i] = resp.StringValue(c)
}
if err = client.WriteArray(command); err != nil {
t.Error(err)
}
res, _, err = client.ReadValue()
if err != nil {
t.Error(err)
}
if test.expectedError != nil {
if !strings.Contains(res.Error().Error(), test.expectedError.Error()) {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
}
return
}
// Convert response array to string slice
responseArray := res.Array()
responseStrings := make([]string, len(responseArray))
for i, item := range responseArray {
responseStrings[i] = item.String()
}
// Sort both slices for comparison since KEYS command doesn't guarantee order
sort.Strings(responseStrings)
sort.Strings(test.expectedResponse)
if !reflect.DeepEqual(responseStrings, test.expectedResponse) {
t.Errorf("expected response %v, got %v", test.expectedResponse, responseStrings)
}
})
}
})
}
// Certain commands will need to be tested in a server with an eviction policy.

View File

@@ -321,3 +321,12 @@ func existsKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) {
WriteKeys: make([]string, 0),
}, nil
}
func keysKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) {
if len(cmd) != 2 {
return internal.KeyExtractionFuncResult{}, errors.New(constants.WrongArgsResponse)
}
return internal.KeyExtractionFuncResult{
ReadKeys: cmd[1:2],
}, nil
}

View File

@@ -149,4 +149,98 @@ func getCopyCommandOptions(cmd []string, options CopyOptions) (CopyOptions, erro
default:
return CopyOptions{}, fmt.Errorf("unknown option %s for copy command", strings.ToUpper(cmd[0]))
}
}
}
func matchPattern(pattern string, key string) bool {
/*
Implementation of Redis-style pattern matching
https://redis.io/docs/latest/commands/keys/
*/
patternLen := len(pattern)
keyLen := len(key) // length of the key to match
patternPos := 0 // position in the pattern
keyPos := 0 // position in the key
for patternPos < patternLen {
switch pattern[patternPos] {
case '\\': // Match characters verbatum after slash
if patternPos+1 < patternLen {
patternPos++
if keyPos >= keyLen || pattern[patternPos] != key[keyPos] {
return false
}
keyPos++
}
case '?': // Match any single character (skip key position)
// key position is at the end, return false
if keyPos >= keyLen {
return false
}
keyPos++
case '*': // Match any sequence of characters
// If pattern is at the end, return true
if patternPos+1 >= patternLen {
return true
}
// Use recursion to match the rest of the pattern at each position
for i := keyPos; i <= keyLen; i++ {
if matchPattern(pattern[patternPos+1:], key[i:]) {
return true
}
}
return false
case '[': // Match any character in the character class brackets []
// key position is at the end, return false
if keyPos >= keyLen {
return false
}
patternPos++ // skip the [ character
// check if character class is negated (^)
negate := false
if patternPos < patternLen && pattern[patternPos] == '^' {
negate = true
patternPos++
}
// look through all characters in the character class
matched := false
for patternPos < patternLen && pattern[patternPos] != ']' {
// if character is escaped, check the next character
if pattern[patternPos] == '\\' && patternPos+1 < patternLen {
patternPos++
if pattern[patternPos] == key[keyPos] {
matched = true
}
// if character is a range, check if the key position is within the range
} else if patternPos+2 < patternLen && pattern[patternPos+1] == '-' {
// Handle range
if key[keyPos] >= pattern[patternPos] && key[keyPos] <= pattern[patternPos+2] {
matched = true
}
patternPos += 2
// if character is a match, set matched to true
} else if pattern[patternPos] == key[keyPos] {
matched = true
}
patternPos++
}
// if pattern position is at the end, return false
if patternPos >= patternLen {
return false
}
// negate check: if matched is true and negate is true, return false
if matched == negate {
return false
}
keyPos++
default: // Match literal character (just like slash but on the current key position)
if keyPos >= keyLen || pattern[patternPos] != key[keyPos] {
return false
}
keyPos++
}
patternPos++
}
return keyPos == keyLen
}

View File

@@ -138,6 +138,8 @@ type HandlerFuncParams struct {
Connection *net.Conn
// KeysExist returns a map that specifies which keys exist in the keyspace.
KeysExist func(ctx context.Context, keys []string) map[string]bool
// GetKeys returns all the keys in the keyspace.
GetKeys func(ctx context.Context) []string
// GetExpiry returns the expiry time of a key.
GetExpiry func(ctx context.Context, key string) time.Time
// GetHashExpiry returns the expiry time of a field in a key whose value is a hash.

View File

@@ -822,3 +822,18 @@ func (server *SugarDB) Exists(keys ...string) (int, error) {
}
return internal.ParseIntegerResponse(b)
}
// Keys returns all of the keys matching the glob pattern of the given key.
// Parameters:
//
// `pattern` - string - pattern of key to match on
//
// Returns: A string slice of all the matching keys. If there are no keys matching the pattern, an empty slice is returned.
func (server *SugarDB) Keys(pattern string) ([]string, error) {
b, err := server.handleCommand(server.context, internal.EncodeCommand([]string{"KEYS", pattern}), nil, false, true)
if err != nil {
return nil, err
}
return internal.ParseStringArrayResponse(b)
}

View File

@@ -2143,4 +2143,105 @@ func TestSugarDB_Generic(t *testing.T) {
})
}
})
t.Run("TestSugarDB_KEYS", func(t *testing.T) {
tests := []struct {
name string
presetValues map[string]interface{}
pattern string
want []string
wantErr bool
}{
{
name: "1. Return all keys with * pattern",
presetValues: map[string]interface{}{
"keys_key1": "value1",
"keys_key2": "value2",
"keys_key3": "value3",
},
pattern: "*",
want: []string{"keys_key1", "keys_key2", "keys_key3"},
wantErr: false,
},
{
name: "2. Return keys matching specific pattern",
presetValues: map[string]interface{}{
"keys_key1": "value1",
"keys_key2": "value2",
"other_key": "value3",
},
pattern: "keys_*",
want: []string{"keys_key1", "keys_key2"},
wantErr: false,
},
{
name: "3. Return keys matching single character pattern",
presetValues: map[string]interface{}{
"keys_key1": "value1",
"keys_key2": "value2",
"keys_kex3": "value3",
},
pattern: "keys_key?",
want: []string{"keys_key1", "keys_key2"},
wantErr: false,
},
{
name: "4. Return empty slice when no keys match pattern",
presetValues: map[string]interface{}{
"keys_key1": "value1",
"keys_key2": "value2",
},
pattern: "nonexistent_*",
want: []string{},
wantErr: false,
},
{
name: "5. Return empty slice when no keys exist",
presetValues: nil,
pattern: "*",
want: []string{},
wantErr: false,
},
{
name: "6. Return keys matching character class pattern",
presetValues: map[string]interface{}{
"keys_key1": "value1",
"keys_key2": "value2",
"keys_kex3": "value3",
},
pattern: "keys_key[12]",
want: []string{"keys_key1", "keys_key2"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a new server instance for each test case
server := createSugarDB()
t.Cleanup(func() {
server.ShutDown()
})
if tt.presetValues != nil {
for k, v := range tt.presetValues {
err := presetValue(server, context.Background(), k, v)
if err != nil {
t.Error(err)
return
}
}
}
got, err := server.Keys(tt.pattern)
if (err != nil) != tt.wantErr {
t.Errorf("KEYS() error = %v, wantErr %v", err, tt.wantErr)
return
}
slices.Sort(got)
slices.Sort(tt.want)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("KEYS() got = %v, want %v", got, tt.want)
}
})
}
})
}

View File

@@ -134,6 +134,22 @@ func (server *SugarDB) keysExist(ctx context.Context, keys []string) map[string]
return exists
}
func (server *SugarDB) getKeys(ctx context.Context) []string {
server.storeLock.RLock()
defer server.storeLock.RUnlock()
database := ctx.Value("Database").(int)
keys := make([]string, len(server.store[database]))
i := 0
for key := range server.store[database] {
keys[i] = key
i++
}
return keys
}
func (server *SugarDB) getExpiry(ctx context.Context, key string) time.Time {
server.storeLock.RLock()
defer server.storeLock.RUnlock()

View File

@@ -44,6 +44,7 @@ func (server *SugarDB) getHandlerFuncParams(ctx context.Context, cmd []string, c
Command: cmd,
Connection: conn,
KeysExist: server.keysExist,
GetKeys: server.getKeys,
GetExpiry: server.getExpiry,
GetHashExpiry: server.getHashExpiry,
GetValues: server.getValues,