Update Hash to allow for expirey commands (#146)

* Convert hash to composite type. Fixed broken Hash commands from Hash refactor. Coverage and fixed broken test - @osteensco
This commit is contained in:
osteensco
2024-11-03 13:24:31 -06:00
committed by GitHub
parent 05b7601752
commit 09640082c4
8 changed files with 3577 additions and 3517 deletions

View File

@@ -33,27 +33,25 @@ func handleHSET(params internal.HandlerFuncParams) ([]byte, error) {
key := keys.WriteKeys[0]
keyExists := params.KeysExist(params.Context, keys.WriteKeys)[key]
entries := make(map[string]interface{})
entries := Hash{}
if len(params.Command[2:])%2 != 0 {
return nil, errors.New("each field must have a corresponding value")
}
for i := 2; i <= len(params.Command)-2; i += 2 {
entries[params.Command[i]] = internal.AdaptType(params.Command[i+1])
k := params.Command[i]
entries[k] = HashValue{Value: internal.AdaptType(params.Command[i+1])}
}
if !keyExists {
if err != nil {
return nil, err
}
if err = params.SetValues(params.Context, map[string]interface{}{key: entries}); err != nil {
return nil, err
}
return []byte(fmt.Sprintf(":%d\r\n", len(entries))), nil
}
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok {
// Not hash, save the entries map directly.
if err = params.SetValues(params.Context, map[string]interface{}{key: entries}); err != nil {
@@ -67,18 +65,19 @@ func handleHSET(params internal.HandlerFuncParams) ([]byte, error) {
case "hsetnx":
// Handle HSETNX
for field, _ := range entries {
if hash[field] == nil {
if _, ok := hash[field]; !ok {
count += 1
}
}
for field, value := range hash {
entries[field] = value
}
default:
// Handle HSET
for field, value := range hash {
if entries[field] == nil {
entries[field] = value
if entries[field].Value == nil {
entries[field] = HashValue{Value: value}
}
}
count = len(entries)
@@ -105,29 +104,29 @@ func handleHGET(params internal.HandlerFuncParams) ([]byte, error) {
return []byte("$-1\r\n"), nil
}
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
var value interface{}
var value HashValue
res := fmt.Sprintf("*%d\r\n", len(fields))
for _, field := range fields {
value = hash[field]
if value == nil {
if value.Value == nil {
res += "$-1\r\n"
continue
}
if s, ok := value.(string); ok {
if s, ok := value.Value.(string); ok {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s)
continue
}
if d, ok := value.(int); ok {
if d, ok := value.Value.(int); ok {
res += fmt.Sprintf(":%d\r\n", d)
continue
}
if f, ok := value.(float64); ok {
if f, ok := value.Value.(float64); ok {
fs := strconv.FormatFloat(f, 'f', -1, 64)
res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs)
continue
@@ -150,14 +149,15 @@ func handleHMGET(params internal.HandlerFuncParams) ([]byte, error) {
return []byte("$-1\r\n"), nil
}
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
fields := params.Command[2:]
var value interface{}
var value HashValue
res := fmt.Sprintf("*%d\r\n", len(fields))
for _, field := range fields {
value, ok = hash[field]
@@ -166,15 +166,15 @@ func handleHMGET(params internal.HandlerFuncParams) ([]byte, error) {
continue
}
if s, ok := value.(string); ok {
if s, ok := value.Value.(string); ok {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s)
continue
}
if d, ok := value.(int); ok {
if d, ok := value.Value.(int); ok {
res += fmt.Sprintf(":%d\r\n", d)
continue
}
if f, ok := value.(float64); ok {
if f, ok := value.Value.(float64); ok {
fs := strconv.FormatFloat(f, 'f', -1, 64)
res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs)
continue
@@ -199,30 +199,30 @@ func handleHSTRLEN(params internal.HandlerFuncParams) ([]byte, error) {
return []byte("$-1\r\n"), nil
}
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
var value interface{}
var value HashValue
res := fmt.Sprintf("*%d\r\n", len(fields))
for _, field := range fields {
value = hash[field]
if value == nil {
if value.Value == nil {
res += ":0\r\n"
continue
}
if s, ok := value.(string); ok {
if s, ok := value.Value.(string); ok {
res += fmt.Sprintf(":%d\r\n", len(s))
continue
}
if f, ok := value.(float64); ok {
if f, ok := value.Value.(float64); ok {
fs := strconv.FormatFloat(f, 'f', -1, 64)
res += fmt.Sprintf(":%d\r\n", len(fs))
continue
}
if d, ok := value.(int); ok {
if d, ok := value.Value.(int); ok {
res += fmt.Sprintf(":%d\r\n", len(strconv.Itoa(d)))
continue
}
@@ -245,23 +245,23 @@ func handleHVALS(params internal.HandlerFuncParams) ([]byte, error) {
return []byte("*0\r\n"), nil
}
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
res := fmt.Sprintf("*%d\r\n", len(hash))
for _, val := range hash {
if s, ok := val.(string); ok {
if s, ok := val.Value.(string); ok {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s)
continue
}
if f, ok := val.(float64); ok {
if f, ok := val.Value.(float64); ok {
fs := strconv.FormatFloat(f, 'f', -1, 64)
res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs)
continue
}
if d, ok := val.(int); ok {
if d, ok := val.Value.(int); ok {
res += fmt.Sprintf(":%d\r\n", d)
}
}
@@ -303,7 +303,7 @@ func handleHRANDFIELD(params internal.HandlerFuncParams) ([]byte, error) {
return []byte("*0\r\n"), nil
}
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
@@ -317,16 +317,16 @@ func handleHRANDFIELD(params internal.HandlerFuncParams) ([]byte, error) {
for field, value := range hash {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(field), field)
if withvalues {
if s, ok := value.(string); ok {
if s, ok := value.Value.(string); ok {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s)
continue
}
if f, ok := value.(float64); ok {
if f, ok := value.Value.(float64); ok {
fs := strconv.FormatFloat(f, 'f', -1, 64)
res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs)
continue
}
if d, ok := value.(int); ok {
if d, ok := value.Value.(int); ok {
res += fmt.Sprintf(":%d\r\n", d)
continue
}
@@ -362,16 +362,16 @@ func handleHRANDFIELD(params internal.HandlerFuncParams) ([]byte, error) {
for _, field := range pluckedFields {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(field), field)
if withvalues {
if s, ok := hash[field].(string); ok {
if s, ok := hash[field].Value.(string); ok {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s)
continue
}
if f, ok := hash[field].(float64); ok {
if f, ok := hash[field].Value.(float64); ok {
fs := strconv.FormatFloat(f, 'f', -1, 64)
res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs)
continue
}
if d, ok := hash[field].(int); ok {
if d, ok := hash[field].Value.(int); ok {
res += fmt.Sprintf(":%d\r\n", d)
continue
}
@@ -394,7 +394,7 @@ func handleHLEN(params internal.HandlerFuncParams) ([]byte, error) {
return []byte(":0\r\n"), nil
}
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
@@ -415,7 +415,7 @@ func handleHKEYS(params internal.HandlerFuncParams) ([]byte, error) {
return []byte("*0\r\n"), nil
}
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
@@ -456,15 +456,15 @@ func handleHINCRBY(params internal.HandlerFuncParams) ([]byte, error) {
}
if !keyExists {
hash := make(map[string]interface{})
hash := make(Hash)
if strings.EqualFold(params.Command[0], "hincrbyfloat") {
hash[field] = floatIncrement
hash[field] = HashValue{Value: floatIncrement}
if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil {
return nil, err
}
return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(floatIncrement, 'f', -1, 64))), nil
} else {
hash[field] = intIncrement
hash[field] = HashValue{Value: intIncrement}
if err = params.SetValues(params.Context, map[string]interface{}{key: hash}); err != nil {
return nil, err
}
@@ -472,31 +472,31 @@ func handleHINCRBY(params internal.HandlerFuncParams) ([]byte, error) {
}
}
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
if hash[field] == nil {
hash[field] = 0
if hash[field].Value == nil {
hash[field] = HashValue{Value: 0}
}
switch hash[field].(type) {
switch hash[field].Value.(type) {
default:
return nil, fmt.Errorf("value at field %s is not a number", field)
case int:
i, _ := hash[field].(int)
i, _ := hash[field].Value.(int)
if strings.EqualFold(params.Command[0], "hincrbyfloat") {
hash[field] = float64(i) + floatIncrement
hash[field] = HashValue{Value: float64(i) + floatIncrement}
} else {
hash[field] = i + intIncrement
hash[field] = HashValue{Value: i + intIncrement}
}
case float64:
f, _ := hash[field].(float64)
f, _ := hash[field].Value.(float64)
if strings.EqualFold(params.Command[0], "hincrbyfloat") {
hash[field] = f + floatIncrement
hash[field] = HashValue{Value: f + floatIncrement}
} else {
hash[field] = f + float64(intIncrement)
hash[field] = HashValue{Value: f + float64(intIncrement)}
}
}
@@ -504,11 +504,11 @@ func handleHINCRBY(params internal.HandlerFuncParams) ([]byte, error) {
return nil, err
}
if f, ok := hash[field].(float64); ok {
if f, ok := hash[field].Value.(float64); ok {
return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(f, 'f', -1, 64))), nil
}
i, _ := hash[field].(int)
i, _ := hash[field].Value.(int)
return []byte(fmt.Sprintf(":%d\r\n", i)), nil
}
@@ -525,7 +525,7 @@ func handleHGETALL(params internal.HandlerFuncParams) ([]byte, error) {
return []byte("*0\r\n"), nil
}
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
@@ -533,14 +533,16 @@ func handleHGETALL(params internal.HandlerFuncParams) ([]byte, error) {
res := fmt.Sprintf("*%d\r\n", len(hash)*2)
for field, value := range hash {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(field), field)
if s, ok := value.(string); ok {
if s, ok := value.Value.(string); ok {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s)
}
if f, ok := value.(float64); ok {
if f, ok := value.Value.(float64); ok {
fs := strconv.FormatFloat(f, 'f', -1, 64)
res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs)
}
if d, ok := value.(int); ok {
if d, ok := value.Value.(int); ok {
res += fmt.Sprintf(":%d\r\n", d)
}
}
@@ -562,12 +564,12 @@ func handleHEXISTS(params internal.HandlerFuncParams) ([]byte, error) {
return []byte(":0\r\n"), nil
}
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
if hash[field] != nil {
if hash[field].Value != nil {
return []byte(":1\r\n"), nil
}
@@ -588,7 +590,7 @@ func handleHDEL(params internal.HandlerFuncParams) ([]byte, error) {
return []byte(":0\r\n"), nil
}
hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{})
hash, ok := params.GetValues(params.Context, []string{key})[key].(Hash)
if !ok {
return nil, fmt.Errorf("value at %s is not a hash", key)
}
@@ -596,7 +598,7 @@ func handleHDEL(params internal.HandlerFuncParams) ([]byte, error) {
count := 0
for _, field := range fields {
if hash[field] != nil {
if hash[field].Value != nil {
delete(hash, field)
count += 1
}
@@ -742,5 +744,14 @@ Return the string length of the values stored at the specified fields. 0 if the
KeyExtractionFunc: hdelKeyFunc,
HandlerFunc: handleHDEL,
},
// {
// Command: "hexpire",
// Module: constants.HashModule,
// Categories: []string{constants.HashCategory, constants.WriteCategory, constants.FastCategory},
// Description: `(HEXPIRE key seconds [NX | XX | GT | LT] FIELDS numfields field [field ...]) Sets the expiration, in seconds, of a field in a hash.`,
// Sync: true,
// KeyExtractionFunc: hexpireKeyFunc,
// HandlerFunc: handleHEXPIRE,
// },
}
}