reformat code

This commit is contained in:
hdt3213
2021-04-03 20:14:12 +08:00
parent bf913a5aca
commit bcf0cd5e92
54 changed files with 4887 additions and 4896 deletions

View File

@@ -2,11 +2,13 @@
[中文版](https://github.com/HDT3213/godis/blob/master/README_CN.md)
`Godis` is a simple implementation of Redis Server, which intents to provide an example of writing a high concurrent middleware using golang.
`Godis` is a simple implementation of Redis Server, which intents to provide an example of writing a high concurrent
middleware using golang.
Please be advised, NEVER think about using this in production environment.
This repository implemented most features of redis, including 5 data structures, ttl, publish/subscribe, AOF persistence and server side cluster mode.
This repository implemented most features of redis, including 5 data structures, ttl, publish/subscribe, AOF persistence
and server side cluster mode.
If you could read Chinese, you can find more details in [My Blog](https://www.cnblogs.com/Finley/category/1598973.html).
@@ -22,7 +24,7 @@ You could use redis-cli or other redis client to connect godis server, which lis
The program will try to read config file path from environment variable `CONFIG`.
If environment variable is not set, then the program try to read `redis.conf` in the working directory.
If environment variable is not set, then the program try to read `redis.conf` in the working directory.
If there is no such file, then the program will run with default config.
@@ -35,8 +37,7 @@ peers localhost:7379,localhost:7389 // other node in cluster
self localhost:6399 // self address
```
We provide node1.conf and node2.conf for demonstration.
use following command line to start a two-node-cluster:
We provide node1.conf and node2.conf for demonstration. use following command line to start a two-node-cluster:
```bash
CONFIG=node1.conf ./godis-darwin &

View File

@@ -2,8 +2,8 @@ Godis 是一个用 Go 语言实现的 Redis 服务器。本项目旨在为尝试
**请注意:不要在生产环境使用使用此项目**
Godis 实现了 Redis 的大多数功能包括5种数据结构、TTL、发布订阅以及 AOF 持久化。可以在[我的博客](https://www.cnblogs.com/Finley/category/1598973.html)了解更多关于 Godis 的信息。
Godis 实现了 Redis 的大多数功能包括5种数据结构、TTL、发布订阅以及 AOF 持久化。可以在[我的博客](https://www.cnblogs.com/Finley/category/1598973.html)了解更多关于
Godis 的信息。
# 运行 Godis
@@ -149,19 +149,19 @@ redis-cli -p 6399
- tcp: tcp 服务器实现
- redis: redis 协议解析器
- datastruct: redis 的各类数据结构实现
- dict: hash 表
- list: 链表
- lock: 用于锁定 key 的锁组件
- set 基于hash表的集合
- sortedset: 基于跳表实现的有序集合
- dict: hash 表
- list: 链表
- lock: 用于锁定 key 的锁组件
- set 基于hash表的集合
- sortedset: 基于跳表实现的有序集合
- db: redis 存储引擎实现
- db.go: 引擎的基础功能
- router.go: 将命令路由给响应的处理函数
- keys.go: del、ttl、expire 等通用命令实现
- string.go: get、set 等字符串命令实现
- list.go: lpush、lindex 等列表命令实现
- hash.go: hget、hset 等哈希表命令实现
- set.go: sadd 等集合命令实现
- sortedset.go: zadd 等有序集合命令实现
- pubsub.go: 发布订阅命令实现
- aof.go: aof持久化实现
- db.go: 引擎的基础功能
- router.go: 将命令路由给响应的处理函数
- keys.go: del、ttl、expire 等通用命令实现
- string.go: get、set 等字符串命令实现
- list.go: lpush、lindex 等列表命令实现
- hash.go: hget、hset 等哈希表命令实现
- set.go: sadd 等集合命令实现
- sortedset.go: zadd 等有序集合命令实现
- pubsub.go: 发布订阅命令实现
- aof.go: aof持久化实现

View File

@@ -1,45 +1,45 @@
package cluster
import (
"context"
"errors"
"github.com/HDT3213/godis/src/redis/client"
"github.com/jolestar/go-commons-pool/v2"
"context"
"errors"
"github.com/HDT3213/godis/src/redis/client"
"github.com/jolestar/go-commons-pool/v2"
)
type ConnectionFactory struct {
Peer string
Peer string
}
func (f *ConnectionFactory) MakeObject(ctx context.Context) (*pool.PooledObject, error) {
c, err := client.MakeClient(f.Peer)
if err != nil {
return nil, err
}
c.Start()
return pool.NewPooledObject(c), nil
c, err := client.MakeClient(f.Peer)
if err != nil {
return nil, err
}
c.Start()
return pool.NewPooledObject(c), nil
}
func (f *ConnectionFactory) DestroyObject(ctx context.Context, object *pool.PooledObject) error {
c, ok := object.Object.(*client.Client)
if !ok {
return errors.New("type mismatch")
}
c.Close()
return nil
c, ok := object.Object.(*client.Client)
if !ok {
return errors.New("type mismatch")
}
c.Close()
return nil
}
func (f *ConnectionFactory) ValidateObject(ctx context.Context, object *pool.PooledObject) bool {
// do validate
return true
// do validate
return true
}
func (f *ConnectionFactory) ActivateObject(ctx context.Context, object *pool.PooledObject) error {
// do activate
return nil
// do activate
return nil
}
func (f *ConnectionFactory) PassivateObject(ctx context.Context, object *pool.PooledObject) error {
// do passivate
return nil
// do passivate
return nil
}

View File

@@ -1,99 +1,99 @@
package cluster
import (
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"strconv"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"strconv"
)
func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'del' command")
}
keys := make([]string, len(args)-1)
for i := 1; i < len(args); i++ {
keys[i-1] = string(args[i])
}
groupMap := cluster.groupBy(keys)
if len(groupMap) == 1 { // do fast
for peer, group := range groupMap { // only one group
return cluster.Relay(peer, c, makeArgs("DEL", group...))
}
}
// prepare
var errReply redis.Reply
txId := cluster.idGenerator.NextId()
txIdStr := strconv.FormatInt(txId, 10)
rollback := false
for peer, group := range groupMap {
args := []string{txIdStr}
args = append(args, group...)
var resp redis.Reply
if peer == cluster.self {
resp = PrepareDel(cluster, c, makeArgs("PrepareDel", args...))
} else {
resp = cluster.Relay(peer, c, makeArgs("PrepareDel", args...))
}
if reply.IsErrorReply(resp) {
errReply = resp
rollback = true
break
}
}
var respList []redis.Reply
if rollback {
// rollback
RequestRollback(cluster, c, txId, groupMap)
} else {
// commit
respList, errReply = RequestCommit(cluster, c, txId, groupMap)
if errReply != nil {
rollback = true
}
}
if !rollback {
var deleted int64 = 0
for _, resp := range respList {
intResp := resp.(*reply.IntReply)
deleted += intResp.Code
}
return reply.MakeIntReply(int64(deleted))
}
return errReply
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'del' command")
}
keys := make([]string, len(args)-1)
for i := 1; i < len(args); i++ {
keys[i-1] = string(args[i])
}
groupMap := cluster.groupBy(keys)
if len(groupMap) == 1 { // do fast
for peer, group := range groupMap { // only one group
return cluster.Relay(peer, c, makeArgs("DEL", group...))
}
}
// prepare
var errReply redis.Reply
txId := cluster.idGenerator.NextId()
txIdStr := strconv.FormatInt(txId, 10)
rollback := false
for peer, group := range groupMap {
args := []string{txIdStr}
args = append(args, group...)
var resp redis.Reply
if peer == cluster.self {
resp = PrepareDel(cluster, c, makeArgs("PrepareDel", args...))
} else {
resp = cluster.Relay(peer, c, makeArgs("PrepareDel", args...))
}
if reply.IsErrorReply(resp) {
errReply = resp
rollback = true
break
}
}
var respList []redis.Reply
if rollback {
// rollback
RequestRollback(cluster, c, txId, groupMap)
} else {
// commit
respList, errReply = RequestCommit(cluster, c, txId, groupMap)
if errReply != nil {
rollback = true
}
}
if !rollback {
var deleted int64 = 0
for _, resp := range respList {
intResp := resp.(*reply.IntReply)
deleted += intResp.Code
}
return reply.MakeIntReply(int64(deleted))
}
return errReply
}
// args: PrepareDel id keys...
func PrepareDel(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) < 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'preparedel' command")
}
txId := string(args[1])
keys := make([]string, 0, len(args)-2)
for i := 2; i < len(args); i++ {
arg := args[i]
keys = append(keys, string(arg))
}
txArgs := makeArgs("DEL", keys...) // actual args for cluster.db
tx := NewTransaction(cluster, c, txId, txArgs, keys)
cluster.transactions.Put(txId, tx)
err := tx.prepare()
if err != nil {
return reply.MakeErrReply(err.Error())
}
return &reply.OkReply{}
if len(args) < 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'preparedel' command")
}
txId := string(args[1])
keys := make([]string, 0, len(args)-2)
for i := 2; i < len(args); i++ {
arg := args[i]
keys = append(keys, string(arg))
}
txArgs := makeArgs("DEL", keys...) // actual args for cluster.db
tx := NewTransaction(cluster, c, txId, txArgs, keys)
cluster.transactions.Put(txId, tx)
err := tx.prepare()
if err != nil {
return reply.MakeErrReply(err.Error())
}
return &reply.OkReply{}
}
// invoker should provide lock
func CommitDel(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply {
keys := make([]string, len(tx.args))
for i, v := range tx.args {
keys[i] = string(v)
}
keys = keys[1:]
keys := make([]string, len(tx.args))
for i, v := range tx.args {
keys[i] = string(v)
}
keys = keys[1:]
deleted := cluster.db.Removes(keys...)
if deleted > 0 {
cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args))
}
return reply.MakeIntReply(int64(deleted))
deleted := cluster.db.Removes(keys...)
if deleted > 0 {
cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args))
}
return reply.MakeIntReply(int64(deleted))
}

View File

@@ -1,85 +1,85 @@
package idgenerator
import (
"hash/fnv"
"log"
"sync"
"time"
"hash/fnv"
"log"
"sync"
"time"
)
const (
workerIdBits int64 = 5
datacenterIdBits int64 = 5
sequenceBits int64 = 12
workerIdBits int64 = 5
datacenterIdBits int64 = 5
sequenceBits int64 = 12
maxWorkerId int64 = -1 ^ (-1 << uint64(workerIdBits))
maxDatacenterId int64 = -1 ^ (-1 << uint64(datacenterIdBits))
maxSequence int64 = -1 ^ (-1 << uint64(sequenceBits))
maxWorkerId int64 = -1 ^ (-1 << uint64(workerIdBits))
maxDatacenterId int64 = -1 ^ (-1 << uint64(datacenterIdBits))
maxSequence int64 = -1 ^ (-1 << uint64(sequenceBits))
timeLeft uint8 = 22
dataLeft uint8 = 17
workLeft uint8 = 12
timeLeft uint8 = 22
dataLeft uint8 = 17
workLeft uint8 = 12
twepoch int64 = 1525705533000
twepoch int64 = 1525705533000
)
type IdGenerator struct {
mu *sync.Mutex
lastStamp int64
workerId int64
dataCenterId int64
sequence int64
mu *sync.Mutex
lastStamp int64
workerId int64
dataCenterId int64
sequence int64
}
func MakeGenerator(cluster string, node string) *IdGenerator {
fnv64 := fnv.New64()
_, _ = fnv64.Write([]byte(cluster))
dataCenterId := int64(fnv64.Sum64())
fnv64 := fnv.New64()
_, _ = fnv64.Write([]byte(cluster))
dataCenterId := int64(fnv64.Sum64())
fnv64.Reset()
_, _ = fnv64.Write([]byte(node))
workerId := int64(fnv64.Sum64())
fnv64.Reset()
_, _ = fnv64.Write([]byte(node))
workerId := int64(fnv64.Sum64())
return &IdGenerator{
mu: &sync.Mutex{},
lastStamp: -1,
dataCenterId: dataCenterId,
workerId: workerId,
sequence: 1,
}
return &IdGenerator{
mu: &sync.Mutex{},
lastStamp: -1,
dataCenterId: dataCenterId,
workerId: workerId,
sequence: 1,
}
}
func (w *IdGenerator) getCurrentTime() int64 {
return time.Now().UnixNano() / 1e6
return time.Now().UnixNano() / 1e6
}
func (w *IdGenerator) NextId() int64 {
w.mu.Lock()
defer w.mu.Unlock()
w.mu.Lock()
defer w.mu.Unlock()
timestamp := w.getCurrentTime()
if timestamp < w.lastStamp {
log.Fatal("can not generate id")
}
if w.lastStamp == timestamp {
w.sequence = (w.sequence + 1) & maxSequence
if w.sequence == 0 {
for timestamp <= w.lastStamp {
timestamp = w.getCurrentTime()
}
}
} else {
w.sequence = 0
}
w.lastStamp = timestamp
timestamp := w.getCurrentTime()
if timestamp < w.lastStamp {
log.Fatal("can not generate id")
}
if w.lastStamp == timestamp {
w.sequence = (w.sequence + 1) & maxSequence
if w.sequence == 0 {
for timestamp <= w.lastStamp {
timestamp = w.getCurrentTime()
}
}
} else {
w.sequence = 0
}
w.lastStamp = timestamp
return ((timestamp - twepoch) << timeLeft) | (w.dataCenterId << dataLeft) | (w.workerId << workLeft) | w.sequence
return ((timestamp - twepoch) << timeLeft) | (w.dataCenterId << dataLeft) | (w.workerId << workLeft) | w.sequence
}
func (w *IdGenerator) tilNextMillis() int64 {
timestamp := w.getCurrentTime()
if timestamp <= w.lastStamp {
timestamp = w.getCurrentTime()
}
return timestamp
timestamp := w.getCurrentTime()
if timestamp <= w.lastStamp {
timestamp = w.getCurrentTime()
}
return timestamp
}

View File

@@ -1,159 +1,159 @@
package cluster
import (
"fmt"
"github.com/HDT3213/godis/src/db"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"strconv"
"fmt"
"github.com/HDT3213/godis/src/db"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"strconv"
)
func MGet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command")
}
keys := make([]string, len(args)-1)
for i := 1; i < len(args); i++ {
keys[i-1] = string(args[i])
}
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command")
}
keys := make([]string, len(args)-1)
for i := 1; i < len(args); i++ {
keys[i-1] = string(args[i])
}
resultMap := make(map[string][]byte)
groupMap := cluster.groupBy(keys)
for peer, group := range groupMap {
resp := cluster.Relay(peer, c, makeArgs("MGET", group...))
if reply.IsErrorReply(resp) {
errReply := resp.(reply.ErrorReply)
return reply.MakeErrReply(fmt.Sprintf("ERR during get %s occurs: %v", group[0], errReply.Error()))
}
arrReply, _ := resp.(*reply.MultiBulkReply)
for i, v := range arrReply.Args {
key := group[i]
resultMap[key] = v
}
}
result := make([][]byte, len(keys))
for i, k := range keys {
result[i] = resultMap[k]
}
return reply.MakeMultiBulkReply(result)
resultMap := make(map[string][]byte)
groupMap := cluster.groupBy(keys)
for peer, group := range groupMap {
resp := cluster.Relay(peer, c, makeArgs("MGET", group...))
if reply.IsErrorReply(resp) {
errReply := resp.(reply.ErrorReply)
return reply.MakeErrReply(fmt.Sprintf("ERR during get %s occurs: %v", group[0], errReply.Error()))
}
arrReply, _ := resp.(*reply.MultiBulkReply)
for i, v := range arrReply.Args {
key := group[i]
resultMap[key] = v
}
}
result := make([][]byte, len(keys))
for i, k := range keys {
result[i] = resultMap[k]
}
return reply.MakeMultiBulkReply(result)
}
// args: PrepareMSet id keys...
func PrepareMSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) < 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'preparemset' command")
}
txId := string(args[1])
size := (len(args) - 2) / 2
keys := make([]string, size)
for i := 0; i < size; i++ {
keys[i] = string(args[2*i+2])
}
if len(args) < 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'preparemset' command")
}
txId := string(args[1])
size := (len(args) - 2) / 2
keys := make([]string, size)
for i := 0; i < size; i++ {
keys[i] = string(args[2*i+2])
}
txArgs := [][]byte{
[]byte("MSet"),
} // actual args for cluster.db
txArgs = append(txArgs, args[2:]...)
tx := NewTransaction(cluster, c, txId, txArgs, keys)
cluster.transactions.Put(txId, tx)
err := tx.prepare()
if err != nil {
return reply.MakeErrReply(err.Error())
}
return &reply.OkReply{}
txArgs := [][]byte{
[]byte("MSet"),
} // actual args for cluster.db
txArgs = append(txArgs, args[2:]...)
tx := NewTransaction(cluster, c, txId, txArgs, keys)
cluster.transactions.Put(txId, tx)
err := tx.prepare()
if err != nil {
return reply.MakeErrReply(err.Error())
}
return &reply.OkReply{}
}
// invoker should provide lock
func CommitMSet(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply {
size := len(tx.args) / 2
keys := make([]string, size)
values := make([][]byte, size)
for i := 0; i < size; i++ {
keys[i] = string(tx.args[2*i+1])
values[i] = tx.args[2*i+2]
}
for i, key := range keys {
value := values[i]
cluster.db.Put(key, &db.DataEntity{Data: value})
}
cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args))
return &reply.OkReply{}
size := len(tx.args) / 2
keys := make([]string, size)
values := make([][]byte, size)
for i := 0; i < size; i++ {
keys[i] = string(tx.args[2*i+1])
values[i] = tx.args[2*i+2]
}
for i, key := range keys {
value := values[i]
cluster.db.Put(key, &db.DataEntity{Data: value})
}
cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args))
return &reply.OkReply{}
}
func MSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
argCount := len(args) - 1
if argCount%2 != 0 || argCount < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
}
argCount := len(args) - 1
if argCount%2 != 0 || argCount < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
}
size := argCount / 2
keys := make([]string, size)
valueMap := make(map[string]string)
for i := 0; i < size; i++ {
keys[i] = string(args[2*i+1])
valueMap[keys[i]] = string(args[2*i+2])
}
size := argCount / 2
keys := make([]string, size)
valueMap := make(map[string]string)
for i := 0; i < size; i++ {
keys[i] = string(args[2*i+1])
valueMap[keys[i]] = string(args[2*i+2])
}
groupMap := cluster.groupBy(keys)
if len(groupMap) == 1 { // do fast
for peer := range groupMap {
return cluster.Relay(peer, c, args)
}
}
groupMap := cluster.groupBy(keys)
if len(groupMap) == 1 { // do fast
for peer := range groupMap {
return cluster.Relay(peer, c, args)
}
}
//prepare
var errReply redis.Reply
txId := cluster.idGenerator.NextId()
txIdStr := strconv.FormatInt(txId, 10)
rollback := false
for peer, group := range groupMap {
peerArgs := []string{txIdStr}
for _, k := range group {
peerArgs = append(peerArgs, k, valueMap[k])
}
var resp redis.Reply
if peer == cluster.self {
resp = PrepareMSet(cluster, c, makeArgs("PrepareMSet", peerArgs...))
} else {
resp = cluster.Relay(peer, c, makeArgs("PrepareMSet", peerArgs...))
}
if reply.IsErrorReply(resp) {
errReply = resp
rollback = true
break
}
}
if rollback {
// rollback
RequestRollback(cluster, c, txId, groupMap)
} else {
_, errReply = RequestCommit(cluster, c, txId, groupMap)
rollback = errReply != nil
}
if !rollback {
return &reply.OkReply{}
}
return errReply
//prepare
var errReply redis.Reply
txId := cluster.idGenerator.NextId()
txIdStr := strconv.FormatInt(txId, 10)
rollback := false
for peer, group := range groupMap {
peerArgs := []string{txIdStr}
for _, k := range group {
peerArgs = append(peerArgs, k, valueMap[k])
}
var resp redis.Reply
if peer == cluster.self {
resp = PrepareMSet(cluster, c, makeArgs("PrepareMSet", peerArgs...))
} else {
resp = cluster.Relay(peer, c, makeArgs("PrepareMSet", peerArgs...))
}
if reply.IsErrorReply(resp) {
errReply = resp
rollback = true
break
}
}
if rollback {
// rollback
RequestRollback(cluster, c, txId, groupMap)
} else {
_, errReply = RequestCommit(cluster, c, txId, groupMap)
rollback = errReply != nil
}
if !rollback {
return &reply.OkReply{}
}
return errReply
}
func MSetNX(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
argCount := len(args) - 1
if argCount%2 != 0 || argCount < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
}
var peer string
size := argCount / 2
for i := 0; i < size; i++ {
key := string(args[2*i])
currentPeer := cluster.peerPicker.Get(key)
if peer == "" {
peer = currentPeer
} else {
if peer != currentPeer {
return reply.MakeErrReply("ERR msetnx must within one slot in cluster mode")
}
}
}
return cluster.Relay(peer, c, args)
argCount := len(args) - 1
if argCount%2 != 0 || argCount < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
}
var peer string
size := argCount / 2
for i := 0; i < size; i++ {
key := string(args[2*i])
currentPeer := cluster.peerPicker.Get(key)
if peer == "" {
peer = currentPeer
} else {
if peer != currentPeer {
return reply.MakeErrReply("ERR msetnx must within one slot in cluster mode")
}
}
}
return cluster.Relay(peer, c, args)
}

View File

@@ -1,40 +1,39 @@
package cluster
import (
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
)
// TODO: support multiplex slots
func Rename(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rename' command")
}
src := string(args[1])
dest := string(args[2])
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rename' command")
}
src := string(args[1])
dest := string(args[2])
srcPeer := cluster.peerPicker.Get(src)
destPeer := cluster.peerPicker.Get(dest)
srcPeer := cluster.peerPicker.Get(src)
destPeer := cluster.peerPicker.Get(dest)
if srcPeer != destPeer {
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
}
return cluster.Relay(srcPeer, c, args)
if srcPeer != destPeer {
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
}
return cluster.Relay(srcPeer, c, args)
}
func RenameNx(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'renamenx' command")
}
src := string(args[1])
dest := string(args[2])
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'renamenx' command")
}
src := string(args[1])
dest := string(args[2])
srcPeer := cluster.peerPicker.Get(src)
destPeer := cluster.peerPicker.Get(dest)
srcPeer := cluster.peerPicker.Get(src)
destPeer := cluster.peerPicker.Get(dest)
if srcPeer != destPeer {
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
}
return cluster.Relay(srcPeer, c, args)
if srcPeer != destPeer {
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
}
return cluster.Relay(srcPeer, c, args)
}

View File

@@ -3,106 +3,106 @@ package cluster
import "github.com/HDT3213/godis/src/interface/redis"
func MakeRouter() map[string]CmdFunc {
routerMap := make(map[string]CmdFunc)
routerMap["ping"] = Ping
routerMap := make(map[string]CmdFunc)
routerMap["ping"] = Ping
routerMap["commit"] = Commit
routerMap["rollback"] = Rollback
routerMap["del"] = Del
routerMap["preparedel"] = PrepareDel
routerMap["preparemset"] = PrepareMSet
routerMap["commit"] = Commit
routerMap["rollback"] = Rollback
routerMap["del"] = Del
routerMap["preparedel"] = PrepareDel
routerMap["preparemset"] = PrepareMSet
routerMap["expire"] = defaultFunc
routerMap["expireat"] = defaultFunc
routerMap["pexpire"] = defaultFunc
routerMap["pexpireat"] = defaultFunc
routerMap["ttl"] = defaultFunc
routerMap["pttl"] = defaultFunc
routerMap["persist"] = defaultFunc
routerMap["exists"] = defaultFunc
routerMap["type"] = defaultFunc
routerMap["rename"] = Rename
routerMap["renamenx"] = RenameNx
routerMap["expire"] = defaultFunc
routerMap["expireat"] = defaultFunc
routerMap["pexpire"] = defaultFunc
routerMap["pexpireat"] = defaultFunc
routerMap["ttl"] = defaultFunc
routerMap["pttl"] = defaultFunc
routerMap["persist"] = defaultFunc
routerMap["exists"] = defaultFunc
routerMap["type"] = defaultFunc
routerMap["rename"] = Rename
routerMap["renamenx"] = RenameNx
routerMap["set"] = defaultFunc
routerMap["setnx"] = defaultFunc
routerMap["setex"] = defaultFunc
routerMap["psetex"] = defaultFunc
routerMap["mset"] = MSet
routerMap["mget"] = MGet
routerMap["msetnx"] = MSetNX
routerMap["get"] = defaultFunc
routerMap["getset"] = defaultFunc
routerMap["incr"] = defaultFunc
routerMap["incrby"] = defaultFunc
routerMap["incrbyfloat"] = defaultFunc
routerMap["decr"] = defaultFunc
routerMap["decrby"] = defaultFunc
routerMap["set"] = defaultFunc
routerMap["setnx"] = defaultFunc
routerMap["setex"] = defaultFunc
routerMap["psetex"] = defaultFunc
routerMap["mset"] = MSet
routerMap["mget"] = MGet
routerMap["msetnx"] = MSetNX
routerMap["get"] = defaultFunc
routerMap["getset"] = defaultFunc
routerMap["incr"] = defaultFunc
routerMap["incrby"] = defaultFunc
routerMap["incrbyfloat"] = defaultFunc
routerMap["decr"] = defaultFunc
routerMap["decrby"] = defaultFunc
routerMap["lpush"] = defaultFunc
routerMap["lpushx"] = defaultFunc
routerMap["rpush"] = defaultFunc
routerMap["rpushx"] = defaultFunc
routerMap["lpop"] = defaultFunc
routerMap["rpop"] = defaultFunc
//routerMap["rpoplpush"] = RPopLPush
routerMap["lrem"] = defaultFunc
routerMap["llen"] = defaultFunc
routerMap["lindex"] = defaultFunc
routerMap["lset"] = defaultFunc
routerMap["lrange"] = defaultFunc
routerMap["lpush"] = defaultFunc
routerMap["lpushx"] = defaultFunc
routerMap["rpush"] = defaultFunc
routerMap["rpushx"] = defaultFunc
routerMap["lpop"] = defaultFunc
routerMap["rpop"] = defaultFunc
//routerMap["rpoplpush"] = RPopLPush
routerMap["lrem"] = defaultFunc
routerMap["llen"] = defaultFunc
routerMap["lindex"] = defaultFunc
routerMap["lset"] = defaultFunc
routerMap["lrange"] = defaultFunc
routerMap["hset"] = defaultFunc
routerMap["hsetnx"] = defaultFunc
routerMap["hget"] = defaultFunc
routerMap["hexists"] = defaultFunc
routerMap["hdel"] = defaultFunc
routerMap["hlen"] = defaultFunc
routerMap["hmget"] = defaultFunc
routerMap["hmset"] = defaultFunc
routerMap["hkeys"] = defaultFunc
routerMap["hvals"] = defaultFunc
routerMap["hgetall"] = defaultFunc
routerMap["hincrby"] = defaultFunc
routerMap["hincrbyfloat"] = defaultFunc
routerMap["hset"] = defaultFunc
routerMap["hsetnx"] = defaultFunc
routerMap["hget"] = defaultFunc
routerMap["hexists"] = defaultFunc
routerMap["hdel"] = defaultFunc
routerMap["hlen"] = defaultFunc
routerMap["hmget"] = defaultFunc
routerMap["hmset"] = defaultFunc
routerMap["hkeys"] = defaultFunc
routerMap["hvals"] = defaultFunc
routerMap["hgetall"] = defaultFunc
routerMap["hincrby"] = defaultFunc
routerMap["hincrbyfloat"] = defaultFunc
routerMap["sadd"] = defaultFunc
routerMap["sismember"] = defaultFunc
routerMap["srem"] = defaultFunc
routerMap["scard"] = defaultFunc
routerMap["smembers"] = defaultFunc
routerMap["sinter"] = defaultFunc
routerMap["sinterstore"] = defaultFunc
routerMap["sunion"] = defaultFunc
routerMap["sunionstore"] = defaultFunc
routerMap["sdiff"] = defaultFunc
routerMap["sdiffstore"] = defaultFunc
routerMap["srandmember"] = defaultFunc
routerMap["sadd"] = defaultFunc
routerMap["sismember"] = defaultFunc
routerMap["srem"] = defaultFunc
routerMap["scard"] = defaultFunc
routerMap["smembers"] = defaultFunc
routerMap["sinter"] = defaultFunc
routerMap["sinterstore"] = defaultFunc
routerMap["sunion"] = defaultFunc
routerMap["sunionstore"] = defaultFunc
routerMap["sdiff"] = defaultFunc
routerMap["sdiffstore"] = defaultFunc
routerMap["srandmember"] = defaultFunc
routerMap["zadd"] = defaultFunc
routerMap["zscore"] = defaultFunc
routerMap["zincrby"] = defaultFunc
routerMap["zrank"] = defaultFunc
routerMap["zcount"] = defaultFunc
routerMap["zrevrank"] = defaultFunc
routerMap["zcard"] = defaultFunc
routerMap["zrange"] = defaultFunc
routerMap["zrevrange"] = defaultFunc
routerMap["zrangebyscore"] = defaultFunc
routerMap["zrevrangebyscore"] = defaultFunc
routerMap["zrem"] = defaultFunc
routerMap["zremrangebyscore"] = defaultFunc
routerMap["zremrangebyrank"] = defaultFunc
routerMap["zadd"] = defaultFunc
routerMap["zscore"] = defaultFunc
routerMap["zincrby"] = defaultFunc
routerMap["zrank"] = defaultFunc
routerMap["zcount"] = defaultFunc
routerMap["zrevrank"] = defaultFunc
routerMap["zcard"] = defaultFunc
routerMap["zrange"] = defaultFunc
routerMap["zrevrange"] = defaultFunc
routerMap["zrangebyscore"] = defaultFunc
routerMap["zrevrangebyscore"] = defaultFunc
routerMap["zrem"] = defaultFunc
routerMap["zremrangebyscore"] = defaultFunc
routerMap["zremrangebyrank"] = defaultFunc
//routerMap["flushdb"] = FlushDB
//routerMap["flushall"] = FlushAll
//routerMap["keys"] = Keys
//routerMap["flushdb"] = FlushDB
//routerMap["flushall"] = FlushAll
//routerMap["keys"] = Keys
return routerMap
return routerMap
}
func defaultFunc(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
key := string(args[1])
peer := cluster.peerPicker.Get(key)
return cluster.Relay(peer, c, args)
key := string(args[1])
peer := cluster.peerPicker.Get(key)
return cluster.Relay(peer, c, args)
}

View File

@@ -1,188 +1,188 @@
package cluster
import (
"context"
"fmt"
"github.com/HDT3213/godis/src/db"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/lib/marshal/gob"
"github.com/HDT3213/godis/src/redis/reply"
"strconv"
"strings"
"time"
"context"
"fmt"
"github.com/HDT3213/godis/src/db"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/lib/marshal/gob"
"github.com/HDT3213/godis/src/redis/reply"
"strconv"
"strings"
"time"
)
type Transaction struct {
id string // transaction id
args [][]byte // cmd args
cluster *Cluster
conn redis.Connection
id string // transaction id
args [][]byte // cmd args
cluster *Cluster
conn redis.Connection
keys []string // related keys
undoLog map[string][]byte // store data for undoLog
keys []string // related keys
undoLog map[string][]byte // store data for undoLog
lockUntil time.Time
ctx context.Context
cancel context.CancelFunc
status int8
lockUntil time.Time
ctx context.Context
cancel context.CancelFunc
status int8
}
const (
maxLockTime = 3 * time.Second
maxLockTime = 3 * time.Second
CreatedStatus = 0
PreparedStatus = 1
CommitedStatus = 2
RollbackedStatus = 3
CreatedStatus = 0
PreparedStatus = 1
CommitedStatus = 2
RollbackedStatus = 3
)
func NewTransaction(cluster *Cluster, c redis.Connection, id string, args [][]byte, keys []string) *Transaction {
return &Transaction{
id: id,
args: args,
cluster: cluster,
conn: c,
keys: keys,
status: CreatedStatus,
}
return &Transaction{
id: id,
args: args,
cluster: cluster,
conn: c,
keys: keys,
status: CreatedStatus,
}
}
// t should contains Keys field
func (tx *Transaction) prepare() error {
// lock keys
tx.cluster.db.Locks(tx.keys...)
// lock keys
tx.cluster.db.Locks(tx.keys...)
// use context to manage
//tx.lockUntil = time.Now().Add(maxLockTime)
//ctx, cancel := context.WithDeadline(context.Background(), tx.lockUntil)
//tx.ctx = ctx
//tx.cancel = cancel
// use context to manage
//tx.lockUntil = time.Now().Add(maxLockTime)
//ctx, cancel := context.WithDeadline(context.Background(), tx.lockUntil)
//tx.ctx = ctx
//tx.cancel = cancel
// build undoLog
tx.undoLog = make(map[string][]byte)
for _, key := range tx.keys {
entity, ok := tx.cluster.db.Get(key)
if ok {
blob, err := gob.Marshal(entity)
if err != nil {
return err
}
tx.undoLog[key] = blob
} else {
tx.undoLog[key] = []byte{} // entity was nil, should be removed while rollback
}
}
tx.status = PreparedStatus
return nil
// build undoLog
tx.undoLog = make(map[string][]byte)
for _, key := range tx.keys {
entity, ok := tx.cluster.db.Get(key)
if ok {
blob, err := gob.Marshal(entity)
if err != nil {
return err
}
tx.undoLog[key] = blob
} else {
tx.undoLog[key] = []byte{} // entity was nil, should be removed while rollback
}
}
tx.status = PreparedStatus
return nil
}
func (tx *Transaction) rollback() error {
for key, blob := range tx.undoLog {
if len(blob) > 0 {
entity := &db.DataEntity{}
err := gob.UnMarshal(blob, entity)
if err != nil {
return err
}
tx.cluster.db.Put(key, entity)
} else {
tx.cluster.db.Remove(key)
}
}
if tx.status != CommitedStatus {
tx.cluster.db.UnLocks(tx.keys...)
}
tx.status = RollbackedStatus
return nil
for key, blob := range tx.undoLog {
if len(blob) > 0 {
entity := &db.DataEntity{}
err := gob.UnMarshal(blob, entity)
if err != nil {
return err
}
tx.cluster.db.Put(key, entity)
} else {
tx.cluster.db.Remove(key)
}
}
if tx.status != CommitedStatus {
tx.cluster.db.UnLocks(tx.keys...)
}
tx.status = RollbackedStatus
return nil
}
// rollback local transaction
func Rollback(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rollback' command")
}
txId := string(args[1])
raw, ok := cluster.transactions.Get(txId)
if !ok {
return reply.MakeIntReply(0)
}
tx, _ := raw.(*Transaction)
err := tx.rollback()
if err != nil {
return reply.MakeErrReply(err.Error())
}
return reply.MakeIntReply(1)
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rollback' command")
}
txId := string(args[1])
raw, ok := cluster.transactions.Get(txId)
if !ok {
return reply.MakeIntReply(0)
}
tx, _ := raw.(*Transaction)
err := tx.rollback()
if err != nil {
return reply.MakeErrReply(err.Error())
}
return reply.MakeIntReply(1)
}
// commit local transaction as a worker
func Commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'commit' command")
}
txId := string(args[1])
raw, ok := cluster.transactions.Get(txId)
if !ok {
return reply.MakeIntReply(0)
}
tx, _ := raw.(*Transaction)
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'commit' command")
}
txId := string(args[1])
raw, ok := cluster.transactions.Get(txId)
if !ok {
return reply.MakeIntReply(0)
}
tx, _ := raw.(*Transaction)
// finish transaction
defer func() {
cluster.db.UnLocks(tx.keys...)
tx.status = CommitedStatus
//cluster.transactions.Remove(tx.id) // cannot remove, may rollback after commit
}()
// finish transaction
defer func() {
cluster.db.UnLocks(tx.keys...)
tx.status = CommitedStatus
//cluster.transactions.Remove(tx.id) // cannot remove, may rollback after commit
}()
cmd := strings.ToLower(string(tx.args[0]))
var result redis.Reply
if cmd == "del" {
result = CommitDel(cluster, c, tx)
} else if cmd == "mset" {
result = CommitMSet(cluster, c, tx)
}
cmd := strings.ToLower(string(tx.args[0]))
var result redis.Reply
if cmd == "del" {
result = CommitDel(cluster, c, tx)
} else if cmd == "mset" {
result = CommitMSet(cluster, c, tx)
}
if reply.IsErrorReply(result) {
// failed
err2 := tx.rollback()
return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result))
}
if reply.IsErrorReply(result) {
// failed
err2 := tx.rollback()
return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result))
}
return result
return result
}
// request all node commit transaction as leader
func RequestCommit(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) ([]redis.Reply, reply.ErrorReply) {
var errReply reply.ErrorReply
txIdStr := strconv.FormatInt(txId, 10)
respList := make([]redis.Reply, 0, len(peers))
for peer := range peers {
var resp redis.Reply
if peer == cluster.self {
resp = Commit(cluster, c, makeArgs("commit", txIdStr))
} else {
resp = cluster.Relay(peer, c, makeArgs("commit", txIdStr))
}
if reply.IsErrorReply(resp) {
errReply = resp.(reply.ErrorReply)
break
}
respList = append(respList, resp)
}
if errReply != nil {
RequestRollback(cluster, c, txId, peers)
return nil, errReply
}
return respList, nil
var errReply reply.ErrorReply
txIdStr := strconv.FormatInt(txId, 10)
respList := make([]redis.Reply, 0, len(peers))
for peer := range peers {
var resp redis.Reply
if peer == cluster.self {
resp = Commit(cluster, c, makeArgs("commit", txIdStr))
} else {
resp = cluster.Relay(peer, c, makeArgs("commit", txIdStr))
}
if reply.IsErrorReply(resp) {
errReply = resp.(reply.ErrorReply)
break
}
respList = append(respList, resp)
}
if errReply != nil {
RequestRollback(cluster, c, txId, peers)
return nil, errReply
}
return respList, nil
}
// request all node rollback transaction as leader
func RequestRollback(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) {
txIdStr := strconv.FormatInt(txId, 10)
for peer := range peers {
if peer == cluster.self {
Rollback(cluster, c, makeArgs("rollback", txIdStr))
} else {
cluster.Relay(peer, c, makeArgs("rollback", txIdStr))
}
}
txIdStr := strconv.FormatInt(txId, 10)
for peer := range peers {
if peer == cluster.self {
Rollback(cluster, c, makeArgs("rollback", txIdStr))
} else {
cluster.Relay(peer, c, makeArgs("rollback", txIdStr))
}
}
}

View File

@@ -4,7 +4,7 @@ import (
"fmt"
"github.com/HDT3213/godis/src/config"
"github.com/HDT3213/godis/src/lib/logger"
RedisServer "github.com/HDT3213/godis/src/redis/server"
RedisServer "github.com/HDT3213/godis/src/redis/server"
"github.com/HDT3213/godis/src/tcp"
"os"
)
@@ -23,6 +23,6 @@ func main() {
})
tcp.ListenAndServe(&tcp.Config{
Address: fmt.Sprintf("%s:%d", config.Properties.Bind, config.Properties.Port),
}, RedisServer.MakeHandler())
Address: fmt.Sprintf("%s:%d", config.Properties.Bind, config.Properties.Port),
}, RedisServer.MakeHandler())
}

View File

@@ -23,12 +23,12 @@ type PropertyHolder struct {
var Properties *PropertyHolder
func init() {
// default config
Properties = &PropertyHolder{
Bind: "127.0.0.1",
Port: 6379,
AppendOnly: false,
}
// default config
Properties = &PropertyHolder{
Bind: "127.0.0.1",
Port: 6379,
AppendOnly: false,
}
}
func LoadConfig(configFilename string) *PropertyHolder {

View File

@@ -1,16 +1,16 @@
package dict
type Consumer func(key string, val interface{})bool
type Consumer func(key string, val interface{}) bool
type Dict interface {
Get(key string) (val interface{}, exists bool)
Len() int
Put(key string, val interface{}) (result int)
PutIfAbsent(key string, val interface{}) (result int)
PutIfExists(key string, val interface{}) (result int)
Remove(key string) (result int)
ForEach(consumer Consumer)
Keys() []string
RandomKeys(limit int) []string
RandomDistinctKeys(limit int) []string
Get(key string) (val interface{}, exists bool)
Len() int
Put(key string, val interface{}) (result int)
PutIfAbsent(key string, val interface{}) (result int)
PutIfExists(key string, val interface{}) (result int)
Remove(key string) (result int)
ForEach(consumer Consumer)
Keys() []string
RandomKeys(limit int) []string
RandomDistinctKeys(limit int) []string
}

View File

@@ -1,244 +1,244 @@
package dict
import (
"strconv"
"sync"
"testing"
"strconv"
"sync"
"testing"
)
func TestPut(t *testing.T) {
d := MakeConcurrent(0)
count := 100
var wg sync.WaitGroup
wg.Add(count)
for i := 0; i < count; i++ {
go func(i int) {
// insert
key := "k" + strconv.Itoa(i)
ret := d.Put(key, i)
if ret != 1 { // insert 1
t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key)
}
val, ok := d.Get(key)
if ok {
intVal, _ := val.(int)
if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key)
}
} else {
_, ok := d.Get(key)
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
}
wg.Done()
}(i)
}
wg.Wait()
d := MakeConcurrent(0)
count := 100
var wg sync.WaitGroup
wg.Add(count)
for i := 0; i < count; i++ {
go func(i int) {
// insert
key := "k" + strconv.Itoa(i)
ret := d.Put(key, i)
if ret != 1 { // insert 1
t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key)
}
val, ok := d.Get(key)
if ok {
intVal, _ := val.(int)
if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key)
}
} else {
_, ok := d.Get(key)
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
}
wg.Done()
}(i)
}
wg.Wait()
}
func TestPutIfAbsent(t *testing.T) {
d := MakeConcurrent(0)
count := 100
var wg sync.WaitGroup
wg.Add(count)
for i := 0; i < count; i++ {
go func(i int) {
// insert
key := "k" + strconv.Itoa(i)
ret := d.PutIfAbsent(key, i)
if ret != 1 { // insert 1
t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key)
}
val, ok := d.Get(key)
if ok {
intVal, _ := val.(int)
if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) +
", key: " + key)
}
} else {
_, ok := d.Get(key)
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
}
d := MakeConcurrent(0)
count := 100
var wg sync.WaitGroup
wg.Add(count)
for i := 0; i < count; i++ {
go func(i int) {
// insert
key := "k" + strconv.Itoa(i)
ret := d.PutIfAbsent(key, i)
if ret != 1 { // insert 1
t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key)
}
val, ok := d.Get(key)
if ok {
intVal, _ := val.(int)
if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) +
", key: " + key)
}
} else {
_, ok := d.Get(key)
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
}
// update
ret = d.PutIfAbsent(key, i * 10)
if ret != 0 { // no update
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
}
val, ok = d.Get(key)
if ok {
intVal, _ := val.(int)
if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key)
}
} else {
t.Error("put test failed: expected true, actual: false, key: " + key)
}
wg.Done()
}(i)
}
wg.Wait()
// update
ret = d.PutIfAbsent(key, i*10)
if ret != 0 { // no update
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
}
val, ok = d.Get(key)
if ok {
intVal, _ := val.(int)
if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key)
}
} else {
t.Error("put test failed: expected true, actual: false, key: " + key)
}
wg.Done()
}(i)
}
wg.Wait()
}
func TestPutIfExists(t *testing.T) {
d := MakeConcurrent(0)
count := 100
var wg sync.WaitGroup
wg.Add(count)
for i := 0; i < count; i++ {
go func(i int) {
// insert
key := "k" + strconv.Itoa(i)
// insert
ret := d.PutIfExists(key, i)
if ret != 0 { // insert
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
}
d := MakeConcurrent(0)
count := 100
var wg sync.WaitGroup
wg.Add(count)
for i := 0; i < count; i++ {
go func(i int) {
// insert
key := "k" + strconv.Itoa(i)
// insert
ret := d.PutIfExists(key, i)
if ret != 0 { // insert
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
}
d.Put(key, i)
ret = d.PutIfExists(key, 10 * i)
val, ok := d.Get(key)
if ok {
intVal, _ := val.(int)
if intVal != 10 * i {
t.Error("put test failed: expected " + strconv.Itoa(10 * i) + ", actual: " + strconv.Itoa(intVal))
}
} else {
_, ok := d.Get(key)
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
}
wg.Done()
}(i)
}
wg.Wait()
d.Put(key, i)
ret = d.PutIfExists(key, 10*i)
val, ok := d.Get(key)
if ok {
intVal, _ := val.(int)
if intVal != 10*i {
t.Error("put test failed: expected " + strconv.Itoa(10*i) + ", actual: " + strconv.Itoa(intVal))
}
} else {
_, ok := d.Get(key)
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
}
wg.Done()
}(i)
}
wg.Wait()
}
func TestRemove(t *testing.T) {
d := MakeConcurrent(0)
d := MakeConcurrent(0)
// remove head node
for i := 0; i < 100; i++ {
// insert
key := "k" + strconv.Itoa(i)
d.Put(key, i)
}
for i := 0; i < 100; i++ {
key := "k" + strconv.Itoa(i)
// remove head node
for i := 0; i < 100; i++ {
// insert
key := "k" + strconv.Itoa(i)
d.Put(key, i)
}
for i := 0; i < 100; i++ {
key := "k" + strconv.Itoa(i)
val, ok := d.Get(key)
if ok {
intVal, _ := val.(int)
if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
} else {
t.Error("put test failed: expected true, actual: false")
}
val, ok := d.Get(key)
if ok {
intVal, _ := val.(int)
if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
} else {
t.Error("put test failed: expected true, actual: false")
}
ret := d.Remove(key)
if ret != 1 {
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key:" + key)
}
_, ok = d.Get(key)
if ok {
t.Error("remove test failed: expected true, actual false")
}
ret = d.Remove(key)
if ret != 0 {
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
}
}
ret := d.Remove(key)
if ret != 1 {
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key:" + key)
}
_, ok = d.Get(key)
if ok {
t.Error("remove test failed: expected true, actual false")
}
ret = d.Remove(key)
if ret != 0 {
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
}
}
// remove tail node
d = MakeConcurrent(0)
for i := 0; i < 100; i++ {
// insert
key := "k" + strconv.Itoa(i)
d.Put(key, i)
}
for i := 9; i >= 0; i-- {
key := "k" + strconv.Itoa(i)
// remove tail node
d = MakeConcurrent(0)
for i := 0; i < 100; i++ {
// insert
key := "k" + strconv.Itoa(i)
d.Put(key, i)
}
for i := 9; i >= 0; i-- {
key := "k" + strconv.Itoa(i)
val, ok := d.Get(key)
if ok {
intVal, _ := val.(int)
if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
} else {
t.Error("put test failed: expected true, actual: false")
}
val, ok := d.Get(key)
if ok {
intVal, _ := val.(int)
if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
} else {
t.Error("put test failed: expected true, actual: false")
}
ret := d.Remove(key)
if ret != 1 {
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret))
}
_, ok = d.Get(key)
if ok {
t.Error("remove test failed: expected true, actual false")
}
ret = d.Remove(key)
if ret != 0 {
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
}
}
ret := d.Remove(key)
if ret != 1 {
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret))
}
_, ok = d.Get(key)
if ok {
t.Error("remove test failed: expected true, actual false")
}
ret = d.Remove(key)
if ret != 0 {
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
}
}
// remove middle node
d = MakeConcurrent(0)
d.Put("head", 0)
for i := 0; i < 10; i++ {
// insert
key := "k" + strconv.Itoa(i)
d.Put(key, i)
}
d.Put("tail", 0)
for i := 9; i >= 0; i-- {
key := "k" + strconv.Itoa(i)
// remove middle node
d = MakeConcurrent(0)
d.Put("head", 0)
for i := 0; i < 10; i++ {
// insert
key := "k" + strconv.Itoa(i)
d.Put(key, i)
}
d.Put("tail", 0)
for i := 9; i >= 0; i-- {
key := "k" + strconv.Itoa(i)
val, ok := d.Get(key)
if ok {
intVal, _ := val.(int)
if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
} else {
t.Error("put test failed: expected true, actual: false")
}
val, ok := d.Get(key)
if ok {
intVal, _ := val.(int)
if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
} else {
t.Error("put test failed: expected true, actual: false")
}
ret := d.Remove(key)
if ret != 1 {
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret))
}
_, ok = d.Get(key)
if ok {
t.Error("remove test failed: expected true, actual false")
}
ret = d.Remove(key)
if ret != 0 {
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
}
}
ret := d.Remove(key)
if ret != 1 {
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret))
}
_, ok = d.Get(key)
if ok {
t.Error("remove test failed: expected true, actual false")
}
ret = d.Remove(key)
if ret != 0 {
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
}
}
}
func TestForEach(t *testing.T) {
d := MakeConcurrent(0)
size := 100
for i := 0; i < size; i++ {
// insert
key := "k" + strconv.Itoa(i)
d.Put(key, i)
}
i := 0
d.ForEach(func(key string, value interface{})bool {
intVal, _ := value.(int)
expectedKey := "k" + strconv.Itoa(intVal)
if key != expectedKey {
t.Error("remove test failed: expected " + expectedKey + ", actual: " + key)
}
i++
return true
})
if i != size {
t.Error("remove test failed: expected " + strconv.Itoa(size) + ", actual: " + strconv.Itoa(i))
}
d := MakeConcurrent(0)
size := 100
for i := 0; i < size; i++ {
// insert
key := "k" + strconv.Itoa(i)
d.Put(key, i)
}
i := 0
d.ForEach(func(key string, value interface{}) bool {
intVal, _ := value.(int)
expectedKey := "k" + strconv.Itoa(intVal)
if key != expectedKey {
t.Error("remove test failed: expected " + expectedKey + ", actual: " + key)
}
i++
return true
})
if i != size {
t.Error("remove test failed: expected " + strconv.Itoa(size) + ", actual: " + strconv.Itoa(i))
}
}

View File

@@ -1,108 +1,108 @@
package dict
type SimpleDict struct {
m map[string]interface{}
m map[string]interface{}
}
func MakeSimple() *SimpleDict {
return &SimpleDict{
m: make(map[string]interface{}),
}
return &SimpleDict{
m: make(map[string]interface{}),
}
}
func (dict *SimpleDict) Get(key string) (val interface{}, exists bool) {
val, ok := dict.m[key]
return val, ok
val, ok := dict.m[key]
return val, ok
}
func (dict *SimpleDict) Len() int {
if dict.m == nil {
panic("m is nil")
}
return len(dict.m)
if dict.m == nil {
panic("m is nil")
}
return len(dict.m)
}
func (dict *SimpleDict) Put(key string, val interface{}) (result int) {
_, existed := dict.m[key]
dict.m[key] = val
if existed {
return 0
} else {
return 1
}
_, existed := dict.m[key]
dict.m[key] = val
if existed {
return 0
} else {
return 1
}
}
func (dict *SimpleDict) PutIfAbsent(key string, val interface{}) (result int) {
_, existed := dict.m[key]
if existed {
return 0
} else {
dict.m[key] = val
return 1
}
_, existed := dict.m[key]
if existed {
return 0
} else {
dict.m[key] = val
return 1
}
}
func (dict *SimpleDict) PutIfExists(key string, val interface{}) (result int) {
_, existed := dict.m[key]
if existed {
dict.m[key] = val
return 1
} else {
return 0
}
_, existed := dict.m[key]
if existed {
dict.m[key] = val
return 1
} else {
return 0
}
}
func (dict *SimpleDict) Remove(key string) (result int) {
_, existed := dict.m[key]
delete(dict.m, key)
if existed {
return 1
} else {
return 0
}
_, existed := dict.m[key]
delete(dict.m, key)
if existed {
return 1
} else {
return 0
}
}
func (dict *SimpleDict) Keys() []string {
result := make([]string, len(dict.m))
i := 0
for k := range dict.m {
result[i] = k
}
return result
result := make([]string, len(dict.m))
i := 0
for k := range dict.m {
result[i] = k
}
return result
}
func (dict *SimpleDict) ForEach(consumer Consumer) {
for k, v := range dict.m {
if !consumer(k, v) {
break
}
}
for k, v := range dict.m {
if !consumer(k, v) {
break
}
}
}
func (dict *SimpleDict) RandomKeys(limit int) []string {
result := make([]string, limit)
for i := 0; i < limit; i++ {
for k := range dict.m {
result[i] = k
break
}
}
return result
result := make([]string, limit)
for i := 0; i < limit; i++ {
for k := range dict.m {
result[i] = k
break
}
}
return result
}
func (dict *SimpleDict) RandomDistinctKeys(limit int) []string {
size := limit
if size > len(dict.m) {
size = len(dict.m)
}
result := make([]string, size)
i := 0
for k := range dict.m {
if i == limit {
break
}
result[i] = k
i++
}
return result
size := limit
if size > len(dict.m) {
size = len(dict.m)
}
result := make([]string, size)
i := 0
for k := range dict.m {
if i == limit {
break
}
result[i] = k
i++
}
return result
}

View File

@@ -3,324 +3,324 @@ package list
import "github.com/HDT3213/godis/src/datastruct/utils"
type LinkedList struct {
first *node
last *node
size int
first *node
last *node
size int
}
type node struct {
val interface{}
prev *node
next * node
val interface{}
prev *node
next *node
}
func (list *LinkedList)Add(val interface{}) {
if list == nil {
panic("list is nil")
}
n := &node{
val: val,
}
if list.last == nil {
// empty list
list.first = n
list.last = n
} else {
n.prev = list.last
list.last.next = n
list.last = n
}
list.size++
func (list *LinkedList) Add(val interface{}) {
if list == nil {
panic("list is nil")
}
n := &node{
val: val,
}
if list.last == nil {
// empty list
list.first = n
list.last = n
} else {
n.prev = list.last
list.last.next = n
list.last = n
}
list.size++
}
func (list *LinkedList)find(index int)(n *node) {
if index < list.size / 2 {
n := list.first
for i := 0; i < index; i++ {
n = n.next
}
return n
} else {
n := list.last
for i := list.size - 1; i > index; i-- {
n = n.prev
}
return n
}
func (list *LinkedList) find(index int) (n *node) {
if index < list.size/2 {
n := list.first
for i := 0; i < index; i++ {
n = n.next
}
return n
} else {
n := list.last
for i := list.size - 1; i > index; i-- {
n = n.prev
}
return n
}
}
func (list *LinkedList)Get(index int)(val interface{}) {
if list == nil {
panic("list is nil")
}
if index < 0 || index >= list.size {
panic("index out of bound")
}
return list.find(index).val
func (list *LinkedList) Get(index int) (val interface{}) {
if list == nil {
panic("list is nil")
}
if index < 0 || index >= list.size {
panic("index out of bound")
}
return list.find(index).val
}
func (list *LinkedList)Set(index int, val interface{}) {
if list == nil {
panic("list is nil")
}
if index < 0 || index > list.size {
panic("index out of bound")
}
n := list.find(index)
n.val = val
func (list *LinkedList) Set(index int, val interface{}) {
if list == nil {
panic("list is nil")
}
if index < 0 || index > list.size {
panic("index out of bound")
}
n := list.find(index)
n.val = val
}
func (list *LinkedList)Insert(index int, val interface{}) {
if list == nil {
panic("list is nil")
}
if index < 0 || index > list.size {
panic("index out of bound")
}
func (list *LinkedList) Insert(index int, val interface{}) {
if list == nil {
panic("list is nil")
}
if index < 0 || index > list.size {
panic("index out of bound")
}
if index == list.size {
list.Add(val)
return
} else {
// list is not empty
pivot := list.find(index)
n := &node{
val: val,
prev: pivot.prev,
next: pivot,
}
if pivot.prev == nil {
list.first = n
} else {
pivot.prev.next = n
}
pivot.prev = n
list.size++
}
if index == list.size {
list.Add(val)
return
} else {
// list is not empty
pivot := list.find(index)
n := &node{
val: val,
prev: pivot.prev,
next: pivot,
}
if pivot.prev == nil {
list.first = n
} else {
pivot.prev.next = n
}
pivot.prev = n
list.size++
}
}
func (list *LinkedList)removeNode(n *node) {
if n.prev == nil {
list.first = n.next
} else {
n.prev.next = n.next
}
if n.next == nil {
list.last = n.prev
} else {
n.next.prev = n.prev
}
func (list *LinkedList) removeNode(n *node) {
if n.prev == nil {
list.first = n.next
} else {
n.prev.next = n.next
}
if n.next == nil {
list.last = n.prev
} else {
n.next.prev = n.prev
}
// for gc
n.prev = nil
n.next = nil
// for gc
n.prev = nil
n.next = nil
list.size--
list.size--
}
func (list *LinkedList)Remove(index int)(val interface{}) {
if list == nil {
panic("list is nil")
}
if index < 0 || index >= list.size {
panic("index out of bound")
}
func (list *LinkedList) Remove(index int) (val interface{}) {
if list == nil {
panic("list is nil")
}
if index < 0 || index >= list.size {
panic("index out of bound")
}
n := list.find(index)
list.removeNode(n)
return n.val
n := list.find(index)
list.removeNode(n)
return n.val
}
func (list *LinkedList)RemoveLast()(val interface{}) {
if list == nil {
panic("list is nil")
}
if list.last == nil {
// empty list
return nil
}
n := list.last
list.removeNode(n)
return n.val
func (list *LinkedList) RemoveLast() (val interface{}) {
if list == nil {
panic("list is nil")
}
if list.last == nil {
// empty list
return nil
}
n := list.last
list.removeNode(n)
return n.val
}
func (list *LinkedList)RemoveAllByVal(val interface{})int {
if list == nil {
panic("list is nil")
}
n := list.first
removed := 0
for n != nil {
var toRemoveNode *node
if utils.Equals(n.val, val) {
toRemoveNode = n
}
if n.next == nil {
if toRemoveNode != nil {
removed++
list.removeNode(toRemoveNode)
}
break
} else {
n = n.next
}
if toRemoveNode != nil {
removed++
list.removeNode(toRemoveNode)
}
}
return removed
func (list *LinkedList) RemoveAllByVal(val interface{}) int {
if list == nil {
panic("list is nil")
}
n := list.first
removed := 0
for n != nil {
var toRemoveNode *node
if utils.Equals(n.val, val) {
toRemoveNode = n
}
if n.next == nil {
if toRemoveNode != nil {
removed++
list.removeNode(toRemoveNode)
}
break
} else {
n = n.next
}
if toRemoveNode != nil {
removed++
list.removeNode(toRemoveNode)
}
}
return removed
}
/**
* remove at most `count` values of the specified value in this list
* scan from left to right
*/
func (list *LinkedList) RemoveByVal(val interface{}, count int)int {
if list == nil {
panic("list is nil")
}
n := list.first
removed := 0
for n != nil {
var toRemoveNode *node
if utils.Equals(n.val, val) {
toRemoveNode = n
}
if n.next == nil {
if toRemoveNode != nil {
removed++
list.removeNode(toRemoveNode)
}
break
} else {
n = n.next
}
func (list *LinkedList) RemoveByVal(val interface{}, count int) int {
if list == nil {
panic("list is nil")
}
n := list.first
removed := 0
for n != nil {
var toRemoveNode *node
if utils.Equals(n.val, val) {
toRemoveNode = n
}
if n.next == nil {
if toRemoveNode != nil {
removed++
list.removeNode(toRemoveNode)
}
break
} else {
n = n.next
}
if toRemoveNode != nil {
removed++
list.removeNode(toRemoveNode)
}
if removed == count {
break
}
}
return removed
if toRemoveNode != nil {
removed++
list.removeNode(toRemoveNode)
}
if removed == count {
break
}
}
return removed
}
func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int)int {
if list == nil {
panic("list is nil")
}
n := list.last
removed := 0
for n != nil {
var toRemoveNode *node
if utils.Equals(n.val, val) {
toRemoveNode = n
}
if n.prev == nil {
if toRemoveNode != nil {
removed++
list.removeNode(toRemoveNode)
}
break
} else {
n = n.prev
}
func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int) int {
if list == nil {
panic("list is nil")
}
n := list.last
removed := 0
for n != nil {
var toRemoveNode *node
if utils.Equals(n.val, val) {
toRemoveNode = n
}
if n.prev == nil {
if toRemoveNode != nil {
removed++
list.removeNode(toRemoveNode)
}
break
} else {
n = n.prev
}
if toRemoveNode != nil {
removed++
list.removeNode(toRemoveNode)
}
if removed == count {
break
}
}
return removed
if toRemoveNode != nil {
removed++
list.removeNode(toRemoveNode)
}
if removed == count {
break
}
}
return removed
}
func (list *LinkedList)Len()int {
if list == nil {
panic("list is nil")
}
return list.size
func (list *LinkedList) Len() int {
if list == nil {
panic("list is nil")
}
return list.size
}
func (list *LinkedList)ForEach(consumer func(int, interface{})bool) {
if list == nil {
panic("list is nil")
}
n := list.first
i := 0
for n != nil {
goNext := consumer(i, n.val)
if !goNext || n.next == nil {
break
} else {
i++
n = n.next
}
}
func (list *LinkedList) ForEach(consumer func(int, interface{}) bool) {
if list == nil {
panic("list is nil")
}
n := list.first
i := 0
for n != nil {
goNext := consumer(i, n.val)
if !goNext || n.next == nil {
break
} else {
i++
n = n.next
}
}
}
func (list *LinkedList)Contains(val interface{})bool {
contains := false
list.ForEach(func(i int, actual interface{}) bool {
if actual == val {
contains = true
return false
}
return true
})
return contains
func (list *LinkedList) Contains(val interface{}) bool {
contains := false
list.ForEach(func(i int, actual interface{}) bool {
if actual == val {
contains = true
return false
}
return true
})
return contains
}
func (list *LinkedList)Range(start int, stop int)[]interface{} {
if list == nil {
panic("list is nil")
}
if start < 0 || start >= list.size {
panic("`start` out of range")
}
if stop < start || stop > list.size {
panic("`stop` out of range")
}
func (list *LinkedList) Range(start int, stop int) []interface{} {
if list == nil {
panic("list is nil")
}
if start < 0 || start >= list.size {
panic("`start` out of range")
}
if stop < start || stop > list.size {
panic("`stop` out of range")
}
sliceSize := stop - start
slice := make([]interface{}, sliceSize)
n := list.first
i := 0
sliceIndex := 0
for n != nil {
if i >= start && i < stop {
slice[sliceIndex] = n.val
sliceIndex++
} else if i >= stop {
break
}
if n.next == nil {
break
} else {
i++
n = n.next
}
}
return slice
sliceSize := stop - start
slice := make([]interface{}, sliceSize)
n := list.first
i := 0
sliceIndex := 0
for n != nil {
if i >= start && i < stop {
slice[sliceIndex] = n.val
sliceIndex++
} else if i >= stop {
break
}
if n.next == nil {
break
} else {
i++
n = n.next
}
}
return slice
}
func Make(vals ...interface{}) *LinkedList {
list := LinkedList{}
for _, v := range vals {
list.Add(v)
}
return &list
list := LinkedList{}
for _, v := range vals {
list.Add(v)
}
return &list
}
func MakeBytesList(vals ...[]byte) *LinkedList {
list := LinkedList{}
for _, v := range vals {
list.Add(v)
}
return &list
list := LinkedList{}
for _, v := range vals {
list.Add(v)
}
return &list
}

View File

@@ -1,215 +1,215 @@
package list
import (
"testing"
"strconv"
"strings"
"strconv"
"strings"
"testing"
)
func ToString(list *LinkedList) string {
arr := make([]string, list.size)
list.ForEach(func(i int, v interface{}) bool {
integer, _ := v.(int)
arr[i] = strconv.Itoa(integer)
return true
})
return "[" + strings.Join(arr, ", ") + "]"
arr := make([]string, list.size)
list.ForEach(func(i int, v interface{}) bool {
integer, _ := v.(int)
arr[i] = strconv.Itoa(integer)
return true
})
return "[" + strings.Join(arr, ", ") + "]"
}
func TestAdd(t *testing.T) {
list := Make()
for i := 0; i < 10; i++ {
list.Add(i)
}
list.ForEach(func(i int, v interface{}) bool {
intVal, _ := v.(int)
if intVal != i {
t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
return true
})
list := Make()
for i := 0; i < 10; i++ {
list.Add(i)
}
list.ForEach(func(i int, v interface{}) bool {
intVal, _ := v.(int)
if intVal != i {
t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
return true
})
}
func TestGet(t *testing.T) {
list := Make()
for i := 0; i < 10; i++ {
list.Add(i)
}
for i := 0; i < 10; i++ {
v := list.Get(i)
k, _ := v.(int)
if i != k {
t.Error("get test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(k))
}
}
list := Make()
for i := 0; i < 10; i++ {
list.Add(i)
}
for i := 0; i < 10; i++ {
v := list.Get(i)
k, _ := v.(int)
if i != k {
t.Error("get test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(k))
}
}
}
func TestRemove(t *testing.T) {
list := Make()
for i := 0; i < 10; i++ {
list.Add(i)
}
for i := 9; i >= 0; i-- {
list.Remove(i)
if i != list.Len() {
t.Error("remove test fail: expected size " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(list.Len()))
}
list.ForEach(func(i int, v interface{}) bool {
intVal, _ := v.(int)
if intVal != i {
t.Error("remove test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
return true
})
}
list := Make()
for i := 0; i < 10; i++ {
list.Add(i)
}
for i := 9; i >= 0; i-- {
list.Remove(i)
if i != list.Len() {
t.Error("remove test fail: expected size " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(list.Len()))
}
list.ForEach(func(i int, v interface{}) bool {
intVal, _ := v.(int)
if intVal != i {
t.Error("remove test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
return true
})
}
}
func TestRemoveVal(t *testing.T) {
list := Make()
for i := 0; i < 10; i++ {
list.Add(i)
list.Add(i)
}
for index := 0; index < list.Len(); index++ {
list.RemoveAllByVal(index)
list.ForEach(func(i int, v interface{}) bool {
intVal, _ := v.(int)
if intVal == index {
t.Error("remove test fail: found " + strconv.Itoa(index) + " at index: " + strconv.Itoa(i))
}
return true
})
}
list := Make()
for i := 0; i < 10; i++ {
list.Add(i)
list.Add(i)
}
for index := 0; index < list.Len(); index++ {
list.RemoveAllByVal(index)
list.ForEach(func(i int, v interface{}) bool {
intVal, _ := v.(int)
if intVal == index {
t.Error("remove test fail: found " + strconv.Itoa(index) + " at index: " + strconv.Itoa(i))
}
return true
})
}
list = Make()
for i := 0; i < 10; i++ {
list.Add(i)
list.Add(i)
}
for i := 0; i < 10; i++ {
list.RemoveByVal(i, 1)
}
list.ForEach(func(i int, v interface{}) bool {
intVal, _ := v.(int)
if intVal != i {
t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
return true
})
for i := 0; i < 10; i++ {
list.RemoveByVal(i, 1)
}
if list.Len() != 0 {
t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len()))
}
list = Make()
for i := 0; i < 10; i++ {
list.Add(i)
list.Add(i)
}
for i := 0; i < 10; i++ {
list.RemoveByVal(i, 1)
}
list.ForEach(func(i int, v interface{}) bool {
intVal, _ := v.(int)
if intVal != i {
t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
return true
})
for i := 0; i < 10; i++ {
list.RemoveByVal(i, 1)
}
if list.Len() != 0 {
t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len()))
}
list = Make()
for i := 0; i < 10; i++ {
list.Add(i)
list.Add(i)
}
for i := 0; i < 10; i++ {
list.ReverseRemoveByVal(i, 1)
}
list.ForEach(func(i int, v interface{}) bool {
intVal, _ := v.(int)
if intVal != i {
t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
return true
})
for i := 0; i < 10; i++ {
list.ReverseRemoveByVal(i, 1)
}
if list.Len() != 0 {
t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len()))
}
list = Make()
for i := 0; i < 10; i++ {
list.Add(i)
list.Add(i)
}
for i := 0; i < 10; i++ {
list.ReverseRemoveByVal(i, 1)
}
list.ForEach(func(i int, v interface{}) bool {
intVal, _ := v.(int)
if intVal != i {
t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
return true
})
for i := 0; i < 10; i++ {
list.ReverseRemoveByVal(i, 1)
}
if list.Len() != 0 {
t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len()))
}
}
func TestInsert(t *testing.T) {
list := Make()
for i := 0; i < 10; i++ {
list.Add(i)
}
for i := 0; i < 10; i++ {
list.Insert(i*2, i)
list := Make()
for i := 0; i < 10; i++ {
list.Add(i)
}
for i := 0; i < 10; i++ {
list.Insert(i*2, i)
list.ForEach(func(j int, v interface{}) bool {
var expected int
if j < (i + 1) * 2 {
if j%2 == 0 {
expected = j / 2
} else {
expected = (j - 1) / 2
}
} else {
expected = j - i - 1
}
actual, _ := list.Get(j).(int)
if actual != expected {
t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual))
}
return true
})
list.ForEach(func(j int, v interface{}) bool {
var expected int
if j < (i+1)*2 {
if j%2 == 0 {
expected = j / 2
} else {
expected = (j - 1) / 2
}
} else {
expected = j - i - 1
}
actual, _ := list.Get(j).(int)
if actual != expected {
t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual))
}
return true
})
for j := 0; j < list.Len(); j++ {
var expected int
if j < (i + 1) * 2 {
if j%2 == 0 {
expected = j / 2
} else {
expected = (j - 1) / 2
}
} else {
expected = j - i - 1
}
actual, _ := list.Get(j).(int)
if actual != expected {
t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual))
}
}
for j := 0; j < list.Len(); j++ {
var expected int
if j < (i+1)*2 {
if j%2 == 0 {
expected = j / 2
} else {
expected = (j - 1) / 2
}
} else {
expected = j - i - 1
}
actual, _ := list.Get(j).(int)
if actual != expected {
t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual))
}
}
}
}
}
func TestRemoveLast(t *testing.T) {
list := Make()
for i := 0; i < 10; i++ {
list.Add(i)
}
for i := 9; i >= 0; i-- {
val := list.RemoveLast()
intVal, _ := val.(int)
if intVal != i {
t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
}
list := Make()
for i := 0; i < 10; i++ {
list.Add(i)
}
for i := 9; i >= 0; i-- {
val := list.RemoveLast()
intVal, _ := val.(int)
if intVal != i {
t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
}
}
}
func TestRange(t *testing.T) {
list := Make()
size := 10
for i := 0; i < size; i++ {
list.Add(i)
}
for start := 0; start < size; start++ {
for stop := start; stop < size; stop++ {
slice := list.Range(start, stop)
if len(slice) != stop - start {
t.Error("expected " + strconv.Itoa(stop - start) + ", get: " + strconv.Itoa(len(slice)) +
", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]")
}
sliceIndex := 0
for i := start; i < stop; i++ {
val := slice[sliceIndex]
intVal, _ := val.(int)
if intVal != i {
t.Error("expected " + strconv.Itoa(i) + ", get: " + strconv.Itoa(intVal) +
", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]")
}
sliceIndex++
}
}
}
list := Make()
size := 10
for i := 0; i < size; i++ {
list.Add(i)
}
for start := 0; start < size; start++ {
for stop := start; stop < size; stop++ {
slice := list.Range(start, stop)
if len(slice) != stop-start {
t.Error("expected " + strconv.Itoa(stop-start) + ", get: " + strconv.Itoa(len(slice)) +
", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]")
}
sliceIndex := 0
for i := start; i < stop; i++ {
val := slice[sliceIndex]
intVal, _ := val.(int)
if intVal != i {
t.Error("expected " + strconv.Itoa(i) + ", get: " + strconv.Itoa(intVal) +
", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]")
}
sliceIndex++
}
}
}
}

View File

@@ -1,153 +1,152 @@
package lock
import (
"fmt"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"testing"
"time"
"fmt"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"testing"
"time"
)
const (
prime32 = uint32(16777619)
prime32 = uint32(16777619)
)
type Locks struct {
table []*sync.RWMutex
table []*sync.RWMutex
}
func Make(tableSize int) *Locks {
table := make([]*sync.RWMutex, tableSize)
for i := 0; i < tableSize; i++ {
table[i] = &sync.RWMutex{}
}
return &Locks{
table: table,
}
table := make([]*sync.RWMutex, tableSize)
for i := 0; i < tableSize; i++ {
table[i] = &sync.RWMutex{}
}
return &Locks{
table: table,
}
}
func fnv32(key string) uint32 {
hash := uint32(2166136261)
for i := 0; i < len(key); i++ {
hash *= prime32
hash ^= uint32(key[i])
}
return hash
hash := uint32(2166136261)
for i := 0; i < len(key); i++ {
hash *= prime32
hash ^= uint32(key[i])
}
return hash
}
func (locks *Locks) spread(hashCode uint32) uint32 {
if locks == nil {
panic("dict is nil")
}
tableSize := uint32(len(locks.table))
return (tableSize - 1) & uint32(hashCode)
if locks == nil {
panic("dict is nil")
}
tableSize := uint32(len(locks.table))
return (tableSize - 1) & uint32(hashCode)
}
func (locks *Locks)Lock(key string) {
index := locks.spread(fnv32(key))
mu := locks.table[index]
mu.Lock()
func (locks *Locks) Lock(key string) {
index := locks.spread(fnv32(key))
mu := locks.table[index]
mu.Lock()
}
func (locks *Locks)RLock(key string) {
index := locks.spread(fnv32(key))
mu := locks.table[index]
mu.RLock()
func (locks *Locks) RLock(key string) {
index := locks.spread(fnv32(key))
mu := locks.table[index]
mu.RLock()
}
func (locks *Locks)UnLock(key string) {
index := locks.spread(fnv32(key))
mu := locks.table[index]
mu.Unlock()
func (locks *Locks) UnLock(key string) {
index := locks.spread(fnv32(key))
mu := locks.table[index]
mu.Unlock()
}
func (locks *Locks)RUnLock(key string) {
index := locks.spread(fnv32(key))
mu := locks.table[index]
mu.RUnlock()
func (locks *Locks) RUnLock(key string) {
index := locks.spread(fnv32(key))
mu := locks.table[index]
mu.RUnlock()
}
func (locks *Locks) toLockIndices(keys []string, reverse bool) []uint32 {
indexMap := make(map[uint32]bool)
for _, key := range keys {
index := locks.spread(fnv32(key))
indexMap[index] = true
}
indices := make([]uint32, 0, len(indexMap))
for index := range indexMap {
indices = append(indices, index)
}
sort.Slice(indices, func(i, j int) bool {
if !reverse {
return indices[i] < indices[j]
} else {
return indices[i] > indices[j]
}
})
return indices
indexMap := make(map[uint32]bool)
for _, key := range keys {
index := locks.spread(fnv32(key))
indexMap[index] = true
}
indices := make([]uint32, 0, len(indexMap))
for index := range indexMap {
indices = append(indices, index)
}
sort.Slice(indices, func(i, j int) bool {
if !reverse {
return indices[i] < indices[j]
} else {
return indices[i] > indices[j]
}
})
return indices
}
func (locks *Locks)Locks(keys ...string) {
indices := locks.toLockIndices(keys, false)
for _, index := range indices {
mu := locks.table[index]
mu.Lock()
}
func (locks *Locks) Locks(keys ...string) {
indices := locks.toLockIndices(keys, false)
for _, index := range indices {
mu := locks.table[index]
mu.Lock()
}
}
func (locks *Locks)RLocks(keys ...string) {
indices := locks.toLockIndices(keys, false)
for _, index := range indices {
mu := locks.table[index]
mu.RLock()
}
func (locks *Locks) RLocks(keys ...string) {
indices := locks.toLockIndices(keys, false)
for _, index := range indices {
mu := locks.table[index]
mu.RLock()
}
}
func (locks *Locks)UnLocks(keys ...string) {
indices := locks.toLockIndices(keys, true)
for _, index := range indices {
mu := locks.table[index]
mu.Unlock()
}
func (locks *Locks) UnLocks(keys ...string) {
indices := locks.toLockIndices(keys, true)
for _, index := range indices {
mu := locks.table[index]
mu.Unlock()
}
}
func (locks *Locks)RUnLocks(keys ...string) {
indices := locks.toLockIndices(keys, true)
for _, index := range indices {
mu := locks.table[index]
mu.RUnlock()
}
func (locks *Locks) RUnLocks(keys ...string) {
indices := locks.toLockIndices(keys, true)
for _, index := range indices {
mu := locks.table[index]
mu.RUnlock()
}
}
func GoID() int {
var buf [64]byte
n := runtime.Stack(buf[:], false)
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
id, err := strconv.Atoi(idField)
if err != nil {
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
}
return id
var buf [64]byte
n := runtime.Stack(buf[:], false)
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
id, err := strconv.Atoi(idField)
if err != nil {
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
}
return id
}
func debug(testing.T) {
lm := Locks{}
size := 10
var wg sync.WaitGroup
wg.Add(size)
for i := 0; i < size; i++ {
go func(i int) {
lm.Locks("1", "2")
println("go: " + strconv.Itoa(GoID()))
time.Sleep(time.Second)
println("go: " + strconv.Itoa(GoID()))
lm.UnLocks("1", "2")
wg.Done()
}(i)
}
wg.Wait()
lm := Locks{}
size := 10
var wg sync.WaitGroup
wg.Add(size)
for i := 0; i < size; i++ {
go func(i int) {
lm.Locks("1", "2")
println("go: " + strconv.Itoa(GoID()))
time.Sleep(time.Second)
println("go: " + strconv.Itoa(GoID()))
lm.UnLocks("1", "2")
wg.Done()
}(i)
}
wg.Wait()
}

View File

@@ -1,32 +1,32 @@
package set
import (
"strconv"
"testing"
"strconv"
"testing"
)
func TestSet(t *testing.T) {
size := 10
set := Make()
for i := 0; i < size; i++ {
set.Add(strconv.Itoa(i))
}
for i := 0; i < size; i++ {
ok := set.Has(strconv.Itoa(i))
if !ok {
t.Error("expected true actual false, key: " + strconv.Itoa(i))
}
}
for i := 0; i < size; i++ {
ok := set.Remove(strconv.Itoa(i))
if ok != 1 {
t.Error("expected true actual false, key: " + strconv.Itoa(i))
}
}
for i := 0; i < size; i++ {
ok := set.Has(strconv.Itoa(i))
if ok {
t.Error("expected false actual true, key: " + strconv.Itoa(i))
}
}
size := 10
set := Make()
for i := 0; i < size; i++ {
set.Add(strconv.Itoa(i))
}
for i := 0; i < size; i++ {
ok := set.Has(strconv.Itoa(i))
if !ok {
t.Error("expected true actual false, key: " + strconv.Itoa(i))
}
}
for i := 0; i < size; i++ {
ok := set.Remove(strconv.Itoa(i))
if ok != 1 {
t.Error("expected true actual false, key: " + strconv.Itoa(i))
}
}
for i := 0; i < size; i++ {
ok := set.Has(strconv.Itoa(i))
if ok {
t.Error("expected false actual true, key: " + strconv.Itoa(i))
}
}
}

View File

@@ -1,8 +1,8 @@
package sortedset
import (
"errors"
"strconv"
"errors"
"strconv"
)
/*
@@ -14,78 +14,78 @@ import (
*/
const (
negativeInf int8 = -1
positiveInf int8 = 1
negativeInf int8 = -1
positiveInf int8 = 1
)
type ScoreBorder struct {
Inf int8
Value float64
Exclude bool
Inf int8
Value float64
Exclude bool
}
// if max.greater(score) then the score is within the upper border
// do not use min.greater()
func (border *ScoreBorder)greater(value float64)bool {
if border.Inf == negativeInf {
return false
} else if border.Inf == positiveInf {
return true
}
if border.Exclude {
return border.Value > value
} else {
return border.Value >= value
}
func (border *ScoreBorder) greater(value float64) bool {
if border.Inf == negativeInf {
return false
} else if border.Inf == positiveInf {
return true
}
if border.Exclude {
return border.Value > value
} else {
return border.Value >= value
}
}
func (border *ScoreBorder)less(value float64)bool {
if border.Inf == negativeInf {
return true
} else if border.Inf == positiveInf {
return false
}
if border.Exclude {
return border.Value < value
} else {
return border.Value <= value
}
func (border *ScoreBorder) less(value float64) bool {
if border.Inf == negativeInf {
return true
} else if border.Inf == positiveInf {
return false
}
if border.Exclude {
return border.Value < value
} else {
return border.Value <= value
}
}
var positiveInfBorder = &ScoreBorder {
Inf: positiveInf,
var positiveInfBorder = &ScoreBorder{
Inf: positiveInf,
}
var negativeInfBorder = &ScoreBorder {
Inf: negativeInf,
var negativeInfBorder = &ScoreBorder{
Inf: negativeInf,
}
func ParseScoreBorder(s string)(*ScoreBorder, error) {
if s == "inf" || s == "+inf" {
return positiveInfBorder, nil
}
if s == "-inf" {
return negativeInfBorder, nil
}
if s[0] == '(' {
value, err := strconv.ParseFloat(s[1:], 64)
if err != nil {
return nil, errors.New("ERR min or max is not a float")
}
return &ScoreBorder{
Inf: 0,
Value: value,
Exclude: true,
}, nil
} else {
value, err := strconv.ParseFloat(s, 64)
if err != nil {
return nil, errors.New("ERR min or max is not a float")
}
return &ScoreBorder{
Inf: 0,
Value: value,
Exclude: false,
}, nil
}
func ParseScoreBorder(s string) (*ScoreBorder, error) {
if s == "inf" || s == "+inf" {
return positiveInfBorder, nil
}
if s == "-inf" {
return negativeInfBorder, nil
}
if s[0] == '(' {
value, err := strconv.ParseFloat(s[1:], 64)
if err != nil {
return nil, errors.New("ERR min or max is not a float")
}
return &ScoreBorder{
Inf: 0,
Value: value,
Exclude: true,
}, nil
} else {
value, err := strconv.ParseFloat(s, 64)
if err != nil {
return nil, errors.New("ERR min or max is not a float")
}
return &ScoreBorder{
Inf: 0,
Value: value,
Exclude: false,
}, nil
}
}

View File

@@ -3,130 +3,129 @@ package sortedset
import "math/rand"
const (
maxLevel = 16
maxLevel = 16
)
type Element struct {
Member string
Score float64
Member string
Score float64
}
// level aspect of a Node
type Level struct {
forward *Node // forward node has greater score
span int64
forward *Node // forward node has greater score
span int64
}
type Node struct {
Element
backward *Node
level []*Level // level[0] is base level
Element
backward *Node
level []*Level // level[0] is base level
}
type skiplist struct {
header *Node
tail *Node
length int64
level int16
header *Node
tail *Node
length int64
level int16
}
func makeNode(level int16, score float64, member string)*Node {
n := &Node{
Element: Element{
Score: score,
Member: member,
},
level: make([]*Level, level),
}
for i := range n.level {
n.level[i] = new(Level)
}
return n
func makeNode(level int16, score float64, member string) *Node {
n := &Node{
Element: Element{
Score: score,
Member: member,
},
level: make([]*Level, level),
}
for i := range n.level {
n.level[i] = new(Level)
}
return n
}
func makeSkiplist()*skiplist {
return &skiplist{
level: 1,
header: makeNode(maxLevel, 0, ""),
}
func makeSkiplist() *skiplist {
return &skiplist{
level: 1,
header: makeNode(maxLevel, 0, ""),
}
}
func randomLevel() int16 {
level := int16(1)
for float32(rand.Int31()&0xFFFF) < (0.25 * 0xFFFF) {
level++
}
if level < maxLevel {
return level
}
return maxLevel
level := int16(1)
for float32(rand.Int31()&0xFFFF) < (0.25 * 0xFFFF) {
level++
}
if level < maxLevel {
return level
}
return maxLevel
}
func (skiplist *skiplist)insert(member string, score float64)*Node {
update := make([]*Node, maxLevel) // link new node with node in `update`
rank := make([]int64, maxLevel)
func (skiplist *skiplist) insert(member string, score float64) *Node {
update := make([]*Node, maxLevel) // link new node with node in `update`
rank := make([]int64, maxLevel)
// find position to insert
node := skiplist.header
for i := skiplist.level - 1; i >= 0; i-- {
if i == skiplist.level - 1 {
rank[i] = 0
} else {
rank[i] = rank[i + 1] // store rank that is crossed to reach the insert position
}
if node.level[i] != nil {
// traverse the skip list
for node.level[i].forward != nil &&
(node.level[i].forward.Score < score ||
(node.level[i].forward.Score == score && node.level[i].forward.Member < member)) { // same score, different key
rank[i] += node.level[i].span
node = node.level[i].forward
}
}
update[i] = node
}
// find position to insert
node := skiplist.header
for i := skiplist.level - 1; i >= 0; i-- {
if i == skiplist.level-1 {
rank[i] = 0
} else {
rank[i] = rank[i+1] // store rank that is crossed to reach the insert position
}
if node.level[i] != nil {
// traverse the skip list
for node.level[i].forward != nil &&
(node.level[i].forward.Score < score ||
(node.level[i].forward.Score == score && node.level[i].forward.Member < member)) { // same score, different key
rank[i] += node.level[i].span
node = node.level[i].forward
}
}
update[i] = node
}
level := randomLevel()
// extend skiplist level
if level > skiplist.level {
for i := skiplist.level; i < level; i++ {
rank[i] = 0
update[i] = skiplist.header
update[i].level[i].span = skiplist.length
}
skiplist.level = level
}
level := randomLevel()
// extend skiplist level
if level > skiplist.level {
for i := skiplist.level; i < level; i++ {
rank[i] = 0
update[i] = skiplist.header
update[i].level[i].span = skiplist.length
}
skiplist.level = level
}
// make node and link into skiplist
node = makeNode(level, score, member)
for i := int16(0); i < level; i++ {
node.level[i].forward = update[i].level[i].forward
update[i].level[i].forward = node
// make node and link into skiplist
node = makeNode(level, score, member)
for i := int16(0); i < level; i++ {
node.level[i].forward = update[i].level[i].forward
update[i].level[i].forward = node
// update span covered by update[i] as node is inserted here
node.level[i].span = update[i].level[i].span - (rank[0] - rank[i])
update[i].level[i].span = (rank[0] - rank[i]) + 1
}
// update span covered by update[i] as node is inserted here
node.level[i].span = update[i].level[i].span - (rank[0] - rank[i])
update[i].level[i].span = (rank[0] - rank[i]) + 1
}
// increment span for untouched levels
for i := level; i < skiplist.level; i++ {
update[i].level[i].span++
}
// increment span for untouched levels
for i := level; i < skiplist.level; i++ {
update[i].level[i].span++
}
// set backward node
if update[0] == skiplist.header {
node.backward = nil
} else {
node.backward = update[0]
}
if node.level[0].forward != nil {
node.level[0].forward.backward = node
} else {
skiplist.tail = node
}
skiplist.length++
return node
// set backward node
if update[0] == skiplist.header {
node.backward = nil
} else {
node.backward = update[0]
}
if node.level[0].forward != nil {
node.level[0].forward.backward = node
} else {
skiplist.tail = node
}
skiplist.length++
return node
}
/*
@@ -134,212 +133,212 @@ func (skiplist *skiplist)insert(member string, score float64)*Node {
* param update: backward node (of target)
*/
func (skiplist *skiplist) removeNode(node *Node, update []*Node) {
for i := int16(0); i < skiplist.level; i++ {
if update[i].level[i].forward == node {
update[i].level[i].span += node.level[i].span - 1
update[i].level[i].forward = node.level[i].forward
} else {
update[i].level[i].span--
}
}
if node.level[0].forward != nil {
node.level[0].forward.backward = node.backward
} else {
skiplist.tail = node.backward
}
for skiplist.level > 1 && skiplist.header.level[skiplist.level-1].forward == nil {
skiplist.level--
}
skiplist.length--
for i := int16(0); i < skiplist.level; i++ {
if update[i].level[i].forward == node {
update[i].level[i].span += node.level[i].span - 1
update[i].level[i].forward = node.level[i].forward
} else {
update[i].level[i].span--
}
}
if node.level[0].forward != nil {
node.level[0].forward.backward = node.backward
} else {
skiplist.tail = node.backward
}
for skiplist.level > 1 && skiplist.header.level[skiplist.level-1].forward == nil {
skiplist.level--
}
skiplist.length--
}
/*
* return: has found and removed node
*/
func (skiplist *skiplist) remove(member string, score float64)bool {
/*
* find backward node (of target) or last node of each level
* their forward need to be updated
*/
update := make([]*Node, maxLevel)
node := skiplist.header
for i := skiplist.level - 1; i >= 0; i-- {
for node.level[i].forward != nil &&
(node.level[i].forward.Score < score ||
(node.level[i].forward.Score == score &&
node.level[i].forward.Member < member)) {
node = node.level[i].forward
}
update[i] = node
}
node = node.level[0].forward
if node != nil && score == node.Score && node.Member == member {
skiplist.removeNode(node, update)
// free x
return true
}
return false
func (skiplist *skiplist) remove(member string, score float64) bool {
/*
* find backward node (of target) or last node of each level
* their forward need to be updated
*/
update := make([]*Node, maxLevel)
node := skiplist.header
for i := skiplist.level - 1; i >= 0; i-- {
for node.level[i].forward != nil &&
(node.level[i].forward.Score < score ||
(node.level[i].forward.Score == score &&
node.level[i].forward.Member < member)) {
node = node.level[i].forward
}
update[i] = node
}
node = node.level[0].forward
if node != nil && score == node.Score && node.Member == member {
skiplist.removeNode(node, update)
// free x
return true
}
return false
}
/*
* return: 1 based rank, 0 means member not found
*/
func (skiplist *skiplist) getRank(member string, score float64)int64 {
var rank int64 = 0
x := skiplist.header
for i := skiplist.level - 1; i >= 0; i-- {
for x.level[i].forward != nil &&
(x.level[i].forward.Score < score ||
(x.level[i].forward.Score == score &&
x.level[i].forward.Member <= member)) {
rank += x.level[i].span
x = x.level[i].forward
}
func (skiplist *skiplist) getRank(member string, score float64) int64 {
var rank int64 = 0
x := skiplist.header
for i := skiplist.level - 1; i >= 0; i-- {
for x.level[i].forward != nil &&
(x.level[i].forward.Score < score ||
(x.level[i].forward.Score == score &&
x.level[i].forward.Member <= member)) {
rank += x.level[i].span
x = x.level[i].forward
}
/* x might be equal to zsl->header, so test if obj is non-NULL */
if x.Member == member {
return rank
}
}
return 0
/* x might be equal to zsl->header, so test if obj is non-NULL */
if x.Member == member {
return rank
}
}
return 0
}
/*
* 1-based rank
*/
func (skiplist *skiplist) getByRank(rank int64)*Node {
var i int64 = 0
n := skiplist.header
// scan from top level
for level := skiplist.level - 1; level >= 0; level-- {
for n.level[level].forward != nil && (i+n.level[level].span) <= rank {
i += n.level[level].span
n = n.level[level].forward
}
if i == rank {
return n
}
}
return nil
func (skiplist *skiplist) getByRank(rank int64) *Node {
var i int64 = 0
n := skiplist.header
// scan from top level
for level := skiplist.level - 1; level >= 0; level-- {
for n.level[level].forward != nil && (i+n.level[level].span) <= rank {
i += n.level[level].span
n = n.level[level].forward
}
if i == rank {
return n
}
}
return nil
}
func (skiplist *skiplist) hasInRange(min *ScoreBorder, max *ScoreBorder) bool {
// min & max = empty
if min.Value > max.Value || (min.Value == max.Value && (min.Exclude || max.Exclude)) {
return false
}
// min > tail
n := skiplist.tail
if n == nil || !min.less(n.Score) {
return false
}
// max < head
n = skiplist.header.level[0].forward
if n == nil || !max.greater(n.Score) {
return false
}
return true
// min & max = empty
if min.Value > max.Value || (min.Value == max.Value && (min.Exclude || max.Exclude)) {
return false
}
// min > tail
n := skiplist.tail
if n == nil || !min.less(n.Score) {
return false
}
// max < head
n = skiplist.header.level[0].forward
if n == nil || !max.greater(n.Score) {
return false
}
return true
}
func (skiplist *skiplist) getFirstInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node {
if !skiplist.hasInRange(min, max) {
return nil
}
n := skiplist.header
// scan from top level
for level := skiplist.level - 1; level >= 0; level-- {
// if forward is not in range than move forward
for n.level[level].forward != nil && !min.less(n.level[level].forward.Score) {
n = n.level[level].forward
}
}
/* This is an inner range, so the next node cannot be NULL. */
n = n.level[0].forward
if !max.greater(n.Score) {
return nil
}
return n
if !skiplist.hasInRange(min, max) {
return nil
}
n := skiplist.header
// scan from top level
for level := skiplist.level - 1; level >= 0; level-- {
// if forward is not in range than move forward
for n.level[level].forward != nil && !min.less(n.level[level].forward.Score) {
n = n.level[level].forward
}
}
/* This is an inner range, so the next node cannot be NULL. */
n = n.level[0].forward
if !max.greater(n.Score) {
return nil
}
return n
}
func (skiplist *skiplist) getLastInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node {
if !skiplist.hasInRange(min, max) {
return nil
}
n := skiplist.header
// scan from top level
for level := skiplist.level - 1; level >= 0; level-- {
for n.level[level].forward != nil && max.greater(n.level[level].forward.Score) {
n = n.level[level].forward
}
}
if !min.less(n.Score) {
return nil
}
return n
if !skiplist.hasInRange(min, max) {
return nil
}
n := skiplist.header
// scan from top level
for level := skiplist.level - 1; level >= 0; level-- {
for n.level[level].forward != nil && max.greater(n.level[level].forward.Score) {
n = n.level[level].forward
}
}
if !min.less(n.Score) {
return nil
}
return n
}
/*
* return removed elements
*/
func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder)(removed []*Element) {
update := make([]*Node, maxLevel)
removed = make([]*Element, 0)
// find backward nodes (of target range) or last node of each level
node := skiplist.header
for i := skiplist.level - 1; i >= 0; i-- {
for node.level[i].forward != nil {
if min.less(node.level[i].forward.Score) { // already in range
break
}
node = node.level[i].forward
}
update[i] = node
}
func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder) (removed []*Element) {
update := make([]*Node, maxLevel)
removed = make([]*Element, 0)
// find backward nodes (of target range) or last node of each level
node := skiplist.header
for i := skiplist.level - 1; i >= 0; i-- {
for node.level[i].forward != nil {
if min.less(node.level[i].forward.Score) { // already in range
break
}
node = node.level[i].forward
}
update[i] = node
}
// node is the first one within range
node = node.level[0].forward
// node is the first one within range
node = node.level[0].forward
// remove nodes in range
for node != nil {
if !max.greater(node.Score) { // already out of range
break
}
next := node.level[0].forward
removedElement := node.Element
removed = append(removed, &removedElement)
skiplist.removeNode(node, update)
node = next
}
return removed
// remove nodes in range
for node != nil {
if !max.greater(node.Score) { // already out of range
break
}
next := node.level[0].forward
removedElement := node.Element
removed = append(removed, &removedElement)
skiplist.removeNode(node, update)
node = next
}
return removed
}
// 1-based rank, including start, exclude stop
func (skiplist *skiplist) RemoveRangeByRank(start int64, stop int64)(removed []*Element) {
var i int64 = 0 // rank of iterator
update := make([]*Node, maxLevel)
removed = make([]*Element, 0)
func (skiplist *skiplist) RemoveRangeByRank(start int64, stop int64) (removed []*Element) {
var i int64 = 0 // rank of iterator
update := make([]*Node, maxLevel)
removed = make([]*Element, 0)
// scan from top level
node := skiplist.header
for level := skiplist.level - 1; level >= 0; level-- {
for node.level[level].forward != nil && (i+node.level[level].span) < start {
i += node.level[level].span
node = node.level[level].forward
}
update[level] = node
}
// scan from top level
node := skiplist.header
for level := skiplist.level - 1; level >= 0; level-- {
for node.level[level].forward != nil && (i+node.level[level].span) < start {
i += node.level[level].span
node = node.level[level].forward
}
update[level] = node
}
i++
node = node.level[0].forward // first node in range
i++
node = node.level[0].forward // first node in range
// remove nodes in range
for node != nil && i < stop {
next := node.level[0].forward
removedElement := node.Element
removed = append(removed, &removedElement)
skiplist.removeNode(node, update)
node = next
i++
}
return removed
// remove nodes in range
for node != nil && i < stop {
next := node.level[0].forward
removedElement := node.Element
removed = append(removed, &removedElement)
skiplist.removeNode(node, update)
node = next
i++
}
return removed
}

View File

@@ -1,227 +1,226 @@
package sortedset
import (
"strconv"
"strconv"
)
type SortedSet struct {
dict map[string]*Element
skiplist *skiplist
dict map[string]*Element
skiplist *skiplist
}
func Make()*SortedSet {
return &SortedSet{
dict: make(map[string]*Element),
skiplist: makeSkiplist(),
}
func Make() *SortedSet {
return &SortedSet{
dict: make(map[string]*Element),
skiplist: makeSkiplist(),
}
}
/*
* return: has inserted new node
*/
func (sortedSet *SortedSet)Add(member string, score float64)bool {
element, ok := sortedSet.dict[member]
sortedSet.dict[member] = &Element{
Member: member,
Score: score,
}
if ok {
if score != element.Score {
sortedSet.skiplist.remove(member, score)
sortedSet.skiplist.insert(member, score)
}
return false
} else {
sortedSet.skiplist.insert(member, score)
return true
}
func (sortedSet *SortedSet) Add(member string, score float64) bool {
element, ok := sortedSet.dict[member]
sortedSet.dict[member] = &Element{
Member: member,
Score: score,
}
if ok {
if score != element.Score {
sortedSet.skiplist.remove(member, score)
sortedSet.skiplist.insert(member, score)
}
return false
} else {
sortedSet.skiplist.insert(member, score)
return true
}
}
func (sortedSet *SortedSet) Len()int64 {
return int64(len(sortedSet.dict))
func (sortedSet *SortedSet) Len() int64 {
return int64(len(sortedSet.dict))
}
func (sortedSet *SortedSet) Get(member string) (element *Element, ok bool) {
element, ok = sortedSet.dict[member]
if !ok {
return nil, false
}
return element, true
element, ok = sortedSet.dict[member]
if !ok {
return nil, false
}
return element, true
}
func (sortedSet *SortedSet) Remove(member string)bool {
v, ok := sortedSet.dict[member]
if ok {
sortedSet.skiplist.remove(member, v.Score)
delete(sortedSet.dict, member)
return true
}
return false
func (sortedSet *SortedSet) Remove(member string) bool {
v, ok := sortedSet.dict[member]
if ok {
sortedSet.skiplist.remove(member, v.Score)
delete(sortedSet.dict, member)
return true
}
return false
}
/**
* get 0-based rank
*/
func (sortedSet *SortedSet) GetRank(member string, desc bool) (rank int64) {
element, ok := sortedSet.dict[member]
if !ok {
return -1
}
r := sortedSet.skiplist.getRank(member, element.Score)
if desc {
r = sortedSet.skiplist.length - r
} else {
r--
}
return r
element, ok := sortedSet.dict[member]
if !ok {
return -1
}
r := sortedSet.skiplist.getRank(member, element.Score)
if desc {
r = sortedSet.skiplist.length - r
} else {
r--
}
return r
}
/**
* traverse [start, stop), 0-based rank
*/
func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer func(element *Element)bool) {
size := int64(sortedSet.Len())
if start < 0 || start >= size {
panic("illegal start " + strconv.FormatInt(start, 10))
}
if stop < start || stop > size {
panic("illegal end " + strconv.FormatInt(stop, 10))
}
func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer func(element *Element) bool) {
size := int64(sortedSet.Len())
if start < 0 || start >= size {
panic("illegal start " + strconv.FormatInt(start, 10))
}
if stop < start || stop > size {
panic("illegal end " + strconv.FormatInt(stop, 10))
}
// find start node
var node *Node
if desc {
node = sortedSet.skiplist.tail
if start > 0 {
node = sortedSet.skiplist.getByRank(int64(size - start))
}
} else {
node = sortedSet.skiplist.header.level[0].forward
if start > 0 {
node = sortedSet.skiplist.getByRank(int64(start + 1))
}
}
// find start node
var node *Node
if desc {
node = sortedSet.skiplist.tail
if start > 0 {
node = sortedSet.skiplist.getByRank(int64(size - start))
}
} else {
node = sortedSet.skiplist.header.level[0].forward
if start > 0 {
node = sortedSet.skiplist.getByRank(int64(start + 1))
}
}
sliceSize := int(stop - start)
for i := 0; i < sliceSize; i++ {
if !consumer(&node.Element) {
break
}
if desc {
node = node.backward
} else {
node = node.level[0].forward
}
}
sliceSize := int(stop - start)
for i := 0; i < sliceSize; i++ {
if !consumer(&node.Element) {
break
}
if desc {
node = node.backward
} else {
node = node.level[0].forward
}
}
}
/**
* return [start, stop), 0-based rank
* assert start in [0, size), stop in [start, size]
*/
func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool)[]*Element {
sliceSize := int(stop - start)
slice := make([]*Element, sliceSize)
i := 0
sortedSet.ForEach(start, stop, desc, func(element *Element)bool {
slice[i] = element
i++
return true
})
return slice
func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool) []*Element {
sliceSize := int(stop - start)
slice := make([]*Element, sliceSize)
i := 0
sortedSet.ForEach(start, stop, desc, func(element *Element) bool {
slice[i] = element
i++
return true
})
return slice
}
func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder)int64 {
var i int64 = 0
// ascending order
sortedSet.ForEach(0, sortedSet.Len(), false, func(element *Element) bool {
gtMin := min.less(element.Score) // greater than min
if !gtMin {
// has not into range, continue foreach
return true
}
ltMax := max.greater(element.Score) // less than max
if !ltMax {
// break through score border, break foreach
return false
}
// gtMin && ltMax
i++
return true
})
return i
func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder) int64 {
var i int64 = 0
// ascending order
sortedSet.ForEach(0, sortedSet.Len(), false, func(element *Element) bool {
gtMin := min.less(element.Score) // greater than min
if !gtMin {
// has not into range, continue foreach
return true
}
ltMax := max.greater(element.Score) // less than max
if !ltMax {
// break through score border, break foreach
return false
}
// gtMin && ltMax
i++
return true
})
return i
}
func (sortedSet *SortedSet) ForEachByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool, consumer func(element *Element) bool) {
// find start node
var node *Node
if desc {
node = sortedSet.skiplist.getLastInScoreRange(min, max)
} else {
node = sortedSet.skiplist.getFirstInScoreRange(min, max)
}
// find start node
var node *Node
if desc {
node = sortedSet.skiplist.getLastInScoreRange(min, max)
} else {
node = sortedSet.skiplist.getFirstInScoreRange(min, max)
}
for node != nil && offset > 0 {
if desc {
node = node.backward
} else {
node = node.level[0].forward
}
offset--
}
for node != nil && offset > 0 {
if desc {
node = node.backward
} else {
node = node.level[0].forward
}
offset--
}
// A negative limit returns all elements from the offset
for i := 0; (i < int(limit) || limit < 0) && node != nil; i++ {
if !consumer(&node.Element) {
break
}
if desc {
node = node.backward
} else {
node = node.level[0].forward
}
if node == nil {
break
}
gtMin := min.less(node.Element.Score) // greater than min
ltMax := max.greater(node.Element.Score)
if !gtMin || !ltMax {
break // break through score border
}
}
// A negative limit returns all elements from the offset
for i := 0; (i < int(limit) || limit < 0) && node != nil; i++ {
if !consumer(&node.Element) {
break
}
if desc {
node = node.backward
} else {
node = node.level[0].forward
}
if node == nil {
break
}
gtMin := min.less(node.Element.Score) // greater than min
ltMax := max.greater(node.Element.Score)
if !gtMin || !ltMax {
break // break through score border
}
}
}
/*
* param limit: <0 means no limit
*/
func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool)[]*Element {
if limit == 0 || offset < 0{
return make([]*Element, 0)
}
slice := make([]*Element, 0)
sortedSet.ForEachByScore(min, max, offset, limit, desc, func(element *Element) bool {
slice = append(slice, element)
return true
})
return slice
func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool) []*Element {
if limit == 0 || offset < 0 {
return make([]*Element, 0)
}
slice := make([]*Element, 0)
sortedSet.ForEachByScore(min, max, offset, limit, desc, func(element *Element) bool {
slice = append(slice, element)
return true
})
return slice
}
func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder)int64 {
removed := sortedSet.skiplist.RemoveRangeByScore(min, max)
for _, element := range removed {
delete(sortedSet.dict, element.Member)
}
return int64(len(removed))
func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder) int64 {
removed := sortedSet.skiplist.RemoveRangeByScore(min, max)
for _, element := range removed {
delete(sortedSet.dict, element.Member)
}
return int64(len(removed))
}
/*
* 0-based rank, [start, stop)
*/
func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64)int64 {
removed := sortedSet.skiplist.RemoveRangeByRank(start + 1, stop + 1)
for _, element := range removed {
delete(sortedSet.dict, element.Member)
}
return int64(len(removed))
func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64) int64 {
removed := sortedSet.skiplist.RemoveRangeByRank(start+1, stop+1)
for _, element := range removed {
delete(sortedSet.dict, element.Member)
}
return int64(len(removed))
}

View File

@@ -1,28 +1,28 @@
package utils
func Equals(a interface{}, b interface{})bool {
sliceA, okA := a.([]byte)
sliceB, okB := b.([]byte)
if okA && okB {
return BytesEquals(sliceA, sliceB)
}
return a == b
func Equals(a interface{}, b interface{}) bool {
sliceA, okA := a.([]byte)
sliceB, okB := b.([]byte)
if okA && okB {
return BytesEquals(sliceA, sliceB)
}
return a == b
}
func BytesEquals(a []byte, b []byte) bool {
if (a == nil && b != nil) || (a != nil && b == nil) {
return false
}
if len(a) != len(b) {
return false
}
size := len(a)
for i := 0; i < size; i++ {
av := a[i]
bv := b[i]
if av != bv {
return false
}
}
return true
if (a == nil && b != nil) || (a != nil && b == nil) {
return false
}
if len(a) != len(b) {
return false
}
size := len(a)
for i := 0; i < size; i++ {
av := a[i]
bv := b[i]
if av != bv {
return false
}
}
return true
}

View File

@@ -2,7 +2,7 @@ package db
import (
"bufio"
"github.com/HDT3213/godis/src/config"
"github.com/HDT3213/godis/src/config"
"github.com/HDT3213/godis/src/datastruct/dict"
List "github.com/HDT3213/godis/src/datastruct/list"
"github.com/HDT3213/godis/src/datastruct/lock"
@@ -29,18 +29,18 @@ func makeExpireCmd(key string, expireAt time.Time) *reply.MultiBulkReply {
}
func makeAofCmd(cmd string, args [][]byte) *reply.MultiBulkReply {
params := make([][]byte, len(args)+1)
copy(params[1:], args)
params[0] = []byte(cmd)
return reply.MakeMultiBulkReply(params)
params := make([][]byte, len(args)+1)
copy(params[1:], args)
params[0] = []byte(cmd)
return reply.MakeMultiBulkReply(params)
}
// send command to aof
func (db *DB) AddAof(args *reply.MultiBulkReply) {
// aofChan == nil when loadAof
if config.Properties.AppendOnly && db.aofChan != nil {
db.aofChan <- args
}
// aofChan == nil when loadAof
if config.Properties.AppendOnly && db.aofChan != nil {
db.aofChan <- args
}
}
// listen aof channel and write into file
@@ -72,12 +72,12 @@ func trim(msg []byte) string {
// read aof file
func (db *DB) loadAof(maxBytes int) {
// delete aofChan to prevent write again
aofChan := db.aofChan
db.aofChan = nil
defer func(aofChan chan *reply.MultiBulkReply) {
db.aofChan = aofChan
}(aofChan)
// delete aofChan to prevent write again
aofChan := db.aofChan
db.aofChan = nil
defer func(aofChan chan *reply.MultiBulkReply) {
db.aofChan = aofChan
}(aofChan)
file, err := os.Open(db.aofFilename)
if err != nil {

View File

@@ -1,297 +1,297 @@
package db
import (
"fmt"
"github.com/HDT3213/godis/src/config"
"github.com/HDT3213/godis/src/datastruct/dict"
List "github.com/HDT3213/godis/src/datastruct/list"
"github.com/HDT3213/godis/src/datastruct/lock"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/pubsub"
"github.com/HDT3213/godis/src/redis/reply"
"os"
"runtime/debug"
"strings"
"sync"
"time"
"fmt"
"github.com/HDT3213/godis/src/config"
"github.com/HDT3213/godis/src/datastruct/dict"
List "github.com/HDT3213/godis/src/datastruct/list"
"github.com/HDT3213/godis/src/datastruct/lock"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/pubsub"
"github.com/HDT3213/godis/src/redis/reply"
"os"
"runtime/debug"
"strings"
"sync"
"time"
)
type DataEntity struct {
Data interface{}
Data interface{}
}
const (
dataDictSize = 1 << 16
ttlDictSize = 1 << 10
lockerSize = 128
aofQueueSize = 1 << 16
dataDictSize = 1 << 16
ttlDictSize = 1 << 10
lockerSize = 128
aofQueueSize = 1 << 16
)
// args don't include cmd line
type CmdFunc func(db *DB, args [][]byte) redis.Reply
type DB struct {
// key -> DataEntity
Data dict.Dict
// key -> expireTime (time.Time)
TTLMap dict.Dict
// channel -> list<*client>
SubMap dict.Dict
// key -> DataEntity
Data dict.Dict
// key -> expireTime (time.Time)
TTLMap dict.Dict
// channel -> list<*client>
SubMap dict.Dict
// dict will ensure thread safety of its method
// use this mutex for complicated command only, eg. rpush, incr ...
Locker *lock.Locks
// dict will ensure thread safety of its method
// use this mutex for complicated command only, eg. rpush, incr ...
Locker *lock.Locks
// TimerTask interval
interval time.Duration
// TimerTask interval
interval time.Duration
stopWorld sync.WaitGroup
stopWorld sync.WaitGroup
hub *pubsub.Hub
hub *pubsub.Hub
// main goroutine send commands to aof goroutine through aofChan
aofChan chan *reply.MultiBulkReply
aofFile *os.File
aofFilename string
// main goroutine send commands to aof goroutine through aofChan
aofChan chan *reply.MultiBulkReply
aofFile *os.File
aofFilename string
aofRewriteChan chan *reply.MultiBulkReply
pausingAof sync.RWMutex
aofRewriteChan chan *reply.MultiBulkReply
pausingAof sync.RWMutex
}
var router = MakeRouter()
func MakeDB() *DB {
db := &DB{
Data: dict.MakeConcurrent(dataDictSize),
TTLMap: dict.MakeConcurrent(ttlDictSize),
Locker: lock.Make(lockerSize),
interval: 5 * time.Second,
hub: pubsub.MakeHub(),
}
db := &DB{
Data: dict.MakeConcurrent(dataDictSize),
TTLMap: dict.MakeConcurrent(ttlDictSize),
Locker: lock.Make(lockerSize),
interval: 5 * time.Second,
hub: pubsub.MakeHub(),
}
// aof
if config.Properties.AppendOnly {
db.aofFilename = config.Properties.AppendFilename
db.loadAof(0)
aofFile, err := os.OpenFile(db.aofFilename, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
logger.Warn(err)
} else {
db.aofFile = aofFile
db.aofChan = make(chan *reply.MultiBulkReply, aofQueueSize)
}
go func() {
db.handleAof()
}()
}
// aof
if config.Properties.AppendOnly {
db.aofFilename = config.Properties.AppendFilename
db.loadAof(0)
aofFile, err := os.OpenFile(db.aofFilename, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
logger.Warn(err)
} else {
db.aofFile = aofFile
db.aofChan = make(chan *reply.MultiBulkReply, aofQueueSize)
}
go func() {
db.handleAof()
}()
}
// start timer
db.TimerTask()
return db
// start timer
db.TimerTask()
return db
}
func (db *DB) Close() {
if db.aofFile != nil {
err := db.aofFile.Close()
if err != nil {
logger.Warn(err)
}
}
if db.aofFile != nil {
err := db.aofFile.Close()
if err != nil {
logger.Warn(err)
}
}
}
func (db *DB) Exec(c redis.Connection, args [][]byte) (result redis.Reply) {
defer func() {
if err := recover(); err != nil {
logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack())))
result = &reply.UnknownErrReply{}
}
}()
defer func() {
if err := recover(); err != nil {
logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack())))
result = &reply.UnknownErrReply{}
}
}()
cmd := strings.ToLower(string(args[0]))
cmd := strings.ToLower(string(args[0]))
// special commands
if cmd == "subscribe" {
if len(args) < 2 {
return &reply.ArgNumErrReply{Cmd: "subscribe"}
}
return pubsub.Subscribe(db.hub, c, args[1:])
} else if cmd == "publish" {
return pubsub.Publish(db.hub, args[1:])
} else if cmd == "unsubscribe" {
return pubsub.UnSubscribe(db.hub, c, args[1:])
} else if cmd == "bgrewriteaof" {
// aof.go imports router.go, router.go cannot import BGRewriteAOF from aof.go
reply := BGRewriteAOF(db, args[1:])
return reply
}
// special commands
if cmd == "subscribe" {
if len(args) < 2 {
return &reply.ArgNumErrReply{Cmd: "subscribe"}
}
return pubsub.Subscribe(db.hub, c, args[1:])
} else if cmd == "publish" {
return pubsub.Publish(db.hub, args[1:])
} else if cmd == "unsubscribe" {
return pubsub.UnSubscribe(db.hub, c, args[1:])
} else if cmd == "bgrewriteaof" {
// aof.go imports router.go, router.go cannot import BGRewriteAOF from aof.go
reply := BGRewriteAOF(db, args[1:])
return reply
}
// normal commands
cmdFunc, ok := router[cmd]
if !ok {
return reply.MakeErrReply("ERR unknown command '" + cmd + "'")
}
if len(args) > 1 {
result = cmdFunc(db, args[1:])
} else {
result = cmdFunc(db, [][]byte{})
}
// normal commands
cmdFunc, ok := router[cmd]
if !ok {
return reply.MakeErrReply("ERR unknown command '" + cmd + "'")
}
if len(args) > 1 {
result = cmdFunc(db, args[1:])
} else {
result = cmdFunc(db, [][]byte{})
}
// aof
// aof
return
return
}
/* ---- Data Access ----- */
func (db *DB) Get(key string) (*DataEntity, bool) {
db.stopWorld.Wait()
db.stopWorld.Wait()
raw, ok := db.Data.Get(key)
if !ok {
return nil, false
}
if db.IsExpired(key) {
return nil, false
}
entity, _ := raw.(*DataEntity)
return entity, true
raw, ok := db.Data.Get(key)
if !ok {
return nil, false
}
if db.IsExpired(key) {
return nil, false
}
entity, _ := raw.(*DataEntity)
return entity, true
}
func (db *DB) Put(key string, entity *DataEntity) int {
db.stopWorld.Wait()
return db.Data.Put(key, entity)
db.stopWorld.Wait()
return db.Data.Put(key, entity)
}
func (db *DB) PutIfExists(key string, entity *DataEntity) int {
db.stopWorld.Wait()
return db.Data.PutIfExists(key, entity)
db.stopWorld.Wait()
return db.Data.PutIfExists(key, entity)
}
func (db *DB) PutIfAbsent(key string, entity *DataEntity) int {
db.stopWorld.Wait()
return db.Data.PutIfAbsent(key, entity)
db.stopWorld.Wait()
return db.Data.PutIfAbsent(key, entity)
}
func (db *DB) Remove(key string) {
db.stopWorld.Wait()
db.Data.Remove(key)
db.TTLMap.Remove(key)
db.stopWorld.Wait()
db.Data.Remove(key)
db.TTLMap.Remove(key)
}
func (db *DB) Removes(keys ...string) (deleted int) {
db.stopWorld.Wait()
deleted = 0
for _, key := range keys {
_, exists := db.Data.Get(key)
if exists {
db.Data.Remove(key)
db.TTLMap.Remove(key)
deleted++
}
}
return deleted
db.stopWorld.Wait()
deleted = 0
for _, key := range keys {
_, exists := db.Data.Get(key)
if exists {
db.Data.Remove(key)
db.TTLMap.Remove(key)
deleted++
}
}
return deleted
}
func (db *DB) Flush() {
db.stopWorld.Add(1)
defer db.stopWorld.Done()
db.stopWorld.Add(1)
defer db.stopWorld.Done()
db.Data = dict.MakeConcurrent(dataDictSize)
db.TTLMap = dict.MakeConcurrent(ttlDictSize)
db.Locker = lock.Make(lockerSize)
db.Data = dict.MakeConcurrent(dataDictSize)
db.TTLMap = dict.MakeConcurrent(ttlDictSize)
db.Locker = lock.Make(lockerSize)
}
/* ---- Lock Function ----- */
func (db *DB) Lock(key string) {
db.Locker.Lock(key)
db.Locker.Lock(key)
}
func (db *DB) RLock(key string) {
db.Locker.RLock(key)
db.Locker.RLock(key)
}
func (db *DB) UnLock(key string) {
db.Locker.UnLock(key)
db.Locker.UnLock(key)
}
func (db *DB) RUnLock(key string) {
db.Locker.RUnLock(key)
db.Locker.RUnLock(key)
}
func (db *DB) Locks(keys ...string) {
db.Locker.Locks(keys...)
db.Locker.Locks(keys...)
}
func (db *DB) RLocks(keys ...string) {
db.Locker.RLocks(keys...)
db.Locker.RLocks(keys...)
}
func (db *DB) UnLocks(keys ...string) {
db.Locker.UnLocks(keys...)
db.Locker.UnLocks(keys...)
}
func (db *DB) RUnLocks(keys ...string) {
db.Locker.RUnLocks(keys...)
db.Locker.RUnLocks(keys...)
}
/* ---- TTL Functions ---- */
func (db *DB) Expire(key string, expireTime time.Time) {
db.stopWorld.Wait()
db.TTLMap.Put(key, expireTime)
db.stopWorld.Wait()
db.TTLMap.Put(key, expireTime)
}
func (db *DB) Persist(key string) {
db.stopWorld.Wait()
db.TTLMap.Remove(key)
db.stopWorld.Wait()
db.TTLMap.Remove(key)
}
func (db *DB) IsExpired(key string) bool {
rawExpireTime, ok := db.TTLMap.Get(key)
if !ok {
return false
}
expireTime, _ := rawExpireTime.(time.Time)
expired := time.Now().After(expireTime)
if expired {
db.Remove(key)
}
return expired
rawExpireTime, ok := db.TTLMap.Get(key)
if !ok {
return false
}
expireTime, _ := rawExpireTime.(time.Time)
expired := time.Now().After(expireTime)
if expired {
db.Remove(key)
}
return expired
}
func (db *DB) CleanExpired() {
now := time.Now()
toRemove := &List.LinkedList{}
db.TTLMap.ForEach(func(key string, val interface{}) bool {
expireTime, _ := val.(time.Time)
if now.After(expireTime) {
// expired
db.Data.Remove(key)
toRemove.Add(key)
}
return true
})
toRemove.ForEach(func(i int, val interface{}) bool {
key, _ := val.(string)
db.TTLMap.Remove(key)
return true
})
now := time.Now()
toRemove := &List.LinkedList{}
db.TTLMap.ForEach(func(key string, val interface{}) bool {
expireTime, _ := val.(time.Time)
if now.After(expireTime) {
// expired
db.Data.Remove(key)
toRemove.Add(key)
}
return true
})
toRemove.ForEach(func(i int, val interface{}) bool {
key, _ := val.(string)
db.TTLMap.Remove(key)
return true
})
}
func (db *DB) TimerTask() {
ticker := time.NewTicker(db.interval)
go func() {
for range ticker.C {
db.CleanExpired()
}
}()
ticker := time.NewTicker(db.interval)
go func() {
for range ticker.C {
db.CleanExpired()
}
}()
}
/* ---- Subscribe Functions ---- */
func (db *DB) AfterClientClose(c redis.Connection) {
pubsub.UnsubscribeAll(db.hub, c)
pubsub.UnsubscribeAll(db.hub, c)
}

View File

@@ -1,422 +1,422 @@
package db
import (
Dict "github.com/HDT3213/godis/src/datastruct/dict"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"github.com/shopspring/decimal"
"strconv"
Dict "github.com/HDT3213/godis/src/datastruct/dict"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"github.com/shopspring/decimal"
"strconv"
)
func (db *DB) getAsDict(key string) (Dict.Dict, reply.ErrorReply) {
entity, exists := db.Get(key)
if !exists {
return nil, nil
}
dict, ok := entity.Data.(Dict.Dict)
if !ok {
return nil, &reply.WrongTypeErrReply{}
}
return dict, nil
entity, exists := db.Get(key)
if !exists {
return nil, nil
}
dict, ok := entity.Data.(Dict.Dict)
if !ok {
return nil, &reply.WrongTypeErrReply{}
}
return dict, nil
}
func (db *DB) getOrInitDict(key string) (dict Dict.Dict, inited bool, errReply reply.ErrorReply) {
dict, errReply = db.getAsDict(key)
if errReply != nil {
return nil, false, errReply
}
inited = false
if dict == nil {
dict = Dict.MakeSimple()
db.Put(key, &DataEntity{
Data: dict,
})
inited = true
}
return dict, inited, nil
dict, errReply = db.getAsDict(key)
if errReply != nil {
return nil, false, errReply
}
inited = false
if dict == nil {
dict = Dict.MakeSimple()
db.Put(key, &DataEntity{
Data: dict,
})
inited = true
}
return dict, inited, nil
}
func HSet(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hset' command")
}
key := string(args[0])
field := string(args[1])
value := args[2]
// parse args
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hset' command")
}
key := string(args[0])
field := string(args[1])
value := args[2]
// lock
db.Lock(key)
defer db.UnLock(key)
// lock
db.Lock(key)
defer db.UnLock(key)
// get or init entity
dict, _, errReply := db.getOrInitDict(key)
if errReply != nil {
return errReply
}
// get or init entity
dict, _, errReply := db.getOrInitDict(key)
if errReply != nil {
return errReply
}
result := dict.Put(field, value)
db.AddAof(makeAofCmd("hset", args))
return reply.MakeIntReply(int64(result))
result := dict.Put(field, value)
db.AddAof(makeAofCmd("hset", args))
return reply.MakeIntReply(int64(result))
}
func HSetNX(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hsetnx' command")
}
key := string(args[0])
field := string(args[1])
value := args[2]
// parse args
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hsetnx' command")
}
key := string(args[0])
field := string(args[1])
value := args[2]
db.Lock(key)
defer db.UnLock(key)
db.Lock(key)
defer db.UnLock(key)
dict, _, errReply := db.getOrInitDict(key)
if errReply != nil {
return errReply
}
dict, _, errReply := db.getOrInitDict(key)
if errReply != nil {
return errReply
}
result := dict.PutIfAbsent(field, value)
if result > 0 {
db.AddAof(makeAofCmd("hsetnx", args))
result := dict.PutIfAbsent(field, value)
if result > 0 {
db.AddAof(makeAofCmd("hsetnx", args))
}
return reply.MakeIntReply(int64(result))
}
return reply.MakeIntReply(int64(result))
}
func HGet(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hget' command")
}
key := string(args[0])
field := string(args[1])
// parse args
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hget' command")
}
key := string(args[0])
field := string(args[1])
// get entity
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return &reply.NullBulkReply{}
}
// get entity
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return &reply.NullBulkReply{}
}
raw, exists := dict.Get(field)
if !exists {
return &reply.NullBulkReply{}
}
value, _ := raw.([]byte)
return reply.MakeBulkReply(value)
raw, exists := dict.Get(field)
if !exists {
return &reply.NullBulkReply{}
}
value, _ := raw.([]byte)
return reply.MakeBulkReply(value)
}
func HExists(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hexists' command")
}
key := string(args[0])
field := string(args[1])
// parse args
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hexists' command")
}
key := string(args[0])
field := string(args[1])
// get entity
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return reply.MakeIntReply(0)
}
// get entity
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return reply.MakeIntReply(0)
}
_, exists := dict.Get(field)
if exists {
return reply.MakeIntReply(1)
}
return reply.MakeIntReply(0)
_, exists := dict.Get(field)
if exists {
return reply.MakeIntReply(1)
}
return reply.MakeIntReply(0)
}
func HDel(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hdel' command")
}
key := string(args[0])
fields := make([]string, len(args) - 1)
fieldArgs := args[1:]
for i, v := range fieldArgs {
fields[i] = string(v)
}
// parse args
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hdel' command")
}
key := string(args[0])
fields := make([]string, len(args)-1)
fieldArgs := args[1:]
for i, v := range fieldArgs {
fields[i] = string(v)
}
db.Lock(key)
defer db.UnLock(key)
db.Lock(key)
defer db.UnLock(key)
// get entity
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return reply.MakeIntReply(0)
}
// get entity
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return reply.MakeIntReply(0)
}
deleted := 0
for _, field := range fields {
result := dict.Remove(field)
deleted += result
}
if dict.Len() == 0 {
db.Remove(key)
}
if deleted > 0 {
db.AddAof(makeAofCmd("hdel", args))
}
deleted := 0
for _, field := range fields {
result := dict.Remove(field)
deleted += result
}
if dict.Len() == 0 {
db.Remove(key)
}
if deleted > 0 {
db.AddAof(makeAofCmd("hdel", args))
}
return reply.MakeIntReply(int64(deleted))
return reply.MakeIntReply(int64(deleted))
}
func HLen(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hlen' command")
}
key := string(args[0])
// parse args
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hlen' command")
}
key := string(args[0])
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return reply.MakeIntReply(0)
}
return reply.MakeIntReply(int64(dict.Len()))
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return reply.MakeIntReply(0)
}
return reply.MakeIntReply(int64(dict.Len()))
}
func HMSet(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args) < 3 || len(args) % 2 != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hmset' command")
}
key := string(args[0])
size := (len(args) - 1) / 2
fields := make([]string, size)
values := make([][]byte, size)
for i := 0; i < size; i++ {
fields[i] = string(args[2 * i + 1])
values[i] = args[2 * i + 2]
}
// parse args
if len(args) < 3 || len(args)%2 != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hmset' command")
}
key := string(args[0])
size := (len(args) - 1) / 2
fields := make([]string, size)
values := make([][]byte, size)
for i := 0; i < size; i++ {
fields[i] = string(args[2*i+1])
values[i] = args[2*i+2]
}
// lock key
db.Locker.Lock(key)
defer db.Locker.UnLock(key)
// lock key
db.Locker.Lock(key)
defer db.Locker.UnLock(key)
// get or init entity
dict, _, errReply := db.getOrInitDict(key)
if errReply != nil {
return errReply
}
// get or init entity
dict, _, errReply := db.getOrInitDict(key)
if errReply != nil {
return errReply
}
// put data
for i, field := range fields {
value := values[i]
dict.Put(field, value)
}
db.AddAof(makeAofCmd("hmset", args))
return &reply.OkReply{}
// put data
for i, field := range fields {
value := values[i]
dict.Put(field, value)
}
db.AddAof(makeAofCmd("hmset", args))
return &reply.OkReply{}
}
func HMGet(db *DB, args [][]byte) redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hmget' command")
}
key := string(args[0])
size := len(args) - 1
fields := make([]string, size)
for i := 0; i < size; i++ {
fields[i] = string(args[i + 1])
}
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hmget' command")
}
key := string(args[0])
size := len(args) - 1
fields := make([]string, size)
for i := 0; i < size; i++ {
fields[i] = string(args[i+1])
}
db.RLock(key)
defer db.RUnLock(key)
db.RLock(key)
defer db.RUnLock(key)
// get entity
result := make([][]byte, size)
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return reply.MakeMultiBulkReply(result)
}
// get entity
result := make([][]byte, size)
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return reply.MakeMultiBulkReply(result)
}
for i, field := range fields {
value, ok := dict.Get(field)
if !ok {
result[i] = nil
} else {
bytes, _ := value.([]byte)
result[i] = bytes
}
}
return reply.MakeMultiBulkReply(result)
for i, field := range fields {
value, ok := dict.Get(field)
if !ok {
result[i] = nil
} else {
bytes, _ := value.([]byte)
result[i] = bytes
}
}
return reply.MakeMultiBulkReply(result)
}
func HKeys(db *DB, args [][]byte) redis.Reply {
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hkeys' command")
}
key := string(args[0])
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hkeys' command")
}
key := string(args[0])
db.RLock(key)
defer db.RUnLock(key)
db.RLock(key)
defer db.RUnLock(key)
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return &reply.EmptyMultiBulkReply{}
}
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return &reply.EmptyMultiBulkReply{}
}
fields := make([][]byte, dict.Len())
i := 0
dict.ForEach(func(key string, val interface{})bool {
fields[i] = []byte(key)
i++
return true
})
return reply.MakeMultiBulkReply(fields[:i])
fields := make([][]byte, dict.Len())
i := 0
dict.ForEach(func(key string, val interface{}) bool {
fields[i] = []byte(key)
i++
return true
})
return reply.MakeMultiBulkReply(fields[:i])
}
func HVals(db *DB, args [][]byte) redis.Reply {
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hvals' command")
}
key := string(args[0])
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hvals' command")
}
key := string(args[0])
db.RLock(key)
defer db.RUnLock(key)
db.RLock(key)
defer db.RUnLock(key)
// get entity
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return &reply.EmptyMultiBulkReply{}
}
// get entity
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return &reply.EmptyMultiBulkReply{}
}
values := make([][]byte, dict.Len())
i := 0
dict.ForEach(func(key string, val interface{})bool {
values[i], _ = val.([]byte)
i++
return true
})
return reply.MakeMultiBulkReply(values[:i])
values := make([][]byte, dict.Len())
i := 0
dict.ForEach(func(key string, val interface{}) bool {
values[i], _ = val.([]byte)
i++
return true
})
return reply.MakeMultiBulkReply(values[:i])
}
func HGetAll(db *DB, args [][]byte) redis.Reply {
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hgetAll' command")
}
key := string(args[0])
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hgetAll' command")
}
key := string(args[0])
db.RLock(key)
defer db.RUnLock(key)
db.RLock(key)
defer db.RUnLock(key)
// get entity
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return &reply.EmptyMultiBulkReply{}
}
// get entity
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return &reply.EmptyMultiBulkReply{}
}
size := dict.Len()
result := make([][]byte, size * 2)
i := 0
dict.ForEach(func(key string, val interface{})bool {
result[i] = []byte(key)
i++
result[i], _ = val.([]byte)
i++
return true
})
return reply.MakeMultiBulkReply(result[:i])
size := dict.Len()
result := make([][]byte, size*2)
i := 0
dict.ForEach(func(key string, val interface{}) bool {
result[i] = []byte(key)
i++
result[i], _ = val.([]byte)
i++
return true
})
return reply.MakeMultiBulkReply(result[:i])
}
func HIncrBy(db *DB, args [][]byte) redis.Reply {
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hincrby' command")
}
key := string(args[0])
field := string(args[1])
rawDelta := string(args[2])
delta, err := strconv.ParseInt(rawDelta, 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hincrby' command")
}
key := string(args[0])
field := string(args[1])
rawDelta := string(args[2])
delta, err := strconv.ParseInt(rawDelta, 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
db.Locker.Lock(key)
defer db.Locker.UnLock(key)
db.Locker.Lock(key)
defer db.Locker.UnLock(key)
dict, _, errReply := db.getOrInitDict(key)
if errReply != nil {
return errReply
}
dict, _, errReply := db.getOrInitDict(key)
if errReply != nil {
return errReply
}
value, exists := dict.Get(field)
if !exists {
dict.Put(field, args[2])
db.AddAof(makeAofCmd("hincrby", args))
return reply.MakeBulkReply(args[2])
} else {
val, err := strconv.ParseInt(string(value.([]byte)), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR hash value is not an integer")
}
val += delta
bytes := []byte(strconv.FormatInt(val, 10))
dict.Put(field, bytes)
db.AddAof(makeAofCmd("hincrby", args))
return reply.MakeBulkReply(bytes)
}
value, exists := dict.Get(field)
if !exists {
dict.Put(field, args[2])
db.AddAof(makeAofCmd("hincrby", args))
return reply.MakeBulkReply(args[2])
} else {
val, err := strconv.ParseInt(string(value.([]byte)), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR hash value is not an integer")
}
val += delta
bytes := []byte(strconv.FormatInt(val, 10))
dict.Put(field, bytes)
db.AddAof(makeAofCmd("hincrby", args))
return reply.MakeBulkReply(bytes)
}
}
func HIncrByFloat(db *DB, args [][]byte) redis.Reply {
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hincrbyfloat' command")
}
key := string(args[0])
field := string(args[1])
rawDelta := string(args[2])
delta, err := decimal.NewFromString(rawDelta)
if err != nil {
return reply.MakeErrReply("ERR value is not a valid float")
}
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hincrbyfloat' command")
}
key := string(args[0])
field := string(args[1])
rawDelta := string(args[2])
delta, err := decimal.NewFromString(rawDelta)
if err != nil {
return reply.MakeErrReply("ERR value is not a valid float")
}
db.Lock(key)
defer db.UnLock(key)
db.Lock(key)
defer db.UnLock(key)
// get or init entity
dict, _, errReply := db.getOrInitDict(key)
if errReply != nil {
return errReply
}
// get or init entity
dict, _, errReply := db.getOrInitDict(key)
if errReply != nil {
return errReply
}
value, exists := dict.Get(field)
if !exists {
dict.Put(field, args[2])
return reply.MakeBulkReply(args[2])
} else {
val, err := decimal.NewFromString(string(value.([]byte)))
if err != nil {
return reply.MakeErrReply("ERR hash value is not a float")
}
result := val.Add(delta)
resultBytes:= []byte(result.String())
dict.Put(field, resultBytes)
db.AddAof(makeAofCmd("hincrbyfloat", args))
return reply.MakeBulkReply(resultBytes)
}
value, exists := dict.Get(field)
if !exists {
dict.Put(field, args[2])
return reply.MakeBulkReply(args[2])
} else {
val, err := decimal.NewFromString(string(value.([]byte)))
if err != nil {
return reply.MakeErrReply("ERR hash value is not a float")
}
result := val.Add(delta)
resultBytes := []byte(result.String())
dict.Put(field, resultBytes)
db.AddAof(makeAofCmd("hincrbyfloat", args))
return reply.MakeBulkReply(resultBytes)
}
}

View File

@@ -1,439 +1,439 @@
package db
import (
List "github.com/HDT3213/godis/src/datastruct/list"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"strconv"
List "github.com/HDT3213/godis/src/datastruct/list"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"strconv"
)
func (db *DB) getAsList(key string)(*List.LinkedList, reply.ErrorReply) {
entity, ok := db.Get(key)
if !ok {
return nil, nil
}
bytes, ok := entity.Data.(*List.LinkedList)
if !ok {
return nil, &reply.WrongTypeErrReply{}
}
return bytes, nil
func (db *DB) getAsList(key string) (*List.LinkedList, reply.ErrorReply) {
entity, ok := db.Get(key)
if !ok {
return nil, nil
}
bytes, ok := entity.Data.(*List.LinkedList)
if !ok {
return nil, &reply.WrongTypeErrReply{}
}
return bytes, nil
}
func (db *DB) getOrInitList(key string)(list *List.LinkedList, inited bool, errReply reply.ErrorReply) {
list, errReply = db.getAsList(key)
if errReply != nil {
return nil, false, errReply
}
inited = false
if list == nil {
list = &List.LinkedList{}
db.Put(key, &DataEntity{
Data: list,
})
inited = true
}
return list, inited, nil
func (db *DB) getOrInitList(key string) (list *List.LinkedList, inited bool, errReply reply.ErrorReply) {
list, errReply = db.getAsList(key)
if errReply != nil {
return nil, false, errReply
}
inited = false
if list == nil {
list = &List.LinkedList{}
db.Put(key, &DataEntity{
Data: list,
})
inited = true
}
return list, inited, nil
}
func LIndex(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command")
}
key := string(args[0])
index64, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
index := int(index64)
// parse args
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command")
}
key := string(args[0])
index64, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
index := int(index64)
// get entity
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return &reply.NullBulkReply{}
}
// get entity
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return &reply.NullBulkReply{}
}
size := list.Len() // assert: size > 0
if index < -1 * size {
return &reply.NullBulkReply{}
} else if index < 0 {
index = size + index
} else if index >= size {
return &reply.NullBulkReply{}
}
size := list.Len() // assert: size > 0
if index < -1*size {
return &reply.NullBulkReply{}
} else if index < 0 {
index = size + index
} else if index >= size {
return &reply.NullBulkReply{}
}
val, _ := list.Get(index).([]byte)
return reply.MakeBulkReply(val)
val, _ := list.Get(index).([]byte)
return reply.MakeBulkReply(val)
}
func LLen(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'llen' command")
}
key := string(args[0])
// parse args
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'llen' command")
}
key := string(args[0])
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return reply.MakeIntReply(0)
}
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return reply.MakeIntReply(0)
}
size := int64(list.Len())
return reply.MakeIntReply(size)
size := int64(list.Len())
return reply.MakeIntReply(size)
}
func LPop(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command")
}
key := string(args[0])
// parse args
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command")
}
key := string(args[0])
// lock
db.Lock(key)
defer db.UnLock(key)
// lock
db.Lock(key)
defer db.UnLock(key)
// get data
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return &reply.NullBulkReply{}
}
// get data
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return &reply.NullBulkReply{}
}
val, _ := list.Remove(0).([]byte)
if list.Len() == 0 {
db.Remove(key)
}
db.AddAof(makeAofCmd("lpop", args))
return reply.MakeBulkReply(val)
val, _ := list.Remove(0).([]byte)
if list.Len() == 0 {
db.Remove(key)
}
db.AddAof(makeAofCmd("lpop", args))
return reply.MakeBulkReply(val)
}
func LPush(db *DB, args [][]byte) redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lpush' command")
}
key := string(args[0])
values := args[1:]
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lpush' command")
}
key := string(args[0])
values := args[1:]
// lock
db.Locker.Lock(key)
defer db.Locker.UnLock(key)
// lock
db.Locker.Lock(key)
defer db.Locker.UnLock(key)
// get or init entity
list, _, errReply := db.getOrInitList(key)
if errReply != nil {
return errReply
}
// get or init entity
list, _, errReply := db.getOrInitList(key)
if errReply != nil {
return errReply
}
// insert
for _, value := range values {
list.Insert(0, value)
}
// insert
for _, value := range values {
list.Insert(0, value)
}
db.AddAof(makeAofCmd("lpush", args))
return reply.MakeIntReply(int64(list.Len()))
db.AddAof(makeAofCmd("lpush", args))
return reply.MakeIntReply(int64(list.Len()))
}
func LPushX(db *DB, args [][]byte) redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lpushx' command")
}
key := string(args[0])
values := args[1:]
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lpushx' command")
}
key := string(args[0])
values := args[1:]
// lock
db.Locker.Lock(key)
defer db.Locker.UnLock(key)
// lock
db.Locker.Lock(key)
defer db.Locker.UnLock(key)
// get or init entity
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return reply.MakeIntReply(0)
}
// get or init entity
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return reply.MakeIntReply(0)
}
// insert
for _, value := range values {
list.Insert(0, value)
}
db.AddAof(makeAofCmd("lpushx", args))
return reply.MakeIntReply(int64(list.Len()))
// insert
for _, value := range values {
list.Insert(0, value)
}
db.AddAof(makeAofCmd("lpushx", args))
return reply.MakeIntReply(int64(list.Len()))
}
func LRange(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lrange' command")
}
key := string(args[0])
start64, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
start := int(start64)
stop64, err := strconv.ParseInt(string(args[2]), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
stop := int(stop64)
// parse args
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lrange' command")
}
key := string(args[0])
start64, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
start := int(start64)
stop64, err := strconv.ParseInt(string(args[2]), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
stop := int(stop64)
// lock key
db.RLock(key)
defer db.RUnLock(key)
// lock key
db.RLock(key)
defer db.RUnLock(key)
// get data
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return &reply.EmptyMultiBulkReply{}
}
// get data
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return &reply.EmptyMultiBulkReply{}
}
// compute index
size := list.Len() // assert: size > 0
if start < -1 * size {
start = 0
} else if start < 0 {
start = size + start
} else if start >= size {
return &reply.EmptyMultiBulkReply{}
}
if stop < -1 * size {
stop = 0
} else if stop < 0 {
stop = size + stop + 1
} else if stop < size {
stop = stop + 1
} else {
stop = size
}
if stop < start {
stop = start
}
// compute index
size := list.Len() // assert: size > 0
if start < -1*size {
start = 0
} else if start < 0 {
start = size + start
} else if start >= size {
return &reply.EmptyMultiBulkReply{}
}
if stop < -1*size {
stop = 0
} else if stop < 0 {
stop = size + stop + 1
} else if stop < size {
stop = stop + 1
} else {
stop = size
}
if stop < start {
stop = start
}
// assert: start in [0, size - 1], stop in [start, size]
slice := list.Range(start, stop)
result := make([][]byte, len(slice))
for i, raw := range slice {
bytes, _ := raw.([]byte)
result[i] = bytes
}
return reply.MakeMultiBulkReply(result)
// assert: start in [0, size - 1], stop in [start, size]
slice := list.Range(start, stop)
result := make([][]byte, len(slice))
for i, raw := range slice {
bytes, _ := raw.([]byte)
result[i] = bytes
}
return reply.MakeMultiBulkReply(result)
}
func LRem(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lrem' command")
}
key := string(args[0])
count64, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
count := int(count64)
value := args[2]
// parse args
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lrem' command")
}
key := string(args[0])
count64, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
count := int(count64)
value := args[2]
// lock
db.Lock(key)
defer db.UnLock(key)
// lock
db.Lock(key)
defer db.UnLock(key)
// get data entity
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return reply.MakeIntReply(0)
}
// get data entity
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return reply.MakeIntReply(0)
}
var removed int
if count == 0 {
removed = list.RemoveAllByVal(value)
} else if count > 0 {
removed = list.RemoveByVal(value, count)
} else {
removed = list.ReverseRemoveByVal(value, -count)
}
var removed int
if count == 0 {
removed = list.RemoveAllByVal(value)
} else if count > 0 {
removed = list.RemoveByVal(value, count)
} else {
removed = list.ReverseRemoveByVal(value, -count)
}
if list.Len() == 0 {
db.Remove(key)
}
if removed > 0 {
db.AddAof(makeAofCmd("lrem", args))
}
if list.Len() == 0 {
db.Remove(key)
}
if removed > 0 {
db.AddAof(makeAofCmd("lrem", args))
}
return reply.MakeIntReply(int64(removed))
return reply.MakeIntReply(int64(removed))
}
func LSet(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lset' command")
}
key := string(args[0])
index64, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
index := int(index64)
value := args[2]
// parse args
if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lset' command")
}
key := string(args[0])
index64, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
index := int(index64)
value := args[2]
// lock
db.Locker.Lock(key)
defer db.Locker.UnLock(key)
// lock
db.Locker.Lock(key)
defer db.Locker.UnLock(key)
// get data
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return reply.MakeErrReply("ERR no such key")
}
// get data
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return reply.MakeErrReply("ERR no such key")
}
size := list.Len() // assert: size > 0
if index < -1 * size {
return reply.MakeErrReply("ERR index out of range")
} else if index < 0 {
index = size + index
} else if index >= size {
return reply.MakeErrReply("ERR index out of range")
}
size := list.Len() // assert: size > 0
if index < -1*size {
return reply.MakeErrReply("ERR index out of range")
} else if index < 0 {
index = size + index
} else if index >= size {
return reply.MakeErrReply("ERR index out of range")
}
list.Set(index, value)
db.AddAof(makeAofCmd("lset", args))
return &reply.OkReply{}
list.Set(index, value)
db.AddAof(makeAofCmd("lset", args))
return &reply.OkReply{}
}
func RPop(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rpop' command")
}
key := string(args[0])
// parse args
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rpop' command")
}
key := string(args[0])
// lock
db.Lock(key)
defer db.UnLock(key)
// lock
db.Lock(key)
defer db.UnLock(key)
// get data
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return &reply.NullBulkReply{}
}
// get data
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return &reply.NullBulkReply{}
}
val, _ := list.RemoveLast().([]byte)
if list.Len() == 0 {
db.Remove(key)
}
db.AddAof(makeAofCmd("rpop", args))
return reply.MakeBulkReply(val)
val, _ := list.RemoveLast().([]byte)
if list.Len() == 0 {
db.Remove(key)
}
db.AddAof(makeAofCmd("rpop", args))
return reply.MakeBulkReply(val)
}
func RPopLPush(db *DB, args [][]byte) redis.Reply {
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rpoplpush' command")
}
sourceKey := string(args[0])
destKey := string(args[1])
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rpoplpush' command")
}
sourceKey := string(args[0])
destKey := string(args[1])
// lock
db.Locks(sourceKey, destKey)
defer db.UnLocks(sourceKey, destKey)
// lock
db.Locks(sourceKey, destKey)
defer db.UnLocks(sourceKey, destKey)
// get source entity
sourceList, errReply := db.getAsList(sourceKey)
if errReply != nil {
return errReply
}
if sourceList == nil {
return &reply.NullBulkReply{}
}
// get source entity
sourceList, errReply := db.getAsList(sourceKey)
if errReply != nil {
return errReply
}
if sourceList == nil {
return &reply.NullBulkReply{}
}
// get dest entity
destList, _, errReply := db.getOrInitList(destKey)
if errReply != nil {
return errReply
}
// get dest entity
destList, _, errReply := db.getOrInitList(destKey)
if errReply != nil {
return errReply
}
// pop and push
val, _ := sourceList.RemoveLast().([]byte)
destList.Insert(0, val)
// pop and push
val, _ := sourceList.RemoveLast().([]byte)
destList.Insert(0, val)
if sourceList.Len() == 0 {
db.Remove(sourceKey)
}
if sourceList.Len() == 0 {
db.Remove(sourceKey)
}
db.AddAof(makeAofCmd("rpoplpush", args))
return reply.MakeBulkReply(val)
db.AddAof(makeAofCmd("rpoplpush", args))
return reply.MakeBulkReply(val)
}
func RPush(db *DB, args [][]byte) redis.Reply {
// parse args
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command")
}
key := string(args[0])
values := args[1:]
// parse args
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command")
}
key := string(args[0])
values := args[1:]
// lock
db.Lock(key)
defer db.UnLock(key)
// lock
db.Lock(key)
defer db.UnLock(key)
// get or init entity
list, _, errReply := db.getOrInitList(key)
if errReply != nil {
return errReply
}
// get or init entity
list, _, errReply := db.getOrInitList(key)
if errReply != nil {
return errReply
}
// put list
for _, value := range values {
list.Add(value)
}
db.AddAof(makeAofCmd("rpush", args))
return reply.MakeIntReply(int64(list.Len()))
// put list
for _, value := range values {
list.Add(value)
}
db.AddAof(makeAofCmd("rpush", args))
return reply.MakeIntReply(int64(list.Len()))
}
func RPushX(db *DB, args [][]byte) redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command")
}
key := string(args[0])
values := args[1:]
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command")
}
key := string(args[0])
values := args[1:]
// lock
db.Lock(key)
defer db.UnLock(key)
// lock
db.Lock(key)
defer db.UnLock(key)
// get or init entity
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return reply.MakeIntReply(0)
}
// get or init entity
list, errReply := db.getAsList(key)
if errReply != nil {
return errReply
}
if list == nil {
return reply.MakeIntReply(0)
}
// put list
for _, value := range values {
list.Add(value)
}
db.AddAof(makeAofCmd("rpushx", args))
// put list
for _, value := range values {
list.Add(value)
}
db.AddAof(makeAofCmd("rpushx", args))
return reply.MakeIntReply(int64(list.Len()))
return reply.MakeIntReply(int64(list.Len()))
}

View File

@@ -1,102 +1,102 @@
package db
func MakeRouter()map[string]CmdFunc {
routerMap := make(map[string]CmdFunc)
routerMap["ping"] = Ping
func MakeRouter() map[string]CmdFunc {
routerMap := make(map[string]CmdFunc)
routerMap["ping"] = Ping
routerMap["del"] = Del
routerMap["expire"] = Expire
routerMap["expireat"] = ExpireAt
routerMap["pexpire"] = PExpire
routerMap["pexpireat"] = PExpireAt
routerMap["ttl"] = TTL
routerMap["pttl"] = PTTL
routerMap["persist"] = Persist
routerMap["exists"] = Exists
routerMap["type"] = Type
routerMap["rename"] = Rename
routerMap["renamenx"] = RenameNx
routerMap["del"] = Del
routerMap["expire"] = Expire
routerMap["expireat"] = ExpireAt
routerMap["pexpire"] = PExpire
routerMap["pexpireat"] = PExpireAt
routerMap["ttl"] = TTL
routerMap["pttl"] = PTTL
routerMap["persist"] = Persist
routerMap["exists"] = Exists
routerMap["type"] = Type
routerMap["rename"] = Rename
routerMap["renamenx"] = RenameNx
routerMap["set"] = Set
routerMap["setnx"] = SetNX
routerMap["setex"] = SetEX
routerMap["psetex"] = PSetEX
routerMap["mset"] = MSet
routerMap["mget"] = MGet
routerMap["msetnx"] = MSetNX
routerMap["get"] = Get
routerMap["getset"] = GetSet
routerMap["incr"] = Incr
routerMap["incrby"] = IncrBy
routerMap["incrbyfloat"] = IncrByFloat
routerMap["decr"] = Decr
routerMap["decrby"] = DecrBy
routerMap["set"] = Set
routerMap["setnx"] = SetNX
routerMap["setex"] = SetEX
routerMap["psetex"] = PSetEX
routerMap["mset"] = MSet
routerMap["mget"] = MGet
routerMap["msetnx"] = MSetNX
routerMap["get"] = Get
routerMap["getset"] = GetSet
routerMap["incr"] = Incr
routerMap["incrby"] = IncrBy
routerMap["incrbyfloat"] = IncrByFloat
routerMap["decr"] = Decr
routerMap["decrby"] = DecrBy
routerMap["lpush"] = LPush
routerMap["lpushx"] = LPushX
routerMap["rpush"] = RPush
routerMap["rpushx"] = RPushX
routerMap["lpop"] = LPop
routerMap["rpop"] = RPop
routerMap["rpoplpush"] = RPopLPush
routerMap["lrem"] = LRem
routerMap["llen"] = LLen
routerMap["lindex"] = LIndex
routerMap["lset"] = LSet
routerMap["lrange"] = LRange
routerMap["lpush"] = LPush
routerMap["lpushx"] = LPushX
routerMap["rpush"] = RPush
routerMap["rpushx"] = RPushX
routerMap["lpop"] = LPop
routerMap["rpop"] = RPop
routerMap["rpoplpush"] = RPopLPush
routerMap["lrem"] = LRem
routerMap["llen"] = LLen
routerMap["lindex"] = LIndex
routerMap["lset"] = LSet
routerMap["lrange"] = LRange
routerMap["hset"] = HSet
routerMap["hsetnx"] = HSetNX
routerMap["hget"] = HGet
routerMap["hexists"] = HExists
routerMap["hdel"] = HDel
routerMap["hlen"] = HLen
routerMap["hmget"] = HMGet
routerMap["hmset"] = HMSet
routerMap["hkeys"] = HKeys
routerMap["hvals"] = HVals
routerMap["hgetall"] = HGetAll
routerMap["hincrby"] = HIncrBy
routerMap["hincrbyfloat"] = HIncrByFloat
routerMap["hset"] = HSet
routerMap["hsetnx"] = HSetNX
routerMap["hget"] = HGet
routerMap["hexists"] = HExists
routerMap["hdel"] = HDel
routerMap["hlen"] = HLen
routerMap["hmget"] = HMGet
routerMap["hmset"] = HMSet
routerMap["hkeys"] = HKeys
routerMap["hvals"] = HVals
routerMap["hgetall"] = HGetAll
routerMap["hincrby"] = HIncrBy
routerMap["hincrbyfloat"] = HIncrByFloat
routerMap["sadd"] = SAdd
routerMap["sismember"] = SIsMember
routerMap["srem"] = SRem
routerMap["scard"] = SCard
routerMap["smembers"] = SMembers
routerMap["sinter"] = SInter
routerMap["sinterstore"] = SInterStore
routerMap["sunion"] = SUnion
routerMap["sunionstore"] = SUnionStore
routerMap["sdiff"] = SDiff
routerMap["sdiffstore"] = SDiffStore
routerMap["srandmember"] = SRandMember
routerMap["sadd"] = SAdd
routerMap["sismember"] = SIsMember
routerMap["srem"] = SRem
routerMap["scard"] = SCard
routerMap["smembers"] = SMembers
routerMap["sinter"] = SInter
routerMap["sinterstore"] = SInterStore
routerMap["sunion"] = SUnion
routerMap["sunionstore"] = SUnionStore
routerMap["sdiff"] = SDiff
routerMap["sdiffstore"] = SDiffStore
routerMap["srandmember"] = SRandMember
routerMap["zadd"] = ZAdd
routerMap["zscore"] = ZScore
routerMap["zincrby"] = ZIncrBy
routerMap["zrank"] = ZRank
routerMap["zcount"] = ZCount
routerMap["zrevrank"] = ZRevRank
routerMap["zcard"] = ZCard
routerMap["zrange"] = ZRange
routerMap["zrevrange"] = ZRevRange
routerMap["zrangebyscore"] = ZRangeByScore
routerMap["zrevrangebyscore"] = ZRevRangeByScore
routerMap["zrem"] = ZRem
routerMap["zremrangebyscore"] = ZRemRangeByScore
routerMap["zremrangebyrank"] = ZRemRangeByRank
routerMap["zadd"] = ZAdd
routerMap["zscore"] = ZScore
routerMap["zincrby"] = ZIncrBy
routerMap["zrank"] = ZRank
routerMap["zcount"] = ZCount
routerMap["zrevrank"] = ZRevRank
routerMap["zcard"] = ZCard
routerMap["zrange"] = ZRange
routerMap["zrevrange"] = ZRevRange
routerMap["zrangebyscore"] = ZRangeByScore
routerMap["zrevrangebyscore"] = ZRevRangeByScore
routerMap["zrem"] = ZRem
routerMap["zremrangebyscore"] = ZRemRangeByScore
routerMap["zremrangebyrank"] = ZRemRangeByRank
routerMap["geoadd"] = GeoAdd
routerMap["geopos"] = GeoPos
routerMap["geodist"] = GeoDist
routerMap["geohash"] = GeoHash
routerMap["georadius"] = GeoRadius
routerMap["georadiusbymember"] = GeoRadiusByMember
routerMap["geoadd"] = GeoAdd
routerMap["geopos"] = GeoPos
routerMap["geodist"] = GeoDist
routerMap["geohash"] = GeoHash
routerMap["georadius"] = GeoRadius
routerMap["georadiusbymember"] = GeoRadiusByMember
routerMap["flushdb"] = FlushDB
routerMap["flushall"] = FlushAll
routerMap["keys"] = Keys
routerMap["flushdb"] = FlushDB
routerMap["flushall"] = FlushAll
routerMap["keys"] = Keys
return routerMap
return routerMap
}

View File

@@ -1,16 +1,16 @@
package db
import (
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
)
func Ping(db *DB, args [][]byte) redis.Reply {
if len(args) == 0 {
return &reply.PongReply{}
} else if len(args) == 1 {
return reply.MakeStatusReply("\"" + string(args[0]) + "\"")
} else {
return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command")
}
if len(args) == 0 {
return &reply.PongReply{}
} else if len(args) == 1 {
return reply.MakeStatusReply("\"" + string(args[0]) + "\"")
} else {
return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command")
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@ package db
import "github.com/HDT3213/godis/src/interface/redis"
type DB interface {
Exec(client redis.Connection, args [][]byte) redis.Reply
AfterClientClose(c redis.Connection)
Close()
Exec(client redis.Connection, args [][]byte) redis.Reply
AfterClientClose(c redis.Connection)
Close()
}

View File

@@ -1,11 +1,11 @@
package redis
type Connection interface {
Write([]byte) error
Write([]byte) error
// client should keep its subscribing channels
SubsChannel(channel string)
UnSubsChannel(channel string)
SubsCount()int
GetChannels()[]string
// client should keep its subscribing channels
SubsChannel(channel string)
UnSubsChannel(channel string)
SubsCount() int
GetChannels() []string
}

View File

@@ -1,5 +1,5 @@
package redis
type Reply interface {
ToBytes()[]byte
ToBytes() []byte
}

View File

@@ -1,13 +1,13 @@
package tcp
import (
"net"
"context"
"context"
"net"
)
type HandleFunc func(ctx context.Context, conn net.Conn)
type Handler interface {
Handle(ctx context.Context, conn net.Conn)
Close()error
Handle(ctx context.Context, conn net.Conn)
Close() error
}

View File

@@ -1,80 +1,80 @@
package consistenthash
import (
"hash/crc32"
"sort"
"strconv"
"strings"
"hash/crc32"
"sort"
"strconv"
"strings"
)
type HashFunc func(data []byte) uint32
type Map struct {
hashFunc HashFunc
replicas int
keys []int // sorted
hashMap map[int]string
hashFunc HashFunc
replicas int
keys []int // sorted
hashMap map[int]string
}
func New(replicas int, fn HashFunc) *Map {
m := &Map{
replicas: replicas,
hashFunc: fn,
hashMap: make(map[int]string),
}
if m.hashFunc == nil {
m.hashFunc = crc32.ChecksumIEEE
}
return m
m := &Map{
replicas: replicas,
hashFunc: fn,
hashMap: make(map[int]string),
}
if m.hashFunc == nil {
m.hashFunc = crc32.ChecksumIEEE
}
return m
}
func (m *Map) IsEmpty() bool {
return len(m.keys) == 0
return len(m.keys) == 0
}
func (m *Map) Add(keys ...string) {
for _, key := range keys {
if key == "" {
continue
}
for i := 0; i < m.replicas; i++ {
hash := int(m.hashFunc([]byte(strconv.Itoa(i) + key)))
m.keys = append(m.keys, hash)
m.hashMap[hash] = key
}
}
sort.Ints(m.keys)
for _, key := range keys {
if key == "" {
continue
}
for i := 0; i < m.replicas; i++ {
hash := int(m.hashFunc([]byte(strconv.Itoa(i) + key)))
m.keys = append(m.keys, hash)
m.hashMap[hash] = key
}
}
sort.Ints(m.keys)
}
// support hash tag
func getPartitionKey(key string) string {
beg := strings.Index(key, "{")
if beg == -1 {
return key
}
end := strings.Index(key, "}")
if end == -1 || end == beg+1 {
return key
}
return key[beg+1 : end]
beg := strings.Index(key, "{")
if beg == -1 {
return key
}
end := strings.Index(key, "}")
if end == -1 || end == beg+1 {
return key
}
return key[beg+1 : end]
}
// Get gets the closest item in the hash to the provided key.
func (m *Map) Get(key string) string {
if m.IsEmpty() {
return ""
}
if m.IsEmpty() {
return ""
}
partitionKey := getPartitionKey(key)
hash := int(m.hashFunc([]byte(partitionKey)))
partitionKey := getPartitionKey(key)
hash := int(m.hashFunc([]byte(partitionKey)))
// Binary search for appropriate replica.
idx := sort.Search(len(m.keys), func(i int) bool { return m.keys[i] >= hash })
// Binary search for appropriate replica.
idx := sort.Search(len(m.keys), func(i int) bool { return m.keys[i] >= hash })
// Means we have cycled back to the first replica.
if idx == len(m.keys) {
idx = 0
}
// Means we have cycled back to the first replica.
if idx == len(m.keys) {
idx = 0
}
return m.hashMap[m.keys[idx]]
return m.hashMap[m.keys[idx]]
}

View File

@@ -1,78 +1,78 @@
package files
import (
"mime/multipart"
"io/ioutil"
"path"
"os"
"fmt"
"fmt"
"io/ioutil"
"mime/multipart"
"os"
"path"
)
func GetSize(f multipart.File) (int, error) {
content, err := ioutil.ReadAll(f)
content, err := ioutil.ReadAll(f)
return len(content), err
return len(content), err
}
func GetExt(fileName string) string {
return path.Ext(fileName)
return path.Ext(fileName)
}
func CheckNotExist(src string) bool {
_, err := os.Stat(src)
_, err := os.Stat(src)
return os.IsNotExist(err)
return os.IsNotExist(err)
}
func CheckPermission(src string) bool {
_, err := os.Stat(src)
_, err := os.Stat(src)
return os.IsPermission(err)
return os.IsPermission(err)
}
func IsNotExistMkDir(src string) error {
if notExist := CheckNotExist(src); notExist == true {
if err := MkDir(src); err != nil {
return err
}
}
if notExist := CheckNotExist(src); notExist == true {
if err := MkDir(src); err != nil {
return err
}
}
return nil
return nil
}
func MkDir(src string) error {
err := os.MkdirAll(src, os.ModePerm)
if err != nil {
return err
}
err := os.MkdirAll(src, os.ModePerm)
if err != nil {
return err
}
return nil
return nil
}
func Open(name string, flag int, perm os.FileMode) (*os.File, error) {
f, err := os.OpenFile(name, flag, perm)
if err != nil {
return nil, err
}
f, err := os.OpenFile(name, flag, perm)
if err != nil {
return nil, err
}
return f, nil
return f, nil
}
func MustOpen(fileName, dir string) (*os.File, error) {
perm := CheckPermission(dir)
if perm == true {
return nil, fmt.Errorf("permission denied dir: %s", dir)
}
perm := CheckPermission(dir)
if perm == true {
return nil, fmt.Errorf("permission denied dir: %s", dir)
}
err := IsNotExistMkDir(dir)
if err != nil {
return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err)
}
err := IsNotExistMkDir(dir)
if err != nil {
return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err)
}
f, err := Open(dir + string(os.PathSeparator) + fileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644)
if err != nil {
return nil, fmt.Errorf("fail to open file, err: %s", err)
}
f, err := Open(dir+string(os.PathSeparator)+fileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644)
if err != nil {
return nil, fmt.Errorf("fail to open file, err: %s", err)
}
return f, nil
return f, nil
}

View File

@@ -1,9 +1,9 @@
package geohash
import (
"bytes"
"encoding/base32"
"encoding/binary"
"bytes"
"encoding/base32"
"encoding/binary"
)
var bits = []uint8{128, 64, 32, 16, 8, 4, 2, 1}
@@ -13,95 +13,95 @@ const defaultBitSize = 64 // 32 bits for latitude, another 32 bits for longitude
// return: geohash, box
func encode0(latitude, longitude float64, bitSize uint) ([]byte, [2][2]float64) {
box := [2][2]float64{
{-180, 180}, // lng
{-90, 90}, // lat
}
pos := [2]float64{longitude, latitude}
hash := &bytes.Buffer{}
bit := 0
var precision uint = 0
code := uint8(0)
for precision < bitSize {
for direction, val := range pos {
mid := (box[direction][0] + box[direction][1]) / 2
if val < mid {
box[direction][1] = mid
} else {
box[direction][0] = mid
code |= bits[bit]
}
bit++
if bit == 8 {
hash.WriteByte(code)
bit = 0
code = 0
}
precision++
if precision == bitSize {
break
}
}
}
// precision%8 > 0
if code > 0 {
hash.WriteByte(code)
}
return hash.Bytes(), box
box := [2][2]float64{
{-180, 180}, // lng
{-90, 90}, // lat
}
pos := [2]float64{longitude, latitude}
hash := &bytes.Buffer{}
bit := 0
var precision uint = 0
code := uint8(0)
for precision < bitSize {
for direction, val := range pos {
mid := (box[direction][0] + box[direction][1]) / 2
if val < mid {
box[direction][1] = mid
} else {
box[direction][0] = mid
code |= bits[bit]
}
bit++
if bit == 8 {
hash.WriteByte(code)
bit = 0
code = 0
}
precision++
if precision == bitSize {
break
}
}
}
// precision%8 > 0
if code > 0 {
hash.WriteByte(code)
}
return hash.Bytes(), box
}
func Encode(latitude, longitude float64) uint64 {
buf, _ := encode0(latitude, longitude, defaultBitSize)
return binary.BigEndian.Uint64(buf)
buf, _ := encode0(latitude, longitude, defaultBitSize)
return binary.BigEndian.Uint64(buf)
}
func decode0(hash []byte) [][]float64 {
box := [][]float64{
{-180, 180},
{-90, 90},
}
direction := 0
for i := 0; i < len(hash); i++ {
code := hash[i]
for j := 0; j < len(bits); j++ {
mid := (box[direction][0] + box[direction][1]) / 2
mask := bits[j]
if mask&code > 0 {
box[direction][0] = mid
} else {
box[direction][1] = mid
}
direction = (direction + 1) % 2
}
}
return box
box := [][]float64{
{-180, 180},
{-90, 90},
}
direction := 0
for i := 0; i < len(hash); i++ {
code := hash[i]
for j := 0; j < len(bits); j++ {
mid := (box[direction][0] + box[direction][1]) / 2
mask := bits[j]
if mask&code > 0 {
box[direction][0] = mid
} else {
box[direction][1] = mid
}
direction = (direction + 1) % 2
}
}
return box
}
func Decode(code uint64) (float64, float64) {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, code)
box := decode0(buf)
lng := float64(box[0][0]+box[0][1]) / 2
lat := float64(box[1][0]+box[1][1]) / 2
return lat, lng
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, code)
box := decode0(buf)
lng := float64(box[0][0]+box[0][1]) / 2
lat := float64(box[1][0]+box[1][1]) / 2
return lat, lng
}
func ToString(buf []byte) string {
return enc.EncodeToString(buf)
return enc.EncodeToString(buf)
}
func ToInt(buf []byte) uint64 {
// padding
if len(buf) < 8 {
buf2 := make([]byte, 8)
copy(buf2, buf)
return binary.BigEndian.Uint64(buf2)
}
return binary.BigEndian.Uint64(buf)
// padding
if len(buf) < 8 {
buf2 := make([]byte, 8)
copy(buf2, buf)
return binary.BigEndian.Uint64(buf2)
}
return binary.BigEndian.Uint64(buf)
}
func FromInt(code uint64) []byte {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, code)
return buf
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, code)
return buf
}

View File

@@ -1,39 +1,39 @@
package geohash
import (
"fmt"
"math"
"testing"
"fmt"
"math"
"testing"
)
func TestToRange(t *testing.T) {
neighbor := []byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00}
range_ := ToRange(neighbor, 36)
expectedLower := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00})
expectedUpper := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xF0, 0x00, 0x00, 0x00})
if expectedLower != range_[0] {
t.Error("incorrect lower")
}
if expectedUpper != range_[1] {
t.Error("incorrect upper")
}
neighbor := []byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00}
range_ := ToRange(neighbor, 36)
expectedLower := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00})
expectedUpper := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xF0, 0x00, 0x00, 0x00})
if expectedLower != range_[0] {
t.Error("incorrect lower")
}
if expectedUpper != range_[1] {
t.Error("incorrect upper")
}
}
func TestEncode(t *testing.T) {
lat0 := 48.669
lng0 := -4.32913
hash := Encode(lat0, lng0)
str := ToString(FromInt(hash))
if str != "gbsuv7zt7zntw" {
t.Error("encode error")
}
lat, lng := Decode(hash)
if math.Abs(lat-lat0) > 1e-6 || math.Abs(lng-lng0) > 1e-6 {
t.Error("decode error")
}
lat0 := 48.669
lng0 := -4.32913
hash := Encode(lat0, lng0)
str := ToString(FromInt(hash))
if str != "gbsuv7zt7zntw" {
t.Error("encode error")
}
lat, lng := Decode(hash)
if math.Abs(lat-lat0) > 1e-6 || math.Abs(lng-lng0) > 1e-6 {
t.Error("decode error")
}
}
func TestGetNeighbours(t *testing.T) {
ranges := GetNeighbours(90, 180, 630*1000)
fmt.Printf("%#v", ranges)
ranges := GetNeighbours(90, 180, 630*1000)
fmt.Printf("%#v", ranges)
}

View File

@@ -3,134 +3,134 @@ package geohash
import "math"
const (
DR = math.Pi / 180.0
EarthRadius = 6372797.560856
MercatorMax = 20037726.37 // pi * EarthRadius
MercatorMin = -20037726.37
DR = math.Pi / 180.0
EarthRadius = 6372797.560856
MercatorMax = 20037726.37 // pi * EarthRadius
MercatorMin = -20037726.37
)
func degRad(ang float64) float64 {
return ang * DR
return ang * DR
}
func radDeg(ang float64) float64 {
return ang / DR
return ang / DR
}
func getBoundingBox(latitude float64, longitude float64, radiusMeters float64) (
minLat, maxLat, minLng, maxLng float64) {
minLng = longitude - radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
if minLng < -180 {
minLng = -180
}
maxLng = longitude + radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
if maxLng > 180 {
maxLng = 180
}
minLat = latitude - radDeg(radiusMeters/EarthRadius)
if minLat < -90 {
minLat = -90
}
maxLat = latitude + radDeg(radiusMeters/EarthRadius)
if maxLat > 90 {
maxLat = 90
}
return
minLat, maxLat, minLng, maxLng float64) {
minLng = longitude - radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
if minLng < -180 {
minLng = -180
}
maxLng = longitude + radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
if maxLng > 180 {
maxLng = 180
}
minLat = latitude - radDeg(radiusMeters/EarthRadius)
if minLat < -90 {
minLat = -90
}
maxLat = latitude + radDeg(radiusMeters/EarthRadius)
if maxLat > 90 {
maxLat = 90
}
return
}
func estimatePrecisionByRadius(radiusMeters float64, latitude float64) uint {
if radiusMeters == 0 {
return defaultBitSize - 1
}
var precision uint = 1
for radiusMeters < MercatorMax {
radiusMeters *= 2
precision++
}
/* Make sure range is included in most of the base cases. */
precision -= 2
if latitude > 66 || latitude < -66 {
precision--
if latitude > 80 || latitude < -80 {
precision--
}
}
if precision < 1 {
precision = 1
}
if precision > 32 {
precision = 32
}
return precision*2 - 1
if radiusMeters == 0 {
return defaultBitSize - 1
}
var precision uint = 1
for radiusMeters < MercatorMax {
radiusMeters *= 2
precision++
}
/* Make sure range is included in most of the base cases. */
precision -= 2
if latitude > 66 || latitude < -66 {
precision--
if latitude > 80 || latitude < -80 {
precision--
}
}
if precision < 1 {
precision = 1
}
if precision > 32 {
precision = 32
}
return precision*2 - 1
}
func Distance(latitude1, longitude1, latitude2, longitude2 float64) float64 {
radLat1 := degRad(latitude1)
radLat2 := degRad(latitude2)
a := radLat1 - radLat2
b := degRad(longitude1) - degRad(longitude2)
return 2 * EarthRadius * math.Asin(math.Sqrt(math.Pow(math.Sin(a/2), 2) +
math.Cos(radLat1)*math.Cos(radLat2)*math.Pow(math.Sin(b/2), 2)))
radLat1 := degRad(latitude1)
radLat2 := degRad(latitude2)
a := radLat1 - radLat2
b := degRad(longitude1) - degRad(longitude2)
return 2 * EarthRadius * math.Asin(math.Sqrt(math.Pow(math.Sin(a/2), 2)+
math.Cos(radLat1)*math.Cos(radLat2)*math.Pow(math.Sin(b/2), 2)))
}
func ToRange(scope []byte, precision uint) [2]uint64 {
lower := ToInt(scope)
radius := uint64(1 << (64 - precision))
upper := lower + radius
return [2]uint64{lower, upper}
lower := ToInt(scope)
radius := uint64(1 << (64 - precision))
upper := lower + radius
return [2]uint64{lower, upper}
}
func ensureValidLat(lat float64) float64 {
if lat > 90 {
return 90
}
if lat < -90 {
return -90
}
return lat
if lat > 90 {
return 90
}
if lat < -90 {
return -90
}
return lat
}
func ensureValidLng(lng float64) float64 {
if lng > 180 {
return -360 + lng
}
if lng < -180 {
return 360 + lng
}
return lng
if lng > 180 {
return -360 + lng
}
if lng < -180 {
return 360 + lng
}
return lng
}
func GetNeighbours(latitude, longitude, radiusMeters float64) [][2]uint64 {
precision := estimatePrecisionByRadius(radiusMeters, latitude)
precision := estimatePrecisionByRadius(radiusMeters, latitude)
center, box := encode0(latitude, longitude, precision)
height := box[0][1] - box[0][0]
width := box[1][1] - box[1][0]
centerLng := (box[0][1] + box[0][0]) / 2
centerLat := (box[1][1] + box[1][0]) / 2
maxLat := ensureValidLat(centerLat + height)
minLat := ensureValidLat(centerLat - height)
maxLng := ensureValidLng(centerLng + width)
minLng := ensureValidLng(centerLng - width)
center, box := encode0(latitude, longitude, precision)
height := box[0][1] - box[0][0]
width := box[1][1] - box[1][0]
centerLng := (box[0][1] + box[0][0]) / 2
centerLat := (box[1][1] + box[1][0]) / 2
maxLat := ensureValidLat(centerLat + height)
minLat := ensureValidLat(centerLat - height)
maxLng := ensureValidLng(centerLng + width)
minLng := ensureValidLng(centerLng - width)
var result [10][2]uint64
leftUpper, _ := encode0(maxLat, minLng, precision)
result[1] = ToRange(leftUpper, precision)
upper, _ := encode0(maxLat, centerLng, precision)
result[2] = ToRange(upper, precision)
rightUpper, _ := encode0(maxLat, maxLng, precision)
result[3] = ToRange(rightUpper, precision)
left, _ := encode0(centerLat, minLng, precision)
result[4] = ToRange(left, precision)
result[5] = ToRange(center, precision)
right, _ := encode0(centerLat, maxLng, precision)
result[6] = ToRange(right, precision)
leftDown, _ := encode0(minLat, minLng, precision)
result[7] = ToRange(leftDown, precision)
down, _ := encode0(minLat, centerLng, precision)
result[8] = ToRange(down, precision)
rightDown, _ := encode0(minLat, maxLng, precision)
result[9] = ToRange(rightDown, precision)
var result [10][2]uint64
leftUpper, _ := encode0(maxLat, minLng, precision)
result[1] = ToRange(leftUpper, precision)
upper, _ := encode0(maxLat, centerLng, precision)
result[2] = ToRange(upper, precision)
rightUpper, _ := encode0(maxLat, maxLng, precision)
result[3] = ToRange(rightUpper, precision)
left, _ := encode0(centerLat, minLng, precision)
result[4] = ToRange(left, precision)
result[5] = ToRange(center, precision)
right, _ := encode0(centerLat, maxLng, precision)
result[6] = ToRange(right, precision)
leftDown, _ := encode0(minLat, minLng, precision)
result[7] = ToRange(leftDown, precision)
down, _ := encode0(minLat, centerLng, precision)
result[8] = ToRange(down, precision)
rightDown, _ := encode0(minLat, maxLng, precision)
result[9] = ToRange(rightDown, precision)
return result[1:]
return result[1:]
}

View File

@@ -1,21 +1,21 @@
package gob
import (
"bytes"
"encoding/gob"
"bytes"
"encoding/gob"
)
func Marshal(obj interface{}) ([]byte, error) {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func UnMarshal(data []byte, obj interface{}) error {
dec := gob.NewDecoder(bytes.NewBuffer(data))
return dec.Decode(obj)
dec := gob.NewDecoder(bytes.NewBuffer(data))
return dec.Decode(obj)
}

View File

@@ -4,16 +4,14 @@ import "sync/atomic"
type AtomicBool uint32
func (b *AtomicBool)Get()bool {
return atomic.LoadUint32((*uint32)(b)) != 0
func (b *AtomicBool) Get() bool {
return atomic.LoadUint32((*uint32)(b)) != 0
}
func (b *AtomicBool)Set(v bool) {
if v {
atomic.StoreUint32((*uint32)(b), 1)
} else {
atomic.StoreUint32((*uint32)(b), 0)
}
func (b *AtomicBool) Set(v bool) {
if v {
atomic.StoreUint32((*uint32)(b), 1)
} else {
atomic.StoreUint32((*uint32)(b), 0)
}
}

View File

@@ -1,38 +1,38 @@
package wait
import (
"sync"
"time"
"sync"
"time"
)
type Wait struct {
wg sync.WaitGroup
wg sync.WaitGroup
}
func (w *Wait)Add(delta int) {
w.wg.Add(delta)
func (w *Wait) Add(delta int) {
w.wg.Add(delta)
}
func (w *Wait)Done() {
w.wg.Done()
func (w *Wait) Done() {
w.wg.Done()
}
func (w *Wait)Wait() {
w.wg.Wait()
func (w *Wait) Wait() {
w.wg.Wait()
}
// return isTimeout
func (w *Wait)WaitWithTimeout(timeout time.Duration)bool {
c := make(chan bool)
go func() {
defer close(c)
w.wg.Wait()
c <- true
}()
select {
case <-c:
return false // completed normally
case <-time.After(timeout):
return true // timed out
}
func (w *Wait) WaitWithTimeout(timeout time.Duration) bool {
c := make(chan bool)
go func() {
defer close(c)
w.wg.Wait()
c <- true
}()
select {
case <-c:
return false // completed normally
case <-time.After(timeout):
return true // timed out
}
}

View File

@@ -1,95 +1,95 @@
package wildcard
const (
normal = iota
all // *
any // ?
set_ // []
normal = iota
all // *
any // ?
set_ // []
)
type item struct {
character byte
set map[byte]bool
typeCode int
character byte
set map[byte]bool
typeCode int
}
func (i *item) contains(c byte) bool {
_, ok := i.set[c]
return ok
_, ok := i.set[c]
return ok
}
type Pattern struct {
items []*item
items []*item
}
func CompilePattern(src string) *Pattern {
items := make([]*item, 0)
escape := false
inSet := false
var set map[byte]bool
for _, v := range src {
c := byte(v)
if escape {
items = append(items, &item{typeCode: normal, character: c})
escape = false
} else if c == '*' {
items = append(items, &item{typeCode: all})
} else if c == '?' {
items = append(items, &item{typeCode: any})
} else if c == '\\' {
escape = true
} else if c == '[' {
if !inSet {
inSet = true
set = make(map[byte]bool)
} else {
set[c] = true
}
} else if c == ']' {
if inSet {
inSet = false
items = append(items, &item{typeCode: set_, set: set})
} else {
items = append(items, &item{typeCode: normal, character: c})
}
} else {
if inSet {
set[c] = true
} else {
items = append(items, &item{typeCode: normal, character: c})
}
}
}
return &Pattern{
items: items,
}
items := make([]*item, 0)
escape := false
inSet := false
var set map[byte]bool
for _, v := range src {
c := byte(v)
if escape {
items = append(items, &item{typeCode: normal, character: c})
escape = false
} else if c == '*' {
items = append(items, &item{typeCode: all})
} else if c == '?' {
items = append(items, &item{typeCode: any})
} else if c == '\\' {
escape = true
} else if c == '[' {
if !inSet {
inSet = true
set = make(map[byte]bool)
} else {
set[c] = true
}
} else if c == ']' {
if inSet {
inSet = false
items = append(items, &item{typeCode: set_, set: set})
} else {
items = append(items, &item{typeCode: normal, character: c})
}
} else {
if inSet {
set[c] = true
} else {
items = append(items, &item{typeCode: normal, character: c})
}
}
}
return &Pattern{
items: items,
}
}
func (p *Pattern) IsMatch(s string) bool {
if len(p.items) == 0 {
return len(s) == 0
}
m := len(s)
n := len(p.items)
table := make([][]bool, m+1)
for i := 0; i < m+1; i++ {
table[i] = make([]bool, n+1)
}
table[0][0] = true
for j := 1; j < n+1; j++ {
table[0][j] = table[0][j-1] && p.items[j-1].typeCode == all
}
for i := 1; i < m+1; i++ {
for j := 1; j < n+1; j++ {
if p.items[j-1].typeCode == all {
table[i][j] = table[i-1][j] || table[i][j-1]
} else {
table[i][j] = table[i-1][j-1] &&
(p.items[j-1].typeCode == any ||
(p.items[j-1].typeCode == normal && uint8(s[i-1]) == p.items[j-1].character) ||
(p.items[j-1].typeCode == set_ && p.items[j-1].contains(s[i-1])))
}
}
}
return table[m][n]
if len(p.items) == 0 {
return len(s) == 0
}
m := len(s)
n := len(p.items)
table := make([][]bool, m+1)
for i := 0; i < m+1; i++ {
table[i] = make([]bool, n+1)
}
table[0][0] = true
for j := 1; j < n+1; j++ {
table[0][j] = table[0][j-1] && p.items[j-1].typeCode == all
}
for i := 1; i < m+1; i++ {
for j := 1; j < n+1; j++ {
if p.items[j-1].typeCode == all {
table[i][j] = table[i-1][j] || table[i][j-1]
} else {
table[i][j] = table[i-1][j-1] &&
(p.items[j-1].typeCode == any ||
(p.items[j-1].typeCode == normal && uint8(s[i-1]) == p.items[j-1].character) ||
(p.items[j-1].typeCode == set_ && p.items[j-1].contains(s[i-1])))
}
}
}
return table[m][n]
}

View File

@@ -3,73 +3,73 @@ package wildcard
import "testing"
func TestWildCard(t *testing.T) {
p := CompilePattern("a")
if !p.IsMatch("a") {
t.Error("expect true actually false")
}
if p.IsMatch("b") {
t.Error("expect false actually true")
}
p := CompilePattern("a")
if !p.IsMatch("a") {
t.Error("expect true actually false")
}
if p.IsMatch("b") {
t.Error("expect false actually true")
}
// test '?'
p = CompilePattern("a?")
if !p.IsMatch("ab") {
t.Error("expect true actually false")
}
if p.IsMatch("a") {
t.Error("expect false actually true")
}
if p.IsMatch("abb") {
t.Error("expect false actually true")
}
if p.IsMatch("bb") {
t.Error("expect false actually true")
}
// test '?'
p = CompilePattern("a?")
if !p.IsMatch("ab") {
t.Error("expect true actually false")
}
if p.IsMatch("a") {
t.Error("expect false actually true")
}
if p.IsMatch("abb") {
t.Error("expect false actually true")
}
if p.IsMatch("bb") {
t.Error("expect false actually true")
}
// test *
p = CompilePattern("a*")
if !p.IsMatch("ab") {
t.Error("expect true actually false")
}
if !p.IsMatch("a") {
t.Error("expect true actually false")
}
if !p.IsMatch("abb") {
t.Error("expect true actually false")
}
if p.IsMatch("bb") {
t.Error("expect false actually true")
}
// test *
p = CompilePattern("a*")
if !p.IsMatch("ab") {
t.Error("expect true actually false")
}
if !p.IsMatch("a") {
t.Error("expect true actually false")
}
if !p.IsMatch("abb") {
t.Error("expect true actually false")
}
if p.IsMatch("bb") {
t.Error("expect false actually true")
}
// test []
p = CompilePattern("a[ab[]")
if !p.IsMatch("ab") {
t.Error("expect true actually false")
}
if !p.IsMatch("aa") {
t.Error("expect true actually false")
}
if !p.IsMatch("a[") {
t.Error("expect true actually false")
}
if p.IsMatch("abb") {
t.Error("expect false actually true")
}
if p.IsMatch("bb") {
t.Error("expect false actually true")
}
// test []
p = CompilePattern("a[ab[]")
if !p.IsMatch("ab") {
t.Error("expect true actually false")
}
if !p.IsMatch("aa") {
t.Error("expect true actually false")
}
if !p.IsMatch("a[") {
t.Error("expect true actually false")
}
if p.IsMatch("abb") {
t.Error("expect false actually true")
}
if p.IsMatch("bb") {
t.Error("expect false actually true")
}
// test escape
p = CompilePattern("\\\\") // pattern: \\
if !p.IsMatch("\\") {
t.Error("expect true actually false")
}
// test escape
p = CompilePattern("\\\\") // pattern: \\
if !p.IsMatch("\\") {
t.Error("expect true actually false")
}
p = CompilePattern("\\*")
if !p.IsMatch("*") {
t.Error("expect true actually false")
}
if p.IsMatch("a") {
t.Error("expect false actually true")
}
p = CompilePattern("\\*")
if !p.IsMatch("*") {
t.Error("expect true actually false")
}
if p.IsMatch("a") {
t.Error("expect false actually true")
}
}

View File

@@ -1,20 +1,20 @@
package pubsub
import (
"github.com/HDT3213/godis/src/datastruct/dict"
"github.com/HDT3213/godis/src/datastruct/lock"
"github.com/HDT3213/godis/src/datastruct/dict"
"github.com/HDT3213/godis/src/datastruct/lock"
)
type Hub struct {
// channel -> list(*Client)
subs dict.Dict
// lock channel
subsLocker *lock.Locks
// channel -> list(*Client)
subs dict.Dict
// lock channel
subsLocker *lock.Locks
}
func MakeHub() *Hub {
return &Hub{
subs: dict.MakeConcurrent(4),
subsLocker: lock.Make(16),
}
return &Hub{
subs: dict.MakeConcurrent(4),
subsLocker: lock.Make(16),
}
}

View File

@@ -1,23 +1,23 @@
package pubsub
import (
"github.com/HDT3213/godis/src/datastruct/list"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"strconv"
"github.com/HDT3213/godis/src/datastruct/list"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"strconv"
)
var (
_subscribe = "subscribe"
_unsubscribe = "unsubscribe"
messageBytes = []byte("message")
unSubscribeNothing = []byte("*3\r\n$11\r\nunsubscribe\r\n$-1\n:0\r\n")
_subscribe = "subscribe"
_unsubscribe = "unsubscribe"
messageBytes = []byte("message")
unSubscribeNothing = []byte("*3\r\n$11\r\nunsubscribe\r\n$-1\n:0\r\n")
)
func makeMsg(t string, channel string, code int64) []byte {
return []byte("*3\r\n$" + strconv.FormatInt(int64(len(t)), 10) + reply.CRLF + t + reply.CRLF +
"$" + strconv.FormatInt(int64(len(channel)), 10) + reply.CRLF + channel + reply.CRLF +
":" + strconv.FormatInt(code, 10) + reply.CRLF)
return []byte("*3\r\n$" + strconv.FormatInt(int64(len(t)), 10) + reply.CRLF + t + reply.CRLF +
"$" + strconv.FormatInt(int64(len(channel)), 10) + reply.CRLF + channel + reply.CRLF +
":" + strconv.FormatInt(code, 10) + reply.CRLF)
}
/*
@@ -25,22 +25,22 @@ func makeMsg(t string, channel string, code int64) []byte {
* return: is new subscribed
*/
func subscribe0(hub *Hub, channel string, client redis.Connection) bool {
client.SubsChannel(channel)
client.SubsChannel(channel)
// add into hub.subs
raw, ok := hub.subs.Get(channel)
var subscribers *list.LinkedList
if ok {
subscribers, _ = raw.(*list.LinkedList)
} else {
subscribers = list.Make()
hub.subs.Put(channel, subscribers)
}
if subscribers.Contains(client) {
return false
}
subscribers.Add(client)
return true
// add into hub.subs
raw, ok := hub.subs.Get(channel)
var subscribers *list.LinkedList
if ok {
subscribers, _ = raw.(*list.LinkedList)
} else {
subscribers = list.Make()
hub.subs.Put(channel, subscribers)
}
if subscribers.Contains(client) {
return false
}
subscribers.Add(client)
return true
}
/*
@@ -48,102 +48,102 @@ func subscribe0(hub *Hub, channel string, client redis.Connection) bool {
* return: is actually un-subscribe
*/
func unsubscribe0(hub *Hub, channel string, client redis.Connection) bool {
client.UnSubsChannel(channel)
client.UnSubsChannel(channel)
// remove from hub.subs
raw, ok := hub.subs.Get(channel)
if ok {
subscribers, _ := raw.(*list.LinkedList)
subscribers.RemoveAllByVal(client)
// remove from hub.subs
raw, ok := hub.subs.Get(channel)
if ok {
subscribers, _ := raw.(*list.LinkedList)
subscribers.RemoveAllByVal(client)
if subscribers.Len() == 0 {
// clean
hub.subs.Remove(channel)
}
return true
}
return false
if subscribers.Len() == 0 {
// clean
hub.subs.Remove(channel)
}
return true
}
return false
}
func Subscribe(hub *Hub, c redis.Connection, args [][]byte) redis.Reply {
channels := make([]string, len(args))
for i, b := range args {
channels[i] = string(b)
}
channels := make([]string, len(args))
for i, b := range args {
channels[i] = string(b)
}
hub.subsLocker.Locks(channels...)
defer hub.subsLocker.UnLocks(channels...)
hub.subsLocker.Locks(channels...)
defer hub.subsLocker.UnLocks(channels...)
for _, channel := range channels {
if subscribe0(hub, channel, c) {
_ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount())))
}
}
return &reply.NoReply{}
for _, channel := range channels {
if subscribe0(hub, channel, c) {
_ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount())))
}
}
return &reply.NoReply{}
}
func UnsubscribeAll(hub *Hub, c redis.Connection) {
channels := c.GetChannels()
channels := c.GetChannels()
hub.subsLocker.Locks(channels...)
defer hub.subsLocker.UnLocks(channels...)
hub.subsLocker.Locks(channels...)
defer hub.subsLocker.UnLocks(channels...)
for _, channel := range channels {
unsubscribe0(hub, channel, c)
}
for _, channel := range channels {
unsubscribe0(hub, channel, c)
}
}
func UnSubscribe(db *Hub, c redis.Connection, args [][]byte) redis.Reply {
var channels []string
if len(args) > 0 {
channels = make([]string, len(args))
for i, b := range args {
channels[i] = string(b)
}
} else {
channels = c.GetChannels()
}
var channels []string
if len(args) > 0 {
channels = make([]string, len(args))
for i, b := range args {
channels[i] = string(b)
}
} else {
channels = c.GetChannels()
}
db.subsLocker.Locks(channels...)
defer db.subsLocker.UnLocks(channels...)
db.subsLocker.Locks(channels...)
defer db.subsLocker.UnLocks(channels...)
if len(channels) == 0 {
_ = c.Write(unSubscribeNothing)
return &reply.NoReply{}
}
if len(channels) == 0 {
_ = c.Write(unSubscribeNothing)
return &reply.NoReply{}
}
for _, channel := range channels {
if unsubscribe0(db, channel, c) {
_ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount())))
}
}
return &reply.NoReply{}
for _, channel := range channels {
if unsubscribe0(db, channel, c) {
_ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount())))
}
}
return &reply.NoReply{}
}
func Publish(hub *Hub, args [][]byte) redis.Reply {
if len(args) != 2 {
return &reply.ArgNumErrReply{Cmd: "publish"}
}
channel := string(args[0])
message := args[1]
if len(args) != 2 {
return &reply.ArgNumErrReply{Cmd: "publish"}
}
channel := string(args[0])
message := args[1]
hub.subsLocker.Lock(channel)
defer hub.subsLocker.UnLock(channel)
hub.subsLocker.Lock(channel)
defer hub.subsLocker.UnLock(channel)
raw, ok := hub.subs.Get(channel)
if !ok {
return reply.MakeIntReply(0)
}
subscribers, _ := raw.(*list.LinkedList)
subscribers.ForEach(func(i int, c interface{}) bool {
client, _ := c.(redis.Connection)
replyArgs := make([][]byte, 3)
replyArgs[0] = messageBytes
replyArgs[1] = []byte(channel)
replyArgs[2] = message
_ = client.Write(reply.MakeMultiBulkReply(replyArgs).ToBytes())
return true
})
return reply.MakeIntReply(int64(subscribers.Len()))
raw, ok := hub.subs.Get(channel)
if !ok {
return reply.MakeIntReply(0)
}
subscribers, _ := raw.(*list.LinkedList)
subscribers.ForEach(func(i int, c interface{}) bool {
client, _ := c.(redis.Connection)
replyArgs := make([][]byte, 3)
replyArgs[0] = messageBytes
replyArgs[1] = []byte(channel)
replyArgs[2] = message
_ = client.Write(reply.MakeMultiBulkReply(replyArgs).ToBytes())
return true
})
return reply.MakeIntReply(int64(subscribers.Len()))
}

View File

@@ -1,322 +1,321 @@
package client
import (
"bufio"
"context"
"errors"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/lib/sync/wait"
"github.com/HDT3213/godis/src/redis/reply"
"io"
"net"
"strconv"
"strings"
"sync"
"time"
"bufio"
"context"
"errors"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/lib/sync/wait"
"github.com/HDT3213/godis/src/redis/reply"
"io"
"net"
"strconv"
"strings"
"sync"
"time"
)
type Client struct {
conn net.Conn
sendingReqs chan *Request // waiting sending
waitingReqs chan *Request // waiting response
ticker *time.Ticker
addr string
conn net.Conn
sendingReqs chan *Request // waiting sending
waitingReqs chan *Request // waiting response
ticker *time.Ticker
addr string
ctx context.Context
cancelFunc context.CancelFunc
writing *sync.WaitGroup
ctx context.Context
cancelFunc context.CancelFunc
writing *sync.WaitGroup
}
type Request struct {
id uint64
args [][]byte
reply redis.Reply
heartbeat bool
waiting *wait.Wait
err error
id uint64
args [][]byte
reply redis.Reply
heartbeat bool
waiting *wait.Wait
err error
}
const (
chanSize = 256
maxWait = 3 * time.Second
chanSize = 256
maxWait = 3 * time.Second
)
func MakeClient(addr string) (*Client, error) {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
return &Client{
addr: addr,
conn: conn,
sendingReqs: make(chan *Request, chanSize),
waitingReqs: make(chan *Request, chanSize),
ctx: ctx,
cancelFunc: cancel,
writing: &sync.WaitGroup{},
}, nil
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
return &Client{
addr: addr,
conn: conn,
sendingReqs: make(chan *Request, chanSize),
waitingReqs: make(chan *Request, chanSize),
ctx: ctx,
cancelFunc: cancel,
writing: &sync.WaitGroup{},
}, nil
}
func (client *Client) Start() {
client.ticker = time.NewTicker(10 * time.Second)
go client.handleWrite()
go func() {
err := client.handleRead()
logger.Warn(err)
}()
go client.heartbeat()
client.ticker = time.NewTicker(10 * time.Second)
go client.handleWrite()
go func() {
err := client.handleRead()
logger.Warn(err)
}()
go client.heartbeat()
}
func (client *Client) Close() {
// stop new request
close(client.sendingReqs)
// stop new request
close(client.sendingReqs)
// wait stop process
client.writing.Wait()
// wait stop process
client.writing.Wait()
// clean
client.cancelFunc()
_ = client.conn.Close()
close(client.waitingReqs)
// clean
client.cancelFunc()
_ = client.conn.Close()
close(client.waitingReqs)
}
func (client *Client) handleConnectionError(err error) error {
err1 := client.conn.Close()
if err1 != nil {
if opErr, ok := err1.(*net.OpError); ok {
if opErr.Err.Error() != "use of closed network connection" {
return err1
}
} else {
return err1
}
}
conn, err1 := net.Dial("tcp", client.addr)
if err1 != nil {
logger.Error(err1)
return err1
}
client.conn = conn
go func() {
_ = client.handleRead()
}()
return nil
err1 := client.conn.Close()
if err1 != nil {
if opErr, ok := err1.(*net.OpError); ok {
if opErr.Err.Error() != "use of closed network connection" {
return err1
}
} else {
return err1
}
}
conn, err1 := net.Dial("tcp", client.addr)
if err1 != nil {
logger.Error(err1)
return err1
}
client.conn = conn
go func() {
_ = client.handleRead()
}()
return nil
}
func (client *Client) heartbeat() {
loop:
for {
select {
case <-client.ticker.C:
client.sendingReqs <- &Request{
args: [][]byte{[]byte("PING")},
heartbeat: true,
}
case <-client.ctx.Done():
break loop
}
}
for {
select {
case <-client.ticker.C:
client.sendingReqs <- &Request{
args: [][]byte{[]byte("PING")},
heartbeat: true,
}
case <-client.ctx.Done():
break loop
}
}
}
func (client *Client) handleWrite() {
loop:
for {
select {
case req := <-client.sendingReqs:
client.writing.Add(1)
client.doRequest(req)
case <-client.ctx.Done():
break loop
}
}
for {
select {
case req := <-client.sendingReqs:
client.writing.Add(1)
client.doRequest(req)
case <-client.ctx.Done():
break loop
}
}
}
// todo: wait with timeout
func (client *Client) Send(args [][]byte) redis.Reply {
request := &Request{
args: args,
heartbeat: false,
waiting: &wait.Wait{},
}
request.waiting.Add(1)
client.sendingReqs <- request
timeout := request.waiting.WaitWithTimeout(maxWait)
if timeout {
return reply.MakeErrReply("server time out")
}
if request.err != nil {
return reply.MakeErrReply("request failed")
}
return request.reply
request := &Request{
args: args,
heartbeat: false,
waiting: &wait.Wait{},
}
request.waiting.Add(1)
client.sendingReqs <- request
timeout := request.waiting.WaitWithTimeout(maxWait)
if timeout {
return reply.MakeErrReply("server time out")
}
if request.err != nil {
return reply.MakeErrReply("request failed")
}
return request.reply
}
func (client *Client) doRequest(req *Request) {
bytes := reply.MakeMultiBulkReply(req.args).ToBytes()
_, err := client.conn.Write(bytes)
i := 0
for err != nil && i < 3 {
err = client.handleConnectionError(err)
if err == nil {
_, err = client.conn.Write(bytes)
}
i++
}
if err == nil {
client.waitingReqs <- req
} else {
req.err = err
req.waiting.Done()
client.writing.Done()
}
bytes := reply.MakeMultiBulkReply(req.args).ToBytes()
_, err := client.conn.Write(bytes)
i := 0
for err != nil && i < 3 {
err = client.handleConnectionError(err)
if err == nil {
_, err = client.conn.Write(bytes)
}
i++
}
if err == nil {
client.waitingReqs <- req
} else {
req.err = err
req.waiting.Done()
client.writing.Done()
}
}
func (client *Client) finishRequest(reply redis.Reply) {
request := <-client.waitingReqs
request.reply = reply
if request.waiting != nil {
request.waiting.Done()
}
client.writing.Done()
request := <-client.waitingReqs
request.reply = reply
if request.waiting != nil {
request.waiting.Done()
}
client.writing.Done()
}
func (client *Client) handleRead() error {
reader := bufio.NewReader(client.conn)
downloading := false
expectedArgsCount := 0
receivedCount := 0
msgType := byte(0) // first char of msg
var args [][]byte
var fixedLen int64 = 0
var err error
var msg []byte
for {
// read line
if fixedLen == 0 { // read normal line
msg, err = reader.ReadBytes('\n')
if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
logger.Info("connection close")
} else {
logger.Warn(err)
}
reader := bufio.NewReader(client.conn)
downloading := false
expectedArgsCount := 0
receivedCount := 0
msgType := byte(0) // first char of msg
var args [][]byte
var fixedLen int64 = 0
var err error
var msg []byte
for {
// read line
if fixedLen == 0 { // read normal line
msg, err = reader.ReadBytes('\n')
if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
logger.Info("connection close")
} else {
logger.Warn(err)
}
return errors.New("connection closed")
}
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
return errors.New("protocol error")
}
} else { // read bulk line (binary safe)
msg = make([]byte, fixedLen+2)
_, err = io.ReadFull(reader, msg)
if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
return errors.New("connection closed")
} else {
return err
}
}
if len(msg) == 0 ||
msg[len(msg)-2] != '\r' ||
msg[len(msg)-1] != '\n' {
return errors.New("protocol error")
}
fixedLen = 0
}
return errors.New("connection closed")
}
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
return errors.New("protocol error")
}
} else { // read bulk line (binary safe)
msg = make([]byte, fixedLen+2)
_, err = io.ReadFull(reader, msg)
if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
return errors.New("connection closed")
} else {
return err
}
}
if len(msg) == 0 ||
msg[len(msg)-2] != '\r' ||
msg[len(msg)-1] != '\n' {
return errors.New("protocol error")
}
fixedLen = 0
}
// parse line
if !downloading {
// receive new response
if msg[0] == '*' { // multi bulk response
// bulk multi msg
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
if err != nil {
return errors.New("protocol error: " + err.Error())
}
if expectedLine == 0 {
client.finishRequest(&reply.EmptyMultiBulkReply{})
} else if expectedLine > 0 {
msgType = msg[0]
downloading = true
expectedArgsCount = int(expectedLine)
receivedCount = 0
args = make([][]byte, expectedLine)
} else {
return errors.New("protocol error")
}
} else if msg[0] == '$' { // bulk response
fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
if err != nil {
return err
}
if fixedLen == -1 { // null bulk
client.finishRequest(&reply.NullBulkReply{})
fixedLen = 0
} else if fixedLen > 0 {
msgType = msg[0]
downloading = true
expectedArgsCount = 1
receivedCount = 0
args = make([][]byte, 1)
} else {
return errors.New("protocol error")
}
} else { // single line response
str := strings.TrimSuffix(string(msg), "\n")
str = strings.TrimSuffix(str, "\r")
var result redis.Reply
switch msg[0] {
case '+':
result = reply.MakeStatusReply(str[1:])
case '-':
result = reply.MakeErrReply(str[1:])
case ':':
val, err := strconv.ParseInt(str[1:], 10, 64)
if err != nil {
return errors.New("protocol error")
}
result = reply.MakeIntReply(val)
}
client.finishRequest(result)
}
} else {
// receive following part of a request
line := msg[0 : len(msg)-2]
if line[0] == '$' {
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
if err != nil {
return err
}
if fixedLen <= 0 { // null bulk in multi bulks
args[receivedCount] = []byte{}
receivedCount++
fixedLen = 0
}
} else {
args[receivedCount] = line
receivedCount++
}
// parse line
if !downloading {
// receive new response
if msg[0] == '*' { // multi bulk response
// bulk multi msg
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
if err != nil {
return errors.New("protocol error: " + err.Error())
}
if expectedLine == 0 {
client.finishRequest(&reply.EmptyMultiBulkReply{})
} else if expectedLine > 0 {
msgType = msg[0]
downloading = true
expectedArgsCount = int(expectedLine)
receivedCount = 0
args = make([][]byte, expectedLine)
} else {
return errors.New("protocol error")
}
} else if msg[0] == '$' { // bulk response
fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
if err != nil {
return err
}
if fixedLen == -1 { // null bulk
client.finishRequest(&reply.NullBulkReply{})
fixedLen = 0
} else if fixedLen > 0 {
msgType = msg[0]
downloading = true
expectedArgsCount = 1
receivedCount = 0
args = make([][]byte, 1)
} else {
return errors.New("protocol error")
}
} else { // single line response
str := strings.TrimSuffix(string(msg), "\n")
str = strings.TrimSuffix(str, "\r")
var result redis.Reply
switch msg[0] {
case '+':
result = reply.MakeStatusReply(str[1:])
case '-':
result = reply.MakeErrReply(str[1:])
case ':':
val, err := strconv.ParseInt(str[1:], 10, 64)
if err != nil {
return errors.New("protocol error")
}
result = reply.MakeIntReply(val)
}
client.finishRequest(result)
}
} else {
// receive following part of a request
line := msg[0 : len(msg)-2]
if line[0] == '$' {
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
if err != nil {
return err
}
if fixedLen <= 0 { // null bulk in multi bulks
args[receivedCount] = []byte{}
receivedCount++
fixedLen = 0
}
} else {
args[receivedCount] = line
receivedCount++
}
// if sending finished
if receivedCount == expectedArgsCount {
downloading = false // finish downloading progress
// if sending finished
if receivedCount == expectedArgsCount {
downloading = false // finish downloading progress
if msgType == '*' {
reply := reply.MakeMultiBulkReply(args)
client.finishRequest(reply)
} else if msgType == '$' {
reply := reply.MakeBulkReply(args[0])
client.finishRequest(reply)
}
if msgType == '*' {
reply := reply.MakeMultiBulkReply(args)
client.finishRequest(reply)
} else if msgType == '$' {
reply := reply.MakeBulkReply(args[0])
client.finishRequest(reply)
}
// finish reply
expectedArgsCount = 0
receivedCount = 0
args = nil
msgType = byte(0)
}
}
}
// finish reply
expectedArgsCount = 0
receivedCount = 0
args = nil
msgType = byte(0)
}
}
}
}

View File

@@ -1,104 +1,104 @@
package client
import (
"github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/redis/reply"
"testing"
"github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/redis/reply"
"testing"
)
func TestClient(t *testing.T) {
logger.Setup(&logger.Settings{
Path: "logs",
Name: "godis",
Ext: ".log",
TimeFormat: "2006-01-02",
})
client, err := MakeClient("localhost:6379")
if err != nil {
t.Error(err)
}
client.Start()
logger.Setup(&logger.Settings{
Path: "logs",
Name: "godis",
Ext: ".log",
TimeFormat: "2006-01-02",
})
client, err := MakeClient("localhost:6379")
if err != nil {
t.Error(err)
}
client.Start()
result := client.Send([][]byte{
[]byte("PING"),
})
if statusRet, ok := result.(*reply.StatusReply); ok {
if statusRet.Status != "PONG" {
t.Error("`ping` failed, result: " + statusRet.Status)
}
}
result := client.Send([][]byte{
[]byte("PING"),
})
if statusRet, ok := result.(*reply.StatusReply); ok {
if statusRet.Status != "PONG" {
t.Error("`ping` failed, result: " + statusRet.Status)
}
}
result = client.Send([][]byte{
[]byte("SET"),
[]byte("a"),
[]byte("a"),
})
if statusRet, ok := result.(*reply.StatusReply); ok {
if statusRet.Status != "OK" {
t.Error("`set` failed, result: " + statusRet.Status)
}
}
result = client.Send([][]byte{
[]byte("SET"),
[]byte("a"),
[]byte("a"),
})
if statusRet, ok := result.(*reply.StatusReply); ok {
if statusRet.Status != "OK" {
t.Error("`set` failed, result: " + statusRet.Status)
}
}
result = client.Send([][]byte{
[]byte("GET"),
[]byte("a"),
})
if bulkRet, ok := result.(*reply.BulkReply); ok {
if string(bulkRet.Arg) != "a" {
t.Error("`get` failed, result: " + string(bulkRet.Arg))
}
}
result = client.Send([][]byte{
[]byte("GET"),
[]byte("a"),
})
if bulkRet, ok := result.(*reply.BulkReply); ok {
if string(bulkRet.Arg) != "a" {
t.Error("`get` failed, result: " + string(bulkRet.Arg))
}
}
result = client.Send([][]byte{
[]byte("DEL"),
[]byte("a"),
})
if intRet, ok := result.(*reply.IntReply); ok {
if intRet.Code != 1 {
t.Error("`del` failed, result: " + string(intRet.Code))
}
}
result = client.Send([][]byte{
[]byte("DEL"),
[]byte("a"),
})
if intRet, ok := result.(*reply.IntReply); ok {
if intRet.Code != 1 {
t.Error("`del` failed, result: " + string(intRet.Code))
}
}
result = client.Send([][]byte{
[]byte("GET"),
[]byte("a"),
})
if _, ok := result.(*reply.NullBulkReply); !ok {
t.Error("`get` failed, result: " + string(result.ToBytes()))
}
result = client.Send([][]byte{
[]byte("GET"),
[]byte("a"),
})
if _, ok := result.(*reply.NullBulkReply); !ok {
t.Error("`get` failed, result: " + string(result.ToBytes()))
}
result = client.Send([][]byte{
[]byte("DEL"),
[]byte("arr"),
})
result = client.Send([][]byte{
[]byte("DEL"),
[]byte("arr"),
})
result = client.Send([][]byte{
[]byte("RPUSH"),
[]byte("arr"),
[]byte("1"),
[]byte("2"),
[]byte("c"),
})
if intRet, ok := result.(*reply.IntReply); ok {
if intRet.Code != 3 {
t.Error("`rpush` failed, result: " + string(intRet.Code))
}
}
result = client.Send([][]byte{
[]byte("RPUSH"),
[]byte("arr"),
[]byte("1"),
[]byte("2"),
[]byte("c"),
})
if intRet, ok := result.(*reply.IntReply); ok {
if intRet.Code != 3 {
t.Error("`rpush` failed, result: " + string(intRet.Code))
}
}
result = client.Send([][]byte{
[]byte("LRANGE"),
[]byte("arr"),
[]byte("0"),
[]byte("-1"),
})
if multiBulkRet, ok := result.(*reply.MultiBulkReply); ok {
if len(multiBulkRet.Args) != 3 ||
string(multiBulkRet.Args[0]) != "1" ||
string(multiBulkRet.Args[1]) != "2" ||
string(multiBulkRet.Args[2]) != "c" {
t.Error("`lrange` failed, result: " + string(multiBulkRet.ToBytes()))
}
}
result = client.Send([][]byte{
[]byte("LRANGE"),
[]byte("arr"),
[]byte("0"),
[]byte("-1"),
})
if multiBulkRet, ok := result.(*reply.MultiBulkReply); ok {
if len(multiBulkRet.Args) != 3 ||
string(multiBulkRet.Args[0]) != "1" ||
string(multiBulkRet.Args[1]) != "2" ||
string(multiBulkRet.Args[2]) != "c" {
t.Error("`lrange` failed, result: " + string(multiBulkRet.ToBytes()))
}
}
client.Close()
client.Close()
}

View File

@@ -1,42 +1,42 @@
package reply
type PongReply struct {}
type PongReply struct{}
var PongBytes = []byte("+PONG\r\n")
func (r *PongReply)ToBytes()[]byte {
return PongBytes
func (r *PongReply) ToBytes() []byte {
return PongBytes
}
type OkReply struct {}
type OkReply struct{}
var okBytes = []byte("+OK\r\n")
func (r *OkReply)ToBytes()[]byte {
return okBytes
func (r *OkReply) ToBytes() []byte {
return okBytes
}
var nullBulkBytes = []byte("$-1\r\n")
type NullBulkReply struct {}
type NullBulkReply struct{}
func (r *NullBulkReply)ToBytes()[]byte {
return nullBulkBytes
func (r *NullBulkReply) ToBytes() []byte {
return nullBulkBytes
}
var emptyMultiBulkBytes = []byte("*0\r\n")
type EmptyMultiBulkReply struct {}
type EmptyMultiBulkReply struct{}
func (r *EmptyMultiBulkReply)ToBytes()[]byte {
return emptyMultiBulkBytes
func (r *EmptyMultiBulkReply) ToBytes() []byte {
return emptyMultiBulkBytes
}
// reply nothing, for commands like subscribe
type NoReply struct {}
type NoReply struct{}
var NoBytes = []byte("")
func (r *NoReply)ToBytes()[]byte {
return NoBytes
func (r *NoReply) ToBytes() []byte {
return NoBytes
}

View File

@@ -1,67 +1,67 @@
package reply
// UnknownErr
type UnknownErrReply struct {}
type UnknownErrReply struct{}
var unknownErrBytes = []byte("-Err unknown\r\n")
func (r *UnknownErrReply)ToBytes()[]byte {
return unknownErrBytes
func (r *UnknownErrReply) ToBytes() []byte {
return unknownErrBytes
}
func (r *UnknownErrReply) Error()string {
return "Err unknown"
func (r *UnknownErrReply) Error() string {
return "Err unknown"
}
// ArgNumErr
type ArgNumErrReply struct {
Cmd string
Cmd string
}
func (r *ArgNumErrReply)ToBytes()[]byte {
return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n")
func (r *ArgNumErrReply) ToBytes() []byte {
return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n")
}
func (r *ArgNumErrReply) Error()string {
return "ERR wrong number of arguments for '" + r.Cmd + "' command"
func (r *ArgNumErrReply) Error() string {
return "ERR wrong number of arguments for '" + r.Cmd + "' command"
}
// SyntaxErr
type SyntaxErrReply struct {}
type SyntaxErrReply struct{}
var syntaxErrBytes = []byte("-Err syntax error\r\n")
func (r *SyntaxErrReply)ToBytes()[]byte {
return syntaxErrBytes
func (r *SyntaxErrReply) ToBytes() []byte {
return syntaxErrBytes
}
func (r *SyntaxErrReply)Error()string {
return "Err syntax error"
func (r *SyntaxErrReply) Error() string {
return "Err syntax error"
}
// WrongTypeErr
type WrongTypeErrReply struct {}
type WrongTypeErrReply struct{}
var wrongTypeErrBytes = []byte("-WRONGTYPE Operation against a key holding the wrong kind of value\r\n")
func (r *WrongTypeErrReply)ToBytes()[]byte {
return wrongTypeErrBytes
func (r *WrongTypeErrReply) ToBytes() []byte {
return wrongTypeErrBytes
}
func (r *WrongTypeErrReply)Error()string {
return "WRONGTYPE Operation against a key holding the wrong kind of value"
func (r *WrongTypeErrReply) Error() string {
return "WRONGTYPE Operation against a key holding the wrong kind of value"
}
// ProtocolErr
type ProtocolErrReply struct {
Msg string
Msg string
}
func (r *ProtocolErrReply)ToBytes()[]byte {
return []byte("-ERR Protocol error: '" + r.Msg + "'\r\n")
func (r *ProtocolErrReply) ToBytes() []byte {
return []byte("-ERR Protocol error: '" + r.Msg + "'\r\n")
}
func (r *ProtocolErrReply) Error()string {
return "ERR Protocol error: '" + r.Msg
func (r *ProtocolErrReply) Error() string {
return "ERR Protocol error: '" + r.Msg
}

View File

@@ -1,141 +1,140 @@
package reply
import (
"bytes"
"github.com/HDT3213/godis/src/interface/redis"
"strconv"
"bytes"
"github.com/HDT3213/godis/src/interface/redis"
"strconv"
)
var (
nullBulkReplyBytes = []byte("$-1")
CRLF = "\r\n"
nullBulkReplyBytes = []byte("$-1")
CRLF = "\r\n"
)
/* ---- Bulk Reply ---- */
type BulkReply struct {
Arg []byte
Arg []byte
}
func MakeBulkReply(arg []byte) *BulkReply {
return &BulkReply{
Arg: arg,
}
return &BulkReply{
Arg: arg,
}
}
func (r *BulkReply) ToBytes() []byte {
if len(r.Arg) == 0 {
return nullBulkReplyBytes
}
return []byte("$" + strconv.Itoa(len(r.Arg)) + CRLF + string(r.Arg) + CRLF)
if len(r.Arg) == 0 {
return nullBulkReplyBytes
}
return []byte("$" + strconv.Itoa(len(r.Arg)) + CRLF + string(r.Arg) + CRLF)
}
/* ---- Multi Bulk Reply ---- */
type MultiBulkReply struct {
Args [][]byte
Args [][]byte
}
func MakeMultiBulkReply(args [][]byte) *MultiBulkReply {
return &MultiBulkReply{
Args: args,
}
return &MultiBulkReply{
Args: args,
}
}
func (r *MultiBulkReply) ToBytes() []byte {
argLen := len(r.Args)
var buf bytes.Buffer
buf.WriteString("*" + strconv.Itoa(argLen) + CRLF)
for _, arg := range r.Args {
if arg == nil {
buf.WriteString("$-1" + CRLF)
} else {
buf.WriteString("$" + strconv.Itoa(len(arg)) + CRLF + string(arg) + CRLF)
}
}
return buf.Bytes()
argLen := len(r.Args)
var buf bytes.Buffer
buf.WriteString("*" + strconv.Itoa(argLen) + CRLF)
for _, arg := range r.Args {
if arg == nil {
buf.WriteString("$-1" + CRLF)
} else {
buf.WriteString("$" + strconv.Itoa(len(arg)) + CRLF + string(arg) + CRLF)
}
}
return buf.Bytes()
}
/* ---- Multi Raw Reply ---- */
type MultiRawReply struct {
Args [][]byte
Args [][]byte
}
func MakeMultiRawReply(args [][]byte) *MultiRawReply {
return &MultiRawReply{
Args: args,
}
return &MultiRawReply{
Args: args,
}
}
func (r *MultiRawReply) ToBytes() []byte {
argLen := len(r.Args)
var buf bytes.Buffer
buf.WriteString("*" + strconv.Itoa(argLen) + CRLF)
for _, arg := range r.Args {
buf.Write(arg)
}
return buf.Bytes()
argLen := len(r.Args)
var buf bytes.Buffer
buf.WriteString("*" + strconv.Itoa(argLen) + CRLF)
for _, arg := range r.Args {
buf.Write(arg)
}
return buf.Bytes()
}
/* ---- Status Reply ---- */
type StatusReply struct {
Status string
Status string
}
func MakeStatusReply(status string) *StatusReply {
return &StatusReply{
Status: status,
}
return &StatusReply{
Status: status,
}
}
func (r *StatusReply) ToBytes() []byte {
return []byte("+" + r.Status + "\r\n")
return []byte("+" + r.Status + "\r\n")
}
/* ---- Int Reply ---- */
type IntReply struct {
Code int64
Code int64
}
func MakeIntReply(code int64) *IntReply {
return &IntReply{
Code: code,
}
return &IntReply{
Code: code,
}
}
func (r *IntReply) ToBytes() []byte {
return []byte(":" + strconv.FormatInt(r.Code, 10) + CRLF)
return []byte(":" + strconv.FormatInt(r.Code, 10) + CRLF)
}
/* ---- Error Reply ---- */
type ErrorReply interface {
Error() string
ToBytes() []byte
Error() string
ToBytes() []byte
}
type StandardErrReply struct {
Status string
Status string
}
func MakeErrReply(status string) *StandardErrReply {
return &StandardErrReply{
Status: status,
}
return &StandardErrReply{
Status: status,
}
}
func IsErrorReply(reply redis.Reply) bool {
return reply.ToBytes()[0] == '-'
return reply.ToBytes()[0] == '-'
}
func (r *StandardErrReply) ToBytes() []byte {
return []byte("-" + r.Status + "\r\n")
return []byte("-" + r.Status + "\r\n")
}
func (r *StandardErrReply) Error() string {
return r.Status
return r.Status
}

View File

@@ -1,95 +1,95 @@
package server
import (
"github.com/HDT3213/godis/src/lib/sync/atomic"
"github.com/HDT3213/godis/src/lib/sync/wait"
"net"
"sync"
"time"
"github.com/HDT3213/godis/src/lib/sync/atomic"
"github.com/HDT3213/godis/src/lib/sync/wait"
"net"
"sync"
"time"
)
// abstract of active client
type Client struct {
conn net.Conn
conn net.Conn
// waiting util reply finished
waitingReply wait.Wait
// waiting util reply finished
waitingReply wait.Wait
// is sending request in progress
uploading atomic.AtomicBool
// multi bulk msg lineCount - 1(first line)
expectedArgsCount uint32
// sent line count, exclude first line
receivedCount uint32
// sent lines, exclude first line
args [][]byte
// is sending request in progress
uploading atomic.AtomicBool
// multi bulk msg lineCount - 1(first line)
expectedArgsCount uint32
// sent line count, exclude first line
receivedCount uint32
// sent lines, exclude first line
args [][]byte
// lock while server sending response
mu sync.Mutex
// lock while server sending response
mu sync.Mutex
// subscribing channels
subs map[string]bool
// subscribing channels
subs map[string]bool
}
func (c *Client)Close()error {
c.waitingReply.WaitWithTimeout(10 * time.Second)
_ = c.conn.Close()
return nil
func (c *Client) Close() error {
c.waitingReply.WaitWithTimeout(10 * time.Second)
_ = c.conn.Close()
return nil
}
func MakeClient(conn net.Conn) *Client {
return &Client{
conn: conn,
}
return &Client{
conn: conn,
}
}
func (c *Client)Write(b []byte)error {
if b == nil || len(b) == 0 {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
func (c *Client) Write(b []byte) error {
if b == nil || len(b) == 0 {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
_, err := c.conn.Write(b)
return err
_, err := c.conn.Write(b)
return err
}
func (c *Client)SubsChannel(channel string) {
c.mu.Lock()
defer c.mu.Unlock()
func (c *Client) SubsChannel(channel string) {
c.mu.Lock()
defer c.mu.Unlock()
if c.subs == nil {
c.subs = make(map[string]bool)
}
c.subs[channel] = true
if c.subs == nil {
c.subs = make(map[string]bool)
}
c.subs[channel] = true
}
func (c *Client)UnSubsChannel(channel string) {
c.mu.Lock()
defer c.mu.Unlock()
func (c *Client) UnSubsChannel(channel string) {
c.mu.Lock()
defer c.mu.Unlock()
if c.subs == nil {
return
}
delete(c.subs, channel)
if c.subs == nil {
return
}
delete(c.subs, channel)
}
func (c *Client)SubsCount()int {
if c.subs == nil {
return 0
}
return len(c.subs)
func (c *Client) SubsCount() int {
if c.subs == nil {
return 0
}
return len(c.subs)
}
func (c *Client)GetChannels()[]string {
if c.subs == nil {
return make([]string, 0)
}
channels := make([]string, len(c.subs))
i := 0
for channel := range c.subs {
channels[i] = channel
i++
}
return channels
func (c *Client) GetChannels() []string {
if c.subs == nil {
return make([]string, 0)
}
channels := make([]string, len(c.subs))
i := 0
for channel := range c.subs {
channels[i] = channel
i++
}
return channels
}

View File

@@ -5,192 +5,192 @@ package server
*/
import (
"bufio"
"context"
"github.com/HDT3213/godis/src/cluster"
"github.com/HDT3213/godis/src/config"
DBImpl "github.com/HDT3213/godis/src/db"
"github.com/HDT3213/godis/src/interface/db"
"github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/lib/sync/atomic"
"github.com/HDT3213/godis/src/redis/reply"
"io"
"net"
"strconv"
"strings"
"sync"
"bufio"
"context"
"github.com/HDT3213/godis/src/cluster"
"github.com/HDT3213/godis/src/config"
DBImpl "github.com/HDT3213/godis/src/db"
"github.com/HDT3213/godis/src/interface/db"
"github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/lib/sync/atomic"
"github.com/HDT3213/godis/src/redis/reply"
"io"
"net"
"strconv"
"strings"
"sync"
)
var (
UnknownErrReplyBytes = []byte("-ERR unknown\r\n")
UnknownErrReplyBytes = []byte("-ERR unknown\r\n")
)
type Handler struct {
activeConn sync.Map // *client -> placeholder
db db.DB
closing atomic.AtomicBool // refusing new client and new request
activeConn sync.Map // *client -> placeholder
db db.DB
closing atomic.AtomicBool // refusing new client and new request
}
func MakeHandler() *Handler {
var db db.DB
if config.Properties.Peers != nil &&
len(config.Properties.Peers) > 0 {
db = cluster.MakeCluster()
} else {
db = DBImpl.MakeDB()
}
return &Handler{
db: db,
}
var db db.DB
if config.Properties.Peers != nil &&
len(config.Properties.Peers) > 0 {
db = cluster.MakeCluster()
} else {
db = DBImpl.MakeDB()
}
return &Handler{
db: db,
}
}
func (h *Handler) closeClient(client *Client) {
_ = client.Close()
h.db.AfterClientClose(client)
h.activeConn.Delete(client)
_ = client.Close()
h.db.AfterClientClose(client)
h.activeConn.Delete(client)
}
func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
if h.closing.Get() {
// closing handler refuse new connection
_ = conn.Close()
}
if h.closing.Get() {
// closing handler refuse new connection
_ = conn.Close()
}
client := MakeClient(conn)
h.activeConn.Store(client, 1)
client := MakeClient(conn)
h.activeConn.Store(client, 1)
reader := bufio.NewReader(conn)
var fixedLen int64 = 0
var err error
var msg []byte
for {
if fixedLen == 0 {
msg, err = reader.ReadBytes('\n')
if err != nil {
if err == io.EOF ||
err == io.ErrUnexpectedEOF ||
strings.Contains(err.Error(), "use of closed network connection") {
logger.Info("connection close")
} else {
logger.Warn(err)
}
reader := bufio.NewReader(conn)
var fixedLen int64 = 0
var err error
var msg []byte
for {
if fixedLen == 0 {
msg, err = reader.ReadBytes('\n')
if err != nil {
if err == io.EOF ||
err == io.ErrUnexpectedEOF ||
strings.Contains(err.Error(), "use of closed network connection") {
logger.Info("connection close")
} else {
logger.Warn(err)
}
// after client close
h.closeClient(client)
return // io error, disconnect with client
}
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
_, _ = client.conn.Write(errReply.ToBytes())
}
} else {
msg = make([]byte, fixedLen+2)
_, err = io.ReadFull(reader, msg)
if err != nil {
if err == io.EOF ||
err == io.ErrUnexpectedEOF ||
strings.Contains(err.Error(), "use of closed network connection") {
logger.Info("connection close")
} else {
logger.Warn(err)
}
// after client close
h.closeClient(client)
return // io error, disconnect with client
}
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
_, _ = client.conn.Write(errReply.ToBytes())
}
} else {
msg = make([]byte, fixedLen+2)
_, err = io.ReadFull(reader, msg)
if err != nil {
if err == io.EOF ||
err == io.ErrUnexpectedEOF ||
strings.Contains(err.Error(), "use of closed network connection") {
logger.Info("connection close")
} else {
logger.Warn(err)
}
// after client close
h.closeClient(client)
return // io error, disconnect with client
}
if len(msg) == 0 ||
msg[len(msg)-2] != '\r' ||
msg[len(msg)-1] != '\n' {
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
_, _ = client.conn.Write(errReply.ToBytes())
}
fixedLen = 0
}
// after client close
h.closeClient(client)
return // io error, disconnect with client
}
if len(msg) == 0 ||
msg[len(msg)-2] != '\r' ||
msg[len(msg)-1] != '\n' {
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
_, _ = client.conn.Write(errReply.ToBytes())
}
fixedLen = 0
}
if !client.uploading.Get() {
// new request
if msg[0] == '*' {
// bulk multi msg
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
if err != nil {
_, _ = client.conn.Write(UnknownErrReplyBytes)
continue
}
client.waitingReply.Add(1)
client.uploading.Set(true)
client.expectedArgsCount = uint32(expectedLine)
client.receivedCount = 0
client.args = make([][]byte, expectedLine)
} else {
// text protocol
// remove \r or \n or \r\n in the end of line
str := strings.TrimSuffix(string(msg), "\n")
str = strings.TrimSuffix(str, "\r")
strs := strings.Split(str, " ")
args := make([][]byte, len(strs))
for i, s := range strs {
args[i] = []byte(s)
}
if !client.uploading.Get() {
// new request
if msg[0] == '*' {
// bulk multi msg
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
if err != nil {
_, _ = client.conn.Write(UnknownErrReplyBytes)
continue
}
client.waitingReply.Add(1)
client.uploading.Set(true)
client.expectedArgsCount = uint32(expectedLine)
client.receivedCount = 0
client.args = make([][]byte, expectedLine)
} else {
// text protocol
// remove \r or \n or \r\n in the end of line
str := strings.TrimSuffix(string(msg), "\n")
str = strings.TrimSuffix(str, "\r")
strs := strings.Split(str, " ")
args := make([][]byte, len(strs))
for i, s := range strs {
args[i] = []byte(s)
}
// send reply
result := h.db.Exec(client, args)
if result != nil {
_ = client.Write(result.ToBytes())
} else {
_ = client.Write(UnknownErrReplyBytes)
}
}
} else {
// receive following part of a request
line := msg[0 : len(msg)-2]
if line[0] == '$' {
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
if err != nil {
errReply := &reply.ProtocolErrReply{Msg: err.Error()}
_, _ = client.conn.Write(errReply.ToBytes())
}
if fixedLen <= 0 {
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
_, _ = client.conn.Write(errReply.ToBytes())
}
} else {
client.args[client.receivedCount] = line
client.receivedCount++
}
// send reply
result := h.db.Exec(client, args)
if result != nil {
_ = client.Write(result.ToBytes())
} else {
_ = client.Write(UnknownErrReplyBytes)
}
}
} else {
// receive following part of a request
line := msg[0 : len(msg)-2]
if line[0] == '$' {
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
if err != nil {
errReply := &reply.ProtocolErrReply{Msg: err.Error()}
_, _ = client.conn.Write(errReply.ToBytes())
}
if fixedLen <= 0 {
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
_, _ = client.conn.Write(errReply.ToBytes())
}
} else {
client.args[client.receivedCount] = line
client.receivedCount++
}
// if sending finished
if client.receivedCount == client.expectedArgsCount {
client.uploading.Set(false) // finish sending progress
// if sending finished
if client.receivedCount == client.expectedArgsCount {
client.uploading.Set(false) // finish sending progress
// send reply
result := h.db.Exec(client, client.args)
if result != nil {
_ = client.Write(result.ToBytes())
} else {
_ = client.Write(UnknownErrReplyBytes)
}
// send reply
result := h.db.Exec(client, client.args)
if result != nil {
_ = client.Write(result.ToBytes())
} else {
_ = client.Write(UnknownErrReplyBytes)
}
// finish reply
client.expectedArgsCount = 0
client.receivedCount = 0
client.args = nil
client.waitingReply.Done()
}
}
// finish reply
client.expectedArgsCount = 0
client.receivedCount = 0
client.args = nil
client.waitingReply.Done()
}
}
}
}
}
func (h *Handler) Close() error {
logger.Info("handler shuting down...")
h.closing.Set(true)
// TODO: concurrent wait
h.activeConn.Range(func(key interface{}, val interface{}) bool {
client := key.(*Client)
_ = client.Close()
return true
})
h.db.Close()
return nil
logger.Info("handler shuting down...")
h.closing.Set(true)
// TODO: concurrent wait
h.activeConn.Range(func(key interface{}, val interface{}) bool {
client := key.(*Client)
_ = client.Close()
return true
})
h.db.Close()
return nil
}

View File

@@ -5,79 +5,78 @@ package tcp
*/
import (
"net"
"context"
"bufio"
"github.com/HDT3213/godis/src/lib/logger"
"sync"
"io"
"github.com/HDT3213/godis/src/lib/sync/atomic"
"time"
"github.com/HDT3213/godis/src/lib/sync/wait"
"bufio"
"context"
"github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/lib/sync/atomic"
"github.com/HDT3213/godis/src/lib/sync/wait"
"io"
"net"
"sync"
"time"
)
type EchoHandler struct {
activeConn sync.Map
closing atomic.AtomicBool
activeConn sync.Map
closing atomic.AtomicBool
}
func MakeEchoHandler()(*EchoHandler) {
return &EchoHandler{
}
func MakeEchoHandler() *EchoHandler {
return &EchoHandler{}
}
type Client struct {
Conn net.Conn
Waiting wait.Wait
Conn net.Conn
Waiting wait.Wait
}
func (c *Client)Close()error {
c.Waiting.WaitWithTimeout(10 * time.Second)
c.Conn.Close()
return nil
func (c *Client) Close() error {
c.Waiting.WaitWithTimeout(10 * time.Second)
c.Conn.Close()
return nil
}
func (h *EchoHandler)Handle(ctx context.Context, conn net.Conn) {
if h.closing.Get() {
// closing handler refuse new connection
conn.Close()
}
func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
if h.closing.Get() {
// closing handler refuse new connection
conn.Close()
}
client := &Client {
Conn: conn,
}
h.activeConn.Store(client, 1)
client := &Client{
Conn: conn,
}
h.activeConn.Store(client, 1)
reader := bufio.NewReader(conn)
for {
// may occurs: client EOF, client timeout, server early close
msg, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
logger.Info("connection close")
h.activeConn.Delete(client)
} else {
logger.Warn(err)
}
return
}
client.Waiting.Add(1)
//logger.Info("sleeping")
//time.Sleep(10 * time.Second)
b := []byte(msg)
conn.Write(b)
client.Waiting.Done()
}
reader := bufio.NewReader(conn)
for {
// may occurs: client EOF, client timeout, server early close
msg, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
logger.Info("connection close")
h.activeConn.Delete(client)
} else {
logger.Warn(err)
}
return
}
client.Waiting.Add(1)
//logger.Info("sleeping")
//time.Sleep(10 * time.Second)
b := []byte(msg)
conn.Write(b)
client.Waiting.Done()
}
}
func (h *EchoHandler)Close()error {
logger.Info("handler shuting down...")
h.closing.Set(true)
// TODO: concurrent wait
h.activeConn.Range(func(key interface{}, val interface{})bool {
client := key.(*Client)
client.Close()
return true
})
return nil
func (h *EchoHandler) Close() error {
logger.Info("handler shuting down...")
h.closing.Set(true)
// TODO: concurrent wait
h.activeConn.Range(func(key interface{}, val interface{}) bool {
client := key.(*Client)
client.Close()
return true
})
return nil
}

View File

@@ -5,75 +5,74 @@ package tcp
*/
import (
"context"
"fmt"
"github.com/HDT3213/godis/src/interface/tcp"
"github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/lib/sync/atomic"
"net"
"os"
"os/signal"
"sync"
"syscall"
"time"
"context"
"fmt"
"github.com/HDT3213/godis/src/interface/tcp"
"github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/lib/sync/atomic"
"net"
"os"
"os/signal"
"sync"
"syscall"
"time"
)
type Config struct {
Address string `yaml:"address"`
MaxConnect uint32 `yaml:"max-connect"`
Timeout time.Duration `yaml:"timeout"`
Address string `yaml:"address"`
MaxConnect uint32 `yaml:"max-connect"`
Timeout time.Duration `yaml:"timeout"`
}
func ListenAndServe(cfg *Config, handler tcp.Handler) {
listener, err := net.Listen("tcp", cfg.Address)
if err != nil {
logger.Fatal(fmt.Sprintf("listen err: %v", err))
}
listener, err := net.Listen("tcp", cfg.Address)
if err != nil {
logger.Fatal(fmt.Sprintf("listen err: %v", err))
}
// listen signal
var closing atomic.AtomicBool
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT)
go func() {
sig := <-sigCh
switch sig {
case syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT:
logger.Info("shuting down...")
closing.Set(true)
_ = listener.Close() // listener.Accept() will return err immediately
_ = handler.Close() // close connections
}
}()
// listen signal
var closing atomic.AtomicBool
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT)
go func() {
sig := <-sigCh
switch sig {
case syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT:
logger.Info("shuting down...")
closing.Set(true)
_ = listener.Close() // listener.Accept() will return err immediately
_ = handler.Close() // close connections
}
}()
// listen port
logger.Info(fmt.Sprintf("bind: %s, start listening...", cfg.Address))
defer func() {
// close during unexpected error
_ = listener.Close()
_ = handler.Close()
}()
ctx, _ := context.WithCancel(context.Background())
var waitDone sync.WaitGroup
for {
conn, err := listener.Accept()
if err != nil {
if closing.Get() {
logger.Info("waiting disconnect...")
waitDone.Wait()
return // handler will be closed by defer
}
logger.Error(fmt.Sprintf("accept err: %v", err))
continue
}
// handle
logger.Info("accept link")
waitDone.Add(1)
go func() {
defer func() {
waitDone.Done()
}()
handler.Handle(ctx, conn)
}()
}
// listen port
logger.Info(fmt.Sprintf("bind: %s, start listening...", cfg.Address))
defer func() {
// close during unexpected error
_ = listener.Close()
_ = handler.Close()
}()
ctx, _ := context.WithCancel(context.Background())
var waitDone sync.WaitGroup
for {
conn, err := listener.Accept()
if err != nil {
if closing.Get() {
logger.Info("waiting disconnect...")
waitDone.Wait()
return // handler will be closed by defer
}
logger.Error(fmt.Sprintf("accept err: %v", err))
continue
}
// handle
logger.Info("accept link")
waitDone.Add(1)
go func() {
defer func() {
waitDone.Done()
}()
handler.Handle(ctx, conn)
}()
}
}