mirror of
https://github.com/HDT3213/godis.git
synced 2025-10-07 01:32:56 +08:00
add set/get cmd
This commit is contained in:
@@ -113,6 +113,23 @@ func (d *Dict)PutIfAbsent(key string, val interface{})int {
|
||||
|
||||
if existed {
|
||||
return 0
|
||||
} else {
|
||||
// insert
|
||||
shard.table[key] = val
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
// return the number of updated key-value
|
||||
func (d *Dict)PutIfExists(key string, val interface{})int {
|
||||
shard := d.shards[d.spread(key)]
|
||||
shard.mutex.Lock()
|
||||
defer shard.mutex.Unlock()
|
||||
|
||||
_, existed := shard.table[key]
|
||||
|
||||
if !existed {
|
||||
return 0
|
||||
} else {
|
||||
// update
|
||||
shard.table[key] = val
|
@@ -73,6 +73,28 @@ func TestPutIfAbsent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutIfExists(t *testing.T) {
|
||||
d := Make(0)
|
||||
|
||||
// insert
|
||||
ret := d.PutIfExists("a", 1)
|
||||
if ret != 0 { // insert
|
||||
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
|
||||
}
|
||||
|
||||
d.Put("a", 1)
|
||||
ret = d.PutIfExists("a", 2)
|
||||
val, ok := d.Get("a")
|
||||
if ok {
|
||||
intVal, _ := val.(int)
|
||||
if intVal != 2 {
|
||||
t.Error("put test failed: expected 2, actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
} else {
|
||||
t.Error("put test failed: expected true, actual: false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
d := Make(0)
|
||||
|
64
src/db/db.go
64
src/db/db.go
@@ -3,53 +3,73 @@ package db
|
||||
import (
|
||||
"strings"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"github.com/HDT3213/godis/src/db/db"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"github.com/HDT3213/godis/src/lib/logger"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/datastruct/dict"
|
||||
)
|
||||
|
||||
// args don't include cmd line
|
||||
type CmdFunc func(args [][]byte)redis.Reply
|
||||
const (
|
||||
StringCode = iota // Data is []byte
|
||||
ListCode // *list.LinkedList
|
||||
SetCode
|
||||
DictCode // *dict.Dict
|
||||
SortedSetCode
|
||||
)
|
||||
|
||||
|
||||
type DB struct {
|
||||
cmdMap map[string]CmdFunc
|
||||
type DataEntity struct {
|
||||
Code uint8
|
||||
TTL int64 // ttl in seconds, 0 for unlimited ttl
|
||||
Data interface{}
|
||||
}
|
||||
|
||||
type UnknownErrReply struct {}
|
||||
// args don't include cmd line
|
||||
type CmdFunc func(db *DB, args [][]byte)redis.Reply
|
||||
|
||||
func (r *UnknownErrReply)ToBytes()[]byte {
|
||||
return []byte("-Err unknown\r\n")
|
||||
type DB struct {
|
||||
Data *dict.Dict // key -> DataEntity
|
||||
}
|
||||
|
||||
var cmdMap = MakeCmdMap()
|
||||
|
||||
func MakeCmdMap()map[string]CmdFunc {
|
||||
cmdMap := make(map[string]CmdFunc)
|
||||
cmdMap["ping"] = Ping
|
||||
|
||||
cmdMap["set"] = Set
|
||||
cmdMap["setnx"] = SetNX
|
||||
cmdMap["setex"] = SetEX
|
||||
cmdMap["psetex"] = PSetEX
|
||||
|
||||
cmdMap["get"] = Get
|
||||
|
||||
return cmdMap
|
||||
}
|
||||
|
||||
func MakeDB() *DB {
|
||||
return &DB{
|
||||
Data: dict.Make(1024),
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB)Exec(args [][]byte)(result redis.Reply) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack())))
|
||||
result = &UnknownErrReply{}
|
||||
result = &reply.UnknownErrReply{}
|
||||
}
|
||||
}()
|
||||
|
||||
cmd := strings.ToLower(string(args[0]))
|
||||
cmdFunc, ok := db.cmdMap[cmd]
|
||||
cmdFunc, ok := cmdMap[cmd]
|
||||
if !ok {
|
||||
return reply.MakeErrReply("ERR unknown command '" + cmd + "'")
|
||||
}
|
||||
if len(args) > 1 {
|
||||
result = cmdFunc(args[1:])
|
||||
result = cmdFunc(db, args[1:])
|
||||
} else {
|
||||
result = cmdFunc([][]byte{})
|
||||
result = cmdFunc(db, [][]byte{})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func MakeDB() *DB {
|
||||
cmdMap := make(map[string]CmdFunc)
|
||||
cmdMap["ping"] = db.Ping
|
||||
|
||||
return &DB{
|
||||
cmdMap:cmdMap,
|
||||
}
|
||||
}
|
||||
|
@@ -1,28 +0,0 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
)
|
||||
|
||||
type PongReply struct {}
|
||||
|
||||
func (r *PongReply)ToBytes()[]byte {
|
||||
return []byte("+PONG\r\n")
|
||||
}
|
||||
|
||||
type ArgNumErrReply struct {}
|
||||
|
||||
func (r *ArgNumErrReply)ToBytes()[]byte {
|
||||
return []byte("-ERR wrong number of arguments for 'ping' command\r\n")
|
||||
}
|
||||
|
||||
func Ping(args [][]byte)redis.Reply {
|
||||
if len(args) == 0 {
|
||||
return &PongReply{}
|
||||
} else if len(args) == 1 {
|
||||
return reply.MakeStatusReply("\"" + string(args[0]) + "\"")
|
||||
} else {
|
||||
return &ArgNumErrReply{}
|
||||
}
|
||||
}
|
27
src/db/get.go
Normal file
27
src/db/get.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
)
|
||||
|
||||
func Get(db *DB, args [][]byte)redis.Reply {
|
||||
if len(args) != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'get' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
val, ok := db.Data.Get(key)
|
||||
if !ok {
|
||||
return &reply.NullBulkReply{}
|
||||
}
|
||||
entity, _ := val.(*DataEntity)
|
||||
if entity.Code == StringCode {
|
||||
bytes, ok := entity.Data.([]byte)
|
||||
if !ok {
|
||||
return &reply.UnknownErrReply{}
|
||||
}
|
||||
return reply.MakeBulkReply(bytes)
|
||||
} else {
|
||||
return &reply.WrongTypeErrReply{}
|
||||
}
|
||||
}
|
16
src/db/ping.go
Normal file
16
src/db/ping.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
)
|
||||
|
||||
func Ping(db *DB, args [][]byte)redis.Reply {
|
||||
if len(args) == 0 {
|
||||
return &reply.PongReply{}
|
||||
} else if len(args) == 1 {
|
||||
return reply.MakeStatusReply("\"" + string(args[0]) + "\"")
|
||||
} else {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command")
|
||||
}
|
||||
}
|
159
src/db/set.go
Normal file
159
src/db/set.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
upsertPolicy = iota // default
|
||||
insertPolicy // set nx
|
||||
updatePolicy // set ex
|
||||
)
|
||||
|
||||
const unlimitedTTL int64 = 0
|
||||
|
||||
// SET key value [EX seconds] [PX milliseconds] [NX|XX]
|
||||
func Set(db *DB, args [][]byte)redis.Reply {
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'set' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
value := args[1]
|
||||
policy := upsertPolicy
|
||||
ttl := unlimitedTTL
|
||||
|
||||
// parse options
|
||||
if len(args) > 2 {
|
||||
for i := 2; i < len(args); i++ {
|
||||
arg := strings.ToUpper(string(args[i]))
|
||||
if arg == "NX" { // insert
|
||||
if policy == updatePolicy {
|
||||
return &reply.SyntaxErrReply{}
|
||||
}
|
||||
policy = insertPolicy
|
||||
} else if arg == "XX" { // update policy
|
||||
if policy == insertPolicy {
|
||||
return &reply.SyntaxErrReply{}
|
||||
}
|
||||
policy = updatePolicy
|
||||
} else if arg == "EX" { // ttl in seconds
|
||||
if ttl != unlimitedTTL {
|
||||
return &reply.SyntaxErrReply{}
|
||||
}
|
||||
if i + 1 >= len(args) {
|
||||
return &reply.SyntaxErrReply{}
|
||||
}
|
||||
ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64)
|
||||
if err != nil {
|
||||
return &reply.SyntaxErrReply{}
|
||||
}
|
||||
if ttlArg <= 0 {
|
||||
return reply.MakeErrReply("ERR invalid expire time in set")
|
||||
}
|
||||
ttl = ttlArg * 1000
|
||||
i++ // skip next arg
|
||||
} else if arg == "PX" { // ttl in milliseconds
|
||||
if ttl != unlimitedTTL {
|
||||
return &reply.SyntaxErrReply{}
|
||||
}
|
||||
if i + 1 >= len(args) {
|
||||
return &reply.SyntaxErrReply{}
|
||||
}
|
||||
ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64)
|
||||
if err != nil {
|
||||
return &reply.SyntaxErrReply{}
|
||||
}
|
||||
if ttlArg <= 0 {
|
||||
return reply.MakeErrReply("ERR invalid expire time in set")
|
||||
}
|
||||
ttl = ttlArg
|
||||
i++ // skip next arg
|
||||
} else {
|
||||
return &reply.SyntaxErrReply{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
entity := &DataEntity{
|
||||
Code: StringCode,
|
||||
TTL: ttl,
|
||||
Data: value,
|
||||
}
|
||||
|
||||
switch policy {
|
||||
case upsertPolicy:
|
||||
db.Data.Put(key, entity)
|
||||
case insertPolicy:
|
||||
db.Data.PutIfAbsent(key, entity)
|
||||
case updatePolicy:
|
||||
db.Data.PutIfExists(key, entity)
|
||||
}
|
||||
return &reply.OkReply{}
|
||||
}
|
||||
|
||||
func SetNX(db *DB, args [][]byte)redis.Reply {
|
||||
if len(args) != 2 {
|
||||
reply.MakeErrReply("ERR wrong number of arguments for 'setnx' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
value := args[1]
|
||||
entity := &DataEntity{
|
||||
Code: StringCode,
|
||||
Data: value,
|
||||
}
|
||||
result := db.Data.PutIfAbsent(key, entity)
|
||||
return reply.MakeIntReply(int64(result))
|
||||
}
|
||||
|
||||
func SetEX(db *DB, args [][]byte)redis.Reply {
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'setex' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
value := args[1]
|
||||
|
||||
ttlArg, err := strconv.ParseInt(string(args[1]), 10, 64)
|
||||
if err != nil {
|
||||
return &reply.SyntaxErrReply{}
|
||||
}
|
||||
if ttlArg <= 0 {
|
||||
return reply.MakeErrReply("ERR invalid expire time in setex")
|
||||
}
|
||||
ttl := ttlArg * 1000
|
||||
|
||||
entity := &DataEntity{
|
||||
Code: StringCode,
|
||||
TTL: ttl,
|
||||
Data: value,
|
||||
}
|
||||
db.Data.PutIfExists(key, entity)
|
||||
return &reply.OkReply{}
|
||||
}
|
||||
|
||||
func PSetEX(db *DB, args [][]byte)redis.Reply {
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'psetex' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
value := args[1]
|
||||
|
||||
ttl, err := strconv.ParseInt(string(args[1]), 10, 64)
|
||||
if err != nil {
|
||||
return &reply.SyntaxErrReply{}
|
||||
}
|
||||
if ttl <= 0 {
|
||||
return reply.MakeErrReply("ERR invalid expire time in psetex")
|
||||
}
|
||||
|
||||
entity := &DataEntity{
|
||||
Code: StringCode,
|
||||
TTL: ttl,
|
||||
Data: value,
|
||||
}
|
||||
db.Data.PutIfExists(key, entity)
|
||||
return &reply.OkReply{}
|
||||
}
|
@@ -5,7 +5,3 @@ import "github.com/HDT3213/godis/src/interface/redis"
|
||||
type DB interface {
|
||||
Exec([][]byte)redis.Reply
|
||||
}
|
||||
|
||||
type DataEntity interface {
|
||||
|
||||
}
|
25
src/redis/reply/consts.go
Normal file
25
src/redis/reply/consts.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package reply
|
||||
|
||||
type PongReply struct {}
|
||||
|
||||
var PongBytes = []byte("+PONG\r\n")
|
||||
|
||||
func (r *PongReply)ToBytes()[]byte {
|
||||
return PongBytes
|
||||
}
|
||||
|
||||
type OkReply struct {}
|
||||
|
||||
var OkBytes = []byte("+OK\r\n")
|
||||
|
||||
func (r *OkReply)ToBytes()[]byte {
|
||||
return OkBytes
|
||||
}
|
||||
|
||||
var nullBulkBytes = []byte("$-1\r\n")
|
||||
|
||||
type NullBulkReply struct {}
|
||||
|
||||
func (r *NullBulkReply)ToBytes()[]byte {
|
||||
return nullBulkBytes
|
||||
}
|
37
src/redis/reply/errors.go
Normal file
37
src/redis/reply/errors.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package reply
|
||||
|
||||
// UnknownErr
|
||||
type UnknownErrReply struct {}
|
||||
|
||||
var unknownErrBytes = []byte("-Err unknown\r\n")
|
||||
|
||||
func (r *UnknownErrReply)ToBytes()[]byte {
|
||||
return unknownErrBytes
|
||||
}
|
||||
|
||||
// ArgNumErr
|
||||
type ArgNumErrReply struct {
|
||||
Cmd string
|
||||
}
|
||||
|
||||
func (r *ArgNumErrReply)ToBytes()[]byte {
|
||||
return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n")
|
||||
}
|
||||
|
||||
// SyntaxErr
|
||||
type SyntaxErrReply struct {}
|
||||
|
||||
var syntaxErrBytes = []byte("-Err syntax error\r\n")
|
||||
|
||||
func (r *SyntaxErrReply)ToBytes()[]byte {
|
||||
return syntaxErrBytes
|
||||
}
|
||||
|
||||
// WrongTypeErr
|
||||
type WrongTypeErrReply struct {}
|
||||
|
||||
var wrongTypeErrBytes = []byte("-WRONGTYPE Operation against a key holding the wrong kind of value\r\n")
|
||||
|
||||
func (r *WrongTypeErrReply)ToBytes()[]byte {
|
||||
return wrongTypeErrBytes
|
||||
}
|
Reference in New Issue
Block a user