Update Hash to allow for expirey commands (#146)

* Convert hash to composite type. Fixed broken Hash commands from Hash refactor. Coverage and fixed broken test - @osteensco
This commit is contained in:
osteensco
2024-11-03 13:24:31 -06:00
committed by GitHub
parent 05b7601752
commit 09640082c4
8 changed files with 3577 additions and 3517 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -813,7 +813,11 @@ func handleType(params internal.HandlerFuncParams) ([]byte, error) {
case reflect.Slice: case reflect.Slice:
type_string = "list" type_string = "list"
case reflect.Map: case reflect.Map:
type_string = "hash" if t.Elem().Name() == "HashValue" {
type_string = "hash"
} else {
type_string = t.Elem().Name()
}
case reflect.Pointer: case reflect.Pointer:
if t.Elem().Name() == "Set" { if t.Elem().Name() == "Set" {
type_string = "set" type_string = "set"

View File

@@ -2786,7 +2786,7 @@ func Test_Generic(t *testing.T) {
}() }()
client := resp.NewConn(conn) client := resp.NewConn(conn)
expected := "Key" expected := "key"
if err = client.WriteArray([]resp.Value{resp.StringValue("RANDOMKEY")}); err != nil { if err = client.WriteArray([]resp.Value{resp.StringValue("RANDOMKEY")}); err != nil {
t.Error(err) t.Error(err)
} }
@@ -2796,7 +2796,7 @@ func Test_Generic(t *testing.T) {
t.Error(err) t.Error(err)
} }
if !strings.Contains(res.String(), expected) { if !strings.Contains(strings.ToLower(res.String()), expected) {
t.Errorf("expected a key containing substring '%s', got %s", expected, res.String()) t.Errorf("expected a key containing substring '%s', got %s", expected, res.String())
} }
}) })

View File

@@ -33,27 +33,25 @@ func handleHSET(params internal.HandlerFuncParams) ([]byte, error) {
key := keys.WriteKeys[0] key := keys.WriteKeys[0]
keyExists := params.KeysExist(params.Context, keys.WriteKeys)[key] keyExists := params.KeysExist(params.Context, keys.WriteKeys)[key]
entries := make(map[string]interface{}) entries := Hash{}
if len(params.Command[2:])%2 != 0 { if len(params.Command[2:])%2 != 0 {
return nil, errors.New("each field must have a corresponding value") return nil, errors.New("each field must have a corresponding value")
} }
for i := 2; i <= len(params.Command)-2; i += 2 { for i := 2; i <= len(params.Command)-2; i += 2 {
entries[params.Command[i]] = internal.AdaptType(params.Command[i+1]) k := params.Command[i]
entries[k] = HashValue{Value: internal.AdaptType(params.Command[i+1])}
} }
if !keyExists { if !keyExists {
if err != nil {
return nil, err
}
if err = params.SetValues(params.Context, map[string]interface{}{key: entries}); err != nil { if err = params.SetValues(params.Context, map[string]interface{}{key: entries}); err != nil {
return nil, err return nil, err
} }
return []byte(fmt.Sprintf(":%d\r\n", len(entries))), nil return []byte(fmt.Sprintf(":%d\r\n", len(entries))), nil
} }
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok { if !ok {
// Not hash, save the entries map directly. // Not hash, save the entries map directly.
if err = params.SetValues(params.Context, map[string]interface{}{key: entries}); err != nil { if err = params.SetValues(params.Context, map[string]interface{}{key: entries}); err != nil {
@@ -67,18 +65,19 @@ func handleHSET(params internal.HandlerFuncParams) ([]byte, error) {
case "hsetnx": case "hsetnx":
// Handle HSETNX // Handle HSETNX
for field, _ := range entries { for field, _ := range entries {
if hash[field] == nil { if _, ok := hash[field]; !ok {
count += 1 count += 1
} }
} }
for field, value := range hash { for field, value := range hash {
entries[field] = value entries[field] = value
} }
default: default:
// Handle HSET // Handle HSET
for field, value := range hash { for field, value := range hash {
if entries[field] == nil { if entries[field].Value == nil {
entries[field] = value entries[field] = HashValue{Value: value}
} }
} }
count = len(entries) count = len(entries)
@@ -105,29 +104,29 @@ func handleHGET(params internal.HandlerFuncParams) ([]byte, error) {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
} }
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
var value interface{} var value HashValue
res := fmt.Sprintf("*%d\r\n", len(fields)) res := fmt.Sprintf("*%d\r\n", len(fields))
for _, field := range fields { for _, field := range fields {
value = hash[field] value = hash[field]
if value == nil { if value.Value == nil {
res += "$-1\r\n" res += "$-1\r\n"
continue continue
} }
if s, ok := value.(string); ok { if s, ok := value.Value.(string); ok {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s) res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s)
continue continue
} }
if d, ok := value.(int); ok { if d, ok := value.Value.(int); ok {
res += fmt.Sprintf(":%d\r\n", d) res += fmt.Sprintf(":%d\r\n", d)
continue continue
} }
if f, ok := value.(float64); ok { if f, ok := value.Value.(float64); ok {
fs := strconv.FormatFloat(f, 'f', -1, 64) fs := strconv.FormatFloat(f, 'f', -1, 64)
res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs) res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs)
continue continue
@@ -150,14 +149,15 @@ func handleHMGET(params internal.HandlerFuncParams) ([]byte, error) {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
} }
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
fields := params.Command[2:] fields := params.Command[2:]
var value interface{} var value HashValue
res := fmt.Sprintf("*%d\r\n", len(fields)) res := fmt.Sprintf("*%d\r\n", len(fields))
for _, field := range fields { for _, field := range fields {
value, ok = hash[field] value, ok = hash[field]
@@ -166,15 +166,15 @@ func handleHMGET(params internal.HandlerFuncParams) ([]byte, error) {
continue continue
} }
if s, ok := value.(string); ok { if s, ok := value.Value.(string); ok {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s) res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s)
continue continue
} }
if d, ok := value.(int); ok { if d, ok := value.Value.(int); ok {
res += fmt.Sprintf(":%d\r\n", d) res += fmt.Sprintf(":%d\r\n", d)
continue continue
} }
if f, ok := value.(float64); ok { if f, ok := value.Value.(float64); ok {
fs := strconv.FormatFloat(f, 'f', -1, 64) fs := strconv.FormatFloat(f, 'f', -1, 64)
res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs) res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs)
continue continue
@@ -199,30 +199,30 @@ func handleHSTRLEN(params internal.HandlerFuncParams) ([]byte, error) {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
} }
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
var value interface{} var value HashValue
res := fmt.Sprintf("*%d\r\n", len(fields)) res := fmt.Sprintf("*%d\r\n", len(fields))
for _, field := range fields { for _, field := range fields {
value = hash[field] value = hash[field]
if value == nil { if value.Value == nil {
res += ":0\r\n" res += ":0\r\n"
continue continue
} }
if s, ok := value.(string); ok { if s, ok := value.Value.(string); ok {
res += fmt.Sprintf(":%d\r\n", len(s)) res += fmt.Sprintf(":%d\r\n", len(s))
continue continue
} }
if f, ok := value.(float64); ok { if f, ok := value.Value.(float64); ok {
fs := strconv.FormatFloat(f, 'f', -1, 64) fs := strconv.FormatFloat(f, 'f', -1, 64)
res += fmt.Sprintf(":%d\r\n", len(fs)) res += fmt.Sprintf(":%d\r\n", len(fs))
continue continue
} }
if d, ok := value.(int); ok { if d, ok := value.Value.(int); ok {
res += fmt.Sprintf(":%d\r\n", len(strconv.Itoa(d))) res += fmt.Sprintf(":%d\r\n", len(strconv.Itoa(d)))
continue continue
} }
@@ -245,23 +245,23 @@ func handleHVALS(params internal.HandlerFuncParams) ([]byte, error) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
res := fmt.Sprintf("*%d\r\n", len(hash)) res := fmt.Sprintf("*%d\r\n", len(hash))
for _, val := range hash { for _, val := range hash {
if s, ok := val.(string); ok { if s, ok := val.Value.(string); ok {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s) res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s)
continue continue
} }
if f, ok := val.(float64); ok { if f, ok := val.Value.(float64); ok {
fs := strconv.FormatFloat(f, 'f', -1, 64) fs := strconv.FormatFloat(f, 'f', -1, 64)
res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs) res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs)
continue continue
} }
if d, ok := val.(int); ok { if d, ok := val.Value.(int); ok {
res += fmt.Sprintf(":%d\r\n", d) res += fmt.Sprintf(":%d\r\n", d)
} }
} }
@@ -303,7 +303,7 @@ func handleHRANDFIELD(params internal.HandlerFuncParams) ([]byte, error) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
@@ -317,16 +317,16 @@ func handleHRANDFIELD(params internal.HandlerFuncParams) ([]byte, error) {
for field, value := range hash { for field, value := range hash {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(field), field) res += fmt.Sprintf("$%d\r\n%s\r\n", len(field), field)
if withvalues { if withvalues {
if s, ok := value.(string); ok { if s, ok := value.Value.(string); ok {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s) res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s)
continue continue
} }
if f, ok := value.(float64); ok { if f, ok := value.Value.(float64); ok {
fs := strconv.FormatFloat(f, 'f', -1, 64) fs := strconv.FormatFloat(f, 'f', -1, 64)
res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs) res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs)
continue continue
} }
if d, ok := value.(int); ok { if d, ok := value.Value.(int); ok {
res += fmt.Sprintf(":%d\r\n", d) res += fmt.Sprintf(":%d\r\n", d)
continue continue
} }
@@ -362,16 +362,16 @@ func handleHRANDFIELD(params internal.HandlerFuncParams) ([]byte, error) {
for _, field := range pluckedFields { for _, field := range pluckedFields {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(field), field) res += fmt.Sprintf("$%d\r\n%s\r\n", len(field), field)
if withvalues { if withvalues {
if s, ok := hash[field].(string); ok { if s, ok := hash[field].Value.(string); ok {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s) res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s)
continue continue
} }
if f, ok := hash[field].(float64); ok { if f, ok := hash[field].Value.(float64); ok {
fs := strconv.FormatFloat(f, 'f', -1, 64) fs := strconv.FormatFloat(f, 'f', -1, 64)
res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs) res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs)
continue continue
} }
if d, ok := hash[field].(int); ok { if d, ok := hash[field].Value.(int); ok {
res += fmt.Sprintf(":%d\r\n", d) res += fmt.Sprintf(":%d\r\n", d)
continue continue
} }
@@ -394,7 +394,7 @@ func handleHLEN(params internal.HandlerFuncParams) ([]byte, error) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
@@ -415,7 +415,7 @@ func handleHKEYS(params internal.HandlerFuncParams) ([]byte, error) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
@@ -456,15 +456,15 @@ func handleHINCRBY(params internal.HandlerFuncParams) ([]byte, error) {
} }
if !keyExists { if !keyExists {
hash := make(map[string]interface{}) hash := make(Hash)
if strings.EqualFold(params.Command[0], "hincrbyfloat") { if strings.EqualFold(params.Command[0], "hincrbyfloat") {
hash[field] = floatIncrement hash[field] = HashValue{Value: floatIncrement}
if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil { if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil {
return nil, err return nil, err
} }
return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(floatIncrement, 'f', -1, 64))), nil return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(floatIncrement, 'f', -1, 64))), nil
} else { } else {
hash[field] = intIncrement hash[field] = HashValue{Value: intIncrement}
if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil { if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil {
return nil, err return nil, err
} }
@@ -472,31 +472,31 @@ func handleHINCRBY(params internal.HandlerFuncParams) ([]byte, error) {
} }
} }
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
if hash[field] == nil { if hash[field].Value == nil {
hash[field] = 0 hash[field] = HashValue{Value: 0}
} }
switch hash[field].(type) { switch hash[field].Value.(type) {
default: default:
return nil, fmt.Errorf("value at field %s is not a number", field) return nil, fmt.Errorf("value at field %s is not a number", field)
case int: case int:
i, _ := hash[field].(int) i, _ := hash[field].Value.(int)
if strings.EqualFold(params.Command[0], "hincrbyfloat") { if strings.EqualFold(params.Command[0], "hincrbyfloat") {
hash[field] = float64(i) + floatIncrement hash[field] = HashValue{Value: float64(i) + floatIncrement}
} else { } else {
hash[field] = i + intIncrement hash[field] = HashValue{Value: i + intIncrement}
} }
case float64: case float64:
f, _ := hash[field].(float64) f, _ := hash[field].Value.(float64)
if strings.EqualFold(params.Command[0], "hincrbyfloat") { if strings.EqualFold(params.Command[0], "hincrbyfloat") {
hash[field] = f + floatIncrement hash[field] = HashValue{Value: f + floatIncrement}
} else { } else {
hash[field] = f + float64(intIncrement) hash[field] = HashValue{Value: f + float64(intIncrement)}
} }
} }
@@ -504,11 +504,11 @@ func handleHINCRBY(params internal.HandlerFuncParams) ([]byte, error) {
return nil, err return nil, err
} }
if f, ok := hash[field].(float64); ok { if f, ok := hash[field].Value.(float64); ok {
return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(f, 'f', -1, 64))), nil return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(f, 'f', -1, 64))), nil
} }
i, _ := hash[field].(int) i, _ := hash[field].Value.(int)
return []byte(fmt.Sprintf(":%d\r\n", i)), nil return []byte(fmt.Sprintf(":%d\r\n", i)), nil
} }
@@ -525,7 +525,7 @@ func handleHGETALL(params internal.HandlerFuncParams) ([]byte, error) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
@@ -533,14 +533,16 @@ func handleHGETALL(params internal.HandlerFuncParams) ([]byte, error) {
res := fmt.Sprintf("*%d\r\n", len(hash)*2) res := fmt.Sprintf("*%d\r\n", len(hash)*2)
for field, value := range hash { for field, value := range hash {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(field), field) res += fmt.Sprintf("$%d\r\n%s\r\n", len(field), field)
if s, ok := value.(string); ok { if s, ok := value.Value.(string); ok {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s) res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s)
} }
if f, ok := value.(float64); ok {
if f, ok := value.Value.(float64); ok {
fs := strconv.FormatFloat(f, 'f', -1, 64) fs := strconv.FormatFloat(f, 'f', -1, 64)
res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs) res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs)
} }
if d, ok := value.(int); ok {
if d, ok := value.Value.(int); ok {
res += fmt.Sprintf(":%d\r\n", d) res += fmt.Sprintf(":%d\r\n", d)
} }
} }
@@ -562,12 +564,12 @@ func handleHEXISTS(params internal.HandlerFuncParams) ([]byte, error) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
if hash[field] != nil { if hash[field].Value != nil {
return []byte(":1\r\n"), nil return []byte(":1\r\n"), nil
} }
@@ -588,7 +590,7 @@ func handleHDEL(params internal.HandlerFuncParams) ([]byte, error) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key) return nil, fmt.Errorf("value at %s is not a hash", key)
} }
@@ -596,7 +598,7 @@ func handleHDEL(params internal.HandlerFuncParams) ([]byte, error) {
count := 0 count := 0
for _, field := range fields { for _, field := range fields {
if hash[field] != nil { if hash[field].Value != nil {
delete(hash, field) delete(hash, field)
count += 1 count += 1
} }
@@ -742,5 +744,14 @@ Return the string length of the values stored at the specified fields. 0 if the
KeyExtractionFunc: hdelKeyFunc, KeyExtractionFunc: hdelKeyFunc,
HandlerFunc: handleHDEL, HandlerFunc: handleHDEL,
}, },
// {
// Command: "hexpire",
// Module: constants.HashModule,
// Categories: []string{constants.HashCategory, constants.WriteCategory, constants.FastCategory},
// Description: `(HEXPIRE key seconds [NX | XX | GT | LT] FIELDS numfields field [field ...]) Sets the expiration, in seconds, of a field in a hash.`,
// Sync: true,
// KeyExtractionFunc: hexpireKeyFunc,
// HandlerFunc: handleHEXPIRE,
// },
} }
} }

View File

@@ -16,15 +16,17 @@ package hash_test
import ( import (
"errors" "errors"
"github.com/echovault/sugardb/internal"
"github.com/echovault/sugardb/internal/config"
"github.com/echovault/sugardb/internal/constants"
"github.com/echovault/sugardb/sugardb"
"github.com/tidwall/resp"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"github.com/echovault/sugardb/internal"
"github.com/echovault/sugardb/internal/config"
"github.com/echovault/sugardb/internal/constants"
"github.com/echovault/sugardb/internal/modules/hash"
"github.com/echovault/sugardb/sugardb"
"github.com/tidwall/resp"
) )
func Test_Hash(t *testing.T) { func Test_Hash(t *testing.T) {
@@ -1534,15 +1536,15 @@ func Test_Hash(t *testing.T) {
key string key string
presetValue interface{} presetValue interface{}
command []string command []string
expectedResponse map[string]string expectedResponse hash.Hash
expectedError error expectedError error
}{ }{
{ {
name: "1. Return an array containing all the fields and values of the hash", name: "1. Return an array containing all the fields and values of the hash",
key: "HGetAllKey1", key: "HGetAllKey1",
presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, presetValue: hash.Hash{"field1": hash.HashValue{Value: "value1"}, "field2": hash.HashValue{Value: "123456789"}, "field3": hash.HashValue{Value: "3.142"}},
command: []string{"HGETALL", "HGetAllKey1"}, command: []string{"HGETALL", "HGetAllKey1"},
expectedResponse: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, expectedResponse: hash.Hash{"field1": hash.HashValue{Value: "value1"}, "field2": hash.HashValue{Value: "123456789"}, "field3": hash.HashValue{Value: "3.142"}},
expectedError: nil, expectedError: nil,
}, },
{ {
@@ -1593,15 +1595,15 @@ func Test_Hash(t *testing.T) {
resp.StringValue(test.presetValue.(string)), resp.StringValue(test.presetValue.(string)),
} }
expected = "ok" expected = "ok"
case map[string]string: case hash.Hash:
command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)}
for key, value := range test.presetValue.(map[string]string) { for key, value := range test.presetValue.(hash.Hash) {
command = append(command, []resp.Value{ command = append(command, []resp.Value{
resp.StringValue(key), resp.StringValue(key),
resp.StringValue(value)}..., resp.StringValue(value.Value.(string))}...,
) )
} }
expected = strconv.Itoa(len(test.presetValue.(map[string]string))) expected = strconv.Itoa(len(test.presetValue.(hash.Hash)))
} }
if err = client.WriteArray(command); err != nil { if err = client.WriteArray(command); err != nil {
@@ -1647,7 +1649,8 @@ func Test_Hash(t *testing.T) {
for i, item := range res.Array() { for i, item := range res.Array() {
if i%2 == 0 { if i%2 == 0 {
field := item.String() field := item.String()
value := res.Array()[i+1].String() value := hash.HashValue{Value: res.Array()[i+1].String()}
if test.expectedResponse[field] != value { if test.expectedResponse[field] != value {
t.Errorf("expected value at field \"%s\" to be \"%s\", got \"%s\"", field, test.expectedResponse[field], value) t.Errorf("expected value at field \"%s\" to be \"%s\", got \"%s\"", field, test.expectedResponse[field], value)
} }

View File

@@ -0,0 +1,48 @@
package hash
import (
"time"
"unsafe"
"github.com/echovault/sugardb/internal/constants"
)
type HashValue struct {
Value interface{}
ExpireAt time.Time
}
type Hash map[string]HashValue
func (h Hash) GetMem() int64 {
var size int64
// Map headers
size += int64(unsafe.Sizeof(h))
for key, val := range h {
size += int64(unsafe.Sizeof(key))
size += int64(len(key))
size += int64(unsafe.Sizeof(val))
size += int64(unsafe.Sizeof(val.ExpireAt))
switch vt := val.Value.(type) {
// AdaptType() will always ensure data type is of string, float64 or int.
case nil:
size += 0
case int:
size += int64(unsafe.Sizeof(vt))
case float64, int64:
size += 8
case string:
size += int64(unsafe.Sizeof(vt))
size += int64(len(vt))
}
}
return size
}
var _ constants.CompositeType = (*Hash)(nil)

View File

@@ -51,29 +51,6 @@ func (k *KeyData) GetMem() (int64, error) {
size += int64(unsafe.Sizeof(v)) size += int64(unsafe.Sizeof(v))
size += int64(len(v)) size += int64(len(v))
// handle hash
// AdaptType() will always ensure data type is of string, float64 or int.
case map[string]interface{}:
// Map headers
size += int64(unsafe.Sizeof(v))
for key, val := range v {
size += int64(unsafe.Sizeof(key))
size += int64(len(key))
switch vt := val.(type) {
case nil:
size += 0
case int:
size += int64(unsafe.Sizeof(vt))
case float64, int64:
size += 8
case string:
size += int64(unsafe.Sizeof(vt))
size += int64(len(vt))
}
}
// handle list // handle list
case []string: case []string:
for _, s := range v { for _, s := range v {
@@ -81,7 +58,7 @@ func (k *KeyData) GetMem() (int64, error) {
size += int64(len(s)) size += int64(len(s))
} }
// handle non primitive datatypes like set and sorted set // handle non primitive datatypes like hash, set, and sorted set
case constants.CompositeType: case constants.CompositeType:
size += k.Value.(constants.CompositeType).GetMem() size += k.Value.(constants.CompositeType).GetMem()

View File

@@ -19,6 +19,8 @@ import (
"reflect" "reflect"
"slices" "slices"
"testing" "testing"
"github.com/echovault/sugardb/internal/modules/hash"
) )
func TestSugarDB_HDEL(t *testing.T) { func TestSugarDB_HDEL(t *testing.T) {
@@ -33,20 +35,29 @@ func TestSugarDB_HDEL(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{ {
name: "Return count of deleted fields in the specified hash", name: "Return count of deleted fields in the specified hash",
key: "key1", key: "key1",
presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142, "field7": "value7"}, presetValue: hash.Hash{
fields: []string{"field1", "field2", "field3", "field4", "field5", "field6"}, "field1": {Value: "value1"},
want: 3, "field2": {Value: 123456789},
wantErr: false, "field3": {Value: 3.142},
"field7": {Value: "value7"},
},
fields: []string{"field1", "field2", "field3", "field4", "field5", "field6"},
want: 3,
wantErr: false,
}, },
{ {
name: "0 response when passing delete fields that are non-existent on valid hash", name: "0 response when passing delete fields that are non-existent on valid hash",
key: "key2", key: "key2",
presetValue: map[string]interface{}{"field1": "value1", "field2": "value2", "field3": "value3"}, presetValue: hash.Hash{
fields: []string{"field4", "field5", "field6"}, "field1": {Value: "value1"},
want: 0, "field2": {Value: "value2"},
wantErr: false, "field3": {Value: "value3"},
},
fields: []string{"field4", "field5", "field6"},
want: 0,
wantErr: false,
}, },
{ {
name: "0 response when trying to call HDEL on non-existent key", name: "0 response when trying to call HDEL on non-existent key",
@@ -98,16 +109,20 @@ func TestSugarDB_HEXISTS(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{ {
name: "Return 1 if the field exists in the hash", name: "Return 1 if the field exists in the hash",
presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, presetValue: hash.Hash{
key: "key1", "field1": {Value: "value1"},
field: "field1", "field2": {Value: 123456789},
want: true, "field3": {Value: 3.142},
wantErr: false, },
key: "key1",
field: "field1",
want: true,
wantErr: false,
}, },
{ {
name: "False response when trying to call HEXISTS on non-existent key", name: "False response when trying to call HEXISTS on non-existent key",
presetValue: map[string]interface{}{}, presetValue: hash.Hash{},
key: "key2", key: "key2",
field: "field1", field: "field1",
want: false, want: false,
@@ -154,16 +169,20 @@ func TestSugarDB_HGETALL(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{ {
name: "Return an array containing all the fields and values of the hash", name: "Return an array containing all the fields and values of the hash",
key: "key1", key: "key1",
presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, presetValue: hash.Hash{
want: []string{"field1", "value1", "field2", "123456789", "field3", "3.142"}, "field1": {Value: "value1"},
wantErr: false, "field2": {Value: 123456789},
"field3": {Value: 3.142},
},
want: []string{"field1", "value1", "field2", "123456789", "field3", "3.142"},
wantErr: false,
}, },
{ {
name: "Empty array response when trying to call HGETALL on non-existent key", name: "Empty array response when trying to call HGETALL on non-existent key",
key: "key2", key: "key2",
presetValue: map[string]interface{}{}, presetValue: hash.Hash{},
want: []string{}, want: []string{},
wantErr: false, wantErr: false,
}, },
@@ -244,7 +263,7 @@ func TestSugarDB_HINCRBY(t *testing.T) {
}, },
{ {
name: "Increment by integer on existing hash", name: "Increment by integer on existing hash",
presetValue: map[string]interface{}{"field1": 1}, presetValue: hash.Hash{"field1": {Value: 1}},
incr_type: HINCRBY, incr_type: HINCRBY,
key: "key3", key: "key3",
field: "field1", field: "field1",
@@ -254,7 +273,7 @@ func TestSugarDB_HINCRBY(t *testing.T) {
}, },
{ {
name: "Increment by float on an existing hash", name: "Increment by float on an existing hash",
presetValue: map[string]interface{}{"field1": 3.142}, presetValue: hash.Hash{"field1": {Value: 3.142}},
incr_type: HINCRBYFLOAT, incr_type: HINCRBYFLOAT,
key: "key4", key: "key4",
field: "field1", field: "field1",
@@ -274,7 +293,7 @@ func TestSugarDB_HINCRBY(t *testing.T) {
}, },
{ {
name: "Error when trying to increment a hash field that is not a number", name: "Error when trying to increment a hash field that is not a number",
presetValue: map[string]interface{}{"field1": "value1"}, presetValue: hash.Hash{"field1": {Value: "value1"}},
incr_type: HINCRBY, incr_type: HINCRBY,
key: "key10", key: "key10",
field: "field1", field: "field1",
@@ -326,15 +345,19 @@ func TestSugarDB_HKEYS(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{ {
name: "Return an array containing all the keys of the hash", name: "Return an array containing all the keys of the hash",
presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, presetValue: hash.Hash{
key: "key1", "field1": {Value: "value1"},
want: []string{"field1", "field2", "field3"}, "field2": {Value: 123456789},
wantErr: false, "field3": {Value: 3.142},
},
key: "key1",
want: []string{"field1", "field2", "field3"},
wantErr: false,
}, },
{ {
name: "Empty array response when trying to call HKEYS on non-existent key", name: "Empty array response when trying to call HKEYS on non-existent key",
presetValue: map[string]interface{}{}, presetValue: hash.Hash{},
key: "key2", key: "key2",
want: []string{}, want: []string{},
wantErr: false, wantErr: false,
@@ -384,11 +407,15 @@ func TestSugarDB_HLEN(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{ {
name: "Return the correct length of the hash", name: "Return the correct length of the hash",
presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, presetValue: hash.Hash{
key: "key1", "field1": {Value: "value1"},
want: 3, "field2": {Value: 123456789},
wantErr: false, "field3": {Value: 3.142},
},
key: "key1",
want: 3,
wantErr: false,
}, },
{ {
name: "0 Response when trying to call HLEN on non-existent key", name: "0 Response when trying to call HLEN on non-existent key",
@@ -439,31 +466,39 @@ func TestSugarDB_HRANDFIELD(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{ {
name: "Get a random field", name: "Get a random field",
presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, presetValue: hash.Hash{
key: "key1", "field1": {Value: "value1"},
options: HRandFieldOptions{Count: 1}, "field2": {Value: 123456789},
wantCount: 1, "field3": {Value: 3.142},
want: []string{"field1", "field2", "field3"}, },
wantErr: false, key: "key1",
options: HRandFieldOptions{Count: 1},
wantCount: 1,
want: []string{"field1", "field2", "field3"},
wantErr: false,
}, },
{ {
name: "Get a random field with a value", name: "Get a random field with a value",
presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, presetValue: hash.Hash{
key: "key2", "field1": {Value: "value1"},
options: HRandFieldOptions{WithValues: true, Count: 1}, "field2": {Value: 123456789},
wantCount: 2, "field3": {Value: 3.142},
want: []string{"field1", "value1", "field2", "123456789", "field3", "3.142"}, },
wantErr: false, key: "key2",
options: HRandFieldOptions{WithValues: true, Count: 1},
wantCount: 2,
want: []string{"field1", "value1", "field2", "123456789", "field3", "3.142"},
wantErr: false,
}, },
{ {
name: "Get several random fields", name: "Get several random fields",
presetValue: map[string]interface{}{ presetValue: hash.Hash{
"field1": "value1", "field1": {Value: "value1"},
"field2": 123456789, "field2": {Value: 123456789},
"field3": 3.142, "field3": {Value: 3.142},
"field4": "value4", "field4": {Value: "value4"},
"field5": "value5", "field5": {Value: "value6"},
}, },
key: "key3", key: "key3",
options: HRandFieldOptions{Count: 3}, options: HRandFieldOptions{Count: 3},
@@ -473,12 +508,12 @@ func TestSugarDB_HRANDFIELD(t *testing.T) {
}, },
{ {
name: "Get several random fields with their corresponding values", name: "Get several random fields with their corresponding values",
presetValue: map[string]interface{}{ presetValue: hash.Hash{
"field1": "value1", "field1": {Value: "value1"},
"field2": 123456789, "field2": {Value: 123456789},
"field3": 3.142, "field3": {Value: 3.142},
"field4": "value4", "field4": {Value: "value4"},
"field5": "value5", "field5": {Value: "value5"},
}, },
key: "key4", key: "key4",
options: HRandFieldOptions{WithValues: true, Count: 3}, options: HRandFieldOptions{WithValues: true, Count: 3},
@@ -491,12 +526,12 @@ func TestSugarDB_HRANDFIELD(t *testing.T) {
}, },
{ {
name: "Get the entire hash", name: "Get the entire hash",
presetValue: map[string]interface{}{ presetValue: hash.Hash{
"field1": "value1", "field1": {Value: "value1"},
"field2": 123456789, "field2": {Value: 123456789},
"field3": 3.142, "field3": {Value: 3.142},
"field4": "value4", "field4": {Value: "value4"},
"field5": "value5", "field5": {Value: "value5"},
}, },
key: "key5", key: "key5",
options: HRandFieldOptions{Count: 5}, options: HRandFieldOptions{Count: 5},
@@ -506,12 +541,12 @@ func TestSugarDB_HRANDFIELD(t *testing.T) {
}, },
{ {
name: "Get the entire hash with values", name: "Get the entire hash with values",
presetValue: map[string]interface{}{ presetValue: hash.Hash{
"field1": "value1", "field1": {Value: "value1"},
"field2": 123456789, "field2": {Value: 123456789},
"field3": 3.142, "field3": {Value: 3.142},
"field4": "value4", "field4": {Value: "value4"},
"field5": "value5", "field5": {Value: "value5"},
}, },
key: "key5", key: "key5",
options: HRandFieldOptions{WithValues: true, Count: 5}, options: HRandFieldOptions{WithValues: true, Count: 5},
@@ -582,7 +617,7 @@ func TestSugarDB_HSET(t *testing.T) {
{ {
name: "HSETNX set field on existing hash map", name: "HSETNX set field on existing hash map",
key: "key2", key: "key2",
presetValue: map[string]interface{}{"field1": "value1"}, presetValue: hash.Hash{"field1": {Value: "value1"}},
hsetFunc: server.HSetNX, hsetFunc: server.HSetNX,
fieldValuePairs: map[string]string{"field2": "value2"}, fieldValuePairs: map[string]string{"field2": "value2"},
want: 1, want: 1,
@@ -591,7 +626,7 @@ func TestSugarDB_HSET(t *testing.T) {
{ {
name: "HSETNX skips operation when setting on existing field", name: "HSETNX skips operation when setting on existing field",
key: "key3", key: "key3",
presetValue: map[string]interface{}{"field1": "value1"}, presetValue: hash.Hash{"field1": {Value: "value1"}},
hsetFunc: server.HSetNX, hsetFunc: server.HSetNX,
fieldValuePairs: map[string]string{"field1": "value1"}, fieldValuePairs: map[string]string{"field1": "value1"},
want: 0, want: 0,
@@ -609,7 +644,7 @@ func TestSugarDB_HSET(t *testing.T) {
{ {
name: "Regular HSET update on existing hash map", name: "Regular HSET update on existing hash map",
key: "key5", key: "key5",
presetValue: map[string]interface{}{"field1": "value1", "field2": "value2"}, presetValue: hash.Hash{"field1": {Value: "value1"}, "field2": {Value: "value2"}},
fieldValuePairs: map[string]string{"field1": "value1-new", "field2": "value2-ne2", "field3": "value3"}, fieldValuePairs: map[string]string{"field1": "value1-new", "field2": "value2-ne2", "field3": "value3"},
hsetFunc: server.HSet, hsetFunc: server.HSet,
want: 3, want: 3,
@@ -660,31 +695,35 @@ func TestSugarDB_HSTRLEN(t *testing.T) {
{ {
// Return lengths of field values. // Return lengths of field values.
// If the key does not exist, its length should be 0. // If the key does not exist, its length should be 0.
name: "Return lengths of field values", name: "1. Return lengths of field values",
presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, presetValue: hash.Hash{
key: "key1", "field1": {Value: "value1"},
fields: []string{"field1", "field2", "field3", "field4"}, "field2": {Value: 123456789},
want: []int{len("value1"), len("123456789"), len("3.142"), 0}, "field3": {Value: 3.142},
wantErr: false, },
key: "key1",
fields: []string{"field1", "field2", "field3", "field4"},
want: []int{len("value1"), len("123456789"), len("3.142"), 0},
wantErr: false,
}, },
{ {
name: "Response when trying to get HSTRLEN non-existent key", name: "2. Response when trying to get HSTRLEN non-existent key",
presetValue: map[string]interface{}{}, presetValue: hash.Hash{},
key: "key2", key: "key2",
fields: []string{"field1"}, fields: []string{"field1"},
want: []int{0}, want: []int{0},
wantErr: false, wantErr: false,
}, },
{ {
name: "Command too short", name: "3. Command too short",
key: "key3", key: "key3",
presetValue: map[string]interface{}{}, presetValue: hash.Hash{},
fields: []string{}, fields: []string{},
want: nil, want: nil,
wantErr: true, wantErr: true,
}, },
{ {
name: "Trying to get lengths on a non hash map returns error", name: "4. Trying to get lengths on a non hash map returns error",
key: "key4", key: "key4",
presetValue: "Default value", presetValue: "Default value",
fields: []string{"field1"}, fields: []string{"field1"},
@@ -694,6 +733,7 @@ func TestSugarDB_HSTRLEN(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Log(tt.name)
if tt.presetValue != nil { if tt.presetValue != nil {
err := presetValue(server, context.Background(), tt.key, tt.presetValue) err := presetValue(server, context.Background(), tt.key, tt.presetValue)
if err != nil { if err != nil {
@@ -724,11 +764,15 @@ func TestSugarDB_HVALS(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{ {
name: "Return all the values from a hash", name: "Return all the values from a hash",
key: "key1", key: "key1",
presetValue: map[string]interface{}{"field1": "value1", "field2": 123456789, "field3": 3.142}, presetValue: hash.Hash{
want: []string{"value1", "123456789", "3.142"}, "field1": {Value: "value1"},
wantErr: false, "field2": {Value: 123456789},
"field3": {Value: 3.142},
},
want: []string{"value1", "123456789", "3.142"},
wantErr: false,
}, },
{ {
name: "Empty array response when trying to get HSTRLEN non-existent key", name: "Empty array response when trying to get HSTRLEN non-existent key",
@@ -782,12 +826,16 @@ func TestSugarDB_HGet(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{ {
name: "1. Get values from existing hash.", name: "1. Get values from existing hash.",
key: "HgetKey1", key: "HgetKey1",
presetValue: map[string]interface{}{"field1": "value1", "field2": 365, "field3": 3.142}, presetValue: hash.Hash{
fields: []string{"field1", "field2", "field3", "field4"}, "field1": {Value: "value1"},
want: []string{"value1", "365", "3.142", ""}, "field2": {Value: 365},
wantErr: false, "field3": {Value: 3.142},
},
fields: []string{"field1", "field2", "field3", "field4"},
want: []string{"value1", "365", "3.142", ""},
wantErr: false,
}, },
{ {
name: "2. Return empty slice when attempting to get from non-existed key", name: "2. Return empty slice when attempting to get from non-existed key",
@@ -838,12 +886,16 @@ func TestSugarDB_HMGet(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{ {
name: "1. Get values from existing hash.", name: "1. Get values from existing hash.",
key: "HgetKey1", key: "HgetKey1",
presetValue: map[string]interface{}{"field1": "value1", "field2": 365, "field3": 3.142}, presetValue: hash.Hash{
fields: []string{"field1", "field2", "field3", "field4"}, "field1": {Value: "value1"},
want: []string{"value1", "365", "3.142", ""}, "field2": {Value: 365},
wantErr: false, "field3": {Value: 3.142},
},
fields: []string{"field1", "field2", "field3", "field4"},
want: []string{"value1", "365", "3.142", ""},
wantErr: false,
}, },
{ {
name: "2. Return empty slice when attempting to get from non-existed key", name: "2. Return empty slice when attempting to get from non-existed key",