mirror of
https://github.com/HDT3213/godis.git
synced 2025-10-05 00:42:43 +08:00
601 lines
16 KiB
Go
601 lines
16 KiB
Go
package database
|
|
|
|
import (
|
|
"github.com/hdt3213/godis/aof"
|
|
"github.com/hdt3213/godis/interface/database"
|
|
"github.com/hdt3213/godis/interface/redis"
|
|
"github.com/hdt3213/godis/lib/utils"
|
|
"github.com/hdt3213/godis/redis/protocol"
|
|
"github.com/shopspring/decimal"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
func (db *DB) getAsString(key string) ([]byte, protocol.ErrorReply) {
|
|
entity, ok := db.GetEntity(key)
|
|
if !ok {
|
|
return nil, nil
|
|
}
|
|
bytes, ok := entity.Data.([]byte)
|
|
if !ok {
|
|
return nil, &protocol.WrongTypeErrReply{}
|
|
}
|
|
return bytes, nil
|
|
}
|
|
|
|
// execGet returns string value bound to the given key
|
|
func execGet(db *DB, args [][]byte) redis.Reply {
|
|
key := string(args[0])
|
|
bytes, err := db.getAsString(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if bytes == nil {
|
|
return &protocol.NullBulkReply{}
|
|
}
|
|
return protocol.MakeBulkReply(bytes)
|
|
}
|
|
|
|
const (
|
|
upsertPolicy = iota // default
|
|
insertPolicy // set nx
|
|
updatePolicy // set ex
|
|
)
|
|
|
|
const unlimitedTTL int64 = 0
|
|
|
|
// execSet sets string value and time to live to the given key
|
|
func execSet(db *DB, args [][]byte) redis.Reply {
|
|
key := string(args[0])
|
|
value := args[1]
|
|
policy := upsertPolicy
|
|
ttl := unlimitedTTL
|
|
|
|
// parse options
|
|
if len(args) > 2 {
|
|
for i := 2; i < len(args); i++ {
|
|
arg := strings.ToUpper(string(args[i]))
|
|
if arg == "NX" { // insert
|
|
if policy == updatePolicy {
|
|
return &protocol.SyntaxErrReply{}
|
|
}
|
|
policy = insertPolicy
|
|
} else if arg == "XX" { // update policy
|
|
if policy == insertPolicy {
|
|
return &protocol.SyntaxErrReply{}
|
|
}
|
|
policy = updatePolicy
|
|
} else if arg == "EX" { // ttl in seconds
|
|
if ttl != unlimitedTTL {
|
|
// ttl has been set
|
|
return &protocol.SyntaxErrReply{}
|
|
}
|
|
if i+1 >= len(args) {
|
|
return &protocol.SyntaxErrReply{}
|
|
}
|
|
ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64)
|
|
if err != nil {
|
|
return &protocol.SyntaxErrReply{}
|
|
}
|
|
if ttlArg <= 0 {
|
|
return protocol.MakeErrReply("ERR invalid expire time in set")
|
|
}
|
|
ttl = ttlArg * 1000
|
|
i++ // skip next arg
|
|
} else if arg == "PX" { // ttl in milliseconds
|
|
if ttl != unlimitedTTL {
|
|
return &protocol.SyntaxErrReply{}
|
|
}
|
|
if i+1 >= len(args) {
|
|
return &protocol.SyntaxErrReply{}
|
|
}
|
|
ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64)
|
|
if err != nil {
|
|
return &protocol.SyntaxErrReply{}
|
|
}
|
|
if ttlArg <= 0 {
|
|
return protocol.MakeErrReply("ERR invalid expire time in set")
|
|
}
|
|
ttl = ttlArg
|
|
i++ // skip next arg
|
|
} else {
|
|
return &protocol.SyntaxErrReply{}
|
|
}
|
|
}
|
|
}
|
|
|
|
entity := &database.DataEntity{
|
|
Data: value,
|
|
}
|
|
|
|
var result int
|
|
switch policy {
|
|
case upsertPolicy:
|
|
db.PutEntity(key, entity)
|
|
result = 1
|
|
case insertPolicy:
|
|
result = db.PutIfAbsent(key, entity)
|
|
case updatePolicy:
|
|
result = db.PutIfExists(key, entity)
|
|
}
|
|
if result > 0 {
|
|
if ttl != unlimitedTTL {
|
|
expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond)
|
|
db.Expire(key, expireTime)
|
|
db.addAof(CmdLine{
|
|
[]byte("SET"),
|
|
args[0],
|
|
args[1],
|
|
})
|
|
db.addAof(aof.MakeExpireCmd(key, expireTime).Args)
|
|
} else {
|
|
db.Persist(key) // override ttl
|
|
db.addAof(utils.ToCmdLine3("set", args...))
|
|
}
|
|
}
|
|
|
|
if result > 0 {
|
|
return &protocol.OkReply{}
|
|
}
|
|
return &protocol.NullBulkReply{}
|
|
}
|
|
|
|
// execSetNX sets string if not exists
|
|
func execSetNX(db *DB, args [][]byte) redis.Reply {
|
|
key := string(args[0])
|
|
value := args[1]
|
|
entity := &database.DataEntity{
|
|
Data: value,
|
|
}
|
|
result := db.PutIfAbsent(key, entity)
|
|
db.addAof(utils.ToCmdLine3("setnx", args...))
|
|
return protocol.MakeIntReply(int64(result))
|
|
}
|
|
|
|
// execSetEX sets string and its ttl
|
|
func execSetEX(db *DB, args [][]byte) redis.Reply {
|
|
key := string(args[0])
|
|
value := args[2]
|
|
|
|
ttlArg, err := strconv.ParseInt(string(args[1]), 10, 64)
|
|
if err != nil {
|
|
return &protocol.SyntaxErrReply{}
|
|
}
|
|
if ttlArg <= 0 {
|
|
return protocol.MakeErrReply("ERR invalid expire time in setex")
|
|
}
|
|
ttl := ttlArg * 1000
|
|
|
|
entity := &database.DataEntity{
|
|
Data: value,
|
|
}
|
|
|
|
db.PutEntity(key, entity)
|
|
expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond)
|
|
db.Expire(key, expireTime)
|
|
db.addAof(utils.ToCmdLine3("setex", args...))
|
|
db.addAof(aof.MakeExpireCmd(key, expireTime).Args)
|
|
return &protocol.OkReply{}
|
|
}
|
|
|
|
// execPSetEX set a key's time to live in milliseconds
|
|
func execPSetEX(db *DB, args [][]byte) redis.Reply {
|
|
key := string(args[0])
|
|
value := args[2]
|
|
|
|
ttlArg, err := strconv.ParseInt(string(args[1]), 10, 64)
|
|
if err != nil {
|
|
return &protocol.SyntaxErrReply{}
|
|
}
|
|
if ttlArg <= 0 {
|
|
return protocol.MakeErrReply("ERR invalid expire time in setex")
|
|
}
|
|
|
|
entity := &database.DataEntity{
|
|
Data: value,
|
|
}
|
|
|
|
db.PutEntity(key, entity)
|
|
expireTime := time.Now().Add(time.Duration(ttlArg) * time.Millisecond)
|
|
db.Expire(key, expireTime)
|
|
db.addAof(utils.ToCmdLine3("setex", args...))
|
|
db.addAof(aof.MakeExpireCmd(key, expireTime).Args)
|
|
|
|
return &protocol.OkReply{}
|
|
}
|
|
|
|
func prepareMSet(args [][]byte) ([]string, []string) {
|
|
size := len(args) / 2
|
|
keys := make([]string, size)
|
|
for i := 0; i < size; i++ {
|
|
keys[i] = string(args[2*i])
|
|
}
|
|
return keys, nil
|
|
}
|
|
|
|
func undoMSet(db *DB, args [][]byte) []CmdLine {
|
|
writeKeys, _ := prepareMSet(args)
|
|
return rollbackGivenKeys(db, writeKeys...)
|
|
}
|
|
|
|
// execMSet sets multi key-value in database
|
|
func execMSet(db *DB, args [][]byte) redis.Reply {
|
|
if len(args)%2 != 0 {
|
|
return protocol.MakeSyntaxErrReply()
|
|
}
|
|
|
|
size := len(args) / 2
|
|
keys := make([]string, size)
|
|
values := make([][]byte, size)
|
|
for i := 0; i < size; i++ {
|
|
keys[i] = string(args[2*i])
|
|
values[i] = args[2*i+1]
|
|
}
|
|
|
|
for i, key := range keys {
|
|
value := values[i]
|
|
db.PutEntity(key, &database.DataEntity{Data: value})
|
|
}
|
|
db.addAof(utils.ToCmdLine3("mset", args...))
|
|
return &protocol.OkReply{}
|
|
}
|
|
|
|
func prepareMGet(args [][]byte) ([]string, []string) {
|
|
keys := make([]string, len(args))
|
|
for i, v := range args {
|
|
keys[i] = string(v)
|
|
}
|
|
return nil, keys
|
|
}
|
|
|
|
// execMGet get multi key-value from database
|
|
func execMGet(db *DB, args [][]byte) redis.Reply {
|
|
keys := make([]string, len(args))
|
|
for i, v := range args {
|
|
keys[i] = string(v)
|
|
}
|
|
|
|
result := make([][]byte, len(args))
|
|
for i, key := range keys {
|
|
bytes, err := db.getAsString(key)
|
|
if err != nil {
|
|
_, isWrongType := err.(*protocol.WrongTypeErrReply)
|
|
if isWrongType {
|
|
result[i] = nil
|
|
continue
|
|
} else {
|
|
return err
|
|
}
|
|
}
|
|
result[i] = bytes // nil or []byte
|
|
}
|
|
|
|
return protocol.MakeMultiBulkReply(result)
|
|
}
|
|
|
|
// execMSetNX sets multi key-value in database, only if none of the given keys exist
|
|
func execMSetNX(db *DB, args [][]byte) redis.Reply {
|
|
// parse args
|
|
if len(args)%2 != 0 {
|
|
return protocol.MakeSyntaxErrReply()
|
|
}
|
|
size := len(args) / 2
|
|
values := make([][]byte, size)
|
|
keys := make([]string, size)
|
|
for i := 0; i < size; i++ {
|
|
keys[i] = string(args[2*i])
|
|
values[i] = args[2*i+1]
|
|
}
|
|
|
|
for _, key := range keys {
|
|
_, exists := db.GetEntity(key)
|
|
if exists {
|
|
return protocol.MakeIntReply(0)
|
|
}
|
|
}
|
|
|
|
for i, key := range keys {
|
|
value := values[i]
|
|
db.PutEntity(key, &database.DataEntity{Data: value})
|
|
}
|
|
db.addAof(utils.ToCmdLine3("msetnx", args...))
|
|
return protocol.MakeIntReply(1)
|
|
}
|
|
|
|
// execGetSet sets value of a string-type key and returns its old value
|
|
func execGetSet(db *DB, args [][]byte) redis.Reply {
|
|
key := string(args[0])
|
|
value := args[1]
|
|
|
|
old, err := db.getAsString(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
db.PutEntity(key, &database.DataEntity{Data: value})
|
|
db.Persist(key) // override ttl
|
|
db.addAof(utils.ToCmdLine3("getset", args...))
|
|
if old == nil {
|
|
return new(protocol.NullBulkReply)
|
|
}
|
|
return protocol.MakeBulkReply(old)
|
|
}
|
|
|
|
// execIncr increments the integer value of a key by one
|
|
func execIncr(db *DB, args [][]byte) redis.Reply {
|
|
key := string(args[0])
|
|
|
|
bytes, err := db.getAsString(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if bytes != nil {
|
|
val, err := strconv.ParseInt(string(bytes), 10, 64)
|
|
if err != nil {
|
|
return protocol.MakeErrReply("ERR value is not an integer or out of range")
|
|
}
|
|
db.PutEntity(key, &database.DataEntity{
|
|
Data: []byte(strconv.FormatInt(val+1, 10)),
|
|
})
|
|
db.addAof(utils.ToCmdLine3("incr", args...))
|
|
return protocol.MakeIntReply(val + 1)
|
|
}
|
|
db.PutEntity(key, &database.DataEntity{
|
|
Data: []byte("1"),
|
|
})
|
|
db.addAof(utils.ToCmdLine3("incr", args...))
|
|
return protocol.MakeIntReply(1)
|
|
}
|
|
|
|
// execIncrBy increments the integer value of a key by given value
|
|
func execIncrBy(db *DB, args [][]byte) redis.Reply {
|
|
key := string(args[0])
|
|
rawDelta := string(args[1])
|
|
delta, err := strconv.ParseInt(rawDelta, 10, 64)
|
|
if err != nil {
|
|
return protocol.MakeErrReply("ERR value is not an integer or out of range")
|
|
}
|
|
|
|
bytes, errReply := db.getAsString(key)
|
|
if errReply != nil {
|
|
return errReply
|
|
}
|
|
if bytes != nil {
|
|
// existed value
|
|
val, err := strconv.ParseInt(string(bytes), 10, 64)
|
|
if err != nil {
|
|
return protocol.MakeErrReply("ERR value is not an integer or out of range")
|
|
}
|
|
db.PutEntity(key, &database.DataEntity{
|
|
Data: []byte(strconv.FormatInt(val+delta, 10)),
|
|
})
|
|
db.addAof(utils.ToCmdLine3("incrby", args...))
|
|
return protocol.MakeIntReply(val + delta)
|
|
}
|
|
db.PutEntity(key, &database.DataEntity{
|
|
Data: args[1],
|
|
})
|
|
db.addAof(utils.ToCmdLine3("incrby", args...))
|
|
return protocol.MakeIntReply(delta)
|
|
}
|
|
|
|
// execIncrByFloat increments the float value of a key by given value
|
|
func execIncrByFloat(db *DB, args [][]byte) redis.Reply {
|
|
key := string(args[0])
|
|
rawDelta := string(args[1])
|
|
delta, err := decimal.NewFromString(rawDelta)
|
|
if err != nil {
|
|
return protocol.MakeErrReply("ERR value is not a valid float")
|
|
}
|
|
|
|
bytes, errReply := db.getAsString(key)
|
|
if errReply != nil {
|
|
return errReply
|
|
}
|
|
if bytes != nil {
|
|
val, err := decimal.NewFromString(string(bytes))
|
|
if err != nil {
|
|
return protocol.MakeErrReply("ERR value is not a valid float")
|
|
}
|
|
resultBytes := []byte(val.Add(delta).String())
|
|
db.PutEntity(key, &database.DataEntity{
|
|
Data: resultBytes,
|
|
})
|
|
db.addAof(utils.ToCmdLine3("incrbyfloat", args...))
|
|
return protocol.MakeBulkReply(resultBytes)
|
|
}
|
|
db.PutEntity(key, &database.DataEntity{
|
|
Data: args[1],
|
|
})
|
|
db.addAof(utils.ToCmdLine3("incrbyfloat", args...))
|
|
return protocol.MakeBulkReply(args[1])
|
|
}
|
|
|
|
// execDecr decrements the integer value of a key by one
|
|
func execDecr(db *DB, args [][]byte) redis.Reply {
|
|
key := string(args[0])
|
|
|
|
bytes, errReply := db.getAsString(key)
|
|
if errReply != nil {
|
|
return errReply
|
|
}
|
|
if bytes != nil {
|
|
val, err := strconv.ParseInt(string(bytes), 10, 64)
|
|
if err != nil {
|
|
return protocol.MakeErrReply("ERR value is not an integer or out of range")
|
|
}
|
|
db.PutEntity(key, &database.DataEntity{
|
|
Data: []byte(strconv.FormatInt(val-1, 10)),
|
|
})
|
|
db.addAof(utils.ToCmdLine3("decr", args...))
|
|
return protocol.MakeIntReply(val - 1)
|
|
}
|
|
entity := &database.DataEntity{
|
|
Data: []byte("-1"),
|
|
}
|
|
db.PutEntity(key, entity)
|
|
db.addAof(utils.ToCmdLine3("decr", args...))
|
|
return protocol.MakeIntReply(-1)
|
|
}
|
|
|
|
// execDecrBy decrements the integer value of a key by onedecrement
|
|
func execDecrBy(db *DB, args [][]byte) redis.Reply {
|
|
key := string(args[0])
|
|
rawDelta := string(args[1])
|
|
delta, err := strconv.ParseInt(rawDelta, 10, 64)
|
|
if err != nil {
|
|
return protocol.MakeErrReply("ERR value is not an integer or out of range")
|
|
}
|
|
|
|
bytes, errReply := db.getAsString(key)
|
|
if errReply != nil {
|
|
return errReply
|
|
}
|
|
if bytes != nil {
|
|
val, err := strconv.ParseInt(string(bytes), 10, 64)
|
|
if err != nil {
|
|
return protocol.MakeErrReply("ERR value is not an integer or out of range")
|
|
}
|
|
db.PutEntity(key, &database.DataEntity{
|
|
Data: []byte(strconv.FormatInt(val-delta, 10)),
|
|
})
|
|
db.addAof(utils.ToCmdLine3("decrby", args...))
|
|
return protocol.MakeIntReply(val - delta)
|
|
}
|
|
valueStr := strconv.FormatInt(-delta, 10)
|
|
db.PutEntity(key, &database.DataEntity{
|
|
Data: []byte(valueStr),
|
|
})
|
|
db.addAof(utils.ToCmdLine3("decrby", args...))
|
|
return protocol.MakeIntReply(-delta)
|
|
}
|
|
|
|
// execStrLen returns len of string value bound to the given key
|
|
func execStrLen(db *DB, args [][]byte) redis.Reply {
|
|
key := string(args[0])
|
|
bytes, err := db.getAsString(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if bytes == nil {
|
|
return protocol.MakeIntReply(0)
|
|
}
|
|
return protocol.MakeIntReply(int64(len(bytes)))
|
|
}
|
|
|
|
// execAppend sets string value to the given key
|
|
func execAppend(db *DB, args [][]byte) redis.Reply {
|
|
key := string(args[0])
|
|
bytes, err := db.getAsString(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
bytes = append(bytes, args[1]...)
|
|
db.PutEntity(key, &database.DataEntity{
|
|
Data: bytes,
|
|
})
|
|
db.addAof(utils.ToCmdLine3("append", args...))
|
|
return protocol.MakeIntReply(int64(len(bytes)))
|
|
}
|
|
|
|
// execSetRange overwrites part of the string stored at key, starting at the specified offset.
|
|
// If the offset is larger than the current length of the string at key, the string is padded with zero-bytes.
|
|
func execSetRange(db *DB, args [][]byte) redis.Reply {
|
|
key := string(args[0])
|
|
offset, errNative := strconv.ParseInt(string(args[1]), 10, 64)
|
|
if errNative != nil {
|
|
return protocol.MakeErrReply(errNative.Error())
|
|
}
|
|
value := args[2]
|
|
bytes, err := db.getAsString(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
bytesLen := int64(len(bytes))
|
|
if bytesLen < offset {
|
|
diff := offset - bytesLen
|
|
diffArray := make([]byte, diff)
|
|
bytes = append(bytes, diffArray...)
|
|
bytesLen = int64(len(bytes))
|
|
}
|
|
for i := 0; i < len(value); i++ {
|
|
idx := offset + int64(i)
|
|
if idx >= bytesLen {
|
|
bytes = append(bytes, value[i])
|
|
} else {
|
|
bytes[idx] = value[i]
|
|
}
|
|
}
|
|
db.PutEntity(key, &database.DataEntity{
|
|
Data: bytes,
|
|
})
|
|
db.addAof(utils.ToCmdLine3("setRange", args...))
|
|
return protocol.MakeIntReply(int64(len(bytes)))
|
|
}
|
|
|
|
func execGetRange(db *DB, args [][]byte) redis.Reply {
|
|
key := string(args[0])
|
|
startIdx, errNative := strconv.ParseInt(string(args[1]), 10, 64)
|
|
if errNative != nil {
|
|
return protocol.MakeErrReply(errNative.Error())
|
|
}
|
|
endIdx, errNative := strconv.ParseInt(string(args[2]), 10, 64)
|
|
if errNative != nil {
|
|
return protocol.MakeErrReply(errNative.Error())
|
|
}
|
|
|
|
bytes, err := db.getAsString(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if bytes == nil {
|
|
return protocol.MakeNullBulkReply()
|
|
}
|
|
|
|
bytesLen := int64(len(bytes))
|
|
if startIdx < -1*bytesLen {
|
|
return &protocol.NullBulkReply{}
|
|
} else if startIdx < 0 {
|
|
startIdx = bytesLen + startIdx
|
|
} else if startIdx >= bytesLen {
|
|
return &protocol.NullBulkReply{}
|
|
}
|
|
if endIdx < -1*bytesLen {
|
|
return &protocol.NullBulkReply{}
|
|
} else if endIdx < 0 {
|
|
endIdx = bytesLen + endIdx + 1
|
|
} else if endIdx < bytesLen {
|
|
endIdx = endIdx + 1
|
|
} else {
|
|
endIdx = bytesLen
|
|
}
|
|
if startIdx > endIdx {
|
|
return protocol.MakeNullBulkReply()
|
|
}
|
|
|
|
return protocol.MakeBulkReply(bytes[startIdx:endIdx])
|
|
}
|
|
|
|
func init() {
|
|
RegisterCommand("Set", execSet, writeFirstKey, rollbackFirstKey, -3)
|
|
RegisterCommand("SetNx", execSetNX, writeFirstKey, rollbackFirstKey, 3)
|
|
RegisterCommand("SetEX", execSetEX, writeFirstKey, rollbackFirstKey, 4)
|
|
RegisterCommand("PSetEX", execPSetEX, writeFirstKey, rollbackFirstKey, 4)
|
|
RegisterCommand("MSet", execMSet, prepareMSet, undoMSet, -3)
|
|
RegisterCommand("MGet", execMGet, prepareMGet, nil, -2)
|
|
RegisterCommand("MSetNX", execMSetNX, prepareMSet, undoMSet, -3)
|
|
RegisterCommand("Get", execGet, readFirstKey, nil, 2)
|
|
RegisterCommand("GetSet", execGetSet, writeFirstKey, rollbackFirstKey, 3)
|
|
RegisterCommand("Incr", execIncr, writeFirstKey, rollbackFirstKey, 2)
|
|
RegisterCommand("IncrBy", execIncrBy, writeFirstKey, rollbackFirstKey, 3)
|
|
RegisterCommand("IncrByFloat", execIncrByFloat, writeFirstKey, rollbackFirstKey, 3)
|
|
RegisterCommand("Decr", execDecr, writeFirstKey, rollbackFirstKey, 2)
|
|
RegisterCommand("DecrBy", execDecrBy, writeFirstKey, rollbackFirstKey, 3)
|
|
RegisterCommand("StrLen", execStrLen, readFirstKey, nil, 2)
|
|
RegisterCommand("Append", execAppend, writeFirstKey, rollbackFirstKey, 3)
|
|
RegisterCommand("SetRange", execSetRange, writeFirstKey, rollbackFirstKey, 4)
|
|
RegisterCommand("GetRange", execGetRange, readFirstKey, nil, 4)
|
|
}
|