Files
redis-go/db/string.go
2021-05-02 14:54:42 +08:00

517 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package db
import (
"github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/redis/reply"
"github.com/shopspring/decimal"
"strconv"
"strings"
"time"
)
func (db *DB) getAsString(key string) ([]byte, reply.ErrorReply) {
entity, ok := db.Get(key)
if !ok {
return nil, nil
}
bytes, ok := entity.Data.([]byte)
if !ok {
return nil, &reply.WrongTypeErrReply{}
}
return bytes, nil
}
// Get returns string value bound to the given key
func Get(db *DB, args [][]byte) redis.Reply {
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'get' command")
}
key := string(args[0])
bytes, err := db.getAsString(key)
if err != nil {
return err
}
if bytes == nil {
return &reply.NullBulkReply{}
}
return reply.MakeBulkReply(bytes)
}
const (
upsertPolicy = iota // default
insertPolicy // set nx
updatePolicy // set ex
)
const unlimitedTTL int64 = 0
// Set sets string value and time to live to the given key
func Set(db *DB, args [][]byte) redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'set' command")
}
key := string(args[0])
value := args[1]
policy := upsertPolicy
ttl := unlimitedTTL
// parse options
if len(args) > 2 {
for i := 2; i < len(args); i++ {
arg := strings.ToUpper(string(args[i]))
if arg == "NX" { // insert
if policy == updatePolicy {
return &reply.SyntaxErrReply{}
}
policy = insertPolicy
} else if arg == "XX" { // update policy
if policy == insertPolicy {
return &reply.SyntaxErrReply{}
}
policy = updatePolicy
} else if arg == "EX" { // ttl in seconds
if ttl != unlimitedTTL {
// ttl has been set
return &reply.SyntaxErrReply{}
}
if i+1 >= len(args) {
return &reply.SyntaxErrReply{}
}
ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64)
if err != nil {
return &reply.SyntaxErrReply{}
}
if ttlArg <= 0 {
return reply.MakeErrReply("ERR invalid expire time in set")
}
ttl = ttlArg * 1000
i++ // skip next arg
} else if arg == "PX" { // ttl in milliseconds
if ttl != unlimitedTTL {
return &reply.SyntaxErrReply{}
}
if i+1 >= len(args) {
return &reply.SyntaxErrReply{}
}
ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64)
if err != nil {
return &reply.SyntaxErrReply{}
}
if ttlArg <= 0 {
return reply.MakeErrReply("ERR invalid expire time in set")
}
ttl = ttlArg
i++ // skip next arg
} else {
return &reply.SyntaxErrReply{}
}
}
}
entity := &DataEntity{
Data: value,
}
db.Persist(key) // clean ttl
var result int
switch policy {
case upsertPolicy:
result = db.Put(key, entity)
case insertPolicy:
result = db.PutIfAbsent(key, entity)
case updatePolicy:
result = db.PutIfExists(key, entity)
}
/*
* 如果设置了ttl 则以最新的ttl为准
* 如果没有设置ttl 是新增key的情况不设置ttl。
* 如果没有设置ttl 且已存在key的 不修改ttl 但需要增加aof
*/
if ttl != unlimitedTTL {
expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond)
db.Expire(key, expireTime)
db.AddAof(reply.MakeMultiBulkReply([][]byte{
[]byte("SET"),
args[0],
args[1],
}))
db.AddAof(makeExpireCmd(key, expireTime))
} else if result > 0 {
db.Persist(key) // override ttl
db.AddAof(makeAofCmd("set", args))
} else {
db.AddAof(makeAofCmd("set", args))
}
if policy == upsertPolicy || result > 0 {
return &reply.OkReply{}
}
return &reply.NullBulkReply{}
}
// SetNX sets string if not exists
func SetNX(db *DB, args [][]byte) redis.Reply {
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'setnx' command")
}
key := string(args[0])
value := args[1]
entity := &DataEntity{
Data: value,
}
result := db.PutIfAbsent(key, entity)
db.AddAof(makeAofCmd("setnx", args))
return reply.MakeIntReply(int64(result))
}
// SetEX sets string and its ttl
func SetEX(db *DB, args [][]byte) redis.Reply {
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'setex' command")
}
key := string(args[0])
value := args[2]
ttlArg, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil {
return &reply.SyntaxErrReply{}
}
if ttlArg <= 0 {
return reply.MakeErrReply("ERR invalid expire time in setex")
}
ttl := ttlArg * 1000
entity := &DataEntity{
Data: value,
}
db.Lock(key)
defer db.UnLock(key)
db.Put(key, entity)
expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond)
db.Expire(key, expireTime)
db.AddAof(makeAofCmd("setex", args))
db.AddAof(makeExpireCmd(key, expireTime))
return &reply.OkReply{}
}
// PSetEX set a key's time to live in milliseconds
func PSetEX(db *DB, args [][]byte) redis.Reply {
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'setex' command")
}
key := string(args[0])
value := args[2]
ttlArg, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil {
return &reply.SyntaxErrReply{}
}
if ttlArg <= 0 {
return reply.MakeErrReply("ERR invalid expire time in setex")
}
entity := &DataEntity{
Data: value,
}
db.Lock(key)
defer db.UnLock(key)
db.Put(key, entity)
expireTime := time.Now().Add(time.Duration(ttlArg) * time.Millisecond)
db.Expire(key, expireTime)
db.AddAof(makeAofCmd("setex", args))
db.AddAof(makeExpireCmd(key, expireTime))
return &reply.OkReply{}
}
// MSet sets multi key-value in database
func MSet(db *DB, args [][]byte) redis.Reply {
if len(args)%2 != 0 || len(args) == 0 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
}
size := len(args) / 2
keys := make([]string, size)
values := make([][]byte, size)
for i := 0; i < size; i++ {
keys[i] = string(args[2*i])
values[i] = args[2*i+1]
}
db.Locks(keys...)
defer db.UnLocks(keys...)
for i, key := range keys {
value := values[i]
db.Put(key, &DataEntity{Data: value})
}
db.AddAof(makeAofCmd("mset", args))
return &reply.OkReply{}
}
// MGet get multi key-value from database
func MGet(db *DB, args [][]byte) redis.Reply {
if len(args) == 0 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command")
}
keys := make([]string, len(args))
for i, v := range args {
keys[i] = string(v)
}
result := make([][]byte, len(args))
for i, key := range keys {
bytes, err := db.getAsString(key)
if err != nil {
_, isWrongType := err.(*reply.WrongTypeErrReply)
if isWrongType {
result[i] = nil
continue
} else {
return err
}
}
result[i] = bytes // nil or []byte
}
return reply.MakeMultiBulkReply(result)
}
// MSetNX sets multi key-value in database, only if none of the given keys exist
func MSetNX(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args)%2 != 0 || len(args) == 0 {
return reply.MakeErrReply("ERR wrong number of arguments for 'msetnx' command")
}
size := len(args) / 2
values := make([][]byte, size)
keys := make([]string, size)
for i := 0; i < size; i++ {
keys[i] = string(args[2*i])
values[i] = args[2*i+1]
}
// lock keys
db.Locks(keys...)
defer db.UnLocks(keys...)
for _, key := range keys {
_, exists := db.Get(key)
if exists {
return reply.MakeIntReply(0)
}
}
for i, key := range keys {
value := values[i]
db.Put(key, &DataEntity{Data: value})
}
db.AddAof(makeAofCmd("msetnx", args))
return reply.MakeIntReply(1)
}
// GetSet sets value of a string-type key and returns its old value
func GetSet(db *DB, args [][]byte) redis.Reply {
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'getset' command")
}
key := string(args[0])
value := args[1]
old, err := db.getAsString(key)
if err != nil {
return err
}
db.Put(key, &DataEntity{Data: value})
db.Persist(key) // override ttl
db.AddAof(makeAofCmd("getset", args))
if old == nil {
return new(reply.NullBulkReply)
}
return reply.MakeBulkReply(old)
}
// Incr increments the integer value of a key by one
func Incr(db *DB, args [][]byte) redis.Reply {
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'incr' command")
}
key := string(args[0])
db.Lock(key)
defer db.UnLock(key)
bytes, err := db.getAsString(key)
if err != nil {
return err
}
if bytes != nil {
val, err := strconv.ParseInt(string(bytes), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
db.Put(key, &DataEntity{
Data: []byte(strconv.FormatInt(val+1, 10)),
})
db.AddAof(makeAofCmd("incr", args))
return reply.MakeIntReply(val + 1)
}
db.Put(key, &DataEntity{
Data: []byte("1"),
})
db.AddAof(makeAofCmd("incr", args))
return reply.MakeIntReply(1)
}
// IncrBy increments the integer value of a key by given value
func IncrBy(db *DB, args [][]byte) redis.Reply {
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'incrby' command")
}
key := string(args[0])
rawDelta := string(args[1])
delta, err := strconv.ParseInt(rawDelta, 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
db.Lock(key)
defer db.UnLock(key)
bytes, errReply := db.getAsString(key)
if errReply != nil {
return errReply
}
if bytes != nil {
// existed value
val, err := strconv.ParseInt(string(bytes), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
db.Put(key, &DataEntity{
Data: []byte(strconv.FormatInt(val+delta, 10)),
})
db.AddAof(makeAofCmd("incrby", args))
return reply.MakeIntReply(val + delta)
}
db.Put(key, &DataEntity{
Data: args[1],
})
db.AddAof(makeAofCmd("incrby", args))
return reply.MakeIntReply(delta)
}
// IncrByFloat increments the float value of a key by given value
func IncrByFloat(db *DB, args [][]byte) redis.Reply {
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'incrbyfloat' command")
}
key := string(args[0])
rawDelta := string(args[1])
delta, err := decimal.NewFromString(rawDelta)
if err != nil {
return reply.MakeErrReply("ERR value is not a valid float")
}
db.Lock(key)
defer db.UnLock(key)
bytes, errReply := db.getAsString(key)
if errReply != nil {
return errReply
}
if bytes != nil {
val, err := decimal.NewFromString(string(bytes))
if err != nil {
return reply.MakeErrReply("ERR value is not a valid float")
}
resultBytes := []byte(val.Add(delta).String())
db.Put(key, &DataEntity{
Data: resultBytes,
})
db.AddAof(makeAofCmd("incrbyfloat", args))
return reply.MakeBulkReply(resultBytes)
}
db.Put(key, &DataEntity{
Data: args[1],
})
db.AddAof(makeAofCmd("incrbyfloat", args))
return reply.MakeBulkReply(args[1])
}
// Decr decrements the integer value of a key by one
func Decr(db *DB, args [][]byte) redis.Reply {
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'decr' command")
}
key := string(args[0])
db.Lock(key)
defer db.UnLock(key)
bytes, errReply := db.getAsString(key)
if errReply != nil {
return errReply
}
if bytes != nil {
val, err := strconv.ParseInt(string(bytes), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
db.Put(key, &DataEntity{
Data: []byte(strconv.FormatInt(val-1, 10)),
})
db.AddAof(makeAofCmd("decr", args))
return reply.MakeIntReply(val - 1)
}
entity := &DataEntity{
Data: []byte("-1"),
}
db.Put(key, entity)
db.AddAof(makeAofCmd("decr", args))
return reply.MakeIntReply(-1)
}
// DecrBy decrements the integer value of a key by onedecrement
func DecrBy(db *DB, args [][]byte) redis.Reply {
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'decrby' command")
}
key := string(args[0])
rawDelta := string(args[1])
delta, err := strconv.ParseInt(rawDelta, 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
db.Lock(key)
defer db.UnLock(key)
bytes, errReply := db.getAsString(key)
if errReply != nil {
return errReply
}
if bytes != nil {
val, err := strconv.ParseInt(string(bytes), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
db.Put(key, &DataEntity{
Data: []byte(strconv.FormatInt(val-delta, 10)),
})
db.AddAof(makeAofCmd("decrby", args))
return reply.MakeIntReply(val - delta)
}
valueStr := strconv.FormatInt(-delta, 10)
db.Put(key, &DataEntity{
Data: []byte(valueStr),
})
db.AddAof(makeAofCmd("decrby", args))
return reply.MakeIntReply(-delta)
}